Source code for pyabsa.tasks.UniversalSentimentAnalysis.instructor.instructor

# -*- coding: utf-8 -*-
# file: apc_instructor.py
# time: 2021/4/22 0022
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# Copyright (C) 2021. All Rights Reserved.
import os
import pickle

from transformers import AutoTokenizer, DataCollatorForSeq2Seq

from pyabsa.framework.instructor_class.instructor_template import BaseTrainingInstructor
from pyabsa.tasks.UniversalSentimentAnalysis.dataset_utils.data_utils_for_training import (
    USATrainingDataset,
)
from pyabsa.utils.pyabsa_utils import fprint, print_args


[docs] class USATrainingInstructor(BaseTrainingInstructor):
[docs] def _load_dataset_and_prepare_dataloader(self): cache_path = self.load_cache_dataset() # init BERT-based model and dataset self.tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained_bert) self.data_collator = DataCollatorForSeq2Seq(self.tokenizer) self.config.tokenizer = self.tokenizer if not os.path.exists(cache_path) or self.config.overwrite_cache: self.train_set = USATrainingDataset( self.config, self.tokenizer, dataset_type="train" ).tokenized_dataset self.test_set = USATrainingDataset( self.config, self.tokenizer, dataset_type="test" ).tokenized_dataset self.valid_set = USATrainingDataset( self.config, self.tokenizer, dataset_type="valid" ).tokenized_dataset self.save_cache_dataset(cache_path) else: fprint("Loading dataset from cache file: %s" % cache_path) with open(cache_path, "rb") as cache_path: ( self.train_set, self.test_set, self.valid_set, self.config, ) = pickle.load(cache_path) # merge train datasets using datasets.DatasetDict self.datasets = { "train": self.train_set["train"], "test": self.test_set["test"], "valid": self.valid_set["valid"], } self.model = self.config.model(config=self.config)
def __init__(self, config): super().__init__(config) self._load_dataset_and_prepare_dataloader() print_args(self.config)
[docs] def run(self): return self.model.train(self.datasets)