pyabsa.framework.checkpoint_class.checkpoint_utils

Module Contents

Functions

parse_checkpoint_info(t_checkpoint_map, task_code[, ...])

Prints available model checkpoints for a given task and version.

available_checkpoints(→ Union[Dict[str, Any], ...)

Retrieves the available checkpoints for a given task.

download_checkpoint(→ str)

Download a pretrained checkpoint for a given task and language.

pyabsa.framework.checkpoint_class.checkpoint_utils.parse_checkpoint_info(t_checkpoint_map, task_code, show_ckpts=False)[source]

Prints available model checkpoints for a given task and version.

Parameters:
  • t_checkpoint_map – A dictionary of checkpoint information.

  • task_code – A string representing the task code (e.g. apc, atepc, tad, rnac_datasets, rnar, tc, etc.).

  • show_ckpts – A boolean flag indicating whether to show checkpoint information.

Returns:

A dictionary of checkpoint information.

pyabsa.framework.checkpoint_class.checkpoint_utils.available_checkpoints(task_code: pyabsa.framework.flag_class.TaskCodeOption = None, show_ckpts: bool = False) Dict[str, Any] | Dict[str, Dict[str, Any]][source]

Retrieves the available checkpoints for a given task.

Parameters:
  • task_code – The code of the task. It should be one of the constants in TaskCodeOption, e.g. TaskCodeOption.Aspect_Polarity_Classification. see TaskCodeOption: from pyabsa import TaskCodeOption TaskCodeOption.Aspect_Polarity_Classification TaskCodeOption.Aspect_Term_Extraction_and_Classification TaskCodeOption.Sentiment_Analysis TaskCodeOption.Text_Classification TaskCodeOption.Text_Adversarial_Defense

  • show_ckpts – A flag indicating whether to show detailed information about the checkpoints.

  • task_code

  • show_ckpts – show all checkpoints

Returns:

A dictionary with the available checkpoints for the specified task. If no task code is provided, a dictionary with all available checkpoints is returned.

pyabsa.framework.checkpoint_class.checkpoint_utils.download_checkpoint(task: str, language: str, checkpoint: dict) str[source]

Download a pretrained checkpoint for a given task and language. The download_checkpoint() function downloads a checkpoint from a given URL using the requests library. It saves the downloaded checkpoint to a temporary directory with a name that corresponds to the task and language. If the checkpoint has already been downloaded and saved in the temporary directory, the function simply returns the directory path. The function then unzips the downloaded checkpoint file, removes the zip file and returns the directory path of the unzipped checkpoint. If the download is unsuccessful, a ConnectionError is raised.

Parameters:
  • task – A string representing the task to download the checkpoint for (e.g. “sentiment_analysis”).

  • language – A string representing the language to download the checkpoint for (e.g. “english”).

  • checkpoint – A dictionary containing the information about the checkpoint to download.

Returns:

A string representing the path to the downloaded checkpoint.