-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
35 lines (27 loc) · 1.3 KB
/
data.py
File metadata and controls
35 lines (27 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import spacy
from torchtext.datasets import Multi30k
from torchtext.data import Field
def prepare_data():
def tokenize_de(text):
"""
Tokenizes German text from a string into a list of strings
"""
return [tok.text for tok in spacy_de.tokenizer(text)]
def tokenize_en(text):
"""
Tokenizes English text from a string into a list of strings
"""
return [tok.text for tok in spacy_en.tokenizer(text)]
# don't forget to run '$ sudo python3 -m spacy download en & sudo python3 -m spacy download de_core_news_sm'
# if getting 'Can't find model 'en_core_web_sm' error here
spacy_en = spacy.load('en_core_web_sm')
spacy_de = spacy.load('de_core_news_sm')
src_lang = Field(tokenize=tokenize_de, init_token='<sos>', eos_token='<eos>',
lower=True, batch_first=True)
trg_lang = Field(tokenize=tokenize_en, init_token='<sos>', eos_token='<eos>',
lower=True, batch_first=True)
train_data, valid_data, test_data = Multi30k.splits(exts=('.de', '.en'),
fields=(src_lang, trg_lang))
src_lang.build_vocab(train_data, min_freq=2)
trg_lang.build_vocab(train_data, min_freq=2)
return train_data, valid_data, test_data, src_lang, trg_lang