-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdocument_loader.py
More file actions
42 lines (40 loc) · 2 KB
/
document_loader.py
File metadata and controls
42 lines (40 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import pandas as pd
from tqdm import tqdm
from dataprocessor import DataProcessor
from document import Document
from typing import List
from tqdm import tqdm
class DocumentLoader:
def __init__(self,
columns_to_add: dict,
tokenizer: DataProcessor = None,
columns_to_tokenize: list = [],
category_column: str = "category"):
"""
Parameters
----------
columns_to_add: dict
a dictionary of column names and the column name to add to the content (possible rename)
tokenizer: DataProcessor
a DataProcessor instance to use for tokenization
columns_to_tokenize: list
a list of column names to tokenize (have to be string values)
category_column: str
the column name of the category
"""
if tokenizer and not columns_to_tokenize:
raise ValueError("If tokenizer is provided, columns_to_tokenize must be provided")
if not tokenizer and columns_to_tokenize:
raise ValueError("The user provided columns_to_tokenize but no tokenizer")
self.columns_to_add = columns_to_add
self.tokenizer = tokenizer
self.category_column = category_column
self.columns_to_tokenize = columns_to_tokenize
def load_documents(self, dataframe: pd.DataFrame, tokenize: bool = False):
assert self.category_column in dataframe.columns, "The correct category column must be provided"
assert all(col in dataframe.columns for col in self.columns_to_tokenize + list(self.columns_to_add.keys())), "The correct columns to include and tokenize must be provided"
for idx, row in tqdm(dataframe.iterrows(), desc="Loading documents"):
content = {save_col: row[col] for col, save_col in self.columns_to_add.items()}
if tokenize:
content["tokens"] = self.tokenizer.preprocess_text("\n".join(row[self.columns_to_tokenize]))
yield Document(idx, row[self.category_column], content)