diff --git a/ci/tests/test_scgpt/test_binning.py b/ci/tests/test_scgpt/test_binning.py index b93526a7..524b7ad0 100644 --- a/ci/tests/test_scgpt/test_binning.py +++ b/ci/tests/test_scgpt/test_binning.py @@ -38,7 +38,7 @@ def test_digitize_identical_bins(): x = np.array([1, 2, 3, 4, 5]) bins = np.array([2, 2, 4, 4]) result = _digitize(x, bins) - expected = np.array([0, 1, 2, 3, 4]) + expected = np.array([0, 2, 2, 3, 4]) assert np.array_equal(result, expected) @@ -51,7 +51,7 @@ def test_digitize_identical_bins(): np.array([1, 1, 1, 2, 2, 2, 3, 3, 4]), ), # distrubution of the bins depends on the distribution of the data - (np.array([1, 1, 1, 1, 1, 1, 1, 8, 9]), np.array([2, 1, 1, 2, 3, 3, 3, 3, 4])), + (np.array([1, 1, 1, 1, 1, 1, 1, 8, 9]), np.array([1, 2, 1, 1, 3, 1, 1, 3, 4])), (np.array([1, 2, 1, 1, 9, 6, 7, 8, 9]), np.array([1, 2, 1, 1, 4, 2, 2, 3, 4])), ), ) diff --git a/helical/models/scgpt/binning.py b/helical/models/scgpt/binning.py index 9c457fb1..d029da6c 100644 --- a/helical/models/scgpt/binning.py +++ b/helical/models/scgpt/binning.py @@ -32,7 +32,8 @@ def _digitize(x: np.ndarray, bins: np.ndarray, side="both") -> np.ndarray: right_digits = np.digitize(x, bins, right=True) - rands = np.random.rand(len(x)) # uniform random numbers + rng = np.random.default_rng(42) + rands = rng.random(len(x)) digits = rands * (right_digits - left_digits) + left_digits digits = np.ceil(digits).astype(np.int64) diff --git a/helical/models/scgpt/data_collator.py b/helical/models/scgpt/data_collator.py index 5dd3a634..5fdc1882 100644 --- a/helical/models/scgpt/data_collator.py +++ b/helical/models/scgpt/data_collator.py @@ -163,7 +163,10 @@ def _sample( # keep the first n tokens unchanged _n = self.keep_first_n_tokens - indices = torch.randperm(len(genes) - _n, device=device)[: max_length - _n] + g = torch.Generator().manual_seed(0) + indices = torch.randperm(len(genes) - _n, device=device, generator=g)[ + : max_length - _n + ] indices = torch.cat([torch.arange(_n), indices + _n], dim=0) return genes[indices], expressions[indices] diff --git a/pyproject.toml b/pyproject.toml index f6747baa..3c3b9be7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "helical" -version = "1.4.4" +version = "1.4.5" authors = [ { name="Helical Team", email="support@helical-ai.com" }, ]