pyabsa.framework.instructor_class.instructor_template

Classes

BaseTrainingInstructor

Functions

get_resume_checkpoint(config)

Module Contents

class pyabsa.framework.instructor_class.instructor_template.BaseTrainingInstructor(config)
config
logger
model = None
tokenizer = None
train_dataloader = None
valid_dataloader = None
test_dataloader = None
train_dataloaders = []
valid_dataloaders = []
test_dataloaders = []
train_set = None
valid_set = None
test_set = None
optimizer = None
initializer = None
lr_scheduler = None
warmup_scheduler = None
embedding_matrix = None
_reset_params()

Reset the parameters of the model before training.

_reload_model_state_dict(ckpt='./init_state_dict.bin')

Reload the model state dictionary from a checkpoint file. :param ckpt: The path to the checkpoint file.

load_cache_dataset(**kwargs)

Load the dataset from cache if it exists and not set to overwrite the cache. Otherwise, return None. :param kwargs: Additional keyword arguments. :return: The path to the cache file if it exists. Otherwise, return None.

save_cache_dataset(cache_path=None, **kwargs)

Save the dataset to cache for faster loading in the future. :param kwargs: Additional arguments for saving the dataset cache. :param cache_path: The path to the cache file. :return: The path to the saved cache file.

_prepare_dataloader()

Prepares the data loaders for training, validation, and testing.

_prepare_env()

Prepares the environment for training, including setting the tokenizer and embedding matrix, removing the initial state dictionary file if it exists, and setting up the model on the appropriate device.

_train(criterion)

Train the model on a given criterion.

Parameters:

criterion – The loss function used to train the model.

Returns:

If there is only one validation dataloader, return the training results. If there are more than one validation dataloaders, perform k-fold cross-validation and return the results.

abstract _init_misc()

Initialize miscellaneous settings specific to the subclass implementation. This method should be implemented in a subclass.

abstract _cache_or_load_dataset()

Cache or load the dataset. This method should be implemented in a subclass.

abstract _train_and_evaluate(criterion)

Train and evaluate the model. This method should be implemented in a subclass.

abstract _k_fold_train_and_evaluate(criterion)

Train and evaluate the model using k-fold cross validation. This method should be implemented in a subclass.

abstract _evaluate_acc_f1(test_dataloader)

Evaluate the accuracy and F1 score of the model. This method should be implemented in a subclass.

abstract _load_dataset_and_prepare_dataloader()

Load the dataset and prepare the dataloader. This method should be implemented in a subclass.

_resume_from_checkpoint()

Resumes training from a checkpoint if a valid checkpoint path is provided in the configuration file, by loading the model, state dictionary, and configuration from the checkpoint files.

pyabsa.framework.instructor_class.instructor_template.get_resume_checkpoint(config)