Skip to content
19 changes: 19 additions & 0 deletions configs/train_drone.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
python3 source/gen_scl_nat_main.py \
--task gen_scl_nat \
--do_train \
--do_direct_eval \
--model_name_or_path t5-base \
--dataset acos_drone_binary \
--output_folder train_outputs \
--n_gpu 1 \
--train_batch_size 32 \
--eval_batch_size 32 \
--learning_rate 9e-5 \
--gradient_accumulation_steps 1 \
--num_train_epochs 45 \
--num_beams 5 \
--weight_decay 0.0 \
--seed 42 \
--cont_loss 0.05 \
--cont_temp 0.25 \
--model_prefix drone_output
18 changes: 18 additions & 0 deletions configs/train_drone_asqp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
python3 source/gen_scl_nat_main.py \
--task asqp \
--do_train \
--do_direct_eval \
--dataset acos_drone_data \
--output_folder train_drone_asqp \
--n_gpu 1 \
--train_batch_size 16 \
--eval_batch_size 16 \
--learning_rate 9e-5 \
--gradient_accumulation_steps 1 \
--num_train_epochs 45 \
--num_beams 5 \
--weight_decay 0.0 \
--seed 42 \
--cont_loss 0.05 \
--cont_temp 0.25 \
--model_prefix drone_asqp
116 changes: 116 additions & 0 deletions data/acos_drone_binary/dev.txt

Large diffs are not rendered by default.

116 changes: 116 additions & 0 deletions data/acos_drone_binary/test.txt

Large diffs are not rendered by default.

462 changes: 462 additions & 0 deletions data/acos_drone_binary/train.txt

Large diffs are not rendered by default.

116 changes: 116 additions & 0 deletions data/acos_drone_data/dev.txt

Large diffs are not rendered by default.

116 changes: 116 additions & 0 deletions data/acos_drone_data/test.txt

Large diffs are not rendered by default.

462 changes: 462 additions & 0 deletions data/acos_drone_data/train.txt

Large diffs are not rendered by default.

238 changes: 124 additions & 114 deletions source/gen_scl_nat_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def get_dataset(tokenizer, type_path, args):

"""
Uncomment for tsne logging
"""
tsne_dict = {
'sentiment_vecs': [],
'opinion_vecs': [],
Expand All @@ -140,15 +141,19 @@ def get_dataset(tokenizer, type_path, args):
'opinion_labels': [],
'aspect_labels': []
}
"""

class LinearModel(nn.Module):
"""
Linear models used for the aspect/opinion/sentiment-specific representations
"""
def __init__(self):
def __init__(self, model_path):
super().__init__()
self.layer_1 = nn.Linear(1024, 1024)
if model_path == 't5-small':
self.layer_1 = nn.Linear(512, 1024)
elif model_path == 't5-base':
self.layer_1 = nn.Linear(768, 1024)
else:
self.layer_1 = nn.Linear(1024, 1024)
self.dropout = nn.Dropout(0.1)

def forward(self, x, attention_mask):
Expand Down Expand Up @@ -242,6 +247,7 @@ def _step(self, batch):

"""
Uncomment this section to extract the tsne encodings/labels used for Figure 2 in paper
"""

# Use these for generating the 'w/ SCL' figures
sentiment_encs = cont_normed.detach().numpy()[:,0].tolist()
Expand All @@ -251,6 +257,7 @@ def _step(self, batch):
aspect_labs = aspect_labels.detach().tolist()
opinion_labs = opinion_labels.detach().tolist()

"""
# Use these for the version without SCL (no characteristic-specific representations)

sentiment_encs = pooled_encoder_layer.detach().numpy().tolist()
Expand All @@ -260,13 +267,13 @@ def _step(self, batch):
aspect_labs = aspect_labels.detach().tolist()
opinion_labs = opinion_labels.detach().tolist()

"""
tsne_dict['sentiment_vecs'] += sentiment_encs
tsne_dict['aspect_vecs'] += aspect_encs
tsne_dict['opinion_vecs'] += opinion_encs
tsne_dict['sentiment_labels'] += sentiment_labs
tsne_dict['aspect_labels'] += aspect_labs
tsne_dict['opinion_labels'] += opinion_labs
"""

# return original loss plus the characteristic-specific SCL losses
loss = outputs[0] + opinion_contrastive_loss + sentiment_contrastive_loss + aspect_contrastive_loss
Expand Down Expand Up @@ -425,115 +432,118 @@ def evaluate(data_loader, model, sents, task):
json.dump(results, open(f"{args.output_dir}/results-{args.dataset}.json", 'w'), indent=2, sort_keys=True)
return scores

# initialization
args = init_args()
seed_everything(args.seed, workers=True)

tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path)
tokenizer.add_tokens(['[SSEP]'])




# Get example from the train set
dataset = GenSCLNatDataset(tokenizer=tokenizer, data_dir=args.dataset,
data_type='train', max_len=args.max_seq_length, task=args.task, truncate=args.truncate)
data_sample = dataset[0]

# sanity check
# show one sample to check the code and the expected output format are correct
print(f"Here is an example (from the train set):")
print('Input :', tokenizer.decode(data_sample['source_ids'], skip_special_tokens=True))
print(data_sample['source_ids'])
print('Output:', tokenizer.decode(data_sample['target_ids'], skip_special_tokens=True))
print(data_sample['target_ids'])

# training process
if args.do_train:
print("\n****** Conducting Training ******")

# initialize the T5 model
tfm_model = T5ForConditionalGeneration.from_pretrained(args.model_name_or_path)
tfm_model.resize_token_embeddings(len(tokenizer))
# initialize characteristic-specific representation models
cont_model = LinearModel()
op_model = LinearModel()
as_model = LinearModel()
cat_model = LinearModel()
model = T5FineTuner(args, tfm_model, tokenizer, cont_model, op_model, as_model, cat_model)

if args.early_stopping:
checkpoint_callback = pl.callbacks.model_checkpoint.ModelCheckpoint(
dirpath=args.output_dir, monitor='val_loss', mode='min', save_top_k=1

# check for top-level environment
if __name__ == '__main__':
# initialization
args = init_args()
seed_everything(args.seed, workers=True)

tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path)
tokenizer.add_tokens(['[SSEP]'])




# Get example from the train set
dataset = GenSCLNatDataset(tokenizer=tokenizer, data_dir=args.dataset,
data_type='train', max_len=args.max_seq_length, task=args.task, truncate=args.truncate)
data_sample = dataset[0]

# sanity check
# show one sample to check the code and the expected output format are correct
print(f"Here is an example (from the train set):")
print('Input :', tokenizer.decode(data_sample['source_ids'], skip_special_tokens=True))
print(data_sample['source_ids'])
print('Output:', tokenizer.decode(data_sample['target_ids'], skip_special_tokens=True))
print(data_sample['target_ids'])

# training process
if args.do_train:
print("\n****** Conducting Training ******")

# initialize the T5 model
tfm_model = T5ForConditionalGeneration.from_pretrained(args.model_name_or_path)
tfm_model.resize_token_embeddings(len(tokenizer))
# initialize characteristic-specific representation models
cont_model = LinearModel(args.model_name_or_path)
op_model = LinearModel(args.model_name_or_path)
as_model = LinearModel(args.model_name_or_path)
cat_model = LinearModel(args.model_name_or_path)
model = T5FineTuner(args, tfm_model, tokenizer, cont_model, op_model, as_model, cat_model)

if args.early_stopping:
checkpoint_callback = pl.callbacks.model_checkpoint.ModelCheckpoint(
dirpath=args.output_dir, monitor='val_loss', mode='min', save_top_k=1
)
callback_list = [checkpoint_callback, LoggingCallback(), EarlyStopping(monitor="val_loss", mode='min', patience=3)]
else:
callback_list = [LoggingCallback()]

# prepare trainer args
train_params = dict(
default_root_dir=args.output_dir,
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.n_gpu,
gradient_clip_val=1.0,
max_epochs=args.num_train_epochs,
auto_lr_find=False,
deterministic=True,
#auto_scale_batch_size=True,
#callbacks=[checkpoint_callback, EarlyStopping(monitor="val_loss", mode='min'), LoggingCallback()],
callbacks=callback_list
)
callback_list = [checkpoint_callback, LoggingCallback(), EarlyStopping(monitor="val_loss", mode='min', patience=3)]
else:
callback_list = [LoggingCallback()]

# prepare trainer args
train_params = dict(
default_root_dir=args.output_dir,
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.n_gpu,
gradient_clip_val=1.0,
max_epochs=args.num_train_epochs,
auto_lr_find=False,
deterministic=True,
#auto_scale_batch_size=True,
#callbacks=[checkpoint_callback, EarlyStopping(monitor="val_loss", mode='min'), LoggingCallback()],
callbacks=callback_list
)
trainer = pl.Trainer(**train_params)
trainer.fit(model)

if args.early_stopping:
ex_weights = torch.load(checkpoint_callback.best_model_path)['state_dict']
model.load_state_dict(ex_weights)

model.model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
with open(os.path.join(args.output_dir, 'args.json'), 'w') as f:
json.dump(args.__dict__, f, indent=2)

print("Finish training and saving the model!")

# evaluation
if args.do_direct_eval:
print("\n****** Conduct Evaluating with the last state ******")

sents, _ = read_line_examples_from_file(f'data/{args.dataset}/test.txt')

print()
test_dataset = GenSCLNatDataset(tokenizer, data_dir=args.dataset,
data_type='test', max_len=args.max_seq_length, task=args.task, truncate=args.truncate)
test_loader = DataLoader(test_dataset, args.eval_batch_size, num_workers=4)

# compute the performance scores
evaluate(test_loader, model, test_dataset.sentence_strings, args.task)

if args.do_inference:
print("\n****** Conduct inference on trained checkpoint ******")

# initialize the T5 model from previous checkpoint
model_path = args.model_name_or_path
print(f"Loading trained model from {model_path}")
tokenizer = T5Tokenizer.from_pretrained(model_path)
tfm_model = T5ForConditionalGeneration.from_pretrained(model_path)

# representations are only used during loss calculation
cont_model = LinearModel()
op_model = LinearModel()
as_model = LinearModel()
cat_model = LinearModel()
model = T5FineTuner(args, tfm_model, tokenizer, cont_model, op_model, as_model, cat_model)

sents, _ = read_line_examples_from_file(f'data/{args.dataset}/test.txt')

print()
test_dataset = GenSCLNatDataset(tokenizer, data_dir=args.dataset,
data_type='test', max_len=args.max_seq_length, task=args.task, truncate=args.truncate)
test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, num_workers=4)

# compute the performance scores
evaluate(test_loader, model, test_dataset.sentence_strings, args.task)
trainer = pl.Trainer(**train_params)
trainer.fit(model)

if args.early_stopping:
ex_weights = torch.load(checkpoint_callback.best_model_path)['state_dict']
model.load_state_dict(ex_weights)

model.model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
with open(os.path.join(args.output_dir, 'args.json'), 'w') as f:
json.dump(args.__dict__, f, indent=2)

print("Finish training and saving the model!")

# evaluation
if args.do_direct_eval:
print("\n****** Conduct Evaluating with the last state ******")

sents, _ = read_line_examples_from_file(f'data/{args.dataset}/test.txt')

print()
test_dataset = GenSCLNatDataset(tokenizer, data_dir=args.dataset,
data_type='test', max_len=args.max_seq_length, task=args.task, truncate=args.truncate)
test_loader = DataLoader(test_dataset, args.eval_batch_size, num_workers=4)

# compute the performance scores
evaluate(test_loader, model, test_dataset.sentence_strings, args.task)

if args.do_inference:
print("\n****** Conduct inference on trained checkpoint ******")

# initialize the T5 model from previous checkpoint
model_path = args.model_name_or_path
print(f"Loading trained model from {model_path}")
tokenizer = T5Tokenizer.from_pretrained(model_path)
tfm_model = T5ForConditionalGeneration.from_pretrained(model_path)

# representations are only used during loss calculation
cont_model = LinearModel(args.model_name_or_path)
op_model = LinearModel(args.model_name_or_path)
as_model = LinearModel(args.model_name_or_path)
cat_model = LinearModel(args.model_name_or_path)
model = T5FineTuner(args, tfm_model, tokenizer, cont_model, op_model, as_model, cat_model)

sents, _ = read_line_examples_from_file(f'data/{args.dataset}/test.txt')

print()
test_dataset = GenSCLNatDataset(tokenizer, data_dir=args.dataset,
data_type='test', max_len=args.max_seq_length, task=args.task, truncate=args.truncate)
test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, num_workers=4)

# compute the performance scores
evaluate(test_loader, model, test_dataset.sentence_strings, args.task)

Loading