File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 2828from protoscribe .glyphs import glyph_vocab as glyph_lib
2929from protoscribe .sketches .utils import stroke_tokenizer as tokenizer_lib
3030from protoscribe .speech import audio_tokenizer
31+ from protoscribe .speech import augmentation
3132import seqio
3233import tensorflow as tf
3334
@@ -454,6 +455,10 @@ def register(
454455 sketch_stroke_stats = ds_lib .get_sketch_stroke_stats (config )
455456 stroke_tokenizer = ds_lib .get_stroke_tokenizer (config )
456457
458+ if is_training and speech_spectrum_augmentation :
459+ # Initialize spectrum augmentation.
460+ augmentation .tf_spec_augment_init ()
461+
457462 speech_tokenizer = None
458463 if speech_tokenizer_name_or_path :
459464 speech_tokenizer = audio_tokenizer .get_tokenizer (
Original file line number Diff line number Diff line change @@ -32,6 +32,14 @@ def _default_spec_augment_config() -> ml_collections.FrozenConfigDict:
3232 })
3333
3434
35+ def tf_spec_augment_init () -> None :
36+ """Global initialization for spectrum augmenter.
37+
38+ Should be called once ideally to create a global step variable.
39+ """
40+ impl_lib .tf_spec_augment_init ()
41+
42+
3543def tf_spec_augment (
3644 spectrum : tf .Tensor ,
3745 config : ml_collections .FrozenConfigDict | None = None ,
Original file line number Diff line number Diff line change 2828from protoscribe .speech import tf_utils
2929import tensorflow as tf
3030
31- # This creates global step variable.
32- tf_utils .get_or_create_global_step_var ()
33-
3431
3532@dataclasses .dataclass
3633class AugmenterConfig :
@@ -1028,6 +1025,14 @@ def _augmentation_network(
10281025 return inputs
10291026
10301027
1028+ def tf_spec_augment_init () -> None :
1029+ """Global initialization for spectrum augmenter.
1030+
1031+ Should be called once ideally to create a global step variable.
1032+ """
1033+ tf_utils .get_or_create_global_step_var ()
1034+
1035+
10311036def tf_spec_augment_lingvo (
10321037 config : AugmenterConfig ,
10331038 inputs : tf .Tensor ,
Original file line number Diff line number Diff line change 2121
2222class AugmentationTest (tf .test .TestCase ):
2323
24+ def setUp (self ):
25+ super ().setUp ()
26+ lib .tf_spec_augment_init ()
27+
2428 def test_with_time_mask (self ):
2529 tf .random .set_seed (127 )
2630 batch_size = 5
Original file line number Diff line number Diff line change 2525
2626class AugmentationTest (tf .test .TestCase ):
2727
28+ def setUp (self ):
29+ super ().setUp ()
30+ lib .tf_spec_augment_init ()
31+
2832 def test_with_default_config (self ):
2933 input_shape = [_NUM_TIME_STEPS , _NUM_FREQ_BINS ]
3034 inputs = tf .ones (input_shape , dtype = tf .float32 )
You can’t perform that action at this time.
0 commit comments