diff --git a/README.md b/README.md index b8688c8..8822420 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ The following scripts show how to tokenize/index/collate your dataset with `coll ```python import collatable -from collatable import Instance, LabelField, MetadataField, TextField +from collatable import LabelField, MetadataField, TextField from collatable.extras.indexer import LabelIndexer, TokenIndexer dataset = [ @@ -53,7 +53,7 @@ with token_indexer.context(train=True), label_indexer.context(train=True): ) metadata_field = MetadataField({"id": id_}) # Combine these fields into instance - instance = Instance( + instance = dict( text=text_field, label=label_field, metadata=metadata_field, @@ -85,7 +85,7 @@ Execution result: ```python import collatable -from collatable import Instance, SequenceLabelField, TextField +from collatable import SequenceLabelField, TextField from collatable.extras.indexer import LabelIndexer, TokenIndexer dataset = [ @@ -104,7 +104,7 @@ with token_indexer.context(train=True), label_indexer.context(train=True): for tokens, labels in dataset: text_field = TextField(tokens, indexer=token_indexer, padding_value=token_indexer[PAD_TOKEN]) label_field = SequenceLabelField(labels, text_field, indexer=label_indexer) - instance = Instance(text=text_field, label=label_field) + instance = dict(text=text_field, label=label_field) instances.append(instance) output = collatable.collate(instances) @@ -128,7 +128,7 @@ Execution result: ```python import collatable from collatable.extras.indexer import LabelIndexer, TokenIndexer -from collatable import AdjacencyField, Instance, ListField, SpanField, TextField +from collatable import AdjacencyField, ListField, SpanField, TextField PAD_TOKEN = "" token_indexer = TokenIndexer[str](specials=(PAD_TOKEN,)) @@ -143,7 +143,7 @@ with token_indexer.context(train=True), label_indexer.context(train=True): ) spans = ListField([SpanField(0, 2, text), SpanField(5, 7, text), SpanField(11, 12, text)]) relations = AdjacencyField([(0, 1), (0, 2)], spans, labels=["born-in", "lives-in"], indexer=label_indexer) - instance = Instance(text=text, spans=spans, relations=relations) + instance = dict(text=text, spans=spans, relations=relations) instances.append(instance) text = TextField( @@ -153,7 +153,7 @@ with token_indexer.context(train=True), label_indexer.context(train=True): ) spans = ListField([SpanField(0, 1, text), SpanField(5, 6, text)]) relations = AdjacencyField([(0, 1)], spans, labels=["capital-of"], indexer=label_indexer) - instance = Instance(text=text, spans=spans, relations=relations) + instance = dict(text=text, spans=spans, relations=relations) instances.append(instance) output = collatable.collate(instances) diff --git a/collatable/__init__.py b/collatable/__init__.py index 0225b8a..b7b0c6f 100644 --- a/collatable/__init__.py +++ b/collatable/__init__.py @@ -15,7 +15,6 @@ TensorField, TextField, ) -from collatable.instance import Instance # noqa: F401 __version__ = version("collatable") __all__ = [ @@ -30,6 +29,5 @@ "SpanField", "TensorField", "TextField", - "Instance", "collate", ] diff --git a/collatable/collator.py b/collatable/collator.py index 3834827..df32bff 100644 --- a/collatable/collator.py +++ b/collatable/collator.py @@ -1,12 +1,38 @@ -from typing import Dict, Sequence +from typing import Any, Dict, Mapping, Optional, Sequence, Set -from collatable.instance import Instance -from collatable.typing import DataArray +from collatable.fields import Field +from collatable.typing import DataArray, INamedTuple class Collator: - def __call__(self, instances: Sequence["Instance"]) -> Dict[str, DataArray]: - keys = set(instances[0]) + def __init__(self, field_names: Optional[Set[str]] = None) -> None: + self._field_names = field_names + + def _extract_fields(self, instance: Any) -> Mapping[str, Field]: + if not isinstance(instance, Mapping): + if hasattr(instance, "__dict__"): + members = instance.__dict__ + slots = set( + getattr( + instance, + "__slots__", + [key for key in members if not key.startswith("_") or key in (self._field_names or [])], + ) + ) + if self._field_names is not None and not (self._field_names <= slots): + raise ValueError(f"Field names {self._field_names - slots} not found") + instance = {slot: members[slot] for slot in slots if slot in members} + elif isinstance(instance, INamedTuple): + instance = instance._asdict() + return { + key: value + for key, value in instance.items() + if isinstance(value, Field) and (self._field_names is None or key in self._field_names) + } + + def __call__(self, instances: Sequence[Any]) -> Dict[str, DataArray]: + instances = [self._extract_fields(instance) for instance in instances] + keys = set(next(iter(instances), {}).keys()) array: Dict[str, DataArray] = {} for key in keys: values = [instance[key] for instance in instances] @@ -14,5 +40,8 @@ def __call__(self, instances: Sequence["Instance"]) -> Dict[str, DataArray]: return array -def collate(instances: Sequence[Instance]) -> Dict[str, DataArray]: - return Collator()(instances) +def collate( + instances: Sequence[Any], + field_names: Optional[Set[str]] = None, +) -> Dict[str, DataArray]: + return Collator(field_names)(instances) diff --git a/collatable/extras/dataloader.py b/collatable/extras/dataloader.py index 6b1bcca..ee04737 100644 --- a/collatable/extras/dataloader.py +++ b/collatable/extras/dataloader.py @@ -1,26 +1,27 @@ import math import random -from typing import Dict, Iterator, Sequence +from typing import Dict, Iterator, Mapping, Optional, Sequence from collatable.collator import Collator -from collatable.instance import Instance +from collatable.fields import Field from collatable.typing import DataArray class BatchIterator: def __init__( self, - dataset: Sequence[Instance], + dataset: Sequence[Mapping[str, Field]], batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, + collator: Optional[Collator] = None, ) -> None: self._dataset = dataset self._batch_size = batch_size self._shuffle = shuffle self._drop_last = drop_last self._offset = 0 - self._collator = Collator() + self._collator = collator or Collator() self._indices = list(range(len(self._dataset))) if self._shuffle: random.shuffle(self._indices) @@ -48,20 +49,18 @@ def __iter__(self) -> Iterator[Dict[str, DataArray]]: class DataLoader: def __init__( - self, - batch_size: int = 1, - shuffle: bool = False, - drop_last: bool = False, + self, batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, collator: Optional[Collator] = None ) -> None: self._batch_size = batch_size self._shuffle = shuffle self._drop_last = drop_last - self._collator = Collator() + self._collator = collator or Collator() - def __call__(self, dataset: Sequence[Instance]) -> BatchIterator: + def __call__(self, dataset: Sequence[Mapping[str, Field]]) -> BatchIterator: return BatchIterator( dataset, batch_size=self._batch_size, shuffle=self._shuffle, drop_last=self._drop_last, + collator=self._collator, ) diff --git a/collatable/instance.py b/collatable/instance.py deleted file mode 100644 index f3ebab4..0000000 --- a/collatable/instance.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Iterator - -from collatable.fields import Field - - -class Instance: - def __init__(self, **fields: Field) -> None: - self._fields = fields - - def __len__(self) -> int: - return len(self._fields) - - def __iter__(self) -> Iterator[str]: - return iter(self._fields) - - def __getitem__(self, name: str) -> Field: - return self._fields[name] diff --git a/collatable/typing.py b/collatable/typing.py index 95c568a..50469b5 100644 --- a/collatable/typing.py +++ b/collatable/typing.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping, Sequence, TypeVar, Union +import dataclasses +from typing import Any, ClassVar, Dict, Mapping, Protocol, Sequence, Tuple, TypeVar, Union, runtime_checkable import numpy from numpy.typing import ArrayLike, NDArray # noqa: F401 @@ -15,3 +16,21 @@ ScalarT_co = TypeVar("ScalarT_co", bound=Scalar, covariant=True) TensorT_co = TypeVar("TensorT_co", bound=Tensor, covariant=True) DataArrayT_co = TypeVar("DataArrayT_co", bound=DataArray, covariant=True) + + +@runtime_checkable +class IDataclass(Protocol): + __dataclass_fields__: ClassVar[Dict[str, dataclasses.Field]] + + +@runtime_checkable +class INamedTuple(Protocol): + _fields: ClassVar[Tuple[str, ...]] + + def _asdict(self) -> Dict[str, Any]: ... + + def _replace(self: "NamedTupleT", **kwargs: Any) -> "NamedTupleT": ... + + +DataclassT = TypeVar("DataclassT", bound=IDataclass) +NamedTupleT = TypeVar("NamedTupleT", bound=INamedTuple) diff --git a/tests/extras/test_dataloader.py b/tests/extras/test_dataloader.py index 2b18119..03a7ba9 100644 --- a/tests/extras/test_dataloader.py +++ b/tests/extras/test_dataloader.py @@ -1,6 +1,6 @@ from typing import Iterator -from collatable import Instance, LabelField, MetadataField, TextField +from collatable import LabelField, MetadataField, TextField from collatable.extras.dataloader import DataLoader from collatable.extras.dataset import Dataset from collatable.extras.indexer import LabelIndexer, TokenIndexer @@ -12,7 +12,7 @@ def test_dataloader() -> None: token_indexer = TokenIndexer[str](specials=[PAD_TOKEN, UNK_TOKEN], default=UNK_TOKEN) label_indexer = LabelIndexer[str]() - def read_dataset() -> Iterator[Instance]: + def read_dataset() -> Iterator[dict]: dataset = [ ("this is awesome", "positive"), ("this is a bad movie", "negative"), @@ -34,7 +34,7 @@ def read_dataset() -> Iterator[Instance]: ) metadata_field = MetadataField({"id": id_}) # Combine these fields into instance - instance = Instance( + instance = dict( text=text_field, label=label_field, metadata=metadata_field, diff --git a/tests/fields/test_adjacency_field.py b/tests/fields/test_adjacency_field.py index b89f888..1477f51 100644 --- a/tests/fields/test_adjacency_field.py +++ b/tests/fields/test_adjacency_field.py @@ -6,7 +6,6 @@ from collatable.fields.list_field import ListField from collatable.fields.span_field import SpanField from collatable.fields.text_field import TextField -from collatable.instance import Instance def test_adajacency_field() -> None: @@ -23,7 +22,7 @@ def test_adajacency_field() -> None: ) spans = ListField([SpanField(0, 2, text), SpanField(5, 7, text), SpanField(11, 12, text)]) relations = AdjacencyField([(0, 1), (0, 2)], spans, labels=["born-in", "lives-in"], indexer=label_indexer) - instance = Instance(text=text, spans=spans, relations=relations) + instance = dict(text=text, spans=spans, relations=relations) instances.append(instance) text = TextField( @@ -33,7 +32,7 @@ def test_adajacency_field() -> None: ) spans = ListField([SpanField(0, 1, text), SpanField(5, 6, text)]) relations = AdjacencyField([(0, 1)], spans, labels=["capital-of"], indexer=label_indexer) - instance = Instance(text=text, spans=spans, relations=relations) + instance = dict(text=text, spans=spans, relations=relations) instances.append(instance) output = collate(instances)["relations"] diff --git a/tests/test_instance.py b/tests/test_collator.py similarity index 92% rename from tests/test_instance.py rename to tests/test_collator.py index d33d29f..f6d38ab 100644 --- a/tests/test_instance.py +++ b/tests/test_collator.py @@ -5,7 +5,6 @@ from collatable.collator import Collator from collatable.extras.indexer import LabelIndexer, TokenIndexer from collatable.fields import LabelField, MetadataField, TextField -from collatable.instance import Instance def test_instance() -> None: @@ -19,11 +18,11 @@ def test_instance() -> None: token_indexer = TokenIndexer[str]() label_indexer = LabelIndexer[str]() - instances: List[Instance] = [] + instances: List[dict] = [] with token_indexer.context(train=True), label_indexer.context(train=True): for id_, (text, label) in enumerate(dataset): tokens = text.split() - instance = Instance( + instance = dict( text=TextField(tokens, indexer=token_indexer), label=LabelField(label, indexer=label_indexer), metadata=MetadataField({"id": id_}),