diff --git a/src/jobflow/core/job.py b/src/jobflow/core/job.py index aa084923..d00501b2 100644 --- a/src/jobflow/core/job.py +++ b/src/jobflow/core/job.py @@ -1265,7 +1265,13 @@ class Response(typing.Generic[T]): output: T = None detour: jobflow.Flow | Job | list[Job] | list[jobflow.Flow] = None addition: jobflow.Flow | Job | list[Job] | list[jobflow.Flow] = None - replace: jobflow.Flow | Job | list[Job] | list[jobflow.Flow] = None + replace: ( + jobflow.Flow + | Job + | tuple[jobflow.Job | jobflow.Flow] + | list[jobflow.Job | jobflow.Flow] + | dict[Any, jobflow.Job | jobflow.Flow] + ) = None stored_data: dict[Hashable, Any] = None stop_children: bool = False stop_jobflow: bool = False @@ -1312,8 +1318,13 @@ def is_job_or_flow(x): return isinstance(x, Job | Flow) should_replace = is_job_or_flow(job_returns) - if job_returns and isinstance(job_returns, (list, tuple)): - should_replace = all(is_job_or_flow(resp) for resp in job_returns) + if job_returns is not None: + if isinstance(job_returns, (list, tuple)) and len(job_returns) > 0: + should_replace = all(is_job_or_flow(resp) for resp in job_returns) + elif isinstance(job_returns, dict) and len(job_returns) > 0: + should_replace = all( + is_job_or_flow(resp) for resp in job_returns.values() + ) if should_replace: job_returns = Response(replace=job_returns) @@ -1334,6 +1345,13 @@ def is_job_or_flow(x): raise ValueError( "Response cannot be returned in combination with other outputs." ) + elif isinstance(job_returns, dict): + # check that a Response object is not given as one of many outputs + for resp in job_returns.values(): + if isinstance(resp, Response): + raise ValueError( + "Response cannot be returned in combination with other outputs." + ) return cls(output=apply_schema(job_returns, output_schema), job_dir=job_dir) @@ -1393,7 +1411,11 @@ def store_inputs(inputs: Any) -> Any: def prepare_replace( - replace: jobflow.Flow | Job | list[Job], + replace: jobflow.Flow + | Job + | tuple[jobflow.Job | jobflow.Flow] + | list[jobflow.Job | jobflow.Flow] + | dict[Any, jobflow.Job | jobflow.Flow], current_job: Job, ) -> jobflow.Flow: """ @@ -1419,8 +1441,16 @@ def prepare_replace( """ from jobflow.core.flow import Flow + if isinstance(replace, dict): + output = {k: j.output for k, j in replace.items()} + replace = Flow(jobs=list(replace.values()), output=output) + if isinstance(replace, (list, tuple)): - replace = Flow(jobs=replace) + if isinstance(replace, tuple): + seq_output: tuple | list = tuple(j.output for j in replace) + else: + seq_output = [j.output for j in replace] + replace = Flow(jobs=list(replace), output=seq_output) if isinstance(replace, Flow) and replace.output is not None: # add a job with same UUID as the current job to store the outputs of the diff --git a/tests/core/test_flow.py b/tests/core/test_flow.py index 8b78c25b..2335ac9b 100644 --- a/tests/core/test_flow.py +++ b/tests/core/test_flow.py @@ -1376,7 +1376,7 @@ def add_distributed_list(list_a): job1 = make_list_of_n(2, 3) dynamic_job = add_distributed_list(job1.output) flow = Flow([job1, dynamic_job]) - results = run_locally(flow, store=memory_jobstore) + results = run_locally(flow, store=memory_jobstore, ensure_success=True) # Ensure the final result (3 instances of 3s) is in the results. all_responses = [ @@ -1390,10 +1390,55 @@ def add_distributed_list(list_a): # as a replace job1 = make_list_of_n(2, 0) flow = Flow([job1]) - responses = run_locally(flow, store=memory_jobstore) + responses = run_locally(flow, store=memory_jobstore, ensure_success=True) assert responses[job1.uuid][1].output == [] +def test_job_autoreplace_dict(memory_jobstore): + # test to check if a job that returns a dict of jobs + # is interpreted as a replace. + from jobflow import Flow, job + from jobflow.managers.local import run_locally + + @job + def add(x, y): + return x + y + + @job + def make_named_jobs(a, b): + return {"sum": add(a, b), "diff": add(a, -b)} + + @job + def double(x): + return x * 2 + + job1 = make_named_jobs(3, 2) + job2 = double(job1.output["sum"]) + flow = Flow([job1, job2]) + results = run_locally(flow, store=memory_jobstore, ensure_success=True) + + assert results[job2.uuid][1].output == 10 + + all_outputs = [ + response.output + for index_to_response in results.values() + for response in index_to_response.values() + ] + assert 5 in all_outputs # sum + assert 1 in all_outputs # diff + + # check that if an empty dict is returned the output is not misinterpreted + # as a replace + @job + def make_empty_dict(): + return {} + + job1 = make_empty_dict() + flow = Flow([job1]) + responses = run_locally(flow, store=memory_jobstore, ensure_success=True) + assert responses[job1.uuid][1].output == {} + + def test_get_item(): from jobflow import Flow, job, run_locally @@ -1412,3 +1457,29 @@ def capitalize(s): responses = run_locally(flow, ensure_success=True) assert responses[job2.uuid][1].output == "WORLD" + + +def test_flow_no_truthiness(memory_jobstore): + # Test to check if a job that returns a value for which truthiness on the output + # cannot be ascertained works correctly. + # This can happen when a job returns a numpy array, for example. + + from monty.json import MSONable + + from jobflow import Flow, job + from jobflow.managers.local import run_locally + + class NoTruthiness(MSONable): + def __bool__(self): + raise ValueError("No truthiness") + + no_truthiness_obj = NoTruthiness() + + @job + def no_truthiness_job(): + return no_truthiness_obj + + job1 = no_truthiness_job() + flow = Flow([job1]) + results = run_locally(flow, store=memory_jobstore, ensure_success=True) + assert results[job1.uuid][1].output == no_truthiness_obj diff --git a/tests/core/test_job.py b/tests/core/test_job.py index d1c8f7f2..58bd00f2 100644 --- a/tests/core/test_job.py +++ b/tests/core/test_job.py @@ -135,6 +135,7 @@ def add_response(a, b): def test_replace_response(memory_jobstore): from jobflow import Flow, Job, Response + from jobflow.core.job import store_inputs def replace_job(): job = Job(add, function_args=(1,)) @@ -183,9 +184,11 @@ def replace_list_flow(): test_job = Job(replace_list_job, metadata=metadata, output_schema="123") response = test_job.run(memory_jobstore) assert isinstance(response.replace, Flow) - assert response.replace[-1].function == add - assert len(response.replace) == 2 - # currently output schema and metadata ignored in this case + assert response.replace[-2].function == add + assert response.replace[-1].function == store_inputs.original + + assert len(response.replace) == 3 + # currently output schema and metadata ignored for all but the last `store_inputs` for j in response.replace: assert j.hosts == [response.replace.uuid] @@ -218,12 +221,13 @@ def replace_list_flow(): test_job = Job(replace_list_flow, metadata=metadata, output_schema="123") response = test_job.run(memory_jobstore) assert isinstance(response.replace, Flow) - assert isinstance(response.replace[-1], Flow) - assert len(response.replace) == 2 + assert isinstance(response.replace[-2], Flow) + assert len(response.replace) == 3 for f in response.replace: - for j in f: - assert j.hosts == [f.uuid, response.replace.uuid] - # currently output schema and metadata ignored in this case + if isinstance(f, Flow): + for j in f: + assert j.hosts == [f.uuid, response.replace.uuid] + # currently output schema and metadata ignored for all but the last `store_inputs` def test_job_config(memory_jobstore): @@ -352,7 +356,7 @@ def detour_job(): # test replace with flow test_job = Job(replace_flow, config=nopass_config) response = test_job.run(memory_jobstore) - for j in response.replace: + for j in response.replace[:-1]: assert j.config.manager_config == {} test_job = Job(replace_flow, config=pass_config)