diff --git a/metrics/bleurt/README.md b/metrics/bleurt/README.md index b14094e5..b8846d5b 100644 --- a/metrics/bleurt/README.md +++ b/metrics/bleurt/README.md @@ -38,7 +38,7 @@ This metric takes as input lists of predicted sentences and reference sentences: >>> predictions = ["hello there", "general kenobi"] >>> references = ["hello there", "general kenobi"] >>> bleurt = load("bleurt", module_type="metric") ->>> results = bleurt.compute(predictions=predictions, references=references) +>>> results = bleurt.compute(predictions=predictions, references=references, batch_size=32) ``` ### Inputs @@ -76,7 +76,7 @@ Example with the default model (`"bleurt-base-128"`): >>> predictions = ["hello there", "general kenobi"] >>> references = ["hello there", "general kenobi"] >>> bleurt = load("bleurt", module_type="metric") ->>> results = bleurt.compute(predictions=predictions, references=references) +>>> results = bleurt.compute(predictions=predictions, references=references, batch_size=32) >>> print(results) {'scores': [1.0295498371124268, 1.0445425510406494]} ``` @@ -86,7 +86,7 @@ Example with the full `"BLEURT-20"` model checkpoint: >>> predictions = ["hello there", "general kenobi"] >>> references = ["hello there", "general kenobi"] >>> bleurt = load("bleurt", module_type="metric", config_name="BLEURT-20") ->>> results = bleurt.compute(predictions=predictions, references=references) +>>> results = bleurt.compute(predictions=predictions, references=references, batch_size=32) >>> print(results) {'scores': [1.015415906906128, 0.9985226988792419]} ``` diff --git a/metrics/bleurt/bleurt.py b/metrics/bleurt/bleurt.py index 11ad20a7..c98cc6c8 100644 --- a/metrics/bleurt/bleurt.py +++ b/metrics/bleurt/bleurt.py @@ -120,6 +120,6 @@ def _download_and_prepare(self, dl_manager): model_path = dl_manager.download_and_extract(CHECKPOINT_URLS[checkpoint_name]) self.scorer = score.BleurtScorer(os.path.join(model_path, checkpoint_name)) - def _compute(self, predictions, references): - scores = self.scorer.score(references=references, candidates=predictions) + def _compute(self, predictions, references, batch_size=None): + scores = self.scorer.score(references=references,candidates=predictions, batch_size=batch_size) return {"scores": scores}