Skip to content

Commit b537b7f

Browse files
committed
Fix missing label handling
1 parent aa59606 commit b537b7f

4 files changed

Lines changed: 27 additions & 8 deletions

File tree

chebai/models/electra.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,13 @@ def _process_for_loss(
287287
tuple: A tuple containing the processed model output, labels, and loss arguments.
288288
"""
289289
kwargs_copy = dict(loss_kwargs)
290+
output = model_output["logits"]
290291
if labels is not None:
291292
labels = labels.float()
292-
return model_output["logits"], labels, kwargs_copy
293+
if "missing_labels" in kwargs_copy:
294+
missing_labels = kwargs_copy.pop("missing_labels")
295+
output = output * (~missing_labels).int()
296+
return output, labels, kwargs_copy
293297

294298
def _get_prediction_and_labels(
295299
self, data: Dict[str, Any], labels: Tensor, model_output: Dict[str, Tensor]
@@ -310,6 +314,11 @@ def _get_prediction_and_labels(
310314
if "non_null_labels" in loss_kwargs:
311315
n = loss_kwargs["non_null_labels"]
312316
d = d[n]
317+
318+
if "missing_labels" in loss_kwargs:
319+
missing_labels = loss_kwargs["missing_labels"]
320+
labels = labels * (~missing_labels).int()
321+
313322
return torch.sigmoid(d), labels.int() if labels is not None else None
314323

315324
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:

chebai/preprocessing/collate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
6464
Handles both fully and partially labeled data, where some samples may have `None` as their label. The indices
6565
of non-null labels are stored in the `non_null_labels` field, which is used to filter out predictions for
6666
unlabeled data during evaluation (e.g., F1, MSE). For models supporting partially labeled data, this method
67-
ensures alignment between features and labels.
67+
ensures alignment between features and labels. Missing labels are passed as a loss keyword.
6868
6969
Args:
7070
data (List[Union[Dict, Tuple]]): List of ragged data samples. Each sample can be a dictionary or tuple
@@ -81,10 +81,13 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
8181
if isinstance(data[0], tuple):
8282
# For legacy data
8383
x, y, idents = zip(*data)
84+
missing_labels = None
8485
else:
8586
x, y, idents = zip(
8687
*((d["features"], d["labels"], d.get("ident")) for d in data)
8788
)
89+
missing_labels = [d.get("missing_labels", [False for _ in y[0]]) for d in data]
90+
8891
if any(x is not None for x in y):
8992
# If any label is not None: (None, None, `1`, None)
9093
if any(x is None for x in y):
@@ -97,11 +100,13 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
97100
else:
98101
# If all labels are not None: (`0`, `2`, `1`, `3`)
99102
y = self.process_label_rows(y)
103+
100104
else:
101105
# If all labels are None : (`None`, `None`, `None`, `None`)
102106
y = None
103107
loss_kwargs["non_null_labels"] = []
104108

109+
loss_kwargs["missing_labels"] = torch.tensor(missing_labels)
105110
# Calculate the lengths of each sequence, create a binary mask for valid (non-padded) positions
106111
lens = torch.tensor(list(map(len, x)))
107112
model_kwargs["mask"] = torch.arange(max(lens))[None, :] < lens[:, None]

chebai/preprocessing/datasets/tox21.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def download(self) -> None:
6868
def setup_processed(self) -> None:
6969
"""Processes and splits the dataset."""
7070
print("Create splits")
71-
data = self._load_data_from_file(os.path.join(self.raw_dir, f"tox21.csv"))
72-
groups = np.array([d["group"] for d in data])
71+
data = list(self._load_data_from_file(os.path.join(self.raw_dir, f"tox21.csv")))
72+
groups = np.array([d.get("group") for d in data])
7373
if not all(g is None for g in groups):
7474
split_size = int(len(set(groups)) * self.train_split)
7575
os.makedirs(self.processed_dir, exist_ok=True)
@@ -129,7 +129,7 @@ def setup(self, **kwargs) -> None:
129129
):
130130
self.setup_processed()
131131

132-
def _load_data_from_file(self, input_file_path: str) -> List[Dict]:
132+
def _load_dict(self, input_file_path: str) -> List[Dict]:
133133
"""Loads data from a CSV file.
134134
135135
Args:

chebai/preprocessing/reader.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,18 @@ def _read_group(self, raw: Any) -> Any:
9292
return raw
9393

9494
def _read_components(self, row: Dict[str, Any]) -> Dict[str, Any]:
95-
"""Read and return components from the row."""
95+
"""Read and return components from the row. If the data contains any missing labels (`None`), they are tracked
96+
under the additional `missing_labels` keyword."""
97+
labels = self._get_raw_label(row)
98+
additional_kwargs = self._get_additional_kwargs(row)
99+
if any(l is None for l in labels):
100+
additional_kwargs["missing_labels"] = [l is None for l in labels]
96101
return dict(
97102
features=self._get_raw_data(row),
98-
labels=self._get_raw_label(row),
103+
labels=labels,
99104
ident=self._get_raw_id(row),
100105
group=self._get_raw_group(row),
101-
additional_kwargs=self._get_additional_kwargs(row),
106+
additional_kwargs=additional_kwargs,
102107
)
103108

104109
def to_data(self, row: Dict[str, Any]) -> Dict[str, Any]:

0 commit comments

Comments
 (0)