diff --git a/evalution/engines/transformers.py b/evalution/engines/transformers.py index bb9bdf5..d3b2a02 100644 --- a/evalution/engines/transformers.py +++ b/evalution/engines/transformers.py @@ -469,7 +469,7 @@ def _processor_init_with_fa2_cb_graph_fix(self: Any, *args: Any, **kwargs: Any) False, ): @wraps(current_generation_step) - def _generation_step_with_fa2_cb_graph_fix(self: Any, model: Any) -> None: + def _generation_step_with_fa2_cb_graph_fix(self: Any, *args: Any, **kwargs: Any) -> None: """Implement generation step with fa2 cb graph fix for this module.""" original_use_cuda_graph = getattr(self, "use_cuda_graph", False) @@ -479,7 +479,11 @@ def _generation_step_with_fa2_cb_graph_fix(self: Any, model: Any) -> None: else self.use_cuda_graph_varlen ) try: - current_generation_step(self, model) + # Keep compatibility across transformers versions where + # _generation_step signatures differ: + # - old: _generation_step(self, model, logit_processor) + # - new: _generation_step(self, model) + current_generation_step(self, *args, **kwargs) finally: self.use_cuda_graph = original_use_cuda_graph