Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The following scripts show how to tokenize/index/collate your dataset with `coll

```python
import collatable
from collatable import Instance, LabelField, MetadataField, TextField
from collatable import LabelField, MetadataField, TextField
from collatable.extras.indexer import LabelIndexer, TokenIndexer

dataset = [
Expand Down Expand Up @@ -53,7 +53,7 @@ with token_indexer.context(train=True), label_indexer.context(train=True):
)
metadata_field = MetadataField({"id": id_})
# Combine these fields into instance
instance = Instance(
instance = dict(
text=text_field,
label=label_field,
metadata=metadata_field,
Expand Down Expand Up @@ -85,7 +85,7 @@ Execution result:

```python
import collatable
from collatable import Instance, SequenceLabelField, TextField
from collatable import SequenceLabelField, TextField
from collatable.extras.indexer import LabelIndexer, TokenIndexer

dataset = [
Expand All @@ -104,7 +104,7 @@ with token_indexer.context(train=True), label_indexer.context(train=True):
for tokens, labels in dataset:
text_field = TextField(tokens, indexer=token_indexer, padding_value=token_indexer[PAD_TOKEN])
label_field = SequenceLabelField(labels, text_field, indexer=label_indexer)
instance = Instance(text=text_field, label=label_field)
instance = dict(text=text_field, label=label_field)
instances.append(instance)

output = collatable.collate(instances)
Expand All @@ -128,7 +128,7 @@ Execution result:
```python
import collatable
from collatable.extras.indexer import LabelIndexer, TokenIndexer
from collatable import AdjacencyField, Instance, ListField, SpanField, TextField
from collatable import AdjacencyField, ListField, SpanField, TextField

PAD_TOKEN = "<PAD>"
token_indexer = TokenIndexer[str](specials=(PAD_TOKEN,))
Expand All @@ -143,7 +143,7 @@ with token_indexer.context(train=True), label_indexer.context(train=True):
)
spans = ListField([SpanField(0, 2, text), SpanField(5, 7, text), SpanField(11, 12, text)])
relations = AdjacencyField([(0, 1), (0, 2)], spans, labels=["born-in", "lives-in"], indexer=label_indexer)
instance = Instance(text=text, spans=spans, relations=relations)
instance = dict(text=text, spans=spans, relations=relations)
instances.append(instance)

text = TextField(
Expand All @@ -153,7 +153,7 @@ with token_indexer.context(train=True), label_indexer.context(train=True):
)
spans = ListField([SpanField(0, 1, text), SpanField(5, 6, text)])
relations = AdjacencyField([(0, 1)], spans, labels=["capital-of"], indexer=label_indexer)
instance = Instance(text=text, spans=spans, relations=relations)
instance = dict(text=text, spans=spans, relations=relations)
instances.append(instance)

output = collatable.collate(instances)
Expand Down
2 changes: 0 additions & 2 deletions collatable/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
TensorField,
TextField,
)
from collatable.instance import Instance # noqa: F401

__version__ = version("collatable")
__all__ = [
Expand All @@ -30,6 +29,5 @@
"SpanField",
"TensorField",
"TextField",
"Instance",
"collate",
]
43 changes: 36 additions & 7 deletions collatable/collator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,47 @@
from typing import Dict, Sequence
from typing import Any, Dict, Mapping, Optional, Sequence, Set

from collatable.instance import Instance
from collatable.typing import DataArray
from collatable.fields import Field
from collatable.typing import DataArray, INamedTuple


class Collator:
def __call__(self, instances: Sequence["Instance"]) -> Dict[str, DataArray]:
keys = set(instances[0])
def __init__(self, field_names: Optional[Set[str]] = None) -> None:
self._field_names = field_names

def _extract_fields(self, instance: Any) -> Mapping[str, Field]:
if not isinstance(instance, Mapping):
if hasattr(instance, "__dict__"):
members = instance.__dict__
slots = set(
getattr(
instance,
"__slots__",
[key for key in members if not key.startswith("_") or key in (self._field_names or [])],
)
)
if self._field_names is not None and not (self._field_names <= slots):
raise ValueError(f"Field names {self._field_names - slots} not found")
instance = {slot: members[slot] for slot in slots if slot in members}
elif isinstance(instance, INamedTuple):
instance = instance._asdict()
return {
key: value
for key, value in instance.items()
if isinstance(value, Field) and (self._field_names is None or key in self._field_names)
}

def __call__(self, instances: Sequence[Any]) -> Dict[str, DataArray]:
instances = [self._extract_fields(instance) for instance in instances]
keys = set(next(iter(instances), {}).keys())
array: Dict[str, DataArray] = {}
for key in keys:
values = [instance[key] for instance in instances]
array[key] = values[0].collate(values)
return array


def collate(instances: Sequence[Instance]) -> Dict[str, DataArray]:
return Collator()(instances)
def collate(
instances: Sequence[Any],
field_names: Optional[Set[str]] = None,
) -> Dict[str, DataArray]:
return Collator(field_names)(instances)
19 changes: 9 additions & 10 deletions collatable/extras/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import math
import random
from typing import Dict, Iterator, Sequence
from typing import Dict, Iterator, Mapping, Optional, Sequence

from collatable.collator import Collator
from collatable.instance import Instance
from collatable.fields import Field
from collatable.typing import DataArray


class BatchIterator:
def __init__(
self,
dataset: Sequence[Instance],
dataset: Sequence[Mapping[str, Field]],
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
collator: Optional[Collator] = None,
) -> None:
self._dataset = dataset
self._batch_size = batch_size
self._shuffle = shuffle
self._drop_last = drop_last
self._offset = 0
self._collator = Collator()
self._collator = collator or Collator()
self._indices = list(range(len(self._dataset)))
if self._shuffle:
random.shuffle(self._indices)
Expand Down Expand Up @@ -48,20 +49,18 @@ def __iter__(self) -> Iterator[Dict[str, DataArray]]:

class DataLoader:
def __init__(
self,
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
self, batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, collator: Optional[Collator] = None
) -> None:
self._batch_size = batch_size
self._shuffle = shuffle
self._drop_last = drop_last
self._collator = Collator()
self._collator = collator or Collator()

def __call__(self, dataset: Sequence[Instance]) -> BatchIterator:
def __call__(self, dataset: Sequence[Mapping[str, Field]]) -> BatchIterator:
return BatchIterator(
dataset,
batch_size=self._batch_size,
shuffle=self._shuffle,
drop_last=self._drop_last,
collator=self._collator,
)
17 changes: 0 additions & 17 deletions collatable/instance.py

This file was deleted.

21 changes: 20 additions & 1 deletion collatable/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Mapping, Sequence, TypeVar, Union
import dataclasses
from typing import Any, ClassVar, Dict, Mapping, Protocol, Sequence, Tuple, TypeVar, Union, runtime_checkable

import numpy
from numpy.typing import ArrayLike, NDArray # noqa: F401
Expand All @@ -15,3 +16,21 @@
ScalarT_co = TypeVar("ScalarT_co", bound=Scalar, covariant=True)
TensorT_co = TypeVar("TensorT_co", bound=Tensor, covariant=True)
DataArrayT_co = TypeVar("DataArrayT_co", bound=DataArray, covariant=True)


@runtime_checkable
class IDataclass(Protocol):
__dataclass_fields__: ClassVar[Dict[str, dataclasses.Field]]


@runtime_checkable
class INamedTuple(Protocol):
_fields: ClassVar[Tuple[str, ...]]

def _asdict(self) -> Dict[str, Any]: ...

def _replace(self: "NamedTupleT", **kwargs: Any) -> "NamedTupleT": ...


DataclassT = TypeVar("DataclassT", bound=IDataclass)
NamedTupleT = TypeVar("NamedTupleT", bound=INamedTuple)
6 changes: 3 additions & 3 deletions tests/extras/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Iterator

from collatable import Instance, LabelField, MetadataField, TextField
from collatable import LabelField, MetadataField, TextField
from collatable.extras.dataloader import DataLoader
from collatable.extras.dataset import Dataset
from collatable.extras.indexer import LabelIndexer, TokenIndexer
Expand All @@ -12,7 +12,7 @@ def test_dataloader() -> None:
token_indexer = TokenIndexer[str](specials=[PAD_TOKEN, UNK_TOKEN], default=UNK_TOKEN)
label_indexer = LabelIndexer[str]()

def read_dataset() -> Iterator[Instance]:
def read_dataset() -> Iterator[dict]:
dataset = [
("this is awesome", "positive"),
("this is a bad movie", "negative"),
Expand All @@ -34,7 +34,7 @@ def read_dataset() -> Iterator[Instance]:
)
metadata_field = MetadataField({"id": id_})
# Combine these fields into instance
instance = Instance(
instance = dict(
text=text_field,
label=label_field,
metadata=metadata_field,
Expand Down
5 changes: 2 additions & 3 deletions tests/fields/test_adjacency_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from collatable.fields.list_field import ListField
from collatable.fields.span_field import SpanField
from collatable.fields.text_field import TextField
from collatable.instance import Instance


def test_adajacency_field() -> None:
Expand All @@ -23,7 +22,7 @@ def test_adajacency_field() -> None:
)
spans = ListField([SpanField(0, 2, text), SpanField(5, 7, text), SpanField(11, 12, text)])
relations = AdjacencyField([(0, 1), (0, 2)], spans, labels=["born-in", "lives-in"], indexer=label_indexer)
instance = Instance(text=text, spans=spans, relations=relations)
instance = dict(text=text, spans=spans, relations=relations)
instances.append(instance)

text = TextField(
Expand All @@ -33,7 +32,7 @@ def test_adajacency_field() -> None:
)
spans = ListField([SpanField(0, 1, text), SpanField(5, 6, text)])
relations = AdjacencyField([(0, 1)], spans, labels=["capital-of"], indexer=label_indexer)
instance = Instance(text=text, spans=spans, relations=relations)
instance = dict(text=text, spans=spans, relations=relations)
instances.append(instance)

output = collate(instances)["relations"]
Expand Down
5 changes: 2 additions & 3 deletions tests/test_instance.py → tests/test_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from collatable.collator import Collator
from collatable.extras.indexer import LabelIndexer, TokenIndexer
from collatable.fields import LabelField, MetadataField, TextField
from collatable.instance import Instance


def test_instance() -> None:
Expand All @@ -19,11 +18,11 @@ def test_instance() -> None:
token_indexer = TokenIndexer[str]()
label_indexer = LabelIndexer[str]()

instances: List[Instance] = []
instances: List[dict] = []
with token_indexer.context(train=True), label_indexer.context(train=True):
for id_, (text, label) in enumerate(dataset):
tokens = text.split()
instance = Instance(
instance = dict(
text=TextField(tokens, indexer=token_indexer),
label=LabelField(label, indexer=label_indexer),
metadata=MetadataField({"id": id_}),
Expand Down