Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions src/models/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,36 @@
class BalancedSampler(Sampler):
def __init__(self, dataset, majority_rate: float=0.5, seed=None):

if isinstance(dataset, dict):
self.labels = dataset["labels"].tolist() if isinstance(dataset["labels"], torch.Tensor) else dataset["labels"]
elif isinstance(dataset, Dataset):
self.labels = [dataset[i]['labels'].item() for i in range(len(dataset))]
# `isinstance(dataset, Dataset)` checks `torch.utils.data.Dataset`,
# but a HuggingFace `datasets.arrow_dataset.Dataset` does NOT
# inherit from it (they're unrelated classes). The else-branch
# below previously expected `torch.Tensor` from
# `dataset["labels"]`, but `datasets >= 4` returns a
# `datasets.arrow_dataset.Column` instead — `.tolist()` is never
# called and `self.labels` ends up as a sequence of 0-d Tensor
# objects. The defaultdict then buckets each tensor by `id()`
# (every label its own key) and the int-key check
# `0 in label_to_indices` fails even when the dataset has both
# classes. Normalise once, here, and let downstream code see
# plain Python ints.
if isinstance(dataset, list):
raw = dataset
elif isinstance(dataset, dict) or hasattr(dataset, '__getitem__'):
try:
raw = dataset["labels"]
except (KeyError, TypeError):
raw = [dataset[i]['labels'] for i in range(len(dataset))]
else:
self.labels = dataset["labels"].tolist() if isinstance(dataset["labels"], torch.Tensor) else dataset["labels"]
raw = dataset

if isinstance(raw, torch.Tensor):
self.labels = raw.tolist()
elif hasattr(raw, 'tolist'):
# numpy array, pandas Series, datasets.arrow_dataset.Column, …
self.labels = list(raw.tolist())
else:
# Per-row tensors → Python ints
self.labels = [int(x.item()) if hasattr(x, 'item') else int(x) for x in raw]
self.majority_rate = majority_rate
self.seed = seed

Expand Down
11 changes: 11 additions & 0 deletions src/models/dnabert2_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Stub for the upstream's dnabert2 path.
#
# `run_preprocess.py` imports this unconditionally but never calls into
# it on the DNABERT (no-2) code path. The upstream commit that should
# have shipped the real module never materialised in
# opensensor/CRISPR_DNABERT@bfbeb81e; this stub keeps the import
# resolvable without changing behaviour for the no-2 path we care about.
#
# If a future code path actually invokes anything here, the AttributeError
# will be loud and obvious — we deliberately do NOT shim functions
# silently.
10 changes: 8 additions & 2 deletions src/models/dnabert_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,10 @@ def test_scratch(self) -> None:
# Results processing
probabilities = inference_results["probability"]
predictions = inference_results["prediction"]
true_labels = test_dataset["labels"].numpy()
# datasets >=4 returns a `datasets.arrow_dataset.Column` from
# `dataset["labels"]`, not a torch.Tensor; .numpy() is gone.
# Pull the column to a list and convert via numpy.
true_labels = np.asarray(list(test_dataset["labels"]))

# Save the results
os.makedirs(os.path.dirname(self.result_path), exist_ok=True)
Expand Down Expand Up @@ -750,7 +753,10 @@ def test_transfer_epi(self) -> None:
# Results processing
probabilities = inference_results["probability"]
predictions = inference_results["prediction"]
true_labels = test_dataset["labels"].numpy()
# datasets >=4 returns a `datasets.arrow_dataset.Column` from
# `dataset["labels"]`, not a torch.Tensor; .numpy() is gone.
# Pull the column to a list and convert via numpy.
true_labels = np.asarray(list(test_dataset["labels"]))

# Save the results
os.makedirs(os.path.dirname(self.result_path), exist_ok=True)
Expand Down
29 changes: 21 additions & 8 deletions src/models/pair_finetuning_dnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def generate_random_sequence_input(self, rna_seq_list: list, n_samples: int) ->
dna_seqs.append("-" + rna_seq)
# Generate off-target pairs with random mutations
for mismatch_count in range(1, 7):
for _ in range((n_samples/n_sgrna)//6):
for _ in range(int((n_samples / n_sgrna) // 6)):
# Generate random DNA sequence with specified number of mismatches
rna_list = list(rna_seq)
dna_list = rna_list.copy()
Expand All @@ -240,21 +240,33 @@ def generate_random_sequence_input(self, rna_seq_list: list, n_samples: int) ->
return {"rna_seq": rna_seqs, "dna_seq": dna_seqs}

def load_sequence_data(self, if_test=None) -> Dataset:
# Two upstream bugs fixed here:
# 1. data_loader.load_dataset_information returns the dataset
# dict with key "sgrna" (lowercase, see data_loader.py:168),
# not "sgRNA". The previous `dataset_dict["sgRNA"]` lookup
# was a hard KeyError on the first dataset.
# 2. `sgrna_seqs[dataset_name].extend(...)` treats the lists
# `sgrna_seqs = []` and `dna_seqs = []` as if they were dicts
# keyed by dataset name. That's a TypeError as soon as the
# first iteration runs. The intent is clearly to accumulate
# across all datasets — direct list extend matches that.
dataset_names = [
"Lazzarotto_2020_CHANGE_seq", "Lazzarotto_2020_GUIDE_seq", "SchmidBurgk_2020_TTISS",
"Chen_2017_GUIDE_seq", "Listgarten_2018_GUIDE_seq", "Tsai_2015_GUIDE_seq_1", "Tsai_2015_GUIDE_seq_2"
]
dna_seqs = []
sgrna_seqs = []
dna_seqs: list[str] = []
sgrna_seqs: list[str] = []
for dataset_name in dataset_names:
self.config["dataset_name"]["dataset_current"] = dataset_name
DataLoaderClass = data_loader.DataLoaderClass(self.config)
dataset_dict = DataLoaderClass.load_dataset()
sgrna_list = dataset_dict["sgRNA"]
sgrna_list = dataset_dict["sgrna"]
# Generate random sequence inputs
generated_data = self.generate_random_sequence_input(sgrna_list, n_samples=len(dataset_dict["rna_seq"])//10)
sgrna_seqs[dataset_name].extend(generated_data["rna_seq"])
dna_seqs[dataset_name].extend(generated_data["dna_seq"])
generated_data = self.generate_random_sequence_input(
sgrna_list, n_samples=len(dataset_dict["rna_seq"]) // 10
)
sgrna_seqs.extend(generated_data["rna_seq"])
dna_seqs.extend(generated_data["dna_seq"])
dataset = self.process(sgrna_seqs, dna_seqs)
if if_test:
dataset = dataset.select(range(200000))
Expand Down Expand Up @@ -409,7 +421,8 @@ def train(self) -> None:
args=training_args,
train_dataset=train_datasets,
compute_metrics=compute_metrics,
tokenizer=self.tokenizer,
# transformers >=5: kwarg renamed tokenizer → processing_class
processing_class=self.tokenizer,
callbacks=[loss_callback], # Add the loss callback here
)

Expand Down