diff --git a/malariagen_data/anoph/sample_metadata.py b/malariagen_data/anoph/sample_metadata.py index 3088508ba..f64b1b06c 100644 --- a/malariagen_data/anoph/sample_metadata.py +++ b/malariagen_data/anoph/sample_metadata.py @@ -1,6 +1,17 @@ import io from itertools import cycle -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, + cast, +) import ipyleaflet # type: ignore import numpy as np @@ -11,6 +22,7 @@ from ..util import check_types from . import base_params, map_params, plotly_params from .base import AnophelesBase +from numpy.typing import NDArray class AnophelesSampleMetadata(AnophelesBase): @@ -891,8 +903,11 @@ def _prep_sample_selection_cache_params( # integer indices instead. df_samples = self.sample_metadata(sample_sets=sample_sets) sample_query_options = sample_query_options or {} - loc_samples = df_samples.eval(sample_query, **sample_query_options).values - sample_indices = np.nonzero(loc_samples)[0].tolist() + loc_samples = cast( + NDArray[Any], + df_samples.eval(sample_query, **sample_query_options).values, + ) + sample_indices = cast(List[int], np.nonzero(loc_samples)[0].tolist()) return sample_sets, sample_indices diff --git a/malariagen_data/anoph/snp_frq.py b/malariagen_data/anoph/snp_frq.py index 497b8234a..6d312b8f3 100644 --- a/malariagen_data/anoph/snp_frq.py +++ b/malariagen_data/anoph/snp_frq.py @@ -7,7 +7,6 @@ from numpydoc_decorator import doc # type: ignore import xarray as xr import numba # type: ignore - from .. import veff from ..util import ( check_types, @@ -576,8 +575,8 @@ def snp_allele_frequencies_advanced( raise ValueError("No SNPs remaining after dropping invariant SNPs.") df_variants = df_variants.loc[loc_variant].reset_index(drop=True) - count = np.compress(loc_variant, count, axis=0) - nobs = np.compress(loc_variant, nobs, axis=0) + count = np.compress(loc_variant, count, axis=0).reshape(-1, count.shape[1]) + nobs = np.compress(loc_variant, nobs, axis=0).reshape(-1, nobs.shape[1]) frequency = np.compress(loc_variant, frequency, axis=0) # Set up variant effect annotator. diff --git a/pyproject.toml b/pyproject.toml index 8e908a1d8..e83c9e80d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.10,<3.13" -numpy = "<2.2" +numpy = "*" numba = ">=0.60.0" llvmlite = "*" scipy = "*" diff --git a/tests/anoph/conftest.py b/tests/anoph/conftest.py index 9f258c295..a6c357f8f 100644 --- a/tests/anoph/conftest.py +++ b/tests/anoph/conftest.py @@ -29,6 +29,9 @@ # real data in GCS, but which is much smaller and so can be used # for faster test runs. +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture(scope="session") def fixture_dir(): @@ -37,10 +40,10 @@ def fixture_dir(): def simulate_contig(*, low, high, base_composition): - size = np.random.randint(low=low, high=high) + size = rng.integers(low=low, high=high) bases = np.array([b"a", b"c", b"g", b"t", b"n", b"A", b"C", b"G", b"T", b"N"]) p = np.array([base_composition[b] for b in bases]) - seq = np.random.choice(bases, size=size, replace=True, p=p) + seq = rng.choice(bases, size=size, replace=True, p=p) return seq @@ -363,7 +366,7 @@ def simulate_site_filters(path, contigs, p_pass, n_sites): for contig in contigs: variants = root.require_group(contig).require_group("variants") size = n_sites[contig] - filter_pass = np.random.choice([False, True], size=size, p=p) + filter_pass = rng.choice([False, True], size=size, p=p) variants.create_dataset(name="filter_pass", data=filter_pass) zarr.consolidate_metadata(path) @@ -386,7 +389,7 @@ def simulate_snp_genotypes( contig_n_sites = n_sites[contig] # Simulate genotype calls. - gt = np.random.choice( + gt = rng.choice( np.arange(4, dtype="i1"), size=(contig_n_sites, n_samples, 2), replace=True, @@ -395,9 +398,7 @@ def simulate_snp_genotypes( # Simulate missing calls. n_calls = contig_n_sites * n_samples - loc_missing = np.random.choice( - [False, True], size=n_calls, replace=True, p=p_missing - ) + loc_missing = rng.choice([False, True], size=n_calls, replace=True, p=p_missing) gt.reshape(-1, 2)[loc_missing] = -1 # Store genotype calls. @@ -438,7 +439,7 @@ def simulate_site_annotations(path, genome): p = [0.897754, 0.0, 0.060577, 0.014287, 0.011096, 0.016286] for contig in contigs: size = genome[contig].shape[0] - x = np.random.choice(vals, size=size, replace=True, p=p) + x = rng.choice(vals, size=size, replace=True, p=p) grp.create_dataset(name=contig, data=x) # codon_nonsyn @@ -447,7 +448,7 @@ def simulate_site_annotations(path, genome): p = [0.91404, 0.001646, 0.018698, 0.065616] for contig in contigs: size = genome[contig].shape[0] - x = np.random.choice(vals, size=size, replace=True, p=p) + x = rng.choice(vals, size=size, replace=True, p=p) grp.create_dataset(name=contig, data=x) # codon_position @@ -456,7 +457,7 @@ def simulate_site_annotations(path, genome): p = [0.897754, 0.034082, 0.034082, 0.034082] for contig in contigs: size = genome[contig].shape[0] - x = np.random.choice(vals, size=size, replace=True, p=p) + x = rng.choice(vals, size=size, replace=True, p=p) grp.create_dataset(name=contig, data=x) # seq_cls @@ -477,28 +478,28 @@ def simulate_site_annotations(path, genome): ] for contig in contigs: size = genome[contig].shape[0] - x = np.random.choice(vals, size=size, replace=True, p=p) + x = rng.choice(vals, size=size, replace=True, p=p) grp.create_dataset(name=contig, data=x) # seq_flen grp = root.require_group("seq_flen") for contig in contigs: size = genome[contig].shape[0] - x = np.random.randint(low=0, high=40_000, size=size) + x = rng.integers(low=0, high=40_000, size=size) grp.create_dataset(name=contig, data=x) # seq_relpos_start grp = root.require_group("seq_relpos_start") for contig in contigs: size = genome[contig].shape[0] - x = np.random.beta(a=0.4, b=4, size=size) * 40_000 + x = rng.beta(a=0.4, b=4, size=size) * 40_000 grp.create_dataset(name=contig, data=x) # seq_relpos_stop grp = root.require_group("seq_relpos_stop") for contig in contigs: size = genome[contig].shape[0] - x = np.random.beta(a=0.4, b=4, size=size) * 40_000 + x = rng.beta(a=0.4, b=4, size=size) * 40_000 grp.create_dataset(name=contig, data=x) zarr.consolidate_metadata(path) @@ -514,7 +515,7 @@ def simulate_hap_sites(path, contigs, snp_sites, p_site): # Simulate POS. snp_pos = snp_sites[f"{contig}/variants/POS"][:] - loc_hap_sites = np.random.choice( + loc_hap_sites = rng.choice( [False, True], size=snp_pos.shape[0], p=[1 - p_site, p_site] ) pos = snp_pos[loc_hap_sites] @@ -527,7 +528,7 @@ def simulate_hap_sites(path, contigs, snp_sites, p_site): # Simulate ALT. snp_alt = snp_sites[f"{contig}/variants/ALT"][:] - sim_alt_choice = np.random.choice(3, size=pos.shape[0]) + sim_alt_choice = rng.choice(3, size=pos.shape[0]) alt = np.take_along_axis( snp_alt[loc_hap_sites], indices=sim_alt_choice[:, None], axis=1 )[:, 0] @@ -547,8 +548,8 @@ def simulate_aim_variants(path, contigs, snp_sites, n_sites_low, n_sites_high): for contig_index, contig in enumerate(contigs): # Simulate AIM positions variable. snp_pos = snp_sites[f"{contig}/variants/POS"][:] - loc_aim_sites = np.random.choice( - snp_pos.shape[0], size=np.random.randint(n_sites_low, n_sites_high) + loc_aim_sites = rng.choice( + snp_pos.shape[0], size=rng.integers(n_sites_low, n_sites_high) ) loc_aim_sites.sort() aim_pos = snp_pos[loc_aim_sites] @@ -564,10 +565,7 @@ def simulate_aim_variants(path, contigs, snp_sites, n_sites_low, n_sites_high): snp_alleles = np.concatenate([snp_ref[:, None], snp_alt], axis=1) aim_site_snp_alleles = snp_alleles[loc_aim_sites] sim_allele_choice = np.vstack( - [ - np.random.choice(4, size=2, replace=False) - for _ in range(len(loc_aim_sites)) - ] + [rng.choice(4, size=2, replace=False) for _ in range(len(loc_aim_sites))] ) aim_alleles = np.take_along_axis( aim_site_snp_alleles, indices=sim_allele_choice, axis=1 @@ -612,7 +610,7 @@ def simulate_cnv_hmm(zarr_path, metadata_path, contigs, contig_sizes): # - samples [1D array] [str] # Get a random probability for a sample being high variance, between 0 and 1. - p_variance = np.random.random() + p_variance = rng.random() # Open a zarr at the specified path. root = zarr.open(zarr_path, mode="w") @@ -626,11 +624,11 @@ def simulate_cnv_hmm(zarr_path, metadata_path, contigs, contig_sizes): n_samples = len(df_samples) # Simulate sample_coverage_variance array. - sample_coverage_variance = np.random.uniform(low=0, high=0.5, size=n_samples) + sample_coverage_variance = rng.uniform(low=0, high=0.5, size=n_samples) root.create_dataset(name="sample_coverage_variance", data=sample_coverage_variance) # Simulate sample_is_high_variance array. - sample_is_high_variance = np.random.choice( + sample_is_high_variance = rng.choice( [False, True], size=n_samples, p=[1 - p_variance, p_variance] ) root.create_dataset(name="sample_is_high_variance", data=sample_is_high_variance) @@ -661,9 +659,9 @@ def simulate_cnv_hmm(zarr_path, metadata_path, contigs, contig_sizes): ) # Simulate CN, NormCov, RawCov under calldata. - cn = np.random.randint(low=-1, high=12, size=(n_windows, n_samples)) - normCov = np.random.randint(low=0, high=356, size=(n_windows, n_samples)) - rawCov = np.random.randint(low=-1, high=18465, size=(n_windows, n_samples)) + cn = rng.integers(low=-1, high=12, size=(n_windows, n_samples)) + normCov = rng.integers(low=0, high=356, size=(n_windows, n_samples)) + rawCov = rng.integers(low=-1, high=18465, size=(n_windows, n_samples)) calldata_grp.create_dataset(name="CN", data=cn) calldata_grp.create_dataset(name="NormCov", data=normCov) calldata_grp.create_dataset(name="RawCov", data=rawCov) @@ -705,13 +703,13 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) # - POS [1D array] [int for n_variants] # Get a random probability for choosing allele 1, between 0 and 1. - p_allele = np.random.random() + p_allele = rng.random() # Get a random probability for passing a particular SNP site (position), between 0 and 1. - p_filter_pass = np.random.random() + p_filter_pass = rng.random() # Get a random probability for applying qMerge filter to a particular SNP site (position), between 0 and 1. - p_filter_qMerge = np.random.random() + p_filter_qMerge = rng.random() # Open a zarr at the specified path. root = zarr.open(zarr_path, mode="w") @@ -733,17 +731,15 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) contig_length_bp = contig_sizes[contig] # Get a random number of CNV alleles ("variants") to simulate. - n_cnv_alleles = np.random.randint(1, 5_000) + n_cnv_alleles = rng.integers(1, 5_000) # Produce a set of random start positions for each allele as a sorted list. - allele_start_pos = sorted( - np.random.randint(1, contig_length_bp, size=n_cnv_alleles) - ) + allele_start_pos = sorted(rng.integers(1, contig_length_bp, size=n_cnv_alleles)) # Produce a set of random allele lengths for each allele, according to a range. allele_length_bp_min = 100 allele_length_bp_max = 100_000 - allele_lengths_bp = np.random.randint( + allele_lengths_bp = rng.integers( allele_length_bp_min, allele_length_bp_max, size=n_cnv_alleles ) @@ -755,7 +751,7 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) # Simulate the genotype calls. # Note: this is only 2D, unlike SNP, HAP, AIM GT which are 3D - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_cnv_alleles, n_samples), replace=True, @@ -772,8 +768,8 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) variants_grp = contig_grp.require_group("variants") # Simulate the CIEND and CIPOS arrays under variants. - ciend = np.random.randint(low=0, high=13200, size=n_cnv_alleles) - cipos = np.random.randint(low=0, high=37200, size=n_cnv_alleles) + ciend = rng.integers(low=0, high=13200, size=n_cnv_alleles) + cipos = rng.integers(low=0, high=37200, size=n_cnv_alleles) variants_grp.create_dataset(name="CIEND", data=ciend) variants_grp.create_dataset(name="CIPOS", data=cipos) @@ -787,10 +783,10 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) variants_grp.create_dataset(name="ID", data=variant_IDs) # Simulate the filters under variants. - filter_pass = np.random.choice( + filter_pass = rng.choice( [False, True], size=n_cnv_alleles, p=[1 - p_filter_pass, p_filter_pass] ) - filter_qMerge = np.random.choice( + filter_qMerge = rng.choice( [False, True], size=n_cnv_alleles, p=[1 - p_filter_qMerge, p_filter_qMerge] ) variants_grp.create_dataset(name="FILTER_PASS", data=filter_pass) @@ -806,6 +802,8 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig_sizes): + # Initialize a default RNG with a fixed seed for general random calls + default_rng = np.random.default_rng(seed=123) # Arbitrary seed for reproducibility # zarr_path is the output path to the zarr store # metadata_path is the input path for the sample metadata # contigs is the list of contigs, e.g. Ag has ('2R', '3R', 'X') @@ -828,10 +826,10 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig # - samples [1D array] [str for n_samples] # Get a random probability for a sample being high variance, between 0 and 1. - p_variance = np.random.random() + p_variance = default_rng.random() # Get a random probability for choosing allele 1, between 0 and 1. - p_allele = np.random.random() + p_allele = default_rng.random() # Open a zarr at the specified path. root = zarr.open(zarr_path, mode="w") @@ -845,11 +843,11 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig n_samples = len(df_samples) # Simulate sample_coverage_variance array. - sample_coverage_variance = np.random.uniform(low=0, high=0.5, size=n_samples) + sample_coverage_variance = default_rng.uniform(low=0, high=0.5, size=n_samples) root.create_dataset(name="sample_coverage_variance", data=sample_coverage_variance) # Simulate sample_is_high_variance array. - sample_is_high_variance = np.random.choice( + sample_is_high_variance = default_rng.choice( [False, True], size=n_samples, p=[1 - p_variance, p_variance] ) root.create_dataset(name="sample_is_high_variance", data=sample_is_high_variance) @@ -864,7 +862,7 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig for i, contig in enumerate(contigs): # Use the same random seed per contig, otherwise n_cnv_variants (and shapes) will not align. unique_seed = fixed_seed + i - np.random.seed(unique_seed) + rng = np.random.default_rng(seed=unique_seed) # Create the contig group. contig_grp = root.require_group(contig) @@ -876,17 +874,17 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig contig_length_bp = contig_sizes[contig] # Get a random number of CNV variants to simulate. - n_cnv_variants = np.random.randint(1, 100) + n_cnv_variants = rng.integers(1, 100) # Produce a set of random start positions for each variant as a sorted list. variant_start_pos = sorted( - np.random.randint(1, contig_length_bp, size=n_cnv_variants) + rng.integers(1, contig_length_bp, size=n_cnv_variants) ) # Produce a set of random lengths for each variant, according to a range. variant_length_bp_min = 100 variant_length_bp_max = 100_000 - variant_lengths_bp = np.random.randint( + variant_lengths_bp = rng.integers( variant_length_bp_min, variant_length_bp_max, size=n_cnv_variants ) @@ -898,7 +896,7 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig # Simulate the genotype calls. # Note: this is only 2D, unlike SNP, HAP, AIM GT which are 3D - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_cnv_variants, n_samples), replace=True, @@ -915,8 +913,8 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig variants_grp = contig_grp.require_group("variants") # Simulate the StartBreakpointMethod and EndBreakpointMethod arrays. - startBreakpointMethod = np.random.randint(low=-1, high=1, size=n_cnv_variants) - endBreakpointMethod = np.random.randint(low=-1, high=2, size=n_cnv_variants) + startBreakpointMethod = rng.integers(low=-1, high=1, size=n_cnv_variants) + endBreakpointMethod = rng.integers(low=-1, high=2, size=n_cnv_variants) variants_grp.create_dataset( name="StartBreakpointMethod", data=startBreakpointMethod ) @@ -1567,7 +1565,7 @@ def init_haplotypes(self): root.create_dataset(name="samples", data=samples, dtype=str) for contig in self.contigs: n_sites = self.n_hap_sites[analysis][contig] - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_sites, n_samples, 2), replace=True, @@ -1598,7 +1596,7 @@ def init_haplotypes(self): root.create_dataset(name="samples", data=samples, dtype=str) for contig in self.contigs: n_sites = self.n_hap_sites[analysis][contig] - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_sites, n_samples, 2), replace=True, @@ -1629,7 +1627,7 @@ def init_haplotypes(self): root.create_dataset(name="samples", data=samples, dtype=str) for contig in self.contigs: n_sites = self.n_hap_sites[analysis][contig] - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_sites, n_samples, 2), replace=True, @@ -1695,7 +1693,7 @@ def init_aim_calls(self): ds["sample_id"] = ("samples",), df_samples["sample_id"] # Add call_genotype variable. - gt = np.random.choice( + gt = rng.choice( np.arange(2, dtype="i1"), size=(ds.sizes["variants"], ds.sizes["samples"], 2), replace=True, @@ -2190,7 +2188,7 @@ def init_hap_sites(self): path=path, contigs=self.contigs, snp_sites=self.snp_sites, - p_site=np.random.random(), + p_site=rng.random(), ) def init_haplotypes(self): @@ -2217,7 +2215,7 @@ def init_haplotypes(self): # Simulate haplotypes. analysis = "funestus" - p_1 = np.random.random() + p_1 = rng.random() samples = df_samples["sample_id"].values self.phasing_samples[sample_set, analysis] = samples n_samples = len(samples) @@ -2233,7 +2231,7 @@ def init_haplotypes(self): root.create_dataset(name="samples", data=samples, dtype=str) for contig in self.contigs: n_sites = self.n_hap_sites[analysis][contig] - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_sites, n_samples, 2), replace=True, diff --git a/tests/anoph/test_base.py b/tests/anoph/test_base.py index 8d7e22499..43bcfbd43 100644 --- a/tests/anoph/test_base.py +++ b/tests/anoph/test_base.py @@ -8,6 +8,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.base import AnophelesBase +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -210,7 +213,7 @@ def test_lookup_study(fixture, api): # Set up test. df_sample_sets = api.sample_sets() all_sample_sets = df_sample_sets["sample_set"].values - sample_set = np.random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) study_rec_by_sample_set = api.lookup_study(sample_set) df_sample_set = df_sample_sets.set_index("sample_set").loc[sample_set] diff --git a/tests/anoph/test_cnv_data.py b/tests/anoph/test_cnv_data.py index 54c1ddf12..383c4248c 100644 --- a/tests/anoph/test_cnv_data.py +++ b/tests/anoph/test_cnv_data.py @@ -13,6 +13,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.cnv_data import AnophelesCnvData +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -425,7 +428,7 @@ def test_cnv_hmm__max_coverage_variance(fixture, api: AnophelesCnvData): region = fixture.random_contig() # Parametrize max_coverage_variance. - parametrize_max_coverage_variance = np.random.uniform(low=0, high=1, size=4) + parametrize_max_coverage_variance = rng.uniform(low=0, high=1, size=4) for max_coverage_variance in parametrize_max_coverage_variance: ds = api.cnv_hmm( @@ -810,7 +813,7 @@ def test_plot_cnv_hmm_coverage_track(fixture, api: AnophelesCnvData): region = fixture.random_contig() df_samples = api.sample_metadata(sample_sets=sample_set) all_sample_ids = df_samples["sample_id"].values - sample_id = np.random.choice(all_sample_ids) + sample_id = rng.choice(all_sample_ids) fig = api.plot_cnv_hmm_coverage_track( sample=sample_id, @@ -859,11 +862,11 @@ def test_plot_cnv_hmm_coverage_track(fixture, api: AnophelesCnvData): def test_plot_cnv_hmm_coverage(fixture, api: AnophelesCnvData): # Set up test. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) region = fixture.random_contig() df_samples = api.sample_metadata(sample_sets=sample_set) all_sample_ids = df_samples["sample_id"].values - sample_id = np.random.choice(all_sample_ids) + sample_id = rng.choice(all_sample_ids) fig = api.plot_cnv_hmm_coverage( sample=sample_id, diff --git a/tests/anoph/test_cnv_frq.py b/tests/anoph/test_cnv_frq.py index c90dd49f4..cd7a851cf 100644 --- a/tests/anoph/test_cnv_frq.py +++ b/tests/anoph/test_cnv_frq.py @@ -19,6 +19,8 @@ check_plot_frequencies_interactive_map, ) +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -97,7 +99,7 @@ def test_gene_cnv_frequencies_with_str_cohorts( region = random.choice(api.contigs) all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) # Set up call params. params = dict( @@ -302,7 +304,7 @@ def test_gene_cnv_frequencies_with_dict_cohorts( ): # Pick test parameters at random. sample_sets = None # all sample sets - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) region = random.choice(api.contigs) # Create cohorts by country. @@ -343,7 +345,7 @@ def test_gene_cnv_frequencies_without_drop_invariant( # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) region = random.choice(api.contigs) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) @@ -398,7 +400,7 @@ def test_gene_cnv_frequencies_with_bad_region( # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) # Set up call params. @@ -718,7 +720,7 @@ def check_gene_cnv_frequencies_advanced( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) if min_cohort_size is None: - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) # Run function under test. ds = api.gene_cnv_frequencies_advanced( diff --git a/tests/anoph/test_distance.py b/tests/anoph/test_distance.py index 9091ee454..c8695855f 100644 --- a/tests/anoph/test_distance.py +++ b/tests/anoph/test_distance.py @@ -11,6 +11,9 @@ from malariagen_data.anoph import pca_params +rng = np.random.default_rng(seed=42) + + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): return AnophelesDistanceAnalysis( @@ -81,7 +84,7 @@ def check_biallelic_diplotype_pairwise_distance(*, api, data_params, metric): ds = api.biallelic_snp_calls(**data_params) n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) + n_snps = rng.integers(4, n_snps_available) # Run the distance computation. dist, samples, n_snps_used = api.biallelic_diplotype_pairwise_distances( @@ -143,7 +146,7 @@ def check_njt(*, api, data_params, metric, algorithm): ds = api.biallelic_snp_calls(**data_params) n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) + n_snps = rng.integers(4, n_snps_available) # Run the distance computation. Z, samples, n_snps_used = api.njt( @@ -232,7 +235,7 @@ def test_plot_njt(fixture, api: AnophelesDistanceAnalysis): # Check available data. ds = api.biallelic_snp_calls(**data_params) n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) + n_snps = rng.integers(4, n_snps_available) # Exercise the function. for color, symbol in zip(colors, symbols): diff --git a/tests/anoph/test_frq.py b/tests/anoph/test_frq.py index cf972c83b..1f390e334 100644 --- a/tests/anoph/test_frq.py +++ b/tests/anoph/test_frq.py @@ -1,8 +1,10 @@ import pytest import plotly.graph_objects as go # type: ignore - +import numpy as np import random +rng = np.random.default_rng(seed=42) + def check_plot_frequencies_heatmap(api, frq_df): fig = api.plot_frequencies_heatmap(frq_df, show=False, max_len=None) @@ -65,7 +67,7 @@ def check_plot_frequencies_time_series_with_areas(api, ds): # Pick a random area and areas from valid areas. cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist() area = random.choice(cohorts_areas) - areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas))) + areas = random.sample(cohorts_areas, rng.integers(1, len(cohorts_areas))) # Plot with area. fig = api.plot_frequencies_time_series(ds, show=False, areas=area) diff --git a/tests/anoph/test_fst.py b/tests/anoph/test_fst.py index 098f08538..520f4f4b5 100644 --- a/tests/anoph/test_fst.py +++ b/tests/anoph/test_fst.py @@ -11,6 +11,8 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.fst import AnophelesFstAnalysis +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -91,7 +93,7 @@ def test_fst_gwss(fixture, api: AnophelesFstAnalysis): cohort1_query=cohort1_query, cohort2_query=cohort2_query, site_mask=random.choice(api.site_mask_ids), - window_size=random.randint(10, 50), + window_size=rng.integers(10, 50), min_cohort_size=1, ) @@ -131,7 +133,7 @@ def test_average_fst(fixture, api: AnophelesFstAnalysis): cohort2_query=cohort2_query, site_mask=random.choice(api.site_mask_ids), min_cohort_size=1, - n_jack=random.randint(10, 200), + n_jack=rng.integers(10, 200), ) # Run main gwss function under test. @@ -229,7 +231,7 @@ def test_pairwise_average_fst_with_str_cohorts( sample_sets=all_sample_sets, site_mask=site_mask, min_cohort_size=1, - n_jack=random.randint(10, 200), + n_jack=rng.integers(10, 200), ) # Run checks. @@ -249,7 +251,7 @@ def test_pairwise_average_fst_with_min_cohort_size(fixture, api: AnophelesFstAna sample_sets=all_sample_sets, site_mask=site_mask, min_cohort_size=15, - n_jack=random.randint(10, 200), + n_jack=rng.integers(10, 200), ) # Run checks. @@ -270,7 +272,7 @@ def test_pairwise_average_fst_with_dict_cohorts(fixture, api: AnophelesFstAnalys sample_sets=all_sample_sets, site_mask=site_mask, min_cohort_size=1, - n_jack=random.randint(10, 200), + n_jack=rng.integers(10, 200), ) # Run checks. @@ -294,7 +296,7 @@ def test_pairwise_average_fst_with_sample_query(fixture, api: AnophelesFstAnalys sample_query=sample_query, site_mask=site_mask, min_cohort_size=1, - n_jack=random.randint(10, 200), + n_jack=rng.integers(10, 200), ) # Run checks. diff --git a/tests/anoph/test_g123.py b/tests/anoph/test_g123.py index ab5e25537..5e609d816 100644 --- a/tests/anoph/test_g123.py +++ b/tests/anoph/test_g123.py @@ -8,6 +8,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.g123 import AnophelesG123Analysis +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -105,7 +108,7 @@ def test_g123_gwss_with_default_sites(fixture, api: AnophelesG123Analysis): g123_params = dict( contig=random.choice(api.contigs), sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=10, ) @@ -121,7 +124,7 @@ def test_g123_gwss_with_phased_sites(fixture, api: AnophelesG123Analysis): contig=random.choice(api.contigs), sites=random.choice(api.phasing_analysis_ids), sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=10, ) @@ -138,7 +141,7 @@ def test_g123_gwss_with_segregating_sites(fixture, api: AnophelesG123Analysis): sites="segregating", site_mask=random.choice(api.site_mask_ids), sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=10, ) @@ -155,7 +158,7 @@ def test_g123_gwss_with_all_sites(fixture, api: AnophelesG123Analysis): sites="all", site_mask=None, sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=10, ) @@ -170,7 +173,7 @@ def test_g123_gwss_with_bad_sites(fixture, api: AnophelesG123Analysis): g123_params = dict( contig=random.choice(api.contigs), sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=10, sites="foobar", ) @@ -180,16 +183,34 @@ def test_g123_gwss_with_bad_sites(fixture, api: AnophelesG123Analysis): api.g123_gwss(**g123_params) +def ensure_int_list(value): + """Convert a value to a list of integers, flattening nested structures if needed.""" + if isinstance(value, int): + return [value] + + result = [] + + def extract_ints(item): + if isinstance(item, int): + result.append(item) + elif isinstance(item, (list, tuple)): + for subitem in item: + extract_ints(subitem) + + extract_ints(value) + return result + + @parametrize_with_cases("fixture,api", cases=".") def test_g123_calibration(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - window_sizes = np.random.randint(100, 500, size=random.randint(2, 5)).tolist() - window_sizes = sorted([int(x) for x in window_sizes]) + window_sizes = rng.integers(100, 500, size=rng.integers(2, 5)).tolist() + window_sizes = sorted(ensure_int_list(window_sizes)) g123_params = dict( - contig=random.choice(api.contigs), - sites=random.choice(api.phasing_analysis_ids), - sample_sets=[random.choice(all_sample_sets)], + contig=rng.choice(api.contigs), + sites=rng.choice(api.phasing_analysis_ids), + sample_sets=[rng.choice(all_sample_sets)], min_cohort_size=10, window_sizes=window_sizes, ) diff --git a/tests/anoph/test_genome_features.py b/tests/anoph/test_genome_features.py index eaef4ee4e..d055c0c76 100644 --- a/tests/anoph/test_genome_features.py +++ b/tests/anoph/test_genome_features.py @@ -10,6 +10,9 @@ from malariagen_data.anoph.genome_features import AnophelesGenomeFeaturesData from malariagen_data.util import Region, resolve_region +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -145,7 +148,7 @@ def test_plot_genes_with_gene_labels(fixture, api: AnophelesGenomeFeaturesData): # If there are no genes, we cannot label them. if not genes_df.empty: # Get a random number of genes to sample. - random_genes_n = np.random.randint(low=1, high=len(genes_df) + 1) + random_genes_n = rng.integers(low=1, high=len(genes_df) + 1) # Get a random sample of genes. random_sample_genes_df = genes_df.sample(n=random_genes_n) @@ -166,7 +169,7 @@ def test_plot_genes_with_gene_labels(fixture, api: AnophelesGenomeFeaturesData): def test_plot_transcript(fixture, api: AnophelesGenomeFeaturesData): for contig in fixture.contigs: df_transcripts = api.genome_features(region=contig).query("type == 'mRNA'") - transcript = np.random.choice(df_transcripts["ID"].values) + transcript = rng.choice(df_transcripts["ID"].values) fig = api.plot_transcript(transcript=transcript, show=False) assert isinstance(fig, bokeh.plotting.figure) @@ -212,7 +215,7 @@ def test_genome_features_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" df = api.genome_features(region=region) assert isinstance(df, pd.DataFrame) diff --git a/tests/anoph/test_genome_sequence.py b/tests/anoph/test_genome_sequence.py index d2eca6bb6..6dafd030e 100644 --- a/tests/anoph/test_genome_sequence.py +++ b/tests/anoph/test_genome_sequence.py @@ -10,6 +10,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.genome_sequence import AnophelesGenomeSequenceData +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -79,7 +82,7 @@ def test_genome_sequence_region(fixture, api): for contig in fixture.contigs: contig_seq = api.genome_sequence(region=contig) # Pick a random start and stop position. - start, stop = sorted(np.random.randint(low=1, high=len(contig_seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(contig_seq), size=2)) region = f"{contig}:{start:,}-{stop:,}" seq = api.genome_sequence(region=region) assert isinstance(seq, da.Array) @@ -118,7 +121,7 @@ def test_genome_sequence_virtual_contigs(ag3_sim_api, chrom): ) # Test with region. - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" seq_region = api.genome_sequence(region=region) assert isinstance(seq_region, da.Array) diff --git a/tests/anoph/test_h12.py b/tests/anoph/test_h12.py index 29262e942..b410d08c4 100644 --- a/tests/anoph/test_h12.py +++ b/tests/anoph/test_h12.py @@ -9,6 +9,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.h12 import AnophelesH12Analysis, haplotype_frequencies +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -100,15 +103,34 @@ def test_haplotype_frequencies(): assert_allclose(vals, np.array([0.2, 0.2, 0.2, 0.4])) +def ensure_int_list(value): + """Convert a value to a list of integers, and flattening nested structures if needed.""" + if isinstance(value, int): + return [value] + + result = [] + + def extract_ints(item): + if isinstance(item, int): + result.append(item) + elif isinstance(item, (list, tuple)): + for subitem in item: + extract_ints(subitem) + + extract_ints(value) + return result + + @parametrize_with_cases("fixture,api", cases=".") def test_h12_calibration(fixture, api: AnophelesH12Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - window_sizes = np.random.randint(100, 500, size=random.randint(2, 5)).tolist() - window_sizes = sorted(set([int(x) for x in window_sizes])) + window_sizes = rng.integers(100, 500, size=rng.integers(2, 5)).tolist() + # Convert window_sizes to a flattened list of integers + window_sizes = sorted(set(ensure_int_list(window_sizes))) h12_params = dict( - contig=random.choice(api.contigs), - sample_sets=[random.choice(all_sample_sets)], + contig=rng.choice(api.contigs), + sample_sets=[rng.choice(all_sample_sets)], window_sizes=window_sizes, min_cohort_size=5, ) @@ -172,7 +194,7 @@ def test_h12_gwss_with_default_analysis(fixture, api: AnophelesH12Analysis): h12_params = dict( contig=random.choice(api.contigs), sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=5, ) @@ -186,7 +208,7 @@ def test_h12_gwss_with_analysis(fixture, api: AnophelesH12Analysis): all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = [random.choice(all_sample_sets)] contig = random.choice(api.contigs) - window_size = random.randint(100, 500) + window_size = rng.integers(100, 500) for analysis in api.phasing_analysis_ids: # Check if any samples available for the given phasing analysis. @@ -240,7 +262,7 @@ def test_h12_gwss_multi_with_default_analysis(fixture, api: AnophelesH12Analysis h12_params = dict( contig=random.choice(api.contigs), sample_sets=all_sample_sets, - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=1, cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, ) @@ -261,8 +283,8 @@ def test_h12_gwss_multi_with_window_size_dict(fixture, api: AnophelesH12Analysis contig=random.choice(api.contigs), sample_sets=all_sample_sets, window_size={ - "cohort1": random.randint(100, 500), - "cohort2": random.randint(100, 500), + "cohort1": rng.integers(100, 500), + "cohort2": rng.integers(100, 500), }, min_cohort_size=1, cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, @@ -313,7 +335,7 @@ def test_h12_gwss_multi_with_analysis(fixture, api: AnophelesH12Analysis): analysis=analysis, contig=contig, sample_sets=all_sample_sets, - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=min(n1, n2), cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, ) diff --git a/tests/anoph/test_h1x.py b/tests/anoph/test_h1x.py index 627717b57..dc66ebc42 100644 --- a/tests/anoph/test_h1x.py +++ b/tests/anoph/test_h1x.py @@ -9,6 +9,8 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.h1x import AnophelesH1XAnalysis, haplotype_joint_frequencies +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -147,7 +149,7 @@ def test_h1x_gwss_with_default_analysis(fixture, api: AnophelesH1XAnalysis): h1x_params = dict( contig=random.choice(api.contigs), sample_sets=all_sample_sets, - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=1, cohort1_query=cohort1_query, cohort2_query=cohort2_query, @@ -198,7 +200,7 @@ def test_h1x_gwss_with_analysis(fixture, api: AnophelesH1XAnalysis): analysis=analysis, contig=contig, sample_sets=all_sample_sets, - window_size=random.randint(100, 500), + window_size=rng.integers(100, 500), min_cohort_size=min(n1, n2), cohort1_query=cohort1_query, cohort2_query=cohort2_query, diff --git a/tests/anoph/test_hap_data.py b/tests/anoph/test_hap_data.py index c42f4d38c..16c411154 100644 --- a/tests/anoph/test_hap_data.py +++ b/tests/anoph/test_hap_data.py @@ -12,6 +12,10 @@ from malariagen_data.anoph.hap_data import AnophelesHapData +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): return AnophelesHapData( @@ -466,7 +470,7 @@ def test_haplotypes_with_cohort_size_param( analysis = api.phasing_analysis_ids[0] # Parametrize over cohort_size. - parametrize_cohort_size = [random.randint(1, 10), random.randint(10, 50), 1_000] + parametrize_cohort_size = [rng.integers(1, 10), rng.integers(10, 50), 1_000] for cohort_size in parametrize_cohort_size: check_haplotypes( fixture=fixture, @@ -493,8 +497,8 @@ def test_haplotypes_with_min_cohort_size_param( # Parametrize over min_cohort_size. parametrize_min_cohort_size = [ - random.randint(1, 10), - random.randint(10, 50), + rng.integers(1, 10), + rng.integers(10, 50), 1_000, ] for min_cohort_size in parametrize_min_cohort_size: @@ -523,8 +527,8 @@ def test_haplotypes_with_max_cohort_size_param( # Parametrize over max_cohort_size. parametrize_max_cohort_size = [ - random.randint(1, 10), - random.randint(10, 50), + rng.integers(1, 10), + rng.integers(10, 50), 1_000, ] for max_cohort_size in parametrize_max_cohort_size: @@ -605,7 +609,9 @@ def test_haplotypes_virtual_contigs( # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted( + [int(x) for x in rng.integers(low=1, high=len(seq), size=2)] + ) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. @@ -677,7 +683,7 @@ def test_haplotype_sites_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. diff --git a/tests/anoph/test_hap_frq.py b/tests/anoph/test_hap_frq.py index 97212087a..689583010 100644 --- a/tests/anoph/test_hap_frq.py +++ b/tests/anoph/test_hap_frq.py @@ -17,6 +17,8 @@ check_plot_frequencies_interactive_map, ) +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -168,7 +170,7 @@ def test_hap_frequencies_with_str_cohorts( # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) region = fixture.random_region_str() # Set up call params. @@ -210,7 +212,7 @@ def test_hap_frequencies_advanced( ): all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) region = fixture.random_region_str() # Set up call params. diff --git a/tests/anoph/test_pca.py b/tests/anoph/test_pca.py index e5fa667a6..9ce044e45 100644 --- a/tests/anoph/test_pca.py +++ b/tests/anoph/test_pca.py @@ -11,6 +11,8 @@ from malariagen_data.anoph.pca import AnophelesPca from malariagen_data.anoph import pca_params +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -95,10 +97,10 @@ def test_pca_plotting(fixture, api: AnophelesPca): # PCA parameters. n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) + n_snps = rng.integers(4, n_snps_available) # PC3 required for plot_pca_coords_3d() assert min(n_samples, n_snps) > 3 - n_components = random.randint(3, min(n_samples, n_snps, 10)) + n_components = rng.integers(3, min(n_samples, n_snps, 10)) # Run the PCA. pca_df, pca_evr = api.pca( @@ -179,15 +181,15 @@ def test_pca_exclude_samples(fixture, api: AnophelesPca): ) # Exclusion parameters. - n_samples_excluded = random.randint(1, 5) + n_samples_excluded = rng.integers(1, 5) samples = ds["sample_id"].values.tolist() - exclude_samples = random.sample(samples, n_samples_excluded) + exclude_samples = random.sample(samples, int(n_samples_excluded)) # convert to int # PCA parameters. n_samples = ds.sizes["samples"] - n_samples_excluded n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) - n_components = random.randint(2, min(n_samples, n_snps, 10)) + n_snps = rng.integers(4, n_snps_available) + n_components = rng.integers(2, min(n_samples, n_snps, 10)) # Run the PCA. pca_df, pca_evr = api.pca( @@ -240,15 +242,15 @@ def test_pca_fit_exclude_samples(fixture, api: AnophelesPca): ) # Exclusion parameters. - n_samples_excluded = random.randint(1, 5) + n_samples_excluded = rng.integers(1, 5) samples = ds["sample_id"].values.tolist() - exclude_samples = random.sample(samples, n_samples_excluded) + exclude_samples = random.sample(samples, int(n_samples_excluded)) # convert to int # PCA parameters. n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) - n_components = random.randint(2, min(n_samples, n_snps, 10)) + n_snps = rng.integers(4, n_snps_available) + n_components = rng.integers(2, min(n_samples, n_snps, 10)) # Run the PCA. pca_df, pca_evr = api.pca( diff --git a/tests/anoph/test_plink_converter.py b/tests/anoph/test_plink_converter.py index 44e476cd9..eb75e2a3f 100644 --- a/tests/anoph/test_plink_converter.py +++ b/tests/anoph/test_plink_converter.py @@ -8,9 +8,11 @@ import os import bed_reader - +import numpy as np from numpy.testing import assert_array_equal +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -89,7 +91,7 @@ def test_plink_converter(fixture, api: PlinkConverter, tmp_path): min_minor_ac=1, max_missing_an=1, thin_offset=1, - random_seed=random.randint(1, 2000), + random_seed=rng.integers(1, 2000), ) # Load a ds containing the randomly generated samples and regions to get the number of available snps to subset from. @@ -98,7 +100,7 @@ def test_plink_converter(fixture, api: PlinkConverter, tmp_path): ) n_snps_available = ds.sizes["variants"] - n_snps = random.randint(1, n_snps_available) + n_snps = rng.integers(1, n_snps_available) # Define plink params. plink_params = dict(output_dir=str(tmp_path), n_snps=n_snps, **data_params) diff --git a/tests/anoph/test_sample_metadata.py b/tests/anoph/test_sample_metadata.py index 6df290118..0f34eea8c 100644 --- a/tests/anoph/test_sample_metadata.py +++ b/tests/anoph/test_sample_metadata.py @@ -14,6 +14,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.sample_metadata import AnophelesSampleMetadata +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -975,7 +978,7 @@ def test_lookup_sample(fixture, api): # Set up test. df_samples = api.sample_metadata() all_sample_ids = df_samples["sample_id"].values - sample_id = np.random.choice(all_sample_ids) + sample_id = rng.choice(all_sample_ids) # Check we get the same sample_id back. sample_rec_by_sample_id = api.lookup_sample(sample_id) diff --git a/tests/anoph/test_snp_data.py b/tests/anoph/test_snp_data.py index ba4f1b690..0afc9f905 100644 --- a/tests/anoph/test_snp_data.py +++ b/tests/anoph/test_snp_data.py @@ -16,6 +16,9 @@ from malariagen_data.anoph.base_params import DEFAULT from malariagen_data.anoph.snp_data import AnophelesSnpData +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -256,7 +259,7 @@ def test_snp_sites_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. @@ -565,7 +568,7 @@ def test_snp_genotypes_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. check_snp_genotypes(api, region=region) @@ -590,7 +593,7 @@ def test_snp_variants_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" pos = api.snp_sites(region=region, field="POS").compute() ds_region = api.snp_variants(region=region) @@ -857,7 +860,7 @@ def test_snp_calls_with_cohort_size_param(fixture, api: AnophelesSnpData): region = fixture.random_region_str() # Test with specific cohort size. - cohort_size = random.randint(1, 10) + cohort_size = rng.integers(1, 10) ds = api.snp_calls( sample_sets=sample_sets, region=region, @@ -916,7 +919,7 @@ def test_snp_calls_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. @@ -1468,7 +1471,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_cohort_size_param( region = fixture.random_region_str() # Test with specific cohort size. - cohort_size = random.randint(1, 10) + cohort_size = rng.integers(1, 10) ds = api.biallelic_snp_calls( sample_sets=sample_sets, region=region, @@ -1522,8 +1525,8 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions( site_mask = random.choice((None,) + api.site_mask_ids) # Parametrise conditions. - min_minor_ac = random.randint(1, 3) - max_missing_an = random.randint(5, 10) + min_minor_ac = rng.integers(1, 3) + max_missing_an = rng.integers(5, 10) # Run tests. ds = check_biallelic_snp_calls_and_diplotypes( @@ -1551,7 +1554,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions( # This should always be true, although depends on min_minor_ac and max_missing_an, # so the range of values for those parameters needs to be chosen with some care. assert n_snps_available > 2 - n_snps_requested = random.randint(1, n_snps_available // 2) + n_snps_requested = rng.integers(1, n_snps_available // 2) ds_thinned = check_biallelic_snp_calls_and_diplotypes( api=api, sample_sets=sample_sets, @@ -1617,7 +1620,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions_fractional( # This should always be true, although depends on min_minor_ac and max_missing_an, # so the range of values for those parameters needs to be chosen with some care. assert n_snps_available > 2 - n_snps_requested = random.randint(1, n_snps_available // 2) + n_snps_requested = rng.integers(1, n_snps_available // 2) ds_thinned = check_biallelic_snp_calls_and_diplotypes( api=api, sample_sets=sample_sets, diff --git a/tests/anoph/test_snp_frq.py b/tests/anoph/test_snp_frq.py index aa803f031..d8bfb07ab 100644 --- a/tests/anoph/test_snp_frq.py +++ b/tests/anoph/test_snp_frq.py @@ -21,6 +21,9 @@ ) +rng = np.random.default_rng(seed=42) + + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): return AnophelesSnpFrequencyAnalysis( @@ -301,7 +304,7 @@ def test_allele_frequencies_with_str_cohorts( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) transcript = random_transcript(api=api) # Set up call params. @@ -561,7 +564,7 @@ def test_allele_frequencies_with_dict_cohorts( # Pick test parameters at random. sample_sets = None # all sample sets site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) transcript = random_transcript(api=api) # Create cohorts by country. @@ -615,7 +618,7 @@ def test_allele_frequencies_without_drop_invariant( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) transcript = random_transcript(api=api) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) @@ -671,7 +674,7 @@ def test_allele_frequencies_without_effects( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) transcript = random_transcript(api=api) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) @@ -753,7 +756,7 @@ def test_allele_frequencies_with_bad_transcript( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) # Set up call params. @@ -780,7 +783,7 @@ def test_allele_frequencies_with_region( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) # This should work, as long as effects=False - i.e., can get frequencies # for any genome region. @@ -838,7 +841,7 @@ def test_allele_frequencies_with_dup_samples( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_set = random.choice(all_sample_sets) site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) transcript = random_transcript(api=api) cohorts = random.choice(["admin1_year", "admin2_month", "country"]) @@ -895,7 +898,7 @@ def check_snp_allele_frequencies_advanced( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) if min_cohort_size is None: - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) if site_mask is None: site_mask = random.choice(api.site_mask_ids + (None,)) @@ -1089,7 +1092,7 @@ def check_aa_allele_frequencies_advanced( all_sample_sets = api.sample_sets()["sample_set"].to_list() sample_sets = random.choice(all_sample_sets) if min_cohort_size is None: - min_cohort_size = random.randint(0, 2) + min_cohort_size = rng.integers(0, 2) # Run function under test. ds = api.aa_allele_frequencies_advanced(