Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
83bfea7
Updated README.md
kzvdar42 May 5, 2020
53b6329
Update README.md
kzvdar42 May 5, 2020
ed73205
Fixed error in dataset
kzvdar42 May 5, 2020
cd9126b
Updated start message.
kzvdar42 May 5, 2020
47e6694
Supressed info from the `transformers` logger.
kzvdar42 May 5, 2020
838d3ed
Added message for the `\help` command.
kzvdar42 May 5, 2020
7edde12
Small refactoring.
kzvdar42 May 5, 2020
b9e071d
More refactoring and small bug fix.
kzvdar42 May 5, 2020
c89c95d
Rescructured the code and improved logging.
kzvdar42 May 5, 2020
a654734
Code style fixed.
kzvdar42 May 5, 2020
3085eab
Updated README.md
kzvdar42 May 5, 2020
b4ea43f
Removed `@syncronized` for methods which don't call the models.
kzvdar42 May 8, 2020
4274130
Removed redundant word joke in arguments.
kzvdar42 May 9, 2020
f684de4
Moved genetator arguments to config field and fixed bugs.
kzvdar42 May 9, 2020
14a5fd2
Added Russian model training & usage.
kzvdar42 May 9, 2020
7205173
Fixed bold formatting.
kzvdar42 May 9, 2020
53a42ba
Finalized training code.
kzvdar42 May 11, 2020
b3cef85
Merge pull request #5 from kzvdar42/rus_master
kzvdar42 May 11, 2020
4f54c50
Removed old train code.
kzvdar42 May 11, 2020
3acd171
Fixed the error with the `_fill_buffer` function.
kzvdar42 May 11, 2020
e6ed8a6
Fixed error in `_prettify_result`. Fix #2
kzvdar42 May 11, 2020
9f19664
Update README.md
kzvdar42 May 11, 2020
22b2b2a
Update requirements.txt
kzvdar42 May 11, 2020
3616dec
Added images and log example to README.
kzvdar42 May 11, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ dmypy.json

### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
Expand All @@ -229,8 +229,11 @@ model
train/runs
train/output*
train/models

train/ru_gpt2
# prep files
data/prep
# jokes database
jokes.db
!jokes.db
*.db
# log files
*.log
363 changes: 109 additions & 254 deletions README.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions bot/bot.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ token = ...
ab_test = false
[model]
model_paths = model
rus_model_path = rus_model
dataset_paths = data/qa_jokes.csv
max_joke_len = 40
buffer_size = 16
Expand Down
38 changes: 38 additions & 0 deletions bot/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import logging
import pandas as pd


class Dataset:
"""Wrapper for the DataFrame to return values similar
to `AbstractJokeGenerator` output.
"""

def __init__(self, dataset_path, promt_token, answer_token):
self.promt_token = promt_token
self.answer_token = answer_token
self.name = os.path.split(dataset_path)[1]
self.data = pd.read_csv(dataset_path)
self.logger = logging.getLogger("DS: " + self.name)

def __getitem__(self, idx):
question = self.data['Question'].iloc[idx].strip()
answer = self.data['Answer'].iloc[idx].strip()
text = (self.promt_token + question + '\n'
+ self.answer_token + ' ' + answer)
self.logger.info('Got joke from dataset')
return {
'text': text,
'generated_by': self.name,
}

def __len__(self):
return len(self.data)


class Joke:
"""An interface class for a Joke."""

def __init__(self, text, id):
self.id = id
self.text = text
39 changes: 25 additions & 14 deletions bot/inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import torch
from torch import LongTensor
import logging

from transformers import (
CTRLLMHeadModel,
Expand All @@ -14,48 +16,57 @@
XLNetLMHeadModel,
XLNetTokenizer,
)
from yt_encoder import YTEncoder

MODEL_CLASSES = {
"gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
"gpt2-yttm": (GPT2LMHeadModel, YTEncoder),
"ctrl": (CTRLLMHeadModel, CTRLTokenizer),
"openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
"xlnet": (XLNetLMHeadModel, XLNetTokenizer),
"transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
}

# Don't show warnings.
logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
logging.getLogger("transformers.configuration_utils").setLevel(logging.ERROR)


class ModelWrapper:
def __init__(self, model_path, model_name,
device='cpu',
model_type='gpt2',
max_length=40,
max_len=40,
temperature=0.9,
num_return_sequences=1,
n_return_sequences=1,
repetition_penalty=1.0,
k=50,
p=0.95,):
self.num_return_sequences = num_return_sequences
p=0.95,
**kwargs):
self.n_return_sequences = n_return_sequences
self.repetition_penalty = repetition_penalty
self.k = k
self.p = p
self.temperature = temperature
self.device = device = torch.device(device)
self.max_length = max_length
self.device = torch.device(device)
self.max_length = max_len
model_class, tokenizer_class = MODEL_CLASSES[model_type]
self.tokenizer = tokenizer_class.from_pretrained(model_path)
self.model = model_class.from_pretrained(model_path)
self.name = model_name
self.model.to(device)
self.logger = logging.getLogger('ML: ' + self.name)
self.model.to(self.device)

def __encode(self, text):
encoded_prompt = self.tokenizer.encode(
text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = self.tokenizer.encode(text, add_special_tokens=False)
encoded_prompt = LongTensor(encoded_prompt).unsqueeze(0)
return encoded_prompt.to(self.device)

def generate(self, beginning, num_return_sequences=None):
if num_return_sequences is None:
num_return_sequences = self.num_return_sequences
def generate(self, beginning, n_return_sequences=None):
if n_return_sequences is None:
n_return_sequences = self.n_return_sequences
encoded_prompt = self.__encode(beginning)
output_sequences = self.model.generate(
input_ids=encoded_prompt,
Expand All @@ -65,7 +76,7 @@ def generate(self, beginning, num_return_sequences=None):
top_p=self.p,
repetition_penalty=self.repetition_penalty,
do_sample=True,
num_return_sequences=num_return_sequences,
num_return_sequences=n_return_sequences,
)

# Remove the batch dimension when returning multiple sequences
Expand All @@ -86,7 +97,7 @@ def generate(self, beginning, num_return_sequences=None):
# dt = datetime.datetime.now() - start
# print("\t", dt)
# Batch processing test
res = m.generate("[QUESTION] ", num_return_sequences=4)
res = m.generate("[QUESTION] ", n_return_sequences=4)
for j in res:
print(j)

Expand Down
4 changes: 0 additions & 4 deletions bot/joke.py

This file was deleted.

Loading