Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions quality/evaluator/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,16 +400,31 @@ def _stem_japanese_segments(segments: Iterable[str]) -> list[str]:


def _sentencepiece_tokenizer(sp_tokenizer: Any) -> Callable[[str], list[str]]:
def _coerce_tokens(raw_tokens: Any) -> list[str]:
if isinstance(raw_tokens, (list, tuple)):
return [str(token) for token in raw_tokens]
tokens_attr = getattr(raw_tokens, "tokens", None)
if tokens_attr is not None:
return [str(token) for token in tokens_attr]
return [str(raw_tokens)]

def _tokenize(text: str) -> list[str]:
encoded = sp_tokenizer.encode(text)
if isinstance(encoded, (list, tuple)):
tokens = [str(token) for token in encoded]
encode_fn = getattr(sp_tokenizer, "encode", None)
encode_as_pieces_fn = getattr(sp_tokenizer, "encode_as_pieces", None)
tokens_source: Any
if callable(encode_fn):
try:
tokens_source = encode_fn(text, out_type=str)
except TypeError:
if callable(encode_as_pieces_fn):
tokens_source = encode_as_pieces_fn(text)
else:
tokens_source = encode_fn(text)
elif callable(encode_as_pieces_fn):
tokens_source = encode_as_pieces_fn(text)
else:
tokens_attr = getattr(encoded, "tokens", None)
if tokens_attr is not None:
tokens = [str(token) for token in tokens_attr]
else:
tokens = [str(encoded)]
tokens_source = [text]
tokens = _coerce_tokens(tokens_source)
return _stem_japanese_segments(tokens)

return _tokenize
Expand Down
68 changes: 65 additions & 3 deletions tests/quality/evaluator/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def score(self, reference: str, prediction: str) -> dict[str, _FakeRougeScore]:
class _FakeSentencePieceProcessor:
last_loaded: Path | None = None
last_encoded: list[str] = []
last_encode_kwargs: list[dict[str, Any]] = []
encode_as_pieces_inputs: list[str] = []
require_out_type: bool = True
allow_out_type: bool = True

def __init__(self, *, model_file: str | None = None) -> None:
if model_file:
Expand All @@ -68,13 +72,32 @@ def load(self, model_path: str) -> bool:
type(self).last_loaded = Path(model_path)
return True

def encode(self, text: str) -> list[str]:
type(self).last_encoded.append(text)
@staticmethod
def _to_pieces(text: str) -> list[str]:
normalized = text.replace(" ", "")
if not normalized:
return ["▁"]
return [f"▁{normalized}"]

def encode(self, text: Any, **kwargs: Any) -> list[str]:
type(self).last_encoded.append(text)
type(self).last_encode_kwargs.append(dict(kwargs))
if not isinstance(text, str):
raise TypeError("encode() expects text as str")
if type(self).require_out_type:
out_type = kwargs.get("out_type")
if out_type is not str:
raise AssertionError("encode() must be called with out_type=str")
if not type(self).allow_out_type and "out_type" in kwargs:
raise TypeError("encode() got an unexpected keyword argument 'out_type'")
return self._to_pieces(text)

def encode_as_pieces(self, text: Any) -> list[str]:
type(self).encode_as_pieces_inputs.append(text)
if not isinstance(text, str):
raise TypeError("encode_as_pieces() expects text as str")
return self._to_pieces(text)


class _FakeJanomeToken:
def __init__(self, surface: str) -> None:
Expand Down Expand Up @@ -236,10 +259,44 @@ def _fail_fallback() -> Callable[[str], list[str]]:
assert fallback_called is False
assert _FakeSentencePieceProcessor.last_loaded == model_path
assert _FakeSentencePieceProcessor.last_encoded[-1] == "テスト"
assert _FakeSentencePieceProcessor.last_encode_kwargs[-1] == {"out_type": str}
assert not _FakeSentencePieceProcessor.encode_as_pieces_inputs
assert tokens == ["stem:テスト"]
assert _FakeJanomeTokenizer.last_inputs[-1] == "テスト"


def test_sentencepiece_tokenizer_falls_back_to_encode_as_pieces(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
module = import_module("quality.evaluator.cli")

fallback_called = False

def _fail_fallback() -> Callable[[str], list[str]]:
nonlocal fallback_called
fallback_called = True
return lambda text: ["fallback"]

monkeypatch.setattr(module, "_fallback_surface_tokenizer", _fail_fallback)

_FakeSentencePieceProcessor.allow_out_type = False

model_path = tmp_path / "sp.model"
model_path.write_text("dummy", encoding="utf-8")

tokenizer = module._build_surface_tokenizer(model_path)

tokens = tokenizer("サンプル")

assert fallback_called is False
assert _FakeSentencePieceProcessor.last_loaded == model_path
assert _FakeSentencePieceProcessor.last_encoded[-1] == "サンプル"
assert _FakeSentencePieceProcessor.last_encode_kwargs[-1] == {"out_type": str}
assert _FakeSentencePieceProcessor.encode_as_pieces_inputs[-1] == "サンプル"
assert tokens == ["stem:サンプル"]
assert _FakeJanomeTokenizer.last_inputs[-1] == "サンプル"


@pytest.fixture(autouse=True)
def _stub_third_party(monkeypatch: pytest.MonkeyPatch) -> None:
bert_score_module = ModuleType("bert_score")
Expand Down Expand Up @@ -272,6 +329,10 @@ def _stub_third_party(monkeypatch: pytest.MonkeyPatch) -> None:

_FakeSentencePieceProcessor.last_loaded = None
_FakeSentencePieceProcessor.last_encoded = []
_FakeSentencePieceProcessor.last_encode_kwargs = []
_FakeSentencePieceProcessor.encode_as_pieces_inputs = []
_FakeSentencePieceProcessor.require_out_type = True
_FakeSentencePieceProcessor.allow_out_type = True
_FakeJanomeTokenizer.last_inputs = []


Expand Down Expand Up @@ -431,7 +492,8 @@ def test_evaluate_surface_prefers_sentencepiece_tokenizer(monkeypatch: pytest.Mo
metrics = module._evaluate_surface(["alpha"], ["alpha"], sentencepiece_model=sentencepiece_model)

assert metrics == {"rouge1": 0.78, "rougeL": 0.72}
assert _FakeSentencePieceTokenizer.last_model == sentencepiece_model
assert _FakeSentencePieceProcessor.last_encode_kwargs[-1] == {"out_type": str}
assert not _FakeSentencePieceProcessor.encode_as_pieces_inputs

def test_collect_pairs_skips_duplicate_identifiers(tmp_path: Path) -> None:
module = import_module("quality.evaluator.cli")
Expand Down