diff --git a/data/jmdict_basic_ranker.joblib b/data/jmdict_basic_ranker.joblib new file mode 100644 index 0000000..1611c43 Binary files /dev/null and b/data/jmdict_basic_ranker.joblib differ diff --git a/examples/baseline.ipynb b/examples/baseline.ipynb index 85bfb8f..1b97429 100644 --- a/examples/baseline.ipynb +++ b/examples/baseline.ipynb @@ -11,8 +11,7 @@ "from sklearn.linear_model import LogisticRegression\n", "from sklearn.model_selection import train_test_split\n", "\n", - "from wsd.models import JMDict, JMDictWithPointWiseRanking, \\\n", - " JMDictGeminiRanking\n", + "from wsd.models import JMDict, PointWiseRanker, GeminiRanker\n", "from wsd.utils import load_dataset, accuracy" ] }, @@ -23,8 +22,7 @@ "outputs": [], "source": [ "basedir = os.getenv('PJ_DIR')\n", - "X, y = load_dataset(f'{basedir}/data/dataset.xml')\n", - "X = [''.join(x) for x in X]\n", + "X, y = load_dataset(f'{basedir}/data/dataset_.xml')\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.33, random_state=42)" ] @@ -46,7 +44,8 @@ "outputs": [], "source": [ "model = LogisticRegression(C=10, penalty='l1', solver='liblinear')\n", - "jmdict_basic = JMDictWithPointWiseRanking(ranking_model=model)\n", + "ranker = PointWiseRanker(ranking_model=model)\n", + "jmdict_basic = JMDict(ranker=ranker)\n", "jmdict_basic.fit(X_train, y_train)\n", "basic_preds = jmdict_basic.predict(X_test)" ] @@ -57,7 +56,8 @@ "metadata": {}, "outputs": [], "source": [ - "jmdict_gemini = JMDictGeminiRanking(\"gemini-2.5-pro-exp-03-25\")\n", + "ranker = GeminiRanker(\"gemini-2.5-flash-lite\")\n", + "jmdict_gemini = JMDict(ranker=ranker)\n", "gemini_preds = jmdict_gemini.predict(X_test)" ] }, @@ -77,7 +77,27 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "jmdict_basic.fit(X, y)\n", + "pd.DataFrame(\n", + " jmdict_basic.ranker.model.coef_, \n", + " columns=jmdict_basic.ranker.vec.get_feature_names_out()\n", + ").T\\\n", + " .sort_values(by=0, ascending=False)\\\n", + " .plot(kind='barh', figsize=(5, 14), title='Basic Ranker Coefficients')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jmdict_basic.ranker.save('../data/jmdict_basic_ranker.joblib')" + ] } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index 447bb7b..6b0de2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ all = [ "numpy", "pillow", "rich", - "scikit-learn", + "scikit-learn==1.7.1", "typer", ] @@ -50,5 +50,6 @@ Issues = "https://github.com/linalgo/wsd/issues" requires = ["setuptools >= 61.0"] build-backend = "setuptools.build_meta" -[tool.setuptools] -packages = ["wsd"] +[tool.setuptools.packages.find] +where = ["."] +include = ["wsd*"] diff --git a/setup.py b/setup.py deleted file mode 100644 index b613543..0000000 --- a/setup.py +++ /dev/null @@ -1,12 +0,0 @@ -"""WSD package setup script.""" -import re - -from setuptools import find_packages, setup - -with open("pyproject.toml", encoding="utf-8") as file: - setup( - version=re.search( - r'^version\s*=\s*"(.*?)"', file.read(), re.M - ).group(1), - packages=find_packages(), - ) diff --git a/wsd/__init__.py b/wsd/__init__.py index ee57fa9..a37c2eb 100644 --- a/wsd/__init__.py +++ b/wsd/__init__.py @@ -1,4 +1,4 @@ """Init file containing version and build""" __version__ = "0.0.1rc0" -__build__ = "Tue May 6 11:02:04 PM JST 2025" +__build__ = "Sat Aug 23 09:48:45 JST 2025" diff --git a/wsd/models/__init__.py b/wsd/models/__init__.py index 634eb66..48b97fe 100644 --- a/wsd/models/__init__.py +++ b/wsd/models/__init__.py @@ -1,4 +1,4 @@ """A collection of WSD models.""" -from .baseline import JMDict, Token # noqa: F401 -from .basic import JMDictWithPointWiseRanking # noqa: F401 -from .gemini import JMDictGeminiRanking # noqa: F401 +from .dictionary import * # noqa: F401 +from .rankers import * # noqa: F401 +from .retrievers import * # noqa: F401 diff --git a/wsd/models/baseline.py b/wsd/models/dictionary.py similarity index 59% rename from wsd/models/baseline.py rename to wsd/models/dictionary.py index 5538e37..f593e74 100644 --- a/wsd/models/baseline.py +++ b/wsd/models/dictionary.py @@ -3,85 +3,64 @@ import json import os import uuid -from abc import ABC, abstractmethod -from collections import defaultdict -from dataclasses import asdict, dataclass -from typing import Any +from dataclasses import asdict +import joblib import tqdm from fugashi import Tagger from linalgo.annotate import Annotation, Annotator, Document, Entity, Task -from wsd.parsers import JMDictParser -from wsd.parsers.jmdict import Entry +from wsd.parsers.jmdict import Entry, Token +from wsd.models.retrievers import Retriever, LocalRetriever +from wsd.models.rankers import Candidate, Ranker, DummyRanker, \ + PointWiseRanker, GeminiRanker -@dataclass -class Token: - """Token dataclass.""" - text: str = '' - lemma: str = '' - pos: str = '' +class Dictionary: + """A simple dictionary base class.""" + def __init__( + self, + retriever: Retriever = None, + ranker: Ranker = None, + annotator: Annotator = None + ): + self.retriever = retriever + self.ranker = ranker or DummyRanker() + self.annotator = annotator or Annotator( + id=uuid.uuid3(uuid.NAMESPACE_URL, 'jmdict-v3').hex, + name='jmdict-v3', + model='MACHINE', + entity=Entity(id=os.getenv('LINHUB_ENTITY')), + task=Task(id=os.getenv('LINHUB_TASK')) + ) -data_dir = os.path.join(os.path.dirname(__file__), '../../data') + def search(self, text: str, context=None) -> list[Entry]: + """Search for an entry by text and rank the results. + Currently returns all entries that contain the text in either the kanji + or reading. + """ + return self.ranker.rank(self.retriever.retrieve(text), context)[0] -class RankingModel(ABC): - """Base class for the ranking models""" + def feeling_lucky(self, text: str, context=None) -> Entry: + """Return the first entry found. - @abstractmethod - def _rank(self, candidates: list[Entry], context: Any): - """Rank results based on the given context. + Currently returns the first entry that contains the text in either the + kanji or reading. Parameters ---------- - candidates : List[Entry] - A list of entries to rank - context : any - The context to use for ranking + text : str + The text to search for Returns ------- - List[Entry] - The ranked list of entries + Entry + The first entry that contains the query. """ - raise NotImplementedError - - -class JMDict(RankingModel): - """A simple dictionary interface for JMDict.""" - - entries = None - indexed = False - - def __init__( - self, - dictionary: str = 'JMdict_en.gz', - annotator: Annotator = None - ): - self._index(os.path.join(data_dir, dictionary)) - self.annotator = annotator or Annotator( - id=uuid.uuid3(uuid.NAMESPACE_URL, 'jmdict-v3').hex, - name='jmdict-v3', - model='MACHINE', - entity=Entity(id=os.getenv('LINHUB_ENTITY')), - task=Task(id=os.getenv('LINHUB_TASK')) - ) - - def _index(self, filename): - """Create an index to speed up lookups.""" - if self.indexed: - return - self.index = defaultdict(set) - self.entries = JMDictParser.parse(filename) - for entry in self.entries: - self.index[entry.ent_seq] = entry - for k_ele in entry.k_ele: - self.index[k_ele.keb].add(entry) - for r_ele in entry.r_ele: - self.index[r_ele.reb].add(entry) - self.indexed = True + entries = self.search(text, context) + return entries[0] if entries else None def annotate( self, @@ -113,7 +92,8 @@ def annotate( if body is not None: a = Annotation( document=doc, - body=json.dumps(body, ensure_ascii=False).encode('utf-8'), + body=json.dumps( + body, ensure_ascii=False).encode('utf-8'), start=start, end=start + len(token.surface), annotator=self.annotator, @@ -126,7 +106,7 @@ def annotate( return documents[0] return documents - def tokenize(self, sentence) -> list[Token]: + def tokenize(self, sentence: str) -> list[Token]: """Tokenize a sentence. Parameters ---------- @@ -150,7 +130,34 @@ def tokenize(self, sentence) -> list[Token]: ) return tokens - def predict(self, sentences: list[str]) -> list[str]: + def fit(self, X: list[list[str]], y: list[list[str]]): + """Fit the ranker. + + Parameters + ---------- + X : list[list[Token]] + The tokenized documents to 'featurize'. + y : list[list[str]] + The list of list of labels for each tokens in the X sentences. + + Returns + ------- + self + The fitted dictionary. + """ + XX = [] + for doc in X: + xx = [] + for token in doc: + candidates = [] + for candidate in self.retriever.retrieve(token.lemma): + candidates.append(Candidate(candidate, token)) + xx.append(candidates) + XX.append(xx) + self.ranker.fit(XX, y) + return self + + def predict(self, sentences: list[str | Token]) -> list[str]: """Predict the `ent_seq` for each token in a sentence. Parameters @@ -168,7 +175,9 @@ def predict(self, sentences: list[str]) -> list[str]: preds = [] for sentence in tqdm.tqdm(sentences): pred = [] - for token in self.tokenize(sentence): + if isinstance(sentence, str): + sentence = self.tokenize(sentence) + for token in sentence: context = {'sentence': sentence, 'token': token} entry = self.feeling_lucky(token.lemma, context) ent_seq = entry.ent_seq if entry else None @@ -176,6 +185,29 @@ def predict(self, sentences: list[str]) -> list[str]: preds.append(pred) return preds + +class JMDict(Dictionary): + """A simple dictionary interface for JMDict.""" + + def __init__(self, retriever='local', ranker='dummy', file=None, *args, **kwargs): + if retriever == 'local': + retriever = LocalRetriever(file) + elif isinstance(retriever, Retriever): + pass + else: + raise ValueError(f"Invalid retriever: {retriever}") + if ranker == 'dummy': + ranker = DummyRanker() + elif ranker == 'pointwise': + ranker = PointWiseRanker() + elif ranker == 'gemini': + ranker = GeminiRanker() + elif isinstance(ranker, Ranker): + ranker.tokenize = self.tokenize + else: + raise ValueError(f"Invalid ranker: {ranker}") + super().__init__(retriever, ranker, *args, **kwargs) + def get(self, ent_seq: str) -> Entry: """Get an entry by its `ent_seq`. @@ -189,87 +221,22 @@ def get(self, ent_seq: str) -> Entry: Entry The entry with the given `ent_seq`. """ - for entry in self.entries: + for entry in self.retriever.entries: if entry.ent_seq == ent_seq: return entry return None - def _rank(self, candidates, context=None) -> tuple[list[Entry], list[float]]: - """A base ranking function that does nothing. + def save(self, path: str): + """Save the dictionary to a file.""" + joblib.dump({'retriever': self.retriever, 'ranker': self.ranker}, path) - Parameters - ---------- - candidates: List[Entry] - The candidates to rank - context : Any - A contet to inform the ranking - - Returns - ------- - candidates: List[Entry] - The ranked candidates - scores: List[float] - The score of each candidate - """ - if len(candidates) < 1: - return [], [] - return candidates, [1] * len(candidates) - - def _lookup(self, text) -> list[Entry]: - """Lookup an entry by text. - - Currently returns all entries that contain the text in either the kanji - or reading. - - Parameters - ---------- - text : str - The text to search for - - Returns - ------- - List[Entry] - A list of entries that contain the query. - """ - return list(self.index[text]) - - def search(self, text: str, context=None) -> list[Entry]: - """Search for an entry by text and rank the results. - - Currently returns all entries that contain the text in either the kanji - or reading. - - Parameters - ---------- - text : str - The text to search for - - Returns - ------- - List[Entry] - A list of entries that contain the query. - """ - res, _ = self._rank(self._lookup(text), context) - return res - - def feeling_lucky(self, text: str, context=None) -> Entry: - """Return the first entry found. - - Currently returns the first entry that contains the text in either the - kanji or reading. - - Parameters - ---------- - text : str - The text to search for - - Returns - ------- - Entry - The first entry that contains the query. - """ - entries = self.search(text, context) - return entries[0] if entries else None + @classmethod + def load(cls, path: str): + """Load the dictionary from a file.""" + o = joblib.load(path) + return cls(retriever=o['retriever'], ranker=o['ranker']) -__all__ = ['JMDict', 'Token', 'RankingModel'] +__all__ = [ + 'Dictionary', 'DummyRanker', 'JMDict', 'Ranker', 'Retriever', 'Token' +] diff --git a/wsd/models/gemini.py b/wsd/models/gemini.py index d5626ea..06bf552 100644 --- a/wsd/models/gemini.py +++ b/wsd/models/gemini.py @@ -5,8 +5,7 @@ from google import genai from google.genai import types -from wsd.models.baseline import JMDict, Token -from wsd.parsers import Entry +from wsd.parsers import Entry, Token SYSTEM = """You are a Japanese dictionary ranking system. When given several candidate definitions for a Japanese word in the context of @@ -95,25 +94,4 @@ def generate(prompt, model_name): return json.loads(response.text) -class JMDictGeminiRanking(JMDict): - """A dictionary using Google's Gemini to rank candidate definitions.""" - - def __init__(self, model_name="gemini-2.5-pro-exp-03-25", **kwargs): - super().__init__(**kwargs) - self.model_name = model_name - - # pylint: disable=signature-differs - def _rank(self, candidates: list[Entry], context): - if len(candidates) < 1: - return [], [] - prompt = get_prompt(context['sentence'], context['token'], candidates) - res = generate(prompt, model_name=self.model_name) - if 'answer' in res: - ans = max(0, min(res['answer'], len(candidates) - 1)) - top = candidates.pop(ans) - candidates.insert(0, top) - scores = [1] + [0] * (len(candidates) - 1) - return candidates, scores - - -__all__ = ['JMDictGeminiRanking'] +__all__ = ['GeminiRanker'] diff --git a/wsd/models/basic.py b/wsd/models/rankers.py similarity index 51% rename from wsd/models/basic.py rename to wsd/models/rankers.py index a70af0a..cef0320 100644 --- a/wsd/models/basic.py +++ b/wsd/models/rankers.py @@ -1,13 +1,57 @@ -# pylint: disable=invalid-name -"""A basic dictionary with ranking based on a binary classifier.""" +"""A collection of rankers for WSD.""" + +from abc import ABC +from dataclasses import dataclass + +import joblib from sklearn.feature_extraction import DictVectorizer from sklearn.linear_model import LogisticRegression -from wsd.models.baseline import JMDict, Token -from wsd.parsers import Entry +from wsd.parsers import Entry, Token +from wsd.models.gemini import get_prompt, generate + + +class Ranker(ABC): + """A simple dictionary interface for JMDict.""" + + def rank(self, candidates, context=None) -> tuple[list[Entry], list[float]]: + """A base ranking function that does nothing. + + Parameters + ---------- + candidates: List[Entry] + The candidates to rank + context : Any + A contet to inform the ranking + + Returns + ------- + candidates: List[Entry] + The ranked candidates + scores: List[float] + The score of each candidate + """ + if len(candidates) < 1: + return [], [] + return candidates, [1] * len(candidates) + + +class DummyRanker(Ranker): + """A simple dictionary interface for JMDict.""" + + def rank(self, candidates, context=None) -> tuple[list[Entry], list[float]]: + """A base ranking function that does nothing.""" + return candidates, [1] * len(candidates) + + +@dataclass +class Candidate: + """A candidate for a point-wise ranking.""" + entry: Entry + token: Token -class JMDictWithPointWiseRanking(JMDict): +class PointWiseRanker(Ranker): """A dictionary with the ranking function based on Binary Classification.""" def __init__(self, ranking_model=None, **kwargs): @@ -17,7 +61,7 @@ def __init__(self, ranking_model=None, **kwargs): if self.model is None: self.model = LogisticRegression() - def _preprocess(self, X: list[str], y: list[str]): + def _preprocess(self, X: list[list[Candidate]], y: list[list[str]]): """Create features for each candidate In the PointWise Binary Classification, the preprocessing just creates @@ -27,10 +71,10 @@ def _preprocess(self, X: list[str], y: list[str]): Parameters ---------- - X : list[str] - A list of sentences to tokenize and 'featurize'. - y : list[str] - The list of labels for each tokens in the X sentences. + X : list[list[Candidate]] + A list of tokenized sentences to 'featurize'. + y : list[list[str]] + The list of list of labels for each tokens in the X sentences. Returns ------- @@ -41,13 +85,12 @@ def _preprocess(self, X: list[str], y: list[str]): """ flat_X, flat_y = [], [] for doc, labels in zip(X, y): - tokens = self.tokenize(doc) - for token, label in zip(tokens, labels): - candidates = self._lookup(token.lemma) + for candidates, label in zip(doc, labels): for candidate in candidates: - feat = self._create_features(candidate, token) + feat = self._create_features( + candidate.entry, candidate.token) flat_X.append(feat) - flat_y.append(label == candidate.ent_seq) + flat_y.append(label == candidate.entry.ent_seq) return flat_X, flat_y def fit(self, X: list[list[Token]], y: list[list[str]]): @@ -62,7 +105,7 @@ def fit(self, X: list[list[Token]], y: list[list[str]]): return self # pylint: disable=signature-differs - def _rank(self, candidates: list[Entry], context): + def rank(self, candidates: list[Entry], context): """A basic ranking function using the score of the binary classifier. Parameters @@ -104,6 +147,39 @@ def _create_features(self, candidate, token): features['reb.text'] = token.text in reb features['reb.lemma'] = token.lemma in reb return features + + def save(self, path: str): + """Save the ranker to a file.""" + joblib.dump({'vec': self.vec, 'model': self.model}, path) + + @classmethod + def load(cls, path: str): + """Load the ranker from a file.""" + d = joblib.load(path) + o = cls(ranking_model=d['model']) + o.vec = d['vec'] + return o + + +class GeminiRanker(Ranker): + """A dictionary using Google's Gemini to rank candidate definitions.""" + + def __init__(self, model_name="gemini-2.5-pro-exp-03-25", **kwargs): + super().__init__(**kwargs) + self.model_name = model_name + + # pylint: disable=signature-differs + def rank(self, candidates: list[Entry], context): + if len(candidates) < 1: + return [], [] + prompt = get_prompt(context['sentence'], context['token'], candidates) + res = generate(prompt, model_name=self.model_name) + if 'answer' in res: + ans = max(0, min(res['answer'], len(candidates) - 1)) + top = candidates.pop(ans) + candidates.insert(0, top) + scores = [1] + [0] * (len(candidates) - 1) + return candidates, scores -__all__ = ['JMDictWithPointWiseRanking'] +__all__ = ['DummyRanker', 'Ranker', 'PointWiseRanker', 'GeminiRanker'] diff --git a/wsd/models/retrievers.py b/wsd/models/retrievers.py new file mode 100644 index 0000000..4e1b913 --- /dev/null +++ b/wsd/models/retrievers.py @@ -0,0 +1,87 @@ +"""A collection of retrievers for WSD.""" +from collections import defaultdict +from abc import ABC + +from wsd.parsers import Entry, JMDictParser + + +class Retriever(ABC): + """Base class for the searcher models""" + + def retrieve(self, text: str) -> list[Entry]: + """Search for an entry by text and rank the results. + + Currently returns all entries that contain the text in either the kanji + or reading. + + Parameters + ---------- + text : str + The text to search for + + Returns + ------- + List[Entry] + A list of entries that contain the query. + """ + raise NotImplementedError("Subclasses must implement this method") + +class LocalRetriever(Retriever): + entries = None + indexed = False + + def __init__(self, file=None, **kwargs): + self.file = file + self._index(self.file) + + def _index(self, filename): + """Create an index to speed up lookups.""" + if self.indexed: + return + self.index = defaultdict(set) + self.entries = JMDictParser.parse(filename) + for entry in self.entries: + self.index[entry.ent_seq] = entry + for k_ele in entry.k_ele: + self.index[k_ele.keb].add(entry) + for r_ele in entry.r_ele: + self.index[r_ele.reb].add(entry) + self.indexed = True + + def retrieve(self, text) -> list[Entry]: + """Lookup an entry by text. + + Currently returns all entries that contain the text in either the kanji + or reading. + + Parameters + ---------- + text : str + The text to search for + + Returns + ------- + List[Entry] + A list of entries that contain the query. + """ + return list(self.index[text]) + + def get(self, ent_seq: str) -> Entry: + """Get an entry by its `ent_seq`. + + Parameters + ---------- + ent_seq : str + The `ent_seq` of the entry to get + + Returns + ------- + Entry + The entry with the given `ent_seq`. + """ + for entry in self.entries: + if entry.ent_seq == ent_seq: + return entry + return None + +__all__ = ['LocalRetriever', 'Retriever'] \ No newline at end of file diff --git a/wsd/models/tests/test_baseline.py b/wsd/models/tests/test_baseline.py index e69de29..c1ef93d 100644 --- a/wsd/models/tests/test_baseline.py +++ b/wsd/models/tests/test_baseline.py @@ -0,0 +1,23 @@ +import unittest + +from wsd.models import JMDict + + +class TestJMDict(unittest.TestCase): + """Test the JMDict model.""" + + def setUp(self): + self.jmdict = JMDict() + + def test_jmdict(self): + entries, _ = self.jmdict.search('日本語') + self.assertEqual(len(entries), 1) + self.assertEqual(entries[0].ent_seq, '1464530') + + def test_no_entry_found(self): + entries, _ = self.jmdict.search('qwefasdfasg') + self.assertEqual(len(entries), 0) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/wsd/parsers/__init__.py b/wsd/parsers/__init__.py index df57f93..6a08f97 100644 --- a/wsd/parsers/__init__.py +++ b/wsd/parsers/__init__.py @@ -1,7 +1,3 @@ """A collection of parsers for various data sources.""" -# from .jmdict import * -# from .xlwsd import * -from .jmdict import Entry, JMDictParser -from .xlwsd import XLWSDParser - -__all__=["Entry", "JMDictParser", "XLWSDParser"] +from .jmdict import * +from .xlwsd import * diff --git a/wsd/parsers/jmdict.py b/wsd/parsers/jmdict.py index 2ef7ff4..0c32e9a 100644 --- a/wsd/parsers/jmdict.py +++ b/wsd/parsers/jmdict.py @@ -1,6 +1,7 @@ # pylint: disable=not-callable """A parser for the JMdict dictionary.""" import gzip +import os import xml.etree.ElementTree as ET from dataclasses import dataclass, field @@ -139,6 +140,9 @@ class JMDictParser: @classmethod def parse(cls, file_path): """Parse a JMdict file.""" + if file_path is None: + file_path = os.path.join( + os.path.dirname(__file__), '../../data/JMdict_en.gz') entries = [] with gzip.open(file_path, "rb") as f: # pylint: disable=invalid-name tree = ET.parse(f) @@ -148,4 +152,14 @@ def parse(cls, file_path): return entries -__all__ = ['JMDictParser', 'Entry', 'Kanji', 'Reading', 'Gloss', 'Sense'] +@dataclass +class Token: + """Token dataclass.""" + text: str = '' + lemma: str = '' + pos: str = '' + + +__all__ = [ + 'JMDictParser', 'Entry', 'Kanji', 'Reading', 'Gloss', 'Sense', 'Token' +] diff --git a/wsd/utils.py b/wsd/utils.py index cbbabb2..b50af3c 100644 --- a/wsd/utils.py +++ b/wsd/utils.py @@ -7,6 +7,8 @@ from linalgo.annotate import Filter, Pipeline, Sequence2SequenceTransformer from linalgo.hub import BQClient +from wsd.parsers.jmdict import Token + tagger = Tagger('-Owakati') @@ -72,10 +74,16 @@ def load_dataset(filename: str) -> tuple[list[list[str]], list[list[str]]]: X, y = [], [] for doc_element in root.findall("document"): - doc_tokens: list[str] = [] + doc_tokens: list[Token] = [] doc_labels: list[str] = [] for token_element in doc_element.findall("token"): - doc_tokens.append(token_element.text) + doc_tokens.append( + Token( + text=token_element.text, + lemma=token_element.get("lemma"), + pos=token_element.get("pos") + ) + ) ent_seq = token_element.get("ent_seq") if ent_seq == '': ent_seq = None