From fe5e1efe9b74332a94a66f4c0a6b2fffa51c404d Mon Sep 17 00:00:00 2001 From: RuntimeRacer Date: Wed, 3 May 2023 00:33:13 +0200 Subject: [PATCH 1/4] continue training even if a broken data batch was hit --- valle/bin/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/valle/bin/trainer.py b/valle/bin/trainer.py index 44349a0..069524c 100644 --- a/valle/bin/trainer.py +++ b/valle/bin/trainer.py @@ -692,8 +692,9 @@ def train_one_epoch( set_batch_count(model, params.batch_idx_train) except: # noqa + logging.warning(f"Hit a broken batch of training data. Cut ID: {batch['utt_id']} Text: {batch['text']} - Skipping...") display_and_save_batch(batch, params=params) - raise + continue if params.average_period > 0: if ( From 234c723e41e0297c1ffe64b742fa6ff3decad34a Mon Sep 17 00:00:00 2001 From: RuntimeRacer Date: Wed, 3 May 2023 01:02:06 +0200 Subject: [PATCH 2/4] also perform grad scaling in pre-run OOM check. --- valle/bin/trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/valle/bin/trainer.py b/valle/bin/trainer.py index 069524c..c21ca38 100644 --- a/valle/bin/trainer.py +++ b/valle/bin/trainer.py @@ -1102,6 +1102,10 @@ def scan_pessimistic_batches_for_oom( elif params.dtype in ["float16", "fp16"]: dtype = torch.float16 + scaler = GradScaler( + enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0 + ) + for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: @@ -1112,7 +1116,7 @@ def scan_pessimistic_batches_for_oom( batch=batch, is_training=True, ) - loss.backward() + scaler.scale(loss).backward() optimizer.zero_grad() except Exception as e: if "CUDA out of memory" in str(e): From 2541aed7d25b941210061ec31bad6e6073a5b6cf Mon Sep 17 00:00:00 2001 From: RuntimeRacer Date: Wed, 3 May 2023 15:32:21 +0200 Subject: [PATCH 3/4] try explicit cleanup of everything related to the batch to prevent training being stale. --- valle/bin/trainer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/valle/bin/trainer.py b/valle/bin/trainer.py index c21ca38..fa69d09 100644 --- a/valle/bin/trainer.py +++ b/valle/bin/trainer.py @@ -692,8 +692,19 @@ def train_one_epoch( set_batch_count(model, params.batch_idx_train) except: # noqa + # Save the broken batch logging.warning(f"Hit a broken batch of training data. Cut ID: {batch['utt_id']} Text: {batch['text']} - Skipping...") display_and_save_batch(batch, params=params) + # Clean up batch data from Memory and GPU + del batch["text_tokens"] + del batch["text_tokens_lens"] + del batch["audio_features"] + del batch["audio_features_lens"] + del batch + del loss + del loss_info + torch.cuda.empty_cache() + # Continue training continue if params.average_period > 0: From f9292b2da7f371a51cca5c359a5482c99dcade84 Mon Sep 17 00:00:00 2001 From: RuntimeRacer Date: Thu, 4 May 2023 10:01:36 +0200 Subject: [PATCH 4/4] avoid error in case we hit OOM already on forward pass --- valle/bin/trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/valle/bin/trainer.py b/valle/bin/trainer.py index fa69d09..8525ad6 100644 --- a/valle/bin/trainer.py +++ b/valle/bin/trainer.py @@ -701,8 +701,11 @@ def train_one_epoch( del batch["audio_features"] del batch["audio_features_lens"] del batch - del loss - del loss_info + try: + del loss + del loss_info + except UnboundLocalError: + pass torch.cuda.empty_cache() # Continue training continue