Source code for pyabsa.tasks.UniversalSentimentAnalysis.prediction.predictor

# -*- coding: utf-8 -*-
# file:
# author: YANG, HENG <> (杨恒)
# Copyright (C) 2020. All Rights Reserved.
import os
import pickle
from typing import Union

import torch
from findfile import find_file
from torch.nn.utils.rnn import pad_sequence
from import DataLoader
from tqdm import tqdm

from pyabsa.framework.flag_class import TaskCodeOption
from pyabsa.framework.prediction_class.predictor_template import InferenceModel
from pyabsa.tasks.UniversalSentimentAnalysis.dataset_utils.data_utils_for_inference import (
from pyabsa.utils.data_utils.dataset_manager import detect_infer_dataset
from pyabsa.utils.pyabsa_utils import fprint, set_device

[docs] class USAPredictor(InferenceModel):
[docs] task_code = TaskCodeOption.Universal_Sentiment_Analysis
def __init__(self, checkpoint=None, **kwargs): super().__init__(checkpoint, task_code=self.task_code, **kwargs) # load from a trainer if self.checkpoint and isinstance(self.checkpoint, str): fprint("Load sentiment classifier from trainer") try: fprint("Load text classifier from", self.checkpoint) state_dict_path = find_file( self.checkpoint, key=".state_dict", exclude_key=["__MACOSX"] ) model_path = find_file( self.checkpoint, key=".model", exclude_key=["__MACOSX"] ) tokenizer_path = find_file( self.checkpoint, key=".tokenizer", exclude_key=["__MACOSX"] ) config_path = find_file( self.checkpoint, key=".config", exclude_key=["__MACOSX"] ) fprint("config: {}".format(config_path)) fprint("state_dict: {}".format(state_dict_path)) fprint("model: {}".format(model_path)) fprint("tokenizer: {}".format(tokenizer_path)) with open(config_path, mode="rb") as f: self.config = pickle.load(f) self.config.auto_device = kwargs.get("auto_device", True) set_device(self.config, self.config.auto_device) self.model = self.config.model(config=self.config) self.model.model.load_state_dict( torch.load(state_dict_path, map_location="cpu"), strict=False ) except Exception as e: raise RuntimeError( "Fail to load the model from {}! " "Please make sure the version of checkpoint and PyABSA are compatible." " Try to remove he checkpoint and download again" " \nException: {} ".format(checkpoint, e) ) self.dataset = USAInferenceDataset( self.config, self.config.tokenizer, dataset_type="test" )
[docs] def batch_infer( self, target_file=None, print_result=True, save_result=False, ignore_error=True, **kwargs ): """ A deprecated version of batch_predict method. Args: target_file (str): the path to the target file for inference print_result (bool): whether to print the result save_result (bool): whether to save the result ignore_error (bool): whether to ignore the error Returns: result (dict): a dictionary of the results """ return self.batch_predict( target_file=target_file, print_result=print_result, save_result=save_result, ignore_error=ignore_error, **kwargs )
[docs] def infer(self, text: str = None, print_result=True, ignore_error=True, **kwargs): """ A deprecated version of the predict method. Args: text (str): the text to predict print_result (bool): whether to print the result ignore_error (bool): whether to ignore the error Returns: result (dict): a dictionary of the results """ return self.predict( text=text, print_result=print_result, ignore_error=ignore_error, **kwargs )
[docs] def batch_predict( self, target_file=None, print_result=True, save_result=False, ignore_error=True, **kwargs ): """ Predict the sentiment from a file of sentences. param: target_file: the file path of the sentences to be predicted. param: print_result: whether to print the result. param: save_result: whether to save the result. param: ignore_error: whether to ignore the error when predicting. param: kwargs: other parameters. """ self.config.eval_batch_size = kwargs.get("eval_batch_size", 32) save_path = os.path.join( os.getcwd(), "{}.{}.result.json".format( self.config.task_name, self.config.model.__name__ ), ) target_file = detect_infer_dataset( target_file, task_code=TaskCodeOption.Aspect_Sentiment_Triplet_Extraction ) if not target_file: raise FileNotFoundError("Can not find inference datasets!") self.dataset.prepare_infer_dataset(target_file, ignore_error=ignore_error) return self._run_prediction( save_path=save_path if save_result else None, print_result=print_result )
[docs] def predict( self, text: Union[str, list] = None, print_result=True, ignore_error=True, **kwargs ): """ Predict the sentiment from a sentence or a list of sentences. param: text: the sentence to be predicted. param: print_result: whether to print the result. param: ignore_error: whether to ignore the error when predicting. param: kwargs: other parameters. """ self.config.eval_batch_size = kwargs.get("eval_batch_size", 32) if text: self.dataset.prepare_infer_dataset(text, ignore_error=ignore_error) else: raise RuntimeError("Please specify your datasets path!") if isinstance(text, str): try: return self._run_prediction(print_result=print_result, **kwargs)[0] except Exception as e: return { "text": text, "output": None, "error": str(e), "error_type": "RuntimeError", } else: return self._run_prediction(print_result=print_result, **kwargs)
[docs] def _run_prediction(self, save_path=None, print_result=True, **kwargs): self.model.model.eval() all_results = [] with torch.no_grad(): def collate_fn(batch): input_ids = [torch.tensor(example["input_ids"]) for example in batch] input_ids = pad_sequence( input_ids, batch_first=True, padding_value=self.config.tokenizer.pad_token_id, ) return input_ids dataloader = DataLoader( self.dataset.tokenized_dataset["test"], batch_size=self.config.batch_size, collate_fn=collate_fn, ) predicted_output = [] print("Model loaded to: ", self.config.device) for batch in tqdm(dataloader): batch = output_ids = self.model.model.generate(batch) output_texts = self.config.tokenizer.batch_decode( output_ids, skip_special_tokens=True ) for output_text in output_texts: predicted_output.append( # self.config.usa_instructor.decode_input(output_text) output_text ) return predicted_output
[docs] def clear_input_samples(self): self.dataset.all_data = []
[docs] class Predictor(USAPredictor): pass