@@ -1045,6 +1045,94 @@ def test_meta_features(self, convert_input_to_np: bool):
10451045 convert_input_to_np = convert_input_to_np ,
10461046 )
10471047
1048+ @parameterized .product (
1049+ convert_input_to_np = [True , False ],
1050+ meta_feature_values = [
1051+ [3.5 , 7.5 , 5.5 ],
1052+ [np .asarray (3.5 ), np .asarray (7.5 ), np .asarray (5.5 )],
1053+ [
1054+ np .asarray (3.5 , dtype = np .int32 ),
1055+ np .asarray (7 , dtype = np .int32 ),
1056+ np .asarray (5 , dtype = np .int32 ),
1057+ ],
1058+ ],
1059+ )
1060+ def test_nonint64_meta_feature (
1061+ self ,
1062+ convert_input_to_np : bool ,
1063+ meta_feature_values : list [np .ndarray | float ],
1064+ ):
1065+ input_elements = [
1066+ {
1067+ "inputs" : np .asarray ([1 , 2 , 3 ]),
1068+ "targets" : np .asarray ([10 ]),
1069+ "meta_feature" : meta_feature_values [0 ],
1070+ },
1071+ {
1072+ "inputs" : np .asarray ([4 , 5 ]),
1073+ "targets" : np .asarray ([20 , 30 , 40 ]),
1074+ "meta_feature" : meta_feature_values [1 ],
1075+ },
1076+ {
1077+ "inputs" : np .asarray ([6 ]),
1078+ "targets" : np .asarray ([50 , 60 ]),
1079+ "meta_feature" : meta_feature_values [2 ],
1080+ },
1081+ ]
1082+ length_struct = {"inputs" : 3 , "targets" : 4 , "meta_feature" : 3 }
1083+
1084+ expected_elements = [
1085+ {
1086+ "inputs" : [1 , 2 , 3 ],
1087+ "targets" : [10 , 0 , 0 , 0 ],
1088+ "inputs_segment_ids" : [1 , 1 , 1 ],
1089+ "targets_segment_ids" : [1 , 0 , 0 , 0 ],
1090+ "inputs_positions" : [0 , 1 , 2 ],
1091+ "targets_positions" : [0 , 0 , 0 , 0 ],
1092+ "meta_feature" : [
1093+ np .asarray (meta_feature_values [0 ]).item (),
1094+ 0 ,
1095+ 0 ,
1096+ ],
1097+ },
1098+ {
1099+ "inputs" : [4 , 5 , 0 ],
1100+ "targets" : [20 , 30 , 40 , 0 ],
1101+ "inputs_segment_ids" : [1 , 1 , 0 ],
1102+ "targets_segment_ids" : [1 , 1 , 1 , 0 ],
1103+ "inputs_positions" : [0 , 1 , 0 ],
1104+ "targets_positions" : [0 , 1 , 2 , 0 ],
1105+ "meta_feature" : [
1106+ np .asarray (meta_feature_values [1 ]).item (),
1107+ 0 ,
1108+ 0 ,
1109+ ],
1110+ },
1111+ {
1112+ "inputs" : [6 , 0 , 0 ],
1113+ "targets" : [50 , 60 , 0 , 0 ],
1114+ "inputs_segment_ids" : [1 , 0 , 0 ],
1115+ "targets_segment_ids" : [1 , 1 , 0 , 0 ],
1116+ "inputs_positions" : [0 , 0 , 0 ],
1117+ "targets_positions" : [0 , 1 , 0 , 0 ],
1118+ "meta_feature" : [
1119+ np .asarray (meta_feature_values [2 ]).item (),
1120+ 0 ,
1121+ 0 ,
1122+ ],
1123+ },
1124+ ]
1125+ _common_test_body (
1126+ self .packer_cls ,
1127+ input_elements ,
1128+ expected_elements ,
1129+ length_struct ,
1130+ kwargs = self .kwargs ,
1131+ num_packing_bins = 3 ,
1132+ meta_features = ["meta_feature" ],
1133+ convert_input_to_np = convert_input_to_np ,
1134+ )
1135+
10481136 @parameterized .parameters (
10491137 {"restore_at_step" : 0 },
10501138 {"restore_at_step" : 1 },
0 commit comments