From df3be99529a6a0d38d131bae3800333435873586 Mon Sep 17 00:00:00 2001 From: StableLlama Date: Sun, 8 Feb 2026 21:44:17 +0100 Subject: [PATCH 1/4] Add PyTorch tensor manipulation nodes and update documentation --- .vscode/launch.json | 14 ++ .vscode/settings.json | 19 ++ .vscode/tasks.json | 85 +++++++ README.md | 12 + pyproject.toml | 8 +- src/basic_data_handling/__init__.py | 4 +- src/basic_data_handling/tensor_nodes.py | 296 ++++++++++++++++++++++++ tests/test_tensor_nodes.py | 93 ++++++++ 8 files changed, 527 insertions(+), 4 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 .vscode/settings.json create mode 100644 .vscode/tasks.json create mode 100644 src/basic_data_handling/tensor_nodes.py create mode 100644 tests/test_tensor_nodes.py diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..b50b37a --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,14 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Python File", + "type": "debugpy", + "request": "launch", + "program": "${file}" + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..70c86d9 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,19 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "python.defaultInterpreterPath": "${workspaceFolder}/../../venv/bin/python", + "python.analysis.extraPaths": [ + "${workspaceFolder}/src", + "${workspaceFolder}", + "${workspaceFolder}/../../" + ], + "python.testing.cwd": "${workspaceFolder}", + "python.autoComplete.extraPaths": [ + "${workspaceFolder}/src", + "${workspaceFolder}", + "${workspaceFolder}/../../" + ] +} \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000..fc81a7d --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,85 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "label": "Run Tests (pytest)", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m pytest", + "group": { + "kind": "test", + "isDefault": true + }, + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + }, + { + "label": "Lint (ruff)", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m ruff check .", + "group": "build", + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + }, + { + "label": "Format (ruff)", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m ruff format .", + "group": "build", + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + }, + { + "label": "Type Check (mypy)", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m mypy .", + "group": "build", + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + }, + { + "label": "Run Coverage", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m pytest --cov=src", + "group": "test", + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + }, + { + "label": "Pre-commit Run All", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m pre_commit run --all-files", + "group": "build", + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + }, + { + "label": "Bump Version (patch)", + "type": "shell", + "command": "${workspaceFolder}/../../venv/bin/python -m bump_my_version bump patch", + "group": "build", + "presentation": { + "reveal": "always", + "panel": "dedicated" + }, + "problemMatcher": [] + } + ] +} diff --git a/README.md b/README.md index 6231c6e..aa54a38 100644 --- a/README.md +++ b/README.md @@ -159,6 +159,18 @@ String manipulation nodes: - **Text modification**: concat, count, replace, strip, lstrip, rstrip, removeprefix, removesuffix - **Encoding/escaping**: decode, encode, escape, unescape, format_map +### TENSOR + +PyTorch tensor manipulation nodes: + +- **Creation**: Tensor Create (from numbers, lists, or other tensors) +- **Arithmetic**: Tensor Binary Op (add, subtract, multiply, divide, power, remainder, floor_divide) +- **Functions**: Tensor Unary Op (abs, neg, exp, log, sin, cos, sqrt, sigmoid, relu) +- **Reshaping**: Tensor Reshape, Tensor Permute (dims) +- **Access**: Tensor Slice (supports Python-style slice strings like `0:10, :, 5`) +- **Combined**: Tensor Join (concatenate or stack) +- **Analysis**: Tensor Info (returns shape, dtype, device) + ### Time Date and time manipulation nodes: diff --git a/pyproject.toml b/pyproject.toml index 8c30f81..e2384ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ description = """Basic Python functions for manipulating data that every program Supported data types: - ComfyUI native: BOOLEAN, FLOAT, INT, STRING, and data lists - Python types as custom data types: DICT, LIST, SET, DATETIME, TIMEDELTA +- PyTorch: TENSOR Feature categories: - Boolean logic operations @@ -22,14 +23,15 @@ Feature categories: - String manipulation - File system path handling, including STRING, IMAGE and MASK load and save - SET operations -- time and date handling""" +- time and date handling +- PyTorch Tensor manipulation (arithmetic, slicing, reshaping)""" authors = [ {name = "StableLlama"} ] readme = "README.md" license = { file = "LICENSE" } classifiers = [] -dependencies = [] +dependencies = ["torch"] [project.optional-dependencies] dev = [ @@ -58,7 +60,7 @@ Icon = "" minversion = "8.0" pythonpath = [ "src", - #"../..", # Path to parent directory containing comfy module + "../..", # Path to parent directory containing comfy module "." ] testpaths = [ diff --git a/src/basic_data_handling/__init__.py b/src/basic_data_handling/__init__.py index 8dfbb20..ccdbc3b 100644 --- a/src/basic_data_handling/__init__.py +++ b/src/basic_data_handling/__init__.py @@ -1,7 +1,7 @@ from . import (boolean_nodes, casting_nodes, comparison_nodes, control_flow_nodes, data_list_nodes, dict_nodes, float_nodes, int_nodes, list_nodes, math_nodes, math_formula_node, path_nodes, regex_nodes, set_nodes, - string_nodes, time_nodes) + string_nodes, tensor_nodes, time_nodes) NODE_CLASS_MAPPINGS = {} NODE_CLASS_MAPPINGS.update(boolean_nodes.NODE_CLASS_MAPPINGS) @@ -19,6 +19,7 @@ NODE_CLASS_MAPPINGS.update(math_nodes.NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(math_formula_node.NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(string_nodes.NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(tensor_nodes.NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(time_nodes.NODE_CLASS_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS = {} @@ -37,4 +38,5 @@ NODE_DISPLAY_NAME_MAPPINGS.update(math_nodes.NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(math_formula_node.NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(string_nodes.NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(tensor_nodes.NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(time_nodes.NODE_DISPLAY_NAME_MAPPINGS) diff --git a/src/basic_data_handling/tensor_nodes.py b/src/basic_data_handling/tensor_nodes.py new file mode 100644 index 0000000..c5f4f51 --- /dev/null +++ b/src/basic_data_handling/tensor_nodes.py @@ -0,0 +1,296 @@ +import torch +from inspect import cleandoc +from typing import Any, Union + +try: + from comfy.comfy_types.node_typing import IO, ComfyNodeABC +except: + class IO: + BOOLEAN = "BOOLEAN" + INT = "INT" + FLOAT = "FLOAT" + STRING = "STRING" + NUMBER = "FLOAT,INT" + ANY = "*" + ComfyNodeABC = object + +class TensorCreate(ComfyNodeABC): + """ + Creates a PyTorch tensor from various input types. + Can accept numbers, lists, or existing tensors. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input": (IO.ANY, {}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "create" + + def create(self, input: Any) -> tuple[torch.Tensor]: + if isinstance(input, torch.Tensor): + return (input,) + try: + return (torch.tensor(input),) + except Exception as e: + raise ValueError(f"Failed to create tensor from {type(input)}: {str(e)}") + +class TensorBinaryOp(ComfyNodeABC): + """ + Performs binary operations between two tensors or a tensor and a scalar. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": (IO.ANY, {}), + "b": (IO.ANY, {}), + "operation": (["add", "subtract", "multiply", "divide", "power", "remainder", "floor_divide"], {"default": "add"}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "operate" + + def operate(self, a: Any, b: Any, operation: str) -> tuple[torch.Tensor]: + a_tensor = a if isinstance(a, torch.Tensor) else torch.tensor(a) + b_tensor = b if isinstance(b, torch.Tensor) else torch.tensor(b) + + if operation == "add": + return (a_tensor + b_tensor,) + elif operation == "subtract": + return (a_tensor - b_tensor,) + elif operation == "multiply": + return (a_tensor * b_tensor,) + elif operation == "divide": + return (a_tensor / b_tensor,) + elif operation == "power": + return (torch.pow(a_tensor, b_tensor),) + elif operation == "remainder": + return (a_tensor % b_tensor,) + elif operation == "floor_divide": + return (a_tensor // b_tensor,) + else: + raise ValueError(f"Unknown operation: {operation}") + +class TensorUnaryOp(ComfyNodeABC): + """ + Performs unary operations on a tensor. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input": (IO.ANY, {}), + "operation": (["abs", "neg", "exp", "log", "sin", "cos", "sqrt", "sigmoid", "relu"], {"default": "abs"}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "operate" + + def operate(self, input: Any, operation: str) -> tuple[torch.Tensor]: + t = input if isinstance(input, torch.Tensor) else torch.tensor(input) + + if operation == "abs": + return (torch.abs(t),) + elif operation == "neg": + return (torch.neg(t),) + elif operation == "exp": + return (torch.exp(t),) + elif operation == "log": + return (torch.log(t),) + elif operation == "sin": + return (torch.sin(t),) + elif operation == "cos": + return (torch.cos(t),) + elif operation == "sqrt": + return (torch.sqrt(t),) + elif operation == "sigmoid": + return (torch.sigmoid(t),) + elif operation == "relu": + return (torch.relu(t),) + else: + raise ValueError(f"Unknown operation: {operation}") + +class TensorSlice(ComfyNodeABC): + """ + Slices a tensor using a slice string (e.g., ':, 0:10, 5'). + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "tensor": (IO.ANY, {}), + "slice_str": (IO.STRING, {"default": ":"}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "slice_tensor" + + def slice_tensor(self, tensor: Any, slice_str: str) -> tuple[torch.Tensor]: + if not isinstance(tensor, torch.Tensor): + tensor = torch.tensor(tensor) + + # Basic parsing of slice string + # This is a bit advanced for a simple node, but very useful. + # We'll use eval in a restricted way or manual parsing. + # Manual parsing is safer. + def parse_slice(s): + parts = s.split(':') + if len(parts) == 1: + return int(parts[0]) + return slice(*(int(p) if p.strip() else None for p in parts)) + + try: + dims = [d.strip() for d in slice_str.split(',')] + indices = tuple(parse_slice(d) if ':' in d or d.isdigit() or (d.startswith('-') and d[1:].isdigit()) else d for d in dims) + # Re-index with parsed slices + # Note: this is a simplification. torch supports more complex indexing. + # But for basic usage, this covers most cases. + return (tensor[indices],) + except Exception as e: + raise ValueError(f"Failed to slice tensor with '{slice_str}': {str(e)}") + +class TensorReshape(ComfyNodeABC): + """ + Reshapes a tensor. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "tensor": (IO.ANY, {}), + "shape": (IO.STRING, {"default": "-1"}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "reshape" + + def reshape(self, tensor: Any, shape: str) -> tuple[torch.Tensor]: + if not isinstance(tensor, torch.Tensor): + tensor = torch.tensor(tensor) + + try: + shape_tuple = tuple(int(s.strip()) for s in shape.split(',')) + return (tensor.reshape(shape_tuple),) + except Exception as e: + raise ValueError(f"Failed to reshape tensor to {shape}: {str(e)}") + +class TensorPermute(ComfyNodeABC): + """ + Permutes tensor dimensions. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "tensor": (IO.ANY, {}), + "dims": (IO.STRING, {"default": "0, 1"}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "permute" + + def permute(self, tensor: Any, dims: str) -> tuple[torch.Tensor]: + if not isinstance(tensor, torch.Tensor): + tensor = torch.tensor(tensor) + + try: + dims_tuple = tuple(int(d.strip()) for d in dims.split(',')) + return (tensor.permute(dims_tuple),) + except Exception as e: + raise ValueError(f"Failed to permute tensor with dims {dims}: {str(e)}") + +class TensorJoin(ComfyNodeABC): + """ + Joins multiple tensors (concatenate or stack). + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "tensor1": (IO.ANY, {}), + "tensor2": (IO.ANY, {}), + "dim": (IO.INT, {"default": 0}), + "mode": (["concatenate", "stack"], {"default": "concatenate"}), + } + } + + RETURN_TYPES = (IO.ANY,) + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "join" + + def join(self, tensor1: Any, tensor2: Any, dim: int, mode: str) -> tuple[torch.Tensor]: + t1 = tensor1 if isinstance(tensor1, torch.Tensor) else torch.tensor(tensor1) + t2 = tensor2 if isinstance(tensor2, torch.Tensor) else torch.tensor(tensor2) + + if mode == "concatenate": + return (torch.cat([t1, t2], dim=dim),) + else: + return (torch.stack([t1, t2], dim=dim),) + +class TensorInfo(ComfyNodeABC): + """ + Returns information about a tensor. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "tensor": (IO.ANY, {}), + } + } + + RETURN_TYPES = (IO.ANY, IO.STRING, IO.STRING) + RETURN_NAMES = ("shape", "dtype", "device") + CATEGORY = "Basic/tensor" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "get_info" + + def get_info(self, tensor: Any) -> tuple[list[int], str, str]: + if not isinstance(tensor, torch.Tensor): + tensor = torch.tensor(tensor) + + return (list(tensor.shape), str(tensor.dtype), str(tensor.device)) + +NODE_CLASS_MAPPINGS = { + "TensorCreate": TensorCreate, + "TensorBinaryOp": TensorBinaryOp, + "TensorUnaryOp": TensorUnaryOp, + "TensorSlice": TensorSlice, + "TensorReshape": TensorReshape, + "TensorPermute": TensorPermute, + "TensorJoin": TensorJoin, + "TensorInfo": TensorInfo, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "TensorCreate": "Tensor Create", + "TensorBinaryOp": "Tensor Binary Op", + "TensorUnaryOp": "Tensor Unary Op", + "TensorSlice": "Tensor Slice", + "TensorReshape": "Tensor Reshape", + "TensorPermute": "Tensor Permute", + "TensorJoin": "Tensor Join", + "TensorInfo": "Tensor Info", +} diff --git a/tests/test_tensor_nodes.py b/tests/test_tensor_nodes.py new file mode 100644 index 0000000..d206881 --- /dev/null +++ b/tests/test_tensor_nodes.py @@ -0,0 +1,93 @@ +import pytest +import torch +from src.basic_data_handling.tensor_nodes import ( + TensorCreate, TensorBinaryOp, TensorUnaryOp, TensorSlice, + TensorReshape, TensorPermute, TensorJoin, TensorInfo +) + +def test_tensor_create(): + node = TensorCreate() + # From list + result = node.create([1, 2, 3]) + assert isinstance(result[0], torch.Tensor) + assert torch.equal(result[0], torch.tensor([1, 2, 3])) + + # From scalar + result = node.create(5.0) + assert torch.equal(result[0], torch.tensor(5.0)) + + # From existing tensor + t = torch.tensor([1.0, 2.0]) + result = node.create(t) + assert result[0] is t + +def test_tensor_binary_op(): + node = TensorBinaryOp() + a = torch.tensor([10.0, 20.0]) + b = torch.tensor([2.0, 4.0]) + + assert torch.equal(node.operate(a, b, "add")[0], a + b) + assert torch.equal(node.operate(a, b, "subtract")[0], a - b) + assert torch.equal(node.operate(a, b, "multiply")[0], a * b) + assert torch.equal(node.operate(a, b, "divide")[0], a / b) + assert torch.equal(node.operate(a, 2.0, "power")[0], a ** 2.0) + +def test_tensor_unary_op(): + node = TensorUnaryOp() + t = torch.tensor([-1.0, 0.0, 1.0]) + + assert torch.equal(node.operate(t, "abs")[0], torch.tensor([1.0, 0.0, 1.0])) + assert torch.equal(node.operate(t, "neg")[0], torch.tensor([1.0, 0.0, -1.0])) + assert torch.equal(node.operate(torch.tensor([0.0]), "sin")[0], torch.tensor([0.0])) + +def test_tensor_slice(): + node = TensorSlice() + t = torch.arange(10).reshape(2, 5) # [[0,1,2,3,4], [5,6,7,8,9]] + + # Simple slice + result = node.slice_tensor(t, "0, 1:3") + assert torch.equal(result[0], t[0, 1:3]) + + # Ellipsis/Full slice + result = node.slice_tensor(t, ":, 2") + assert torch.equal(result[0], t[:, 2]) + +def test_tensor_reshape(): + node = TensorReshape() + t = torch.arange(6) + + result = node.reshape(t, "2, 3") + assert result[0].shape == (2, 3) + + result = node.reshape(t, "-1, 2") + assert result[0].shape == (3, 2) + +def test_tensor_permute(): + node = TensorPermute() + t = torch.randn(2, 3, 4) + + result = node.permute(t, "2, 0, 1") + assert result[0].shape == (4, 2, 3) + +def test_tensor_join(): + node = TensorJoin() + t1 = torch.tensor([1, 2]) + t2 = torch.tensor([3, 4]) + + # Cat + result = node.join(t1, t2, 0, "concatenate") + assert torch.equal(result[0], torch.tensor([1, 2, 3, 4])) + + # Stack + result = node.join(t1, t2, 0, "stack") + assert result[0].shape == (2, 2) + assert torch.equal(result[0], torch.tensor([[1, 2], [3, 4]])) + +def test_tensor_info(): + node = TensorInfo() + t = torch.zeros((2, 3), dtype=torch.float32) + + shape, dtype, device = node.get_info(t) + assert shape == [2, 3] + assert "float32" in dtype + assert isinstance(device, str) From 26616079b33c7ae89f76c14510ad9d939fe314ad Mon Sep 17 00:00:00 2001 From: StableLlama Date: Sun, 8 Feb 2026 21:52:02 +0100 Subject: [PATCH 2/4] Bump to 1.4.0, don't specify to install torch as that must be available already --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e2384ce..211a825 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "basic_data_handling" -version = "1.3.0" +version = "1.4.0" description = """Basic Python functions for manipulating data that every programmer is used to, lightweight with no additional dependencies. Supported data types: @@ -31,7 +31,7 @@ authors = [ readme = "README.md" license = { file = "LICENSE" } classifiers = [] -dependencies = ["torch"] +dependencies = [] [project.optional-dependencies] dev = [ From 2861936fa3287cf9e9d29dd8b0b5ccf893f50523 Mon Sep 17 00:00:00 2001 From: StableLlama Date: Sun, 8 Feb 2026 22:00:25 +0100 Subject: [PATCH 3/4] Fix linting --- src/basic_data_handling/tensor_nodes.py | 2 +- tests/test_tensor_nodes.py | 25 ++++++++++++------------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/basic_data_handling/tensor_nodes.py b/src/basic_data_handling/tensor_nodes.py index c5f4f51..9b825d6 100644 --- a/src/basic_data_handling/tensor_nodes.py +++ b/src/basic_data_handling/tensor_nodes.py @@ -1,6 +1,6 @@ import torch from inspect import cleandoc -from typing import Any, Union +from typing import Any try: from comfy.comfy_types.node_typing import IO, ComfyNodeABC diff --git a/tests/test_tensor_nodes.py b/tests/test_tensor_nodes.py index d206881..c404277 100644 --- a/tests/test_tensor_nodes.py +++ b/tests/test_tensor_nodes.py @@ -1,4 +1,3 @@ -import pytest import torch from src.basic_data_handling.tensor_nodes import ( TensorCreate, TensorBinaryOp, TensorUnaryOp, TensorSlice, @@ -11,11 +10,11 @@ def test_tensor_create(): result = node.create([1, 2, 3]) assert isinstance(result[0], torch.Tensor) assert torch.equal(result[0], torch.tensor([1, 2, 3])) - + # From scalar result = node.create(5.0) assert torch.equal(result[0], torch.tensor(5.0)) - + # From existing tensor t = torch.tensor([1.0, 2.0]) result = node.create(t) @@ -25,7 +24,7 @@ def test_tensor_binary_op(): node = TensorBinaryOp() a = torch.tensor([10.0, 20.0]) b = torch.tensor([2.0, 4.0]) - + assert torch.equal(node.operate(a, b, "add")[0], a + b) assert torch.equal(node.operate(a, b, "subtract")[0], a - b) assert torch.equal(node.operate(a, b, "multiply")[0], a * b) @@ -35,7 +34,7 @@ def test_tensor_binary_op(): def test_tensor_unary_op(): node = TensorUnaryOp() t = torch.tensor([-1.0, 0.0, 1.0]) - + assert torch.equal(node.operate(t, "abs")[0], torch.tensor([1.0, 0.0, 1.0])) assert torch.equal(node.operate(t, "neg")[0], torch.tensor([1.0, 0.0, -1.0])) assert torch.equal(node.operate(torch.tensor([0.0]), "sin")[0], torch.tensor([0.0])) @@ -43,11 +42,11 @@ def test_tensor_unary_op(): def test_tensor_slice(): node = TensorSlice() t = torch.arange(10).reshape(2, 5) # [[0,1,2,3,4], [5,6,7,8,9]] - + # Simple slice result = node.slice_tensor(t, "0, 1:3") assert torch.equal(result[0], t[0, 1:3]) - + # Ellipsis/Full slice result = node.slice_tensor(t, ":, 2") assert torch.equal(result[0], t[:, 2]) @@ -55,17 +54,17 @@ def test_tensor_slice(): def test_tensor_reshape(): node = TensorReshape() t = torch.arange(6) - + result = node.reshape(t, "2, 3") assert result[0].shape == (2, 3) - + result = node.reshape(t, "-1, 2") assert result[0].shape == (3, 2) def test_tensor_permute(): node = TensorPermute() t = torch.randn(2, 3, 4) - + result = node.permute(t, "2, 0, 1") assert result[0].shape == (4, 2, 3) @@ -73,11 +72,11 @@ def test_tensor_join(): node = TensorJoin() t1 = torch.tensor([1, 2]) t2 = torch.tensor([3, 4]) - + # Cat result = node.join(t1, t2, 0, "concatenate") assert torch.equal(result[0], torch.tensor([1, 2, 3, 4])) - + # Stack result = node.join(t1, t2, 0, "stack") assert result[0].shape == (2, 2) @@ -86,7 +85,7 @@ def test_tensor_join(): def test_tensor_info(): node = TensorInfo() t = torch.zeros((2, 3), dtype=torch.float32) - + shape, dtype, device = node.get_info(t) assert shape == [2, 3] assert "float32" in dtype From f306fd4c2dd9b01e4088887ce839d2c37d3cc81c Mon Sep 17 00:00:00 2001 From: StableLlama Date: Sun, 8 Feb 2026 22:06:32 +0100 Subject: [PATCH 4/4] Fix tests --- .github/workflows/validate.yml | 7 +++++++ tests/test_control_flow_nodes.py | 26 +++++++++++++++++++------- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index 8d3b5d4..94cc78d 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -10,4 +10,11 @@ jobs: validate: runs-on: ubuntu-latest steps: + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install torch - uses: comfy-org/node-diff@main diff --git a/tests/test_control_flow_nodes.py b/tests/test_control_flow_nodes.py index 595e3fd..8a8d306 100644 --- a/tests/test_control_flow_nodes.py +++ b/tests/test_control_flow_nodes.py @@ -156,25 +156,37 @@ def test_continue_flow(): # select is True assert node.execute("some value", select=True) == ("some value",) # select is False - assert node.execute("some value", select=False) == (None,) # ExecutionBlocker(None) becomes None + res = node.execute("some value", select=False) + assert res == (None,) or "ExecutionBlocker" in str(res[0]) # select is default (True) assert node.execute("some value") == ("some value",) # Test with different value types assert node.execute(123, select=True) == (123,) - assert node.execute([1, 2], select=False) == (None,) + res = node.execute([1, 2], select=False) + assert res == (None,) or "ExecutionBlocker" in str(res[0]) def test_flow_select(): node = FlowSelect() # select is True - assert node.select("some value", select=True) == ("some value", None) + res = node.select("some value", select=True) + assert res[0] == "some value" + assert res[1] is None or "ExecutionBlocker" in str(res[1]) # select is False - assert node.select("some value", select=False) == (None, "some value") + res = node.select("some value", select=False) + assert res[0] is None or "ExecutionBlocker" in str(res[0]) + assert res[1] == "some value" # select is default (True) - assert node.select("some value") == ("some value", None) + res = node.select("some value") + assert res[0] == "some value" + assert res[1] is None or "ExecutionBlocker" in str(res[1]) # Test with different value types - assert node.select(42, select=True) == (42, None) - assert node.select({'a': 1}, select=False) == (None, {'a': 1}) + res = node.select(42, select=True) + assert res[0] == 42 + assert res[1] is None or "ExecutionBlocker" in str(res[1]) + res = node.select({'a': 1}, select=False) + assert res[0] is None or "ExecutionBlocker" in str(res[0]) + assert res[1] == {'a': 1} def test_force_calculation():