Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ classifiers = [
dependencies = [
"numpy",
"scipy",
"node-graph~=0.6.3",
"node-graph~=0.6.5",
"node-graph-widget>=0.0.5",
"aiida-core~=2.7.1",
"cloudpickle",
Expand Down
24 changes: 24 additions & 0 deletions src/aiida_workgraph/executors/builtins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any
from aiida import orm


def UnavailableExecutor(*args, **kwargs):
Expand All @@ -7,12 +8,14 @@ def UnavailableExecutor(*args, **kwargs):

def get_context(context: dict, key: str) -> Any:
"""Get the context value."""
key = key.value if isinstance(key, orm.Str) else key
results = {'result': context._task_results['graph_ctx'].get(key)}
return results


def update_ctx(context: dict, key: str, value: Any) -> None:
"""Set the context value."""
key = key.value if isinstance(key, orm.Str) else key
context._task_results['graph_ctx'][key] = value


Expand All @@ -31,3 +34,24 @@ def get_item(data: dict, key: str) -> Any:
def return_input(**kwargs: Any) -> dict:
"""Return the input"""
return kwargs


def load_node(pk: int = None, uuid: str = None) -> orm.Node:
"""Load an AiiDA node by its primary key or UUID."""
if uuid is not None:
pk = uuid.value if isinstance(uuid, orm.Str) else uuid
else:
pk = pk.value if isinstance(pk, orm.Int) else pk
return orm.load_node(pk)


def load_code(pk: int = None, uuid: str = None, label: str = None) -> orm.Code:
"""Load an AiiDA code by its primary key or UUID."""
if uuid is not None:
pk = uuid.value if isinstance(uuid, orm.Str) else uuid
elif label is not None:
pk = label.value if isinstance(label, orm.Str) else label
else:
pk = pk.value if isinstance(pk, orm.Int) else pk
print(f'Loading code with pk: {pk}')
return orm.load_code(pk)
38 changes: 38 additions & 0 deletions src/aiida_workgraph/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from typing import Any, Dict, Optional

from aiida_pythonjob.data.serializer import all_serializers
from aiida_pythonjob.utils import serialize_ports
from node_graph.serializer import SerializationAdapter
from node_graph.utils import resolve_tagged_values


class AiidaSerializationAdapter(SerializationAdapter):
id: str = 'aiida'
name: str = 'AiiDA'

def __init__(self, serializers: Optional[Dict[str, str]] = None, user: Any = None) -> None:
self.serializers = serializers or all_serializers
self.user = user

def serialize(self, value: Any, socket: Any, *, store: bool) -> Any:
if socket is None:
return value
spec = socket._to_spec()
resolve_tagged_values(value)
return serialize_ports(
python_data=value,
port_schema=spec,
serializers=self.serializers,
user=self.user,
)

def serialize_ports(self, python_data: Any, port_schema: Any, *, store: bool) -> Any:
resolve_tagged_values(python_data)
return serialize_ports(
python_data=python_data,
port_schema=port_schema,
serializers=self.serializers,
user=self.user,
)
7 changes: 5 additions & 2 deletions src/aiida_workgraph/tasks/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Annotated
from aiida_workgraph.executors.builtins import update_ctx, get_context, select, return_input
from node_graph.task import BuiltinPolicy
from aiida_workgraph.executors.builtins import load_node, load_code


class GraphLevelTask(_GraphIOSharedMixin, Task):
Expand All @@ -37,6 +38,8 @@ class Zone(Task):
identifier='workgraph.zone',
task_type='ZONE',
catalog='Control',
inputs=namespace(),
outputs=namespace(),
base_class_path='aiida_workgraph.tasks.builtins.Zone',
)

Expand Down Expand Up @@ -258,7 +261,7 @@ class AiiDANode(Task):
uuid=Annotated[str, meta(required=False)],
),
outputs=namespace(node=orm.Node),
executor=RuntimeExecutor.from_callable(orm.load_node),
executor=RuntimeExecutor.from_callable(load_node),
base_class_path='aiida_workgraph.task.Task',
)

Expand All @@ -279,6 +282,6 @@ class AiiDACode(Task):
label=Annotated[str, meta(required=False)],
),
outputs=namespace(code=orm.Code),
executor=RuntimeExecutor.from_callable(orm.load_code),
executor=RuntimeExecutor.from_callable(load_code),
base_class_path='aiida_workgraph.task.Task',
)
29 changes: 0 additions & 29 deletions src/aiida_workgraph/tasks/pythonjob_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from typing import Any, Dict, Optional, Callable, Annotated
from aiida import orm
from aiida.common.extendeddicts import AttributeDict
from aiida_pythonjob.data.serializer import all_serializers
from aiida_workgraph.utils import create_and_pause_process
from aiida.engine import run_get_node
from aiida_pythonjob import pyfunction, PythonJob, PyFunction, MonitorPyFunction
from aiida_pythonjob.utils import serialize_ports
from aiida_workgraph.task import Task
from node_graph.socket_spec import SocketSpec, SocketSpecSelect, SocketMeta
from node_graph.task_spec import TaskSpec
Expand All @@ -25,33 +23,6 @@ class BaseSerializablePythonTask(Task):
Subclasses must implement their own `execute` method.
"""

def serialize_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Called during Task -> dict conversion. We walk over the input sockets
and run our specialized Python serialization.
"""
function_inputs = {key: data['inputs'][key] for key in data['inputs'] if key not in self.non_function_inputs}
serialized_inputs = serialize_ports(
python_data=function_inputs,
port_schema=self.spec.inputs,
serializers=all_serializers,
)
data['inputs'].update(serialized_inputs)

def update_from_dict(self, data: Dict[str, Any], **kwargs) -> 'BaseSerializablePythonTask':
"""
Called when reloading from a dict. Note, we do not run `_deserialize_python_data` here.
Thus, the value of the socket will be AiiDA data nodes.
"""
super().update_from_dict(data, **kwargs)
return self

def execute(self, *args, **kwargs):
"""
Subclasses must override.
"""
raise NotImplementedError('Subclasses must implement `execute`.')

@property
def non_function_inputs(self):
return self.spec.metadata.get('non_function_inputs', [])
Expand Down
21 changes: 13 additions & 8 deletions src/aiida_workgraph/tasks/shelljob_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
from aiida import orm


def _serialize_value(self, store: bool = False) -> Any:
from node_graph.utils import resolve_tagged_values

value = resolve_tagged_values(self._value)
if value is None:
return None
return RuntimeExecutor.from_callable(value).to_dict()


class ShellJobTask(Task):
"""Runtime for ShellJob nodes.

Expand All @@ -25,14 +34,10 @@ class ShellJobTask(Task):
task_type = 'SHELLJOB'
catalog = 'AIIDA'

def serialize_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Overwrite the serialize_data method to handle the parser function."""
import inspect

parser = data['inputs'].get('parser')
if parser is not None:
if inspect.isfunction(parser):
data['inputs']['parser'] = RuntimeExecutor.from_callable(parser).to_dict()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# override the _serialize_value
self.inputs['parser'].set_serializer(_serialize_value)

def execute(self, engine_process, args=None, kwargs=None, var_kwargs=None):
"""Submit/launch the AiiDA ShellJob.
Expand Down
9 changes: 3 additions & 6 deletions src/aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,12 +528,9 @@ def make_json_serializable(data):

def resolve_tagged_values(inputs: Dict[str, Any]) -> None:
"""Recursively resolve all TaggedValue either in a dictionary or a TaggedValue."""
if isinstance(inputs, dict):
for key, value in inputs.items():
if isinstance(value, TaggedValue):
inputs[key] = value.__wrapped__
else:
resolve_tagged_values(value)
from node_graph.utils import resolve_tagged_values as _resolve_tagged_values

_resolve_tagged_values(inputs)


def serialize_graph_level_data(
Expand Down
22 changes: 14 additions & 8 deletions src/aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(
inputs: Optional[type | List[str]] = None,
outputs: Optional[type | List[str]] = None,
error_handlers: Optional[Dict[str, ErrorHandlerSpec]] = None,
serialization: Optional[object] = None,
serialization_policy: str = 'off',
**kwargs,
) -> None:
"""
Expand All @@ -47,7 +49,18 @@ def __init__(
name (str, optional): The name of the WorkGraph. Defaults to 'WorkGraph'.
**kwargs: Additional keyword arguments to be passed to the WorkGraph class.
"""
super().__init__(name, inputs=inputs, outputs=outputs, **kwargs)
from aiida_workgraph.serialization import AiidaSerializationAdapter

if serialization is None:
serialization = AiidaSerializationAdapter()
super().__init__(
name,
inputs=inputs,
outputs=outputs,
serialization=serialization,
serialization_policy=serialization_policy,
**kwargs,
)
self.process = None
self.restart_process = None
self.max_number_jobs = 1000000
Expand Down Expand Up @@ -223,8 +236,6 @@ def build_connectivity(self) -> None:
def to_dict(self, include_sockets: bool = False, should_serialize: bool = False) -> Dict[str, Any]:
"""Convert the workgraph to a dictionary."""
from aiida.orm.utils.serialize import serialize
from aiida_workgraph.utils import serialize_graph_level_data
from aiida_pythonjob.data.serializer import all_serializers

wgdata = super().to_dict(include_sockets=include_sockets, should_serialize=should_serialize)
wgdata.update(
Expand All @@ -239,11 +250,6 @@ def to_dict(self, include_sockets: bool = False, should_serialize: bool = False)
wgdata['connectivity'] = self.build_connectivity()
wgdata['process'] = serialize(self.process) if self.process else serialize(None)
wgdata['metadata']['pk'] = self.process.pk if self.process else None
if should_serialize:
# serialize the graph-level tasks
wgdata['tasks']['graph_inputs']['inputs'] = serialize_graph_level_data(
wgdata['tasks']['graph_inputs']['inputs'], self.spec.inputs, all_serializers
)

return wgdata

Expand Down
2 changes: 1 addition & 1 deletion tests/test_data_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ def test_load_code_task(add_code) -> None:
"""Test AiiDA Code task."""

wg = WorkGraph('test_load_code_task')
task1 = wg.add_task('workgraph.load_code', name='task1', label=add_code.label)
task1 = wg.add_task('workgraph.load_code', name='task1', label=add_code.full_label)
wg.run()
assert task1.outputs.code.value.label == add_code.label
11 changes: 3 additions & 8 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from aiida_workgraph import WorkGraph, task
import pytest


@task.graph()
Expand All @@ -11,11 +12,5 @@ def test_func_as_input(capsys):

wg = WorkGraph('test_func_as_input')
wg.add_task(sub_workflow, func=add, name='sub_workflow')
wg.save()

# load and capture stdout
loaded_wg = WorkGraph.load(wg.pk)
captured = capsys.readouterr()

assert 'Info: could not deserialize input' in captured.out
assert 'sub_workflow' in loaded_wg.tasks
with pytest.raises(Exception, match='Cannot serialize the provided object'):
wg.save()
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.