|
62 | 62 | import jax |
63 | 63 | import jax.numpy as jnp |
64 | 64 | import numpy as np |
| 65 | +from packaging import version |
65 | 66 | import tensorflow as tf |
66 | 67 | import tensorflow_datasets as tfds |
67 | 68 | import typing_extensions |
|
73 | 74 |
|
74 | 75 | AUTOTUNE = tf.data.experimental.AUTOTUNE |
75 | 76 |
|
| 77 | +_use_split_info = version.parse("4.4.0") < version.parse( |
| 78 | + tfds.version.__version__) |
| 79 | + |
76 | 80 |
|
77 | 81 | class DatasetBuilder(typing_extensions.Protocol): |
78 | 82 | """Protocol for dataset builders (subset of tfds.core.DatasetBuilder).""" |
@@ -106,15 +110,18 @@ class RemainderOptions(enum.Enum): |
106 | 110 | def _shard_read_instruction( |
107 | 111 | absolute_instruction, |
108 | 112 | *, |
109 | | - split_infos: Dict[str, tfds.core.SplitInfo], |
| 113 | + split_infos: Dict[str, Union[int, tfds.core.SplitInfo]], |
110 | 114 | host_id: int, |
111 | 115 | host_count: int, |
112 | 116 | remainder_options: RemainderOptions, |
113 | 117 | ) -> tfds.core.ReadInstruction: |
114 | 118 | """Shards a single ReadInstruction. See get_read_instruction_for_host().""" |
115 | 119 | start = absolute_instruction.from_ or 0 |
116 | | - end = absolute_instruction.to or ( |
117 | | - split_infos[absolute_instruction.splitname].num_examples) |
| 120 | + if _use_split_info: |
| 121 | + end = absolute_instruction.to or ( |
| 122 | + split_infos[absolute_instruction.splitname].num_examples) # pytype: disable=attribute-error |
| 123 | + else: |
| 124 | + end = absolute_instruction.to or split_infos[absolute_instruction.splitname] |
118 | 125 | assert end >= start, f"start={start}, end={end}" |
119 | 126 | num_examples = end - start |
120 | 127 |
|
@@ -208,16 +215,23 @@ def get_read_instruction_for_host( |
208 | 215 | f"Invalid combination of host_id ({host_id}) and host_count " |
209 | 216 | f"({host_count}).") |
210 | 217 |
|
211 | | - if dataset_info is None: |
212 | | - split_infos = { |
213 | | - split: tfds.core.SplitInfo( |
214 | | - name=split, |
215 | | - shard_lengths=[num_examples], |
216 | | - num_bytes=0, |
217 | | - ), |
218 | | - } |
| 218 | + if _use_split_info: |
| 219 | + if dataset_info is None: |
| 220 | + split_infos = { |
| 221 | + split: tfds.core.SplitInfo( |
| 222 | + name=split, |
| 223 | + shard_lengths=[num_examples], |
| 224 | + num_bytes=0, |
| 225 | + ), |
| 226 | + } |
| 227 | + else: |
| 228 | + split_infos = dataset_info.splits |
219 | 229 | else: |
220 | | - split_infos = dataset_info.splits |
| 230 | + if dataset_info is None: |
| 231 | + split_infos = {split: num_examples} |
| 232 | + else: |
| 233 | + split_infos = {k: v.num_examples for k, v in dataset_info.splits.items()} |
| 234 | + |
221 | 235 | read_instruction = tfds.core.ReadInstruction.from_spec(split) |
222 | 236 | sharded_read_instructions = [] |
223 | 237 | for ri in read_instruction.to_absolute(split_infos): |
|
0 commit comments