@@ -64,7 +64,7 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
6464 Handles both fully and partially labeled data, where some samples may have `None` as their label. The indices
6565 of non-null labels are stored in the `non_null_labels` field, which is used to filter out predictions for
6666 unlabeled data during evaluation (e.g., F1, MSE). For models supporting partially labeled data, this method
67- ensures alignment between features and labels.
67+ ensures alignment between features and labels. Missing labels are passed as a loss keyword.
6868
6969 Args:
7070 data (List[Union[Dict, Tuple]]): List of ragged data samples. Each sample can be a dictionary or tuple
@@ -81,10 +81,13 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
8181 if isinstance (data [0 ], tuple ):
8282 # For legacy data
8383 x , y , idents = zip (* data )
84+ missing_labels = None
8485 else :
8586 x , y , idents = zip (
8687 * ((d ["features" ], d ["labels" ], d .get ("ident" )) for d in data )
8788 )
89+ missing_labels = [d .get ("missing_labels" , [False for _ in y [0 ]]) for d in data ]
90+
8891 if any (x is not None for x in y ):
8992 # If any label is not None: (None, None, `1`, None)
9093 if any (x is None for x in y ):
@@ -97,11 +100,13 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
97100 else :
98101 # If all labels are not None: (`0`, `2`, `1`, `3`)
99102 y = self .process_label_rows (y )
103+
100104 else :
101105 # If all labels are None : (`None`, `None`, `None`, `None`)
102106 y = None
103107 loss_kwargs ["non_null_labels" ] = []
104108
109+ loss_kwargs ["missing_labels" ] = torch .tensor (missing_labels )
105110 # Calculate the lengths of each sequence, create a binary mask for valid (non-padded) positions
106111 lens = torch .tensor (list (map (len , x )))
107112 model_kwargs ["mask" ] = torch .arange (max (lens ))[None , :] < lens [:, None ]
0 commit comments