From 66296eac1648941ccb73c8d0c86d65f85c654568 Mon Sep 17 00:00:00 2001 From: Alex Ganose Date: Wed, 8 Nov 2023 13:29:43 +0000 Subject: [PATCH] Fix JobStoreDocument validator --- src/jobflow/core/job.py | 2 +- src/jobflow/core/schemas.py | 34 ++++++++++ src/jobflow/core/store.py | 2 +- src/jobflow/schemas/job_output_schema.py | 63 ------------------- ...t_job_output_schema.py => test_schemas.py} | 4 +- 5 files changed, 38 insertions(+), 67 deletions(-) create mode 100644 src/jobflow/core/schemas.py delete mode 100644 src/jobflow/schemas/job_output_schema.py rename tests/core/{test_job_output_schema.py => test_schemas.py} (95%) diff --git a/src/jobflow/core/job.py b/src/jobflow/core/job.py index e7ba11a3..f11e1917 100644 --- a/src/jobflow/core/job.py +++ b/src/jobflow/core/job.py @@ -10,7 +10,6 @@ from monty.json import MSONable, jsanitize from jobflow.core.reference import OnMissing, OutputReference -from jobflow.schemas.job_output_schema import JobStoreDocument from jobflow.utils.uuid import suuid if typing.TYPE_CHECKING: @@ -560,6 +559,7 @@ def run(self, store: jobflow.JobStore) -> Response: from jobflow import CURRENT_JOB from jobflow.core.flow import get_flow + from jobflow.core.schemas import JobStoreDocument index_str = f", {self.index}" if self.index != 1 else "" logger.info(f"Starting job - {self.name} ({self.uuid}{index_str})") diff --git a/src/jobflow/core/schemas.py b/src/jobflow/core/schemas.py new file mode 100644 index 00000000..09c84107 --- /dev/null +++ b/src/jobflow/core/schemas.py @@ -0,0 +1,34 @@ +"""A Pydantic model for Jobstore document.""" + +from typing import Any + +from pydantic import BaseModel, Field + + +class JobStoreDocument(BaseModel): + """A Pydantic model for Jobstore document.""" + + uuid: str = Field( + None, description="An unique identifier for the job. Generated automatically." + ) + index: int = Field( + None, + description="The index of the job (number of times the job has been replaced).", + ) + output: Any = Field( + None, + description="This is a reference to the future job output.", + ) + completed_at: str = Field(None, description="The time the job was completed.") + metadata: dict = Field( + None, + description="Metadata information supplied by the user.", + ) + hosts: list[str] = Field( + None, + description="The list of UUIDs of the hosts containing the job.", + ) + name: str = Field( + None, + description="The name of the job.", + ) diff --git a/src/jobflow/core/store.py b/src/jobflow/core/store.py index 923cbb46..45da2246 100644 --- a/src/jobflow/core/store.py +++ b/src/jobflow/core/store.py @@ -18,7 +18,7 @@ from maggma.core import Sort - from jobflow.schemas.job_output_schema import JobStoreDocument + from jobflow.core.schemas import JobStoreDocument obj_type = Union[str, Enum, type[MSONable], list[Union[Enum, str, type[MSONable]]]] save_type = Optional[dict[str, obj_type]] diff --git a/src/jobflow/schemas/job_output_schema.py b/src/jobflow/schemas/job_output_schema.py deleted file mode 100644 index da08d98b..00000000 --- a/src/jobflow/schemas/job_output_schema.py +++ /dev/null @@ -1,63 +0,0 @@ -"""A Pydantic model for Jobstore document.""" - -from typing import Generic, TypeVar - -from monty.json import MontyDecoder -from pydantic import BaseModel, Field, field_validator - -T = TypeVar("T") - - -class JobStoreDocument(BaseModel, Generic[T]): - """A Pydantic model for Jobstore document.""" - - uuid: str = Field( - None, description="An unique identifier for the job. Generated automatically." - ) - index: int = Field( - None, - description="The index of the job (number of times the job has been replaced).", - ) - output: T = Field( - None, - description="This is a reference to the future job output.", - ) - completed_at: str = Field(None, description="The time the job was completed.") - metadata: dict = Field( - None, - description="Metadeta information supplied by the user.", - ) - hosts: list[str] = Field( - None, - description="The list of UUIDs of the hosts containing the job.", - ) - name: str = Field( - None, - description="The name of the job.", - ) - - @field_validator("output", mode="before") - @classmethod - def reserialize_output(cls, v): - """ - Pre-validator for the 'output' field. - - This method checks if the input 'v' is a dictionary with specific keys - ('@module' and '@class'). If these keys are present, it reprocesses - the input dictionary using MontyDecoder to deserialize it. - - Parameters - ---------- - cls : Type[JobStoreDocument] - The class this validator is applied to. - v : Any - The input value to validate. - - Returns - ------- - Any - The validated and potentially deserialized value. - """ - if isinstance(v, dict) and "@module" in v and "@class" in v: - v = MontyDecoder().process_decoded(v) - return v diff --git a/tests/core/test_job_output_schema.py b/tests/core/test_schemas.py similarity index 95% rename from tests/core/test_job_output_schema.py rename to tests/core/test_schemas.py index 2db2c330..84af8ace 100644 --- a/tests/core/test_job_output_schema.py +++ b/tests/core/test_schemas.py @@ -5,7 +5,7 @@ @pytest.fixture() def sample_data(): - from jobflow.schemas.job_output_schema import JobStoreDocument + from jobflow.core.schemas import JobStoreDocument return JobStoreDocument( uuid="abc123", @@ -33,7 +33,7 @@ def test_job_store_document_model(sample_data): def test_job_store_update(memory_jobstore, sample_data): # Storing document as a JobStoreDocument - from jobflow.schemas.job_output_schema import JobStoreDocument + from jobflow.core.schemas import JobStoreDocument d = { "index": 1,