Skip to content

Commit a42e1d2

Browse files
committed
Spectrum augmentation: Avoid creating TF global step variable statically.
PiperOrigin-RevId: 737138292
1 parent 92ac12c commit a42e1d2

5 files changed

Lines changed: 29 additions & 3 deletions

File tree

protoscribe/corpus/reader/tasks.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from protoscribe.glyphs import glyph_vocab as glyph_lib
2929
from protoscribe.sketches.utils import stroke_tokenizer as tokenizer_lib
3030
from protoscribe.speech import audio_tokenizer
31+
from protoscribe.speech import augmentation
3132
import seqio
3233
import tensorflow as tf
3334

@@ -454,6 +455,10 @@ def register(
454455
sketch_stroke_stats = ds_lib.get_sketch_stroke_stats(config)
455456
stroke_tokenizer = ds_lib.get_stroke_tokenizer(config)
456457

458+
if is_training and speech_spectrum_augmentation:
459+
# Initialize spectrum augmentation.
460+
augmentation.tf_spec_augment_init()
461+
457462
speech_tokenizer = None
458463
if speech_tokenizer_name_or_path:
459464
speech_tokenizer = audio_tokenizer.get_tokenizer(

protoscribe/speech/augmentation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ def _default_spec_augment_config() -> ml_collections.FrozenConfigDict:
3232
})
3333

3434

35+
def tf_spec_augment_init() -> None:
36+
"""Global initialization for spectrum augmenter.
37+
38+
Should be called once ideally to create a global step variable.
39+
"""
40+
impl_lib.tf_spec_augment_init()
41+
42+
3543
def tf_spec_augment(
3644
spectrum: tf.Tensor,
3745
config: ml_collections.FrozenConfigDict | None = None,

protoscribe/speech/augmentation_lingvo.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@
2828
from protoscribe.speech import tf_utils
2929
import tensorflow as tf
3030

31-
# This creates global step variable.
32-
tf_utils.get_or_create_global_step_var()
33-
3431

3532
@dataclasses.dataclass
3633
class AugmenterConfig:
@@ -1028,6 +1025,14 @@ def _augmentation_network(
10281025
return inputs
10291026

10301027

1028+
def tf_spec_augment_init() -> None:
1029+
"""Global initialization for spectrum augmenter.
1030+
1031+
Should be called once ideally to create a global step variable.
1032+
"""
1033+
tf_utils.get_or_create_global_step_var()
1034+
1035+
10311036
def tf_spec_augment_lingvo(
10321037
config: AugmenterConfig,
10331038
inputs: tf.Tensor,

protoscribe/speech/augmentation_lingvo_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121

2222
class AugmentationTest(tf.test.TestCase):
2323

24+
def setUp(self):
25+
super().setUp()
26+
lib.tf_spec_augment_init()
27+
2428
def test_with_time_mask(self):
2529
tf.random.set_seed(127)
2630
batch_size = 5

protoscribe/speech/augmentation_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525

2626
class AugmentationTest(tf.test.TestCase):
2727

28+
def setUp(self):
29+
super().setUp()
30+
lib.tf_spec_augment_init()
31+
2832
def test_with_default_config(self):
2933
input_shape = [_NUM_TIME_STEPS, _NUM_FREQ_BINS]
3034
inputs = tf.ones(input_shape, dtype=tf.float32)

0 commit comments

Comments
 (0)