diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index 6d6e308ee..8d9ea4819 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -576,31 +576,72 @@ def roh_hmm( debug = self._log.debug resolved_region: Region = parse_single_region(self, region) - del region - debug("compute windowed heterozygosity") - sample_id, sample_set, windows, counts = self._sample_count_het( + # Create params for hashing. + name = "roh_v1" + + params = dict( sample=sample, - region=resolved_region, - site_mask=site_mask, + region=region, window_size=window_size, + site_mask=site_mask, sample_set=sample_set, - chunks=chunks, - inline_array=inline_array, - ) - - debug("compute runs of homozygosity") - df_roh = self._roh_hmm_predict( - windows=windows, - counts=counts, phet_roh=phet_roh, phet_nonroh=phet_nonroh, transition=transition, - window_size=window_size, - sample_id=sample_id, - contig=resolved_region.contig, + chunks=chunks, + inline_array=inline_array, ) + del region + + # The caching struggles with saving variable length strings, so we can just load/save the numeric data, and + # add the strings (sample ID and contig) from user input. + try: + results = self.results_cache_get(name=name, params=params) + df_roh = pd.DataFrame(results) + df_roh["sample_id"] = sample + df_roh["contig"] = resolved_region.contig + + except CacheMiss: + debug("compute windowed heterozygosity") + sample_id, sample_set, windows, counts = self._sample_count_het( + sample=sample, + region=resolved_region, + site_mask=site_mask, + window_size=window_size, + sample_set=sample_set, + chunks=chunks, + inline_array=inline_array, + ) + + debug("compute runs of homozygosity") + df_roh = self._roh_hmm_predict( + windows=windows, + counts=counts, + phet_roh=phet_roh, + phet_nonroh=phet_nonroh, + transition=transition, + window_size=window_size, + sample_id=sample_id, + contig=resolved_region.contig, + ) + + # Specify numeric columns to save to cache. (See above - variable length strings can break the save). + columns_to_save = [ + "roh_start", + "roh_stop", + "roh_length", + "roh_is_marginal", + ] + + # Save cache + self.results_cache_set( + name=name, + params=params, + results={col: df_roh[col].to_numpy() for col in columns_to_save}, + ) + return df_roh @check_types @@ -1308,7 +1349,7 @@ def ihs_gwss( ) -> Tuple[np.ndarray, np.ndarray]: # change this name if you ever change the behaviour of this function, to # invalidate any previously cached data - name = self._ihs_gwss_cache_name + name = "roh" params = dict( contig=contig,