diff --git a/pyproject.toml b/pyproject.toml index b9d16b85..00448bfb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ torch = [ EMISPyB = "cryoemservices.services.ispyb_connector:EMISPyB" Extract = "cryoemservices.services.extract:Extract" ExtractClass = "cryoemservices.services.extract_class:ExtractClass" + ExtractSubTomo = "cryoemservices.services.extract_2d_subtomo:ExtractSubTomoFor2D" IceBreaker = "cryoemservices.services.icebreaker:IceBreaker" Images = "cryoemservices.services.images:Images" MembrainSeg = "cryoemservices.services.membrain_seg:MembrainSeg" diff --git a/src/cryoemservices/pipeliner_plugins/combine_star_files.py b/src/cryoemservices/pipeliner_plugins/combine_star_files.py index 0b62d4c9..c9593b83 100644 --- a/src/cryoemservices/pipeliner_plugins/combine_star_files.py +++ b/src/cryoemservices/pipeliner_plugins/combine_star_files.py @@ -44,7 +44,7 @@ def combine_star_files(files_to_process: List[Path], output_dir: Path): ): for line_counter in range(50): line = full_starfile.readline() - if line.startswith("opticsGroup"): + if "opticsGroup" in line: reference_optics = line.split() if not line: break @@ -71,7 +71,7 @@ def combine_star_files(files_to_process: List[Path], output_dir: Path): optics_line = added_starfile.readline() if not optics_line: raise IndexError(f"Cannot find optics group in {split_file}") - if optics_line.startswith("opticsGroup"): + if "opticsGroup" in optics_line: new_optics = optics_line.split() break @@ -104,8 +104,14 @@ def combine_star_files(files_to_process: List[Path], output_dir: Path): particle_line = added_starfile.readline() if not particle_line: break + if "opticsGroup" in particle_line: + # Skip the optics group + continue + if particle_line.startswith(("_", "loop_", "data_", "#")): + # Skip all block and loop header lines + continue particle_split_line = particle_line.split() - if len(particle_split_line) > 0 and particle_split_line[0][0].isdigit(): + if len(particle_split_line) > 0: file_particles_count += 1 total_particles += 1 particles_file.write(particle_line) diff --git a/src/cryoemservices/services/cryolo.py b/src/cryoemservices/services/cryolo.py index e90e32eb..7a70219b 100644 --- a/src/cryoemservices/services/cryolo.py +++ b/src/cryoemservices/services/cryolo.py @@ -23,7 +23,7 @@ class CryoloParameters(BaseModel): input_path: str = Field(..., min_length=1) output_path: str = Field(..., min_length=1) experiment_type: str - pixel_size: Optional[float] = None + pixel_size: float cryolo_box_size: int = 160 cryolo_model_weights: str = "gmodel_phosnet_202005_N63_c17.h5" cryolo_threshold: float = 0.3 @@ -40,6 +40,7 @@ class CryoloParameters(BaseModel): mc_uuid: Optional[int] = None app_id: Optional[int] = None picker_uuid: Optional[int] = None + raw_tomogram: Optional[str] = None relion_options: RelionServiceOptions ctf_values: dict = {} @@ -53,13 +54,12 @@ def is_spa_or_tomo(cls, experiment): @model_validator(mode="after") def check_spa_has_uuids_and_pixel_size(self): if self.experiment_type == "spa" and ( - self.mc_uuid is None or self.picker_uuid is None or not self.pixel_size + self.mc_uuid is None or self.picker_uuid is None ): raise ValueError( "In SPA mode the following must be provided: " f"mc_uuid (given {self.mc_uuid}), " - f"picker_uuid (given {self.picker_uuid}), " - f"pixel_size (given {self.pixel_size})" + f"picker_uuid (given {self.picker_uuid})" ) return self @@ -329,6 +329,29 @@ def cryolo(self, rw, header: dict, message: dict): } rw.send_to("ispyb_connector", ispyb_parameters_tomo) + # Get the diameters of the particles in Angstroms for Murfey + try: + cbox_file = cif.read_file(cryolo_params.output_path) + cbox_block = cbox_file.find_block("cryolo") + cbox_sizes = ( + np.array(cbox_block.find_loop("_EstWidth"), dtype=float) + + np.array(cbox_block.find_loop("_EstHeight"), dtype=float) + ) / 2 + cryolo_particle_sizes = cbox_sizes * cryolo_params.pixel_size + except (FileNotFoundError, OSError, AttributeError): + cryolo_particle_sizes = [] + + # Send to murfey for extraction + extraction_parameters = { + "register": "picked_tomogram", + "tomogram": cryolo_params.raw_tomogram or cryolo_params.input_path, + "cbox_3d": cryolo_params.output_path, + "pixel_size": cryolo_params.pixel_size, + "particle_diameters": list(cryolo_particle_sizes), + "particle_count": len(cryolo_particle_sizes), + } + rw.send_to("murfey_feedback", extraction_parameters) + self.log.info( f"Done {self.job_type} {cryolo_params.experiment_type} " f"for {cryolo_params.input_path}." @@ -345,14 +368,11 @@ def cryolo(self, rw, header: dict, message: dict): ) ) cbox_block = cbox_file.find_block("cryolo") - cbox_sizes = np.append( - np.array(cbox_block.find_loop("_EstWidth"), dtype=float), - np.array(cbox_block.find_loop("_EstHeight"), dtype=float), - ) - cbox_confidence = np.append( - np.array(cbox_block.find_loop("_Confidence"), dtype=float), - np.array(cbox_block.find_loop("_Confidence"), dtype=float), - ) + cbox_sizes = ( + np.array(cbox_block.find_loop("_EstWidth"), dtype=float) + + np.array(cbox_block.find_loop("_EstHeight"), dtype=float) + ) / 2 + cbox_confidence = np.array(cbox_block.find_loop("_Confidence"), dtype=float) # Select only a fraction of particles based on confidence if requested if ( diff --git a/src/cryoemservices/services/ctffind.py b/src/cryoemservices/services/ctffind.py index 224e77f8..5bbfd7a9 100644 --- a/src/cryoemservices/services/ctffind.py +++ b/src/cryoemservices/services/ctffind.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import Any, Optional +import mrcfile +import numpy as np from pydantic import BaseModel, Field, ValidationError, field_validator from workflows.recipe import wrap_subscribe @@ -352,3 +354,60 @@ def ctf_find(self, rw, header: dict, message: dict): self.log.info(f"Done {self.job_type} for {ctf_params.input_image}.") rw.transport.ack(header) + + +def ctf( + num_pixels_x: int, num_pixels_y: int, pixel_size: float, defocus: float +) -> np.array: + C_s: float = 2.7e7 + wavelength: float = 0.0196 + grid = np.meshgrid( + np.fft.fftfreq(num_pixels_x, pixel_size), + np.fft.fftfreq(num_pixels_y, pixel_size), + ) + grid = np.fft.fftshift(grid) + ksq = grid[0] ** 2 + grid[1] ** 2 + w = -defocus * wavelength * ksq / 2 + C_s * wavelength**3 + ksq**2 / 4 + return np.sin(2 * np.pi * w) + + +def ctf_micrograph(mic_file, output_file, defocus_u, defocus_v): + with mrcfile.open(mic_file) as mrc: + im = mrc.data + num_pixels_x = int(mrc.header.mx) + num_pixels_y = int(mrc.header.my) + pixel_size = mrc.header.cella.x / mrc.header.mx + dtype = im.dtype + + fft = np.fft.fft2(im) + fft = np.fft.fftshift(fft) + + ctf_mask = ctf( + num_pixels_x, num_pixels_y, pixel_size, np.sqrt(defocus_u**2 + defocus_v**2) + ) + ctf_mask[ctf_mask < 0] = -1 + ctf_mask[ctf_mask >= 0] = 1 + + fft = fft * ctf_mask + + fft_shifted = np.fft.ifftshift(fft) + corrected_im = np.real(np.fft.ifft2(fft_shifted)) + corrected_im = corrected_im.astype(dtype) + mrcfile.write(output_file, corrected_im, overwrite=True) + + +def ctf_of_tomo_star(root_dir, tomo_star, output_dir): + from gemmi import cif + + data = cif.read_file(str(Path(root_dir) / tomo_star)) + mics = list(data.sole_block().find_loop("_rlnMicrographName")) + defocus_u = list(data.sole_block().find_loop("_rlnDefocusU")) + defocus_v = list(data.sole_block().find_loop("_rlnDefocusV")) + + for i, mic in enumerate(mics): + ctf_micrograph( + Path(root_dir) / mic, + Path(root_dir) / output_dir / Path(mic).name, + float(defocus_u[i]), + float(defocus_v[i]), + ) diff --git a/src/cryoemservices/services/denoise.py b/src/cryoemservices/services/denoise.py index 60d8a0ee..55ebb7ec 100644 --- a/src/cryoemservices/services/denoise.py +++ b/src/cryoemservices/services/denoise.py @@ -297,10 +297,12 @@ def denoise(self, rw, header: dict, message: dict): "relion_options": dict(denoise_params.relion_options), } cryolo_parameters = { + "raw_tomogram": denoise_params.volume, "input_path": str(denoised_full_path), "output_path": str(cryolo_dir / f"CBOX_3D/{denoised_full_path.stem}.cbox"), "experiment_type": "tomography", "cryolo_box_size": 40, + "pixel_size": str(denoise_params.relion_options.pixel_size_downscaled), "relion_options": dict(denoise_params.relion_options), } rw.send_to("segmentation", segmentation_parameters) diff --git a/src/cryoemservices/services/extract.py b/src/cryoemservices/services/extract.py index 00ab828b..ceb1621e 100644 --- a/src/cryoemservices/services/extract.py +++ b/src/cryoemservices/services/extract.py @@ -134,6 +134,13 @@ def extract(self, rw, header: dict, message: dict): ) ) + if extract_params.downscale: + extract_params.relion_options.pixel_size_downscaled = ( + extract_params.pixel_size + * extract_params.relion_options.boxsize + / extract_params.relion_options.small_boxsize + ) + # Find the locations of the particles cbox_name = Path( extract_params.coord_list_file.replace("STAR", "CBOX") @@ -223,139 +230,39 @@ def extract(self, rw, header: dict, message: dict): # Pixel locations are from bottom left, need to flip the image later pixel_location_x = round(float(particles_x[particle])) pixel_location_y = round(float(particles_y[particle])) - - # Extract the particle image and pad the edges if it is not square - x_left_pad = 0 - x_right_pad = 0 - y_top_pad = 0 - y_bot_pad = 0 - extract_width = round(extract_params.relion_options.boxsize / 2) - x_left = pixel_location_x - extract_width - if x_left < 0: - x_left_pad = -x_left - x_left = 0 - x_right = pixel_location_x + extract_width - if x_right >= image_size[1]: - x_right_pad = x_right - image_size[1] - x_right = image_size[1] - y_top = pixel_location_y - extract_width - if y_top < 0: - y_top_pad = -y_top - y_top = 0 - y_bot = pixel_location_y + extract_width - if y_bot >= image_size[0]: - y_bot_pad = y_bot - image_size[0] - y_bot = image_size[0] - - particle_subimage = input_micrograph_image[y_top:y_bot, x_left:x_right] - particle_subimage = np.pad( - particle_subimage, - ((y_bot_pad, y_top_pad), (x_left_pad, x_right_pad)), - mode="edge", - ) - - # Flip all the values on inversion - if extract_params.invert_contrast: - particle_subimage = -1 * particle_subimage - - # Downscale the image size - if extract_params.downscale: - extract_params.relion_options.pixel_size_downscaled = ( - extract_params.pixel_size - * extract_params.relion_options.boxsize - / extract_params.relion_options.small_boxsize - ) - subimage_ft = np.fft.fftshift(np.fft.fft2(particle_subimage)) - deltax = ( - subimage_ft.shape[0] - extract_params.relion_options.small_boxsize - ) - deltay = ( - subimage_ft.shape[1] - extract_params.relion_options.small_boxsize - ) - particle_subimage = np.real( - np.fft.ifft2( - np.fft.ifftshift( - subimage_ft[ - deltax // 2 : subimage_ft.shape[0] - deltax // 2, - deltay // 2 : subimage_ft.shape[1] - deltay // 2, - ] - ) - ) - ) - extract_width = round(extract_params.relion_options.small_boxsize / 2) - - # Distance of each pixel from the centre, compared to background radius - grid_indexes = np.meshgrid( - np.arange(2 * extract_width), - np.arange(2 * extract_width), - ) - distance_from_centre = np.sqrt( - (grid_indexes[0] - extract_width + 0.5) ** 2 - + (grid_indexes[1] - extract_width + 0.5) ** 2 - ) - bg_region = ( - distance_from_centre - > np.ones(np.shape(particle_subimage)) * extract_params.bg_radius + full_particle_subimage, failure_reason = extract_single_particle( + input_image=input_micrograph_image, + x_coord=pixel_location_x, + y_coord=pixel_location_y, + extract_width=extract_width, + shape=[image_size[1], image_size[0]], ) - - # Fit background to a plane and subtract the plane from the image - positions = [grid_indexes[0][bg_region], grid_indexes[1][bg_region]] - # needs to create a matrix of the correct shape for a*x + b*y + c plane fit - if not len(positions[0]) == len(positions[1]): + if failure_reason: self.log.warning( - f"Particle image {particle} in {extract_params.micrographs_file} is not square" + f"Extraction failed for {particle} " + f"in {extract_params.micrographs_file}. " + f"Reason was {failure_reason}." ) continue - data_size = len(positions[0]) - positions_matrix = np.hstack( - ( - np.reshape(positions[0], (data_size, 1)), - np.reshape(positions[1], (data_size, 1)), - ) + particle_subimage, failure_reason = enhance_single_particle( + particle_subimage=full_particle_subimage, + extract_width=extract_width, + small_boxsize=extract_params.relion_options.small_boxsize, + bg_radius=extract_params.bg_radius, + invert_contrast=extract_params.invert_contrast, + downscale=extract_params.downscale, + norm=extract_params.norm, + plane_fit=True, ) - # this ones for c - positions_matrix = np.hstack((np.ones((data_size, 1)), positions_matrix)) - values = particle_subimage[bg_region] - # normal equation - try: - theta = np.dot( - np.dot( - np.linalg.inv( - np.dot(positions_matrix.transpose(), positions_matrix) - ), - positions_matrix.transpose(), - ), - values, - ) - except np.linalg.LinAlgError: + if failure_reason: self.log.warning( - f"Could not fit image plane for particle {particle} in {extract_params.micrographs_file}" + f"Extraction failed for {particle} " + f"in {extract_params.micrographs_file}. " + f"Reason was {failure_reason}." ) continue - # now we need the full grid across the image - positions_matrix = np.hstack( - ( - np.reshape(grid_indexes[0], (4 * extract_width**2, 1)), - np.reshape(grid_indexes[1], (4 * extract_width**2, 1)), - ) - ) - positions_matrix = np.hstack( - (np.ones((4 * extract_width**2, 1)), positions_matrix) - ) - plane = np.reshape( - np.dot(positions_matrix, theta), (2 * extract_width, 2 * extract_width) - ) - - particle_subimage -= plane - - # Background normalisation - if extract_params.norm: - # Standardise the values using the background - bg_mean = np.mean(particle_subimage[bg_region]) - bg_std = np.std(particle_subimage[bg_region]) - particle_subimage = (particle_subimage - bg_mean) / bg_std # Add to output stack if len(output_mrc_stack): @@ -413,3 +320,150 @@ def extract(self, rw, header: dict, message: dict): self.log.info(f"Done {self.job_type} for {extract_params.coord_list_file}.") rw.transport.ack(header) + + +def extract_single_particle( + input_image: np.ndarray, + x_coord: float, + y_coord: float, + extract_width: float, + shape: list[int], +) -> tuple[np.ndarray, str]: + """ + A function which can extract a single particle in a micrograph + """ + # Extract the particle image and pad the edges if it is not square + x_left_pad = 0 + x_right_pad = 0 + y_top_pad = 0 + y_bot_pad = 0 + + x_left = round(x_coord - extract_width) + if x_left < 0: + x_left_pad = -x_left + x_left = 0 + x_right = round(x_coord + extract_width) + if x_right >= shape[0]: + x_right_pad = x_right - shape[0] + x_right = shape[0] + y_top = round(y_coord - extract_width) + if y_top < 0: + y_top_pad = -y_top + y_top = 0 + y_bot = round(y_coord + extract_width) + if y_bot >= shape[1]: + y_bot_pad = y_bot - shape[1] + y_bot = shape[1] + + if y_bot <= y_top or x_left >= x_right: + return np.array([]), "Particle is outside image" + else: + particle_subimage = input_image[y_top:y_bot, x_left:x_right] + particle_subimage = np.pad( + particle_subimage, + ((y_bot_pad, y_top_pad), (x_left_pad, x_right_pad)), + mode="edge", + ) + return particle_subimage, "" + + +def enhance_single_particle( + particle_subimage: np.ndarray, + extract_width: float, + small_boxsize: int, + bg_radius: float, + invert_contrast: bool = True, + downscale: bool = True, + norm: bool = True, + plane_fit: bool = True, +): + """ + A function which runs enhancement on an extracted particle in a micrograph + or a flattened particle from a tomogram volume + """ + # Flip all the values on inversion + if invert_contrast: + particle_subimage = -1 * particle_subimage + + if downscale: + # Downscale the image size + subimage_ft = np.fft.fftshift(np.fft.fft2(particle_subimage)) + deltax = subimage_ft.shape[0] - small_boxsize + deltay = subimage_ft.shape[1] - small_boxsize + particle_subimage = np.real( + np.fft.ifft2( + np.fft.ifftshift( + subimage_ft[ + deltax // 2 : subimage_ft.shape[0] - deltax // 2, + deltay // 2 : subimage_ft.shape[1] - deltay // 2, + ] + ) + ) + ) + extract_width = round(small_boxsize / 2) + + # Distance of each pixel from the centre for background normalization + grid_indexes = np.meshgrid( + np.arange(2 * extract_width), + np.arange(2 * extract_width), + ) + distance_from_centre = np.sqrt( + (grid_indexes[0] - extract_width + 0.5) ** 2 + + (grid_indexes[1] - extract_width + 0.5) ** 2 + ) + bg_region = distance_from_centre > np.ones(np.shape(particle_subimage)) * bg_radius + + # Fitting of plane to the ice + if plane_fit: + # Fit background to a plane and subtract the plane from the image + positions = [grid_indexes[0][bg_region], grid_indexes[1][bg_region]] + # needs to create a matrix of the correct shape for a*x + b*y + c plane fit + if not len(positions[0]) == len(positions[1]): + return np.array([]), "Particle image is not square" + data_size = len(positions[0]) + positions_matrix = np.hstack( + ( + np.reshape(positions[0], (data_size, 1)), + np.reshape(positions[1], (data_size, 1)), + ) + ) + # this ones for c + positions_matrix = np.hstack((np.ones((data_size, 1)), positions_matrix)) + values = particle_subimage[bg_region] + # normal equation + try: + theta = np.dot( + np.dot( + np.linalg.inv( + np.dot(positions_matrix.transpose(), positions_matrix) + ), + positions_matrix.transpose(), + ), + values, + ) + except np.linalg.LinAlgError: + return np.array([]), "Could not fit image plane" + # now we need the full grid across the image + positions_matrix = np.hstack( + ( + np.reshape(grid_indexes[0], (4 * extract_width**2, 1)), + np.reshape(grid_indexes[1], (4 * extract_width**2, 1)), + ) + ) + positions_matrix = np.hstack( + (np.ones((4 * extract_width**2, 1)), positions_matrix) + ) + plane = np.reshape( + np.dot(positions_matrix, theta), (2 * extract_width, 2 * extract_width) + ) + + particle_subimage -= plane + + # Background normalisation + if norm: + # Standardise the values using the background + bg_mean = np.mean(particle_subimage[bg_region]) + bg_std = np.std(particle_subimage[bg_region]) + particle_subimage = (particle_subimage - bg_mean) / bg_std + + return particle_subimage, "" diff --git a/src/cryoemservices/services/extract_2d_subtomo.py b/src/cryoemservices/services/extract_2d_subtomo.py new file mode 100644 index 00000000..dc823ff2 --- /dev/null +++ b/src/cryoemservices/services/extract_2d_subtomo.py @@ -0,0 +1,290 @@ +from pathlib import Path + +import mrcfile +import numpy as np +from gemmi import cif +from pydantic import BaseModel, Field, ValidationError +from tqdm import tqdm +from workflows.recipe import wrap_subscribe + +from cryoemservices.services.common_service import CommonService +from cryoemservices.services.extract import enhance_single_particle +from cryoemservices.util.models import MockRW +from cryoemservices.util.relion_service_options import ( + RelionServiceOptions, + update_relion_options, +) +from cryoemservices.util.tomo_output_files import _get_tilt_name_v5_12 + + +class ExtractSubTomoParameters2D(BaseModel): + cbox_3d_file: str = Field(..., min_length=1) + tomogram: str = Field(..., min_length=1) + output_star: str = Field(..., min_length=1) + pixel_size: float + particle_diameter: float = 0 + boxsize: int = 256 + batch_size: int = 5000 + relion_options: RelionServiceOptions + + +class ExtractSubTomoFor2D(CommonService): + """ + A service for extracting 2D particles from cryolo autopicking for tomograms + This extracts the particle in 3D, projects it to 2D + and then processes it ready for SPA-like 2D classification + """ + + # Human readable service name + _service_name = "ExtractSubTomo" + + # Logger name + _logger_name = "cryoemservices.services.extract_subtomo" + + # Job name + job_type = "relion.pseudosubtomo" + + def initializing(self): + """Subscribe to a queue. Received messages must be acknowledged.""" + self.log.info("Sub-tomogram extraction service starting") + wrap_subscribe( + self._transport, + self._environment["queue"] or "extract_subtomo", + self.extract_subtomo_for_2d, + acknowledgement=True, + allow_non_recipe_messages=True, + ) + + def extract_subtomo_for_2d(self, rw, header: dict, message: dict): + """Main function which interprets and processes received messages""" + if not rw: + self.log.info("Received a simple message") + if not isinstance(message, dict): + self.log.error("Rejected invalid simple message") + self._transport.nack(header) + return + + # Create a wrapper-like object that can be passed to functions + # as if a recipe wrapper was present. + rw = MockRW(self._transport) + rw.recipe_step = {"parameters": message} + + try: + if isinstance(message, dict): + extract_subtomo_params = ExtractSubTomoParameters2D( + **{**rw.recipe_step.get("parameters", {}), **message} + ) + else: + extract_subtomo_params = ExtractSubTomoParameters2D( + **{**rw.recipe_step.get("parameters", {})} + ) + except (ValidationError, TypeError) as e: + self.log.warning( + f"Extraction parameter validation failed for message: {message} " + f"and recipe parameters: {rw.recipe_step.get('parameters', {})} " + f"with exception: {e}" + ) + rw.transport.nack(header) + return + + self.log.info( + f"Inputs: {extract_subtomo_params.tomogram}, " + f"{extract_subtomo_params.cbox_3d_file} " + f"Output: {extract_subtomo_params.output_star}" + ) + + # Update the relion options and get box sizes + extract_subtomo_params.relion_options = update_relion_options( + extract_subtomo_params.relion_options, dict(extract_subtomo_params) + ) + if extract_subtomo_params.particle_diameter: + extract_subtomo_params.boxsize = ( + extract_subtomo_params.relion_options.boxsize + ) + + # Make sure the output directory exists + if not Path(extract_subtomo_params.output_star).parent.exists(): + Path(extract_subtomo_params.output_star).parent.mkdir(parents=True) + + # Find the locations of the particles + coords_file = cif.read(extract_subtomo_params.cbox_3d_file) + coords_block = coords_file.find_block("cryolo") + pick_radius = float(coords_block.find_loop("_Width")[0]) / 2 + particles_x = ( + np.array(coords_block.find_loop("_CoordinateX"), dtype=float) + pick_radius + ) + particles_y = ( + np.array(coords_block.find_loop("_CoordinateY"), dtype=float) + pick_radius + ) + particles_z = np.array(coords_block.find_loop("_CoordinateZ"), dtype=float) + + # Read tomogram + with mrcfile.open(extract_subtomo_params.tomogram) as mrc: + tomogram_data = mrc.data + + # Extract at the same downscaling as during tomogram reconstruction + extract_width = round(extract_subtomo_params.relion_options.boxsize / 2) + + output_mrc_stack = np.array([]) + for particle in tqdm(range(len(particles_x))): + if ( + particles_y[particle] - extract_width < 0 + or particles_y[particle] + extract_width > tomogram_data.shape[1] + or particles_x[particle] - extract_width < 0 + or particles_x[particle] + extract_width > tomogram_data.shape[2] + ): + self.log.info( + f"Skipping particle {particle} runs over the edge of the volume" + ) + continue + + min_z = particles_z[particle] - extract_width + max_z = particles_z[particle] + extract_width + if min_z < 0: + min_z = 0 + if max_z >= tomogram_data.shape[0]: + max_z = tomogram_data.shape[0] - 1 + extract_vol = tomogram_data[ + round(float(min_z)) : round(float(max_z)), + round(float(particles_y[particle] - extract_width)) : round( + float(particles_y[particle] + extract_width) + ), + round(float(particles_x[particle] - extract_width)) : round( + float(particles_x[particle] + extract_width) + ), + ] + + # Run projection along x axis + flat_particle = extract_vol.mean(axis=0) + particle_subimage, failure_reason = enhance_single_particle( + particle_subimage=flat_particle, + extract_width=extract_width, + small_boxsize=extract_subtomo_params.boxsize, + bg_radius=round(0.375 * extract_subtomo_params.boxsize), + invert_contrast=True, + downscale=False, + norm=True, + plane_fit=True, + ) + if failure_reason: + self.log.warning( + f"Extraction failed for {particle}. Reason was {failure_reason}." + ) + continue + + # Add to output stack + if len(output_mrc_stack): + output_mrc_stack = np.append( + output_mrc_stack, [particle_subimage], axis=0 + ) + else: + output_mrc_stack = np.array([particle_subimage], dtype=np.float32) + + # Produce the mrc file of the extracted particles + output_mrc_file = ( + Path(extract_subtomo_params.output_star).parent + / Path(extract_subtomo_params.tomogram).with_suffix(".mrcs").name + ) + particle_count = np.shape(output_mrc_stack)[0] + self.log.info(f"Extracted {particle_count} particles") + with mrcfile.new(str(output_mrc_file), overwrite=True) as mrc: + mrc.set_data(output_mrc_stack.astype(np.float32)) + mrc.header.mx = extract_subtomo_params.relion_options.boxsize + mrc.header.my = extract_subtomo_params.relion_options.boxsize + mrc.header.mz = 1 + mrc.header.cella.x = ( + extract_subtomo_params.pixel_size + * extract_subtomo_params.relion_options.boxsize + ) + mrc.header.cella.y = ( + extract_subtomo_params.pixel_size + * extract_subtomo_params.relion_options.boxsize + ) + mrc.header.cella.z = 1 + + # Construct the output star file + if not Path(extract_subtomo_params.output_star).is_file(): + extracted_parts_doc = cif.Document() + optics_block = extracted_parts_doc.add_new_block("optics") + optics_loop = optics_block.init_loop( + "_rln", + [ + "Voltage", + "SphericalAberration", + "AmplitudeContrast", + "OpticsGroup", + "OpticsGroupName", + "CtfDataAreCtfPremultiplied", + "ImageDimensionality", + "ImagePixelSize", + "ImageSize", + ], + ) + optics_loop.add_row( + [ + "300.00", + "2.70", + "0.10", + "1", + "opticsGroup1", + "1", + "2", + str(extract_subtomo_params.pixel_size), + str(extract_subtomo_params.boxsize), + ] + ) + extracted_parts_block = extracted_parts_doc.add_new_block("particles") + extracted_parts_loop = extracted_parts_block.init_loop( + "_rln", + [ + "TomoName", + "OpticsGroup", + "TomoParticleName", + "ImageName", + ], + ) + else: + extracted_parts_doc = cif.read_file(extract_subtomo_params.output_star) + extracted_parts_block = extracted_parts_doc.find_block("particles") + extracted_parts_loop = extracted_parts_block.find_loop( + "_rlnTomoName" + ).get_loop() + for particle in range(particle_count): + extracted_parts_loop.add_row( + [ + _get_tilt_name_v5_12(Path(extract_subtomo_params.tomogram)), + "1", + f"{particle}@{output_mrc_file}", + f"{particle}@{output_mrc_file}", + ] + ) + extracted_parts_doc.write_file( + extract_subtomo_params.output_star, style=cif.Style.Simple + ) + + # Register the extract job with the node creator + self.log.info(f"Sending {self.job_type} to node creator") + node_creator_parameters = { + "job_type": self.job_type, + "input_file": extract_subtomo_params.cbox_3d_file, + "output_file": extract_subtomo_params.output_star, + "relion_options": dict(extract_subtomo_params.relion_options), + "command": "", + "stdout": "", + "stderr": "", + } + rw.send_to("node_creator", node_creator_parameters) + + # Register the files needed for selection and batching + self.log.info("Sending to particle selection") + select_params = { + "input_file": extract_subtomo_params.output_star, + "batch_size": extract_subtomo_params.batch_size, + "image_size": extract_subtomo_params.boxsize, + "tomo": True, + "relion_options": dict(extract_subtomo_params.relion_options), + } + rw.send_to("select_particles", select_params) + + self.log.info(f"Done {self.job_type} for {extract_subtomo_params.cbox_3d_file}") + rw.transport.ack(header) diff --git a/src/cryoemservices/services/extract_subtomo.py b/src/cryoemservices/services/extract_subtomo.py new file mode 100644 index 00000000..d6d289f7 --- /dev/null +++ b/src/cryoemservices/services/extract_subtomo.py @@ -0,0 +1,590 @@ +from __future__ import annotations + +import ast +from pathlib import Path + +import matplotlib.pyplot as plt +import mrcfile +import numpy as np +import workflows.transport.pika_transport as pt +from gemmi import cif +from pydantic import BaseModel, Field, ValidationError, field_validator +from tqdm import tqdm +from workflows.recipe import wrap_subscribe + +from cryoemservices.services.common_service import CommonService +from cryoemservices.services.extract import ( + enhance_single_particle, + extract_single_particle, +) +from cryoemservices.util.models import MockRW +from cryoemservices.util.relion_service_options import ( + RelionServiceOptions, + update_relion_options, +) +from cryoemservices.util.tomo_output_files import ( + _get_tilt_name_v5_12, + _get_tilt_number_v5_12, +) + +transport = pt.PikaTransport() +transport.load_configuration_file( + "/dls_sw/apps/murfey/config/rmq-connection-creds-pollux.yml" +) +transport.connect() + +root_dir = Path("/dls/m07/data/2025/cm40593-13/spool/ctfcorrect") +for tomo in root_dir.glob("Tomograms/job006/tomograms/*_aretomo.mrc"): + transport.send( + "segmentation", + { + "cbox_3d_file": f"{root_dir}/AutoPick/job009/CBOX_3D/{tomo.stem}.denoised.cbox", + "tomogram": str(tomo), + "output_star": f"{root_dir}/Extract/job010/{tomo.stem}.star", + "pixel_size": 5.36, + "particle_diameter": 250, + "relion_options": {}, + }, + ) + +inptus = { + "cbox_3d_file": "/scratch/yxd92326/data/tomo-extract/2_2_Ribosome_Pos_1_stack_aretomo.denoised.cbox", + "tilt_alignment_file": "/scratch/yxd92326/data/tomo-extract/2_2_Ribosome_Pos_1_stack.aln", + "newstack_file": "/scratch/yxd92326/data/tomo-extract/2_2_Ribosome_Pos_1_stack_newstack.txt", + "output_star": "/scratch/yxd92326/data/tomo-extract/extracted/extract.star", + "pixel_size": 1.34, + "dose_per_tilt": 4, + "tilt_offset": 0, + "scaled_tomogram_shape": [1440, 1023, 400], + "relion_options": {}, +} +in2 = { + "cbox_3d_file": "/scratch/yxd92326/data/tomo-extract/2_1_ApoF_Pos_13_9_test.cbox", + "tilt_alignment_file": "/scratch/yxd92326/data/tomo-extract/2_1_ApoF_Pos_13_9_stack_aretomo.aln", + "newstack_file": "/scratch/yxd92326/data/tomo-extract/2_1_ApoF_Pos_13_9_stack_newstack.txt", + "output_star": "/scratch/yxd92326/data/tomo-extract/extracted_apof/extract.star", + "pixel_size": 1.34, + "dose_per_tilt": 4, + "tilt_offset": 0, + "scaled_tomogram_shape": [1440, 1023, 400], + "relion_options": {}, + "particle_diameter": 500, +} + +in3 = { + "cbox_3d_file": "/dls/m06/data/2025/bi37708-55/tmp/extract-test/AutoPick/job009/CBOX_3D/Tomo_position3_stack_Vol.denoised.cbox", + "tomogram": "/dls/m06/data/2025/bi37708-55/tmp/extract-test/Tomograms/job006/tomograms/Tomo_position3_stack_Vol.mrc", + "output_star": "/dls/m06/data/2025/bi37708-55/tmp/extract-test/Extract/class2d/Tomo_position3_stack_Vol.star", + "pixel_size": 7.76, + "particle_diameter": 225, + "relion_options": {}, +} + +transport.send( + "segmentation", + { + "cbox_3d_file": "/dls/m06/data/2025/bi38637-22/spool/extract-test/AutoPick/job009/CBOX_3D/Position_002_stack_Vol.denoised.cbox", + "tomogram": "/dls/m06/data/2025/bi38637-22/spool/extract-test/CtfFind/corrected/aretomo2_compare/Position_002_stack_aretomo2_raw.mrc", + "output_star": "/dls/m06/data/2025/bi38637-22/spool/extract-test/Extract/ctftest/Position_002_stack_raw.star", + "pixel_size": 7.76, + "particle_diameter": 250, + "relion_options": {}, + }, +) + + +class ExtractSubTomoParameters3D(BaseModel): + cbox_3d_file: str = Field(..., min_length=1) + tilt_alignment_file: str = Field(..., min_length=1) + newstack_file: str = Field(..., min_length=1) + output_star: str = Field(..., min_length=1) + scaled_tomogram_shape: list[int] | str + pixel_size: float + dose_per_tilt: float + tilt_offset: float + particle_diameter: float = 0 + boxsize: int = 256 + min_frames: int = 1 + maximum_dose: int = -1 + tomogram_binning: int = 4 + relion_options: RelionServiceOptions + + @field_validator("scaled_tomogram_shape") + @classmethod + def check_shape_is_3d(cls, v): + if not len(v): + raise ValueError("Tomogram shape not given") + if type(v) is str: + shape_list = ast.literal_eval(v) + else: + shape_list = v + if len(shape_list) != 3: + raise ValueError("Tomogram shape must be 3D") + return shape_list + + +class ExtractSubTomoFor3D(CommonService): + """ + A service for extracting particles from cryolo autopicking for tomograms + """ + + # Human readable service name + _service_name = "ExtractSubTomo" + + # Logger name + _logger_name = "cryoemservices.services.extract_subtomo" + + # Job name + job_type = "relion.pseudosubtomo" + + def initializing(self): + """Subscribe to a queue. Received messages must be acknowledged.""" + self.log.info("Sub-tomogram extraction service starting") + wrap_subscribe( + self._transport, + self._environment["queue"] or "extract_subtomo", + self.extract_subtomo, + acknowledgement=True, + allow_non_recipe_messages=True, + ) + + def extract_subtomo(self, rw, header: dict, message: dict): + """Main function which interprets and processes received messages""" + if not rw: + self.log.info("Received a simple message") + if not isinstance(message, dict): + self.log.error("Rejected invalid simple message") + self._transport.nack(header) + return + + # Create a wrapper-like object that can be passed to functions + # as if a recipe wrapper was present. + rw = MockRW(self._transport) + rw.recipe_step = {"parameters": message} + + try: + if isinstance(message, dict): + extract_subtomo_params = ExtractSubTomoParameters3D( + **{**rw.recipe_step.get("parameters", {}), **message} + ) + else: + extract_subtomo_params = ExtractSubTomoParameters3D( + **{**rw.recipe_step.get("parameters", {})} + ) + except (ValidationError, TypeError) as e: + self.log.warning( + f"Extraction parameter validation failed for message: {message} " + f"and recipe parameters: {rw.recipe_step.get('parameters', {})} " + f"with exception: {e}" + ) + rw.transport.nack(header) + return + + self.log.info( + f"Inputs: {extract_subtomo_params.tilt_alignment_file}, " + f"{extract_subtomo_params.cbox_3d_file} " + f"Output: {extract_subtomo_params.output_star}" + ) + + # Update the relion options and get box sizes + extract_subtomo_params.relion_options = update_relion_options( + extract_subtomo_params.relion_options, dict(extract_subtomo_params) + ) + + # Make sure the output directory exists + if not Path(extract_subtomo_params.output_star).parent.exists(): + Path(extract_subtomo_params.output_star).parent.mkdir(parents=True) + + # Find the locations of the particles + coords_file = cif.read(extract_subtomo_params.cbox_3d_file) + coords_block = coords_file.find_block("cryolo") + pick_radius = float(coords_block.find_loop("_Width")[0]) / 2 + particles_x = ( + np.array(coords_block.find_loop("_CoordinateX"), dtype=float) + pick_radius + ) + particles_y = ( + np.array(coords_block.find_loop("_CoordinateY"), dtype=float) + pick_radius + ) + particles_z = np.array(coords_block.find_loop("_CoordinateZ"), dtype=float) + + # Get the shifts between tilts + shift_data = np.genfromtxt(extract_subtomo_params.tilt_alignment_file) + refined_tilt_axis = float(shift_data[0, 1]) + x_shifts = shift_data[:, 3].astype(float) + y_shifts = shift_data[:, 4].astype(float) + tilt_angles = shift_data[:, 9].astype(float) * np.pi / 180 + tilt_count = len(x_shifts) + + # Rotation around the tilt axis is about (0, height/2) + # Or possibly not, sometimes seems to be (width/2, height/2), needs exploration + centre_x = float(extract_subtomo_params.scaled_tomogram_shape[0]) / 2 + centre_y = float(extract_subtomo_params.scaled_tomogram_shape[1]) / 2 + centre_z = float(extract_subtomo_params.scaled_tomogram_shape[2]) / 2 + tilt_axis_radians = (refined_tilt_axis - 90) * np.pi / 180 + + # Downscaling dimensions + extract_subtomo_params.relion_options.pixel_size_downscaled = ( + extract_subtomo_params.pixel_size + * extract_subtomo_params.relion_options.boxsize + / extract_subtomo_params.relion_options.small_boxsize + ) + self.log.info( + f"Downscaling to {extract_subtomo_params.relion_options.pixel_size_downscaled}" + ) + extract_width = round(extract_subtomo_params.relion_options.boxsize / 2) + + pixel_size = extract_subtomo_params.relion_options.pixel_size_downscaled + + # Read in tilt images + self.log.info("Reading tilt images") + tilt_png_names = [] + tilt_images = [] + tilt_numbers = [] + tid = 0 + with open(extract_subtomo_params.newstack_file) as ns_file: + while True: + tid += 1 + line = ns_file.readline() + if not line: + break + elif line.startswith("/"): + tilt_name = line.strip() + tilt_png = Path("/home/yxd92326/Pictures/picking/tomo/") / ( + f"{tid}T_" + Path(tilt_name).with_suffix(".png").name + ) + tilt_png_names.append(tilt_png) + tilt_png.unlink(missing_ok=True) + + tilt_numbers.append(_get_tilt_number_v5_12(Path(tilt_name))) + with mrcfile.open(tilt_name) as mrc: + tilt_images.append(mrc.data) + + for tilt in range(tilt_count): + if extract_subtomo_params.maximum_dose > 0 and ( + extract_subtomo_params.dose_per_tilt * tilt_numbers[tilt] + > extract_subtomo_params.maximum_dose + ): + self.log.info(f"Skipping tilt {tilt} due to dose limit") + + frames = np.zeros((len(particles_x), tilt_count), dtype=int) + tilt_coords: list = [[] for tilt in range(tilt_count)] + for particle in tqdm(range(len(particles_x))): + output_mrc_stack = np.array([]) + for tilt in range(tilt_count): + if extract_subtomo_params.maximum_dose > 0 and ( + extract_subtomo_params.dose_per_tilt * tilt_numbers[tilt] + > extract_subtomo_params.maximum_dose + ): + continue + + x_in_tilt, y_in_tilt = get_coord_in_tilt( + x=particles_x[particle], + y=particles_y[particle], + z=particles_z[particle], + cen_x=centre_x, + cen_y=centre_y, + cen_z=centre_z, + theta_y=tilt_angles[tilt], + theta_z=tilt_axis_radians, + delta_x=x_shifts[tilt], + delta_y=y_shifts[tilt], + binning=extract_subtomo_params.tomogram_binning, + ) + tilt_coords[tilt].append([x_in_tilt, y_in_tilt]) + # print(x_in_tilt, y_in_tilt, tilt_angles[tilt]) + + particle_subimage, failure_reason1 = extract_single_particle( + input_image=tilt_images[tilt], + x_coord=x_in_tilt, + y_coord=y_in_tilt, + extract_width=extract_width, + shape=[ + int(i * extract_subtomo_params.tomogram_binning) + for i in extract_subtomo_params.scaled_tomogram_shape + ], + ) + particle_subimage, failure_reason2 = enhance_single_particle( + particle_subimage=particle_subimage, + extract_width=extract_width, + small_boxsize=extract_subtomo_params.small_boxsize, + bg_radius=round(0.375 * extract_subtomo_params.small_boxsize), + invert_contrast=True, + downscale=True, + norm=True, + plane_fit=True, + ) + + if failure_reason1 or failure_reason2: + self.log.warning( + f"Extraction failed for {particle} in {tilt}. " + f"Reason was {failure_reason1} {failure_reason2}." + ) + particle_subimage = np.zeros( + ( + extract_subtomo_params.small_boxsize, + extract_subtomo_params.small_boxsize, + ) + ) + + # Add to output stack + if len(output_mrc_stack): + output_mrc_stack = np.append( + output_mrc_stack, [particle_subimage], axis=0 + ) + else: + output_mrc_stack = np.array([particle_subimage], dtype=np.float32) + frames[particle, tilt] = 1 + + if not len(output_mrc_stack): + self.log.warning(f"Could not extract particle {particle}") + continue + + # Produce the mrc file of the extracted particles + output_mrc_file = ( + Path(extract_subtomo_params.output_star).parent + / f"{particle}_stack2d.mrcs" + ) + with mrcfile.new(str(output_mrc_file), overwrite=True) as mrc: + mrc.set_data(output_mrc_stack.astype(np.float32)) + mrc.header.mx = extract_subtomo_params.relion_options.small_boxsize + mrc.header.my = extract_subtomo_params.relion_options.small_boxsize + mrc.header.mz = 1 + mrc.header.cella.x = ( + pixel_size * extract_subtomo_params.relion_options.small_boxsize + ) + mrc.header.cella.y = ( + pixel_size * extract_subtomo_params.relion_options.small_boxsize + ) + mrc.header.cella.z = 1 + + for tilt in tqdm(range(tilt_count)): + plt.imshow(tilt_images[tilt], vmin=0.5, vmax=1.7) + for loc in tilt_coords[tilt]: + plt.scatter(loc[0], loc[1], s=2, color="red") + plt.savefig(tilt_png_names[tilt]) + plt.close() + + # Construct the output star file + extracted_parts_doc = cif.Document() + extracted_parts_block = extracted_parts_doc.add_new_block("particles") + extracted_parts_loop = extracted_parts_block.init_loop( + "_rln", + [ + "TomoName", + "OpticsGroup", + "TomoParticleName", + "TomoVisibleFrames", + "ImageName", + "OriginXAngst", + "OriginYAngst", + "OriginZAngst", + "CenteredCoordinateXAngst", + "CenteredCoordinateYAngst", + "CenteredCoordinateZAngst", + ], + ) + for particle in range(len(particles_x)): + extracted_parts_loop.add_row( + [ + _get_tilt_name_v5_12( + Path(extract_subtomo_params.tilt_alignment_file) + ), + "1", + f"{_get_tilt_name_v5_12(Path(extract_subtomo_params.tilt_alignment_file))}/{particle}", + f"[{','.join([str(frm) for frm in frames[particle]])}]", + f"{Path(extract_subtomo_params.output_star).parent}/{particle}_stack2d.mrcs", + str(centre_x * extract_subtomo_params.tomogram_binning), + str(centre_y * extract_subtomo_params.tomogram_binning), + str(centre_z * extract_subtomo_params.tomogram_binning), + str( + float(particles_x[particle]) + - centre_x * extract_subtomo_params.tomogram_binning + ), + str( + float(particles_y[particle]) + - centre_y * extract_subtomo_params.tomogram_binning + ), + str( + float(particles_z[particle]) + - centre_z * extract_subtomo_params.tomogram_binning + ), + ] + ) + extracted_parts_doc.write_file( + extract_subtomo_params.output_star, style=cif.Style.Simple + ) + + # Register the extract job with the node creator + self.log.info(f"Sending {self.job_type} to node creator") + node_creator_parameters = { + "job_type": self.job_type, + "input_file": extract_subtomo_params.cbox_3d_file, + "output_file": extract_subtomo_params.output_star, + "relion_options": dict(extract_subtomo_params.relion_options), + "command": "", + "stdout": "", + "stderr": "", + "results": { + "box_size": extract_subtomo_params.relion_options.small_boxsize + }, + } + rw.send_to("node_creator", node_creator_parameters) + + self.log.info(f"Done {self.job_type} for {extract_subtomo_params.cbox_3d_file}") + rw.transport.ack(header) + + +def get_coord_in_tilt( + x: float, + y: float, + z: float, + cen_x: float, + cen_y: float, + cen_z: float, + theta_y: float, + theta_z: float, + delta_x: float, + delta_y: float, + binning: int, +): + # Translation raw to aligned tilt is subtract shift then rotate around centre + # In binned coordinates here + x_centred = x - cen_x + y_centred = y - cen_y # + cen_x * np.tan(theta_z) TODO: last factor depends on rot + x_2d = x_centred * np.cos(theta_z) - y_centred * np.sin(theta_z) + y_2d = x_centred * np.sin(theta_z) + y_centred * np.cos(theta_z) + # Un-bin and apply shifts + x_tilt = (cen_x + x_2d) * binning + delta_x + y_flat = (cen_y + y_2d) * binning + delta_y + y_tilt = (y_flat - cen_y * binning) * np.cos(theta_y) + cen_y * binning + # print(cen_x, x, x_2d, delta_x, cen_y, y, y_2d, delta_y) + + z_centred = z - cen_z + y_tilt += z_centred * np.sin(theta_y) * binning + # print(z_centred * np.sin(theta_y), x_tilt, y_tilt) + return x_tilt, y_tilt + + +""" +with open("Extract_maxdose/particles.star", "a") as partstar: + for tomo in Path("Extract_maxdose").glob("2_1*"): + with open(tomo / "extract.star") as tomostar: + while True: + line=tomostar.readline() + if not line: + break + if line.startswith("2_1"): + partstar.write(line) +""" + + +def cbox_to_star(name, max_subsample): + import pandas as pd + import starfile + + # make data_particles + cbox = starfile.read(f"AutoPick/job009/CBOX_3D/{name}_stack_aretomo.denoised.cbox") + all_particles = cbox["cryolo"] + new_particles = pd.DataFrame() + new_particles["rlnTomoName"] = [f"{name}" for i in range(len(all_particles))] + new_particles["rlnCenteredCoordinateXAngst"] = ( + all_particles["CoordinateY"] * 4 + 80 - 4092 / 2 + ) * 1.34 + new_particles["rlnCenteredCoordinateYAngst"] = ( + 5760 - all_particles["CoordinateX"] * 4 - 80 - 5760 / 2 + ) * 1.34 + new_particles["rlnCenteredCoordinateZAngst"] = ( + all_particles["CoordinateZ"] * 4 - 1600 / 2 + ) * 1.34 + for subsamp in range(2, max_subsample + 1): + if Path( + f"AutoPick/job009/CBOX_3D/{name}_{subsamp}_stack_aretomo.denoised.cbox" + ).is_file(): + cbox = starfile.read( + f"AutoPick/job009/CBOX_3D/{name}_{subsamp}_stack_aretomo.denoised.cbox" + ) + particles = cbox["cryolo"] + add_particles = pd.DataFrame() + add_particles["rlnTomoName"] = [ + f"{name}_{subsamp}" for i in range(len(particles)) + ] + add_particles["rlnCenteredCoordinateXAngst"] = ( + particles["CoordinateY"] * 4 + 80 - 4092 / 2 + ) * 1.34 + add_particles["rlnCenteredCoordinateYAngst"] = ( + 5760 - particles["CoordinateX"] * 4 - 80 - 5760 / 2 + ) * 1.34 + add_particles["rlnCenteredCoordinateZAngst"] = ( + particles["CoordinateZ"] * 4 - 1600 / 2 + ) * 1.34 + new_particles = pd.concat((new_particles, add_particles)) + starfile.write(new_particles, f"AutoPick/job009/{name}_all_particles_centered.star") + + +def cbox_to_star_whole_dir(pixel_size=1.63, xdim=5760, ydim=4092, zdim=1600): + import pandas as pd + import starfile + + # make data_particles + new_particles = None + + for cbox in Path("AutoPick/job009/CBOX_3D").glob("*.cbox"): + all_particles = starfile.read(cbox)["cryolo"] + + particles_to_drop = [] + for pindex, particle in all_particles.iterrows(): + if ( + particle["CoordinateZ"] < particle["Depth"] / 2 + or zdim / 4 - particle["CoordinateZ"] < particle["Depth"] / 2 + ): + particles_to_drop.append(pindex) + print(cbox, particles_to_drop) + all_particles.drop(labels=particles_to_drop, axis=0, inplace=True) + + add_particles = pd.DataFrame() + add_particles["rlnCenteredCoordinateXAngst"] = ( + all_particles["CoordinateY"] * 4 + 80 - ydim / 2 + ) * pixel_size + add_particles["rlnCenteredCoordinateYAngst"] = ( + xdim / 2 - all_particles["CoordinateX"] * 4 - 80 + ) * pixel_size + add_particles["rlnCenteredCoordinateZAngst"] = ( + all_particles["CoordinateZ"] * 4 - zdim / 2 + ) * pixel_size + add_particles["rlnTomoName"] = [ + cbox.name.split("_stack_")[0] for i in range(len(all_particles)) + ] + if new_particles is None: + new_particles = add_particles + else: + new_particles = pd.concat((new_particles, add_particles)) + starfile.write(new_particles, "AutoPick/job009/all_particles_centered.star") + + +def cbox_to_star_whole_dir_noflip(): + import pandas as pd + import starfile + + # make data_particles + new_particles = None + + for cbox in Path("AutoPick/job009/CBOX_3D").glob("*.cbox"): + all_particles = starfile.read(cbox)["cryolo"] + add_particles = pd.DataFrame() + add_particles["rlnCenteredCoordinateXAngst"] = ( + all_particles["CoordinateX"] * 4 + 80 - 5760 / 2 + ) * 1.63 + add_particles["rlnCenteredCoordinateYAngst"] = ( + all_particles["CoordinateY"] * 4 + 80 - 4092 / 2 + ) * 1.63 + add_particles["rlnCenteredCoordinateZAngst"] = ( + all_particles["CoordinateZ"] * 4 - 1600 / 2 + ) * 1.63 + add_particles["rlnTomoName"] = [ + cbox.name.split("_stack_")[0] for i in range(len(all_particles)) + ] + if new_particles is None: + new_particles = add_particles + else: + new_particles = pd.concat((new_particles, add_particles)) + starfile.write(new_particles, "AutoPick/job009/all_particles_centered_noflip.star") diff --git a/src/cryoemservices/services/select_particles.py b/src/cryoemservices/services/select_particles.py index d47f18c2..d701b1ce 100644 --- a/src/cryoemservices/services/select_particles.py +++ b/src/cryoemservices/services/select_particles.py @@ -18,6 +18,7 @@ class SelectParticlesParameters(BaseModel): batch_size: int image_size: int incomplete_batch_size: int = 10000 + tomo: bool = False relion_options: RelionServiceOptions @@ -92,7 +93,10 @@ def select_particles(self, rw, header: dict, message: dict): select_dir.mkdir(parents=True, exist_ok=True) extracted_parts_file = cif.read_file(select_params.input_file) - extracted_parts_block = extracted_parts_file.sole_block() + try: + extracted_parts_block = extracted_parts_file.sole_block() + except RuntimeError: + extracted_parts_block = extracted_parts_file.find_block("particles") extracted_parts_loop = extracted_parts_block.find_loop( "_rlnCoordinateX" ).get_loop() @@ -160,24 +164,35 @@ def select_particles(self, rw, header: dict, message: dict): ) new_split_block = new_particles_cif.add_new_block("particles") - new_split_loop = new_split_block.init_loop( - "_rln", - [ - "CoordinateX", - "CoordinateY", - "ImageName", - "MicrographName", - "OpticsGroup", - "CtfMaxResolution", - "CtfFigureOfMerit", - "DefocusU", - "DefocusV", - "DefocusAngle", - "CtfBfactor", - "CtfScalefactor", - "PhaseShift", - ], - ) + if select_params.tomo: + new_split_loop = new_split_block.init_loop( + "_rln", + [ + "TomoName", + "OpticsGroup", + "TomoParticleName", + "ImageName", + ], + ) + else: + new_split_loop = new_split_block.init_loop( + "_rln", + [ + "CoordinateX", + "CoordinateY", + "ImageName", + "MicrographName", + "OpticsGroup", + "CtfMaxResolution", + "CtfFigureOfMerit", + "DefocusU", + "DefocusV", + "DefocusAngle", + "CtfBfactor", + "CtfScalefactor", + "PhaseShift", + ], + ) num_prev_parts = 0 # While we have particles to add and the file is not full diff --git a/src/cryoemservices/util/tomo_output_files.py b/src/cryoemservices/util/tomo_output_files.py index 8f60a425..d9b892e8 100644 --- a/src/cryoemservices/util/tomo_output_files.py +++ b/src/cryoemservices/util/tomo_output_files.py @@ -68,7 +68,7 @@ def _global_tilt_series_file( str(relion_options.ampl_contrast), str(relion_options.pixel_size), str(relion_options.invert_hand), - "optics1", + "opticsGroup1", str(relion_options.pixel_size), ] @@ -561,7 +561,7 @@ def _tomogram_output_files( str(relion_options.ampl_contrast), str(relion_options.pixel_size), str(relion_options.invert_hand), - "optics1", + "opticsGroup1", str(relion_options.pixel_size), f"AlignTiltSeries/job005/tilt_series/{tilt_series_name}.star", str(relion_options.pixel_size_downscaled / relion_options.pixel_size), @@ -622,7 +622,7 @@ def _denoising_output_files( str(relion_options.ampl_contrast), str(relion_options.pixel_size), str(relion_options.invert_hand), - "optics1", + "opticsGroup1", str(relion_options.pixel_size), f"AlignTiltSeries/job005/tilt_series/{tilt_series_name}.star", str(relion_options.pixel_size_downscaled / relion_options.pixel_size), @@ -685,7 +685,7 @@ def _membrain_output_files( str(relion_options.ampl_contrast), str(relion_options.pixel_size), str(relion_options.invert_hand), - "optics1", + "opticsGroup1", str(relion_options.pixel_size), f"AlignTiltSeries/job005/tilt_series/{tilt_series_name}.star", str(relion_options.pixel_size_downscaled / relion_options.pixel_size), diff --git a/tests/services/test_cryolo_service.py b/tests/services/test_cryolo_service.py index 3a015888..8a786db2 100644 --- a/tests/services/test_cryolo_service.py +++ b/tests/services/test_cryolo_service.py @@ -175,8 +175,8 @@ def write_cbox_file(command, cwd, capture_output): "register": "picked_particles", "motion_correction_id": cryolo_test_message["mc_uuid"], "micrograph": cryolo_test_message["input_path"], - "particle_diameters": [10.0, 20.0], - "particle_count": 2, + "particle_diameters": [15.0], + "particle_count": 1, "resolution": ctf_test_values["CtfMaxResolution"], "astigmatism": ctf_test_values["DefocusU"] - ctf_test_values["DefocusV"], "defocus": (ctf_test_values["DefocusU"] + ctf_test_values["DefocusV"]) / 2, @@ -213,20 +213,18 @@ def test_cryolo_service_tomography(mock_subprocess, offline_transport, tmp_path) This should call the mock subprocess then send messages on to the node_creator, murfey_feedback, ispyb_connector and images services """ - mock_subprocess().returncode = 0 - mock_subprocess().stdout = "stdout".encode("ascii") - mock_subprocess().stderr = "stderr".encode("ascii") - header = { "message-id": mock.sentinel, "subscription": mock.sentinel, } - output_path = tmp_path / "AutoPick/job007/STAR/sample.star" + output_path = tmp_path / "AutoPick/job007/CBOX_3D/sample.cbox" cryolo_test_message = { - "input_path": "MotionCorr/job002/sample.mrc", + "input_path": "Denoise/job007/tomograms/sample_denoised.mrc", "output_path": str(output_path), + "raw_tomogram": f"{tmp_path}/Tomogram/job006/tomograms/sample.mrc", "experiment_type": "tomography", + "pixel_size": 5.3, "cryolo_box_size": 40, "cryolo_model_weights": "sample_weights", "cryolo_threshold": 0.15, @@ -246,17 +244,34 @@ def test_cryolo_service_tomography(mock_subprocess, offline_transport, tmp_path) ) output_relion_options["cryolo_box_size"] = 40 + def write_cbox_file(command, cwd, capture_output): + # Write star co-ordinate file in the format cryolo will output + (cwd / "CBOX").mkdir() + with open(cwd / "CBOX_3D/sample.cbox", "w") as f: + f.write( + "data_cryolo\n\nloop_\n\n_EstWidth\n_EstHeight\n_Confidence\n" + "_CoordinateX\n_CoordinateY\n_Width\n_Height\n" + "100 200 0.6 0.1 0.2 2 4\n150 250 0.5 0.3 0.4 6 8" + ) + return CompletedProcess( + "", + returncode=0, + stdout="stdout".encode("ascii"), + stderr="stderr".encode("ascii"), + ) + + mock_subprocess.side_effect = write_cbox_file + # Set up the mock service and send the message to it service = cryolo.CrYOLO(environment={"queue": ""}, transport=offline_transport) service.initializing() service.cryolo(None, header=header, message=cryolo_test_message) - assert mock_subprocess.call_count == 4 - mock_subprocess.assert_called_with( + mock_subprocess.assert_called_once_with( [ "cryolo_predict.py", "-i", - "MotionCorr/job002/sample.mrc", + "Denoise/job007/tomograms/sample_denoised.mrc", "--conf", str(tmp_path / "AutoPick/job007/cryolo_config.json"), "-o", @@ -299,7 +314,7 @@ def test_cryolo_service_tomography(mock_subprocess, offline_transport, tmp_path) assert config_values["other"] == {"log_path": "logs/"} # Check that the correct messages were sent - assert offline_transport.send.call_count == 4 + assert offline_transport.send.call_count == 5 offline_transport.send.assert_any_call( "ispyb_connector", { @@ -334,7 +349,7 @@ def test_cryolo_service_tomography(mock_subprocess, offline_transport, tmp_path) "output_file": str(output_path), "relion_options": output_relion_options, "command": ( - "cryolo_predict.py -i MotionCorr/job002/sample.mrc " + "cryolo_predict.py -i Denoise/job007/tomograms/sample_denoised.mrc " f"--conf {tmp_path}/AutoPick/job007/cryolo_config.json " f"-o {tmp_path}/AutoPick/job007 " f"--tomogram -tsr -1 -tmem 0 -tmin 5 --gpus 0 " @@ -347,6 +362,17 @@ def test_cryolo_service_tomography(mock_subprocess, offline_transport, tmp_path) "success": True, }, ) + offline_transport.send.assert_any_call( + "murfey_feedback", + { + "cbox_3d": f"{tmp_path}/AutoPick/job007/CBOX_3D/sample.cbox", + "particle_count": 2, + "particle_diameters": [150 * 5.3, 200 * 5.3], + "pixel_size": 5.3, + "register": "picked_tomogram", + "tomogram": f"{tmp_path}/Tomogram/job006/tomograms/sample.mrc", + }, + ) @pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows")