Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setuptools.setup(
name="simplet5",
version="0.1.4",
version="0.1.5a",
license="apache-2.0",
author="Shivanand Roy",
author_email="shivanandroy.official@gmail.com",
Expand Down Expand Up @@ -44,7 +44,7 @@
"sentencepiece",
"torch>=1.7.0,!=1.8.0", # excludes torch v1.8.0
"transformers==4.16.2",
"pytorch-lightning==1.5.10",
"pytorch-lightning==2.0.1",
],
classifiers=[
"Intended Audience :: Developers",
Expand Down
62 changes: 40 additions & 22 deletions simplet5/simplet5.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def __init__(
save_only_last_epoch (bool, optional): If True, save just the last epoch else models are saved for every epoch
"""
super().__init__()
self.training_step_outputs = []
self.validation_step_outputs = []

self.model = model
self.tokenizer = tokenizer
self.outputdir = outputdir
Expand Down Expand Up @@ -213,9 +216,9 @@ def training_step(self, batch, batch_size):
decoder_attention_mask=labels_attention_mask,
labels=labels,
)

self.training_step_outputs.append(loss)
self.log(
"train_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
"train_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True, sync_dist=True
)
return loss

Expand All @@ -232,9 +235,10 @@ def validation_step(self, batch, batch_size):
decoder_attention_mask=labels_attention_mask,
labels=labels,
)
self.validation_step_outputs.append(loss)

self.log(
"val_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
"val_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True, sync_dist=True
)
return loss

Expand All @@ -252,41 +256,38 @@ def test_step(self, batch, batch_size):
labels=labels,
)

self.log("test_loss", loss, prog_bar=True, logger=True)
self.log("test_loss", loss, prog_bar=True, logger=True, sync_dist=True)
return loss

def configure_optimizers(self):
""" configure optimizers """
return AdamW(self.parameters(), lr=0.0001)

def training_epoch_end(self, training_step_outputs):
def on_training_epoch_end(self, training_step_outputs):
""" save tokenizer and model on epoch end """
self.average_training_loss = np.round(
torch.mean(torch.stack([x["loss"] for x in training_step_outputs])).item(),
4,
)
self.average_training_loss = torch.stack(self.training_step_outputs).mean()
path = f"{self.outputdir}/simplet5-epoch-{self.current_epoch}-train-loss-{str(self.average_training_loss)}-val-loss-{str(self.average_validation_loss)}"
if self.save_only_last_epoch:
if self.current_epoch == self.trainer.max_epochs - 1:
if self.current_epoch == self.max_epochs - 1:
self.tokenizer.save_pretrained(path)
self.model.save_pretrained(path)
else:
self.tokenizer.save_pretrained(path)
self.model.save_pretrained(path)
self.training_step_outputs.clear()

def validation_epoch_end(self, validation_step_outputs):
_loss = [x.cpu() for x in validation_step_outputs]
self.average_validation_loss = np.round(
torch.mean(torch.stack(_loss)).item(),
4,
)
def on_validation_epoch_end(self):
epoch_average = torch.stack(self.validation_step_outputs).mean()
self.log("validation_epoch_average", epoch_average, sync_dist=True)
self.validation_step_outputs.clear()


class SimpleT5:
""" Custom SimpleT5 class """

def __init__(self) -> None:
""" initiates SimpleT5 class """
self.device = 'cuda'
pass

def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
Expand All @@ -311,6 +312,7 @@ def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
self.model = T5ForConditionalGeneration.from_pretrained(
f"{model_name}", return_dict=True
)
self.model.to(self.device)

def train(
self,
Expand All @@ -320,13 +322,14 @@ def train(
target_max_token_len: int = 512,
batch_size: int = 8,
max_epochs: int = 5,
use_gpu: bool = True,
use_gpu: int = 1,
outputdir: str = "outputs",
early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
precision=32,
logger="default",
dataloader_num_workers: int = 2,
save_only_last_epoch: bool = False,
strategy: str='ddp'
):
"""
trains T5/MT5 model on custom dataset
Expand All @@ -337,7 +340,7 @@ def train(
target_max_token_len (int, optional): max token length of target text. Defaults to 512.
batch_size (int, optional): batch size. Defaults to 8.
max_epochs (int, optional): max number of epochs. Defaults to 5.
use_gpu (bool, optional): if True, model uses gpu for training. Defaults to True.
use_gpu (int, optional): number of gpus to use. Defaults to 1.
outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
early_stopping_patience_epochs (int, optional): monitors val_loss on epoch end and stops training, if val_loss does not improve after the specied number of epochs. set 0 to disable early stopping. Defaults to 0 (disabled)
precision (int, optional): sets precision training - Double precision (64), full precision (32) or half precision (16). Defaults to 32.
Expand Down Expand Up @@ -376,19 +379,26 @@ def train(
callbacks.append(early_stop_callback)

# add gpu support
gpus = 1 if use_gpu else 0
gpus = use_gpu

# add logger
loggers = True if logger == "default" else logger

# prepare trainer
# prepare





trainer = pl.Trainer(
logger=loggers,
callbacks=callbacks,
max_epochs=max_epochs,
gpus=gpus,
accelerator="gpu",
devices=gpus,
precision=precision,
log_every_n_steps=1,
strategy=strategy
)

# fit trainer
Expand Down Expand Up @@ -438,6 +448,7 @@ def predict(
early_stopping: bool = True,
skip_special_tokens: bool = True,
clean_up_tokenization_spaces: bool = True,
use_gpu=True
):
"""
generates prediction for T5/MT5 model
Expand All @@ -457,6 +468,13 @@ def predict(
Returns:
list[str]: returns predictions
"""
if use_gpu:
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
raise "exception ---> no gpu found. set use_gpu=False, to use CPU"
else:
self.device = torch.device("cpu")
input_ids = self.tokenizer.encode(
source_text, return_tensors="pt", add_special_tokens=True
)
Expand All @@ -480,4 +498,4 @@ def predict(
)
for g in generated_ids
]
return preds
return preds