pyabsa.framework.checkpoint_class.checkpoint_template
Module Contents
Classes
This class manages the checkpoints for Aspect Sentiment Term Extraction. |
|
This class manages the checkpoints for Aspect Term Extraction and Polarity Classification. |
|
This class manages the checkpoints for text adversarial defense. |
|
This class manages the checkpoints for RNA sequence classification. |
|
This class manages the checkpoints for RNA sequence regression. |
|
This class manages the checkpoints for text classification. |
- class pyabsa.framework.checkpoint_class.checkpoint_template.CheckpointManager[source]
- parse_checkpoint(checkpoint: Union[str, pathlib.Path] = None, task_code: str = TaskCodeOption.Aspect_Polarity_Classification) Union[str, pathlib.Path] [source]
Parse a given checkpoint file path or name and returns the path of the checkpoint directory.
- Parameters:
checkpoint (Union[str, Path], optional) – Zipped checkpoint name, checkpoint path, or checkpoint name queried from Google Drive. Defaults to None.
task_code (str, optional) – Task code, e.g. apc, atepc, tad, rnac_datasets, rnar, tc, etc. Defaults to TaskCodeOption.Aspect_Polarity_Classification.
- Returns:
The path of the checkpoint directory.
- Return type:
Path
Example
` manager = CheckpointManager() checkpoint_path = manager.parse_checkpoint("checkpoint.zip", "apc") `
- _get_remote_checkpoint(checkpoint: str = 'multilingual', task_code: str = None) str [source]
Downloads a checkpoint file and returns the path of the downloaded checkpoint.
- Parameters:
checkpoint (str, optional) – Zipped checkpoint name, checkpoint path, or checkpoint name queried from Google Drive. Defaults to “multilingual”.
task_code (str, optional) – Task code, e.g. apc, atepc, tad, rnac_datasets, rnar, tc, etc. Defaults to None.
- Returns:
The path of the downloaded checkpoint.
- Return type:
Path
- Raises:
SystemExit – If the given checkpoint file is not found.
Example
` manager = CheckpointManager() checkpoint_path = manager._get_remote_checkpoint("multilingual", "apc") `
- class pyabsa.framework.checkpoint_class.checkpoint_template.ASTECheckpointManager[source]
Bases:
CheckpointManager
This class manages the checkpoints for Aspect Sentiment Term Extraction.
- static get_aspect_sentiment_triplet_extractor(checkpoint: Union[str, pathlib.Path] = None, **kwargs) pyabsa.tasks.AspectSentimentTripletExtraction.AspectSentimentTripletExtractor [source]
Get an AspectExtractor object initialized with the given checkpoint for Aspect Sentiment Term Extraction.
- Parameters:
checkpoint – A string or Path object indicating the path to the checkpoint or a zip file containing the checkpoint. If the checkpoint is not registered in PyABSA, it should be the name of the checkpoint queried from Google Drive.
kwargs – Additional keyword arguments to be passed to the AspectExtractor constructor.
- Returns:
An AspectExtractor object initialized with the given checkpoint.
- class pyabsa.framework.checkpoint_class.checkpoint_template.APCCheckpointManager[source]
Bases:
CheckpointManager
- static get_sentiment_classifier(checkpoint: Union[str, pathlib.Path] = None, **kwargs) pyabsa.tasks.AspectPolarityClassification.SentimentClassifier [source]
Returns a pre-trained aspect sentiment classification model.
- Parameters:
checkpoint (Union[str, Path], optional) – A string specifying the path to a checkpoint or the name of a checkpoint registered in PyABSA. If None, the default checkpoint is used.
**kwargs – Additional keyword arguments.
- Returns:
A pre-trained aspect sentiment classification model.
- Return type:
Example
from pyabsa import APCCheckpointManager
sentiment_classifier = APCCheckpointManager.get_sentiment_classifier()
- class pyabsa.framework.checkpoint_class.checkpoint_template.ATEPCCheckpointManager[source]
Bases:
CheckpointManager
This class manages the checkpoints for Aspect Term Extraction and Polarity Classification.
- static get_aspect_extractor(checkpoint: Union[str, pathlib.Path] = None, **kwargs) pyabsa.tasks.AspectTermExtraction.AspectExtractor [source]
Get an AspectExtractor object initialized with the given checkpoint for Aspect Term Extraction and Polarity Classification.
- Parameters:
checkpoint – A string or Path object indicating the path to the checkpoint or a zip file containing the checkpoint. If the checkpoint is not registered in PyABSA, it should be the name of the checkpoint queried from Google Drive.
kwargs – Additional keyword arguments to be passed to the function.
- Returns:
An AspectExtractor object initialized with the given checkpoint.
- class pyabsa.framework.checkpoint_class.checkpoint_template.TADCheckpointManager[source]
Bases:
CheckpointManager
This class manages the checkpoints for text adversarial defense.
- get_tad_text_classifier(**kwargs) pyabsa.tasks.TextAdversarialDefense.TADTextClassifier [source]
Return a TADTextClassifier object initialized with the specified checkpoint.
- Parameters:
checkpoint (Union[str, Path], optional) – The path to the checkpoint, the name of the zipped checkpoint, or the name of the checkpoint queried from Google Drive. Defaults to None.
- Returns:
A TADTextClassifier object initialized with the given checkpoint.
- Return type:
- class pyabsa.framework.checkpoint_class.checkpoint_template.RNACCheckpointManager[source]
Bases:
CheckpointManager
This class manages the checkpoints for RNA sequence classification.
- static get_rna_classifier(checkpoint: Union[str, pathlib.Path] = None, **kwargs) pyabsa.tasks.RNAClassification.RNAClassifier [source]
This method returns an instance of the RNAClassifier class with a parsed checkpoint for RNA sequence classification.
- Parameters:
checkpoint (Union[str, Path], optional) – The name of the zipped checkpoint or the path to the checkpoint file. If not provided, the default checkpoint will be used. Defaults to None.
**kwargs – Additional keyword arguments.
- Returns:
An instance of the RNAClassifier class with a parsed checkpoint for RNA sequence classification.
- Return type:
- Raises:
ValueError – If the provided checkpoint is not found.
- class pyabsa.framework.checkpoint_class.checkpoint_template.RNARCheckpointManager[source]
Bases:
CheckpointManager
This class manages the checkpoints for RNA sequence regression.
- static get_rna_regressor(checkpoint: Union[str, pathlib.Path] = None, **kwargs) pyabsa.tasks.RNARegression.RNARegressor [source]
Loads a pre-trained checkpoint for RNA sequence regression and returns an instance of the RNARegressor class that is ready to make predictions.
- Parameters:
checkpoint (Union[str, Path]) – (Optional) The name of a zipped checkpoint file, the path to a checkpoint file, or the name of a checkpoint file that can be found in Google Drive. If checkpoint is not provided, the default checkpoint for RNA sequence regression will be loaded.
- Returns:
An instance of the RNARegressor class that has been initialized with the specified checkpoint file.
- Return type:
- class pyabsa.framework.checkpoint_class.checkpoint_template.TCCheckpointManager[source]
Bases:
CheckpointManager
This class manages the checkpoints for text classification.
- static get_text_classifier(checkpoint: Union[str, pathlib.Path] = None, **kwargs) pyabsa.tasks.TextClassification.TextClassifier [source]
Returns a TextClassifier instance loaded with a pre-trained checkpoint for text classification.
- Parameters:
checkpoint (Union[str, Path], optional) – The name of a zipped checkpoint file, a path to a checkpoint file, or the name of a checkpoint registered in PyABSA. If None, the latest version of the default checkpoint will be used. Defaults to None.
**kwargs – Additional keyword arguments. Not used in this method.
- Returns:
A TextClassifier instance loaded with the specified checkpoint.
- Return type: