Source code for pyabsa.framework.trainer_class.trainer_template

# -*- coding: utf-8 -*-
# file: trainer.py
# time: 02/11/2022 21:15
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# GScholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# ResearchGate: https://www.researchgate.net/profile/Heng-Yang-17/research
# Copyright (C) 2022. All Rights Reserved.
# Copyright (C) 2021. All Rights Reserved.
import os
import time
import warnings
from pathlib import Path
from typing import Union

import torch
import transformers
from metric_visualizer import MetricVisualizer
from torch import cuda
from transformers import AutoConfig

from pyabsa import __version__ as PyABSAVersion
from pyabsa.utils.logger.logger import get_logger
from pyabsa.utils.pyabsa_utils import set_device, fprint
from ..configuration_class.config_verification import config_check
from ..configuration_class.configuration_template import ConfigManager
from ..dataset_class.dataset_dict_class import DatasetDict
from ..flag_class.flag_template import DeviceTypeOption, ModelSaveOption
from ...utils.check_utils import query_local_datasets_version
from ...utils.data_utils.dataset_item import DatasetItem
from ...utils.data_utils.dataset_manager import detect_dataset

warnings.filterwarnings("once")


[docs] def init_config(config): # set device to be used for training and inference if ( not torch.cuda.device_count() > 1 and config.auto_device == DeviceTypeOption.ALL_CUDA ): fprint( "Cuda devices count <= 1, so reset auto_device=True to auto specify device" ) config.auto_device = True set_device(config, config.auto_device) # set model name config.model_name = ( config.model.__name__.lower() if not isinstance(config.model, list) else "ensemble_model" ) # if using a pretrained BERT model, set hidden_dim and embed_dim from the model's configuration if config.get("pretrained_bert", None): try: pretrain_config = AutoConfig.from_pretrained( config.pretrained_bert, trust_remote_code=True ) config.hidden_dim = pretrain_config.hidden_size config.embed_dim = pretrain_config.hidden_size except Exception as e: print(e) pass # if hidden_dim or embed_dim are not set, use default values of 768 elif not config.get("hidden_dim", None) or not config.get("embed_dim", None): if config.get("hidden_dim", None): config.embed_dim = config.hidden_dim elif config.get("embed_dim", None): config.hidden_dim = config.embed_dim else: config.hidden_dim = 768 config.embed_dim = 768 # set versions of PyABSA, Transformers, and Torch being used config.ABSADatasetsVersion = query_local_datasets_version() config.PyABSAVersion = PyABSAVersion config.TransformersVersion = transformers.__version__ config.TorchVersion = "{}+cuda{}".format( torch.version.__version__, torch.version.cuda ) # set dataset name based on the dataset object passed to the configuration if isinstance(config.dataset, DatasetItem): config.dataset_name = config.dataset.dataset_name elif isinstance(config.dataset, DatasetDict): config.dataset_name = config.dataset["dataset_name"] elif isinstance(config.dataset, str): dataset = DatasetItem("custom_dataset", config.dataset) config.dataset_name = dataset.dataset_name # create a MetricVisualizer object for logging metrics during training if "MV" not in config.args: config.MV = MetricVisualizer( name=config.model.__name__ + "-" + config.dataset_name, trial_tag="Model & Dataset", trial_tag_list=[config.model.__name__ + "-" + config.dataset_name], ) # set checkpoint save mode and run config checks checkpoint_save_mode = config.checkpoint_save_mode config.save_mode = checkpoint_save_mode config_check(config) # set up logging config.logger = get_logger( os.getcwd(), log_name=config.model_name, log_type="trainer" ) config.logger.info("PyABSA version: {}".format(config.PyABSAVersion)) config.logger.info("Transformers version: {}".format(config.TransformersVersion)) config.logger.info("Torch version: {}".format(config.TorchVersion)) config.logger.info("Device: {}".format(config.device_name)) # return the updated configuration object return config
[docs] class Trainer: """ Trainer class for training PyABSA models """ def __init__( self, config: ConfigManager = None, dataset: Union[DatasetItem, Path, str, DatasetDict] = None, from_checkpoint: Union[Path, str] = None, checkpoint_save_mode: Union[ ModelSaveOption, int ] = ModelSaveOption.SAVE_MODEL_STATE_DICT, auto_device: Union[str, bool] = DeviceTypeOption.AUTO, path_to_save: Union[Path, str] = None, load_aug=False, ): """ Init a trainer for trainer a APC, ATEPC, TC or TAD model, after trainer, you need to call load_trained_model() to get the trained model for inference. :param config: PyABSA.config.ConfigManager Configuration for training the model :param dataset: Union[DatasetItem, Path, str, DatasetDict] Name of the dataset, or a dataset_manager path, or a list of dataset_manager paths :param from_checkpoint: Union[Path, str] A checkpoint path to train based on :param checkpoint_save_mode: Union[ModelSaveOption, int] Save trained model to checkpoint, "checkpoint_save_mode=1" to save the state_dict, "checkpoint_save_mode=2" to save the whole model, "checkpoint_save_mode=3" to save the fine-tuned BERT, otherwise avoid saving checkpoint but return the trained model after trainer :param auto_device: Union[str, bool] True or False, otherwise 'allcuda', 'cuda:1', 'cpu' works :param path_to_save: Union[Path, str], optional Specify path to save checkpoints :param load_aug: bool, optional Load the available augmentation dataset if any """ self.config = config # Configuration for training the model self.config.dataset = dataset # Name of the dataset, or a dataset_manager path, or a list of dataset_manager paths self.config.from_checkpoint = ( from_checkpoint # A checkpoint path to train based on ) self.config.checkpoint_save_mode = ( checkpoint_save_mode # Save trained model to checkpoint ) self.config.auto_device = ( auto_device # True or False, otherwise 'allcuda', 'cuda:1', 'cpu' works ) self.config.path_to_save = path_to_save # Specify path to save checkpoints self.config.load_aug = ( load_aug # Load the available augmentation dataset if any ) self.config.inference_model = None # Inference model self.config = init_config(self.config) # Initialize configuration self.config.task_code = None # Task code self.config.task_name = None # Task name self.training_instructor = None # Training instructor self.inference_model_class = None # Inference model class self.inference_model = None # Inference model
[docs] def _run(self): """ just return the trained model for inference (e.g., polarity classification, aspect-term extraction) """ if isinstance(self.config.dataset, DatasetDict): self.config.dataset_dict = self.config.dataset else: # detect dataset dataset_file = detect_dataset( self.config.dataset, task_code=self.config.task_code, load_aug=self.config.load_aug, config=self.config, ) self.config.dataset_file = dataset_file if ( self.config.checkpoint_save_mode or self.config.dataset_file["valid"] or (self.config.get("data_dict") and self.config.dataset_dict["test"]) ): if self.config.path_to_save: self.config.model_path_to_save = self.config.path_to_save elif ( ( hasattr(self.config, "dataset_file") and "valid" in self.config.dataset_file ) or ( hasattr(self.config, "dataset_dict") and "valid" in self.config.dataset_dict ) ) and not self.config.checkpoint_save_mode: fprint( "Using Validation set needs to save checkpoint, turn on checkpoint-saving " ) self.config.model_path_to_save = "checkpoints" self.config.save_mode = 1 else: self.config.model_path_to_save = "checkpoints" else: self.config.model_path_to_save = None # set random seed if isinstance(self.config.seed, int): self.config.seed = [self.config.seed] seeds = self.config.seed # trainer using all random seeds model_path = [] model = None for i, s in enumerate(seeds): self.config.seed = s if self.config.checkpoint_save_mode: model_path.append(self.training_instructor(self.config).run()) else: # always return the last trained model if you don't save trained model model = self.inference_model_class( checkpoint=self.training_instructor(self.config).run() ) self.config.seed = seeds # remove logger while self.config.logger.handlers: self.config.logger.removeHandler(self.config.logger.handlers[0]) # set inference model load path if self.config.checkpoint_save_mode: if os.path.exists(max(model_path)): self.inference_model = max(model_path) else: self.inference_model = model
[docs] def load_trained_model(self): """ Load trained model for inference Returns: Inference model for the trained model. """ if not self.inference_model: fprint( "No trained model found, this could happen while trainer only using trainer set." ) elif isinstance(self.inference_model, str): # If the trained model is a path to a checkpoint, load the model using the checkpoint self.inference_model = self.inference_model_class( checkpoint=self.inference_model ) else: fprint("Trained model already loaded.") return self.inference_model
[docs] def destroy(self): """ Clear the inference model from memory and empty the CUDA cache. """ del self.inference_model cuda.empty_cache() time.sleep(3)