-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
62 lines (53 loc) · 1.85 KB
/
inference.py
File metadata and controls
62 lines (53 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import torch
import pandas as pd
import re
from transformers import AutoTokenizer
from model import SOTAModel
def heavy_clean(text):
text = str(text).lower()
text = re.sub(r'<.*?>', '', text)
text = re.sub(r'(.)\1{2,}', r'\1', text)
text = re.sub(r'[^\w\s\?\!]', ' ', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def run_inference(model_path, input_csv, output_csv, model_name):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
df = pd.read_csv(input_csv)
df['cleaned'] = df['Original_Message'].apply(heavy_clean)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = SOTAModel(model_name)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
preds = []
with torch.no_grad():
for text in df['cleaned']:
enc = tokenizer(
text,
max_length=192,
padding='max_length',
truncation=True,
return_tensors='pt'
)
logits = model(
enc['input_ids'].to(device),
enc['attention_mask'].to(device)
)
preds.append(torch.argmax(logits, dim=1).item())
label_map = {0: 'NON_EXTREMIST', 1: 'EXTREMIST'}
df['Prediction'] = [label_map[p] for p in preds]
df.to_csv(output_csv, index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", required=True)
parser.add_argument("--input_csv", required=True)
parser.add_argument("--output_csv", required=True)
parser.add_argument("--model_name", default="microsoft/deberta-v3-large")
args = parser.parse_args()
run_inference(
args.model_path,
args.input_csv,
args.output_csv,
args.model_name
)