Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions src/jobflow/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,13 @@ class Response(typing.Generic[T]):
output: T = None
detour: jobflow.Flow | Job | list[Job] | list[jobflow.Flow] = None
addition: jobflow.Flow | Job | list[Job] | list[jobflow.Flow] = None
replace: jobflow.Flow | Job | list[Job] | list[jobflow.Flow] = None
replace: (
jobflow.Flow
| Job
| tuple[jobflow.Job | jobflow.Flow]
| list[jobflow.Job | jobflow.Flow]
| dict[Any, jobflow.Job | jobflow.Flow]
) = None
stored_data: dict[Hashable, Any] = None
stop_children: bool = False
stop_jobflow: bool = False
Expand Down Expand Up @@ -1312,8 +1318,13 @@ def is_job_or_flow(x):
return isinstance(x, Job | Flow)

should_replace = is_job_or_flow(job_returns)
if job_returns and isinstance(job_returns, (list, tuple)):
should_replace = all(is_job_or_flow(resp) for resp in job_returns)
if job_returns is not None:
if isinstance(job_returns, (list, tuple)) and len(job_returns) > 0:
should_replace = all(is_job_or_flow(resp) for resp in job_returns)
elif isinstance(job_returns, dict) and len(job_returns) > 0:
should_replace = all(
is_job_or_flow(resp) for resp in job_returns.values()
)

if should_replace:
job_returns = Response(replace=job_returns)
Expand All @@ -1334,6 +1345,13 @@ def is_job_or_flow(x):
raise ValueError(
"Response cannot be returned in combination with other outputs."
)
elif isinstance(job_returns, dict):
# check that a Response object is not given as one of many outputs
for resp in job_returns.values():
if isinstance(resp, Response):
raise ValueError(
"Response cannot be returned in combination with other outputs."
)

return cls(output=apply_schema(job_returns, output_schema), job_dir=job_dir)

Expand Down Expand Up @@ -1393,7 +1411,11 @@ def store_inputs(inputs: Any) -> Any:


def prepare_replace(
replace: jobflow.Flow | Job | list[Job],
replace: jobflow.Flow
| Job
| tuple[jobflow.Job | jobflow.Flow]
| list[jobflow.Job | jobflow.Flow]
| dict[Any, jobflow.Job | jobflow.Flow],
current_job: Job,
) -> jobflow.Flow:
"""
Expand All @@ -1419,8 +1441,16 @@ def prepare_replace(
"""
from jobflow.core.flow import Flow

if isinstance(replace, dict):
output = {k: j.output for k, j in replace.items()}
replace = Flow(jobs=list(replace.values()), output=output)

if isinstance(replace, (list, tuple)):
replace = Flow(jobs=replace)
if isinstance(replace, tuple):
seq_output: tuple | list = tuple(j.output for j in replace)
else:
seq_output = [j.output for j in replace]
replace = Flow(jobs=list(replace), output=seq_output)

if isinstance(replace, Flow) and replace.output is not None:
# add a job with same UUID as the current job to store the outputs of the
Expand Down
75 changes: 73 additions & 2 deletions tests/core/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,7 +1376,7 @@ def add_distributed_list(list_a):
job1 = make_list_of_n(2, 3)
dynamic_job = add_distributed_list(job1.output)
flow = Flow([job1, dynamic_job])
results = run_locally(flow, store=memory_jobstore)
results = run_locally(flow, store=memory_jobstore, ensure_success=True)

# Ensure the final result (3 instances of 3s) is in the results.
all_responses = [
Expand All @@ -1390,10 +1390,55 @@ def add_distributed_list(list_a):
# as a replace
job1 = make_list_of_n(2, 0)
flow = Flow([job1])
responses = run_locally(flow, store=memory_jobstore)
responses = run_locally(flow, store=memory_jobstore, ensure_success=True)
assert responses[job1.uuid][1].output == []


def test_job_autoreplace_dict(memory_jobstore):
# test to check if a job that returns a dict of jobs
# is interpreted as a replace.
from jobflow import Flow, job
from jobflow.managers.local import run_locally

@job
def add(x, y):
return x + y

@job
def make_named_jobs(a, b):
return {"sum": add(a, b), "diff": add(a, -b)}

@job
def double(x):
return x * 2

job1 = make_named_jobs(3, 2)
job2 = double(job1.output["sum"])
flow = Flow([job1, job2])
results = run_locally(flow, store=memory_jobstore, ensure_success=True)

assert results[job2.uuid][1].output == 10

all_outputs = [
response.output
for index_to_response in results.values()
for response in index_to_response.values()
]
assert 5 in all_outputs # sum
assert 1 in all_outputs # diff

# check that if an empty dict is returned the output is not misinterpreted
# as a replace
@job
def make_empty_dict():
return {}

job1 = make_empty_dict()
flow = Flow([job1])
responses = run_locally(flow, store=memory_jobstore, ensure_success=True)
assert responses[job1.uuid][1].output == {}


def test_get_item():
from jobflow import Flow, job, run_locally

Expand All @@ -1412,3 +1457,29 @@ def capitalize(s):

responses = run_locally(flow, ensure_success=True)
assert responses[job2.uuid][1].output == "WORLD"


def test_flow_no_truthiness(memory_jobstore):
# Test to check if a job that returns a value for which truthiness on the output
# cannot be ascertained works correctly.
# This can happen when a job returns a numpy array, for example.

from monty.json import MSONable

from jobflow import Flow, job
from jobflow.managers.local import run_locally

class NoTruthiness(MSONable):
def __bool__(self):
raise ValueError("No truthiness")

no_truthiness_obj = NoTruthiness()

@job
def no_truthiness_job():
return no_truthiness_obj

job1 = no_truthiness_job()
flow = Flow([job1])
results = run_locally(flow, store=memory_jobstore, ensure_success=True)
assert results[job1.uuid][1].output == no_truthiness_obj
22 changes: 13 additions & 9 deletions tests/core/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def add_response(a, b):

def test_replace_response(memory_jobstore):
from jobflow import Flow, Job, Response
from jobflow.core.job import store_inputs

def replace_job():
job = Job(add, function_args=(1,))
Expand Down Expand Up @@ -183,9 +184,11 @@ def replace_list_flow():
test_job = Job(replace_list_job, metadata=metadata, output_schema="123")
response = test_job.run(memory_jobstore)
assert isinstance(response.replace, Flow)
assert response.replace[-1].function == add
assert len(response.replace) == 2
# currently output schema and metadata ignored in this case
assert response.replace[-2].function == add
assert response.replace[-1].function == store_inputs.original

assert len(response.replace) == 3
# currently output schema and metadata ignored for all but the last `store_inputs`
for j in response.replace:
assert j.hosts == [response.replace.uuid]

Expand Down Expand Up @@ -218,12 +221,13 @@ def replace_list_flow():
test_job = Job(replace_list_flow, metadata=metadata, output_schema="123")
response = test_job.run(memory_jobstore)
assert isinstance(response.replace, Flow)
assert isinstance(response.replace[-1], Flow)
assert len(response.replace) == 2
assert isinstance(response.replace[-2], Flow)
assert len(response.replace) == 3
for f in response.replace:
for j in f:
assert j.hosts == [f.uuid, response.replace.uuid]
# currently output schema and metadata ignored in this case
if isinstance(f, Flow):
for j in f:
assert j.hosts == [f.uuid, response.replace.uuid]
# currently output schema and metadata ignored for all but the last `store_inputs`


def test_job_config(memory_jobstore):
Expand Down Expand Up @@ -352,7 +356,7 @@ def detour_job():
# test replace with flow
test_job = Job(replace_flow, config=nopass_config)
response = test_job.run(memory_jobstore)
for j in response.replace:
for j in response.replace[:-1]:
assert j.config.manager_config == {}

test_job = Job(replace_flow, config=pass_config)
Expand Down
Loading