pyabsa.framework.checkpoint_class.checkpoint_template

Module Contents

Classes

CheckpointManager

ASTECheckpointManager

This class manages the checkpoints for Aspect Sentiment Term Extraction.

APCCheckpointManager

ATEPCCheckpointManager

This class manages the checkpoints for Aspect Term Extraction and Polarity Classification.

TADCheckpointManager

This class manages the checkpoints for text adversarial defense.

RNACCheckpointManager

This class manages the checkpoints for RNA sequence classification.

RNARCheckpointManager

This class manages the checkpoints for RNA sequence regression.

TCCheckpointManager

This class manages the checkpoints for text classification.

class pyabsa.framework.checkpoint_class.checkpoint_template.CheckpointManager[source]
parse_checkpoint(checkpoint: str | pathlib.Path = None, task_code: str = TaskCodeOption.Aspect_Polarity_Classification) str | pathlib.Path[source]

Parse a given checkpoint file path or name and returns the path of the checkpoint directory.

Parameters:
  • checkpoint (Union[str, Path], optional) – Zipped checkpoint name, checkpoint path, or checkpoint name queried from Google Drive. Defaults to None.

  • task_code (str, optional) – Task code, e.g. apc, atepc, tad, rnac_datasets, rnar, tc, etc. Defaults to TaskCodeOption.Aspect_Polarity_Classification.

Returns:

The path of the checkpoint directory.

Return type:

Path

Example

` manager = CheckpointManager() checkpoint_path = manager.parse_checkpoint("checkpoint.zip", "apc") `

_get_remote_checkpoint(checkpoint: str = 'multilingual', task_code: str = None) str[source]

Downloads a checkpoint file and returns the path of the downloaded checkpoint.

Parameters:
  • checkpoint (str, optional) – Zipped checkpoint name, checkpoint path, or checkpoint name queried from Google Drive. Defaults to “multilingual”.

  • task_code (str, optional) – Task code, e.g. apc, atepc, tad, rnac_datasets, rnar, tc, etc. Defaults to None.

Returns:

The path of the downloaded checkpoint.

Return type:

Path

Raises:

SystemExit – If the given checkpoint file is not found.

Example

` manager = CheckpointManager() checkpoint_path = manager._get_remote_checkpoint("multilingual", "apc") `

class pyabsa.framework.checkpoint_class.checkpoint_template.ASTECheckpointManager[source]

Bases: CheckpointManager

This class manages the checkpoints for Aspect Sentiment Term Extraction.

static get_aspect_sentiment_triplet_extractor(checkpoint: str | pathlib.Path = None, **kwargs) pyabsa.tasks.AspectSentimentTripletExtraction.AspectSentimentTripletExtractor[source]

Get an AspectExtractor object initialized with the given checkpoint for Aspect Sentiment Term Extraction.

Parameters:
  • checkpoint – A string or Path object indicating the path to the checkpoint or a zip file containing the checkpoint. If the checkpoint is not registered in PyABSA, it should be the name of the checkpoint queried from Google Drive.

  • kwargs – Additional keyword arguments to be passed to the AspectExtractor constructor.

Returns:

An AspectExtractor object initialized with the given checkpoint.

class pyabsa.framework.checkpoint_class.checkpoint_template.APCCheckpointManager[source]

Bases: CheckpointManager

static get_sentiment_classifier(checkpoint: str | pathlib.Path = None, **kwargs) pyabsa.tasks.AspectPolarityClassification.SentimentClassifier[source]

Returns a pre-trained aspect sentiment classification model.

Parameters:
  • checkpoint (Union[str, Path], optional) – A string specifying the path to a checkpoint or the name of a checkpoint registered in PyABSA. If None, the default checkpoint is used.

  • **kwargs – Additional keyword arguments.

Returns:

A pre-trained aspect sentiment classification model.

Return type:

SentimentClassifier

Example

from pyabsa import APCCheckpointManager

sentiment_classifier = APCCheckpointManager.get_sentiment_classifier()

class pyabsa.framework.checkpoint_class.checkpoint_template.ATEPCCheckpointManager[source]

Bases: CheckpointManager

This class manages the checkpoints for Aspect Term Extraction and Polarity Classification.

static get_aspect_extractor(checkpoint: str | pathlib.Path = None, **kwargs) pyabsa.tasks.AspectTermExtraction.AspectExtractor[source]

Get an AspectExtractor object initialized with the given checkpoint for Aspect Term Extraction and Polarity Classification.

Parameters:
  • checkpoint – A string or Path object indicating the path to the checkpoint or a zip file containing the checkpoint. If the checkpoint is not registered in PyABSA, it should be the name of the checkpoint queried from Google Drive.

  • kwargs – Additional keyword arguments to be passed to the function.

Returns:

An AspectExtractor object initialized with the given checkpoint.

class pyabsa.framework.checkpoint_class.checkpoint_template.TADCheckpointManager[source]

Bases: CheckpointManager

This class manages the checkpoints for text adversarial defense.

get_tad_text_classifier(**kwargs) pyabsa.tasks.TextAdversarialDefense.TADTextClassifier[source]

Return a TADTextClassifier object initialized with the specified checkpoint.

Parameters:

checkpoint (Union[str, Path], optional) – The path to the checkpoint, the name of the zipped checkpoint, or the name of the checkpoint queried from Google Drive. Defaults to None.

Returns:

A TADTextClassifier object initialized with the given checkpoint.

Return type:

TADTextClassifier

class pyabsa.framework.checkpoint_class.checkpoint_template.RNACCheckpointManager[source]

Bases: CheckpointManager

This class manages the checkpoints for RNA sequence classification.

static get_rna_classifier(checkpoint: str | pathlib.Path = None, **kwargs) pyabsa.tasks.RNAClassification.RNAClassifier[source]

This method returns an instance of the RNAClassifier class with a parsed checkpoint for RNA sequence classification.

Parameters:
  • checkpoint (Union[str, Path], optional) – The name of the zipped checkpoint or the path to the checkpoint file. If not provided, the default checkpoint will be used. Defaults to None.

  • **kwargs – Additional keyword arguments.

Returns:

An instance of the RNAClassifier class with a parsed checkpoint for RNA sequence classification.

Return type:

RNAClassifier

Raises:

ValueError – If the provided checkpoint is not found.

class pyabsa.framework.checkpoint_class.checkpoint_template.RNARCheckpointManager[source]

Bases: CheckpointManager

This class manages the checkpoints for RNA sequence regression.

static get_rna_regressor(checkpoint: str | pathlib.Path = None, **kwargs) pyabsa.tasks.RNARegression.RNARegressor[source]

Loads a pre-trained checkpoint for RNA sequence regression and returns an instance of the RNARegressor class that is ready to make predictions.

Parameters:

checkpoint (Union[str, Path]) – (Optional) The name of a zipped checkpoint file, the path to a checkpoint file, or the name of a checkpoint file that can be found in Google Drive. If checkpoint is not provided, the default checkpoint for RNA sequence regression will be loaded.

Returns:

An instance of the RNARegressor class that has been initialized with the specified checkpoint file.

Return type:

RNARegressor

class pyabsa.framework.checkpoint_class.checkpoint_template.TCCheckpointManager[source]

Bases: CheckpointManager

This class manages the checkpoints for text classification.

static get_text_classifier(checkpoint: str | pathlib.Path = None, **kwargs) pyabsa.tasks.TextClassification.TextClassifier[source]

Returns a TextClassifier instance loaded with a pre-trained checkpoint for text classification.

Parameters:
  • checkpoint (Union[str, Path], optional) – The name of a zipped checkpoint file, a path to a checkpoint file, or the name of a checkpoint registered in PyABSA. If None, the latest version of the default checkpoint will be used. Defaults to None.

  • **kwargs – Additional keyword arguments. Not used in this method.

Returns:

A TextClassifier instance loaded with the specified checkpoint.

Return type:

TextClassifier