From ef9e8f65dfe8698506c89da6afd0b4278f5e9ad9 Mon Sep 17 00:00:00 2001 From: Mannat Singh Date: Wed, 4 Nov 2020 11:22:55 -0800 Subject: [PATCH 1/2] Create function to see if loss has learable parameters Differential Revision: D24729686 fbshipit-source-id: e3ca6077de362ddc6b13b407a6cecd3af844e92f --- classy_vision/tasks/classification_task.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index 443ce61dcb..526ae8d58d 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -719,10 +719,7 @@ def init_distributed_data_parallel_model(self): broadcast_buffers=broadcast_buffers, find_unused_parameters=self.find_unused_parameters, ) - if ( - isinstance(self.base_loss, ClassyLoss) - and self.base_loss.has_learned_parameters() - ): + if self._loss_has_learnable_params(): logging.info("Initializing distributed loss") self.distributed_loss = init_distributed_data_parallel_model( self.base_loss, @@ -1014,6 +1011,13 @@ def _broadcast_buffers(self): for buffer in buffers: broadcast(buffer, 0, group=self.distributed_model.process_group) + def _loss_has_learnable_params(self): + """Returns True if the loss has any learnable parameters""" + return ( + isinstance(self.base_loss, ClassyLoss) + and self.base_loss.has_learned_parameters() + ) + # TODO: Functions below should be better abstracted into the dataloader # abstraction def get_batchsize_per_replica(self): From af87428c2cdb7784eaeb09ceb4b7f75ec56bb6ed Mon Sep 17 00:00:00 2001 From: Mannat Singh Date: Wed, 4 Nov 2020 11:23:33 -0800 Subject: [PATCH 2/2] Support clipping the gradient norm Differential Revision: D24731449 fbshipit-source-id: 7e4253bb547c44286b78919f6bc92d8caf2d8aec --- classy_vision/tasks/classification_task.py | 40 ++++++++++++++++++++++ test/tasks_classification_task_test.py | 22 ++++++++++++ 2 files changed, 62 insertions(+) diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index 526ae8d58d..a99fffcf38 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -6,6 +6,7 @@ import copy import enum +import itertools import json import logging import math @@ -157,6 +158,7 @@ def __init__(self): ) self.amp_args = None self.mixup_transform = None + self.grad_norm_clip = None self.perf_log = [] self.last_batch = None self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED @@ -412,6 +414,24 @@ def set_optimizer_schedulers(self, schedulers): self.optimizer_schedulers = schedulers return self + def set_grad_norm_clip( + self, + grad_norm_clip: Optional[float], + ) -> "ClassificationTask": + """Enable / disable clipping the gradient norm + + Args: + grad_norm_clip: The value to clip the gradient by, set to None to disable + """ + if grad_norm_clip is None: + logging.info(f"Disabled gradient norm clipping: {grad_norm_clip}") + else: + logging.info( + f"Enabled gradient norm clipping with threshold: {grad_norm_clip}" + ) + self.grad_norm_clip = grad_norm_clip + return self + @classmethod def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": """Instantiates a ClassificationTask from a configuration. @@ -489,6 +509,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": .set_distributed_options(**distributed_options) .set_hooks(hooks) .set_bn_weight_decay(config.get("bn_weight_decay", False)) + .set_grad_norm_clip(config.get("grad_norm_clip")) ) if not test_only: @@ -916,6 +937,9 @@ def train_step(self): else: self.optimizer.backward(local_loss) + if self.grad_norm_clip is not None: + self._clip_grad_norm() + self.check_inf_nan(loss) self.optimizer.step(where=self.where) @@ -989,6 +1013,22 @@ def create_data_iterators(self): del self.data_iterator self.data_iterator = iter(self.dataloader) + def _clip_grad_norm(self): + """Clip the gradient norms based on self.grad_norm_clip""" + model_params = ( + self.base_model.parameters() + if self.amp_args is None + else apex.amp.master_params(self.optimizer.optimizer) + ) + loss_params = ( + self.base_loss.parameters() + if self._loss_has_learnable_params() + else iter(()) + ) + nn.utils.clip_grad_norm_( + itertools.chain(model_params, loss_params), self.grad_norm_clip + ) + def _set_model_train_mode(self): """Set train mode for model""" phase = self.phases[self.phase_idx] diff --git a/test/tasks_classification_task_test.py b/test/tasks_classification_task_test.py index b1fbbbdf92..b5547c539c 100644 --- a/test/tasks_classification_task_test.py +++ b/test/tasks_classification_task_test.py @@ -7,6 +7,7 @@ import copy import shutil import tempfile +import itertools import unittest from test.generic.config_utils import get_fast_test_task_config, get_test_task_config from test.generic.utils import ( @@ -284,3 +285,24 @@ def test_get_classy_state_on_loss(self): task = build_task(config) task.prepare() self.assertIn("alpha", task.get_classy_state()["loss"]) + + def test_grad_norm_clip(self): + config = get_fast_test_task_config() + config["loss"] = {"name": "test_stateful_loss", "in_plane": 256} + config["grad_norm_clip"] = grad_norm_clip = 1 + task = build_task(config) + task.prepare() + + # set fake gradients with norm > grad_norm_clip + for param in itertools.chain( + task.base_model.parameters(), task.base_loss.parameters() + ): + param.grad = 1.1 + torch.rand(param.shape) + self.assertGreater(param.grad.norm(), grad_norm_clip) + + task._clip_grad_norm() + + for param in itertools.chain( + task.base_model.parameters(), task.base_loss.parameters() + ): + self.assertLessEqual(param.grad.norm(), grad_norm_clip)