Source code for pyabsa.tasks.TextClassification.instructor.classifier_instructor

# -*- coding: utf-8 -*-
# file: classifier_instructor.py
# time: 2021/4/22 0022
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# Copyright (C) 2021. All Rights Reserved.

import os
import random
import shutil
import time

import numpy as np
import torch
import torch.nn as nn
from findfile import find_file
from sklearn import metrics
from torch import cuda

from tqdm import tqdm
from transformers import AutoModel

from pyabsa import DeviceTypeOption
from pyabsa.framework.instructor_class.instructor_template import BaseTrainingInstructor
from ..dataset_utils.__classic__.data_utils_for_training import GloVeTCDataset
from ..dataset_utils.__plm__.data_utils_for_training import BERTTCDataset
from ..models import GloVeTCModelList, BERTTCModelList

import pytorch_warmup as warmup

from pyabsa.utils.file_utils.file_utils import save_model
from pyabsa.utils.pyabsa_utils import init_optimizer, fprint, rprint
from pyabsa.framework.tokenizer_class.tokenizer_class import (
    PretrainedTokenizer,
    Tokenizer,
    build_embedding_matrix,
)


[docs]class TCTrainingInstructor(BaseTrainingInstructor): def __init__(self, config): super().__init__(config) self._load_dataset_and_prepare_dataloader() self._init_misc()
[docs] def _init_misc(self): random.seed(self.config.seed) np.random.seed(self.config.seed) torch.manual_seed(self.config.seed) torch.cuda.manual_seed(self.config.seed) self.config.inputs_cols = self.model.inputs self.config.device = torch.device(self.config.device) # use DataParallel for trainer if device count larger than 1 if self.config.auto_device == DeviceTypeOption.ALL_CUDA: self.model.to(self.config.device) self.model = torch.nn.parallel.DataParallel(self.model).module else: self.model.to(self.config.device) self.optimizer = init_optimizer(self.config.optimizer)( self.model.parameters(), lr=self.config.learning_rate, weight_decay=self.config.l2reg, ) self.train_dataloaders = [] self.valid_dataloaders = [] if os.path.exists("./init_state_dict.bin"): os.remove("./init_state_dict.bin") if self.config.cross_validate_fold > 0: torch.save(self.model.state_dict(), "./init_state_dict.bin") self.config.device = torch.device(self.config.device) if self.config.device.type == DeviceTypeOption.CUDA: self.logger.info( "cuda memory allocated:{}".format( torch.cuda.memory_allocated(device=self.config.device)
) )
[docs] def _load_dataset_and_prepare_dataloader(self): cache_path = self.load_cache_dataset() # init BERT-based model and dataset if hasattr(BERTTCModelList, self.config.model.__name__): self.tokenizer = PretrainedTokenizer(self.config) if not os.path.exists(cache_path) or self.config.overwrite_cache: self.train_set = BERTTCDataset( self.config, self.tokenizer, dataset_type="train" ) self.test_set = BERTTCDataset( self.config, self.tokenizer, dataset_type="test" ) self.valid_set = BERTTCDataset( self.config, self.tokenizer, dataset_type="valid" ) try: self.bert = AutoModel.from_pretrained(self.config.pretrained_bert) except ValueError as e: fprint("Init pretrained model failed, exception: {}".format(e)) # init the model behind the construction of datasets in case of updating output_dim self.model = self.config.model(self.bert, self.config).to( self.config.device ) elif hasattr(GloVeTCModelList, self.config.model.__name__): # init GloVe-based model and dataset self.tokenizer = Tokenizer.build_tokenizer( config=self.config, cache_path="{0}_tokenizer.dat".format( os.path.basename(self.config.dataset_name) ), ) self.embedding_matrix = build_embedding_matrix( config=self.config, tokenizer=self.tokenizer, cache_path="{0}_{1}_embedding_matrix.dat".format( str(self.config.embed_dim), os.path.basename(self.config.dataset_name), ), ) self.train_set = GloVeTCDataset( self.config, self.tokenizer, dataset_type="train" ) self.test_set = GloVeTCDataset( self.config, self.tokenizer, dataset_type="test" ) self.valid_set = GloVeTCDataset( self.config, self.tokenizer, dataset_type="valid" ) self.model = self.config.model(self.embedding_matrix, self.config).to( self.config.device ) self.config.embedding_matrix = self.embedding_matrix self.config.tokenizer = self.tokenizer self.save_cache_dataset(cache_path)
[docs] def reload_model(self, ckpt="./init_state_dict.bin"): if os.path.exists(ckpt): self.model.load_state_dict( torch.load(find_file(ckpt, or_key=[".bin", "state_dict"]))
)
[docs] def _train(self, criterion): self._prepare_dataloader() if self.config.warmup_step >= 0: self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=len(self.train_dataloaders[0]) * self.config.num_epoch, ) self.warmup_scheduler = warmup.UntunedLinearWarmup(self.optimizer) if len(self.valid_dataloaders) > 1: return self._k_fold_train_and_evaluate(criterion) else: return self._train_and_evaluate(criterion)
[docs] def _train_and_evaluate(self, criterion): global_step = 0 max_fold_acc = 0 max_fold_f1 = 0 save_path = "{0}/{1}_{2}".format( self.config.model_path_to_save, self.config.model_name, self.config.dataset_name, ) losses = [] self.config.metrics_of_this_checkpoint = {"acc": 0, "f1": 0} self.config.max_test_metrics = {"max_test_acc": 0, "max_test_f1": 0} self.logger.info( "***** Running training for {} *****".format(self.config.task_name) ) self.logger.info("Training set examples = %d", len(self.train_set)) if self.valid_set: self.logger.info("Valid set examples = %d", len(self.valid_set)) if self.test_set: self.logger.info("Test set examples = %d", len(self.test_set)) self.logger.info("Batch size = %d", self.config.batch_size) self.logger.info( "Num steps = %d", len(self.train_dataloaders[0]) // self.config.batch_size * self.config.num_epoch, ) patience = self.config.patience + self.config.evaluate_begin if self.config.log_step < 0: self.config.log_step = ( len(self.train_dataloaders[0]) if self.config.log_step < 0 else self.config.log_step ) for epoch in range(self.config.num_epoch): patience -= 1 description = "Epoch:{} | Loss:{}".format(epoch, 0) iterator = tqdm(self.train_dataloaders[0], desc=description) for i_batch, sample_batched in enumerate(iterator): global_step += 1 # switch model to train mode, clear gradient accumulators self.model.train() self.optimizer.zero_grad() inputs = [ sample_batched[col].to(self.config.device) for col in self.config.inputs_cols ] outputs = self.model(inputs) targets = sample_batched["label"].to(self.config.device) if isinstance(outputs, dict) and "loss" in outputs: loss = outputs["loss"] else: loss = criterion(outputs, targets) losses.append(loss.item()) if self.scaler: self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: loss.backward() self.optimizer.step() if self.config.warmup_step >= 0: with self.warmup_scheduler.dampening(): self.lr_scheduler.step() # evaluate if test set is available if global_step % self.config.log_step == 0: if self.test_dataloader and epoch >= self.config.evaluate_begin: if self.valid_dataloader: test_acc, f1 = self._evaluate_acc_f1(self.valid_dataloader) else: test_acc, f1 = self._evaluate_acc_f1(self.test_dataloader) self.config.metrics_of_this_checkpoint["acc"] = test_acc self.config.metrics_of_this_checkpoint["f1"] = f1 if test_acc > max_fold_acc or f1 > max_fold_f1: if test_acc > max_fold_acc: patience = self.config.patience - 1 max_fold_acc = test_acc if f1 > max_fold_f1: max_fold_f1 = f1 patience = self.config.patience - 1 if self.config.model_path_to_save: if not os.path.exists(self.config.model_path_to_save): os.makedirs(self.config.model_path_to_save) if save_path: try: shutil.rmtree(save_path) # logger.info('Remove sub-optimal trained model:', save_path) except: # logger.info('Can not remove sub-optimal trained model:', save_path) pass save_path = "{0}/{1}_{2}_acc_{3}_f1_{4}/".format( self.config.model_path_to_save, self.config.model_name, self.config.dataset_name, round(test_acc * 100, 2), round(f1 * 100, 2), ) if ( test_acc > self.config.max_test_metrics["max_test_acc"] ): self.config.max_test_metrics[ "max_test_acc" ] = test_acc if f1 > self.config.max_test_metrics["max_test_f1"]: self.config.max_test_metrics["max_test_f1"] = f1 save_model( self.config, self.model, self.tokenizer, save_path ) postfix = "Dev Acc:{:>.2f}(max:{:>.2f}) Dev F1:{:>.2f}(max:{:>.2f})".format( test_acc * 100, max_fold_acc * 100, f1 * 100, max_fold_f1 * 100, ) iterator.set_postfix_str(postfix) elif self.config.save_mode and epoch >= self.config.evaluate_begin: save_model( self.config, self.model, self.tokenizer, save_path + "_{}/".format(loss.item()), ) else: if self.config.get("loss_display", "smooth") == "smooth": description = "Epoch:{:>3d} | Smooth Loss: {:>.4f}".format( epoch, round(np.nanmean(losses), 4) ) else: description = "Epoch:{:>3d} | Batch Loss: {:>.4f}".format( epoch, round(loss.item(), 4) ) iterator.set_description(description) iterator.refresh() if patience == 0: break if not self.valid_dataloader: self.config.MV.log_metric( self.config.model_name + "-" + self.config.dataset_name + "-" + self.config.pretrained_bert, "Max-Test-Acc w/o Valid Set", max_fold_acc * 100, ) self.config.MV.log_metric( self.config.model_name + "-" + self.config.dataset_name + "-" + self.config.pretrained_bert, "Max-Test-F1 w/o Valid Set", max_fold_f1 * 100, ) if self.valid_dataloader: fprint( "Loading best model: {} and evaluating on test set ".format(save_path) ) self.reload_model(find_file(save_path, ".state_dict")) max_fold_acc, max_fold_f1 = self._evaluate_acc_f1(self.test_dataloader) self.config.MV.log_metric( self.config.model_name + "-" + self.config.dataset_name + "-" + self.config.pretrained_bert, "Max-Test-Acc", max_fold_acc * 100, ) self.config.MV.log_metric( self.config.model_name + "-" + self.config.dataset_name + "-" + self.config.pretrained_bert, "Max-Test-F1", max_fold_f1 * 100, ) self.logger.info(self.config.MV.summary(no_print=True)) # self.logger.info(self.config.MV.short_summary(no_print=True)) if self.valid_dataloader or self.config.save_mode: del self.train_dataloaders del self.test_dataloader del self.valid_dataloader del self.model cuda.empty_cache() time.sleep(3) return save_path else: del self.train_dataloaders del self.test_dataloader del self.valid_dataloader cuda.empty_cache() time.sleep(3) return self.model, self.config, self.tokenizer
[docs] def _k_fold_train_and_evaluate(self, criterion): fold_test_acc = [] fold_test_f1 = [] save_path_k_fold = "" max_fold_acc_k_fold = 0 losses = [] self.config.metrics_of_this_checkpoint = {"acc": 0, "f1": 0} self.config.max_test_metrics = {"max_test_acc": 0, "max_test_f1": 0} for f, (train_dataloader, valid_dataloader) in enumerate( zip(self.train_dataloaders, self.valid_dataloaders) ): patience = self.config.patience + self.config.evaluate_begin if self.config.log_step < 0: self.config.log_step = ( len(self.train_dataloaders[0]) if self.config.log_step < 0 else self.config.log_step ) self.logger.info( "***** Running training for {} *****".format(self.config.task_name) ) self.logger.info("Training set examples = %d", len(self.train_set)) if self.valid_set: self.logger.info("Valid set examples = %d", len(self.valid_set)) if self.test_set: self.logger.info("Test set examples = %d", len(self.test_set)) self.logger.info("Batch size = %d", self.config.batch_size) self.logger.info( "Num steps = %d", len(train_dataloader) // self.config.batch_size * self.config.num_epoch, ) if len(self.train_dataloaders) > 1: self.logger.info( "No. {} trainer in {} folds".format( f + 1, self.config.cross_validate_fold ) ) global_step = 0 max_fold_acc = 0 max_fold_f1 = 0 save_path = "{0}/{1}_{2}".format( self.config.model_path_to_save, self.config.model_name, self.config.dataset_name, ) for epoch in range(self.config.num_epoch): patience -= 1 description = "Epoch:{} | Loss:{}".format(epoch, 0) iterator = tqdm(train_dataloader, desc=description) for i_batch, sample_batched in enumerate(iterator): global_step += 1 # switch model to train mode, clear gradient accumulators self.model.train() self.optimizer.zero_grad() inputs = [ sample_batched[col].to(self.config.device) for col in self.config.inputs_cols ] with torch.cuda.amp.autocast(): if self.config.use_amp: with torch.cuda.amp.autocast(): outputs = self.model(inputs) else: outputs = self.model(inputs) targets = sample_batched["label"].to(self.config.device) if isinstance(outputs, dict) and "loss" in outputs: loss = outputs["loss"] else: loss = criterion(outputs, targets) losses.append(loss.item()) if self.config.use_amp and self.scaler: self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: loss.backward() self.optimizer.step() if self.config.warmup_step >= 0: with self.warmup_scheduler.dampening(): self.lr_scheduler.step() # evaluate if test set is available if global_step % self.config.log_step == 0: if self.test_dataloader and epoch >= self.config.evaluate_begin: test_acc, f1 = self._evaluate_acc_f1(valid_dataloader) self.config.metrics_of_this_checkpoint["acc"] = test_acc self.config.metrics_of_this_checkpoint["f1"] = f1 if test_acc > max_fold_acc or f1 > max_fold_f1: if test_acc > max_fold_acc: patience = self.config.patience - 1 max_fold_acc = test_acc if f1 > max_fold_f1: max_fold_f1 = f1 patience = self.config.patience - 1 if self.config.model_path_to_save: if not os.path.exists( self.config.model_path_to_save ): os.makedirs(self.config.model_path_to_save) if save_path: try: shutil.rmtree(save_path) # logger.info('Remove sub-optimal trained model:', save_path) except: # logger.info('Can not remove sub-optimal trained model:', save_path) pass save_path = "{0}/{1}_{2}_acc_{3}_f1_{4}/".format( self.config.model_path_to_save, self.config.model_name, self.config.dataset_name, round(test_acc * 100, 2), round(f1 * 100, 2), ) if ( test_acc > self.config.max_test_metrics["max_test_acc"] ): self.config.max_test_metrics[ "max_test_acc" ] = test_acc if f1 > self.config.max_test_metrics["max_test_f1"]: self.config.max_test_metrics["max_test_f1"] = f1 save_model( self.config, self.model, self.tokenizer, save_path, ) postfix = "Dev Acc:{:>.2f}(max:{:>.2f}) Dev F1:{:>.2f}(max:{:>.2f})".format( test_acc * 100, max_fold_acc * 100, f1 * 100, max_fold_f1 * 100, ) iterator.set_postfix_str(postfix) if ( self.config.save_mode and epoch >= self.config.evaluate_begin ): save_model( self.config, self.model, self.tokenizer, save_path + "_{}/".format(loss.item()), ) else: if self.config.get("loss_display", "smooth") == "smooth": description = "Epoch:{:>3d} | Smooth Loss: {:>.4f}".format( epoch, round(np.nanmean(losses), 4) ) else: description = "Epoch:{:>3d} | Batch Loss: {:>.4f}".format( epoch, round(loss.item(), 4) ) iterator.set_description(description) iterator.refresh() if patience == 0: break max_fold_acc, max_fold_f1 = self._evaluate_acc_f1(self.test_dataloader) if max_fold_acc > max_fold_acc_k_fold: save_path_k_fold = save_path fold_test_acc.append(max_fold_acc) fold_test_f1.append(max_fold_f1) self.config.MV.log_metric( self.config.model_name, "Fold{}-Max-Valid-Acc".format(f), max_fold_acc * 100, ) self.config.MV.log_metric( self.config.model_name, "Fold{}-Max-Valid-F1".format(f), max_fold_f1 * 100, ) # self.logger.info(self.config.MV.summary(no_print=True)) self.logger.info(self.config.MV.raw_summary(no_print=True)) if os.path.exists("./init_state_dict.bin"): self.reload_model() max_test_acc = np.max(fold_test_acc) max_test_f1 = np.mean(fold_test_f1) self.config.MV.log_metric( self.config.model_name + "-" + self.config.dataset_name + "-" + self.config.pretrained_bert, "Max-Test-Acc", max_test_acc * 100, ) self.config.MV.log_metric( self.config.model_name + "-" + self.config.dataset_name + "-" + self.config.pretrained_bert, "Max-Test-F1", max_test_f1 * 100, ) if self.config.cross_validate_fold > 0: # self.logger.info(self.config.MV.summary(no_print=True)) self.logger.info(self.config.MV.raw_summary(no_print=True)) # self.config.MV.summary() self.reload_model(save_path_k_fold) if self.valid_dataloader or self.config.save_mode: del self.train_dataloaders del self.test_dataloader del self.valid_dataloaders del self.model cuda.empty_cache() time.sleep(3) return save_path_k_fold else: # direct return model if do not evaluate if self.config.model_path_to_save: save_path_k_fold = "{0}/{1}/".format( self.config.model_path_to_save, self.config.model_name, ) save_model(self.config, self.model, self.tokenizer, save_path_k_fold) del self.train_dataloaders del self.test_dataloader del self.valid_dataloaders cuda.empty_cache() time.sleep(3) return self.model, self.config, self.tokenizer
[docs] def _evaluate_acc_f1(self, test_dataloader): # switch model to evaluation mode self.model.eval() n_test_correct, n_test_total = 0, 0 t_targets_all, t_outputs_all = None, None with torch.no_grad(): for t_batch, t_sample_batched in enumerate(test_dataloader): t_inputs = [ t_sample_batched[col].to(self.config.device) for col in self.config.inputs_cols ] t_targets = t_sample_batched["label"].to(self.config.device) sen_outputs = self.model(t_inputs) n_test_correct += ( (torch.argmax(sen_outputs, -1) == t_targets).sum().item() ) n_test_total += len(sen_outputs) if t_targets_all is None: t_targets_all = t_targets t_outputs_all = sen_outputs else: t_targets_all = torch.cat((t_targets_all, t_targets), dim=0) t_outputs_all = torch.cat((t_outputs_all, sen_outputs), dim=0) test_acc = n_test_correct / n_test_total f1 = metrics.f1_score( t_targets_all.cpu(), torch.argmax(t_outputs_all.cpu(), -1), labels=list(range(self.config.output_dim)), average=self.config.get("f1_average", "macro"), ) if self.config.args.get("show_metric", False): report = metrics.classification_report( t_targets_all.cpu(), torch.argmax(t_outputs_all.cpu(), -1), digits=4, target_names=[ self.config.index_to_label[x] for x in sorted(self.config.index_to_label.keys()) ], ) fprint( "\n---------------------------- Classification Report ----------------------------\n" ) rprint(report) fprint( "\n---------------------------- Classification Report ----------------------------\n" ) report = metrics.confusion_matrix( t_targets_all.cpu(), torch.argmax(t_outputs_all.cpu(), -1), labels=[ self.config.label_to_index[x] for x in self.config.label_to_index ], ) fprint( "\n---------------------------- Confusion Matrix ----------------------------\n" ) rprint(report) fprint( "\n---------------------------- Confusion Matrix ----------------------------\n" ) return test_acc, f1
[docs] def run(self): # Loss and Optimizer criterion = nn.CrossEntropyLoss() return self._train(criterion)