# -*- coding: utf-8 -*-
# file: checkpoint_template.py
# time: 2021/6/11 0011
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# Copyright (C) 2021. All Rights Reserved.
import os
from pathlib import Path
from typing import Union
from findfile import find_file
from termcolor import colored
from pyabsa import TaskCodeOption
from pyabsa.framework.checkpoint_class.checkpoint_utils import (
available_checkpoints,
download_checkpoint,
)
from pyabsa.utils.file_utils.file_utils import unzip_checkpoint
from pyabsa.utils.pyabsa_utils import fprint
from pyabsa.tasks.AspectPolarityClassification import SentimentClassifier
from pyabsa.tasks.AspectTermExtraction import AspectExtractor
from pyabsa.tasks.TextAdversarialDefense import TADTextClassifier
from pyabsa.tasks.RNAClassification import RNAClassifier
from pyabsa.tasks.RNARegression import RNARegressor
from pyabsa.tasks.TextClassification import TextClassifier
from pyabsa.tasks.AspectSentimentTripletExtraction import (
AspectSentimentTripletExtractor,
)
[docs]
class CheckpointManager:
[docs]
def parse_checkpoint(
self,
checkpoint: Union[str, Path] = None,
task_code: str = TaskCodeOption.Aspect_Polarity_Classification,
) -> Union[str, Path]:
"""
Parse a given checkpoint file path or name and returns the path of the checkpoint directory.
Args:
checkpoint (Union[str, Path], optional): Zipped checkpoint name, checkpoint path, or checkpoint name queried from Google Drive. Defaults to None.
task_code (str, optional): Task code, e.g. apc, atepc, tad, rnac_datasets, rnar, tc, etc. Defaults to TaskCodeOption.Aspect_Polarity_Classification.
Returns:
Path: The path of the checkpoint directory.
Example:
```
manager = CheckpointManager()
checkpoint_path = manager.parse_checkpoint("checkpoint.zip", "apc")
```
"""
if isinstance(checkpoint, str) or isinstance(checkpoint, Path):
# directly load checkpoint from local path
if os.path.exists(checkpoint):
return checkpoint
try:
self._get_remote_checkpoint(checkpoint, task_code)
except Exception as e:
fprint(
"No checkpoint found in Model Hub for task: {}".format(checkpoint)
)
if find_file(os.getcwd(), [checkpoint, task_code, ".config"]):
# load checkpoint from current working directory with task specified
checkpoint_config = find_file(
os.getcwd(), [checkpoint, task_code, ".config"]
)
else:
# load checkpoint from current working directory without task specified
checkpoint_config = find_file(os.getcwd(), [checkpoint, ".config"])
if checkpoint_config:
# locate the checkpoint directory
checkpoint = os.path.dirname(checkpoint_config)
elif isinstance(checkpoint, str) and checkpoint.endswith(".zip"):
checkpoint = unzip_checkpoint(
checkpoint
if os.path.exists(checkpoint)
else find_file(os.getcwd(), checkpoint)
)
return checkpoint
[docs]
def _get_remote_checkpoint(
self, checkpoint: str = "multilingual", task_code: str = None
) -> str:
"""
Downloads a checkpoint file and returns the path of the downloaded checkpoint.
Args:
checkpoint (str, optional): Zipped checkpoint name, checkpoint path, or checkpoint name queried from Google Drive. Defaults to "multilingual".
task_code (str, optional): Task code, e.g. apc, atepc, tad, rnac_datasets, rnar, tc, etc. Defaults to None.
Returns:
Path: The path of the downloaded checkpoint.
Raises:
SystemExit: If the given checkpoint file is not found.
Example:
```
manager = CheckpointManager()
checkpoint_path = manager._get_remote_checkpoint("multilingual", "apc")
```
"""
available_checkpoint_by_task = available_checkpoints(task_code)
if checkpoint.lower() in [
k.lower() for k in available_checkpoint_by_task.keys()
]:
fprint(colored("Downloading checkpoint:{} ".format(checkpoint), "green"))
else:
fprint(
colored(
"Checkpoint:{} is not found, you can raise an issue for requesting shares of checkpoints".format(
checkpoint
),
"red",
)
)
return download_checkpoint(
task=task_code,
language=checkpoint.lower(),
checkpoint=available_checkpoint_by_task[checkpoint.lower()],
)
[docs]
class ASTECheckpointManager(CheckpointManager):
"""
This class manages the checkpoints for Aspect Sentiment Term Extraction.
"""
def __init__(self):
"""
Initializes an instance of the ASTECheckpointManager class.
"""
super(ASTECheckpointManager, self).__init__()
@staticmethod
[docs]
class APCCheckpointManager(CheckpointManager):
def __init__(self):
"""
Initializes an instance of the APCCheckpointManager class.
"""
super(APCCheckpointManager, self).__init__()
@staticmethod
[docs]
def get_sentiment_classifier(
checkpoint: Union[str, Path] = None, **kwargs
) -> "SentimentClassifier":
"""
Returns a pre-trained aspect sentiment classification model.
Args:
checkpoint (Union[str, Path], optional): A string specifying the path to a checkpoint or the name of a
checkpoint registered in PyABSA. If `None`, the default checkpoint is used.
**kwargs: Additional keyword arguments.
Returns:
SentimentClassifier: A pre-trained aspect sentiment classification model.
Example:
from pyabsa import APCCheckpointManager
sentiment_classifier = APCCheckpointManager.get_sentiment_classifier()
"""
return SentimentClassifier(
CheckpointManager().parse_checkpoint(
checkpoint, TaskCodeOption.Aspect_Polarity_Classification
)
)
[docs]
class ATEPCCheckpointManager(CheckpointManager):
"""
This class manages the checkpoints for Aspect Term Extraction and Polarity Classification.
"""
def __init__(self):
"""
Initializes an instance of the ATEPCCheckpointManager class.
"""
super(ATEPCCheckpointManager, self).__init__()
@staticmethod
[docs]
class TADCheckpointManager(CheckpointManager):
"""
This class manages the checkpoints for text adversarial defense.
"""
def __init__(self):
"""
Initializes an instance of the TADCheckpointManager class.
"""
super(TADCheckpointManager, self).__init__()
[docs]
def get_tad_text_classifier(
checkpoint: Union[str, Path] = None, **kwargs
) -> "TADTextClassifier":
"""
Return a TADTextClassifier object initialized with the specified checkpoint.
Args:
checkpoint (Union[str, Path], optional): The path to the checkpoint, the name of the zipped checkpoint, or
the name of the checkpoint queried from Google Drive. Defaults to None.
Returns:
TADTextClassifier: A TADTextClassifier object initialized with the given checkpoint.
"""
return TADTextClassifier(
CheckpointManager().parse_checkpoint(
checkpoint, TaskCodeOption.Text_Adversarial_Defense
)
)
[docs]
class RNACCheckpointManager(CheckpointManager):
"""
This class manages the checkpoints for RNA sequence classification.
"""
def __init__(self):
"""
Initializes an instance of the RNACCheckpointManager class.
"""
super(RNACCheckpointManager, self).__init__()
@staticmethod
[docs]
def get_rna_classifier(
checkpoint: Union[str, Path] = None, **kwargs
) -> "RNAClassifier":
"""
This method returns an instance of the RNAClassifier class with a parsed checkpoint for RNA sequence classification.
Args:
checkpoint (Union[str, Path], optional): The name of the zipped checkpoint or the path to the checkpoint file. If not provided, the default checkpoint will be used. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
RNAClassifier: An instance of the RNAClassifier class with a parsed checkpoint for RNA sequence classification.
Raises:
ValueError: If the provided checkpoint is not found.
"""
return RNAClassifier(
CheckpointManager().parse_checkpoint(
checkpoint, TaskCodeOption.RNASequenceClassification
)
)
[docs]
class RNARCheckpointManager(CheckpointManager):
"""
This class manages the checkpoints for RNA sequence regression.
"""
def __init__(self):
"""
Initializes an instance of the RNARCheckpointManager class.
"""
super(RNARCheckpointManager, self).__init__()
@staticmethod
[docs]
def get_rna_regressor(
checkpoint: Union[str, Path] = None, **kwargs
) -> "RNARegressor":
"""
Loads a pre-trained checkpoint for RNA sequence regression and returns an instance of the RNARegressor class
that is ready to make predictions.
:param checkpoint: (Optional) The name of a zipped checkpoint file, the path to a checkpoint file, or the name of
a checkpoint file that can be found in Google Drive. If `checkpoint` is not provided, the default checkpoint
for RNA sequence regression will be loaded.
:type checkpoint: Union[str, Path]
:return: An instance of the RNARegressor class that has been initialized with the specified checkpoint file.
:rtype: RNARegressor
"""
return RNARegressor(
CheckpointManager().parse_checkpoint(
checkpoint, TaskCodeOption.RNASequenceRegression
)
)
[docs]
class TCCheckpointManager(CheckpointManager):
"""
This class manages the checkpoints for text classification.
"""
def __init__(self):
"""
Initializes an instance of the TCCheckpointManager class.
"""
super(TCCheckpointManager, self).__init__()
@staticmethod
[docs]
def get_text_classifier(
checkpoint: Union[str, Path] = None, **kwargs
) -> "TextClassifier":
"""
Returns a TextClassifier instance loaded with a pre-trained checkpoint for text classification.
Args:
checkpoint (Union[str, Path], optional): The name of a zipped checkpoint file, a path to a checkpoint file,
or the name of a checkpoint registered in PyABSA. If None, the latest version of the default checkpoint
will be used. Defaults to None.
**kwargs: Additional keyword arguments. Not used in this method.
Returns:
TextClassifier: A TextClassifier instance loaded with the specified checkpoint.
"""
return TextClassifier(
CheckpointManager().parse_checkpoint(
checkpoint, TaskCodeOption.Text_Classification
)
)