# -*- coding: utf-8 -*-
# file: classifier_instructor.py
# time: 2021/4/22 0022
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# Copyright (C) 2021. All Rights Reserved.
import os
import random
import shutil
import time
import numpy as np
import pandas
import pytorch_warmup as warmup
import torch
import torch.nn as nn
from findfile import find_file
from sklearn import metrics
from torch import cuda
from torch.utils.data import (
DataLoader,
random_split,
ConcatDataset,
RandomSampler,
SequentialSampler,
)
from tqdm import tqdm
from transformers import AutoModel
from pyabsa.framework.flag_class.flag_template import DeviceTypeOption
from pyabsa.framework.instructor_class.instructor_template import BaseTrainingInstructor
from pyabsa.framework.tokenizer_class.tokenizer_class import (
PretrainedTokenizer,
Tokenizer,
build_embedding_matrix,
)
from pyabsa.utils.file_utils.file_utils import save_model
from pyabsa.utils.pyabsa_utils import init_optimizer, fprint
from ..dataset_utils.__classic__.data_utils_for_training import GloVeTADDataset
from ..dataset_utils.__plm__.data_utils_for_training import BERTTADDataset
from ..models import BERTTADModelList, GloVeTADModelList
[docs]
class TADTrainingInstructor(BaseTrainingInstructor):
[docs]
def _init_misc(self):
random.seed(self.config.seed)
np.random.seed(self.config.seed)
torch.manual_seed(self.config.seed)
torch.cuda.manual_seed(self.config.seed)
self.config.inputs_cols = self.model.inputs
self.config.device = torch.device(self.config.device)
# use DataParallel for trainer if device count larger than 1
if self.config.auto_device == DeviceTypeOption.ALL_CUDA:
self.model.to(self.config.device)
self.model = torch.nn.parallel.DataParallel(self.model).module
else:
self.model.to(self.config.device)
self.optimizer = init_optimizer(self.config.optimizer)(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.l2reg,
)
self.train_dataloaders = []
self.valid_dataloaders = []
if os.path.exists("./init_state_dict.bin"):
os.remove("./init_state_dict.bin")
if self.config.cross_validate_fold > 0:
torch.save(self.model.state_dict(), "./init_state_dict.bin")
self.config.device = torch.device(self.config.device)
if self.config.device.type == DeviceTypeOption.CUDA:
self.config.logger.info(
"cuda memory allocated:{}".format(
torch.cuda.memory_allocated(device=self.config.device)
)
)
[docs]
def _cache_or_load_dataset(self):
pass
[docs]
def _load_dataset_and_prepare_dataloader(self):
cache_path = self.load_cache_dataset()
# init BERT-based model and dataset
if hasattr(BERTTADModelList, self.config.model.__name__):
self.tokenizer = PretrainedTokenizer(self.config)
if not os.path.exists(cache_path) or self.config.overwrite_cache:
self.train_set = BERTTADDataset(
self.config, self.tokenizer, dataset_type="train"
)
self.test_set = BERTTADDataset(
self.config, self.tokenizer, dataset_type="test"
)
self.valid_set = BERTTADDataset(
self.config, self.tokenizer, dataset_type="valid"
)
try:
self.bert = AutoModel.from_pretrained(self.config.pretrained_bert)
except ValueError as e:
fprint("Init pretrained model failed, exception: {}".format(e))
# init the model behind the construction of datasets in case of updating output_dim
self.model = self.config.model(self.bert, self.config).to(
self.config.device
)
elif hasattr(GloVeTADModelList, self.config.model.__name__):
# init GloVe-based model and dataset
self.tokenizer = Tokenizer.build_tokenizer(
config=self.config,
cache_path="{0}_tokenizer.dat".format(
os.path.basename(self.config.dataset_name)
),
)
self.embedding_matrix = build_embedding_matrix(
config=self.config,
tokenizer=self.tokenizer,
cache_path="{0}_{1}_embedding_matrix.dat".format(
str(self.config.embed_dim),
os.path.basename(self.config.dataset_name),
),
)
self.train_set = GloVeTADDataset(
self.config, self.tokenizer, dataset_type="train"
)
self.test_set = GloVeTADDataset(
self.config, self.tokenizer, dataset_type="test"
)
self.valid_set = GloVeTADDataset(
self.config, self.tokenizer, dataset_type="valid"
)
self.model = self.config.model(self.embedding_matrix, self.config).to(
self.config.device
)
self.config.embedding_matrix = self.embedding_matrix
self.config.tokenizer = self.tokenizer
self.save_cache_dataset(cache_path)
def __init__(self, config):
super().__init__(config)
self._load_dataset_and_prepare_dataloader()
self._init_misc()
[docs]
def reload_model_state_dict(self, ckpt="./init_state_dict.bin"):
if os.path.exists(ckpt):
self.model.load_state_dict(
torch.load(find_file(ckpt, or_key=[".bin", "state_dict"])), strict=False
)
[docs]
def prepare_dataloader(self, train_set):
if self.config.cross_validate_fold < 1:
train_sampler = RandomSampler(
self.train_set if not self.train_set else self.train_set
)
self.train_dataloaders.append(
DataLoader(
dataset=train_set,
batch_size=self.config.batch_size,
sampler=train_sampler,
pin_memory=True,
)
)
if self.test_set:
self.test_dataloader = DataLoader(
dataset=self.test_set,
batch_size=self.config.batch_size,
shuffle=False,
)
if self.valid_set:
self.valid_dataloader = DataLoader(
dataset=self.valid_set,
batch_size=self.config.batch_size,
shuffle=False,
)
else:
split_dataset = train_set
len_per_fold = len(split_dataset) // self.config.cross_validate_fold + 1
folds = random_split(
split_dataset,
tuple(
[len_per_fold] * (self.config.cross_validate_fold - 1)
+ [
len(split_dataset)
- len_per_fold * (self.config.cross_validate_fold - 1)
]
),
)
for f_idx in range(self.config.cross_validate_fold):
train_set = ConcatDataset(
[x for i, x in enumerate(folds) if i != f_idx]
)
val_set = folds[f_idx]
train_sampler = RandomSampler(train_set if not train_set else train_set)
val_sampler = SequentialSampler(val_set if not val_set else val_set)
self.train_dataloaders.append(
DataLoader(
dataset=train_set,
batch_size=self.config.batch_size,
sampler=train_sampler,
)
)
self.valid_dataloaders.append(
DataLoader(
dataset=val_set,
batch_size=self.config.batch_size,
sampler=val_sampler,
)
)
if self.test_set:
self.test_dataloader = DataLoader(
dataset=self.test_set,
batch_size=self.config.batch_size,
shuffle=False,
)
[docs]
def _train(self, criterion):
self._prepare_dataloader()
if self.config.warmup_step >= 0:
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=len(self.train_dataloaders[0]) * self.config.num_epoch,
)
self.warmup_scheduler = warmup.UntunedLinearWarmup(self.optimizer)
if len(self.valid_dataloaders) > 1:
return self._k_fold_train_and_evaluate(criterion)
else:
return self._train_and_evaluate(criterion)
[docs]
def _train_and_evaluate(self, criterion):
global_step = 0
max_label_fold_acc = 0
max_label_fold_f1 = 0
max_adv_det_fold_acc = 0
max_adv_det_fold_f1 = 0
max_adv_tr_fold_acc = 0
max_adv_tr_fold_f1 = 0
losses = []
save_path = "{0}/{1}_{2}".format(
self.config.model_path_to_save,
self.config.model_name,
self.config.dataset_name,
)
self.config.metrics_of_this_checkpoint = {"acc": 0, "f1": 0}
self.config.max_test_metrics = {
"max_cls_test_acc": 0,
"max_cls_test_f1": 0,
"max_adv_det_test_acc": 0,
"max_adv_det_test_f1": 0,
"max_adv_tr_test_acc": 0,
"max_adv_tr_test_f1": 0,
}
self.config.logger.info(
"***** Running trainer for Text Classification with Adversarial Attack Defense *****"
)
self.config.logger.info("Training set examples = %d", len(self.train_set))
if self.test_set:
self.config.logger.info("Test set examples = %d", len(self.test_set))
self.config.logger.info("Batch size = %d", self.config.batch_size)
self.config.logger.info(
"Num steps = %d",
len(self.train_dataloaders[0])
// self.config.batch_size
* self.config.num_epoch,
)
patience = self.config.patience + self.config.evaluate_begin
if self.config.log_step < 0:
self.config.log_step = (
len(self.train_dataloaders[0])
if self.config.log_step < 0
else self.config.log_step
)
for epoch in range(self.config.num_epoch):
patience -= 1
description = "Epoch:{} | Loss:{}".format(epoch, 0)
iterator = tqdm(self.train_dataloaders[0], desc=description)
for i_batch, sample_batched in enumerate(iterator):
global_step += 1
# switch model to train mode, clear gradient accumulators
self.model.train()
self.optimizer.zero_grad()
inputs = [
sample_batched[col].to(self.config.device)
for col in self.config.inputs_cols
]
with torch.cuda.amp.autocast():
outputs = self.model(inputs)
label_targets = sample_batched["label"].to(self.config.device)
adv_tr_targets = sample_batched["adv_train_label"].to(
self.config.device
)
adv_det_targets = sample_batched["is_adv"].to(self.config.device)
sen_logits, advdet_logits, adv_tr_logits = (
outputs["sent_logits"],
outputs["advdet_logits"],
outputs["adv_tr_logits"],
)
sen_loss = criterion(sen_logits, label_targets)
adv_det_loss = criterion(advdet_logits, adv_det_targets)
adv_train_loss = criterion(adv_tr_logits, adv_tr_targets)
loss = (
sen_loss
+ self.config.args.get("adv_det_weight", 5) * adv_det_loss
+ self.config.args.get("adv_train_weight", 5) * adv_train_loss
)
losses.append(loss.item())
if self.scaler:
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
loss.backward()
self.optimizer.step()
if self.config.warmup_step >= 0:
with self.warmup_scheduler.dampening():
self.lr_scheduler.step()
# evaluate if test set is available
if global_step % self.config.log_step == 0:
if self.test_dataloader and epoch >= self.config.evaluate_begin:
if self.valid_dataloader:
(
test_label_acc,
test_label_f1,
test_adv_det_acc,
test_adv_det_f1,
test_adv_tr_acc,
test_adv_tr_f1,
) = self._evaluate_acc_f1(self.valid_dataloader)
else:
(
test_label_acc,
test_label_f1,
test_adv_det_acc,
test_adv_det_f1,
test_adv_tr_acc,
test_adv_tr_f1,
) = self._evaluate_acc_f1(self.test_dataloader)
self.config.metrics_of_this_checkpoint[
"max_cls_test_acc"
] = test_label_acc
self.config.metrics_of_this_checkpoint[
"max_cls_test_f1"
] = test_label_f1
self.config.metrics_of_this_checkpoint[
"max_adv_det_test_acc"
] = test_adv_det_acc
self.config.metrics_of_this_checkpoint[
"max_adv_det_test_f1"
] = test_adv_det_f1
self.config.metrics_of_this_checkpoint[
"max_adv_tr_test_acc"
] = test_adv_tr_acc
self.config.metrics_of_this_checkpoint[
"max_adv_tr_test_f1"
] = test_adv_tr_f1
if (
test_label_acc > max_label_fold_acc
or test_label_acc > max_label_fold_f1
or test_adv_det_acc > max_adv_det_fold_acc
or test_adv_det_f1 > max_adv_det_fold_f1
or test_adv_tr_acc > max_adv_tr_fold_acc
or test_adv_tr_f1 > max_adv_tr_fold_f1
):
if test_label_acc > max_label_fold_acc:
patience = self.config.patience - 1
max_label_fold_acc = test_label_acc
if test_label_f1 > max_label_fold_f1:
patience = self.config.patience - 1
max_label_fold_f1 = test_label_f1
if test_adv_det_acc > max_adv_det_fold_acc:
patience = self.config.patience - 1
max_adv_det_fold_acc = test_adv_det_acc
if test_adv_det_f1 > max_adv_det_fold_f1:
patience = self.config.patience - 1
max_adv_det_fold_f1 = test_adv_det_f1
if test_adv_tr_acc > max_adv_tr_fold_acc:
patience = self.config.patience - 1
max_adv_tr_fold_acc = test_adv_tr_acc
if test_adv_tr_f1 > max_adv_tr_fold_f1:
patience = self.config.patience - 1
max_adv_tr_fold_f1 = test_adv_tr_f1
if self.config.model_path_to_save:
if not os.path.exists(self.config.model_path_to_save):
os.makedirs(self.config.model_path_to_save)
if save_path:
try:
shutil.rmtree(save_path)
# logger.info('Remove sub-configimal trained model:', save_path)
except:
# logger.info('Can not remove sub-configimal trained model:', save_path)
pass
save_path = (
"{0}/{1}_{2}_cls_acc_{3}_cls_f1_{4}_adv_det_acc_{5}_adv_det_f1_{6}"
"_adv_training_acc_{7}_adv_training_f1_{8}/".format(
self.config.model_path_to_save,
self.config.model_name,
self.config.dataset_name,
round(test_label_acc * 100, 2),
round(test_label_f1 * 100, 2),
round(test_adv_det_acc * 100, 2),
round(test_adv_det_f1 * 100, 2),
round(test_adv_tr_acc * 100, 2),
round(test_adv_tr_f1 * 100, 2),
)
)
if (
test_label_acc
> self.config.max_test_metrics["max_cls_test_acc"]
):
self.config.max_test_metrics[
"max_cls_test_acc"
] = test_label_acc
if (
test_label_f1
> self.config.max_test_metrics["max_cls_test_f1"]
):
self.config.max_test_metrics[
"max_cls_test_f1"
] = test_label_f1
if (
test_adv_det_acc
> self.config.max_test_metrics[
"max_adv_det_test_acc"
]
):
self.config.max_test_metrics[
"max_adv_det_test_acc"
] = test_adv_det_acc
if (
test_adv_det_f1
> self.config.max_test_metrics[
"max_adv_det_test_f1"
]
):
self.config.max_test_metrics[
"max_adv_det_test_f1"
] = test_adv_det_f1
if (
test_adv_tr_acc
> self.config.max_test_metrics[
"max_adv_tr_test_acc"
]
):
self.config.max_test_metrics[
"max_adv_tr_test_acc"
] = test_adv_tr_acc
if (
test_adv_tr_f1
> self.config.max_test_metrics["max_adv_tr_test_f1"]
):
self.config.max_test_metrics[
"max_adv_tr_test_f1"
] = test_adv_tr_f1
save_model(
self.config, self.model, self.tokenizer, save_path
)
postfix = (
"Dev CLS ACC:{:>.2f}(max:{:>.2f}) Dev AdvDet ACC:{:>.2f}(max:{:>.2f})"
" Dev AdvCLS ACC:{:>.2f}(max:{:>.2f})".format(
test_label_acc * 100,
max_label_fold_acc * 100,
test_adv_det_acc * 100,
max_adv_det_fold_acc * 100,
test_adv_tr_acc * 100,
max_adv_tr_fold_acc * 100,
)
)
iterator.set_postfix_str(postfix)
elif self.config.save_mode and epoch >= self.config.evaluate_begin:
save_model(
self.config,
self.model,
self.tokenizer,
save_path + "_{}/".format(loss.item()),
)
else:
if self.config.get("loss_display", "smooth") == "smooth":
description = "Epoch:{:>3d} | Smooth Loss: {:>.4f}".format(
epoch, round(np.nanmean(losses), 4)
)
else:
description = "Epoch:{:>3d} | Batch Loss: {:>.4f}".format(
epoch, round(loss.item(), 4)
)
iterator.set_description(description)
iterator.refresh()
if patience == 0:
break
if not self.valid_dataloader:
self.config.MV.log_metric(
self.config.model_name,
"Max-CLS-Acc w/o Valid Set",
max_label_fold_acc * 100,
)
self.config.MV.log_metric(
self.config.model_name,
"Max-CLS-F1 w/o Valid Set",
max_label_fold_f1 * 100,
)
self.config.MV.log_metric(
self.config.model_name,
"Max-AdvDet-Acc w/o Valid Set",
max_adv_det_fold_acc * 100,
)
self.config.MV.log_metric(
self.config.model_name,
"Max-AdvDet-F1 w/o Valid Set",
max_adv_det_fold_f1 * 100,
)
if self.valid_dataloader:
fprint(
"Loading best model: {} and evaluating on test set ".format(save_path)
)
self.reload_model_state_dict(find_file(save_path, ".state_dict"))
(
max_label_fold_acc,
max_label_fold_f1,
max_adv_det_fold_acc,
max_adv_det_fold_f1,
max_adv_tr_fold_acc,
max_adv_tr_fold_f1,
) = self._evaluate_acc_f1(self.test_dataloader)
self.config.MV.log_metric(
self.config.model_name
+ "-"
+ self.config.dataset_name
+ "-"
+ self.config.pretrained_bert,
"Max-CLS-Acc",
max_label_fold_acc * 100,
)
self.config.MV.log_metric(
self.config.model_name
+ "-"
+ self.config.dataset_name
+ "-"
+ self.config.pretrained_bert,
"Max-CLS-F1",
max_label_fold_f1 * 100,
)
self.config.MV.log_metric(
self.config.model_name
+ "-"
+ self.config.dataset_name
+ "-"
+ self.config.pretrained_bert,
"Max-AdvDet-Acc",
max_adv_det_fold_acc * 100,
)
self.config.MV.log_metric(
self.config.model_name
+ "-"
+ self.config.dataset_name
+ "-"
+ self.config.pretrained_bert,
"Max-AdvDet-F1",
max_adv_det_fold_f1 * 100,
)
self.config.MV.log_metric(
self.config.model_name
+ "-"
+ self.config.dataset_name
+ "-"
+ self.config.pretrained_bert,
"Max-AdvCLS-Acc",
max_adv_tr_fold_acc * 100,
)
self.config.MV.log_metric(
self.config.model_name
+ "-"
+ self.config.dataset_name
+ "-"
+ self.config.pretrained_bert,
"Max-AdvCLS-F1",
max_adv_tr_fold_f1 * 100,
)
self.config.logger.info(self.config.MV.summary(no_print=True))
rolling_intv = 5
df = pandas.DataFrame(losses)
losses = list(np.hstack(df.rolling(rolling_intv, min_periods=1).mean().values))
self.config.loss = losses[-1]
# self.config.loss = np.average(losses)
if self.valid_dataloader or self.config.save_mode:
del self.train_dataloaders
del self.test_dataloader
del self.valid_dataloader
del self.model
cuda.empty_cache()
time.sleep(3)
return save_path
else:
del self.train_dataloaders
del self.test_dataloader
del self.valid_dataloader
cuda.empty_cache()
time.sleep(3)
return self.model, self.config, self.tokenizer
[docs]
def _k_fold_train_and_evaluate(self, criterion):
raise NotImplementedError()
[docs]
def _evaluate_acc_f1(self, test_dataloader):
# switch model to evaluation mode
self.model.eval()
n_label_test_correct, n_label_test_total = 1e-10, 1e-10
n_adv_det_test_correct, n_adv_det_test_total = 1e-10, 1e-10
n_adv_tr_test_correct, n_adv_tr_test_total = 1e-10, 1e-10
t_label_targets_all, t_label_outputs_all = None, None
t_adv_det_targets_all, t_adv_det_outputs_all = None, None
t_adv_tr_targets_all, t_adv_tr_outputs_all = None, None
with torch.no_grad():
for t_batch, t_sample_batched in enumerate(test_dataloader):
t_inputs = [
t_sample_batched[col].to(self.config.device)
for col in self.config.inputs_cols
]
t_label_targets = t_sample_batched["label"].to(self.config.device)
t_adv_tr_targets = t_sample_batched["adv_train_label"].to(
self.config.device
)
t_adv_det_targets = t_sample_batched["is_adv"].to(self.config.device)
t_outputs = self.model(t_inputs)
sent_logits, advdet_logits, adv_tr_logits = (
t_outputs["sent_logits"],
t_outputs["advdet_logits"],
t_outputs["adv_tr_logits"],
)
# --------------------------------------------------------------------------------------------#
valid_label_targets = torch.tensor(
[x for x in t_label_targets.cpu() if x != -100]
).to(self.config.device)
if any(valid_label_targets):
valid_label_logit_ids = [
True if x != -100 else False for x in t_label_targets.cpu()
]
valid_label_logits = sent_logits[valid_label_logit_ids]
n_label_test_correct += (
(torch.argmax(valid_label_logits, -1) == valid_label_targets)
.sum()
.item()
)
n_label_test_total += len(valid_label_logits)
if t_label_targets_all is None:
t_label_targets_all = valid_label_targets
t_label_outputs_all = valid_label_logits
else:
t_label_targets_all = torch.cat(
(t_label_targets_all, valid_label_targets), dim=0
)
t_label_outputs_all = torch.cat(
(t_label_outputs_all, valid_label_logits), dim=0
)
# --------------------------------------------------------------------------------------------#
n_adv_det_test_correct += (
(torch.argmax(advdet_logits, -1) == t_adv_det_targets).sum().item()
)
n_adv_det_test_total += len(advdet_logits)
if t_adv_det_targets_all is None:
t_adv_det_targets_all = t_adv_det_targets
t_adv_det_outputs_all = advdet_logits
else:
t_adv_det_targets_all = torch.cat(
(t_adv_det_targets_all, t_adv_det_targets), dim=0
)
t_adv_det_outputs_all = torch.cat(
(t_adv_det_outputs_all, advdet_logits), dim=0
)
# --------------------------------------------------------------------------------------------#
valid_adv_tr_targets = torch.tensor(
[x for x in t_adv_tr_targets.cpu() if x != -100]
).to(self.config.device)
if any(t_adv_tr_targets):
valid_adv_tr_logit_ids = [
True if x != -100 else False for x in t_adv_tr_targets.cpu()
]
valid_adv_tr_logits = adv_tr_logits[valid_adv_tr_logit_ids]
n_adv_tr_test_correct += (
(torch.argmax(valid_adv_tr_logits, -1) == valid_adv_tr_targets)
.sum()
.item()
)
n_adv_tr_test_total += len(valid_adv_tr_logits)
if t_adv_tr_targets_all is None:
t_adv_tr_targets_all = valid_adv_tr_targets
t_adv_tr_outputs_all = valid_adv_tr_logits
else:
t_adv_tr_targets_all = torch.cat(
(t_adv_tr_targets_all, valid_adv_tr_targets), dim=0
)
t_adv_tr_outputs_all = torch.cat(
(t_adv_tr_outputs_all, valid_adv_tr_logits), dim=0
)
label_test_acc = n_label_test_correct / n_label_test_total
label_test_f1 = metrics.f1_score(
t_label_targets_all.cpu(),
torch.argmax(t_label_outputs_all.cpu(), -1),
labels=list(range(self.config.class_dim)),
average=self.config.get("f1_average", "macro"),
)
if self.config.args.get("show_metric", False):
fprint(
"\n---------------------------- Standard Classification Report ----------------------------\n"
)
fprint(
metrics.classification_report(
t_label_targets_all.cpu(),
torch.argmax(t_label_outputs_all.cpu(), -1),
target_names=[
self.config.index_to_label[x]
for x in sorted(self.config.index_to_label.keys())
if x != -100
],
)
)
fprint(
"\n---------------------------- Standard Classification Report ----------------------------\n"
)
adv_det_test_acc = n_adv_det_test_correct / n_adv_det_test_total
adv_det_test_f1 = metrics.f1_score(
t_adv_det_targets_all.cpu(),
torch.argmax(t_adv_det_outputs_all, -1).cpu(),
labels=list(range(self.config.adv_det_dim)),
average=self.config.get("f1_average", "macro"),
)
adv_tr_test_acc = n_adv_tr_test_correct / n_adv_tr_test_total
adv_tr_test_f1 = metrics.f1_score(
t_adv_tr_targets_all.cpu(),
torch.argmax(t_adv_tr_outputs_all, -1).cpu(),
labels=list(range(self.config.class_dim)),
average=self.config.get("f1_average", "macro"),
)
return (
label_test_acc,
label_test_f1,
adv_det_test_acc,
adv_det_test_f1,
adv_tr_test_acc,
adv_tr_test_f1,
)
[docs]
def run(self):
# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
return self._train(criterion)