diff --git a/src/jobflow/core/job.py b/src/jobflow/core/job.py index 82726d09..c6dc1fe5 100644 --- a/src/jobflow/core/job.py +++ b/src/jobflow/core/job.py @@ -10,6 +10,7 @@ from monty.json import MSONable, jsanitize from jobflow.core.reference import OnMissing, OutputReference +from jobflow.schemas.job_output_schema import JobStoreDocument from jobflow.utils.uuid import suuid if typing.TYPE_CHECKING: @@ -633,15 +634,15 @@ def run(self, store: jobflow.JobStore) -> Response: ) from err save = {k: "output" if v is True else v for k, v in self._kwargs.items()} - data = { - "uuid": self.uuid, - "index": self.index, - "output": output, - "completed_at": datetime.now().isoformat(), - "metadata": self.metadata, - "hosts": self.hosts, - "name": self.name, - } + data: JobStoreDocument = JobStoreDocument( + uuid=self.uuid, + index=self.index, + output=output, + completed_at=datetime.now().isoformat(), + metadata=self.metadata, + hosts=self.hosts, + name=self.name, + ) store.update(data, key=["uuid", "index"], save=save) CURRENT_JOB.reset() diff --git a/src/jobflow/core/store.py b/src/jobflow/core/store.py index b665e3cd..b2e31574 100644 --- a/src/jobflow/core/store.py +++ b/src/jobflow/core/store.py @@ -8,6 +8,7 @@ from monty.json import MSONable from jobflow.core.reference import OnMissing +from jobflow.schemas.job_output_schema import JobStoreDocument from jobflow.utils.find import get_root_locations if typing.TYPE_CHECKING: @@ -254,7 +255,7 @@ def query_one( def update( self, - docs: list[dict] | dict, + docs: list[dict] | dict | JobStoreDocument | list[JobStoreDocument], key: list | str = None, save: bool | save_type = None, ): @@ -264,7 +265,7 @@ def update( Parameters ---------- docs - The document or list of documents to update. + The Pydantic document or list of Pydantic documents to update. key Field name(s) to determine uniqueness for a document, can be a list of multiple fields, a single field, or None if the Store's key field is to diff --git a/src/jobflow/schemas/job_output_schema.py b/src/jobflow/schemas/job_output_schema.py new file mode 100644 index 00000000..b061c16f --- /dev/null +++ b/src/jobflow/schemas/job_output_schema.py @@ -0,0 +1,63 @@ +"""A Pydantic model for Jobstore document.""" + +from typing import Generic, List, TypeVar + +from monty.json import MontyDecoder +from pydantic import BaseModel, Field, field_validator + +T = TypeVar("T") + + +class JobStoreDocument(BaseModel, Generic[T]): + """A Pydantic model for Jobstore document.""" + + uuid: str = Field( + None, description="An unique identifier for the job. Generated automatically." + ) + index: int = Field( + None, + description="The index of the job (number of times the job has been replaced).", + ) + output: T = Field( + None, + description="This is a reference to the future job output.", + ) + completed_at: str = Field(None, description="The time the job was completed.") + metadata: dict = Field( + None, + description="Metadeta information supplied by the user.", + ) + hosts: List[str] = Field( + None, + description="The list of UUIDs of the hosts containing the job.", + ) + name: str = Field( + None, + description="The name of the job.", + ) + + @field_validator("output", mode="before") + @classmethod + def reserialize_output(cls, v): + """ + Pre-validator for the 'output' field. + + This method checks if the input 'v' is a dictionary with specific keys + ('@module' and '@class'). If these keys are present, it reprocesses + the input dictionary using MontyDecoder to deserialize it. + + Parameters + ---------- + cls : Type[JobStoreDocument] + The class this validator is applied to. + v : Any + The input value to validate. + + Returns + ------- + Any + The validated and potentially deserialized value. + """ + if isinstance(v, dict) and "@module" in v and "@class" in v: + v = MontyDecoder().process_decoded(v) + return v diff --git a/tests/core/test_job_output_schema.py b/tests/core/test_job_output_schema.py new file mode 100644 index 00000000..99777e7e --- /dev/null +++ b/tests/core/test_job_output_schema.py @@ -0,0 +1,87 @@ +from datetime import datetime + +import pytest + + +@pytest.fixture +def sample_data(): + from jobflow.schemas.job_output_schema import JobStoreDocument + + return JobStoreDocument( + uuid="abc123", + index=1, + output=None, + completed_at=datetime.now().isoformat(), + metadata={"key": "value"}, + hosts=["host1", "host2"], + name="my_job", + ) + + +def test_job_store_document_model(sample_data): + # Test creating model + data = sample_data + + assert data.uuid == "abc123" + assert data.index == 1 + assert data.output is None + assert datetime.fromisoformat(data.completed_at).hour == datetime.now().hour + assert data.metadata == {"key": "value"} + assert data.hosts == ["host1", "host2"] + assert data.name == "my_job" + + +def test_job_store_update(memory_jobstore, sample_data): + # Storing document as a JobStoreDocument + from jobflow.schemas.job_output_schema import JobStoreDocument + + d = { + "index": 1, + "uuid": "abc123", + "metadata": {"key": "value"}, + "hosts": ["host1", "host2"], + "name": "my_job", + "e": 6, + "d": 4, + } + sample_data = JobStoreDocument(**d) + memory_jobstore.update(sample_data) + + # Check document was inserted + results = memory_jobstore.query_one(criteria={"hosts": {"$exists": 1}}) + assert results["index"] == 1 + assert results["uuid"] == "abc123" + assert results["metadata"] == {"key": "value"} + assert results["hosts"] == ["host1", "host2"] + assert results["name"] == "my_job" + assert "e" not in results + assert "d" not in results + + # Further checks to see if two documents get inserted + e = d.copy() + e["uuid"] = "def456" + new_data_e = JobStoreDocument(**e) + f = d.copy() + f["uuid"] = "ghi789" + new_data_f = JobStoreDocument(**f) + memory_jobstore.update([new_data_e, new_data_f]) + + # Check if document new_data_e is present in the store + results = memory_jobstore.query_one(criteria={"uuid": "def456"}) + assert results["index"] == 1 + assert results["uuid"] == "def456" + assert results["metadata"] == {"key": "value"} + assert results["hosts"] == ["host1", "host2"] + assert results["name"] == "my_job" + assert "e" not in results + assert "d" not in results + + # Check if document new_data_f is present in the store + results = memory_jobstore.query_one(criteria={"uuid": "ghi789"}) + assert results["index"] == 1 + assert results["uuid"] == "ghi789" + assert results["metadata"] == {"key": "value"} + assert results["hosts"] == ["host1", "host2"] + assert results["name"] == "my_job" + assert "e" not in results + assert "d" not in results