diff --git a/libensemble/gen_funcs/persistent_sampling.py b/libensemble/gen_funcs/persistent_sampling.py index 401ccdaa9..7408d4179 100644 --- a/libensemble/gen_funcs/persistent_sampling.py +++ b/libensemble/gen_funcs/persistent_sampling.py @@ -48,15 +48,14 @@ def persistent_uniform(_, persis_info, gen_specs, libE_info): ps = PersistentSupport(libE_info, EVAL_GEN_TAG) # Send batches until manager sends stop tag - tag = None - while tag not in [STOP_TAG, PERSIS_STOP]: + while not ps.instructed_to_exit(): H_o = np.zeros(b, dtype=gen_specs["out"]) H_o["x"] = persis_info["rand_stream"].uniform(lb, ub, (b, n)) if "obj_component" in H_o.dtype.fields: H_o["obj_component"] = persis_info["rand_stream"].integers( low=0, high=gen_specs["user"]["num_components"], size=b ) - tag, Work, calc_in = ps.send_recv(H_o) + calc_in = ps.send_recv(H_o) if hasattr(calc_in, "__len__"): b = len(calc_in) @@ -229,14 +228,13 @@ def batched_history_matching(_, persis_info, gen_specs, libE_info): mu = np.zeros(n) Sigma = np.eye(n) - tag = None - while tag not in [STOP_TAG, PERSIS_STOP]: + while not ps.instructed_to_exit(): H_o = np.zeros(b, dtype=gen_specs["out"]) H_o["x"] = persis_info["rand_stream"].multivariate_normal(mu, Sigma, b) # Send data and get next assignment - tag, Work, calc_in = ps.send_recv(H_o) + calc_in = ps.send_recv(H_o) if calc_in is not None: all_inds = np.argsort(calc_in["f"]) best_inds = all_inds[:q] diff --git a/libensemble/tools/persistent_support.py b/libensemble/tools/persistent_support.py index dca7d37ca..16606be6f 100644 --- a/libensemble/tools/persistent_support.py +++ b/libensemble/tools/persistent_support.py @@ -28,6 +28,7 @@ def __init__(self, libE_info: Dict[str, Dict[Any, Any]], calc_type: int) -> None EVAL_SIM_TAG, ], f"The calc_type: {self.calc_type} specifies neither a simulator nor generator." self.calc_str = calc_type_strings[self.calc_type] + self.tag = None def send(self, output: npt.NDArray, calc_status: int = UNSET_TAG, keep_state=False) -> None: """ @@ -79,7 +80,8 @@ def recv(self, blocking: bool = True) -> (int, dict, npt.NDArray): logger.debug(f"Persistent {self.calc_str} received signal {tag} from manager") if not isinstance(Work, dict): self.comm.push_to_buffer(tag, Work) - return tag, Work, None + self.tag = tag + return None else: logger.debug(f"Persistent {self.calc_str} received work request from manager") @@ -100,7 +102,9 @@ def recv(self, blocking: bool = True) -> (int, dict, npt.NDArray): return data_tag, calc_in, None # calc_in is signal identifier logger.debug(f"Persistent {self.calc_str} received work rows from manager") - return tag, Work, calc_in + self.last_H_rows = Work["libE_info"]["H_rows"] + self.tag = tag + return calc_in def send_recv(self, output: npt.NDArray, calc_status: int = UNSET_TAG) -> (int, dict, npt.NDArray): """ @@ -115,6 +119,9 @@ def send_recv(self, output: npt.NDArray, calc_status: int = UNSET_TAG) -> (int, self.send(output, calc_status) return self.recv() + def instructed_to_exit(self) -> bool: + return self.tag in [STOP_TAG, PERSIS_STOP] + def request_cancel_sim_ids(self, sim_ids: List[int]): """Request cancellation of sim_ids.