Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 68 additions & 22 deletions src/cryoemservices/services/select_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,36 +400,43 @@ def select_classes(self, rw, header: dict, message: dict):
# Determine the next split size to use and whether to run 3D classification
send_to_3d_classification = False
if self.previous_total_count == 0:
# First run of this job, use class3d_max_size
# First run of this job, use class3d_batch_size
next_batch_size = autoselect_params.class3d_batch_size
if self.total_count > autoselect_params.class3d_batch_size:
# Do 3D classification if there are more particles than the batch size
send_to_3d_classification = True
elif self.previous_total_count >= autoselect_params.class3d_max_size:
# Iterations beyond those where 3D classification is run
next_batch_size = autoselect_params.class3d_max_size
else:
# Re-runs with fewer particles than the maximum
previous_batch_multiple = (
self.previous_total_count // autoselect_params.class3d_batch_size
)
new_batch_multiple = (
self.total_count // autoselect_params.class3d_batch_size
)
if new_batch_multiple > previous_batch_multiple:
# Re-runs with doubling particle count each time
if self.previous_total_count >= autoselect_params.class3d_batch_size:
previous_batch_power = int(
np.log(
self.previous_total_count
// autoselect_params.class3d_batch_size
)
/ np.log(2)
)
else:
previous_batch_power = 0
if self.total_count >= autoselect_params.class3d_batch_size:
new_batch_power = int(
np.log(self.total_count // autoselect_params.class3d_batch_size)
/ np.log(2)
)
else:
new_batch_power = 0
if new_batch_power > previous_batch_power:
# Do 3D classification if a batch threshold has been crossed
send_to_3d_classification = True
# Set the batch size from the total count, but do not exceed the maximum
next_batch_size = (
new_batch_multiple * autoselect_params.class3d_batch_size
next_batch_size = int(
np.power(2, new_batch_power) * autoselect_params.class3d_batch_size
)
if next_batch_size > autoselect_params.class3d_max_size:
next_batch_size = autoselect_params.class3d_max_size
else:
# Otherwise just get the next threshold
next_batch_size = (
previous_batch_multiple + 1
) * autoselect_params.class3d_batch_size
next_batch_size = int(
np.power(2, (previous_batch_power + 1))
* autoselect_params.class3d_batch_size
)

# Run the combine star files job to split particles_all.star into batches
split_node_creator_params: dict[str, Any] = {
Expand All @@ -440,7 +447,7 @@ def select_classes(self, rw, header: dict, message: dict):
"command": (
f"combine_star_files {combine_star_dir}/particles_all.star "
f"--output_dir {combine_star_dir} "
f"--split --split_size {next_batch_size}"
f"--split --split_size {min(next_batch_size, autoselect_params.class3d_max_size)}"
),
"stdout": "",
"stderr": "",
Expand All @@ -454,7 +461,7 @@ def select_classes(self, rw, header: dict, message: dict):
split_star_file(
file_to_process=combine_star_dir / "particles_all.star",
output_dir=combine_star_dir,
split_size=next_batch_size,
split_size=min(next_batch_size, autoselect_params.class3d_max_size),
)
split_node_creator_params["success"] = True
except (IndexError, KeyError):
Expand Down Expand Up @@ -557,7 +564,10 @@ def select_classes(self, rw, header: dict, message: dict):
)

# Create 3D classification jobs
if send_to_3d_classification:
if (
send_to_3d_classification
and next_batch_size <= autoselect_params.class3d_max_size
):
# Only send to 3D if a new multiple of the batch threshold is crossed
# and the count has not passed the maximum
self.log.info("Sending to Murfey for Class3D")
Expand All @@ -579,6 +589,28 @@ def select_classes(self, rw, header: dict, message: dict):
"class3d_message": class3d_params,
}
rw.send_to("murfey_feedback", murfey_3d_params)
elif send_to_3d_classification:
found_sample = resample_best_particles(
combine_star_dir / "particles_all.star",
combine_star_dir / f"particles_best_{next_batch_size}.star",
)
if found_sample:
# Tell Murfey to do Class3D
class3d_params = {
"particles_file": f"{combine_star_dir}/particles_best_{next_batch_size}.star",
"class3d_dir": f"{project_dir}/Class3D/job",
"batch_size": next_batch_size,
}
murfey_3d_params = {
"register": "run_class3d",
"class3d_message": class3d_params,
}
rw.send_to("murfey_feedback", murfey_3d_params)
else:
self.log.warning(
"Cannot rerun Class3D as no scores available in "
f"{combine_star_dir}/particles_all.star"
)

murfey_confirmation = {
"register": "done_class_selection",
Expand All @@ -592,3 +624,17 @@ def select_classes(self, rw, header: dict, message: dict):

self.log.info(f"Done {self.job_type} for {autoselect_params.input_file}.")
rw.transport.ack(header)


def resample_best_particles(particles_all: Path, output_star: Path):
data = starfile.read(particles_all)
if "rlnCryodannScore" in data["particles"].keys():
cutoff_score = sorted(data["particles"]["rlnCryodannScore"], reverse=True)[
200000
]
data["particles"] = data["particles"][
data["particles"]["rlnCryodannScore"] > cutoff_score
]
starfile.write(data, output_star)
return True
return False
2 changes: 0 additions & 2 deletions src/cryoemservices/util/relion_service_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ class RelionServiceOptions(BaseModel):
autoselect_min_score: float = 0.7
# 2D classification particle batch size
batch_size: int = 50000
# Maximum batch size for the single batch of 3D classification
class3d_max_size: int = 200000
# Initial lowpass filter on 3D reference
initial_lowpass: int = 40

Expand Down
10 changes: 6 additions & 4 deletions src/cryoemservices/wrappers/class3d_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,9 @@ def run_class3d(class3d_params: Class3DParameters, send_to_rabbitmq: Callable) -
)
return False

# Actual batch size may not be the same as the input value, check it in the data
true_batch_size = class3d_params.batch_size

# Generate healpix image of the particle distribution
logger.info("Generating healpix angular distribution image")
data = cif.read_file(
Expand All @@ -417,6 +420,7 @@ def run_class3d(class3d_params: Class3DParameters, send_to_rabbitmq: Callable) -
class_numbers = np.array(
particles_block.find_loop("_rlnClassNumber"), dtype=int
)
true_batch_size = len(angles_rot)

for class_id in range(class3d_params.class3d_nr_classes):
if not len(angles_tilt[class_numbers == class_id + 1]):
Expand Down Expand Up @@ -492,9 +496,7 @@ def run_class3d(class3d_params: Class3DParameters, send_to_rabbitmq: Callable) -
f"{class3d_params.class3d_dir}/"
f"run_it{class3d_params.class3d_nr_iter:03}_class{class_id + 1:03}.mrc"
),
"particles_per_class": (
float(classes_loop[class_id, 1]) * class3d_params.batch_size
),
"particles_per_class": float(classes_loop[class_id, 1]) * true_batch_size,
"class_distribution": classes_loop[class_id, 1],
"rotation_accuracy": classes_loop[class_id, 2],
"translation_accuracy": classes_loop[class_id, 3],
Expand Down Expand Up @@ -580,7 +582,7 @@ def run_class3d(class3d_params: Class3DParameters, send_to_rabbitmq: Callable) -
)
for cid in class_sorting:
if (
class3d_params.batch_size == 200000
class3d_params.batch_size >= 200000
and class_resolutions[cid] < 11
and (class_efficiencies[cid] > 0.65 or class3d_params.symmetry != "C1")
):
Expand Down
16 changes: 8 additions & 8 deletions tests/services/test_class3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def test_class3d_service_has_initial_model(
"data_model_classes\nloop_\n"
"_rlnReferenceImage\n_Fraction\n_Rotation\n_Translation\n"
"_Resolution\n_Completeness\n_OffsetX\n_OffsetY\n"
"1@Class3D/job015/run_it020_classes.mrcs 0.4 30.3 33.3 12.2 1.0 0.6 0.01\n"
"2@Class3D/job015/run_it020_classes.mrcs 0.6 20.2 22.2 10.0 0.9 -0.5 -0.02"
"1@Class3D/job015/run_it020_classes.mrcs 0.2 30.3 33.3 12.2 1.0 0.6 0.01\n"
"2@Class3D/job015/run_it020_classes.mrcs 0.8 20.2 22.2 10.0 0.9 -0.5 -0.02"
)

# Create a recipe wrapper with the test message
Expand Down Expand Up @@ -196,15 +196,15 @@ def test_class3d_service_has_initial_model(
},
"buffer_lookup": {"particle_classification_group_id": 5},
"buffer_store": 10,
"class_distribution": "0.4",
"class_distribution": "0.2",
"class_image_full_path": (
f"{tmp_path}/Class3D/job015/run_it020_class001.mrc"
),
"class_number": 1,
"estimated_resolution": 12.2,
"ispyb_command": "buffer",
"overall_fourier_completeness": 1.0,
"particles_per_class": 40000.0,
"particles_per_class": 0.8,
"rotation_accuracy": "30.3",
"translation_accuracy": "33.3",
"angular_efficiency": 0.6,
Expand All @@ -216,15 +216,15 @@ def test_class3d_service_has_initial_model(
},
"buffer_lookup": {"particle_classification_group_id": 5},
"buffer_store": 11,
"class_distribution": "0.6",
"class_distribution": "0.8",
"class_image_full_path": (
f"{tmp_path}/Class3D/job015/run_it020_class002.mrc"
),
"class_number": 2,
"estimated_resolution": 10.0,
"ispyb_command": "buffer",
"overall_fourier_completeness": 0.9,
"particles_per_class": 60000.0,
"particles_per_class": 3.2,
"rotation_accuracy": "20.2",
"translation_accuracy": "22.2",
"angular_efficiency": 0.6,
Expand Down Expand Up @@ -436,7 +436,7 @@ def test_class3d_service_rerun(
"estimated_resolution": 12.2,
"ispyb_command": "buffer",
"overall_fourier_completeness": 1.0,
"particles_per_class": 40000.0,
"particles_per_class": 1.6,
"rotation_accuracy": "30.3",
"translation_accuracy": "33.3",
"angular_efficiency": 0.6,
Expand All @@ -458,7 +458,7 @@ def test_class3d_service_rerun(
"estimated_resolution": 10.0,
"ispyb_command": "buffer",
"overall_fourier_completeness": 0.9,
"particles_per_class": 60000.0,
"particles_per_class": 2.4,
"rotation_accuracy": "20.2",
"translation_accuracy": "22.2",
"angular_efficiency": 0.6,
Expand Down
63 changes: 59 additions & 4 deletions tests/services/test_select_classes_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,27 @@ def select_classes_common_setup(
particles_file.parent.mkdir(parents=True)
with open(particles_file, "w") as f:
f.write("data_optics\n\nloop_\n_group\nopticsGroup1\n\n")
f.write("data_particles\n\nloop_\n_x\n_y\n_particle\n_movie\n")
f.write(
"data_particles\n\nloop_\n_x\n_y\n_particle\n_movie\n_rlnCryodannScore\n"
)
for i in range(particles_to_add):
f.write(
f"{i / 100} {i / 100} {i}@Extract/job008/classes.mrcs "
f"MotionCorr/job002/Movies/movie.mrc\n"
f"MotionCorr/job002/Movies/movie.mrc {np.random.random()}\n"
)

if initial_particle_count:
particles_file = job_dir / "Select/job013/particles_all.star"
particles_file.parent.mkdir(parents=True)
with open(particles_file, "w") as f:
f.write("data_optics\n\nloop_\n_group\nopticsGroup1\n\n")
f.write("data_particles\n\nloop_\n_x\n_y\n_particle\n_movie\n")
f.write(
"data_particles\n\nloop_\n_x\n_y\n_particle\n_movie\n_rlnCryodannScore\n"
)
for i in range(initial_particle_count):
f.write(
f"{i/100} {i/100} {i}@Extract/job008/classes.mrcs "
f"MotionCorr/job002/Movies/movie.mrc\n"
f"MotionCorr/job002/Movies/movie.mrc {np.random.random()}\n"
)

Path(job_dir / "MotionCorr/job002/Movies").mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -609,6 +613,57 @@ def test_select_classes_service_past_maximum(
assert len(offline_transport.send.call_args_list) == 7


@mock.patch("cryoemservices.services.select_classes.subprocess.run")
def test_select_classes_service_do_batch_past_maximum(
mock_subprocess, offline_transport, tmp_path
):
"""
Test the service for the case where the existing particle count exceeds the maximum.
In this case the next power (400000) is crossed so
3D classification should be run for a subset
"""
mock_subprocess().returncode = 0
mock_subprocess().stdout = "stdout".encode("ascii")
mock_subprocess().stderr = "stderr".encode("ascii")

header = {
"message-id": mock.sentinel,
"subscription": mock.sentinel,
}
select_test_message, relion_options = select_classes_common_setup(
tmp_path, initial_particle_count=390000, particles_to_add=20000
)

# Set up the mock service and send the message to it
service = select_classes.SelectClasses(
environment={"queue": ""}, transport=offline_transport
)
service.initializing()
service.select_classes(None, header=header, message=select_test_message)

# Check the correct particle counts were found and split files made
assert service.previous_total_count == 390000
assert service.total_count == 410000
assert (tmp_path / "Select/job013/particles_split1.star").is_file()
assert (tmp_path / "Select/job013/particles_split2.star").is_file()
assert len(list(tmp_path.glob("Select/job013/particles_batch_*"))) == 0

# Don't bother to check the auto-selection calls here, they are checked above
# Do check the Murfey 3D calls
assert len(offline_transport.send.call_args_list) == 8
offline_transport.send.assert_any_call(
"murfey_feedback",
{
"register": "run_class3d",
"class3d_message": {
"particles_file": f"{tmp_path}/Select/job013/particles_best_400000.star",
"class3d_dir": f"{tmp_path}/Class3D/job",
"batch_size": 400000,
},
},
)


def test_parse_combiner_output(offline_transport):
"""
Send test lines to the output parser
Expand Down
Loading