diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a67d723a..0d78ce85 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -70,11 +70,6 @@ repos: - "#" - --allow-past-years types: [python] -- repo: https://github.com/PyCQA/docformatter - rev: v1.5.0 - hooks: - - id: docformatter - args: [--in-place, --wrap-summaries=80, --wrap-descriptions=80] - repo: https://github.com/PyCQA/pydocstyle hooks: - id: pydocstyle diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index acd00f4e..04a8970f 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -247,7 +247,11 @@ def policy_loss( logits=gen_logits, ) assert token_entropies.shape == batch['action_mask'].shape, ( - f'Token entropies shape {token_entropies.shape} does not match action mask shape {batch["action_mask"].shape}.', + 'Token entropies shape {token_entropies_shape} does not match action mask shape {action_mask_shape}.' + .format( + token_entropies_shape=token_entropies.shape, + action_mask_shape=batch['action_mask'].shape, + ), ) seq_entropies = utils.get_sequence_entropies( token_entropies=token_entropies, diff --git a/compose_rl/algorithms/reward_modeling/model.py b/compose_rl/algorithms/reward_modeling/model.py index 2d3637ee..cb6dc449 100644 --- a/compose_rl/algorithms/reward_modeling/model.py +++ b/compose_rl/algorithms/reward_modeling/model.py @@ -329,7 +329,7 @@ def forward( batch['input_ids'], attention_mask=batch['attention_mask'], ).logits - logits = logits[:, :, self.eos_token_id] + logits = logits[:, :, self.eos_token_id] # type: ignore if self.min_threshold is not None and self.max_threshold is not None: logits: torch.Tensor = torch.clamp( logits,