From 5224c232049ad91ddc54386f1b75443c38507e11 Mon Sep 17 00:00:00 2001 From: yxd92326 Date: Fri, 25 Jul 2025 15:31:19 +0100 Subject: [PATCH 1/3] Request 3D classification again for more than 20000 particles using cryodann scores --- src/cryoemservices/services/select_classes.py | 90 ++++++++++++++----- tests/services/test_select_classes_service.py | 63 ++++++++++++- 2 files changed, 127 insertions(+), 26 deletions(-) diff --git a/src/cryoemservices/services/select_classes.py b/src/cryoemservices/services/select_classes.py index 566c5386..ac54f545 100644 --- a/src/cryoemservices/services/select_classes.py +++ b/src/cryoemservices/services/select_classes.py @@ -400,36 +400,43 @@ def select_classes(self, rw, header: dict, message: dict): # Determine the next split size to use and whether to run 3D classification send_to_3d_classification = False if self.previous_total_count == 0: - # First run of this job, use class3d_max_size + # First run of this job, use class3d_batch_size next_batch_size = autoselect_params.class3d_batch_size if self.total_count > autoselect_params.class3d_batch_size: # Do 3D classification if there are more particles than the batch size send_to_3d_classification = True - elif self.previous_total_count >= autoselect_params.class3d_max_size: - # Iterations beyond those where 3D classification is run - next_batch_size = autoselect_params.class3d_max_size else: - # Re-runs with fewer particles than the maximum - previous_batch_multiple = ( - self.previous_total_count // autoselect_params.class3d_batch_size - ) - new_batch_multiple = ( - self.total_count // autoselect_params.class3d_batch_size - ) - if new_batch_multiple > previous_batch_multiple: + # Re-runs with doubling particle count each time + if self.previous_total_count >= autoselect_params.class3d_batch_size: + previous_batch_power = int( + np.log( + self.previous_total_count + // autoselect_params.class3d_batch_size + ) + / np.log(2) + ) + else: + previous_batch_power = 0 + if self.total_count >= autoselect_params.class3d_batch_size: + new_batch_power = int( + np.log(self.total_count // autoselect_params.class3d_batch_size) + / np.log(2) + ) + else: + new_batch_power = 0 + if new_batch_power > previous_batch_power: # Do 3D classification if a batch threshold has been crossed send_to_3d_classification = True # Set the batch size from the total count, but do not exceed the maximum - next_batch_size = ( - new_batch_multiple * autoselect_params.class3d_batch_size + next_batch_size = int( + np.power(2, new_batch_power) * autoselect_params.class3d_batch_size ) - if next_batch_size > autoselect_params.class3d_max_size: - next_batch_size = autoselect_params.class3d_max_size else: # Otherwise just get the next threshold - next_batch_size = ( - previous_batch_multiple + 1 - ) * autoselect_params.class3d_batch_size + next_batch_size = int( + np.power(2, (previous_batch_power + 1)) + * autoselect_params.class3d_batch_size + ) # Run the combine star files job to split particles_all.star into batches split_node_creator_params: dict[str, Any] = { @@ -440,7 +447,7 @@ def select_classes(self, rw, header: dict, message: dict): "command": ( f"combine_star_files {combine_star_dir}/particles_all.star " f"--output_dir {combine_star_dir} " - f"--split --split_size {next_batch_size}" + f"--split --split_size {min(next_batch_size, autoselect_params.class3d_max_size)}" ), "stdout": "", "stderr": "", @@ -454,7 +461,7 @@ def select_classes(self, rw, header: dict, message: dict): split_star_file( file_to_process=combine_star_dir / "particles_all.star", output_dir=combine_star_dir, - split_size=next_batch_size, + split_size=min(next_batch_size, autoselect_params.class3d_max_size), ) split_node_creator_params["success"] = True except (IndexError, KeyError): @@ -557,7 +564,10 @@ def select_classes(self, rw, header: dict, message: dict): ) # Create 3D classification jobs - if send_to_3d_classification: + if ( + send_to_3d_classification + and next_batch_size <= autoselect_params.class3d_max_size + ): # Only send to 3D if a new multiple of the batch threshold is crossed # and the count has not passed the maximum self.log.info("Sending to Murfey for Class3D") @@ -579,6 +589,28 @@ def select_classes(self, rw, header: dict, message: dict): "class3d_message": class3d_params, } rw.send_to("murfey_feedback", murfey_3d_params) + elif send_to_3d_classification: + found_sample = resample_best_particles( + combine_star_dir / "particles_all.star", + combine_star_dir / f"particles_best_{next_batch_size}.star", + ) + if found_sample: + # Tell Murfey to do Class3D + class3d_params = { + "particles_file": f"{combine_star_dir}/particles_best_{next_batch_size}.star", + "class3d_dir": f"{project_dir}/Class3D/job", + "batch_size": next_batch_size, + } + murfey_3d_params = { + "register": "run_class3d", + "class3d_message": class3d_params, + } + rw.send_to("murfey_feedback", murfey_3d_params) + else: + self.log.warning( + "Cannot rerun Class3D as no scores available in " + f"{combine_star_dir}/particles_all.star" + ) murfey_confirmation = { "register": "done_class_selection", @@ -592,3 +624,17 @@ def select_classes(self, rw, header: dict, message: dict): self.log.info(f"Done {self.job_type} for {autoselect_params.input_file}.") rw.transport.ack(header) + + +def resample_best_particles(particles_all: Path, output_star: Path): + data = starfile.read(particles_all) + if "rlnCryodannScore" in data["particles"].keys(): + cutoff_score = sorted(data["particles"]["rlnCryodannScore"], reverse=True)[ + 200000 + ] + data["particles"] = data["particles"][ + data["particles"]["rlnCryodannScore"] > cutoff_score + ] + starfile.write(data, output_star) + return True + return False diff --git a/tests/services/test_select_classes_service.py b/tests/services/test_select_classes_service.py index 4e023298..2a4ecdf5 100644 --- a/tests/services/test_select_classes_service.py +++ b/tests/services/test_select_classes_service.py @@ -29,11 +29,13 @@ def select_classes_common_setup( particles_file.parent.mkdir(parents=True) with open(particles_file, "w") as f: f.write("data_optics\n\nloop_\n_group\nopticsGroup1\n\n") - f.write("data_particles\n\nloop_\n_x\n_y\n_particle\n_movie\n") + f.write( + "data_particles\n\nloop_\n_x\n_y\n_particle\n_movie\n_rlnCryodannScore\n" + ) for i in range(particles_to_add): f.write( f"{i / 100} {i / 100} {i}@Extract/job008/classes.mrcs " - f"MotionCorr/job002/Movies/movie.mrc\n" + f"MotionCorr/job002/Movies/movie.mrc {np.random.random()}\n" ) if initial_particle_count: @@ -41,11 +43,13 @@ def select_classes_common_setup( particles_file.parent.mkdir(parents=True) with open(particles_file, "w") as f: f.write("data_optics\n\nloop_\n_group\nopticsGroup1\n\n") - f.write("data_particles\n\nloop_\n_x\n_y\n_particle\n_movie\n") + f.write( + "data_particles\n\nloop_\n_x\n_y\n_particle\n_movie\n_rlnCryodannScore\n" + ) for i in range(initial_particle_count): f.write( f"{i/100} {i/100} {i}@Extract/job008/classes.mrcs " - f"MotionCorr/job002/Movies/movie.mrc\n" + f"MotionCorr/job002/Movies/movie.mrc {np.random.random()}\n" ) Path(job_dir / "MotionCorr/job002/Movies").mkdir(parents=True, exist_ok=True) @@ -609,6 +613,57 @@ def test_select_classes_service_past_maximum( assert len(offline_transport.send.call_args_list) == 7 +@mock.patch("cryoemservices.services.select_classes.subprocess.run") +def test_select_classes_service_do_batch_past_maximum( + mock_subprocess, offline_transport, tmp_path +): + """ + Test the service for the case where the existing particle count exceeds the maximum. + In this case the next power (400000) is crossed so + 3D classification should be run for a subset + """ + mock_subprocess().returncode = 0 + mock_subprocess().stdout = "stdout".encode("ascii") + mock_subprocess().stderr = "stderr".encode("ascii") + + header = { + "message-id": mock.sentinel, + "subscription": mock.sentinel, + } + select_test_message, relion_options = select_classes_common_setup( + tmp_path, initial_particle_count=390000, particles_to_add=20000 + ) + + # Set up the mock service and send the message to it + service = select_classes.SelectClasses( + environment={"queue": ""}, transport=offline_transport + ) + service.initializing() + service.select_classes(None, header=header, message=select_test_message) + + # Check the correct particle counts were found and split files made + assert service.previous_total_count == 390000 + assert service.total_count == 410000 + assert (tmp_path / "Select/job013/particles_split1.star").is_file() + assert (tmp_path / "Select/job013/particles_split2.star").is_file() + assert len(list(tmp_path.glob("Select/job013/particles_batch_*"))) == 0 + + # Don't bother to check the auto-selection calls here, they are checked above + # Do check the Murfey 3D calls + assert len(offline_transport.send.call_args_list) == 8 + offline_transport.send.assert_any_call( + "murfey_feedback", + { + "register": "run_class3d", + "class3d_message": { + "particles_file": f"{tmp_path}/Select/job013/particles_best_400000.star", + "class3d_dir": f"{tmp_path}/Class3D/job", + "batch_size": 400000, + }, + }, + ) + + def test_parse_combiner_output(offline_transport): """ Send test lines to the output parser From 7804684c8018ccaf62760c98779fbffdeb7609dc Mon Sep 17 00:00:00 2001 From: yxd92326 Date: Wed, 30 Jul 2025 16:10:10 +0100 Subject: [PATCH 2/3] Allow refinement as long as batch size at least 200000 --- src/cryoemservices/util/relion_service_options.py | 2 -- src/cryoemservices/wrappers/class3d_wrapper.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/cryoemservices/util/relion_service_options.py b/src/cryoemservices/util/relion_service_options.py index 2b953499..d4648eb4 100644 --- a/src/cryoemservices/util/relion_service_options.py +++ b/src/cryoemservices/util/relion_service_options.py @@ -123,8 +123,6 @@ class RelionServiceOptions(BaseModel): autoselect_min_score: float = 0.7 # 2D classification particle batch size batch_size: int = 50000 - # Maximum batch size for the single batch of 3D classification - class3d_max_size: int = 200000 # Initial lowpass filter on 3D reference initial_lowpass: int = 40 diff --git a/src/cryoemservices/wrappers/class3d_wrapper.py b/src/cryoemservices/wrappers/class3d_wrapper.py index 53dc97e3..5afe7cd7 100644 --- a/src/cryoemservices/wrappers/class3d_wrapper.py +++ b/src/cryoemservices/wrappers/class3d_wrapper.py @@ -580,7 +580,7 @@ def run_class3d(class3d_params: Class3DParameters, send_to_rabbitmq: Callable) - ) for cid in class_sorting: if ( - class3d_params.batch_size == 200000 + class3d_params.batch_size >= 200000 and class_resolutions[cid] < 11 and (class_efficiencies[cid] > 0.65 or class3d_params.symmetry != "C1") ): From 0c01f5c1691df6bedac0347e228401d851960065 Mon Sep 17 00:00:00 2001 From: yxd92326 Date: Wed, 30 Jul 2025 16:42:37 +0100 Subject: [PATCH 3/3] Work out true class sizes if supplied with over 200000 particles --- src/cryoemservices/wrappers/class3d_wrapper.py | 8 +++++--- tests/services/test_class3d.py | 16 ++++++++-------- tests/wrappers/test_class3d_wrapper.py | 16 ++++++++-------- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/cryoemservices/wrappers/class3d_wrapper.py b/src/cryoemservices/wrappers/class3d_wrapper.py index 5afe7cd7..4b56c87b 100644 --- a/src/cryoemservices/wrappers/class3d_wrapper.py +++ b/src/cryoemservices/wrappers/class3d_wrapper.py @@ -401,6 +401,9 @@ def run_class3d(class3d_params: Class3DParameters, send_to_rabbitmq: Callable) - ) return False + # Actual batch size may not be the same as the input value, check it in the data + true_batch_size = class3d_params.batch_size + # Generate healpix image of the particle distribution logger.info("Generating healpix angular distribution image") data = cif.read_file( @@ -417,6 +420,7 @@ def run_class3d(class3d_params: Class3DParameters, send_to_rabbitmq: Callable) - class_numbers = np.array( particles_block.find_loop("_rlnClassNumber"), dtype=int ) + true_batch_size = len(angles_rot) for class_id in range(class3d_params.class3d_nr_classes): if not len(angles_tilt[class_numbers == class_id + 1]): @@ -492,9 +496,7 @@ def run_class3d(class3d_params: Class3DParameters, send_to_rabbitmq: Callable) - f"{class3d_params.class3d_dir}/" f"run_it{class3d_params.class3d_nr_iter:03}_class{class_id + 1:03}.mrc" ), - "particles_per_class": ( - float(classes_loop[class_id, 1]) * class3d_params.batch_size - ), + "particles_per_class": float(classes_loop[class_id, 1]) * true_batch_size, "class_distribution": classes_loop[class_id, 1], "rotation_accuracy": classes_loop[class_id, 2], "translation_accuracy": classes_loop[class_id, 3], diff --git a/tests/services/test_class3d.py b/tests/services/test_class3d.py index 544592ab..740cb3d1 100644 --- a/tests/services/test_class3d.py +++ b/tests/services/test_class3d.py @@ -86,8 +86,8 @@ def test_class3d_service_has_initial_model( "data_model_classes\nloop_\n" "_rlnReferenceImage\n_Fraction\n_Rotation\n_Translation\n" "_Resolution\n_Completeness\n_OffsetX\n_OffsetY\n" - "1@Class3D/job015/run_it020_classes.mrcs 0.4 30.3 33.3 12.2 1.0 0.6 0.01\n" - "2@Class3D/job015/run_it020_classes.mrcs 0.6 20.2 22.2 10.0 0.9 -0.5 -0.02" + "1@Class3D/job015/run_it020_classes.mrcs 0.2 30.3 33.3 12.2 1.0 0.6 0.01\n" + "2@Class3D/job015/run_it020_classes.mrcs 0.8 20.2 22.2 10.0 0.9 -0.5 -0.02" ) # Create a recipe wrapper with the test message @@ -196,7 +196,7 @@ def test_class3d_service_has_initial_model( }, "buffer_lookup": {"particle_classification_group_id": 5}, "buffer_store": 10, - "class_distribution": "0.4", + "class_distribution": "0.2", "class_image_full_path": ( f"{tmp_path}/Class3D/job015/run_it020_class001.mrc" ), @@ -204,7 +204,7 @@ def test_class3d_service_has_initial_model( "estimated_resolution": 12.2, "ispyb_command": "buffer", "overall_fourier_completeness": 1.0, - "particles_per_class": 40000.0, + "particles_per_class": 0.8, "rotation_accuracy": "30.3", "translation_accuracy": "33.3", "angular_efficiency": 0.6, @@ -216,7 +216,7 @@ def test_class3d_service_has_initial_model( }, "buffer_lookup": {"particle_classification_group_id": 5}, "buffer_store": 11, - "class_distribution": "0.6", + "class_distribution": "0.8", "class_image_full_path": ( f"{tmp_path}/Class3D/job015/run_it020_class002.mrc" ), @@ -224,7 +224,7 @@ def test_class3d_service_has_initial_model( "estimated_resolution": 10.0, "ispyb_command": "buffer", "overall_fourier_completeness": 0.9, - "particles_per_class": 60000.0, + "particles_per_class": 3.2, "rotation_accuracy": "20.2", "translation_accuracy": "22.2", "angular_efficiency": 0.6, @@ -436,7 +436,7 @@ def test_class3d_service_rerun( "estimated_resolution": 12.2, "ispyb_command": "buffer", "overall_fourier_completeness": 1.0, - "particles_per_class": 40000.0, + "particles_per_class": 1.6, "rotation_accuracy": "30.3", "translation_accuracy": "33.3", "angular_efficiency": 0.6, @@ -458,7 +458,7 @@ def test_class3d_service_rerun( "estimated_resolution": 10.0, "ispyb_command": "buffer", "overall_fourier_completeness": 0.9, - "particles_per_class": 60000.0, + "particles_per_class": 2.4, "rotation_accuracy": "20.2", "translation_accuracy": "22.2", "angular_efficiency": 0.6, diff --git a/tests/wrappers/test_class3d_wrapper.py b/tests/wrappers/test_class3d_wrapper.py index 53789a05..0653b501 100644 --- a/tests/wrappers/test_class3d_wrapper.py +++ b/tests/wrappers/test_class3d_wrapper.py @@ -347,7 +347,7 @@ def test_class3d_wrapper_do_initial_model( "estimated_resolution": 12.2, "ispyb_command": "buffer", "overall_fourier_completeness": 1.0, - "particles_per_class": 20000.0, + "particles_per_class": 1.6, "rotation_accuracy": "30.3", "translation_accuracy": "33.3", "angular_efficiency": 0.7, @@ -367,7 +367,7 @@ def test_class3d_wrapper_do_initial_model( "estimated_resolution": 10.0, "ispyb_command": "buffer", "overall_fourier_completeness": 0.9, - "particles_per_class": 30000.0, + "particles_per_class": 2.4, "rotation_accuracy": "20.2", "translation_accuracy": "22.2", "angular_efficiency": 0.7, @@ -483,8 +483,8 @@ def test_class3d_wrapper_has_initial_model( "data_model_classes\nloop_\n" "_rlnReferenceImage\n_Fraction\n_Rotation\n_Translation\n" "_Resolution\n_Completeness\n_OffsetX\n_OffsetY\n" - "1@Class3D/job015/run_it020_classes.mrcs 0.4 30.3 33.3 12.2 1.0 0.6 0.01\n" - "2@Class3D/job015/run_it020_classes.mrcs 0.6 20.2 22.2 10.0 0.9 -0.5 -0.02" + "1@Class3D/job015/run_it020_classes.mrcs 0.2 30.3 33.3 12.2 1.0 0.6 0.01\n" + "2@Class3D/job015/run_it020_classes.mrcs 0.8 20.2 22.2 10.0 0.9 -0.5 -0.02" ) # Create a recipe wrapper with the test message @@ -592,7 +592,7 @@ def test_class3d_wrapper_has_initial_model( }, "buffer_lookup": {"particle_classification_group_id": 5}, "buffer_store": 10, - "class_distribution": "0.4", + "class_distribution": "0.2", "class_image_full_path": ( f"{tmp_path}/Class3D/job015/run_it020_class001.mrc" ), @@ -600,7 +600,7 @@ def test_class3d_wrapper_has_initial_model( "estimated_resolution": 12.2, "ispyb_command": "buffer", "overall_fourier_completeness": 1.0, - "particles_per_class": 40000.0, + "particles_per_class": 0.8, "rotation_accuracy": "30.3", "translation_accuracy": "33.3", "angular_efficiency": 0.6, @@ -612,7 +612,7 @@ def test_class3d_wrapper_has_initial_model( }, "buffer_lookup": {"particle_classification_group_id": 5}, "buffer_store": 11, - "class_distribution": "0.6", + "class_distribution": "0.8", "class_image_full_path": ( f"{tmp_path}/Class3D/job015/run_it020_class002.mrc" ), @@ -620,7 +620,7 @@ def test_class3d_wrapper_has_initial_model( "estimated_resolution": 10.0, "ispyb_command": "buffer", "overall_fourier_completeness": 0.9, - "particles_per_class": 60000.0, + "particles_per_class": 3.2, "rotation_accuracy": "20.2", "translation_accuracy": "22.2", "angular_efficiency": 0.6,