Source code for pyabsa.tasks.ABSAInstruction.multitask_train

# -*- coding: utf-8 -*-
# file: train.py
# time: 11:30 2023/3/13
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2023. All Rights Reserved.
import os
import warnings

import findfile

warnings.filterwarnings("ignore")
import pandas as pd

from .model import T5Generator, T5Classifier
from .data_utils import InstructDatasetLoader, read_json

[docs] task_name = "multitask"
[docs] experiment_name = "instruction"
# model_checkpoint = 'allenai/tk-instruct-base-def-pos'
[docs] model_checkpoint = "kevinscaria/ate_tk-instruct-base-def-pos-neg-neut-combined"
# model_checkpoint = 'allenai/tk-instruct-large-def-pos' # model_checkpoint = 'allenai/tk-instruct-3b-def-pos' # model_checkpoint = 'google/mt5-base' print("Experiment Name: ", experiment_name) model_out_path = "checkpoints"
[docs] model_out_path = os.path.join( model_out_path, task_name, f"{model_checkpoint.replace('/', '')}-{experiment_name}" )
print("Model output path: ", model_out_path) # Load the data # id_train_file_path = './integrated_datasets' # id_test_file_path = './integrated_datasets'
[docs] id_train_file_path = "./integrated_datasets/acos_datasets/"
[docs] id_test_file_path = "./integrated_datasets/acos_datasets"
# id_train_file_path = './integrated_datasets/acos_datasets/501.Laptop14' # id_test_file_path = './integrated_datasets/acos_datasets/501.Laptop14' # id_train_file_path = './integrated_datasets/acos_datasets/504.Restaurant16' # id_test_file_path = './integrated_datasets/acos_datasets/504.Restaurant16' id_tr_df = read_json(id_train_file_path, "train") id_te_df = read_json(id_test_file_path, "test")
[docs] id_tr_df = pd.DataFrame(id_tr_df)
[docs] id_te_df = pd.DataFrame(id_te_df)
[docs] loader = InstructDatasetLoader(id_tr_df, id_te_df)
if loader.train_df_id is not None: loader.train_df_id = loader.prepare_instruction_dataloader(loader.train_df_id) if loader.test_df_id is not None: loader.test_df_id = loader.prepare_instruction_dataloader(loader.test_df_id) if loader.train_df_ood is not None: loader.train_df_ood = loader.prepare_instruction_dataloader(loader.train_df_ood) if loader.test_df_ood is not None: loader.test_df_ood = loader.prepare_instruction_dataloader(loader.test_df_ood) # Create T5 utils object
[docs] t5_exp = T5Generator(model_checkpoint)
# Tokenize Dataset id_ds, id_tokenized_ds, ood_ds, ood_tokenzed_ds = loader.create_datasets( t5_exp.tokenize_function_inputs ) # Training arguments
[docs] training_args = { "output_dir": model_out_path, "evaluation_strategy": "epoch", "save_strategy": "epoch", "learning_rate": 5e-5, "per_device_train_batch_size": 16, "per_device_eval_batch_size": 16, "num_train_epochs": 6, "weight_decay": 0.01, "warmup_ratio": 0.1, "load_best_model_at_end": True, "push_to_hub": False, "eval_accumulation_steps": 1, "predict_with_generate": True, "logging_steps": 1000000000, "use_mps_device": False, # 'fp16': True, "fp16": False, }
# Train model
[docs] model_trainer = t5_exp.train(id_tokenized_ds, **training_args)
# Model inference - Trainer object - (Pass model trainer as predictor) # model_checkpoint = findfile.find_cwd_dir('tk-instruct-base-def-pos') # t5_exp = T5Generator(model_checkpoint) # Get prediction labels - Training set
[docs] id_tr_pred_labels = t5_exp.get_labels( predictor=model_trainer, tokenized_dataset=id_tokenized_ds, sample_set="train", batch_size=16, )
[docs] id_tr_labels = [i.strip() for i in id_ds["train"]["labels"]]
# Get prediction labels - Testing set
[docs] id_te_pred_labels = t5_exp.get_labels( predictor=model_trainer, tokenized_dataset=id_tokenized_ds, sample_set="test", batch_size=16, )
[docs] id_te_labels = [i.strip() for i in id_ds["test"]["labels"]]
# # Compute Metrics # metrics = t5_exp.get_metrics(id_tr_labels, id_tr_pred_labels) # print('----------------------- Training Set Metrics -----------------------') # print(metrics) # # metrics = t5_exp.get_metrics(id_te_labels, id_te_pred_labels) # print('----------------------- Testing Set Metrics -----------------------') # print(metrics) # Compute Metrics metrics = t5_exp.get_classic_metrics(id_tr_labels, id_tr_pred_labels) print("----------------------- Classic Training Set Metrics -----------------------") print(metrics)
[docs] metrics = t5_exp.get_classic_metrics(id_te_labels, id_te_pred_labels)
print("----------------------- Classic Testing Set Metrics -----------------------") print(metrics)