pyabsa.tasks.TextClassification.prediction.text_classifier¶
Classes¶
High-level predictor for Text Classification. |
|
High-level predictor for Text Classification. |
Module Contents¶
- class pyabsa.tasks.TextClassification.prediction.text_classifier.TextClassifier(checkpoint=None, cal_perplexity=False, **kwargs)¶
Bases:
pyabsa.framework.prediction_class.predictor_template.InferenceModelHigh-level predictor for Text Classification.
Loads a trained text classification checkpoint (BERT-based or GloVe-based) and provides convenient inference APIs for single sentences and batch datasets. When gold labels are present, evaluation metrics are printed.
- task_code = 'TC'¶
- _log_write_args()¶
- batch_infer(target_file=None, print_result=True, save_result=False, ignore_error=True, defense: str = None, **kwargs)¶
Deprecated alias of batch_predict.
- Parameters:
target_file – Path to the input file or directory.
print_result – Whether to print formatted results.
save_result – Whether to save JSON results.
ignore_error – Skip malformed lines instead of raising errors.
defense – Optional adversarial defense strategy.
**kwargs – Additional inference options.
- Returns:
Prediction results.
- Return type:
List[dict]
- infer(text: str | list = None, print_result=True, ignore_error=True, defense: str = None, **kwargs)¶
Deprecated alias of predict for single or multiple inputs.
- Parameters:
text – A string or list of strings to infer.
print_result – Whether to print formatted results.
ignore_error – Skip malformed inputs instead of raising errors.
defense – Optional adversarial defense strategy.
**kwargs – Additional inference options.
- Returns:
Prediction results.
- Return type:
dict or List[dict]
- batch_predict(target_file=None, print_result=True, save_result=False, ignore_error=True, **kwargs)¶
Run text classification inference on a dataset file or directory.
- Parameters:
target_file – Path to a file or directory containing inputs.
print_result – Print formatted results to stdout.
save_result – Save JSON results to the working directory.
ignore_error – Skip malformed lines instead of raising errors.
**kwargs – Additional options, e.g., eval_batch_size.
- Returns:
Inference results.
- Return type:
List[dict]
- predict(text: str | list = None, print_result=True, ignore_error=True, **kwargs)¶
Predict labels for a string or a list of strings.
- Parameters:
text – Single text or a list of texts to classify.
print_result – Print formatted results to stdout.
ignore_error – Skip malformed inputs instead of raising errors.
**kwargs – Additional options, e.g., eval_batch_size.
- Returns:
A single result for string input, otherwise a list of results.
- Return type:
dict or List[dict]
- _run_prediction(save_path=None, print_result=True)¶
Internal prediction loop for text classification.
Executes the model over self.infer_dataloader, collects logits, computes predictions, and optionally prints and saves results. When reference labels are present, prints a classification report and confusion matrix.
- Parameters:
save_path – Optional path to save JSON results.
print_result – Whether to print formatted results to stdout.
- Returns:
Inference results.
- Return type:
List[dict]
- clear_input_samples()¶
Clear any previously prepared inference samples/dataset cache.
- class pyabsa.tasks.TextClassification.prediction.text_classifier.Predictor(checkpoint=None, cal_perplexity=False, **kwargs)¶
Bases:
TextClassifierHigh-level predictor for Text Classification.
Loads a trained text classification checkpoint (BERT-based or GloVe-based) and provides convenient inference APIs for single sentences and batch datasets. When gold labels are present, evaluation metrics are printed.