diff --git a/recipes/ispyb/em-spa-refine-wrapper.json b/recipes/ispyb/em-spa-refine-wrapper.json index c5ae8b3b..f5751b0f 100644 --- a/recipes/ispyb/em-spa-refine-wrapper.json +++ b/recipes/ispyb/em-spa-refine-wrapper.json @@ -10,7 +10,8 @@ "original_pixel_size": "{pixel_size}", "refine_class_nr": "{class_number}", "refine_job_dir": "{refine_job_dir}", - "relion_options": {} + "relion_options": {}, + "submit_to_slurm": "True" }, "queue": "extract_class", "service": "ExtractClass" diff --git a/src/cryoemservices/services/class2d.py b/src/cryoemservices/services/class2d.py index 8b7a659c..8557f4d3 100644 --- a/src/cryoemservices/services/class2d.py +++ b/src/cryoemservices/services/class2d.py @@ -19,7 +19,7 @@ class Class2D(CommonService): def initializing(self): """Subscribe to a queue. Received messages must be acknowledged.""" self.log.info("Class2D service starting") - wrap_subscribe( + self.subscription_id = wrap_subscribe( self._transport, self._environment["queue"] or "class2d", self.class2d, @@ -34,7 +34,7 @@ def class2d(self, rw, header: dict, message: dict): if not isinstance(message, dict): self.log.error("Rejected invalid simple message") self._transport.nack(header) - return + return False # Create a wrapper-like object that can be passed to functions # as if a recipe wrapper was present. @@ -58,7 +58,7 @@ def class2d(self, rw, header: dict, message: dict): f"with exception: {e}" ) rw.transport.nack(header) - return + return False # In this setup we cannot nack messages on failure, so instead check here if message.get("requeue", 0) >= 5: @@ -71,6 +71,8 @@ def class2d(self, rw, header: dict, message: dict): f"Running disconnected Class2D job for {class2d_params.particles_file}" ) rw.transport.ack(header) + rw.transport.unsubscribe(self.subscription_id) + rw.transport.drop_callback_reference(self.subscription_id) # Run the class2d job try: @@ -88,3 +90,7 @@ def class2d(self, rw, header: dict, message: dict): # Send back to the queue but mark a failure in the message message["requeue"] = message.get("requeue", 0) + 1 rw.send_to("class2d", message) + + # Reconnect to rabbitmq + self.initializing() + return True diff --git a/src/cryoemservices/services/class3d.py b/src/cryoemservices/services/class3d.py index 5f43b6ab..0c5892ee 100644 --- a/src/cryoemservices/services/class3d.py +++ b/src/cryoemservices/services/class3d.py @@ -19,7 +19,7 @@ class Class3D(CommonService): def initializing(self): """Subscribe to a queue. Received messages must be acknowledged.""" self.log.info("Class3D service starting") - wrap_subscribe( + self.subscription_id = wrap_subscribe( self._transport, self._environment["queue"] or "class3d", self.class3d, @@ -34,7 +34,7 @@ def class3d(self, rw, header: dict, message: dict): if not isinstance(message, dict): self.log.error("Rejected invalid simple message") self._transport.nack(header) - return + return False # Create a wrapper-like object that can be passed to functions # as if a recipe wrapper was present. @@ -58,7 +58,7 @@ def class3d(self, rw, header: dict, message: dict): f"with exception: {e}" ) rw.transport.nack(header) - return + return False # In this setup we cannot nack messages on failure, so instead check here if message.get("requeue", 0) >= 5: @@ -71,6 +71,8 @@ def class3d(self, rw, header: dict, message: dict): f"Running disconnected Class3D job for {class3d_params.particles_file}" ) rw.transport.ack(header) + rw.transport.unsubscribe(self.subscription_id) + rw.transport.drop_callback_reference(self.subscription_id) # Run the class3d job try: @@ -88,3 +90,7 @@ def class3d(self, rw, header: dict, message: dict): # Send back to the queue but mark a failure in the message message["requeue"] = message.get("requeue", 0) + 1 rw.send_to("class3d", message) + + # Reconnect to rabbitmq + self.initializing() + return True diff --git a/src/cryoemservices/services/common_service.py b/src/cryoemservices/services/common_service.py index 85e91182..ed57b293 100644 --- a/src/cryoemservices/services/common_service.py +++ b/src/cryoemservices/services/common_service.py @@ -31,6 +31,7 @@ def __init__( self.log = logging.getLogger(self._logger_name) self.log.setLevel(logging.INFO) self.single_message_mode: bool = single_message_mode + self.subscription_id: int = 0 def _transport_interceptor(self, callback): """Takes a callback function and adds headers and messages""" diff --git a/src/cryoemservices/services/cryolo.py b/src/cryoemservices/services/cryolo.py index b265b827..57f8263c 100644 --- a/src/cryoemservices/services/cryolo.py +++ b/src/cryoemservices/services/cryolo.py @@ -346,6 +346,18 @@ def cryolo(self, rw, header: dict, message: dict): rw.transport.ack(header) return + # Rename any flattened files + if scaled_input_path != cryolo_params.input_path: + (job_dir / f"CBOX/{Path(cryolo_params.output_path).stem}_flat.cbox").rename( + job_dir / f"CBOX/{Path(cryolo_params.output_path).stem}.cbox" + ) + (job_dir / f"STAR/{Path(cryolo_params.output_path).stem}_flat.star").rename( + job_dir / f"STAR/{Path(cryolo_params.output_path).stem}.star" + ) + (job_dir / f"EMAN/{Path(cryolo_params.output_path).stem}_flat.box").rename( + job_dir / f"EMAN/{Path(cryolo_params.output_path).stem}.box" + ) + # Read in the cbox file for particle selection and finding sizes try: cbox_file = cif.read_file( diff --git a/src/cryoemservices/services/extract_class.py b/src/cryoemservices/services/extract_class.py index ef0d789a..a59da6c0 100644 --- a/src/cryoemservices/services/extract_class.py +++ b/src/cryoemservices/services/extract_class.py @@ -3,6 +3,7 @@ import math import os import re +import subprocess from pathlib import Path from pydantic import BaseModel, Field, ValidationError @@ -30,6 +31,7 @@ class ExtractClassParameters(BaseModel): downscale: bool = True normalise: bool = True invert_contrast: bool = True + submit_to_slurm: bool = False relion_options: RelionServiceOptions @@ -254,19 +256,22 @@ def extract_class(self, rw, header: dict, message: dict): if extract_params.downscale: command.append("--downscale") - result = slurm_submission_for_services( - log=self.log, - service_config_file=self._environment["config"], - slurm_cluster=self._environment["slurm_cluster"], - job_name="ReExtract", - command=command, - project_dir=extract_job_dir, - output_file=extract_job_dir / "slurm_run", - cpus=40, - use_gpu=False, - use_singularity=False, - script_extras="module load EM/cryoem-services", - ) + if extract_params.submit_to_slurm: + result = slurm_submission_for_services( + log=self.log, + service_config_file=self._environment["config"], + slurm_cluster=self._environment["slurm_cluster"], + job_name="ReExtract", + command=command, + project_dir=extract_job_dir, + output_file=extract_job_dir / "slurm_run", + cpus=40, + use_gpu=False, + use_singularity=False, + script_extras="module load EM/cryoem-services", + ) + else: + result = subprocess.run(command, capture_output=True) # Register the Re-extraction job with the node creator self.log.info(f"Sending {self.extract_job_type} to node creator") diff --git a/src/cryoemservices/services/motioncorr.py b/src/cryoemservices/services/motioncorr.py index 2a9daba5..14ea6b34 100644 --- a/src/cryoemservices/services/motioncorr.py +++ b/src/cryoemservices/services/motioncorr.py @@ -579,7 +579,7 @@ def motion_correction(self, rw, header: dict, message: dict): ) if mc_params.do_icebreaker_jobs and not icebreaker_output.is_file(): # Three IceBreaker jobs: CtfFind job is MC+4 - ctf_job_number = 6 + ctf_job_number = job_number + 4 # Both IceBreaker micrographs and flattening inherit from motioncorr self.log.info( @@ -624,10 +624,10 @@ def motion_correction(self, rw, header: dict, message: dict): ctf_job_number = job_number + 4 else: # No IceBreaker jobs: CtfFind job is MC+1 - ctf_job_number = 3 + ctf_job_number = job_number + 1 else: # Tomography: CtfFind job is MC+1 - ctf_job_number = 3 + ctf_job_number = job_number + 1 # Forward results to ctffind (in both SPA and tomography) self.log.info(f"Sending to ctf: {mc_params.mrc_out}") diff --git a/src/cryoemservices/services/refine3d.py b/src/cryoemservices/services/refine3d.py index 45864436..ca8334a8 100644 --- a/src/cryoemservices/services/refine3d.py +++ b/src/cryoemservices/services/refine3d.py @@ -19,7 +19,7 @@ class Refine3D(CommonService): def initializing(self): """Subscribe to a queue. Received messages must be acknowledged.""" self.log.info("Refine3D service starting") - wrap_subscribe( + self.subscription_id = wrap_subscribe( self._transport, self._environment["queue"] or "refine3d", self.refine3d, @@ -34,7 +34,7 @@ def refine3d(self, rw, header: dict, message: dict): if not isinstance(message, dict): self.log.error("Rejected invalid simple message") self._transport.nack(header) - return + return False # Create a wrapper-like object that can be passed to functions # as if a recipe wrapper was present. @@ -58,7 +58,7 @@ def refine3d(self, rw, header: dict, message: dict): f"with exception: {e}" ) rw.transport.nack(header) - return + return False # In this setup we cannot nack messages on failure, so instead check here if message.get("requeue", 0) >= 5: @@ -71,6 +71,8 @@ def refine3d(self, rw, header: dict, message: dict): f"Running disconnected Refine3D job for {refine_params.particles_file}" ) rw.transport.ack(header) + rw.transport.unsubscribe(self.subscription_id) + rw.transport.drop_callback_reference(self.subscription_id) # Run the refinement job try: @@ -90,3 +92,7 @@ def refine3d(self, rw, header: dict, message: dict): # Send back to the queue but mark a failure in the message message["requeue"] = message.get("requeue", 0) + 1 rw.send_to("refine3d", message) + + # Reconnect to rabbitmq + self.initializing() + return True diff --git a/src/cryoemservices/util/ispyb_commands.py b/src/cryoemservices/util/ispyb_commands.py index 48478c9f..d8127429 100644 --- a/src/cryoemservices/util/ispyb_commands.py +++ b/src/cryoemservices/util/ispyb_commands.py @@ -7,7 +7,7 @@ import ispyb.sqlalchemy as models import sqlalchemy.exc -import sqlalchemy.orm +from sqlalchemy.orm import Session from cryoemservices.util import ispyb_buffer @@ -36,27 +36,35 @@ def parameters_with_replacement(param: str, message: dict, all_parameters: Calla return value_to_return -def multipart_message( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session -): +def multipart_message(message: dict, parameters: Callable, session: Session): """ The multipart_message command allows the recipe or client to specify a list of API calls to run. Each API call may have a return value that can be stored and passed on. """ commands = parameters("ispyb_command_list") - step = message.get("checkpoint", 0) + 1 if not commands or not isinstance(commands, list): logger.error("Received multipart message containing no command list") return False current_command = commands[0] - command = globals().get(current_command.get("ispyb_command")) + command: Callable | None = globals().get(current_command.get("ispyb_command")) if not command: logger.error( f"Multipart command {current_command} does not have a valid ispyb_command" ) return False + return run_multipart_command(message, parameters, session, command) + + +def run_multipart_command( + message: dict, parameters: Callable, session: Session, command: Callable +): + """Run specific commands from a multipart message""" + # Get the commands and steps again + commands = parameters("ispyb_command_list") + step = message.get("checkpoint", 0) + 1 + current_command = commands[0] logger.info( f"Processing step {step} of multipart message ({current_command}) " f"with {len(commands) - 1} further steps", @@ -102,7 +110,7 @@ def step_parameters(parameter): } -def buffer(message: dict, parameters: Callable, session: sqlalchemy.orm.Session): +def buffer(message: dict, parameters: Callable, session: Session): """ The buffer command supports running buffer lookups before running a command, and storing the result in a buffer after running the command. @@ -171,7 +179,7 @@ def buffer(message: dict, parameters: Callable, session: sqlalchemy.orm.Session) return result -def insert_movie(message: dict, parameters: Callable, session: sqlalchemy.orm.Session): +def insert_movie(message: dict, parameters: Callable, session: Session): try: foil_hole_id = ( parameters("foil_hole_id") if parameters("foil_hole_id") != "None" else None @@ -205,9 +213,7 @@ def insert_movie(message: dict, parameters: Callable, session: sqlalchemy.orm.Se return False -def insert_motion_correction( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session -): +def insert_motion_correction(message: dict, parameters: Callable, session: Session): def full_parameters(param): return parameters_with_replacement(param, message, parameters) @@ -265,7 +271,7 @@ def movie_parameters(p): def insert_relative_ice_thickness( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session + message: dict, parameters: Callable, session: Session ): def full_parameters(param): return parameters_with_replacement(param, message, parameters) @@ -292,7 +298,7 @@ def full_parameters(param): return False -def insert_ctf(message: dict, parameters: Callable, session: sqlalchemy.orm.Session): +def insert_ctf(message: dict, parameters: Callable, session: Session): def full_parameters(param): return parameters_with_replacement(param, message, parameters) @@ -330,9 +336,7 @@ def full_parameters(param): return False -def insert_particle_picker( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session -): +def insert_particle_picker(message: dict, parameters: Callable, session: Session): def full_parameters(param): return parameters_with_replacement(param, message, parameters) @@ -359,7 +363,7 @@ def full_parameters(param): def insert_particle_classification( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session + message: dict, parameters: Callable, session: Session ): def full_parameters(param): return parameters_with_replacement(param, message, parameters) @@ -421,7 +425,7 @@ def full_parameters(param): def insert_particle_classification_group( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session + message: dict, parameters: Callable, session: Session ): def full_parameters(param): return parameters_with_replacement(param, message, parameters) @@ -479,9 +483,7 @@ def full_parameters(param): return False -def insert_cryoem_initial_model( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session -): +def insert_cryoem_initial_model(message: dict, parameters: Callable, session: Session): def full_parameters(param): return parameters_with_replacement(param, message, parameters) @@ -513,9 +515,7 @@ def full_parameters(param): return False -def insert_bfactor_fit( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session -): +def insert_bfactor_fit(message: dict, parameters: Callable, session: Session): def full_parameters(param): return parameters_with_replacement(param, message, parameters) @@ -558,9 +558,7 @@ def full_parameters(param): return False -def insert_tomogram( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session -): +def insert_tomogram(message: dict, parameters: Callable, session: Session): if not message: message = {} @@ -624,9 +622,7 @@ def full_parameters(param): return False -def insert_processed_tomogram( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session -): +def insert_processed_tomogram(message: dict, parameters: Callable, session: Session): def full_parameters(param): return parameters_with_replacement(param, message, parameters) @@ -651,9 +647,7 @@ def full_parameters(param): return False -def insert_tilt_image_alignment( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session -): +def insert_tilt_image_alignment(message: dict, parameters: Callable, session: Session): def full_parameters(param): return parameters_with_replacement(param, message, parameters) @@ -706,9 +700,7 @@ def full_parameters(param): return False -def update_processing_status( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session -): +def update_processing_status(message: dict, parameters: Callable, session: Session): def full_parameters(param): return parameters_with_replacement(param, message, parameters) @@ -755,9 +747,7 @@ def full_parameters(param): # These are needed for the old relion-zocalo wrapper -def add_program_attachment( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session -): +def add_program_attachment(message: dict, parameters: Callable, session: Session): file_name = parameters("file_name") file_path = parameters("file_path") logger.error( @@ -767,9 +757,7 @@ def add_program_attachment( return {"success": True, "return_value": 0} -def register_processing( - message: dict, parameters: Callable, session: sqlalchemy.orm.Session -): +def register_processing(message: dict, parameters: Callable, session: Session): program = parameters("program") cmdline = parameters("cmdline") environment = parameters("environment") or "" diff --git a/src/cryoemservices/util/murfey_db_commands.py b/src/cryoemservices/util/murfey_db_commands.py index f6f09ac2..1f58fb42 100644 --- a/src/cryoemservices/util/murfey_db_commands.py +++ b/src/cryoemservices/util/murfey_db_commands.py @@ -14,6 +14,29 @@ logger.setLevel(logging.INFO) +def multipart_message(message: dict, parameters: Callable, session: Session): + """ + Override of the multipart message command, + which uses commands from this file rather than ispyb commands + """ + commands = parameters("ispyb_command_list") + if not commands or not isinstance(commands, list): + logger.error("Received multipart message containing no command list") + return False + + current_command = commands[0] + ispyb_command = current_command.get("ispyb_command") + command: Callable | None = globals().get(ispyb_command) or getattr( + ispyb_commands, ispyb_command, None + ) + if not command: + logger.error( + f"Multipart command {current_command} does not have a valid ispyb_command" + ) + return False + return ispyb_commands.run_multipart_command(message, parameters, session, command) + + def buffer(message: dict, parameters: Callable, session: Session): """ Override of the buffer command,