pyabsa.tasks.TextClassification.prediction.text_classifier

Classes

TextClassifier

High-level predictor for Text Classification.

Predictor

High-level predictor for Text Classification.

Module Contents

class pyabsa.tasks.TextClassification.prediction.text_classifier.TextClassifier(checkpoint=None, cal_perplexity=False, **kwargs)

Bases: pyabsa.framework.prediction_class.predictor_template.InferenceModel

High-level predictor for Text Classification.

Loads a trained text classification checkpoint (BERT-based or GloVe-based) and provides convenient inference APIs for single sentences and batch datasets. When gold labels are present, evaluation metrics are printed.

task_code = 'TC'
_log_write_args()
batch_infer(target_file=None, print_result=True, save_result=False, ignore_error=True, defense: str = None, **kwargs)

Deprecated alias of batch_predict.

Parameters:
  • target_file – Path to the input file or directory.

  • print_result – Whether to print formatted results.

  • save_result – Whether to save JSON results.

  • ignore_error – Skip malformed lines instead of raising errors.

  • defense – Optional adversarial defense strategy.

  • **kwargs – Additional inference options.

Returns:

Prediction results.

Return type:

List[dict]

infer(text: str | list = None, print_result=True, ignore_error=True, defense: str = None, **kwargs)

Deprecated alias of predict for single or multiple inputs.

Parameters:
  • text – A string or list of strings to infer.

  • print_result – Whether to print formatted results.

  • ignore_error – Skip malformed inputs instead of raising errors.

  • defense – Optional adversarial defense strategy.

  • **kwargs – Additional inference options.

Returns:

Prediction results.

Return type:

dict or List[dict]

batch_predict(target_file=None, print_result=True, save_result=False, ignore_error=True, **kwargs)

Run text classification inference on a dataset file or directory.

Parameters:
  • target_file – Path to a file or directory containing inputs.

  • print_result – Print formatted results to stdout.

  • save_result – Save JSON results to the working directory.

  • ignore_error – Skip malformed lines instead of raising errors.

  • **kwargs – Additional options, e.g., eval_batch_size.

Returns:

Inference results.

Return type:

List[dict]

predict(text: str | list = None, print_result=True, ignore_error=True, **kwargs)

Predict labels for a string or a list of strings.

Parameters:
  • text – Single text or a list of texts to classify.

  • print_result – Print formatted results to stdout.

  • ignore_error – Skip malformed inputs instead of raising errors.

  • **kwargs – Additional options, e.g., eval_batch_size.

Returns:

A single result for string input, otherwise a list of results.

Return type:

dict or List[dict]

_run_prediction(save_path=None, print_result=True)

Internal prediction loop for text classification.

Executes the model over self.infer_dataloader, collects logits, computes predictions, and optionally prints and saves results. When reference labels are present, prints a classification report and confusion matrix.

Parameters:
  • save_path – Optional path to save JSON results.

  • print_result – Whether to print formatted results to stdout.

Returns:

Inference results.

Return type:

List[dict]

clear_input_samples()

Clear any previously prepared inference samples/dataset cache.

class pyabsa.tasks.TextClassification.prediction.text_classifier.Predictor(checkpoint=None, cal_perplexity=False, **kwargs)

Bases: TextClassifier

High-level predictor for Text Classification.

Loads a trained text classification checkpoint (BERT-based or GloVe-based) and provides convenient inference APIs for single sentences and batch datasets. When gold labels are present, evaluation metrics are printed.