From a1cdee3486b3a5b49b24d35c53e1261d1229c84b Mon Sep 17 00:00:00 2001 From: Mohamed Laarej Date: Tue, 15 Apr 2025 02:56:11 +0100 Subject: [PATCH 01/10] Modernize NumPy random functions, fix mypy errors for issue#756 --- malariagen_data/anoph/sample_metadata.py | 21 +++- malariagen_data/anoph/snp_frq.py | 5 +- tests/anoph/conftest.py | 116 +++++++++++------------ tests/anoph/test_base.py | 5 +- tests/anoph/test_cnv_data.py | 11 ++- tests/anoph/test_g123.py | 31 +++++- tests/anoph/test_genome_features.py | 9 +- tests/anoph/test_genome_sequence.py | 7 +- tests/anoph/test_h12.py | 30 +++++- tests/anoph/test_hap_data.py | 10 +- tests/anoph/test_sample_metadata.py | 5 +- tests/anoph/test_snp_data.py | 11 ++- 12 files changed, 170 insertions(+), 91 deletions(-) diff --git a/malariagen_data/anoph/sample_metadata.py b/malariagen_data/anoph/sample_metadata.py index 3088508ba..f64b1b06c 100644 --- a/malariagen_data/anoph/sample_metadata.py +++ b/malariagen_data/anoph/sample_metadata.py @@ -1,6 +1,17 @@ import io from itertools import cycle -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, + cast, +) import ipyleaflet # type: ignore import numpy as np @@ -11,6 +22,7 @@ from ..util import check_types from . import base_params, map_params, plotly_params from .base import AnophelesBase +from numpy.typing import NDArray class AnophelesSampleMetadata(AnophelesBase): @@ -891,8 +903,11 @@ def _prep_sample_selection_cache_params( # integer indices instead. df_samples = self.sample_metadata(sample_sets=sample_sets) sample_query_options = sample_query_options or {} - loc_samples = df_samples.eval(sample_query, **sample_query_options).values - sample_indices = np.nonzero(loc_samples)[0].tolist() + loc_samples = cast( + NDArray[Any], + df_samples.eval(sample_query, **sample_query_options).values, + ) + sample_indices = cast(List[int], np.nonzero(loc_samples)[0].tolist()) return sample_sets, sample_indices diff --git a/malariagen_data/anoph/snp_frq.py b/malariagen_data/anoph/snp_frq.py index 497b8234a..6d312b8f3 100644 --- a/malariagen_data/anoph/snp_frq.py +++ b/malariagen_data/anoph/snp_frq.py @@ -7,7 +7,6 @@ from numpydoc_decorator import doc # type: ignore import xarray as xr import numba # type: ignore - from .. import veff from ..util import ( check_types, @@ -576,8 +575,8 @@ def snp_allele_frequencies_advanced( raise ValueError("No SNPs remaining after dropping invariant SNPs.") df_variants = df_variants.loc[loc_variant].reset_index(drop=True) - count = np.compress(loc_variant, count, axis=0) - nobs = np.compress(loc_variant, nobs, axis=0) + count = np.compress(loc_variant, count, axis=0).reshape(-1, count.shape[1]) + nobs = np.compress(loc_variant, nobs, axis=0).reshape(-1, nobs.shape[1]) frequency = np.compress(loc_variant, frequency, axis=0) # Set up variant effect annotator. diff --git a/tests/anoph/conftest.py b/tests/anoph/conftest.py index 9f258c295..a6c357f8f 100644 --- a/tests/anoph/conftest.py +++ b/tests/anoph/conftest.py @@ -29,6 +29,9 @@ # real data in GCS, but which is much smaller and so can be used # for faster test runs. +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture(scope="session") def fixture_dir(): @@ -37,10 +40,10 @@ def fixture_dir(): def simulate_contig(*, low, high, base_composition): - size = np.random.randint(low=low, high=high) + size = rng.integers(low=low, high=high) bases = np.array([b"a", b"c", b"g", b"t", b"n", b"A", b"C", b"G", b"T", b"N"]) p = np.array([base_composition[b] for b in bases]) - seq = np.random.choice(bases, size=size, replace=True, p=p) + seq = rng.choice(bases, size=size, replace=True, p=p) return seq @@ -363,7 +366,7 @@ def simulate_site_filters(path, contigs, p_pass, n_sites): for contig in contigs: variants = root.require_group(contig).require_group("variants") size = n_sites[contig] - filter_pass = np.random.choice([False, True], size=size, p=p) + filter_pass = rng.choice([False, True], size=size, p=p) variants.create_dataset(name="filter_pass", data=filter_pass) zarr.consolidate_metadata(path) @@ -386,7 +389,7 @@ def simulate_snp_genotypes( contig_n_sites = n_sites[contig] # Simulate genotype calls. - gt = np.random.choice( + gt = rng.choice( np.arange(4, dtype="i1"), size=(contig_n_sites, n_samples, 2), replace=True, @@ -395,9 +398,7 @@ def simulate_snp_genotypes( # Simulate missing calls. n_calls = contig_n_sites * n_samples - loc_missing = np.random.choice( - [False, True], size=n_calls, replace=True, p=p_missing - ) + loc_missing = rng.choice([False, True], size=n_calls, replace=True, p=p_missing) gt.reshape(-1, 2)[loc_missing] = -1 # Store genotype calls. @@ -438,7 +439,7 @@ def simulate_site_annotations(path, genome): p = [0.897754, 0.0, 0.060577, 0.014287, 0.011096, 0.016286] for contig in contigs: size = genome[contig].shape[0] - x = np.random.choice(vals, size=size, replace=True, p=p) + x = rng.choice(vals, size=size, replace=True, p=p) grp.create_dataset(name=contig, data=x) # codon_nonsyn @@ -447,7 +448,7 @@ def simulate_site_annotations(path, genome): p = [0.91404, 0.001646, 0.018698, 0.065616] for contig in contigs: size = genome[contig].shape[0] - x = np.random.choice(vals, size=size, replace=True, p=p) + x = rng.choice(vals, size=size, replace=True, p=p) grp.create_dataset(name=contig, data=x) # codon_position @@ -456,7 +457,7 @@ def simulate_site_annotations(path, genome): p = [0.897754, 0.034082, 0.034082, 0.034082] for contig in contigs: size = genome[contig].shape[0] - x = np.random.choice(vals, size=size, replace=True, p=p) + x = rng.choice(vals, size=size, replace=True, p=p) grp.create_dataset(name=contig, data=x) # seq_cls @@ -477,28 +478,28 @@ def simulate_site_annotations(path, genome): ] for contig in contigs: size = genome[contig].shape[0] - x = np.random.choice(vals, size=size, replace=True, p=p) + x = rng.choice(vals, size=size, replace=True, p=p) grp.create_dataset(name=contig, data=x) # seq_flen grp = root.require_group("seq_flen") for contig in contigs: size = genome[contig].shape[0] - x = np.random.randint(low=0, high=40_000, size=size) + x = rng.integers(low=0, high=40_000, size=size) grp.create_dataset(name=contig, data=x) # seq_relpos_start grp = root.require_group("seq_relpos_start") for contig in contigs: size = genome[contig].shape[0] - x = np.random.beta(a=0.4, b=4, size=size) * 40_000 + x = rng.beta(a=0.4, b=4, size=size) * 40_000 grp.create_dataset(name=contig, data=x) # seq_relpos_stop grp = root.require_group("seq_relpos_stop") for contig in contigs: size = genome[contig].shape[0] - x = np.random.beta(a=0.4, b=4, size=size) * 40_000 + x = rng.beta(a=0.4, b=4, size=size) * 40_000 grp.create_dataset(name=contig, data=x) zarr.consolidate_metadata(path) @@ -514,7 +515,7 @@ def simulate_hap_sites(path, contigs, snp_sites, p_site): # Simulate POS. snp_pos = snp_sites[f"{contig}/variants/POS"][:] - loc_hap_sites = np.random.choice( + loc_hap_sites = rng.choice( [False, True], size=snp_pos.shape[0], p=[1 - p_site, p_site] ) pos = snp_pos[loc_hap_sites] @@ -527,7 +528,7 @@ def simulate_hap_sites(path, contigs, snp_sites, p_site): # Simulate ALT. snp_alt = snp_sites[f"{contig}/variants/ALT"][:] - sim_alt_choice = np.random.choice(3, size=pos.shape[0]) + sim_alt_choice = rng.choice(3, size=pos.shape[0]) alt = np.take_along_axis( snp_alt[loc_hap_sites], indices=sim_alt_choice[:, None], axis=1 )[:, 0] @@ -547,8 +548,8 @@ def simulate_aim_variants(path, contigs, snp_sites, n_sites_low, n_sites_high): for contig_index, contig in enumerate(contigs): # Simulate AIM positions variable. snp_pos = snp_sites[f"{contig}/variants/POS"][:] - loc_aim_sites = np.random.choice( - snp_pos.shape[0], size=np.random.randint(n_sites_low, n_sites_high) + loc_aim_sites = rng.choice( + snp_pos.shape[0], size=rng.integers(n_sites_low, n_sites_high) ) loc_aim_sites.sort() aim_pos = snp_pos[loc_aim_sites] @@ -564,10 +565,7 @@ def simulate_aim_variants(path, contigs, snp_sites, n_sites_low, n_sites_high): snp_alleles = np.concatenate([snp_ref[:, None], snp_alt], axis=1) aim_site_snp_alleles = snp_alleles[loc_aim_sites] sim_allele_choice = np.vstack( - [ - np.random.choice(4, size=2, replace=False) - for _ in range(len(loc_aim_sites)) - ] + [rng.choice(4, size=2, replace=False) for _ in range(len(loc_aim_sites))] ) aim_alleles = np.take_along_axis( aim_site_snp_alleles, indices=sim_allele_choice, axis=1 @@ -612,7 +610,7 @@ def simulate_cnv_hmm(zarr_path, metadata_path, contigs, contig_sizes): # - samples [1D array] [str] # Get a random probability for a sample being high variance, between 0 and 1. - p_variance = np.random.random() + p_variance = rng.random() # Open a zarr at the specified path. root = zarr.open(zarr_path, mode="w") @@ -626,11 +624,11 @@ def simulate_cnv_hmm(zarr_path, metadata_path, contigs, contig_sizes): n_samples = len(df_samples) # Simulate sample_coverage_variance array. - sample_coverage_variance = np.random.uniform(low=0, high=0.5, size=n_samples) + sample_coverage_variance = rng.uniform(low=0, high=0.5, size=n_samples) root.create_dataset(name="sample_coverage_variance", data=sample_coverage_variance) # Simulate sample_is_high_variance array. - sample_is_high_variance = np.random.choice( + sample_is_high_variance = rng.choice( [False, True], size=n_samples, p=[1 - p_variance, p_variance] ) root.create_dataset(name="sample_is_high_variance", data=sample_is_high_variance) @@ -661,9 +659,9 @@ def simulate_cnv_hmm(zarr_path, metadata_path, contigs, contig_sizes): ) # Simulate CN, NormCov, RawCov under calldata. - cn = np.random.randint(low=-1, high=12, size=(n_windows, n_samples)) - normCov = np.random.randint(low=0, high=356, size=(n_windows, n_samples)) - rawCov = np.random.randint(low=-1, high=18465, size=(n_windows, n_samples)) + cn = rng.integers(low=-1, high=12, size=(n_windows, n_samples)) + normCov = rng.integers(low=0, high=356, size=(n_windows, n_samples)) + rawCov = rng.integers(low=-1, high=18465, size=(n_windows, n_samples)) calldata_grp.create_dataset(name="CN", data=cn) calldata_grp.create_dataset(name="NormCov", data=normCov) calldata_grp.create_dataset(name="RawCov", data=rawCov) @@ -705,13 +703,13 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) # - POS [1D array] [int for n_variants] # Get a random probability for choosing allele 1, between 0 and 1. - p_allele = np.random.random() + p_allele = rng.random() # Get a random probability for passing a particular SNP site (position), between 0 and 1. - p_filter_pass = np.random.random() + p_filter_pass = rng.random() # Get a random probability for applying qMerge filter to a particular SNP site (position), between 0 and 1. - p_filter_qMerge = np.random.random() + p_filter_qMerge = rng.random() # Open a zarr at the specified path. root = zarr.open(zarr_path, mode="w") @@ -733,17 +731,15 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) contig_length_bp = contig_sizes[contig] # Get a random number of CNV alleles ("variants") to simulate. - n_cnv_alleles = np.random.randint(1, 5_000) + n_cnv_alleles = rng.integers(1, 5_000) # Produce a set of random start positions for each allele as a sorted list. - allele_start_pos = sorted( - np.random.randint(1, contig_length_bp, size=n_cnv_alleles) - ) + allele_start_pos = sorted(rng.integers(1, contig_length_bp, size=n_cnv_alleles)) # Produce a set of random allele lengths for each allele, according to a range. allele_length_bp_min = 100 allele_length_bp_max = 100_000 - allele_lengths_bp = np.random.randint( + allele_lengths_bp = rng.integers( allele_length_bp_min, allele_length_bp_max, size=n_cnv_alleles ) @@ -755,7 +751,7 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) # Simulate the genotype calls. # Note: this is only 2D, unlike SNP, HAP, AIM GT which are 3D - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_cnv_alleles, n_samples), replace=True, @@ -772,8 +768,8 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) variants_grp = contig_grp.require_group("variants") # Simulate the CIEND and CIPOS arrays under variants. - ciend = np.random.randint(low=0, high=13200, size=n_cnv_alleles) - cipos = np.random.randint(low=0, high=37200, size=n_cnv_alleles) + ciend = rng.integers(low=0, high=13200, size=n_cnv_alleles) + cipos = rng.integers(low=0, high=37200, size=n_cnv_alleles) variants_grp.create_dataset(name="CIEND", data=ciend) variants_grp.create_dataset(name="CIPOS", data=cipos) @@ -787,10 +783,10 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) variants_grp.create_dataset(name="ID", data=variant_IDs) # Simulate the filters under variants. - filter_pass = np.random.choice( + filter_pass = rng.choice( [False, True], size=n_cnv_alleles, p=[1 - p_filter_pass, p_filter_pass] ) - filter_qMerge = np.random.choice( + filter_qMerge = rng.choice( [False, True], size=n_cnv_alleles, p=[1 - p_filter_qMerge, p_filter_qMerge] ) variants_grp.create_dataset(name="FILTER_PASS", data=filter_pass) @@ -806,6 +802,8 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig_sizes): + # Initialize a default RNG with a fixed seed for general random calls + default_rng = np.random.default_rng(seed=123) # Arbitrary seed for reproducibility # zarr_path is the output path to the zarr store # metadata_path is the input path for the sample metadata # contigs is the list of contigs, e.g. Ag has ('2R', '3R', 'X') @@ -828,10 +826,10 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig # - samples [1D array] [str for n_samples] # Get a random probability for a sample being high variance, between 0 and 1. - p_variance = np.random.random() + p_variance = default_rng.random() # Get a random probability for choosing allele 1, between 0 and 1. - p_allele = np.random.random() + p_allele = default_rng.random() # Open a zarr at the specified path. root = zarr.open(zarr_path, mode="w") @@ -845,11 +843,11 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig n_samples = len(df_samples) # Simulate sample_coverage_variance array. - sample_coverage_variance = np.random.uniform(low=0, high=0.5, size=n_samples) + sample_coverage_variance = default_rng.uniform(low=0, high=0.5, size=n_samples) root.create_dataset(name="sample_coverage_variance", data=sample_coverage_variance) # Simulate sample_is_high_variance array. - sample_is_high_variance = np.random.choice( + sample_is_high_variance = default_rng.choice( [False, True], size=n_samples, p=[1 - p_variance, p_variance] ) root.create_dataset(name="sample_is_high_variance", data=sample_is_high_variance) @@ -864,7 +862,7 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig for i, contig in enumerate(contigs): # Use the same random seed per contig, otherwise n_cnv_variants (and shapes) will not align. unique_seed = fixed_seed + i - np.random.seed(unique_seed) + rng = np.random.default_rng(seed=unique_seed) # Create the contig group. contig_grp = root.require_group(contig) @@ -876,17 +874,17 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig contig_length_bp = contig_sizes[contig] # Get a random number of CNV variants to simulate. - n_cnv_variants = np.random.randint(1, 100) + n_cnv_variants = rng.integers(1, 100) # Produce a set of random start positions for each variant as a sorted list. variant_start_pos = sorted( - np.random.randint(1, contig_length_bp, size=n_cnv_variants) + rng.integers(1, contig_length_bp, size=n_cnv_variants) ) # Produce a set of random lengths for each variant, according to a range. variant_length_bp_min = 100 variant_length_bp_max = 100_000 - variant_lengths_bp = np.random.randint( + variant_lengths_bp = rng.integers( variant_length_bp_min, variant_length_bp_max, size=n_cnv_variants ) @@ -898,7 +896,7 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig # Simulate the genotype calls. # Note: this is only 2D, unlike SNP, HAP, AIM GT which are 3D - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_cnv_variants, n_samples), replace=True, @@ -915,8 +913,8 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig variants_grp = contig_grp.require_group("variants") # Simulate the StartBreakpointMethod and EndBreakpointMethod arrays. - startBreakpointMethod = np.random.randint(low=-1, high=1, size=n_cnv_variants) - endBreakpointMethod = np.random.randint(low=-1, high=2, size=n_cnv_variants) + startBreakpointMethod = rng.integers(low=-1, high=1, size=n_cnv_variants) + endBreakpointMethod = rng.integers(low=-1, high=2, size=n_cnv_variants) variants_grp.create_dataset( name="StartBreakpointMethod", data=startBreakpointMethod ) @@ -1567,7 +1565,7 @@ def init_haplotypes(self): root.create_dataset(name="samples", data=samples, dtype=str) for contig in self.contigs: n_sites = self.n_hap_sites[analysis][contig] - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_sites, n_samples, 2), replace=True, @@ -1598,7 +1596,7 @@ def init_haplotypes(self): root.create_dataset(name="samples", data=samples, dtype=str) for contig in self.contigs: n_sites = self.n_hap_sites[analysis][contig] - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_sites, n_samples, 2), replace=True, @@ -1629,7 +1627,7 @@ def init_haplotypes(self): root.create_dataset(name="samples", data=samples, dtype=str) for contig in self.contigs: n_sites = self.n_hap_sites[analysis][contig] - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_sites, n_samples, 2), replace=True, @@ -1695,7 +1693,7 @@ def init_aim_calls(self): ds["sample_id"] = ("samples",), df_samples["sample_id"] # Add call_genotype variable. - gt = np.random.choice( + gt = rng.choice( np.arange(2, dtype="i1"), size=(ds.sizes["variants"], ds.sizes["samples"], 2), replace=True, @@ -2190,7 +2188,7 @@ def init_hap_sites(self): path=path, contigs=self.contigs, snp_sites=self.snp_sites, - p_site=np.random.random(), + p_site=rng.random(), ) def init_haplotypes(self): @@ -2217,7 +2215,7 @@ def init_haplotypes(self): # Simulate haplotypes. analysis = "funestus" - p_1 = np.random.random() + p_1 = rng.random() samples = df_samples["sample_id"].values self.phasing_samples[sample_set, analysis] = samples n_samples = len(samples) @@ -2233,7 +2231,7 @@ def init_haplotypes(self): root.create_dataset(name="samples", data=samples, dtype=str) for contig in self.contigs: n_sites = self.n_hap_sites[analysis][contig] - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_sites, n_samples, 2), replace=True, diff --git a/tests/anoph/test_base.py b/tests/anoph/test_base.py index 016757011..7ca1eb730 100644 --- a/tests/anoph/test_base.py +++ b/tests/anoph/test_base.py @@ -8,6 +8,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.base import AnophelesBase +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -208,7 +211,7 @@ def test_lookup_study(fixture, api): # Set up test. df_sample_sets = api.sample_sets() all_sample_sets = df_sample_sets["sample_set"].values - sample_set = np.random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) study_rec_by_sample_set = api.lookup_study(sample_set) df_sample_set = df_sample_sets.set_index("sample_set").loc[sample_set] diff --git a/tests/anoph/test_cnv_data.py b/tests/anoph/test_cnv_data.py index fb960c93b..36f4f1476 100644 --- a/tests/anoph/test_cnv_data.py +++ b/tests/anoph/test_cnv_data.py @@ -13,6 +13,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.cnv_data import AnophelesCnvData +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -423,7 +426,7 @@ def test_cnv_hmm__max_coverage_variance(fixture, api: AnophelesCnvData): region = fixture.random_contig() # Parametrize max_coverage_variance. - parametrize_max_coverage_variance = np.random.uniform(low=0, high=1, size=4) + parametrize_max_coverage_variance = rng.uniform(low=0, high=1, size=4) for max_coverage_variance in parametrize_max_coverage_variance: ds = api.cnv_hmm( @@ -808,7 +811,7 @@ def test_plot_cnv_hmm_coverage_track(fixture, api: AnophelesCnvData): region = fixture.random_contig() df_samples = api.sample_metadata(sample_sets=sample_set) all_sample_ids = df_samples["sample_id"].values - sample_id = np.random.choice(all_sample_ids) + sample_id = rng.choice(all_sample_ids) fig = api.plot_cnv_hmm_coverage_track( sample=sample_id, @@ -857,11 +860,11 @@ def test_plot_cnv_hmm_coverage_track(fixture, api: AnophelesCnvData): def test_plot_cnv_hmm_coverage(fixture, api: AnophelesCnvData): # Set up test. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) region = fixture.random_contig() df_samples = api.sample_metadata(sample_sets=sample_set) all_sample_ids = df_samples["sample_id"].values - sample_id = np.random.choice(all_sample_ids) + sample_id = rng.choice(all_sample_ids) fig = api.plot_cnv_hmm_coverage( sample=sample_id, diff --git a/tests/anoph/test_g123.py b/tests/anoph/test_g123.py index 59b5936ca..1b567226f 100644 --- a/tests/anoph/test_g123.py +++ b/tests/anoph/test_g123.py @@ -8,6 +8,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.g123 import AnophelesG123Analysis +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -178,16 +181,34 @@ def test_g123_gwss_with_bad_sites(fixture, api: AnophelesG123Analysis): api.g123_gwss(**g123_params) +def ensure_int_list(value): + """Convert a value to a list of integers, flattening nested structures if needed.""" + if isinstance(value, int): + return [value] + + result = [] + + def extract_ints(item): + if isinstance(item, int): + result.append(item) + elif isinstance(item, (list, tuple)): + for subitem in item: + extract_ints(subitem) + + extract_ints(value) + return result + + @parametrize_with_cases("fixture,api", cases=".") def test_g123_calibration(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - window_sizes = np.random.randint(100, 500, size=random.randint(2, 5)).tolist() - window_sizes = sorted([int(x) for x in window_sizes]) + window_sizes = rng.integers(100, 500, size=random.randint(2, 5)).tolist() + window_sizes = sorted(ensure_int_list(window_sizes)) g123_params = dict( - contig=random.choice(api.contigs), - sites=random.choice(api.phasing_analysis_ids), - sample_sets=[random.choice(all_sample_sets)], + contig=rng.choice(api.contigs), + sites=rng.choice(api.phasing_analysis_ids), + sample_sets=[rng.choice(all_sample_sets)], min_cohort_size=10, window_sizes=window_sizes, ) diff --git a/tests/anoph/test_genome_features.py b/tests/anoph/test_genome_features.py index 1492af613..712742619 100644 --- a/tests/anoph/test_genome_features.py +++ b/tests/anoph/test_genome_features.py @@ -10,6 +10,9 @@ from malariagen_data.anoph.genome_features import AnophelesGenomeFeaturesData from malariagen_data.util import Region, resolve_region +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -143,7 +146,7 @@ def test_plot_genes_with_gene_labels(fixture, api: AnophelesGenomeFeaturesData): # If there are no genes, we cannot label them. if not genes_df.empty: # Get a random number of genes to sample. - random_genes_n = np.random.randint(low=1, high=len(genes_df) + 1) + random_genes_n = rng.integers(low=1, high=len(genes_df) + 1) # Get a random sample of genes. random_sample_genes_df = genes_df.sample(n=random_genes_n) @@ -164,7 +167,7 @@ def test_plot_genes_with_gene_labels(fixture, api: AnophelesGenomeFeaturesData): def test_plot_transcript(fixture, api: AnophelesGenomeFeaturesData): for contig in fixture.contigs: df_transcripts = api.genome_features(region=contig).query("type == 'mRNA'") - transcript = np.random.choice(df_transcripts["ID"].values) + transcript = rng.choice(df_transcripts["ID"].values) fig = api.plot_transcript(transcript=transcript, show=False) assert isinstance(fig, bokeh.plotting.figure) @@ -209,7 +212,7 @@ def test_genome_features_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" df = api.genome_features(region=region) assert isinstance(df, pd.DataFrame) diff --git a/tests/anoph/test_genome_sequence.py b/tests/anoph/test_genome_sequence.py index 75d80017b..a69bdf8f6 100644 --- a/tests/anoph/test_genome_sequence.py +++ b/tests/anoph/test_genome_sequence.py @@ -10,6 +10,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.genome_sequence import AnophelesGenomeSequenceData +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -77,7 +80,7 @@ def test_genome_sequence_region(fixture, api): for contig in fixture.contigs: contig_seq = api.genome_sequence(region=contig) # Pick a random start and stop position. - start, stop = sorted(np.random.randint(low=1, high=len(contig_seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(contig_seq), size=2)) region = f"{contig}:{start:,}-{stop:,}" seq = api.genome_sequence(region=region) assert isinstance(seq, da.Array) @@ -116,7 +119,7 @@ def test_genome_sequence_virtual_contigs(ag3_sim_api, chrom): ) # Test with region. - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" seq_region = api.genome_sequence(region=region) assert isinstance(seq_region, da.Array) diff --git a/tests/anoph/test_h12.py b/tests/anoph/test_h12.py index 26ff07147..78a5d0360 100644 --- a/tests/anoph/test_h12.py +++ b/tests/anoph/test_h12.py @@ -9,6 +9,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.h12 import AnophelesH12Analysis, haplotype_frequencies +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -98,15 +101,34 @@ def test_haplotype_frequencies(): assert_allclose(vals, np.array([0.2, 0.2, 0.2, 0.4])) +def ensure_int_list(value): + """Convert a value to a list of integers, and flattening nested structures if needed.""" + if isinstance(value, int): + return [value] + + result = [] + + def extract_ints(item): + if isinstance(item, int): + result.append(item) + elif isinstance(item, (list, tuple)): + for subitem in item: + extract_ints(subitem) + + extract_ints(value) + return result + + @parametrize_with_cases("fixture,api", cases=".") def test_h12_calibration(fixture, api: AnophelesH12Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - window_sizes = np.random.randint(100, 500, size=random.randint(2, 5)).tolist() - window_sizes = sorted(set([int(x) for x in window_sizes])) + window_sizes = rng.integers(100, 500, size=random.randint(2, 5)).tolist() + # Convert window_sizes to a flattened list of integers + window_sizes = sorted(set(ensure_int_list(window_sizes))) h12_params = dict( - contig=random.choice(api.contigs), - sample_sets=[random.choice(all_sample_sets)], + contig=rng.choice(api.contigs), + sample_sets=[rng.choice(all_sample_sets)], window_sizes=window_sizes, min_cohort_size=5, ) diff --git a/tests/anoph/test_hap_data.py b/tests/anoph/test_hap_data.py index 27d71775e..f054eb243 100644 --- a/tests/anoph/test_hap_data.py +++ b/tests/anoph/test_hap_data.py @@ -12,6 +12,10 @@ from malariagen_data.anoph.hap_data import AnophelesHapData +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): return AnophelesHapData( @@ -603,7 +607,9 @@ def test_haplotypes_virtual_contigs( # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted( + [int(x) for x in rng.integers(low=1, high=len(seq), size=2)] + ) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. @@ -675,7 +681,7 @@ def test_haplotype_sites_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. diff --git a/tests/anoph/test_sample_metadata.py b/tests/anoph/test_sample_metadata.py index e5b8ec8eb..c0ef29f80 100644 --- a/tests/anoph/test_sample_metadata.py +++ b/tests/anoph/test_sample_metadata.py @@ -14,6 +14,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.sample_metadata import AnophelesSampleMetadata +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -972,7 +975,7 @@ def test_lookup_sample(fixture, api): # Set up test. df_samples = api.sample_metadata() all_sample_ids = df_samples["sample_id"].values - sample_id = np.random.choice(all_sample_ids) + sample_id = rng.choice(all_sample_ids) # Check we get the same sample_id back. sample_rec_by_sample_id = api.lookup_sample(sample_id) diff --git a/tests/anoph/test_snp_data.py b/tests/anoph/test_snp_data.py index c4a611576..aa3c5ad3e 100644 --- a/tests/anoph/test_snp_data.py +++ b/tests/anoph/test_snp_data.py @@ -16,6 +16,9 @@ from malariagen_data.anoph.base_params import DEFAULT from malariagen_data.anoph.snp_data import AnophelesSnpData +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -254,7 +257,7 @@ def test_snp_sites_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. @@ -563,7 +566,7 @@ def test_snp_genotypes_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. check_snp_genotypes(api, region=region) @@ -588,7 +591,7 @@ def test_snp_variants_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" pos = api.snp_sites(region=region, field="POS").compute() ds_region = api.snp_variants(region=region) @@ -914,7 +917,7 @@ def test_snp_calls_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. From e7ef1206e03c2c5575b3033bc343f1f5f03852a3 Mon Sep 17 00:00:00 2001 From: Mohamed Laarej Date: Thu, 1 May 2025 17:07:54 +0100 Subject: [PATCH 02/10] Updates tests to consistently use the seeded NumPy random number generator (rng) instead of legacy np.random or Python's random module and unpins the NumPy version in pyproject.toml --- pyproject.toml | 2 +- tests/anoph/test_cnv_frq.py | 12 +++++++----- tests/anoph/test_distance.py | 9 ++++++--- tests/anoph/test_frq.py | 6 ++++-- tests/anoph/test_fst.py | 14 ++++++++------ tests/anoph/test_g123.py | 12 ++++++------ tests/anoph/test_h12.py | 14 +++++++------- tests/anoph/test_h1x.py | 6 ++++-- tests/anoph/test_hap_data.py | 10 +++++----- tests/anoph/test_hap_frq.py | 6 ++++-- tests/anoph/test_pca.py | 22 ++++++++++++---------- tests/anoph/test_plink_converter.py | 8 +++++--- tests/anoph/test_snp_data.py | 12 ++++++------ tests/anoph/test_snp_frq.py | 21 ++++++++++++--------- 14 files changed, 87 insertions(+), 67 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8e908a1d8..e83c9e80d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.10,<3.13" -numpy = "<2.2" +numpy = "*" numba = ">=0.60.0" llvmlite = "*" scipy = "*" diff --git a/tests/anoph/test_cnv_frq.py b/tests/anoph/test_cnv_frq.py index c90dd49f4..cd7a851cf 100644 --- a/tests/anoph/test_cnv_frq.py +++ b/tests/anoph/test_cnv_frq.py @@ -19,6 +19,8 @@ check_plot_frequencies_interactive_map, ) +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -97,7 +99,7 @@ def test_gene_cnv_frequencies_with_str_cohorts( region = random.choice(api.contigs) all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) # Set up call params. params = dict( @@ -302,7 +304,7 @@ def test_gene_cnv_frequencies_with_dict_cohorts( ): # Pick test parameters at random. sample_sets = None # all sample sets - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) region = random.choice(api.contigs) # Create cohorts by country. @@ -343,7 +345,7 @@ def test_gene_cnv_frequencies_without_drop_invariant( # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) region = random.choice(api.contigs) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) @@ -398,7 +400,7 @@ def test_gene_cnv_frequencies_with_bad_region( # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) # Set up call params. @@ -718,7 +720,7 @@ def check_gene_cnv_frequencies_advanced( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) if min_cohort_size is None: - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) # Run function under test. ds = api.gene_cnv_frequencies_advanced( diff --git a/tests/anoph/test_distance.py b/tests/anoph/test_distance.py index 9091ee454..c8695855f 100644 --- a/tests/anoph/test_distance.py +++ b/tests/anoph/test_distance.py @@ -11,6 +11,9 @@ from malariagen_data.anoph import pca_params +rng = np.random.default_rng(seed=42) + + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): return AnophelesDistanceAnalysis( @@ -81,7 +84,7 @@ def check_biallelic_diplotype_pairwise_distance(*, api, data_params, metric): ds = api.biallelic_snp_calls(**data_params) n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) + n_snps = rng.integers(4, n_snps_available) # Run the distance computation. dist, samples, n_snps_used = api.biallelic_diplotype_pairwise_distances( @@ -143,7 +146,7 @@ def check_njt(*, api, data_params, metric, algorithm): ds = api.biallelic_snp_calls(**data_params) n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) + n_snps = rng.integers(4, n_snps_available) # Run the distance computation. Z, samples, n_snps_used = api.njt( @@ -232,7 +235,7 @@ def test_plot_njt(fixture, api: AnophelesDistanceAnalysis): # Check available data. ds = api.biallelic_snp_calls(**data_params) n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) + n_snps = rng.integers(4, n_snps_available) # Exercise the function. for color, symbol in zip(colors, symbols): diff --git a/tests/anoph/test_frq.py b/tests/anoph/test_frq.py index cf972c83b..1f390e334 100644 --- a/tests/anoph/test_frq.py +++ b/tests/anoph/test_frq.py @@ -1,8 +1,10 @@ import pytest import plotly.graph_objects as go # type: ignore - +import numpy as np import random +rng = np.random.default_rng(seed=42) + def check_plot_frequencies_heatmap(api, frq_df): fig = api.plot_frequencies_heatmap(frq_df, show=False, max_len=None) @@ -65,7 +67,7 @@ def check_plot_frequencies_time_series_with_areas(api, ds): # Pick a random area and areas from valid areas. cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist() area = random.choice(cohorts_areas) - areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas))) + areas = random.sample(cohorts_areas, rng.integers(1, len(cohorts_areas))) # Plot with area. fig = api.plot_frequencies_time_series(ds, show=False, areas=area) diff --git a/tests/anoph/test_fst.py b/tests/anoph/test_fst.py index 098f08538..520f4f4b5 100644 --- a/tests/anoph/test_fst.py +++ b/tests/anoph/test_fst.py @@ -11,6 +11,8 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.fst import AnophelesFstAnalysis +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -91,7 +93,7 @@ def test_fst_gwss(fixture, api: AnophelesFstAnalysis): cohort1_query=cohort1_query, cohort2_query=cohort2_query, site_mask=random.choice(api.site_mask_ids), - window_size=random.randint(10, 50), + window_size=rng.integers(10, 50), min_cohort_size=1, ) @@ -131,7 +133,7 @@ def test_average_fst(fixture, api: AnophelesFstAnalysis): cohort2_query=cohort2_query, site_mask=random.choice(api.site_mask_ids), min_cohort_size=1, - n_jack=random.randint(10, 200), + n_jack=rng.integers(10, 200), ) # Run main gwss function under test. @@ -229,7 +231,7 @@ def test_pairwise_average_fst_with_str_cohorts( sample_sets=all_sample_sets, site_mask=site_mask, min_cohort_size=1, - n_jack=random.randint(10, 200), + n_jack=rng.integers(10, 200), ) # Run checks. @@ -249,7 +251,7 @@ def test_pairwise_average_fst_with_min_cohort_size(fixture, api: AnophelesFstAna sample_sets=all_sample_sets, site_mask=site_mask, min_cohort_size=15, - n_jack=random.randint(10, 200), + n_jack=rng.integers(10, 200), ) # Run checks. @@ -270,7 +272,7 @@ def test_pairwise_average_fst_with_dict_cohorts(fixture, api: AnophelesFstAnalys sample_sets=all_sample_sets, site_mask=site_mask, min_cohort_size=1, - n_jack=random.randint(10, 200), + n_jack=rng.integers(10, 200), ) # Run checks. @@ -294,7 +296,7 @@ def test_pairwise_average_fst_with_sample_query(fixture, api: AnophelesFstAnalys sample_query=sample_query, site_mask=site_mask, min_cohort_size=1, - n_jack=random.randint(10, 200), + n_jack=rng.integers(10, 200), ) # Run checks. diff --git a/tests/anoph/test_g123.py b/tests/anoph/test_g123.py index 26fb82ef6..5e609d816 100644 --- a/tests/anoph/test_g123.py +++ b/tests/anoph/test_g123.py @@ -108,7 +108,7 @@ def test_g123_gwss_with_default_sites(fixture, api: AnophelesG123Analysis): g123_params = dict( contig=random.choice(api.contigs), sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=10, ) @@ -124,7 +124,7 @@ def test_g123_gwss_with_phased_sites(fixture, api: AnophelesG123Analysis): contig=random.choice(api.contigs), sites=random.choice(api.phasing_analysis_ids), sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=10, ) @@ -141,7 +141,7 @@ def test_g123_gwss_with_segregating_sites(fixture, api: AnophelesG123Analysis): sites="segregating", site_mask=random.choice(api.site_mask_ids), sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=10, ) @@ -158,7 +158,7 @@ def test_g123_gwss_with_all_sites(fixture, api: AnophelesG123Analysis): sites="all", site_mask=None, sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=10, ) @@ -173,7 +173,7 @@ def test_g123_gwss_with_bad_sites(fixture, api: AnophelesG123Analysis): g123_params = dict( contig=random.choice(api.contigs), sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=10, sites="foobar", ) @@ -205,7 +205,7 @@ def extract_ints(item): def test_g123_calibration(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - window_sizes = rng.integers(100, 500, size=random.randint(2, 5)).tolist() + window_sizes = rng.integers(100, 500, size=rng.integers(2, 5)).tolist() window_sizes = sorted(ensure_int_list(window_sizes)) g123_params = dict( contig=rng.choice(api.contigs), diff --git a/tests/anoph/test_h12.py b/tests/anoph/test_h12.py index a834ebf49..b410d08c4 100644 --- a/tests/anoph/test_h12.py +++ b/tests/anoph/test_h12.py @@ -125,7 +125,7 @@ def extract_ints(item): def test_h12_calibration(fixture, api: AnophelesH12Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - window_sizes = rng.integers(100, 500, size=random.randint(2, 5)).tolist() + window_sizes = rng.integers(100, 500, size=rng.integers(2, 5)).tolist() # Convert window_sizes to a flattened list of integers window_sizes = sorted(set(ensure_int_list(window_sizes))) h12_params = dict( @@ -194,7 +194,7 @@ def test_h12_gwss_with_default_analysis(fixture, api: AnophelesH12Analysis): h12_params = dict( contig=random.choice(api.contigs), sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=5, ) @@ -208,7 +208,7 @@ def test_h12_gwss_with_analysis(fixture, api: AnophelesH12Analysis): all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = [random.choice(all_sample_sets)] contig = random.choice(api.contigs) - window_size = random.randint(100, 500) + window_size = rng.integers(100, 500) for analysis in api.phasing_analysis_ids: # Check if any samples available for the given phasing analysis. @@ -262,7 +262,7 @@ def test_h12_gwss_multi_with_default_analysis(fixture, api: AnophelesH12Analysis h12_params = dict( contig=random.choice(api.contigs), sample_sets=all_sample_sets, - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=1, cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, ) @@ -283,8 +283,8 @@ def test_h12_gwss_multi_with_window_size_dict(fixture, api: AnophelesH12Analysis contig=random.choice(api.contigs), sample_sets=all_sample_sets, window_size={ - "cohort1": random.randint(100, 500), - "cohort2": random.randint(100, 500), + "cohort1": rng.integers(100, 500), + "cohort2": rng.integers(100, 500), }, min_cohort_size=1, cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, @@ -335,7 +335,7 @@ def test_h12_gwss_multi_with_analysis(fixture, api: AnophelesH12Analysis): analysis=analysis, contig=contig, sample_sets=all_sample_sets, - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=min(n1, n2), cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, ) diff --git a/tests/anoph/test_h1x.py b/tests/anoph/test_h1x.py index 627717b57..dc66ebc42 100644 --- a/tests/anoph/test_h1x.py +++ b/tests/anoph/test_h1x.py @@ -9,6 +9,8 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.h1x import AnophelesH1XAnalysis, haplotype_joint_frequencies +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -147,7 +149,7 @@ def test_h1x_gwss_with_default_analysis(fixture, api: AnophelesH1XAnalysis): h1x_params = dict( contig=random.choice(api.contigs), sample_sets=all_sample_sets, - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=1, cohort1_query=cohort1_query, cohort2_query=cohort2_query, @@ -198,7 +200,7 @@ def test_h1x_gwss_with_analysis(fixture, api: AnophelesH1XAnalysis): analysis=analysis, contig=contig, sample_sets=all_sample_sets, - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=min(n1, n2), cohort1_query=cohort1_query, cohort2_query=cohort2_query, diff --git a/tests/anoph/test_hap_data.py b/tests/anoph/test_hap_data.py index 3d8303028..16c411154 100644 --- a/tests/anoph/test_hap_data.py +++ b/tests/anoph/test_hap_data.py @@ -470,7 +470,7 @@ def test_haplotypes_with_cohort_size_param( analysis = api.phasing_analysis_ids[0] # Parametrize over cohort_size. - parametrize_cohort_size = [random.randint(1, 10), random.randint(10, 50), 1_000] + parametrize_cohort_size = [rng.integers(1, 10), rng.integers(10, 50), 1_000] for cohort_size in parametrize_cohort_size: check_haplotypes( fixture=fixture, @@ -497,8 +497,8 @@ def test_haplotypes_with_min_cohort_size_param( # Parametrize over min_cohort_size. parametrize_min_cohort_size = [ - random.randint(1, 10), - random.randint(10, 50), + rng.integers(1, 10), + rng.integers(10, 50), 1_000, ] for min_cohort_size in parametrize_min_cohort_size: @@ -527,8 +527,8 @@ def test_haplotypes_with_max_cohort_size_param( # Parametrize over max_cohort_size. parametrize_max_cohort_size = [ - random.randint(1, 10), - random.randint(10, 50), + rng.integers(1, 10), + rng.integers(10, 50), 1_000, ] for max_cohort_size in parametrize_max_cohort_size: diff --git a/tests/anoph/test_hap_frq.py b/tests/anoph/test_hap_frq.py index 97212087a..689583010 100644 --- a/tests/anoph/test_hap_frq.py +++ b/tests/anoph/test_hap_frq.py @@ -17,6 +17,8 @@ check_plot_frequencies_interactive_map, ) +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -168,7 +170,7 @@ def test_hap_frequencies_with_str_cohorts( # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) region = fixture.random_region_str() # Set up call params. @@ -210,7 +212,7 @@ def test_hap_frequencies_advanced( ): all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) region = fixture.random_region_str() # Set up call params. diff --git a/tests/anoph/test_pca.py b/tests/anoph/test_pca.py index e5fa667a6..9ce044e45 100644 --- a/tests/anoph/test_pca.py +++ b/tests/anoph/test_pca.py @@ -11,6 +11,8 @@ from malariagen_data.anoph.pca import AnophelesPca from malariagen_data.anoph import pca_params +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -95,10 +97,10 @@ def test_pca_plotting(fixture, api: AnophelesPca): # PCA parameters. n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) + n_snps = rng.integers(4, n_snps_available) # PC3 required for plot_pca_coords_3d() assert min(n_samples, n_snps) > 3 - n_components = random.randint(3, min(n_samples, n_snps, 10)) + n_components = rng.integers(3, min(n_samples, n_snps, 10)) # Run the PCA. pca_df, pca_evr = api.pca( @@ -179,15 +181,15 @@ def test_pca_exclude_samples(fixture, api: AnophelesPca): ) # Exclusion parameters. - n_samples_excluded = random.randint(1, 5) + n_samples_excluded = rng.integers(1, 5) samples = ds["sample_id"].values.tolist() - exclude_samples = random.sample(samples, n_samples_excluded) + exclude_samples = random.sample(samples, int(n_samples_excluded)) # convert to int # PCA parameters. n_samples = ds.sizes["samples"] - n_samples_excluded n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) - n_components = random.randint(2, min(n_samples, n_snps, 10)) + n_snps = rng.integers(4, n_snps_available) + n_components = rng.integers(2, min(n_samples, n_snps, 10)) # Run the PCA. pca_df, pca_evr = api.pca( @@ -240,15 +242,15 @@ def test_pca_fit_exclude_samples(fixture, api: AnophelesPca): ) # Exclusion parameters. - n_samples_excluded = random.randint(1, 5) + n_samples_excluded = rng.integers(1, 5) samples = ds["sample_id"].values.tolist() - exclude_samples = random.sample(samples, n_samples_excluded) + exclude_samples = random.sample(samples, int(n_samples_excluded)) # convert to int # PCA parameters. n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) - n_components = random.randint(2, min(n_samples, n_snps, 10)) + n_snps = rng.integers(4, n_snps_available) + n_components = rng.integers(2, min(n_samples, n_snps, 10)) # Run the PCA. pca_df, pca_evr = api.pca( diff --git a/tests/anoph/test_plink_converter.py b/tests/anoph/test_plink_converter.py index 44e476cd9..eb75e2a3f 100644 --- a/tests/anoph/test_plink_converter.py +++ b/tests/anoph/test_plink_converter.py @@ -8,9 +8,11 @@ import os import bed_reader - +import numpy as np from numpy.testing import assert_array_equal +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -89,7 +91,7 @@ def test_plink_converter(fixture, api: PlinkConverter, tmp_path): min_minor_ac=1, max_missing_an=1, thin_offset=1, - random_seed=random.randint(1, 2000), + random_seed=rng.integers(1, 2000), ) # Load a ds containing the randomly generated samples and regions to get the number of available snps to subset from. @@ -98,7 +100,7 @@ def test_plink_converter(fixture, api: PlinkConverter, tmp_path): ) n_snps_available = ds.sizes["variants"] - n_snps = random.randint(1, n_snps_available) + n_snps = rng.integers(1, n_snps_available) # Define plink params. plink_params = dict(output_dir=str(tmp_path), n_snps=n_snps, **data_params) diff --git a/tests/anoph/test_snp_data.py b/tests/anoph/test_snp_data.py index 4d3145a5f..0afc9f905 100644 --- a/tests/anoph/test_snp_data.py +++ b/tests/anoph/test_snp_data.py @@ -860,7 +860,7 @@ def test_snp_calls_with_cohort_size_param(fixture, api: AnophelesSnpData): region = fixture.random_region_str() # Test with specific cohort size. - cohort_size = random.randint(1, 10) + cohort_size = rng.integers(1, 10) ds = api.snp_calls( sample_sets=sample_sets, region=region, @@ -1471,7 +1471,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_cohort_size_param( region = fixture.random_region_str() # Test with specific cohort size. - cohort_size = random.randint(1, 10) + cohort_size = rng.integers(1, 10) ds = api.biallelic_snp_calls( sample_sets=sample_sets, region=region, @@ -1525,8 +1525,8 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions( site_mask = random.choice((None,) + api.site_mask_ids) # Parametrise conditions. - min_minor_ac = random.randint(1, 3) - max_missing_an = random.randint(5, 10) + min_minor_ac = rng.integers(1, 3) + max_missing_an = rng.integers(5, 10) # Run tests. ds = check_biallelic_snp_calls_and_diplotypes( @@ -1554,7 +1554,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions( # This should always be true, although depends on min_minor_ac and max_missing_an, # so the range of values for those parameters needs to be chosen with some care. assert n_snps_available > 2 - n_snps_requested = random.randint(1, n_snps_available // 2) + n_snps_requested = rng.integers(1, n_snps_available // 2) ds_thinned = check_biallelic_snp_calls_and_diplotypes( api=api, sample_sets=sample_sets, @@ -1620,7 +1620,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions_fractional( # This should always be true, although depends on min_minor_ac and max_missing_an, # so the range of values for those parameters needs to be chosen with some care. assert n_snps_available > 2 - n_snps_requested = random.randint(1, n_snps_available // 2) + n_snps_requested = rng.integers(1, n_snps_available // 2) ds_thinned = check_biallelic_snp_calls_and_diplotypes( api=api, sample_sets=sample_sets, diff --git a/tests/anoph/test_snp_frq.py b/tests/anoph/test_snp_frq.py index aa803f031..d8bfb07ab 100644 --- a/tests/anoph/test_snp_frq.py +++ b/tests/anoph/test_snp_frq.py @@ -21,6 +21,9 @@ ) +rng = np.random.default_rng(seed=42) + + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): return AnophelesSnpFrequencyAnalysis( @@ -301,7 +304,7 @@ def test_allele_frequencies_with_str_cohorts( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) transcript = random_transcript(api=api) # Set up call params. @@ -561,7 +564,7 @@ def test_allele_frequencies_with_dict_cohorts( # Pick test parameters at random. sample_sets = None # all sample sets site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) transcript = random_transcript(api=api) # Create cohorts by country. @@ -615,7 +618,7 @@ def test_allele_frequencies_without_drop_invariant( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) transcript = random_transcript(api=api) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) @@ -671,7 +674,7 @@ def test_allele_frequencies_without_effects( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) transcript = random_transcript(api=api) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) @@ -753,7 +756,7 @@ def test_allele_frequencies_with_bad_transcript( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) # Set up call params. @@ -780,7 +783,7 @@ def test_allele_frequencies_with_region( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) # This should work, as long as effects=False - i.e., can get frequencies # for any genome region. @@ -838,7 +841,7 @@ def test_allele_frequencies_with_dup_samples( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_set = random.choice(all_sample_sets) site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) transcript = random_transcript(api=api) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) @@ -895,7 +898,7 @@ def check_snp_allele_frequencies_advanced( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) if min_cohort_size is None: - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) if site_mask is None: site_mask = random.choice(api.site_mask_ids + (None,)) @@ -1089,7 +1092,7 @@ def check_aa_allele_frequencies_advanced( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) if min_cohort_size is None: - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) # Run function under test. ds = api.aa_allele_frequencies_advanced( From 3feb26eb6bf52ac208c4539d891712d2ad93c9ab Mon Sep 17 00:00:00 2001 From: Mohamed Laarej Date: Fri, 16 May 2025 16:08:21 +0100 Subject: [PATCH 03/10] Fix random number generation and type casting issues - Replaced all Python random.choice() with rng.choice() for consistency - Replaced random.sample() with rng.choice(..., replace=False) - Added .tolist() to convert NumPy arrays to Python lists where needed - Added str() casting for np.str_ values to ensure Python string compatibility - Fixed 'low >= high' errors in rng.integers() calls by ensuring high > low - Specifically fixed tests/anoph/test_frq.py by changing rng.integers(1, len(cohorts_areas)) to rng.integers(1, len(cohorts_areas)+1) to avoid invalid ranges - Applied int() casting to NumPy integer types where Python int was expected - Fixed site_mask selection to ensure only valid masks are used for each test context Addresses feedback from PR #760 and resolves test failures. --- tests/anoph/conftest.py | 49 ++++--- tests/anoph/test_aim_data.py | 18 +-- tests/anoph/test_cnv_data.py | 42 +++--- tests/anoph/test_cnv_frq.py | 74 +++++------ tests/anoph/test_dipclust.py | 35 ++--- tests/anoph/test_distance.py | 60 ++++++--- tests/anoph/test_frq.py | 9 +- tests/anoph/test_fst.py | 53 ++++---- tests/anoph/test_g123.py | 39 +++--- tests/anoph/test_genome_features.py | 2 +- tests/anoph/test_h12.py | 35 +++-- tests/anoph/test_h1x.py | 13 +- tests/anoph/test_hap_data.py | 40 +++--- tests/anoph/test_hap_frq.py | 10 +- tests/anoph/test_hapclust.py | 12 +- tests/anoph/test_igv.py | 7 +- tests/anoph/test_pca.py | 58 +++++--- tests/anoph/test_plink_converter.py | 11 +- tests/anoph/test_sample_metadata.py | 40 +++--- tests/anoph/test_snp_data.py | 199 ++++++++++++++++++---------- tests/anoph/test_snp_frq.py | 165 +++++++++++++++-------- 21 files changed, 561 insertions(+), 410 deletions(-) diff --git a/tests/anoph/conftest.py b/tests/anoph/conftest.py index a6c357f8f..a79f1ded4 100644 --- a/tests/anoph/conftest.py +++ b/tests/anoph/conftest.py @@ -2,7 +2,6 @@ import shutil import string from pathlib import Path -from random import choice, choices, randint from typing import Any, Dict, Tuple import numpy as np @@ -40,7 +39,7 @@ def fixture_dir(): def simulate_contig(*, low, high, base_composition): - size = rng.integers(low=low, high=high) + size = int(rng.integers(low=low, high=high)) bases = np.array([b"a", b"c", b"g", b"t", b"n", b"A", b"C", b"G", b"T", b"N"]) p = np.array([base_composition[b] for b in bases]) seq = rng.choice(bases, size=size, replace=True, p=p) @@ -151,9 +150,9 @@ def simulate_genes(self, *, contig, contig_size): # Simulate genes. for gene_ix in range(self.max_genes): gene_id = f"gene-{contig}-{gene_ix}" - strand = choice(["+", "-"]) - inter_size = randint(self.inter_size_low, self.inter_size_high) - gene_size = randint(self.gene_size_low, self.gene_size_high) + strand = rng.choice(["+", "-"]) + inter_size = int(rng.integers(self.inter_size_low, self.inter_size_high)) + gene_size = int(rng.integers(self.gene_size_low, self.gene_size_high)) if strand == "+": gene_start = cur_fwd + inter_size else: @@ -166,7 +165,11 @@ def simulate_genes(self, *, contig, contig_size): gene_attrs = f"ID={gene_id}" for attr in self.attrs: random_str = "".join( - choices(string.ascii_uppercase + string.digits, k=5) + rng.choice( + list(string.ascii_uppercase + string.digits), + size=5, + replace=True, + ) ) gene_attrs += f";{attr}={random_str}" gene = ( @@ -212,7 +215,7 @@ def simulate_transcripts( # accurate in real data. for transcript_ix in range( - randint(self.n_transcripts_low, self.n_transcripts_high) + int(rng.integers(self.n_transcripts_low, self.n_transcripts_high)) ): transcript_id = f"transcript-{contig}-{gene_ix}-{transcript_ix}" transcript_start = gene_start @@ -260,13 +263,16 @@ def simulate_exons( transcript_size = transcript_end - transcript_start exons = [] exon_end = transcript_start - n_exons = randint(self.n_exons_low, self.n_exons_high) + n_exons = int(rng.integers(self.n_exons_low, self.n_exons_high)) for exon_ix in range(n_exons): exon_id = f"exon-{contig}-{gene_ix}-{transcript_ix}-{exon_ix}" if exon_ix > 0: # Insert an intron between this exon and the previous one. - intron_size = randint( - self.intron_size_low, min(transcript_size, self.intron_size_high) + intron_size = int( + rng.integers( + self.intron_size_low, + min(transcript_size, self.intron_size_high), + ) ) exon_start = exon_end + intron_size if exon_start >= transcript_end: @@ -275,7 +281,7 @@ def simulate_exons( else: # First exon, assume exon starts where the transcript starts. exon_start = transcript_start - exon_size = randint(self.exon_size_low, self.exon_size_high) + exon_size = int(rng.integers(self.exon_size_low, self.exon_size_high)) exon_end = min(exon_start + exon_size, transcript_end) assert exon_end > exon_start exon = ( @@ -311,7 +317,7 @@ def simulate_exons( else: feature_type = self.cds_type # Cheat a little, random phase. - phase = choice([1, 2, 3]) + phase = rng.choice([1, 2, 3]) feature = ( contig, self.source, @@ -549,7 +555,7 @@ def simulate_aim_variants(path, contigs, snp_sites, n_sites_low, n_sites_high): # Simulate AIM positions variable. snp_pos = snp_sites[f"{contig}/variants/POS"][:] loc_aim_sites = rng.choice( - snp_pos.shape[0], size=rng.integers(n_sites_low, n_sites_high) + snp_pos.shape[0], size=int(rng.integers(n_sites_low, n_sites_high)) ) loc_aim_sites.sort() aim_pos = snp_pos[loc_aim_sites] @@ -731,11 +737,10 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) contig_length_bp = contig_sizes[contig] # Get a random number of CNV alleles ("variants") to simulate. - n_cnv_alleles = rng.integers(1, 5_000) + n_cnv_alleles = int(rng.integers(1, 5_000)) # Produce a set of random start positions for each allele as a sorted list. allele_start_pos = sorted(rng.integers(1, contig_length_bp, size=n_cnv_alleles)) - # Produce a set of random allele lengths for each allele, according to a range. allele_length_bp_min = 100 allele_length_bp_max = 100_000 @@ -874,7 +879,7 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig contig_length_bp = contig_sizes[contig] # Get a random number of CNV variants to simulate. - n_cnv_variants = rng.integers(1, 100) + n_cnv_variants = int(rng.integers(1, 100)) # Produce a set of random start positions for each variant as a sorted list. variant_start_pos = sorted( @@ -1010,20 +1015,20 @@ def contigs(self) -> Tuple[str, ...]: return tuple(self.config["CONTIGS"]) def random_contig(self): - return choice(self.contigs) + return rng.choice(self.contigs) def random_transcript_id(self): df_transcripts = self.genome_features.query("type == 'mRNA'") transcript_ids = [ gff3_parse_attributes(t)["ID"] for t in df_transcripts.loc[:, "attributes"] ] - transcript_id = choice(transcript_ids) + transcript_id = rng.choice(transcript_ids) return transcript_id def random_region_str(self, region_size=None): contig = self.random_contig() contig_size = self.contig_sizes[contig] - region_start = randint(1, contig_size) + region_start = int(rng.integers(1, contig_size)) if region_size: # Ensure we the region span doesn't exceed the contig size. if contig_size - region_start < region_size: @@ -1031,7 +1036,7 @@ def random_region_str(self, region_size=None): region_end = region_start + region_size else: - region_end = randint(region_start, contig_size) + region_end = int(rng.integers(region_start, contig_size)) region = f"{contig}:{region_start:,}-{region_end:,}" return region @@ -1133,7 +1138,7 @@ def init_public_release_manifest(self): manifest = pd.DataFrame( { "sample_set": ["AG1000G-AO", "AG1000G-BF-A"], - "sample_count": [randint(10, 50), randint(10, 40)], + "sample_count": [int(rng.integers(10, 50)), int(rng.integers(10, 40))], "study_id": ["AG1000G-AO", "AG1000G-BF-1"], "study_url": [ "https://www.malariagen.net/network/where-we-work/AG1000G-AO", @@ -1165,7 +1170,7 @@ def init_pre_release_manifest(self): "1177-VO-ML-LEHMANN-VMF00004", ], # Make sure we have some gambiae, coluzzii and arabiensis. - "sample_count": [randint(20, 60)], + "sample_count": [int(rng.integers(20, 60))], "study_id": ["1177-VO-ML-LEHMANN"], "study_url": [ "https://www.malariagen.net/network/where-we-work/1177-VO-ML-LEHMANN" diff --git a/tests/anoph/test_aim_data.py b/tests/anoph/test_aim_data.py index 8a1c76b34..4c4e37698 100644 --- a/tests/anoph/test_aim_data.py +++ b/tests/anoph/test_aim_data.py @@ -1,14 +1,14 @@ import itertools -import random - import plotly.graph_objects as go import pytest import xarray as xr from numpy.testing import assert_array_equal - +import numpy as np from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.aim_data import AnophelesAimData +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -88,9 +88,9 @@ def test_aim_calls(aims, ag3_sim_api): all_releases = api.releases parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Parametrize sample_query. @@ -179,9 +179,9 @@ def test_plot_aim_heatmap(aims, ag3_sim_api): all_releases = api.releases parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Parametrize sample_query. diff --git a/tests/anoph/test_cnv_data.py b/tests/anoph/test_cnv_data.py index 383c4248c..15bb229b9 100644 --- a/tests/anoph/test_cnv_data.py +++ b/tests/anoph/test_cnv_data.py @@ -1,5 +1,3 @@ -import random - import bokeh.models import dask.array as da import numpy as np @@ -139,14 +137,14 @@ def test_open_cnv_coverage_calls(fixture, api: AnophelesCnvData): # Check with a sample set that should not exist with pytest.raises(ValueError): root = api.open_cnv_coverage_calls( - sample_set="foobar", analysis=random.choice(api.coverage_calls_analysis_ids) + sample_set="foobar", analysis=rng.choice(api.coverage_calls_analysis_ids) ) # Check with an analysis that should not exist all_sample_sets = api.sample_sets()["sample_set"].to_list() with pytest.raises(ValueError): root = api.open_cnv_coverage_calls( - sample_set=random.choice(all_sample_sets), analysis="foobar" + sample_set=rng.choice(all_sample_sets), analysis="foobar" ) # Check with a sample set and analysis that should not exist @@ -346,15 +344,15 @@ def test_cnv_hmm(fixture, api: AnophelesCnvData): all_sample_sets = api.sample_sets()["sample_set"].to_list() parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Parametrize region. parametrize_region = [ fixture.random_contig(), - random.sample(api.contigs, 2), + rng.choice(api.contigs, 2, replace=False).tolist(), fixture.random_region_str(), ] @@ -424,7 +422,7 @@ def test_cnv_hmm(fixture, api: AnophelesCnvData): def test_cnv_hmm__max_coverage_variance(fixture, api: AnophelesCnvData): # Set up test. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) region = fixture.random_contig() # Parametrize max_coverage_variance. @@ -468,7 +466,7 @@ def test_cnv_hmm__max_coverage_variance(fixture, api: AnophelesCnvData): def test_cnv_coverage_calls(fixture, api: AnophelesCnvData): # Parametrize sample_sets. all_sample_sets = api.sample_sets()["sample_set"].to_list() - parametrize_sample_sets = random.sample(all_sample_sets, 3) + parametrize_sample_sets = rng.choice(all_sample_sets, 3, replace=False).tolist() # Parametrize analysis. parametrize_analysis = api.coverage_calls_analysis_ids @@ -476,7 +474,7 @@ def test_cnv_coverage_calls(fixture, api: AnophelesCnvData): # Parametrize region. parametrize_region = [ fixture.random_contig(), - random.sample(api.contigs, 2), + rng.choice(api.contigs, 2, replace=False).tolist(), fixture.random_region_str(), ] @@ -554,15 +552,15 @@ def test_cnv_discordant_read_calls(fixture, api: AnophelesCnvData): all_sample_sets = api.sample_sets()["sample_set"].to_list() parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Parametrize contig. parametrize_contig = [ - random.choice(api.contigs), - random.sample(api.contigs, 2), + rng.choice(api.contigs), + rng.choice(api.contigs, 2, replace=False).tolist(), ] for sample_sets in parametrize_sample_sets: @@ -631,13 +629,13 @@ def test_cnv_discordant_read_calls(fixture, api: AnophelesCnvData): # Check with a contig that should not exist with pytest.raises(ValueError): api.cnv_discordant_read_calls( - contig="foobar", sample_sets=random.choice(all_sample_sets) + contig="foobar", sample_sets=rng.choice(all_sample_sets) ) # Check with a sample set that should not exist with pytest.raises(ValueError): api.cnv_discordant_read_calls( - contig=random.choice(api.contigs), sample_sets="foobar" + contig=rng.choice(api.contigs), sample_sets="foobar" ) # Check with a contig and sample set that should not exist @@ -809,7 +807,7 @@ def test_cnv_discordant_read_calls__sample_query_options( def test_plot_cnv_hmm_coverage_track(fixture, api: AnophelesCnvData): # Set up test. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) region = fixture.random_contig() df_samples = api.sample_metadata(sample_sets=sample_set) all_sample_ids = df_samples["sample_id"].values @@ -916,9 +914,9 @@ def test_plot_cnv_hmm_heatmap_track(fixture, api: AnophelesCnvData): all_sample_sets = api.sample_sets()["sample_set"].to_list() parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] for region in parametrize_region: diff --git a/tests/anoph/test_cnv_frq.py b/tests/anoph/test_cnv_frq.py index cd7a851cf..3ea50e495 100644 --- a/tests/anoph/test_cnv_frq.py +++ b/tests/anoph/test_cnv_frq.py @@ -1,5 +1,3 @@ -import random - import numpy as np import pandas as pd import xarray as xr @@ -96,10 +94,10 @@ def test_gene_cnv_frequencies_with_str_cohorts( api: AnophelesCnvFrequencyAnalysis, cohorts, ): - region = random.choice(api.contigs) + region = rng.choice(api.contigs) all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - min_cohort_size = rng.integers(0, 2) + sample_sets = rng.choice(all_sample_sets) + min_cohort_size = int(rng.integers(0, 2)) # Set up call params. params = dict( @@ -149,8 +147,8 @@ def test_gene_cnv_frequencies_with_min_cohort_size( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - region = random.choice(api.contigs) + sample_sets = rng.choice(all_sample_sets) + region = rng.choice(api.contigs) cohorts = "admin1_year" # Set up call params. @@ -200,13 +198,11 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query( # Pick test parameters at random. sample_sets = None min_cohort_size = 0 - region = random.choice(api.contigs) - cohorts = random.choice( - ["admin1_year", "admin1_month", "admin2_year", "admin2_month"] - ) + region = rng.choice(api.contigs) + cohorts = rng.choice(["admin1_year", "admin1_month", "admin2_year", "admin2_month"]) df_samples = api.sample_metadata(sample_sets=sample_sets) countries = df_samples["country"].unique() - country = random.choice(countries) + country = rng.choice(countries) sample_query = f"country == '{country}'" # Figure out expected cohort labels. @@ -248,13 +244,11 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query_options( # Pick test parameters at random. sample_sets = None min_cohort_size = 0 - region = random.choice(api.contigs) - cohorts = random.choice( - ["admin1_year", "admin1_month", "admin2_year", "admin2_month"] - ) + region = rng.choice(api.contigs) + cohorts = rng.choice(["admin1_year", "admin1_month", "admin2_year", "admin2_month"]) df_samples = api.sample_metadata(sample_sets=sample_sets) countries = df_samples["country"].unique().tolist() - countries_list = random.sample(countries, 2) + countries_list = rng.choice(countries, 2, replace=False).tolist() sample_query_options = { "local_dict": { "countries_list": countries_list, @@ -304,8 +298,8 @@ def test_gene_cnv_frequencies_with_dict_cohorts( ): # Pick test parameters at random. sample_sets = None # all sample sets - min_cohort_size = rng.integers(0, 2) - region = random.choice(api.contigs) + min_cohort_size = int(rng.integers(0, 2)) + region = rng.choice(api.contigs) # Create cohorts by country. df_samples = api.sample_metadata(sample_sets=sample_sets) @@ -344,10 +338,10 @@ def test_gene_cnv_frequencies_without_drop_invariant( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - min_cohort_size = rng.integers(0, 2) - region = random.choice(api.contigs) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + sample_sets = rng.choice(all_sample_sets) + min_cohort_size = int(rng.integers(0, 2)) + region = rng.choice(api.contigs) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # Figure out expected cohort labels. df_samples = api.sample_metadata(sample_sets=sample_sets) @@ -399,9 +393,9 @@ def test_gene_cnv_frequencies_with_bad_region( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - min_cohort_size = rng.integers(0, 2) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + sample_sets = rng.choice(all_sample_sets) + min_cohort_size = int(rng.integers(0, 2)) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # Set up call params. params = dict( @@ -425,9 +419,9 @@ def test_gene_cnv_frequencies_with_max_coverage_variance( max_coverage_variance, ): all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) - region = random.choice(api.contigs) + sample_sets = rng.choice(all_sample_sets) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) + region = rng.choice(api.contigs) params = dict( region=region, @@ -504,7 +498,7 @@ def test_gene_cnv_frequencies_advanced_with_sample_query( all_sample_sets = api.sample_sets()["sample_set"].to_list() df_samples = api.sample_metadata(sample_sets=all_sample_sets) countries = df_samples["country"].unique() - country = random.choice(countries) + country = rng.choice(countries) sample_query = f"country == '{country}'" check_gene_cnv_frequencies_advanced( @@ -523,7 +517,7 @@ def test_gene_cnv_frequencies_advanced_with_sample_query_options( all_sample_sets = api.sample_sets()["sample_set"].to_list() df_samples = api.sample_metadata(sample_sets=all_sample_sets) countries = df_samples["country"].unique().tolist() - countries_list = random.sample(countries, 2) + countries_list = rng.choice(countries, 2, replace=False).tolist() sample_query_options = { "local_dict": { "countries_list": countries_list, @@ -550,7 +544,7 @@ def test_gene_cnv_frequencies_advanced_with_min_cohort_size( all_sample_sets = api.sample_sets()["sample_set"].to_list() area_by = "admin1_iso" period_by = "year" - region = random.choice(api.contigs) + region = rng.choice(api.contigs) if min_cohort_size <= 10: # Expect this to find at least one cohort, so go ahead with full @@ -586,7 +580,7 @@ def test_gene_cnv_frequencies_advanced_with_max_coverage_variance( all_sample_sets = api.sample_sets()["sample_set"].to_list() area_by = "admin1_iso" period_by = "year" - region = random.choice(api.contigs) + region = rng.choice(api.contigs) if max_coverage_variance >= 0.4: # Expect this to find at least one cohort, so go ahead with full @@ -621,7 +615,7 @@ def test_gene_cnv_frequencies_advanced_with_nobs_mode( all_sample_sets = api.sample_sets()["sample_set"].to_list() area_by = "admin1_iso" period_by = "year" - region = random.choice(api.contigs) + region = rng.choice(api.contigs) check_gene_cnv_frequencies_advanced( api=api, @@ -643,7 +637,7 @@ def test_gene_cnv_frequencies_advanced_with_variant_query( all_sample_sets = api.sample_sets()["sample_set"].to_list() area_by = "admin1_iso" period_by = "year" - region = random.choice(api.contigs) + region = rng.choice(api.contigs) variant_query = "cnv_type == '{variant_query_option}'" check_gene_cnv_frequencies_advanced( @@ -711,16 +705,16 @@ def check_gene_cnv_frequencies_advanced( ): # Pick test parameters at random. if region is None: - region = random.choice(api.contigs) + region = rng.choice(api.contigs) if area_by is None: - area_by = random.choice(["country", "admin1_iso", "admin2_name"]) + area_by = rng.choice(["country", "admin1_iso", "admin2_name"]) if period_by is None: - period_by = random.choice(["year", "quarter", "month"]) + period_by = rng.choice(["year", "quarter", "month"]) if sample_sets is None: all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) if min_cohort_size is None: - min_cohort_size = rng.integers(0, 2) + min_cohort_size = int(rng.integers(0, 2)) # Run function under test. ds = api.gene_cnv_frequencies_advanced( diff --git a/tests/anoph/test_dipclust.py b/tests/anoph/test_dipclust.py index c0bbad033..23b04a411 100644 --- a/tests/anoph/test_dipclust.py +++ b/tests/anoph/test_dipclust.py @@ -1,17 +1,18 @@ -import random import pytest from pytest_cases import parametrize_with_cases - +import numpy as np from malariagen_data import af1 as _af1 from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.dipclust import AnophelesDipClustAnalysis +rng = np.random.default_rng(seed=42) + def random_transcripts_contig(*, api, contig, n): df_gff = api.genome_features(attributes=["ID", "Parent"]) df_transcripts = df_gff.query(f"type == 'mRNA' and contig == '{contig}'") transcript_ids = df_transcripts["ID"].dropna().to_list() - transcripts = random.sample(transcript_ids, n) + transcripts = rng.choice(transcript_ids, n, replace=False).tolist() return transcripts @@ -97,12 +98,13 @@ def test_plot_diplotype_clustering( "ward", ) sample_queries = (None, "sex_call == 'F'") + idx = rng.choice([0, 1]) dipclust_params = dict( region=fixture.random_region_str(region_size=5000), - sample_sets=[random.choice(all_sample_sets)], - linkage_method=random.choice(linkage_methods), + sample_sets=[rng.choice(all_sample_sets)], + linkage_method=str(rng.choice(linkage_methods)), distance_metric=distance_metric, - sample_query=random.choice(sample_queries), + sample_query=sample_queries[idx], show=False, ) @@ -127,12 +129,13 @@ def test_plot_diplotype_clustering_advanced( "ward", ) sample_queries = (None, "sex_call == 'F'") + idx = rng.choice([0, 1]) dipclust_params = dict( region=fixture.random_region_str(region_size=5000), - sample_sets=[random.choice(all_sample_sets)], - linkage_method=random.choice(linkage_methods), + sample_sets=[rng.choice(all_sample_sets)], + linkage_method=str(rng.choice(linkage_methods)), distance_metric=distance_metric, - sample_query=random.choice(sample_queries), + sample_query=sample_queries[idx], show=False, ) @@ -159,13 +162,14 @@ def test_plot_diplotype_clustering_advanced_with_transcript( "ward", ) sample_queries = (None, "sex_call == 'F'") + idx = rng.choice([0, 1]) dipclust_params = dict( region=contig, snp_transcript=transcripts, - sample_sets=[random.choice(all_sample_sets)], - linkage_method=random.choice(linkage_methods), + sample_sets=[rng.choice(all_sample_sets)], + linkage_method=str(rng.choice(linkage_methods)), distance_metric="cityblock", - sample_query=random.choice(sample_queries), + sample_query=sample_queries[idx], show=False, ) @@ -190,13 +194,14 @@ def test_plot_diplotype_clustering_advanced_with_cnv_region( "ward", ) sample_queries = (None, "sex_call == 'F'") + idx = rng.choice([0, 1]) dipclust_params = dict( region=region, cnv_region=region, - sample_sets=[random.choice(all_sample_sets)], - linkage_method=random.choice(linkage_methods), + sample_sets=[rng.choice(all_sample_sets)], + linkage_method=str(rng.choice(linkage_methods)), distance_metric="cityblock", - sample_query=random.choice(sample_queries), + sample_query=sample_queries[idx], show=False, ) diff --git a/tests/anoph/test_distance.py b/tests/anoph/test_distance.py index c8695855f..2694b492a 100644 --- a/tests/anoph/test_distance.py +++ b/tests/anoph/test_distance.py @@ -1,5 +1,3 @@ -import random - import numpy as np import plotly.graph_objects as go # type: ignore import pytest @@ -9,6 +7,7 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.distance import AnophelesDistanceAnalysis from malariagen_data.anoph import pca_params +from .conftest import Af1Simulator, Ag3Simulator # Import the simulator classes rng = np.random.default_rng(seed=42) @@ -84,7 +83,7 @@ def check_biallelic_diplotype_pairwise_distance(*, api, data_params, metric): ds = api.biallelic_snp_calls(**data_params) n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = rng.integers(4, n_snps_available) + n_snps = int(rng.integers(4, n_snps_available)) # Run the distance computation. dist, samples, n_snps_used = api.biallelic_diplotype_pairwise_distances( @@ -126,9 +125,9 @@ def test_biallelic_diplotype_pairwise_distance_with_metric( ): all_sample_sets = api.sample_sets()["sample_set"].to_list() data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=rng.choice(np.array([""] + list(api.site_mask_ids), dtype=object)), min_minor_ac=pca_params.min_minor_ac_default, max_missing_an=pca_params.max_missing_an_default, ) @@ -146,7 +145,7 @@ def check_njt(*, api, data_params, metric, algorithm): ds = api.biallelic_snp_calls(**data_params) n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = rng.integers(4, n_snps_available) + n_snps = int(rng.integers(4, n_snps_available)) # Run the distance computation. Z, samples, n_snps_used = api.njt( @@ -174,15 +173,21 @@ def check_njt(*, api, data_params, metric, algorithm): @parametrize_with_cases("fixture,api", cases=".") def test_njt_with_metric(fixture, api: AnophelesDistanceAnalysis): all_sample_sets = api.sample_sets()["sample_set"].to_list() + if isinstance(fixture, Af1Simulator): + expected_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + expected_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + expected_site_masks = [""] + list(api.site_mask_ids) data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=str(rng.choice(np.array(expected_site_masks, dtype=object))), min_minor_ac=pca_params.min_minor_ac_default, max_missing_an=pca_params.max_missing_an_default, ) parametrize_metric = "cityblock", "euclidean", "sqeuclidean" - algorithm = random.choice(["dynamic", "rapid", "canonical"]) + algorithm = str(rng.choice(["dynamic", "rapid", "canonical"])) for metric in parametrize_metric: check_njt( api=api, @@ -195,14 +200,20 @@ def test_njt_with_metric(fixture, api: AnophelesDistanceAnalysis): @parametrize_with_cases("fixture,api", cases=".") def test_njt_with_algorithm(fixture, api: AnophelesDistanceAnalysis): all_sample_sets = api.sample_sets()["sample_set"].to_list() + if isinstance(fixture, Af1Simulator): + expected_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + expected_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + expected_site_masks = [""] + list(api.site_mask_ids) data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=str(rng.choice(np.array(expected_site_masks, dtype=object))), min_minor_ac=pca_params.min_minor_ac_default, max_missing_an=pca_params.max_missing_an_default, ) - metric = random.choice(["cityblock", "euclidean", "sqeuclidean"]) + metric = str(rng.choice(["cityblock", "euclidean", "sqeuclidean"])) parametrize_algorithm = "dynamic", "rapid", "canonical" for algorithm in parametrize_algorithm: check_njt( @@ -216,15 +227,21 @@ def test_njt_with_algorithm(fixture, api: AnophelesDistanceAnalysis): @parametrize_with_cases("fixture,api", cases=".") def test_plot_njt(fixture, api: AnophelesDistanceAnalysis): all_sample_sets = api.sample_sets()["sample_set"].to_list() + if isinstance(fixture, Af1Simulator): + expected_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + expected_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + expected_site_masks = [""] + list(api.site_mask_ids) data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=str(rng.choice(np.array(expected_site_masks, dtype=object))), min_minor_ac=pca_params.min_minor_ac_default, max_missing_an=pca_params.max_missing_an_default, ) - metric = random.choice(["cityblock", "euclidean", "sqeuclidean"]) - algorithm = random.choice(["dynamic", "rapid", "canonical"]) + metric = str(rng.choice(["cityblock", "euclidean", "sqeuclidean"])) + algorithm = str(rng.choice(["dynamic", "rapid", "canonical"])) custom_cohorts = { "male": "sex_call == 'M'", "female": "sex_call == 'F'", @@ -235,8 +252,7 @@ def test_plot_njt(fixture, api: AnophelesDistanceAnalysis): # Check available data. ds = api.biallelic_snp_calls(**data_params) n_snps_available = ds.sizes["variants"] - n_snps = rng.integers(4, n_snps_available) - + n_snps = int(rng.integers(4, n_snps_available)) # Exercise the function. for color, symbol in zip(colors, symbols): fig = api.plot_njt( diff --git a/tests/anoph/test_frq.py b/tests/anoph/test_frq.py index 1f390e334..1b2e6647b 100644 --- a/tests/anoph/test_frq.py +++ b/tests/anoph/test_frq.py @@ -1,7 +1,6 @@ import pytest import plotly.graph_objects as go # type: ignore import numpy as np -import random rng = np.random.default_rng(seed=42) @@ -41,7 +40,7 @@ def check_plot_frequencies_time_series_with_taxa(api, ds): ds = ds.isel(variants=slice(0, 100)) taxa = list(ds.cohort_taxon.to_dataframe()["cohort_taxon"].unique()) - taxon = random.choice(taxa) + taxon = rng.choice(taxa) # Plot with taxon. fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxon) @@ -66,8 +65,10 @@ def check_plot_frequencies_time_series_with_areas(api, ds): # Pick a random area and areas from valid areas. cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist() - area = random.choice(cohorts_areas) - areas = random.sample(cohorts_areas, rng.integers(1, len(cohorts_areas))) + area = rng.choice(cohorts_areas) + areas = rng.choice( + cohorts_areas, int(rng.integers(1, len(cohorts_areas) + 1)), replace=False + ).tolist() # Plot with area. fig = api.plot_frequencies_time_series(ds, show=False, areas=area) diff --git a/tests/anoph/test_fst.py b/tests/anoph/test_fst.py index 520f4f4b5..6ba6908fe 100644 --- a/tests/anoph/test_fst.py +++ b/tests/anoph/test_fst.py @@ -1,5 +1,4 @@ import itertools -import random import pytest from pytest_cases import parametrize_with_cases import numpy as np @@ -84,16 +83,16 @@ def test_fst_gwss(fixture, api: AnophelesFstAnalysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].dropna().unique().tolist() - countries = random.sample(all_countries, 2) + countries = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == {countries[0]!r}" cohort2_query = f"country == {countries[1]!r}" fst_params = dict( - contig=random.choice(api.contigs), + contig=rng.choice(api.contigs), sample_sets=all_sample_sets, cohort1_query=cohort1_query, cohort2_query=cohort2_query, - site_mask=random.choice(api.site_mask_ids), - window_size=rng.integers(10, 50), + site_mask=rng.choice(api.site_mask_ids), + window_size=int(rng.integers(10, 50)), min_cohort_size=1, ) @@ -123,17 +122,17 @@ def test_average_fst(fixture, api: AnophelesFstAnalysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].dropna().unique().tolist() - countries = random.sample(all_countries, 2) + countries = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == {countries[0]!r}" cohort2_query = f"country == {countries[1]!r}" fst_params = dict( - region=random.choice(api.contigs), + region=rng.choice(api.contigs), sample_sets=all_sample_sets, cohort1_query=cohort1_query, cohort2_query=cohort2_query, - site_mask=random.choice(api.site_mask_ids), + site_mask=rng.choice(api.site_mask_ids), min_cohort_size=1, - n_jack=rng.integers(10, 200), + n_jack=int(rng.integers(10, 200)), ) # Run main gwss function under test. @@ -151,15 +150,15 @@ def test_average_fst_with_min_cohort_size(fixture, api: AnophelesFstAnalysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].dropna().unique().tolist() - countries = random.sample(all_countries, 2) + countries = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == {countries[0]!r}" cohort2_query = f"country == {countries[1]!r}" fst_params = dict( - region=random.choice(api.contigs), + region=rng.choice(api.contigs), sample_sets=all_sample_sets, cohort1_query=cohort1_query, cohort2_query=cohort2_query, - site_mask=random.choice(api.site_mask_ids), + site_mask=rng.choice(api.site_mask_ids), min_cohort_size=1000, ) @@ -223,15 +222,15 @@ def test_pairwise_average_fst_with_str_cohorts( ): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - region = random.choice(api.contigs) - site_mask = random.choice(api.site_mask_ids) + region = rng.choice(api.contigs) + site_mask = rng.choice(api.site_mask_ids) fst_params = dict( region=region, cohorts=cohorts, sample_sets=all_sample_sets, site_mask=site_mask, min_cohort_size=1, - n_jack=rng.integers(10, 200), + n_jack=int(rng.integers(10, 200)), ) # Run checks. @@ -242,8 +241,8 @@ def test_pairwise_average_fst_with_str_cohorts( def test_pairwise_average_fst_with_min_cohort_size(fixture, api: AnophelesFstAnalysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - region = random.choice(api.contigs) - site_mask = random.choice(api.site_mask_ids) + region = rng.choice(api.contigs) + site_mask = rng.choice(api.site_mask_ids) cohorts = "admin1_year" fst_params = dict( region=region, @@ -251,7 +250,7 @@ def test_pairwise_average_fst_with_min_cohort_size(fixture, api: AnophelesFstAna sample_sets=all_sample_sets, site_mask=site_mask, min_cohort_size=15, - n_jack=rng.integers(10, 200), + n_jack=int(rng.integers(10, 200)), ) # Run checks. @@ -264,15 +263,15 @@ def test_pairwise_average_fst_with_dict_cohorts(fixture, api: AnophelesFstAnalys all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].dropna().unique().tolist() cohorts = {country: f"country == '{country}'" for country in all_countries} - region = random.choice(api.contigs) - site_mask = random.choice(api.site_mask_ids) + region = rng.choice(api.contigs) + site_mask = rng.choice(api.site_mask_ids) fst_params = dict( region=region, cohorts=cohorts, sample_sets=all_sample_sets, site_mask=site_mask, min_cohort_size=1, - n_jack=rng.integers(10, 200), + n_jack=int(rng.integers(10, 200)), ) # Run checks. @@ -283,12 +282,12 @@ def test_pairwise_average_fst_with_dict_cohorts(fixture, api: AnophelesFstAnalys def test_pairwise_average_fst_with_sample_query(fixture, api: AnophelesFstAnalysis): # Set up test parameters. all_taxa = api.sample_metadata()["taxon"].dropna().unique().tolist() - taxon = random.choice(all_taxa) + taxon = rng.choice(all_taxa) sample_query = f"taxon == '{taxon}'" all_sample_sets = api.sample_sets()["sample_set"].to_list() cohorts = "admin2_month" - region = random.choice(api.contigs) - site_mask = random.choice(api.site_mask_ids) + region = rng.choice(api.contigs) + site_mask = rng.choice(api.site_mask_ids) fst_params = dict( region=region, cohorts=cohorts, @@ -296,7 +295,7 @@ def test_pairwise_average_fst_with_sample_query(fixture, api: AnophelesFstAnalys sample_query=sample_query, site_mask=site_mask, min_cohort_size=1, - n_jack=rng.integers(10, 200), + n_jack=int(rng.integers(10, 200)), ) # Run checks. @@ -308,8 +307,8 @@ def test_pairwise_average_fst_with_bad_cohorts(fixture, api: AnophelesFstAnalysi # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() cohorts = "foobar" - region = random.choice(api.contigs) - site_mask = random.choice(api.site_mask_ids) + region = rng.choice(api.contigs) + site_mask = rng.choice(api.site_mask_ids) fst_params = dict( region=region, cohorts=cohorts, diff --git a/tests/anoph/test_g123.py b/tests/anoph/test_g123.py index f65c1a3b7..a2a68e478 100644 --- a/tests/anoph/test_g123.py +++ b/tests/anoph/test_g123.py @@ -1,4 +1,3 @@ -import random import pytest from pytest_cases import parametrize_with_cases import numpy as np @@ -106,9 +105,9 @@ def test_g123_gwss_with_default_sites(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() g123_params = dict( - contig=random.choice(api.contigs), - sample_sets=[random.choice(all_sample_sets)], - window_size=rng.integers(100, 500), + contig=rng.choice(api.contigs), + sample_sets=[rng.choice(all_sample_sets)], + window_size=int(rng.integers(100, 500)), min_cohort_size=10, ) @@ -121,10 +120,10 @@ def test_g123_gwss_with_phased_sites(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() g123_params = dict( - contig=random.choice(api.contigs), - sites=random.choice(api.phasing_analysis_ids), - sample_sets=[random.choice(all_sample_sets)], - window_size=rng.integers(100, 500), + contig=rng.choice(api.contigs), + sites=rng.choice(api.phasing_analysis_ids), + sample_sets=[rng.choice(all_sample_sets)], + window_size=int(rng.integers(100, 500)), min_cohort_size=10, ) @@ -137,11 +136,11 @@ def test_g123_gwss_with_segregating_sites(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() g123_params = dict( - contig=random.choice(api.contigs), + contig=rng.choice(api.contigs), sites="segregating", - site_mask=random.choice(api.site_mask_ids), - sample_sets=[random.choice(all_sample_sets)], - window_size=rng.integers(100, 500), + site_mask=rng.choice(api.site_mask_ids), + sample_sets=[rng.choice(all_sample_sets)], + window_size=int(rng.integers(100, 500)), min_cohort_size=10, ) @@ -154,11 +153,11 @@ def test_g123_gwss_with_all_sites(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() g123_params = dict( - contig=random.choice(api.contigs), + contig=rng.choice(api.contigs), sites="all", site_mask=None, - sample_sets=[random.choice(all_sample_sets)], - window_size=rng.integers(100, 500), + sample_sets=[rng.choice(all_sample_sets)], + window_size=int(rng.integers(100, 500)), min_cohort_size=10, ) @@ -171,9 +170,9 @@ def test_g123_gwss_with_bad_sites(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() g123_params = dict( - contig=random.choice(api.contigs), - sample_sets=[random.choice(all_sample_sets)], - window_size=rng.integers(100, 500), + contig=rng.choice(api.contigs), + sample_sets=[rng.choice(all_sample_sets)], + window_size=int(rng.integers(100, 500)), min_cohort_size=10, sites="foobar", ) @@ -187,8 +186,8 @@ def test_g123_gwss_with_bad_sites(fixture, api: AnophelesG123Analysis): def test_g123_calibration(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - window_sizes = rng.integers(100, 500, size=rng.integers(2, 5)).tolist() - window_sizes = sorted(int(window_sizes)) + window_sizes = rng.integers(100, 500, size=int(rng.integers(2, 5))).tolist() + window_sizes = sorted(window_sizes) g123_params = dict( contig=rng.choice(api.contigs), sites=rng.choice(api.phasing_analysis_ids), diff --git a/tests/anoph/test_genome_features.py b/tests/anoph/test_genome_features.py index d055c0c76..1f12dac40 100644 --- a/tests/anoph/test_genome_features.py +++ b/tests/anoph/test_genome_features.py @@ -148,7 +148,7 @@ def test_plot_genes_with_gene_labels(fixture, api: AnophelesGenomeFeaturesData): # If there are no genes, we cannot label them. if not genes_df.empty: # Get a random number of genes to sample. - random_genes_n = rng.integers(low=1, high=len(genes_df) + 1) + random_genes_n = int(rng.integers(low=1, high=len(genes_df) + 1)) # Get a random sample of genes. random_sample_genes_df = genes_df.sample(n=random_genes_n) diff --git a/tests/anoph/test_h12.py b/tests/anoph/test_h12.py index bd1d4233d..6cda7e028 100644 --- a/tests/anoph/test_h12.py +++ b/tests/anoph/test_h12.py @@ -1,4 +1,3 @@ -import random import pytest from pytest_cases import parametrize_with_cases import numpy as np @@ -107,7 +106,7 @@ def test_haplotype_frequencies(): def test_h12_calibration(fixture, api: AnophelesH12Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - window_sizes = rng.integers(100, 500, size=rng.integers(2, 5)).tolist() + window_sizes = rng.integers(100, 500, size=int(rng.integers(2, 5))).tolist() # Convert window_sizes to a flattened list of integers window_sizes = sorted(set([int(x) for x in window_sizes])) h12_params = dict( @@ -174,9 +173,9 @@ def test_h12_gwss_with_default_analysis(fixture, api: AnophelesH12Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() h12_params = dict( - contig=random.choice(api.contigs), - sample_sets=[random.choice(all_sample_sets)], - window_size=rng.integers(100, 500), + contig=rng.choice(api.contigs), + sample_sets=[rng.choice(all_sample_sets)], + window_size=int(rng.integers(100, 500)), min_cohort_size=5, ) @@ -188,9 +187,9 @@ def test_h12_gwss_with_default_analysis(fixture, api: AnophelesH12Analysis): def test_h12_gwss_with_analysis(fixture, api: AnophelesH12Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = [random.choice(all_sample_sets)] - contig = random.choice(api.contigs) - window_size = rng.integers(100, 500) + sample_sets = [rng.choice(all_sample_sets)] + contig = rng.choice(api.contigs) + window_size = int(rng.integers(100, 500)) for analysis in api.phasing_analysis_ids: # Check if any samples available for the given phasing analysis. @@ -238,13 +237,13 @@ def test_h12_gwss_multi_with_default_analysis(fixture, api: AnophelesH12Analysis # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].unique().tolist() - country1, country2 = random.sample(all_countries, 2) + country1, country2 = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == '{country1}'" cohort2_query = f"country == '{country2}'" h12_params = dict( - contig=random.choice(api.contigs), + contig=rng.choice(api.contigs), sample_sets=all_sample_sets, - window_size=rng.integers(100, 500), + window_size=int(rng.integers(100, 500)), min_cohort_size=1, cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, ) @@ -258,15 +257,15 @@ def test_h12_gwss_multi_with_window_size_dict(fixture, api: AnophelesH12Analysis # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].unique().tolist() - country1, country2 = random.sample(all_countries, 2) + country1, country2 = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == '{country1}'" cohort2_query = f"country == '{country2}'" h12_params = dict( - contig=random.choice(api.contigs), + contig=rng.choice(api.contigs), sample_sets=all_sample_sets, window_size={ - "cohort1": rng.integers(100, 500), - "cohort2": rng.integers(100, 500), + "cohort1": int(rng.integers(100, 500)), + "cohort2": int(rng.integers(100, 500)), }, min_cohort_size=1, cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, @@ -281,10 +280,10 @@ def test_h12_gwss_multi_with_analysis(fixture, api: AnophelesH12Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].unique().tolist() - country1, country2 = random.sample(all_countries, 2) + country1, country2 = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == '{country1}'" cohort2_query = f"country == '{country2}'" - contig = random.choice(api.contigs) + contig = rng.choice(api.contigs) for analysis in api.phasing_analysis_ids: # Check if any samples available for the given phasing analysis. @@ -317,7 +316,7 @@ def test_h12_gwss_multi_with_analysis(fixture, api: AnophelesH12Analysis): analysis=analysis, contig=contig, sample_sets=all_sample_sets, - window_size=rng.integers(100, 500), + window_size=int(rng.integers(100, 500)), min_cohort_size=min(n1, n2), cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, ) diff --git a/tests/anoph/test_h1x.py b/tests/anoph/test_h1x.py index dc66ebc42..5c8528e0e 100644 --- a/tests/anoph/test_h1x.py +++ b/tests/anoph/test_h1x.py @@ -1,4 +1,3 @@ -import random import pytest from pytest_cases import parametrize_with_cases import numpy as np @@ -143,13 +142,13 @@ def test_h1x_gwss_with_default_analysis(fixture, api: AnophelesH1XAnalysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].unique().tolist() - country1, country2 = random.sample(all_countries, 2) + country1, country2 = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == '{country1}'" cohort2_query = f"country == '{country2}'" h1x_params = dict( - contig=random.choice(api.contigs), + contig=rng.choice(api.contigs), sample_sets=all_sample_sets, - window_size=rng.integers(100, 500), + window_size=int(rng.integers(100, 500)), min_cohort_size=1, cohort1_query=cohort1_query, cohort2_query=cohort2_query, @@ -164,10 +163,10 @@ def test_h1x_gwss_with_analysis(fixture, api: AnophelesH1XAnalysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].unique().tolist() - country1, country2 = random.sample(all_countries, 2) + country1, country2 = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == '{country1}'" cohort2_query = f"country == '{country2}'" - contig = random.choice(api.contigs) + contig = rng.choice(api.contigs) for analysis in api.phasing_analysis_ids: # Check if any samples available for the given phasing analysis. @@ -200,7 +199,7 @@ def test_h1x_gwss_with_analysis(fixture, api: AnophelesH1XAnalysis): analysis=analysis, contig=contig, sample_sets=all_sample_sets, - window_size=rng.integers(100, 500), + window_size=int(rng.integers(100, 500)), min_cohort_size=min(n1, n2), cohort1_query=cohort1_query, cohort2_query=cohort2_query, diff --git a/tests/anoph/test_hap_data.py b/tests/anoph/test_hap_data.py index 16c411154..f8b9ea7bb 100644 --- a/tests/anoph/test_hap_data.py +++ b/tests/anoph/test_hap_data.py @@ -1,5 +1,3 @@ -import random - import dask.array as da import numpy as np import pytest @@ -326,9 +324,9 @@ def test_haplotypes_with_sample_sets_param(fixture, api: AnophelesHapData): all_releases = api.releases parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Run tests. @@ -346,7 +344,7 @@ def test_haplotypes_with_sample_sets_param(fixture, api: AnophelesHapData): def test_haplotypes_with_region_param(fixture, api: AnophelesHapData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) analysis = api.phasing_analysis_ids[0] # Parametrize region. @@ -356,7 +354,7 @@ def test_haplotypes_with_region_param(fixture, api: AnophelesHapData): contig, fixture.random_region_str(), [fixture.random_region_str(), fixture.random_region_str()], - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Run tests. @@ -374,7 +372,7 @@ def test_haplotypes_with_region_param(fixture, api: AnophelesHapData): def test_haplotypes_with_analysis_param(fixture, api: AnophelesHapData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Parametrize analysis. @@ -399,7 +397,7 @@ def test_haplotypes_with_sample_query_param( # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() analysis = api.phasing_analysis_ids[0] @@ -426,7 +424,7 @@ def test_haplotypes_with_sample_query_options_param( # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() analysis = api.phasing_analysis_ids[0] sample_query_options = { @@ -465,12 +463,16 @@ def test_haplotypes_with_cohort_size_param( # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() analysis = api.phasing_analysis_ids[0] # Parametrize over cohort_size. - parametrize_cohort_size = [rng.integers(1, 10), rng.integers(10, 50), 1_000] + parametrize_cohort_size = [ + int(rng.integers(1, 10)), + int(rng.integers(10, 50)), + 1_000, + ] for cohort_size in parametrize_cohort_size: check_haplotypes( fixture=fixture, @@ -491,14 +493,14 @@ def test_haplotypes_with_min_cohort_size_param( # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() analysis = api.phasing_analysis_ids[0] # Parametrize over min_cohort_size. parametrize_min_cohort_size = [ - rng.integers(1, 10), - rng.integers(10, 50), + int(rng.integers(1, 10)), + int(rng.integers(10, 50)), 1_000, ] for min_cohort_size in parametrize_min_cohort_size: @@ -521,14 +523,14 @@ def test_haplotypes_with_max_cohort_size_param( # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() analysis = api.phasing_analysis_ids[0] # Parametrize over max_cohort_size. parametrize_max_cohort_size = [ - rng.integers(1, 10), - rng.integers(10, 50), + int(rng.integers(1, 10)), + int(rng.integers(10, 50)), 1_000, ] for max_cohort_size in parametrize_max_cohort_size: @@ -661,7 +663,7 @@ def test_haplotype_sites(fixture, api: AnophelesHapData): # Test with genome feature ID. df_gff = api.genome_features(attributes=["ID"]) - region = random.choice(df_gff["ID"].dropna().to_list()) + region = rng.choice(df_gff["ID"].dropna().to_list()) check_haplotype_sites(api=api, region=region) diff --git a/tests/anoph/test_hap_frq.py b/tests/anoph/test_hap_frq.py index 689583010..91ad8ccf2 100644 --- a/tests/anoph/test_hap_frq.py +++ b/tests/anoph/test_hap_frq.py @@ -1,5 +1,3 @@ -import random - import pandas as pd import numpy as np import xarray as xr @@ -169,8 +167,8 @@ def test_hap_frequencies_with_str_cohorts( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - min_cohort_size = rng.integers(0, 2) + sample_sets = rng.choice(all_sample_sets) + min_cohort_size = int(rng.integers(0, 2)) region = fixture.random_region_str() # Set up call params. @@ -211,8 +209,8 @@ def test_hap_frequencies_advanced( fixture, api: AnophelesHapFrequencyAnalysis, area_by, period_by ): all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - min_cohort_size = rng.integers(0, 2) + sample_sets = rng.choice(all_sample_sets) + min_cohort_size = int(rng.integers(0, 2)) region = fixture.random_region_str() # Set up call params. diff --git a/tests/anoph/test_hapclust.py b/tests/anoph/test_hapclust.py index 454b6e40c..068c229eb 100644 --- a/tests/anoph/test_hapclust.py +++ b/tests/anoph/test_hapclust.py @@ -1,11 +1,12 @@ -import random import pytest from pytest_cases import parametrize_with_cases - +import numpy as np from malariagen_data import af1 as _af1 from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.hapclust import AnophelesHapClustAnalysis +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -86,11 +87,12 @@ def test_plot_haplotype_clustering(fixture, api: AnophelesHapClustAnalysis): "ward", ) sample_queries = (None, "sex_call == 'F'") + idx = rng.choice([0, 1]) # to genrate a random index hapclust_params = dict( region=fixture.random_region_str(region_size=5000), - sample_sets=[random.choice(all_sample_sets)], - linkage_method=random.choice(linkage_methods), - sample_query=random.choice(sample_queries), + sample_sets=[rng.choice(all_sample_sets)], + linkage_method=str(rng.choice(linkage_methods)), + sample_query=sample_queries[idx], show=False, ) diff --git a/tests/anoph/test_igv.py b/tests/anoph/test_igv.py index a468af725..854b4b0fd 100644 --- a/tests/anoph/test_igv.py +++ b/tests/anoph/test_igv.py @@ -1,5 +1,4 @@ -import random - +import numpy as np import igv_notebook # type: ignore import pytest from pytest_cases import parametrize_with_cases @@ -8,6 +7,8 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.igv import AnophelesIgv +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -81,7 +82,7 @@ def test_igv(fixture, api: AnophelesIgv): @parametrize_with_cases("fixture,api", cases=".") def test_view_alignments(fixture, api: AnophelesIgv): region = fixture.random_region_str() - sample = random.choice(api.sample_metadata()["sample_id"]) + sample = rng.choice(api.sample_metadata()["sample_id"]) ret = api.view_alignments(region=region, sample=sample, init=False) # No return value to avoid cluttering notebook output. assert ret is None diff --git a/tests/anoph/test_pca.py b/tests/anoph/test_pca.py index 9ce044e45..54dafd671 100644 --- a/tests/anoph/test_pca.py +++ b/tests/anoph/test_pca.py @@ -1,5 +1,3 @@ -import random - import numpy as np import pandas as pd import plotly.graph_objects as go # type: ignore @@ -10,6 +8,7 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.pca import AnophelesPca from malariagen_data.anoph import pca_params +from .conftest import Af1Simulator, Ag3Simulator rng = np.random.default_rng(seed=42) @@ -84,9 +83,9 @@ def test_pca_plotting(fixture, api: AnophelesPca): # Parameters for selecting input data. all_sample_sets = api.sample_sets()["sample_set"].to_list() data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=rng.choice(np.array([""] + list(api.site_mask_ids), dtype=object)), ) ds = api.biallelic_snp_calls( min_minor_ac=pca_params.min_minor_ac_default, @@ -97,10 +96,10 @@ def test_pca_plotting(fixture, api: AnophelesPca): # PCA parameters. n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = rng.integers(4, n_snps_available) + n_snps = int(rng.integers(4, n_snps_available)) # PC3 required for plot_pca_coords_3d() assert min(n_samples, n_snps) > 3 - n_components = rng.integers(3, min(n_samples, n_snps, 10)) + n_components = int(rng.integers(3, min(n_samples, n_snps, 10))) # Run the PCA. pca_df, pca_evr = api.pca( @@ -169,10 +168,17 @@ def test_pca_plotting(fixture, api: AnophelesPca): def test_pca_exclude_samples(fixture, api: AnophelesPca): # Parameters for selecting input data. all_sample_sets = api.sample_sets()["sample_set"].to_list() + + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=rng.choice(np.array(valid_site_masks, dtype=object)), ) ds = api.biallelic_snp_calls( min_minor_ac=pca_params.min_minor_ac_default, @@ -181,15 +187,17 @@ def test_pca_exclude_samples(fixture, api: AnophelesPca): ) # Exclusion parameters. - n_samples_excluded = rng.integers(1, 5) + n_samples_excluded = int(rng.integers(1, 5)) samples = ds["sample_id"].values.tolist() - exclude_samples = random.sample(samples, int(n_samples_excluded)) # convert to int + exclude_samples = rng.choice( + samples, int(n_samples_excluded), replace=False + ).tolist() # PCA parameters. n_samples = ds.sizes["samples"] - n_samples_excluded n_snps_available = ds.sizes["variants"] - n_snps = rng.integers(4, n_snps_available) - n_components = rng.integers(2, min(n_samples, n_snps, 10)) + n_snps = int(rng.integers(4, n_snps_available)) + n_components = int(rng.integers(2, min(n_samples, n_snps, 10))) # Run the PCA. pca_df, pca_evr = api.pca( @@ -230,10 +238,16 @@ def test_pca_exclude_samples(fixture, api: AnophelesPca): def test_pca_fit_exclude_samples(fixture, api: AnophelesPca): # Parameters for selecting input data. all_sample_sets = api.sample_sets()["sample_set"].to_list() + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=rng.choice(np.array(valid_site_masks, dtype=object)), ) ds = api.biallelic_snp_calls( min_minor_ac=pca_params.min_minor_ac_default, @@ -242,15 +256,17 @@ def test_pca_fit_exclude_samples(fixture, api: AnophelesPca): ) # Exclusion parameters. - n_samples_excluded = rng.integers(1, 5) + n_samples_excluded = int(rng.integers(1, 5)) samples = ds["sample_id"].values.tolist() - exclude_samples = random.sample(samples, int(n_samples_excluded)) # convert to int + exclude_samples = rng.choice( + samples, int(n_samples_excluded), replace=False + ).tolist() # PCA parameters. n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = rng.integers(4, n_snps_available) - n_components = rng.integers(2, min(n_samples, n_snps, 10)) + n_snps = int(rng.integers(4, n_snps_available)) + n_components = int(rng.integers(2, min(n_samples, n_snps, 10))) # Run the PCA. pca_df, pca_evr = api.pca( diff --git a/tests/anoph/test_plink_converter.py b/tests/anoph/test_plink_converter.py index eb75e2a3f..323af84c5 100644 --- a/tests/anoph/test_plink_converter.py +++ b/tests/anoph/test_plink_converter.py @@ -1,4 +1,3 @@ -import random import pytest from pytest_cases import parametrize_with_cases @@ -85,13 +84,13 @@ def test_plink_converter(fixture, api: PlinkConverter, tmp_path): all_sample_sets = api.sample_sets()["sample_set"].to_list() data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=rng.choice(np.array([""] + list(api.site_mask_ids), dtype=object)), min_minor_ac=1, max_missing_an=1, thin_offset=1, - random_seed=rng.integers(1, 2000), + random_seed=int(rng.integers(1, 2000)), ) # Load a ds containing the randomly generated samples and regions to get the number of available snps to subset from. @@ -100,7 +99,7 @@ def test_plink_converter(fixture, api: PlinkConverter, tmp_path): ) n_snps_available = ds.sizes["variants"] - n_snps = rng.integers(1, n_snps_available) + n_snps = int(rng.integers(1, n_snps_available)) # Define plink params. plink_params = dict(output_dir=str(tmp_path), n_snps=n_snps, **data_params) diff --git a/tests/anoph/test_sample_metadata.py b/tests/anoph/test_sample_metadata.py index 0f34eea8c..89e6b2249 100644 --- a/tests/anoph/test_sample_metadata.py +++ b/tests/anoph/test_sample_metadata.py @@ -1,5 +1,3 @@ -import random - import ipyleaflet # type: ignore import numpy as np import pandas as pd @@ -123,7 +121,7 @@ def test_general_metadata_with_single_sample_set(fixture, api: AnophelesSampleMe df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) # Call function to be tested. df = api.general_metadata(sample_sets=sample_set) @@ -142,7 +140,7 @@ def test_general_metadata_with_multiple_sample_sets( df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() # Call function to be tested. df = api.general_metadata(sample_sets=sample_sets) @@ -156,7 +154,7 @@ def test_general_metadata_with_multiple_sample_sets( @parametrize_with_cases("fixture,api", cases=".") def test_general_metadata_with_release(fixture, api: AnophelesSampleMetadata): # Set up the test. - release = random.choice(api.releases) + release = rng.choice(api.releases) # Call function to be tested. df = api.general_metadata(sample_sets=release) @@ -203,7 +201,7 @@ def test_sequence_qc_metadata_with_single_sample_set( df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) # Call function to be tested. df = api.sequence_qc_metadata(sample_sets=sample_set) @@ -224,7 +222,7 @@ def test_sequence_qc_metadata_with_multiple_sample_sets( df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() # Call function to be tested. df = api.sequence_qc_metadata(sample_sets=sample_sets) @@ -240,7 +238,7 @@ def test_sequence_qc_metadata_with_multiple_sample_sets( @parametrize_with_cases("fixture,api", cases=".") def test_sequence_qc_metadata_with_release(fixture, api: AnophelesSampleMetadata): # Set up the test. - release = random.choice(api.releases) + release = rng.choice(api.releases) # Call function to be tested. df = api.sequence_qc_metadata(sample_sets=release) @@ -314,7 +312,7 @@ def test_aim_metadata_with_single_sample_set(ag3_sim_api): df_sample_sets = ag3_sim_api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) # Call function to be tested. df = ag3_sim_api.aim_metadata(sample_sets=sample_set) @@ -332,7 +330,7 @@ def test_aim_metadata_with_multiple_sample_sets(ag3_sim_api): df_sample_sets = ag3_sim_api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() # Call function to be tested. df = ag3_sim_api.aim_metadata(sample_sets=sample_sets) @@ -347,7 +345,7 @@ def test_aim_metadata_with_release(ag3_sim_api): # N.B., only Ag3 has AIM data. # Set up the test. - release = random.choice(ag3_sim_api.releases) + release = rng.choice(ag3_sim_api.releases) # Call function to be tested. df = ag3_sim_api.aim_metadata(sample_sets=release) @@ -426,7 +424,7 @@ def test_cohorts_metadata_with_single_sample_set(fixture, api: AnophelesSampleMe df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) # Call function to be tested. df = api.cohorts_metadata(sample_sets=sample_set) @@ -445,7 +443,7 @@ def test_cohorts_metadata_with_multiple_sample_sets( df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() # Call function to be tested. df = api.cohorts_metadata(sample_sets=sample_sets) @@ -459,7 +457,7 @@ def test_cohorts_metadata_with_multiple_sample_sets( @parametrize_with_cases("fixture,api", cases=".") def test_cohorts_metadata_with_release(fixture, api: AnophelesSampleMetadata): # Set up test. - release = random.choice(api.releases) + release = rng.choice(api.releases) # Call function to be tested. df = api.cohorts_metadata(sample_sets=release) @@ -520,7 +518,7 @@ def test_sample_metadata_with_single_sample_set(fixture, api: AnophelesSampleMet df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) # Call function to be tested. df = api.sample_metadata(sample_sets=sample_set) @@ -547,7 +545,7 @@ def test_sample_metadata_with_multiple_sample_sets( df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() # Call function to be tested. df = api.sample_metadata(sample_sets=sample_sets) @@ -569,7 +567,7 @@ def test_sample_metadata_with_multiple_sample_sets( @parametrize_with_cases("fixture,api", cases=".") def test_sample_metadata_with_release(fixture, api: AnophelesSampleMetadata): # Set up test. - release = random.choice(api.releases) + release = rng.choice(api.releases) # Call function to be tested. df = api.sample_metadata(sample_sets=release) @@ -593,10 +591,10 @@ def test_sample_metadata_with_duplicate_sample_sets( fixture, api: AnophelesSampleMetadata ): # Set up test. - release = random.choice(api.releases) + release = rng.choice(api.releases) df_sample_sets = api.sample_sets(release=release).set_index("sample_set") all_sample_sets = df_sample_sets.index.to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) # Call function to be tested. assert_frame_equal( @@ -948,7 +946,7 @@ def test_plot_sample_location_mapbox(fixture, api): # Get test sample_sets. df_sample_sets = api.sample_sets().set_index("sample_set") all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() fig = api.plot_sample_location_mapbox( sample_sets=sample_sets, @@ -963,7 +961,7 @@ def test_plot_sample_location_geo(fixture, api): # Get test sample_sets. df_sample_sets = api.sample_sets().set_index("sample_set") all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() fig = api.plot_sample_location_geo( sample_sets=sample_sets, diff --git a/tests/anoph/test_snp_data.py b/tests/anoph/test_snp_data.py index 0afc9f905..23607aa38 100644 --- a/tests/anoph/test_snp_data.py +++ b/tests/anoph/test_snp_data.py @@ -1,4 +1,3 @@ -import random from itertools import product import allel # type: ignore @@ -15,6 +14,7 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.base_params import DEFAULT from malariagen_data.anoph.snp_data import AnophelesSnpData +from .conftest import Af1Simulator, Ag3Simulator # Global RNG for test file; functions may override with local RNG for reproducibility rng = np.random.default_rng(seed=42) @@ -181,7 +181,7 @@ def test_site_filters(fixture, api: AnophelesSnpData): # Test with genome feature ID. df_gff = api.genome_features(attributes=["ID"]) - region = random.choice(df_gff["ID"].dropna().to_list()) + region = rng.choice(df_gff["ID"].dropna().to_list()) check_site_filters(api, mask=mask, region=region) @@ -201,7 +201,7 @@ def check_snp_sites(api: AnophelesSnpData, region): assert pos.shape[0] == ref.shape[0] == alt.shape[0] # Apply site mask. - mask = random.choice(api.site_mask_ids) + mask = rng.choice(api.site_mask_ids) filter_pass = api.site_filters(region=region, mask=mask).compute() n_pass = np.count_nonzero(filter_pass) pos_pass = api.snp_sites( @@ -237,7 +237,7 @@ def test_snp_sites(fixture, api: AnophelesSnpData): # Test with genome feature ID. df_gff = api.genome_features(attributes=["ID"]) - region = random.choice(df_gff["ID"].dropna().to_list()) + region = rng.choice(df_gff["ID"].dropna().to_list()) check_snp_sites(api=api, region=region) @@ -325,11 +325,11 @@ def test_site_annotations(fixture, api): parametrize_region = [ contig, fixture.random_region_str(), - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Parametrize site_mask. - parametrize_site_mask = (None, random.choice(api.site_mask_ids)) + parametrize_site_mask = (None, rng.choice(api.site_mask_ids)) # Run tests. for region, site_mask in product( @@ -418,7 +418,7 @@ def check_snp_genotypes( assert ad.shape[2] == 4 # Check with site mask. - mask = random.choice(api.site_mask_ids) + mask = rng.choice(api.site_mask_ids) filter_pass = api.site_filters(region=region, mask=mask).compute() gt_pass = api.snp_genotypes( region=region, @@ -462,9 +462,9 @@ def test_snp_genotypes_with_sample_sets_param(fixture, api: AnophelesSnpData): all_sample_sets = api.sample_sets()["sample_set"].to_list() parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Run tests. @@ -476,7 +476,7 @@ def test_snp_genotypes_with_sample_sets_param(fixture, api: AnophelesSnpData): def test_snp_genotypes_with_region_param(fixture, api: AnophelesSnpData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) # Parametrize region. contig = fixture.random_contig() @@ -485,7 +485,7 @@ def test_snp_genotypes_with_region_param(fixture, api: AnophelesSnpData): contig, fixture.random_region_str(), [fixture.random_region_str(), fixture.random_region_str()], - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Run tests. @@ -500,7 +500,7 @@ def test_snp_genotypes_with_region_param(fixture, api: AnophelesSnpData): def test_snp_genotypes_with_sample_query_param( ag3_sim_api: AnophelesSnpData, sample_query ): - contig = random.choice(ag3_sim_api.contigs) + contig = rng.choice(ag3_sim_api.contigs) df_samples = ag3_sim_api.sample_metadata().query(sample_query) if len(df_samples) == 0: @@ -529,7 +529,7 @@ def test_snp_genotypes_with_sample_query_param( def test_snp_genotypes_with_sample_query_options_param( ag3_sim_api: AnophelesSnpData, sample_query, sample_query_options ): - contig = random.choice(ag3_sim_api.contigs) + contig = rng.choice(ag3_sim_api.contigs) df_samples = ag3_sim_api.sample_metadata().query( sample_query, **sample_query_options ) @@ -695,16 +695,22 @@ def check_snp_calls(api, sample_sets, region, site_mask): def test_snp_calls_with_sample_sets_param(fixture, api: AnophelesSnpData): # Fixed parameters. region = fixture.random_region_str() - site_mask = random.choice((None,) + api.site_mask_ids) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize sample_sets. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_releases = api.releases parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Run tests. @@ -718,8 +724,14 @@ def test_snp_calls_with_sample_sets_param(fixture, api: AnophelesSnpData): def test_snp_calls_with_region_param(fixture, api: AnophelesSnpData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice((None,) + api.site_mask_ids) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize region. contig = fixture.random_contig() @@ -728,7 +740,7 @@ def test_snp_calls_with_region_param(fixture, api: AnophelesSnpData): contig, fixture.random_region_str(), [fixture.random_region_str(), fixture.random_region_str()], - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Run tests. @@ -742,7 +754,7 @@ def test_snp_calls_with_region_param(fixture, api: AnophelesSnpData): def test_snp_calls_with_site_mask_param(fixture, api: AnophelesSnpData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Parametrize site_mask. @@ -816,7 +828,7 @@ def test_snp_calls_with_sample_query_options_param( def test_snp_calls_with_min_cohort_size_param(fixture, api: AnophelesSnpData): # Randomly fix some input parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Test with minimum cohort size. @@ -839,7 +851,7 @@ def test_snp_calls_with_min_cohort_size_param(fixture, api: AnophelesSnpData): def test_snp_calls_with_max_cohort_size_param(fixture, api: AnophelesSnpData): # Randomly fix some input parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Test with maximum cohort size. @@ -856,11 +868,11 @@ def test_snp_calls_with_max_cohort_size_param(fixture, api: AnophelesSnpData): def test_snp_calls_with_cohort_size_param(fixture, api: AnophelesSnpData): # Randomly fix some input parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Test with specific cohort size. - cohort_size = rng.integers(1, 10) + cohort_size = int(rng.integers(1, 10)) ds = api.snp_calls( sample_sets=sample_sets, region=region, @@ -967,7 +979,8 @@ def check_snp_allele_counts( assert ac.shape == (pos.shape[0], 4) assert np.all(ac >= 0) an = ac.sum(axis=1) - assert an.max() <= 2 * n_samples + if an.size > 0: # Check if 'an' is not empty + assert an.max() <= 2 * n_samples # Run again to ensure loading from results cache produces the same result. ac2 = api.snp_allele_counts( @@ -984,16 +997,21 @@ def check_snp_allele_counts( def test_snp_allele_counts_with_sample_sets_param(fixture, api: AnophelesSnpData): # Fixed parameters. region = fixture.random_region_str() - site_mask = random.choice((None,) + api.site_mask_ids) - + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize sample_sets. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_releases = api.releases parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Run tests. @@ -1011,8 +1029,14 @@ def test_snp_allele_counts_with_sample_sets_param(fixture, api: AnophelesSnpData def test_snp_allele_counts_with_region_param(fixture, api: AnophelesSnpData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice((None,) + api.site_mask_ids) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize region. contig = fixture.random_contig() @@ -1021,7 +1045,7 @@ def test_snp_allele_counts_with_region_param(fixture, api: AnophelesSnpData): contig, fixture.random_region_str(), [fixture.random_region_str(), fixture.random_region_str()], - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Run tests. @@ -1039,7 +1063,7 @@ def test_snp_allele_counts_with_region_param(fixture, api: AnophelesSnpData): def test_snp_allele_counts_with_site_mask_param(fixture, api: AnophelesSnpData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Parametrize site_mask. @@ -1060,9 +1084,15 @@ def test_snp_allele_counts_with_site_mask_param(fixture, api: AnophelesSnpData): def test_snp_allele_counts_with_sample_query_param(fixture, api: AnophelesSnpData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() - site_mask = random.choice((None,) + api.site_mask_ids) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize sample_query. parametrize_sample_query = (None, "sex_call == 'F'") @@ -1084,9 +1114,15 @@ def test_snp_allele_counts_with_sample_query_options_param( ): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() - site_mask = random.choice((None,) + api.site_mask_ids) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) sample_query_options = { "local_dict": { "sex_call_list": ["F", "M"], @@ -1124,7 +1160,7 @@ def test_is_accessible(fixture, api: AnophelesSnpData): parametrize_region = [ contig, fixture.random_region_str(), - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Parametrize site_mask. @@ -1146,9 +1182,9 @@ def test_is_accessible(fixture, api: AnophelesSnpData): def test_plot_snps(fixture, api: AnophelesSnpData): # Randomly choose parameter values. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() - site_mask = random.choice(api.site_mask_ids) + site_mask = rng.choice(api.site_mask_ids) # Exercise the function. fig = api.plot_snps( @@ -1291,18 +1327,25 @@ def check_biallelic_snp_calls_and_diplotypes( def test_biallelic_snp_calls_and_diplotypes_with_sample_sets_param( fixture, api: AnophelesSnpData ): + all_sample_sets = api.sample_sets()["sample_set"].to_list() # Fixed parameters. region = fixture.random_region_str() - site_mask = random.choice((None,) + api.site_mask_ids) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize sample_sets. - all_sample_sets = api.sample_sets()["sample_set"].to_list() all_releases = api.releases parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Run tests. @@ -1318,8 +1361,14 @@ def test_biallelic_snp_calls_and_diplotypes_with_region_param( ): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice((None,) + api.site_mask_ids) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize region. contig = fixture.random_contig() @@ -1328,7 +1377,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_region_param( contig, fixture.random_region_str(), [fixture.random_region_str(), fixture.random_region_str()], - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Run tests. @@ -1344,7 +1393,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_site_mask_param( ): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Parametrize site_mask. @@ -1423,7 +1472,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_min_cohort_size_param( ): # Randomly fix some input parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Test with minimum cohort size. @@ -1448,7 +1497,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_max_cohort_size_param( ): # Randomly fix some input parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Test with maximum cohort size. @@ -1467,11 +1516,11 @@ def test_biallelic_snp_calls_and_diplotypes_with_cohort_size_param( ): # Randomly fix some input parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Test with specific cohort size. - cohort_size = rng.integers(1, 10) + cohort_size = int(rng.integers(1, 10)) ds = api.biallelic_snp_calls( sample_sets=sample_sets, region=region, @@ -1505,7 +1554,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_cohort_size_param( def test_biallelic_snp_calls_and_diplotypes_with_site_class_param( ag3_sim_api: AnophelesSnpData, site_class ): - contig = random.choice(ag3_sim_api.contigs) + contig = rng.choice(ag3_sim_api.contigs) ds1 = ag3_sim_api.biallelic_snp_calls(region=contig) ds2 = ag3_sim_api.biallelic_snp_calls(region=contig, site_class=site_class) assert ds2.sizes["variants"] < ds1.sizes["variants"] @@ -1519,14 +1568,20 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions( fixture, api: AnophelesSnpData ): # Fixed parameters. - contig = random.choice(api.contigs) + contig = rng.choice(api.contigs) all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice((None,) + api.site_mask_ids) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrise conditions. - min_minor_ac = rng.integers(1, 3) - max_missing_an = rng.integers(5, 10) + min_minor_ac = int(rng.integers(1, 3)) + max_missing_an = int(rng.integers(5, 10)) # Run tests. ds = check_biallelic_snp_calls_and_diplotypes( @@ -1554,7 +1609,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions( # This should always be true, although depends on min_minor_ac and max_missing_an, # so the range of values for those parameters needs to be chosen with some care. assert n_snps_available > 2 - n_snps_requested = rng.integers(1, n_snps_available // 2) + n_snps_requested = int(rng.integers(1, n_snps_available // 2)) ds_thinned = check_biallelic_snp_calls_and_diplotypes( api=api, sample_sets=sample_sets, @@ -1585,14 +1640,20 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions_fractional( fixture, api: AnophelesSnpData ): # Fixed parameters. - contig = random.choice(api.contigs) + contig = rng.choice(api.contigs) all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice((None,) + api.site_mask_ids) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrise conditions. - min_minor_ac = random.uniform(0, 0.05) - max_missing_an = random.uniform(0.05, 0.2) + min_minor_ac = rng.uniform(0, 0.05) + max_missing_an = rng.uniform(0.05, 0.2) # Run tests. ds = check_biallelic_snp_calls_and_diplotypes( @@ -1620,7 +1681,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions_fractional( # This should always be true, although depends on min_minor_ac and max_missing_an, # so the range of values for those parameters needs to be chosen with some care. assert n_snps_available > 2 - n_snps_requested = rng.integers(1, n_snps_available // 2) + n_snps_requested = int(rng.integers(1, n_snps_available // 2)) ds_thinned = check_biallelic_snp_calls_and_diplotypes( api=api, sample_sets=sample_sets, diff --git a/tests/anoph/test_snp_frq.py b/tests/anoph/test_snp_frq.py index d8bfb07ab..a09a5c3d8 100644 --- a/tests/anoph/test_snp_frq.py +++ b/tests/anoph/test_snp_frq.py @@ -1,5 +1,3 @@ -import random - import numpy as np import pandas as pd from pandas.testing import assert_frame_equal @@ -7,6 +5,7 @@ from pytest_cases import parametrize_with_cases import xarray as xr from numpy.testing import assert_allclose, assert_array_equal +from .conftest import Af1Simulator, Ag3Simulator from malariagen_data import af1 as _af1 from malariagen_data import ag3 as _ag3 @@ -114,7 +113,7 @@ def random_transcript(*, api): df_gff = api.genome_features(attributes=["ID", "Parent"]) df_transcripts = df_gff.query("type == 'mRNA'") transcript_ids = df_transcripts["ID"].dropna().to_list() - transcript_id = random.choice(transcript_ids) + transcript_id = rng.choice(transcript_ids) transcript = df_transcripts.set_index("ID").loc[transcript_id] return transcript @@ -125,7 +124,13 @@ def test_snp_effects(fixture, api: AnophelesSnpFrequencyAnalysis): transcript = random_transcript(api=api) # Pick a random site mask. - site_mask = random.choice(api.site_mask_ids + (None,)) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Compute effects. df = api.snp_effects(transcript=transcript.name, site_mask=site_mask) @@ -302,9 +307,16 @@ def test_allele_frequencies_with_str_cohorts( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = rng.integers(0, 2) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) + + min_cohort_size = int(rng.integers(0, 2)) transcript = random_transcript(api=api) # Set up call params. @@ -367,8 +379,14 @@ def test_allele_frequencies_with_min_cohort_size( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) transcript = random_transcript(api=api) cohorts = "admin1_year" @@ -432,15 +450,19 @@ def test_allele_frequencies_with_str_cohorts_and_sample_query( ): # Pick test parameters at random. sample_sets = None - site_mask = random.choice(api.site_mask_ids + (None,)) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) min_cohort_size = 0 transcript = random_transcript(api=api) - cohorts = random.choice( - ["admin1_year", "admin1_month", "admin2_year", "admin2_month"] - ) + cohorts = rng.choice(["admin1_year", "admin1_month", "admin2_year", "admin2_month"]) df_samples = api.sample_metadata(sample_sets=sample_sets) countries = df_samples["country"].unique() - country = random.choice(countries) + country = rng.choice(countries) sample_query = f"country == '{country}'" # Figure out expected cohort labels. @@ -493,15 +515,19 @@ def test_allele_frequencies_with_str_cohorts_and_sample_query_options( ): # Pick test parameters at random. sample_sets = None - site_mask = random.choice(api.site_mask_ids + (None,)) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) min_cohort_size = 0 transcript = random_transcript(api=api) - cohorts = random.choice( - ["admin1_year", "admin1_month", "admin2_year", "admin2_month"] - ) + cohorts = rng.choice(["admin1_year", "admin1_month", "admin2_year", "admin2_month"]) df_samples = api.sample_metadata(sample_sets=sample_sets) countries = df_samples["country"].unique().tolist() - countries_list = random.sample(countries, 2) + countries_list = rng.choice(countries, 2, replace=False).tolist() sample_query_options = { "local_dict": { "countries_list": countries_list, @@ -563,8 +589,8 @@ def test_allele_frequencies_with_dict_cohorts( ): # Pick test parameters at random. sample_sets = None # all sample sets - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = rng.integers(0, 2) + site_mask = rng.choice(list(api.site_mask_ids) + [""]) + min_cohort_size = int(rng.integers(0, 2)) transcript = random_transcript(api=api) # Create cohorts by country. @@ -616,11 +642,11 @@ def test_allele_frequencies_without_drop_invariant( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = rng.integers(0, 2) + sample_sets = rng.choice(all_sample_sets) + site_mask = rng.choice(list(api.site_mask_ids) + [""]) + min_cohort_size = int(rng.integers(0, 2)) transcript = random_transcript(api=api) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # Figure out expected cohort labels. df_samples = api.sample_metadata(sample_sets=sample_sets) @@ -672,11 +698,17 @@ def test_allele_frequencies_without_effects( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = rng.integers(0, 2) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) + min_cohort_size = int(rng.integers(0, 2)) transcript = random_transcript(api=api) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # Figure out expected cohort labels. df_samples = api.sample_metadata(sample_sets=sample_sets) @@ -754,10 +786,10 @@ def test_allele_frequencies_with_bad_transcript( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = rng.integers(0, 2) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + sample_sets = rng.choice(all_sample_sets) + site_mask = rng.choice(list(api.site_mask_ids) + [""]) + min_cohort_size = int(rng.integers(0, 2)) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # Set up call params. params = dict( @@ -781,10 +813,16 @@ def test_allele_frequencies_with_region( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = rng.integers(0, 2) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) + min_cohort_size = int(rng.integers(0, 2)) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # This should work, as long as effects=False - i.e., can get frequencies # for any genome region. transcript = fixture.random_region_str(region_size=500) @@ -839,11 +877,17 @@ def test_allele_frequencies_with_dup_samples( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_set = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = rng.integers(0, 2) + sample_set = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) + min_cohort_size = int(rng.integers(0, 2)) transcript = random_transcript(api=api) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # Set up call params. params = dict( @@ -875,6 +919,7 @@ def test_allele_frequencies_with_dup_samples( def check_snp_allele_frequencies_advanced( *, + fixture, api: AnophelesSnpFrequencyAnalysis, transcript=None, area_by="admin1_iso", @@ -891,16 +936,22 @@ def check_snp_allele_frequencies_advanced( if transcript is None: transcript = random_transcript(api=api).name if area_by is None: - area_by = random.choice(["country", "admin1_iso", "admin2_name"]) + area_by = rng.choice(["country", "admin1_iso", "admin2_name"]) if period_by is None: - period_by = random.choice(["year", "quarter", "month"]) + period_by = rng.choice(["year", "quarter", "month"]) if sample_sets is None: all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) if min_cohort_size is None: - min_cohort_size = rng.integers(0, 2) + min_cohort_size = int(rng.integers(0, 2)) if site_mask is None: - site_mask = random.choice(api.site_mask_ids + (None,)) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Run function under test. ds = api.snp_allele_frequencies_advanced( @@ -1085,14 +1136,14 @@ def check_aa_allele_frequencies_advanced( if transcript is None: transcript = random_transcript(api=api).name if area_by is None: - area_by = random.choice(["country", "admin1_iso", "admin2_name"]) + area_by = rng.choice(["country", "admin1_iso", "admin2_name"]) if period_by is None: - period_by = random.choice(["year", "quarter", "month"]) + period_by = rng.choice(["year", "quarter", "month"]) if sample_sets is None: all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) if min_cohort_size is None: - min_cohort_size = rng.integers(0, 2) + min_cohort_size = int(rng.integers(0, 2)) # Run function under test. ds = api.aa_allele_frequencies_advanced( @@ -1262,6 +1313,7 @@ def test_allele_frequencies_advanced_with_area_by( area_by, ): check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, area_by=area_by, ) @@ -1279,6 +1331,7 @@ def test_allele_frequencies_advanced_with_period_by( period_by, ): check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, period_by=period_by, ) @@ -1296,10 +1349,11 @@ def test_allele_frequencies_advanced_with_sample_query( all_sample_sets = api.sample_sets()["sample_set"].to_list() df_samples = api.sample_metadata(sample_sets=all_sample_sets) countries = df_samples["country"].unique() - country = random.choice(countries) + country = rng.choice(countries) sample_query = f"country == '{country}'" check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, sample_sets=all_sample_sets, sample_query=sample_query, @@ -1321,7 +1375,7 @@ def test_allele_frequencies_advanced_with_sample_query_options( all_sample_sets = api.sample_sets()["sample_set"].to_list() df_samples = api.sample_metadata(sample_sets=all_sample_sets) countries = df_samples["country"].unique().tolist() - countries_list = random.sample(countries, 2) + countries_list = rng.choice(countries, 2, replace=False).tolist() sample_query_options = { "local_dict": { "countries_list": countries_list, @@ -1330,6 +1384,7 @@ def test_allele_frequencies_advanced_with_sample_query_options( sample_query = "country in @countries_list" check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, sample_sets=all_sample_sets, sample_query=sample_query, @@ -1361,6 +1416,7 @@ def test_allele_frequencies_advanced_with_min_cohort_size( # Expect this to find at least one cohort, so go ahead with full # checks. check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, transcript=transcript, sample_sets=all_sample_sets, @@ -1409,6 +1465,7 @@ def test_allele_frequencies_advanced_with_variant_query( # Test a query that should succeed. variant_query = "effect == 'NON_SYNONYMOUS_CODING'" check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, transcript=transcript, sample_sets=all_sample_sets, @@ -1453,6 +1510,7 @@ def test_allele_frequencies_advanced_with_nobs_mode( nobs_mode, ): check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, nobs_mode=nobs_mode, ) @@ -1468,10 +1526,11 @@ def test_allele_frequencies_advanced_with_dup_samples( api: AnophelesSnpFrequencyAnalysis, ): all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) sample_sets = [sample_set, sample_set] check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, sample_sets=sample_sets, ) From 504a73dbb89af5ecd2e2484e8f9ccea288dcf626 Mon Sep 17 00:00:00 2001 From: Mohamed Laarej Date: Fri, 16 May 2025 17:16:50 +0100 Subject: [PATCH 04/10] Fix test_frq.py to handle single-row dataframes in CI environment --- tests/anoph/test_frq.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/anoph/test_frq.py b/tests/anoph/test_frq.py index 1b2e6647b..8b837965c 100644 --- a/tests/anoph/test_frq.py +++ b/tests/anoph/test_frq.py @@ -10,8 +10,15 @@ def check_plot_frequencies_heatmap(api, frq_df): assert isinstance(fig, go.Figure) # Test max_len behaviour. + # Only test if we have more than 1 row, otherwise set max_len to 0 + # should still raise ValueError + if len(frq_df) > 1: + test_max_len = len(frq_df) - 1 + else: + test_max_len = 0 + with pytest.raises(ValueError): - api.plot_frequencies_heatmap(frq_df, show=False, max_len=len(frq_df) - 1) + api.plot_frequencies_heatmap(frq_df, show=False, max_len=test_max_len) # Test index parameter - if None, should use dataframe index. fig = api.plot_frequencies_heatmap(frq_df, show=False, index=None, max_len=None) From 78e8a31bd4840cc58110f080651f0b196166df8c Mon Sep 17 00:00:00 2001 From: Mohamed Laarej Date: Sat, 17 May 2025 17:28:05 +0100 Subject: [PATCH 05/10] Lowering n_snps from 50_000 to 10_000 in notebooks/karyotype.ipynb to fix CI notebook execution --- notebooks/karyotype.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/notebooks/karyotype.ipynb b/notebooks/karyotype.ipynb index bd26b8a8d..5571e7660 100644 --- a/notebooks/karyotype.ipynb +++ b/notebooks/karyotype.ipynb @@ -93,7 +93,7 @@ "pca_df_2la, pca_evr_2la = ag3.pca(\n", " region=region_2la,\n", " sample_sets=sample_sets,\n", - " n_snps=50_000,\n", + " n_snps=10_000,\n", ")\n", "pca_df_2la = pca_df_2la.merge(kt_df_2la, on=\"sample_id\")\n", "pca_df_2la.head()" @@ -180,7 +180,7 @@ "pca_df_2rb, pca_evr_2rb = ag3.pca(\n", " region=region_2rb,\n", " sample_sets=sample_sets,\n", - " n_snps=50_000,\n", + " n_snps=10_000,\n", ")\n", "pca_df_2rb = pca_df_2rb.merge(kt_df_2rb, on=\"sample_id\")\n", "pca_df_2rb.head()" @@ -262,7 +262,7 @@ " region=region_2rc,\n", " sample_sets=sample_sets,\n", " sample_query=\"taxon == 'gambiae'\",\n", - " n_snps=50_000,\n", + " n_snps=10_000,\n", ")\n", "pca_df_2rc_gam = pca_df_2rc_gam.merge(kt_df_2rc_gam, on=\"sample_id\")\n", "pca_df_2rc_gam.head()" @@ -342,7 +342,7 @@ " region=region_2rc,\n", " sample_sets=sample_sets,\n", " sample_query=\"taxon == 'coluzzii'\",\n", - " n_snps=50_000,\n", + " n_snps=10_000,\n", ")\n", "pca_df_2rc_col = pca_df_2rc_col.merge(kt_df_2rc_col, on=\"sample_id\")\n", "pca_df_2rc_col.head()" From 194f2c1bf5ceffb9c5c5be8184ff484d9b9eb106 Mon Sep 17 00:00:00 2001 From: jonbrenas <51911846+jonbrenas@users.noreply.github.com> Date: Fri, 30 May 2025 15:58:55 +0100 Subject: [PATCH 06/10] Solving linting issue --- malariagen_data/anoph/sample_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/malariagen_data/anoph/sample_metadata.py b/malariagen_data/anoph/sample_metadata.py index 70e61334d..7cab8804a 100644 --- a/malariagen_data/anoph/sample_metadata.py +++ b/malariagen_data/anoph/sample_metadata.py @@ -11,7 +11,7 @@ Tuple, Union, Hashable, - cast + cast, ) import ipyleaflet # type: ignore From 6a23218ee07abbe1d40d2e85c477ce8c8f93dddc Mon Sep 17 00:00:00 2001 From: Jon Brenas <51911846+jonbrenas@users.noreply.github.com> Date: Fri, 30 May 2025 19:54:11 +0100 Subject: [PATCH 07/10] Update sample_metadata.py Trying to solve the notebook issue --- malariagen_data/anoph/sample_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/malariagen_data/anoph/sample_metadata.py b/malariagen_data/anoph/sample_metadata.py index 7cab8804a..43873bcca 100644 --- a/malariagen_data/anoph/sample_metadata.py +++ b/malariagen_data/anoph/sample_metadata.py @@ -921,7 +921,7 @@ def _prep_sample_selection_cache_params( df_samples = self.sample_metadata(sample_sets=sample_sets) sample_query_options = sample_query_options or {} loc_samples = ( - df_samples.eval(sample_query, **sample_query_options).values, + df_samples.eval(sample_query, **sample_query_options).values ) sample_indices = np.nonzero(loc_samples)[0].tolist() From a84edcfb52329ae23808726572e33d0031242388 Mon Sep 17 00:00:00 2001 From: Jon Brenas <51911846+jonbrenas@users.noreply.github.com> Date: Fri, 30 May 2025 19:56:15 +0100 Subject: [PATCH 08/10] Update sample_metadata.py Solving the linting issue --- malariagen_data/anoph/sample_metadata.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/malariagen_data/anoph/sample_metadata.py b/malariagen_data/anoph/sample_metadata.py index 43873bcca..62baa6870 100644 --- a/malariagen_data/anoph/sample_metadata.py +++ b/malariagen_data/anoph/sample_metadata.py @@ -920,9 +920,7 @@ def _prep_sample_selection_cache_params( # integer indices instead. df_samples = self.sample_metadata(sample_sets=sample_sets) sample_query_options = sample_query_options or {} - loc_samples = ( - df_samples.eval(sample_query, **sample_query_options).values - ) + loc_samples = df_samples.eval(sample_query, **sample_query_options).values sample_indices = np.nonzero(loc_samples)[0].tolist() return sample_sets, sample_indices From 0ddf4ca405197b42a42c16ccadd7cfdc52c64f28 Mon Sep 17 00:00:00 2001 From: Jon Brenas <51911846+jonbrenas@users.noreply.github.com> Date: Fri, 30 May 2025 20:59:20 +0100 Subject: [PATCH 09/10] Update karyotype.ipynb Rolling back changes --- notebooks/karyotype.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/notebooks/karyotype.ipynb b/notebooks/karyotype.ipynb index 5571e7660..bd26b8a8d 100644 --- a/notebooks/karyotype.ipynb +++ b/notebooks/karyotype.ipynb @@ -93,7 +93,7 @@ "pca_df_2la, pca_evr_2la = ag3.pca(\n", " region=region_2la,\n", " sample_sets=sample_sets,\n", - " n_snps=10_000,\n", + " n_snps=50_000,\n", ")\n", "pca_df_2la = pca_df_2la.merge(kt_df_2la, on=\"sample_id\")\n", "pca_df_2la.head()" @@ -180,7 +180,7 @@ "pca_df_2rb, pca_evr_2rb = ag3.pca(\n", " region=region_2rb,\n", " sample_sets=sample_sets,\n", - " n_snps=10_000,\n", + " n_snps=50_000,\n", ")\n", "pca_df_2rb = pca_df_2rb.merge(kt_df_2rb, on=\"sample_id\")\n", "pca_df_2rb.head()" @@ -262,7 +262,7 @@ " region=region_2rc,\n", " sample_sets=sample_sets,\n", " sample_query=\"taxon == 'gambiae'\",\n", - " n_snps=10_000,\n", + " n_snps=50_000,\n", ")\n", "pca_df_2rc_gam = pca_df_2rc_gam.merge(kt_df_2rc_gam, on=\"sample_id\")\n", "pca_df_2rc_gam.head()" @@ -342,7 +342,7 @@ " region=region_2rc,\n", " sample_sets=sample_sets,\n", " sample_query=\"taxon == 'coluzzii'\",\n", - " n_snps=10_000,\n", + " n_snps=50_000,\n", ")\n", "pca_df_2rc_col = pca_df_2rc_col.merge(kt_df_2rc_col, on=\"sample_id\")\n", "pca_df_2rc_col.head()" From ee3f1445cd0b2e07e71729145c3167d21a004c36 Mon Sep 17 00:00:00 2001 From: Mohamed Laarej Date: Mon, 4 Aug 2025 16:45:30 +0100 Subject: [PATCH 10/10] refactor: use private column name '_partition' in plot_haplotype_network --- malariagen_data/anopheles.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index ed9cb65cf..0de03c734 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -2240,18 +2240,15 @@ def plot_haplotype_network( color_discrete_map_display = None ht_color_counts = None if color is not None: - # sanitise color column - necessary to avoid grey pie chart segments - df_haps["partition"] = df_haps[color].str.replace(r"\W", "", regex=True) - - # extract all unique values of the color column - color_values = df_haps["partition"].fillna("").unique() - color_values_mapping = dict(zip(df_haps["partition"], df_haps[color])) + df_haps["_partition"] = df_haps[color].str.replace(r"\\W", "", regex=True) + color_values = df_haps["_partition"].fillna("").unique() + color_values_mapping = dict(zip(df_haps["_partition"], df_haps[color])) color_values_mapping[""] = "black" color_values_display = [color_values_mapping[c] for c in color_values] # count color values for each distinct haplotype ht_color_counts = [ - df_haps.iloc[list(s)]["partition"].value_counts().to_dict() + df_haps.iloc[list(s)]["_partition"].value_counts().to_dict() for s in ht_distinct_sets ] @@ -2262,7 +2259,7 @@ def plot_haplotype_network( category_orders_prepped, ) = self._setup_sample_colors_plotly( data=df_haps, - color="partition", + color="_partition", color_discrete_map=color_discrete_map, color_discrete_sequence=color_discrete_sequence, category_orders=category_orders,