Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions libensemble/gen_funcs/persistent_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,14 @@ def persistent_uniform(_, persis_info, gen_specs, libE_info):
ps = PersistentSupport(libE_info, EVAL_GEN_TAG)

# Send batches until manager sends stop tag
tag = None
while tag not in [STOP_TAG, PERSIS_STOP]:
while not ps.instructed_to_exit():
H_o = np.zeros(b, dtype=gen_specs["out"])
H_o["x"] = persis_info["rand_stream"].uniform(lb, ub, (b, n))
if "obj_component" in H_o.dtype.fields:
H_o["obj_component"] = persis_info["rand_stream"].integers(
low=0, high=gen_specs["user"]["num_components"], size=b
)
tag, Work, calc_in = ps.send_recv(H_o)
calc_in = ps.send_recv(H_o)
if hasattr(calc_in, "__len__"):
b = len(calc_in)

Expand Down Expand Up @@ -229,14 +228,13 @@ def batched_history_matching(_, persis_info, gen_specs, libE_info):

mu = np.zeros(n)
Sigma = np.eye(n)
tag = None

while tag not in [STOP_TAG, PERSIS_STOP]:
while not ps.instructed_to_exit():
H_o = np.zeros(b, dtype=gen_specs["out"])
H_o["x"] = persis_info["rand_stream"].multivariate_normal(mu, Sigma, b)

# Send data and get next assignment
tag, Work, calc_in = ps.send_recv(H_o)
calc_in = ps.send_recv(H_o)
if calc_in is not None:
all_inds = np.argsort(calc_in["f"])
best_inds = all_inds[:q]
Expand Down
11 changes: 9 additions & 2 deletions libensemble/tools/persistent_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, libE_info: Dict[str, Dict[Any, Any]], calc_type: int) -> None
EVAL_SIM_TAG,
], f"The calc_type: {self.calc_type} specifies neither a simulator nor generator."
self.calc_str = calc_type_strings[self.calc_type]
self.tag = None

def send(self, output: npt.NDArray, calc_status: int = UNSET_TAG, keep_state=False) -> None:
"""
Expand Down Expand Up @@ -79,7 +80,8 @@ def recv(self, blocking: bool = True) -> (int, dict, npt.NDArray):
logger.debug(f"Persistent {self.calc_str} received signal {tag} from manager")
if not isinstance(Work, dict):
self.comm.push_to_buffer(tag, Work)
return tag, Work, None
self.tag = tag
return None
else:
logger.debug(f"Persistent {self.calc_str} received work request from manager")

Expand All @@ -100,7 +102,9 @@ def recv(self, blocking: bool = True) -> (int, dict, npt.NDArray):
return data_tag, calc_in, None # calc_in is signal identifier

logger.debug(f"Persistent {self.calc_str} received work rows from manager")
return tag, Work, calc_in
self.last_H_rows = Work["libE_info"]["H_rows"]
self.tag = tag
return calc_in

def send_recv(self, output: npt.NDArray, calc_status: int = UNSET_TAG) -> (int, dict, npt.NDArray):
"""
Expand All @@ -115,6 +119,9 @@ def send_recv(self, output: npt.NDArray, calc_status: int = UNSET_TAG) -> (int,
self.send(output, calc_status)
return self.recv()

def instructed_to_exit(self) -> bool:
return self.tag in [STOP_TAG, PERSIS_STOP]

def request_cancel_sim_ids(self, sim_ids: List[int]):
"""Request cancellation of sim_ids.

Expand Down