pyabsa.tasks.TextAdversarialDefense.instructor.tad_instructor

Module Contents

Classes

TADTrainingInstructor

class pyabsa.tasks.TextAdversarialDefense.instructor.tad_instructor.TADTrainingInstructor(config)[source]

Bases: pyabsa.framework.instructor_class.instructor_template.BaseTrainingInstructor

_init_misc()[source]

Initialize miscellaneous settings specific to the subclass implementation. This method should be implemented in a subclass.

_cache_or_load_dataset()[source]

Cache or load the dataset. This method should be implemented in a subclass.

_load_dataset_and_prepare_dataloader()[source]

Load the dataset and prepare the dataloader. This method should be implemented in a subclass.

reload_model_state_dict(ckpt='./init_state_dict.bin')[source]
prepare_dataloader(train_set)[source]
_train(criterion)[source]

Train the model on a given criterion.

Parameters:

criterion – The loss function used to train the model.

Returns:

If there is only one validation dataloader, return the training results. If there are more than one validation dataloaders, perform k-fold cross-validation and return the results.

_train_and_evaluate(criterion)[source]

Train and evaluate the model. This method should be implemented in a subclass.

abstract _k_fold_train_and_evaluate(criterion)[source]

Train and evaluate the model using k-fold cross validation. This method should be implemented in a subclass.

_evaluate_acc_f1(test_dataloader)[source]

Evaluate the accuracy and F1 score of the model. This method should be implemented in a subclass.

run()[source]