Skip to content

Commit 1904494

Browse files
iindykcopybara-github
authored andcommitted
Internal change.
PiperOrigin-RevId: 879130679
1 parent 057c938 commit 1904494

1 file changed

Lines changed: 88 additions & 0 deletions

File tree

grain/_src/python/dataset/transformations/testing_util.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,94 @@ def test_meta_features(self, convert_input_to_np: bool):
10451045
convert_input_to_np=convert_input_to_np,
10461046
)
10471047

1048+
@parameterized.product(
1049+
convert_input_to_np=[True, False],
1050+
meta_feature_values=[
1051+
[3.5, 7.5, 5.5],
1052+
[np.asarray(3.5), np.asarray(7.5), np.asarray(5.5)],
1053+
[
1054+
np.asarray(3.5, dtype=np.int32),
1055+
np.asarray(7, dtype=np.int32),
1056+
np.asarray(5, dtype=np.int32),
1057+
],
1058+
],
1059+
)
1060+
def test_nonint64_meta_feature(
1061+
self,
1062+
convert_input_to_np: bool,
1063+
meta_feature_values: list[np.ndarray | float],
1064+
):
1065+
input_elements = [
1066+
{
1067+
"inputs": np.asarray([1, 2, 3]),
1068+
"targets": np.asarray([10]),
1069+
"meta_feature": meta_feature_values[0],
1070+
},
1071+
{
1072+
"inputs": np.asarray([4, 5]),
1073+
"targets": np.asarray([20, 30, 40]),
1074+
"meta_feature": meta_feature_values[1],
1075+
},
1076+
{
1077+
"inputs": np.asarray([6]),
1078+
"targets": np.asarray([50, 60]),
1079+
"meta_feature": meta_feature_values[2],
1080+
},
1081+
]
1082+
length_struct = {"inputs": 3, "targets": 4, "meta_feature": 3}
1083+
1084+
expected_elements = [
1085+
{
1086+
"inputs": [1, 2, 3],
1087+
"targets": [10, 0, 0, 0],
1088+
"inputs_segment_ids": [1, 1, 1],
1089+
"targets_segment_ids": [1, 0, 0, 0],
1090+
"inputs_positions": [0, 1, 2],
1091+
"targets_positions": [0, 0, 0, 0],
1092+
"meta_feature": [
1093+
np.asarray(meta_feature_values[0]).item(),
1094+
0,
1095+
0,
1096+
],
1097+
},
1098+
{
1099+
"inputs": [4, 5, 0],
1100+
"targets": [20, 30, 40, 0],
1101+
"inputs_segment_ids": [1, 1, 0],
1102+
"targets_segment_ids": [1, 1, 1, 0],
1103+
"inputs_positions": [0, 1, 0],
1104+
"targets_positions": [0, 1, 2, 0],
1105+
"meta_feature": [
1106+
np.asarray(meta_feature_values[1]).item(),
1107+
0,
1108+
0,
1109+
],
1110+
},
1111+
{
1112+
"inputs": [6, 0, 0],
1113+
"targets": [50, 60, 0, 0],
1114+
"inputs_segment_ids": [1, 0, 0],
1115+
"targets_segment_ids": [1, 1, 0, 0],
1116+
"inputs_positions": [0, 0, 0],
1117+
"targets_positions": [0, 1, 0, 0],
1118+
"meta_feature": [
1119+
np.asarray(meta_feature_values[2]).item(),
1120+
0,
1121+
0,
1122+
],
1123+
},
1124+
]
1125+
_common_test_body(
1126+
self.packer_cls,
1127+
input_elements,
1128+
expected_elements,
1129+
length_struct,
1130+
kwargs=self.kwargs,
1131+
num_packing_bins=3,
1132+
meta_features=["meta_feature"],
1133+
convert_input_to_np=convert_input_to_np,
1134+
)
1135+
10481136
@parameterized.parameters(
10491137
{"restore_at_step": 0},
10501138
{"restore_at_step": 1},

0 commit comments

Comments
 (0)