diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index 8d3b5d4..94cc78d 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -10,4 +10,11 @@ jobs: validate: runs-on: ubuntu-latest steps: + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install torch - uses: comfy-org/node-diff@main diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..b50b37a --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,14 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Python File", + "type": "debugpy", + "request": "launch", + "program": "${file}" + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..70c86d9 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,19 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "python.defaultInterpreterPath": "${workspaceFolder}/../../venv/bin/python", + "python.analysis.extraPaths": [ + "${workspaceFolder}/src", + "${workspaceFolder}", + "${workspaceFolder}/../../" + ], + "python.testing.cwd": "${workspaceFolder}", + "python.autoComplete.extraPaths": [ + "${workspaceFolder}/src", + "${workspaceFolder}", + "${workspaceFolder}/../../" + ] +} \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000..fc81a7d --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,85 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "label": "Run Tests (pytest)", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m pytest", + "group": { + "kind": "test", + "isDefault": true + }, + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + }, + { + "label": "Lint (ruff)", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m ruff check .", + "group": "build", + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + }, + { + "label": "Format (ruff)", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m ruff format .", + "group": "build", + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + }, + { + "label": "Type Check (mypy)", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m mypy .", + "group": "build", + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + }, + { + "label": "Run Coverage", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m pytest --cov=src", + "group": "test", + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + }, + { + "label": "Pre-commit Run All", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m pre_commit run --all-files", + "group": "build", + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + }, + { + "label": "Bump Version (patch)", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m bump_my_version bump patch", + "group": "build", + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + } + ] +} diff --git a/README.md b/README.md index 6231c6e..aa54a38 100644 --- a/README.md +++ b/README.md @@ -159,6 +159,18 @@ String manipulation nodes: - **Text modification**: concat, count, replace, strip, lstrip, rstrip, removeprefix, removesuffix - **Encoding/escaping**: decode, encode, escape, unescape, format_map +### TENSOR + +PyTorch tensor manipulation nodes: + +- **Creation**: Tensor Create (from numbers, lists, or other tensors) +- **Arithmetic**: Tensor Binary Op (add, subtract, multiply, divide, power, remainder, floor_divide) +- **Functions**: Tensor Unary Op (abs, neg, exp, log, sin, cos, sqrt, sigmoid, relu) +- **Reshaping**: Tensor Reshape, Tensor Permute (dims) +- **Access**: Tensor Slice (supports Python-style slice strings like `0:10, :, 5`) +- **Combined**: Tensor Join (concatenate or stack) +- **Analysis**: Tensor Info (returns shape, dtype, device) + ### Time Date and time manipulation nodes: diff --git a/pyproject.toml b/pyproject.toml index 8c30f81..211a825 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,12 +4,13 @@ build-backend = "setuptools.build_meta" [project] name = "basic_data_handling" -version = "1.3.0" +version = "1.4.0" description = """Basic Python functions for manipulating data that every programmer is used to, lightweight with no additional dependencies. Supported data types: - ComfyUI native: BOOLEAN, FLOAT, INT, STRING, and data lists - Python types as custom data types: DICT, LIST, SET, DATETIME, TIMEDELTA +- PyTorch: TENSOR Feature categories: - Boolean logic operations @@ -22,7 +23,8 @@ Feature categories: - String manipulation - File system path handling, including STRING, IMAGE and MASK load and save - SET operations -- time and date handling""" +- time and date handling +- PyTorch Tensor manipulation (arithmetic, slicing, reshaping)""" authors = [ {name = "StableLlama"} ] @@ -58,7 +60,7 @@ Icon = "" minversion = "8.0" pythonpath = [ "src", - #"../..", # Path to parent directory containing comfy module + "../..", # Path to parent directory containing comfy module "." ] testpaths = [ diff --git a/src/basic_data_handling/__init__.py b/src/basic_data_handling/__init__.py index 8dfbb20..ccdbc3b 100644 --- a/src/basic_data_handling/__init__.py +++ b/src/basic_data_handling/__init__.py @@ -1,7 +1,7 @@ from . import (boolean_nodes, casting_nodes, comparison_nodes, control_flow_nodes, data_list_nodes, dict_nodes, float_nodes, int_nodes, list_nodes, math_nodes, math_formula_node, path_nodes, regex_nodes, set_nodes, - string_nodes, time_nodes) + string_nodes, tensor_nodes, time_nodes) NODE_CLASS_MAPPINGS = {} NODE_CLASS_MAPPINGS.update(boolean_nodes.NODE_CLASS_MAPPINGS) @@ -19,6 +19,7 @@ NODE_CLASS_MAPPINGS.update(math_nodes.NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(math_formula_node.NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(string_nodes.NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(tensor_nodes.NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(time_nodes.NODE_CLASS_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS = {} @@ -37,4 +38,5 @@ NODE_DISPLAY_NAME_MAPPINGS.update(math_nodes.NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(math_formula_node.NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(string_nodes.NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(tensor_nodes.NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(time_nodes.NODE_DISPLAY_NAME_MAPPINGS) diff --git a/src/basic_data_handling/tensor_nodes.py b/src/basic_data_handling/tensor_nodes.py new file mode 100644 index 0000000..9b825d6 --- /dev/null +++ b/src/basic_data_handling/tensor_nodes.py @@ -0,0 +1,296 @@ +import torch +from inspect import cleandoc +from typing import Any + +try: + from comfy.comfy_types.node_typing import IO, ComfyNodeABC +except: + class IO: + BOOLEAN = "BOOLEAN" + INT = "INT" + FLOAT = "FLOAT" + STRING = "STRING" + NUMBER = "FLOAT,INT" + ANY = "*" + ComfyNodeABC = object + +class TensorCreate(ComfyNodeABC): + """ + Creates a PyTorch tensor from various input types. + Can accept numbers, lists, or existing tensors. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input": (IO.ANY, {}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "create" + + def create(self, input: Any) -> tuple[torch.Tensor]: + if isinstance(input, torch.Tensor): + return (input,) + try: + return (torch.tensor(input),) + except Exception as e: + raise ValueError(f"Failed to create tensor from {type(input)}: {str(e)}") + +class TensorBinaryOp(ComfyNodeABC): + """ + Performs binary operations between two tensors or a tensor and a scalar. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": (IO.ANY, {}), + "b": (IO.ANY, {}), + "operation": (["add", "subtract", "multiply", "divide", "power", "remainder", "floor_divide"], {"default": "add"}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "operate" + + def operate(self, a: Any, b: Any, operation: str) -> tuple[torch.Tensor]: + a_tensor = a if isinstance(a, torch.Tensor) else torch.tensor(a) + b_tensor = b if isinstance(b, torch.Tensor) else torch.tensor(b) + + if operation == "add": + return (a_tensor + b_tensor,) + elif operation == "subtract": + return (a_tensor - b_tensor,) + elif operation == "multiply": + return (a_tensor * b_tensor,) + elif operation == "divide": + return (a_tensor / b_tensor,) + elif operation == "power": + return (torch.pow(a_tensor, b_tensor),) + elif operation == "remainder": + return (a_tensor % b_tensor,) + elif operation == "floor_divide": + return (a_tensor // b_tensor,) + else: + raise ValueError(f"Unknown operation: {operation}") + +class TensorUnaryOp(ComfyNodeABC): + """ + Performs unary operations on a tensor. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input": (IO.ANY, {}), + "operation": (["abs", "neg", "exp", "log", "sin", "cos", "sqrt", "sigmoid", "relu"], {"default": "abs"}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "operate" + + def operate(self, input: Any, operation: str) -> tuple[torch.Tensor]: + t = input if isinstance(input, torch.Tensor) else torch.tensor(input) + + if operation == "abs": + return (torch.abs(t),) + elif operation == "neg": + return (torch.neg(t),) + elif operation == "exp": + return (torch.exp(t),) + elif operation == "log": + return (torch.log(t),) + elif operation == "sin": + return (torch.sin(t),) + elif operation == "cos": + return (torch.cos(t),) + elif operation == "sqrt": + return (torch.sqrt(t),) + elif operation == "sigmoid": + return (torch.sigmoid(t),) + elif operation == "relu": + return (torch.relu(t),) + else: + raise ValueError(f"Unknown operation: {operation}") + +class TensorSlice(ComfyNodeABC): + """ + Slices a tensor using a slice string (e.g., ':, 0:10, 5'). + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "tensor": (IO.ANY, {}), + "slice_str": (IO.STRING, {"default": ":"}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "slice_tensor" + + def slice_tensor(self, tensor: Any, slice_str: str) -> tuple[torch.Tensor]: + if not isinstance(tensor, torch.Tensor): + tensor = torch.tensor(tensor) + + # Basic parsing of slice string + # This is a bit advanced for a simple node, but very useful. + # We'll use eval in a restricted way or manual parsing. + # Manual parsing is safer. + def parse_slice(s): + parts = s.split(':') + if len(parts) == 1: + return int(parts[0]) + return slice(*(int(p) if p.strip() else None for p in parts)) + + try: + dims = [d.strip() for d in slice_str.split(',')] + indices = tuple(parse_slice(d) if ':' in d or d.isdigit() or (d.startswith('-') and d[1:].isdigit()) else d for d in dims) + # Re-index with parsed slices + # Note: this is a simplification. torch supports more complex indexing. + # But for basic usage, this covers most cases. + return (tensor[indices],) + except Exception as e: + raise ValueError(f"Failed to slice tensor with '{slice_str}': {str(e)}") + +class TensorReshape(ComfyNodeABC): + """ + Reshapes a tensor. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "tensor": (IO.ANY, {}), + "shape": (IO.STRING, {"default": "-1"}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "reshape" + + def reshape(self, tensor: Any, shape: str) -> tuple[torch.Tensor]: + if not isinstance(tensor, torch.Tensor): + tensor = torch.tensor(tensor) + + try: + shape_tuple = tuple(int(s.strip()) for s in shape.split(',')) + return (tensor.reshape(shape_tuple),) + except Exception as e: + raise ValueError(f"Failed to reshape tensor to {shape}: {str(e)}") + +class TensorPermute(ComfyNodeABC): + """ + Permutes tensor dimensions. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "tensor": (IO.ANY, {}), + "dims": (IO.STRING, {"default": "0, 1"}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "permute" + + def permute(self, tensor: Any, dims: str) -> tuple[torch.Tensor]: + if not isinstance(tensor, torch.Tensor): + tensor = torch.tensor(tensor) + + try: + dims_tuple = tuple(int(d.strip()) for d in dims.split(',')) + return (tensor.permute(dims_tuple),) + except Exception as e: + raise ValueError(f"Failed to permute tensor with dims {dims}: {str(e)}") + +class TensorJoin(ComfyNodeABC): + """ + Joins multiple tensors (concatenate or stack). + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "tensor1": (IO.ANY, {}), + "tensor2": (IO.ANY, {}), + "dim": (IO.INT, {"default": 0}), + "mode": (["concatenate", "stack"], {"default": "concatenate"}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "join" + + def join(self, tensor1: Any, tensor2: Any, dim: int, mode: str) -> tuple[torch.Tensor]: + t1 = tensor1 if isinstance(tensor1, torch.Tensor) else torch.tensor(tensor1) + t2 = tensor2 if isinstance(tensor2, torch.Tensor) else torch.tensor(tensor2) + + if mode == "concatenate": + return (torch.cat([t1, t2], dim=dim),) + else: + return (torch.stack([t1, t2], dim=dim),) + +class TensorInfo(ComfyNodeABC): + """ + Returns information about a tensor. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "tensor": (IO.ANY, {}), + } + } + + RETURN_TYPES = (IO.ANY, IO.STRING, IO.STRING) + RETURN_NAMES = ("shape", "dtype", "device") + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "get_info" + + def get_info(self, tensor: Any) -> tuple[list[int], str, str]: + if not isinstance(tensor, torch.Tensor): + tensor = torch.tensor(tensor) + + return (list(tensor.shape), str(tensor.dtype), str(tensor.device)) + +NODE_CLASS_MAPPINGS = { + "TensorCreate": TensorCreate, + "TensorBinaryOp": TensorBinaryOp, + "TensorUnaryOp": TensorUnaryOp, + "TensorSlice": TensorSlice, + "TensorReshape": TensorReshape, + "TensorPermute": TensorPermute, + "TensorJoin": TensorJoin, + "TensorInfo": TensorInfo, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "TensorCreate": "Tensor Create", + "TensorBinaryOp": "Tensor Binary Op", + "TensorUnaryOp": "Tensor Unary Op", + "TensorSlice": "Tensor Slice", + "TensorReshape": "Tensor Reshape", + "TensorPermute": "Tensor Permute", + "TensorJoin": "Tensor Join", + "TensorInfo": "Tensor Info", +} diff --git a/tests/test_control_flow_nodes.py b/tests/test_control_flow_nodes.py index 595e3fd..8a8d306 100644 --- a/tests/test_control_flow_nodes.py +++ b/tests/test_control_flow_nodes.py @@ -156,25 +156,37 @@ def test_continue_flow(): # select is True assert node.execute("some value", select=True) == ("some value",) # select is False - assert node.execute("some value", select=False) == (None,) # ExecutionBlocker(None) becomes None + res = node.execute("some value", select=False) + assert res == (None,) or "ExecutionBlocker" in str(res[0]) # select is default (True) assert node.execute("some value") == ("some value",) # Test with different value types assert node.execute(123, select=True) == (123,) - assert node.execute([1, 2], select=False) == (None,) + res = node.execute([1, 2], select=False) + assert res == (None,) or "ExecutionBlocker" in str(res[0]) def test_flow_select(): node = FlowSelect() # select is True - assert node.select("some value", select=True) == ("some value", None) + res = node.select("some value", select=True) + assert res[0] == "some value" + assert res[1] is None or "ExecutionBlocker" in str(res[1]) # select is False - assert node.select("some value", select=False) == (None, "some value") + res = node.select("some value", select=False) + assert res[0] is None or "ExecutionBlocker" in str(res[0]) + assert res[1] == "some value" # select is default (True) - assert node.select("some value") == ("some value", None) + res = node.select("some value") + assert res[0] == "some value" + assert res[1] is None or "ExecutionBlocker" in str(res[1]) # Test with different value types - assert node.select(42, select=True) == (42, None) - assert node.select({'a': 1}, select=False) == (None, {'a': 1}) + res = node.select(42, select=True) + assert res[0] == 42 + assert res[1] is None or "ExecutionBlocker" in str(res[1]) + res = node.select({'a': 1}, select=False) + assert res[0] is None or "ExecutionBlocker" in str(res[0]) + assert res[1] == {'a': 1} def test_force_calculation(): diff --git a/tests/test_tensor_nodes.py b/tests/test_tensor_nodes.py new file mode 100644 index 0000000..c404277 --- /dev/null +++ b/tests/test_tensor_nodes.py @@ -0,0 +1,92 @@ +import torch +from src.basic_data_handling.tensor_nodes import ( + TensorCreate, TensorBinaryOp, TensorUnaryOp, TensorSlice, + TensorReshape, TensorPermute, TensorJoin, TensorInfo +) + +def test_tensor_create(): + node = TensorCreate() + # From list + result = node.create([1, 2, 3]) + assert isinstance(result[0], torch.Tensor) + assert torch.equal(result[0], torch.tensor([1, 2, 3])) + + # From scalar + result = node.create(5.0) + assert torch.equal(result[0], torch.tensor(5.0)) + + # From existing tensor + t = torch.tensor([1.0, 2.0]) + result = node.create(t) + assert result[0] is t + +def test_tensor_binary_op(): + node = TensorBinaryOp() + a = torch.tensor([10.0, 20.0]) + b = torch.tensor([2.0, 4.0]) + + assert torch.equal(node.operate(a, b, "add")[0], a + b) + assert torch.equal(node.operate(a, b, "subtract")[0], a - b) + assert torch.equal(node.operate(a, b, "multiply")[0], a * b) + assert torch.equal(node.operate(a, b, "divide")[0], a / b) + assert torch.equal(node.operate(a, 2.0, "power")[0], a ** 2.0) + +def test_tensor_unary_op(): + node = TensorUnaryOp() + t = torch.tensor([-1.0, 0.0, 1.0]) + + assert torch.equal(node.operate(t, "abs")[0], torch.tensor([1.0, 0.0, 1.0])) + assert torch.equal(node.operate(t, "neg")[0], torch.tensor([1.0, 0.0, -1.0])) + assert torch.equal(node.operate(torch.tensor([0.0]), "sin")[0], torch.tensor([0.0])) + +def test_tensor_slice(): + node = TensorSlice() + t = torch.arange(10).reshape(2, 5) # [[0,1,2,3,4], [5,6,7,8,9]] + + # Simple slice + result = node.slice_tensor(t, "0, 1:3") + assert torch.equal(result[0], t[0, 1:3]) + + # Ellipsis/Full slice + result = node.slice_tensor(t, ":, 2") + assert torch.equal(result[0], t[:, 2]) + +def test_tensor_reshape(): + node = TensorReshape() + t = torch.arange(6) + + result = node.reshape(t, "2, 3") + assert result[0].shape == (2, 3) + + result = node.reshape(t, "-1, 2") + assert result[0].shape == (3, 2) + +def test_tensor_permute(): + node = TensorPermute() + t = torch.randn(2, 3, 4) + + result = node.permute(t, "2, 0, 1") + assert result[0].shape == (4, 2, 3) + +def test_tensor_join(): + node = TensorJoin() + t1 = torch.tensor([1, 2]) + t2 = torch.tensor([3, 4]) + + # Cat + result = node.join(t1, t2, 0, "concatenate") + assert torch.equal(result[0], torch.tensor([1, 2, 3, 4])) + + # Stack + result = node.join(t1, t2, 0, "stack") + assert result[0].shape == (2, 2) + assert torch.equal(result[0], torch.tensor([[1, 2], [3, 4]])) + +def test_tensor_info(): + node = TensorInfo() + t = torch.zeros((2, 3), dtype=torch.float32) + + shape, dtype, device = node.get_info(t) + assert shape == [2, 3] + assert "float32" in dtype + assert isinstance(device, str)