diff --git a/Helm/charts/tomo_align/templates/deployment.yaml b/Helm/charts/tomo_align/templates/deployment.yaml index 88d9028d..55ab951d 100644 --- a/Helm/charts/tomo_align/templates/deployment.yaml +++ b/Helm/charts/tomo_align/templates/deployment.yaml @@ -33,6 +33,9 @@ spec: - -c - >- {{ .Values.command }} + env: + - name: HOME + value: "/tmp" volumeMounts: - name: config-file mountPath: /cryoemservices/config diff --git a/Helm/charts/tomo_align_slurm/templates/deployment.yaml b/Helm/charts/tomo_align_slurm/templates/deployment.yaml index 6f6d5ea4..4d2c76de 100644 --- a/Helm/charts/tomo_align_slurm/templates/deployment.yaml +++ b/Helm/charts/tomo_align_slurm/templates/deployment.yaml @@ -32,6 +32,8 @@ spec: - >- {{ .Values.command }} env: + - name: TILT_DENOISING_SIF + value: "{{ .Values.tiltDenoisingSIF }}" - name: ARETOMO2_EXECUTABLE value: "{{ .Values.aretomo2Executable }}" - name: ARETOMO3_EXECUTABLE diff --git a/recipes/ispyb/em-tomo-align.json b/recipes/ispyb/em-tomo-align.json index c5708948..57b6a0bd 100644 --- a/recipes/ispyb/em-tomo-align.json +++ b/recipes/ispyb/em-tomo-align.json @@ -23,9 +23,11 @@ "node_creator": 12, "projxy": 6, "projxz": 6, - "success": 7 + "success": 7, + "tomo_align_denoise": 13 }, "parameters": { + "denoise_tilts": 1, "dose_per_frame": "{dose_per_frame}", "frame_count": "{frame_count}", "input_file_list": "{input_file_list}", @@ -127,5 +129,22 @@ "queue": "node_creator", "service": "NodeCreator" }, + "13": { + "parameters": { + "denoise_tilts": 2, + "dose_per_frame": "{dose_per_frame}", + "frame_count": "{frame_count}", + "input_file_list": "{input_file_list}", + "kv": "{kv}", + "manual_tilt_offset": "{manual_tilt_offset}", + "path_pattern": "{path_pattern}", + "pixel_size": "{pixel_size}", + "relion_options": {}, + "stack_file": "{stack_file}", + "tilt_axis": "{tilt_axis}" + }, + "queue": "tomo_align", + "service": "TomoAlign" + }, "start": [[1, []]] } diff --git a/src/cryoemservices/services/tomo_align.py b/src/cryoemservices/services/tomo_align.py index 45a875f3..b7348ad0 100644 --- a/src/cryoemservices/services/tomo_align.py +++ b/src/cryoemservices/services/tomo_align.py @@ -131,6 +131,7 @@ class TomoParameters(BaseModel): interpolation_correction: Optional[int] = None dark_tol: Optional[float] = None manual_tilt_offset: Optional[float] = None + denoise_tilts: int = 0 visits_for_slurm: Optional[list] = ["bi", "cm", "nr", "nt"] relion_options: RelionServiceOptions @@ -222,6 +223,38 @@ def parse_tomo_output(self, tomo_stdout: str): if line.startswith("Best tilt axis"): self.alignment_quality = float(line.split()[5]) + def get_denoised_tilt_name(self, tilt: str) -> str: + denoised_tilt = "/" + "/".join( + "spool" if p == "processed" else p for p in Path(tilt).parts[1:] + ) + denoised_tilt = str( + Path(denoised_tilt).parent / (Path(denoised_tilt).stem + "_denoised.mrc") + ) + Path(denoised_tilt).parent.mkdir(parents=True, exist_ok=True) + return denoised_tilt + + def run_tilt_denoising(self, tilt_list: list[str]) -> bool: + for tilt in tilt_list: + denoised_tilt = self.get_denoised_tilt_name(tilt) + denoise_result = subprocess.run( + [ + "python", + "run_denoiser.py", + "--nimage", + str(tilt), + "--dimage", + str(denoised_tilt), + ], + capture_output=True, + ) + if denoise_result.returncode: + self.log.error(f"Failed to denoise tilt {tilt}") + self.log.error( + f"Denoise reason: {denoise_result.stdout.decode('utf8')} {denoise_result.stderr.decode('utf8')}" + ) + return False + return True + def extract_from_aln(self, tomo_parameters, alignment_output_dir, plot_path): tomo_aln_file = None self.thickness_pixels = None @@ -385,6 +418,30 @@ def _tilt(file_list_for_tilts): for index in sorted(tilts_to_remove, reverse=True): self.input_file_list_of_lists.remove(self.input_file_list_of_lists[index]) + # Decide whether to denoise + if tomo_params.denoise_tilts == 1: + self.log.info("Sending to tilt denoising and alignment re-run") + rw.send_to("tomo_align_denoise", {"denoise_tilts": 2}) + elif tomo_params.denoise_tilts == 2: + self.log.info("Running tilt denoising") + new_input_list_of_lists = [] + tilts_to_denoise = [] + for tname, tangle in self.input_file_list_of_lists: + denoised_tilt = self.get_denoised_tilt_name(tname) + tilts_to_denoise.append(tname) + new_input_list_of_lists.append([denoised_tilt, tangle]) + denoise_success = self.run_tilt_denoising(tilts_to_denoise) + if not denoise_success: + self.log.error("Failed to denoise tilts") + rw.transport.nack(header) + return + self.input_file_list_of_lists = new_input_list_of_lists + tomo_params.stack_file = "/" + "/".join( + "spool" if p == "processed" else p + for p in Path(tomo_params.stack_file).parts[1:] + ) + Path(tomo_params.stack_file).parent.mkdir(parents=True, exist_ok=True) + # Find the input image dimensions with mrcfile.open(self.input_file_list_of_lists[0][0]) as mrc: mrc_header = mrc.header @@ -589,6 +646,10 @@ def _tilt(file_list_for_tilts): } ] + # Write the score somewhere + with open(aretomo_output_path.with_suffix(".com"), "a") as comfile: + comfile.write(f"\n\nAlignment quality {self.alignment_quality}") + # Find the indexes of the dark images removed by AreTomo missing_indices = [] dark_images_file = Path(stack_name + "_DarkImgs.txt") diff --git a/src/cryoemservices/services/tomo_align_slurm.py b/src/cryoemservices/services/tomo_align_slurm.py index 4b8a662a..dd7c0d4e 100644 --- a/src/cryoemservices/services/tomo_align_slurm.py +++ b/src/cryoemservices/services/tomo_align_slurm.py @@ -11,7 +11,11 @@ import requests from cryoemservices.services.tomo_align import TomoAlign, TomoParameters -from cryoemservices.util.slurm_submission import slurm_submission_for_services +from cryoemservices.util.slurm_submission import ( + config_from_file, + slurm_submission_for_services, + wait_for_job_completion, +) def retrieve_files( @@ -104,6 +108,8 @@ def check_visit(tomo_params: TomoParameters): visit_search = re.search( "/[a-z]{2}[0-9]{5}-[0-9]{1,3}/", tomo_params.stack_file ) + if tomo_params.denoise_tilts == 2: + return False if visit_search: visit_name = visit_search[0][1:-1] visit_code = visit_name[:2] @@ -126,6 +132,73 @@ def parse_tomo_output_file(self, tomo_output_file: Path): self.alignment_quality = float(line.split()[5]) tomo_file.close() + def run_tilt_denoising(self, tilt_list: list[str]) -> bool: + transfer_status = transfer_files([Path(tilt) for tilt in tilt_list]) + if len(transfer_status) != len(tilt_list): + self.log.error( + f"Unable to transfer files: desired {tilt_list}, done {transfer_status}" + ) + return False + self.log.info("All files transferred") + + job_ids = [] + final_tilts = [] + for tilt in tilt_list: + denoised_tilt = self.get_denoised_tilt_name(tilt) + command = [ + "python", + "/install/denoiser/run_denoiser.py", + "--nimage", + str(tilt), + "--dimage", + str(denoised_tilt), + ] + tilt_job_id = slurm_submission_for_services( + log=self.log, + service_config_file=self._environment["config"], + slurm_cluster=self._environment["slurm_cluster"], + job_name="TiltDenoise", + command=command, + project_dir=Path(denoised_tilt).parent, + output_file=Path(denoised_tilt), + cpus=1, + use_gpu=True, + use_singularity=True, + cif_name=os.environ["TILT_DENOISING_SIF"], + external_filesystem=True, + wait_for_completion=False, + ) + job_ids.append(tilt_job_id.returncode) + final_tilts.append(denoised_tilt) + + service_config = config_from_file(self._environment["config"]) + self.log.info("Waiting for completion and retrieval of output files...") + for tid, job_id in enumerate(job_ids): + job_state = wait_for_job_completion( + job_id=job_id, + logger=self.log, + service_config=service_config, + cluster_name=self._environment["slurm_cluster"], + ) + retrieve_files( + job_directory=Path(final_tilts[tid]).parent, + files_to_skip=[Path(tilt_list[tid])], + basepath=str(Path(tilt_list[tid]).stem), + ) + if job_state != "COMPLETED": + self.log.error(f"Job {job_id} failed with {job_state}") + self.log.info("All denoising jobs finished and output files retrieved") + + for out_tilt in final_tilts: + if not Path(out_tilt).is_file(): + self.log.info(f"Tilt denoising failed for {out_tilt}") + if Path(out_tilt).with_suffix(".err").is_file(): + with open(Path(out_tilt).with_suffix(".err"), "r") as slurm_stderr: + stderr = slurm_stderr.read() + self.log.error(stderr) + return False + return True + def aretomo( self, tomo_parameters: TomoParameters, diff --git a/src/cryoemservices/util/slurm_submission.py b/src/cryoemservices/util/slurm_submission.py index 4ddcd003..f0a5c65b 100644 --- a/src/cryoemservices/util/slurm_submission.py +++ b/src/cryoemservices/util/slurm_submission.py @@ -274,6 +274,7 @@ def slurm_submission_for_services( memory_request: int = 12000, external_filesystem: bool = False, extra_singularity_directories: Optional[list[str]] = None, + wait_for_completion: bool = True, ) -> subprocess.CompletedProcess: """Submit jobs to a slurm cluster via the RestAPI""" # Load the service config with slurm credentials @@ -371,6 +372,8 @@ def slurm_submission_for_services( # Get the status of the submitted job from the restAPI log.info(f"Submitted job {job_id} for {job_name} to slurm. Waiting...") + if not wait_for_completion: + return subprocess.CompletedProcess(args="", returncode=job_id) slurm_job_state = wait_for_job_completion( job_id=job_id, logger=log, diff --git a/tests/services/test_tomo_align.py b/tests/services/test_tomo_align.py index 5d8a52e7..fe595e79 100644 --- a/tests/services/test_tomo_align.py +++ b/tests/services/test_tomo_align.py @@ -2102,6 +2102,25 @@ def test_parse_tomo_align_output(offline_transport): assert service.alignment_quality == 0.07568 +@mock.patch("cryoemservices.services.tomo_align.subprocess.run") +def test_run_tilt_denoising(mock_subprocess, tmp_path): + mock_subprocess().returncode = 0 + + tilt_in = f"{tmp_path}/processed/relion_murfey/MotionCorr/job002/Movies/tilt.mrc" + tilt_out = ( + f"{tmp_path}/spool/relion_murfey/MotionCorr/job002/Movies/tilt_denoised.mrc" + ) + + denoised_tilt = tomo_align.run_tilt_denoising(tilt_in) + + assert denoised_tilt == tilt_out + + mock_subprocess.assert_called_with( + ["python", "run_denoiser.py", "--nimage", tilt_in, "--dimage", tilt_out], + capture_output=True, + ) + + def test_resize_tomogram(tmp_path): """Test the reshaping of a XZY tomogram""" with mrcfile.new(tmp_path / "test.mrc") as mrc: