Source code for pyabsa.utils.data_utils.dataset_manager

# -*- coding: utf-8 -*-
# file: dataset_manager.py
# time: 2021/6/8 0008
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# Copyright (C) 2021. All Rights Reserved.

import os
import shutil
import tempfile
import time
import zipfile
from pathlib import Path
from typing import Union

import git
import findfile
import requests
import tqdm

from termcolor import colored

from pyabsa.augmentation import (
    auto_aspect_sentiment_classification_augmentation,
    auto_classification_augmentation,
)
from pyabsa.framework.flag_class import (
    TaskCodeOption,
    PyABSAMaterialHostAddress,
    TaskNameOption,
)
from pyabsa.utils.check_utils.dataset_version_check import check_datasets_version
from pyabsa.utils.data_utils.dataset_item import DatasetItem
from pyabsa.utils.pyabsa_utils import fprint

[docs]filter_key_words = [ ".py", ".md", "readme", ".log", "result", ".zip", ".state_dict", ".model", ".png", "acc_", "f1_", ".backup", ".bak",
]
[docs]def _detect_dataset_from_exact_path( dataset_name_or_path, task_code: TaskCodeOption = None, load_aug=False, config=None, **kwargs ): logger = kwargs.get("logger", None) dataset_file = {"train": [], "test": [], "valid": []} logger.info("Detecting dataset from exact path: %s" % dataset_name_or_path) dataset_file["train"] += findfile.find_files( dataset_name_or_path, ["train", task_code], exclude_key=[".inference", "test.", "valid."] + filter_key_words + [".ignore"], ) dataset_file["test"] += findfile.find_files( dataset_name_or_path, ["test", task_code], exclude_key=[".inference", "train.", "valid."] + filter_key_words + [".ignore"], ) dataset_file["valid"] += findfile.find_files( dataset_name_or_path, ["valid", task_code], exclude_key=[".inference", "train.", "test."] + filter_key_words + [".ignore"], ) dataset_file["valid"] += findfile.find_files( dataset_name_or_path, ["dev", task_code], exclude_key=[".inference", "train.", "test."] + filter_key_words + [".ignore"], ) logger.info( "Make sure the dataset file names are in the correct format, e.g., train.txt.{}, test.txt.{}, valid.txt.{}".format( task_code, task_code, task_code ).lower() ) return dataset_file
[docs]def detect_dataset( dataset_name_or_path, task_code: TaskCodeOption = None, load_aug=False, config=None, **kwargs ): """ Detect dataset from dataset_path, you need to specify the task type, which can be TaskCodeOption.Aspect_Polarity_Classification, 'atepc' or 'tc', etc. :param dataset_name_or_path: str or DatasetItem The name or path of the dataset. :param task_code: str or TaskCodeOption The task type, such as "apc" for aspect-polarity classification or "tc" for text classification. :param load_aug: bool, default False Whether to load the augmented dataset. :param config: Config, optional The configuration object. :param kwargs: dict Additional keyword arguments. :return: dict A dictionary containing file paths for the train, test, and validation sets. """ logger = config.logger if config else kwargs.get("logger", None) if ( isinstance(dataset_name_or_path, str) or isinstance(dataset_name_or_path, Path) ) and os.path.exists(dataset_name_or_path): logger.info( dataset_name_or_path.__str__() + " in the trainer is a exact path, will detect dataset from this path" ) dataset_file = _detect_dataset_from_exact_path( dataset_name_or_path, task_code, load_aug, config, logger=logger, **kwargs ) else: logger.info( dataset_name_or_path.__str__() + " in the trainer is not a exact path, will search dataset in current working directory" ) dataset_file = _detect_dataset_from_local_search( dataset_name_or_path, task_code, load_aug, config, logger=logger, **kwargs ) if len(dataset_file["train"]) == 0: logger.info( "No " + dataset_name_or_path.__str__() + " found in current working directory, will search dataset from https://github.com/yangheng95/ABSADatasets" ) dataset_file = _detect_dataset_from_remote_search( dataset_name_or_path, task_code, load_aug, config, logger=logger, **kwargs ) # # if we need train a checkpoint using as much data as possible, we can merge train, valid and test set as trainer sets # dataset_file['train'] = dataset_file['train'] + dataset_file['test'] + dataset_file['valid'] # dataset_file['test'] = [] # dataset_file['valid'] = [] if len(dataset_file["train"]) == 0: raise RuntimeError( 'Fail to locate dataset: {}. Your dataset file names should contain task code, e.g., ".apc" or ".atepc" or "tc". ' "you may need rename your dataset according to {}".format( dataset_name_or_path, "https://github.com/yangheng95/ABSADatasets#important-rename-your-dataset-filename-before-use-it-in-pyabsa", ) ) if len(dataset_file["test"]) == 0: logger.info( "Warning! auto_evaluate=True, however cannot find test set using for evaluating!" ) if len(dataset_name_or_path) > 1: logger.info( "Please DO NOT mix datasets with different sentiment labels for trainer & inference !" ) for k, v in dataset_file.items(): dataset_file[k] = list(set(v)) return dataset_file
[docs]def detect_infer_dataset( dataset_name_or_path, task_code: TaskCodeOption = None, **kwargs ): """ Detect the inference dataset from local disk or download from GitHub :param dataset_name_or_path: dataset name or path :param task_code: task name :param kwargs: other arguments """ logger = kwargs.get("logger", None) dataset_file = [] if ( isinstance(dataset_name_or_path, str) or isinstance(dataset_name_or_path, Path) ) and os.path.isfile(dataset_name_or_path): dataset_file.append(dataset_name_or_path) return dataset_file if not isinstance(dataset_name_or_path, DatasetItem): dataset_name_or_path = DatasetItem(dataset_name_or_path) for d in dataset_name_or_path: if not os.path.exists(d): if os.path.exists("integrated_datasets"): if logger: logger.info("Try to load {} dataset from local disk".format(d)) else: fprint("Try to load {} dataset from local disk".format(d)) else: if logger: logger.info( "Try to download {} dataset from https://github.com/yangheng95/ABSADatasets".format( d ) ) else: fprint( "Try to download {} dataset from https://github.com/yangheng95/ABSADatasets".format( d ) ) try: download_all_available_datasets(logger=logger) except Exception as e: if logger: logger.error( "Fail to download dataset from https://github.com/yangheng95/ABSADatasets, please check your network connection" ) logger.info("Try to load {} dataset from Huggingface".format(d)) else: fprint( "Fail to download dataset from https://github.com/yangheng95/ABSADatasets, please check your network connection" ) fprint("Try to load {} dataset from Huggingface".format(d)) download_dataset_by_name( logger=logger, task_code=task_code, dataset_name=d ) search_path = findfile.find_dir( os.getcwd(), [d, task_code], exclude_key=filter_key_words, disable_alert=False, ) dataset_file += findfile.find_files( search_path, [".inference", d], exclude_key=["train."] + filter_key_words, ) else: dataset_file += findfile.find_files( d, [".inference", task_code], exclude_key=["train."] + filter_key_words ) if len(dataset_file) == 0: if os.path.isdir(dataset_name_or_path.dataset_name): fprint( "No inference set found from: {}, unrecognized files: {}".format( dataset_name_or_path, ", ".join(os.listdir(dataset_name_or_path.dataset_name)), ) ) raise RuntimeError( "Fail to locate dataset: {}. If you are using your own dataset, you may need rename your dataset according to {}".format( dataset_name_or_path, "https://github.com/yangheng95/ABSADatasets#important-rename-your-dataset-filename-before-use-it-in-pyabsa", ) ) if len(dataset_name_or_path) > 1: fprint( colored( "Please DO NOT mix datasets with different sentiment labels for trainer & inference !", "yellow", ) ) return dataset_file
[docs]def download_all_available_datasets(**kwargs): """ Download datasets from GitHub :param kwargs: other arguments """ logger = kwargs.get("logger", None) save_path = os.getcwd() if not save_path.endswith("integrated_datasets"): save_path = os.path.join(save_path, "integrated_datasets") if findfile.find_files(save_path, "integrated_datasets", exclude_key=".git"): if kwargs.get("force_download", False): shutil.rmtree(save_path) if logger: logger.info( "Force download datasets from https://github.com/yangheng95/ABSADatasets" ) else: fprint( "Force download datasets from https://github.com/yangheng95/ABSADatasets" ) else: if logger: logger.info( "Datasets already exist in {}, skip download".format(save_path) ) else: fprint("Datasets already exist in {}, skip download".format(save_path)) return with tempfile.TemporaryDirectory() as tmpdir: try: fprint( "Clone ABSADatasets from https://github.com/yangheng95/ABSADatasets.git" ) git.Repo.clone_from( "https://github.com/yangheng95/ABSADatasets.git", tmpdir, branch="v2.0", depth=1, ) # git.Repo.clone_from('https://github.com/yangheng95/ABSADatasets.git', tmpdir, branch='master', depth=1) try: shutil.move(os.path.join(tmpdir, "datasets"), "{}".format(save_path)) except IOError as e: pass except Exception as e: try: fprint( "Clone ABSADatasets from https://gitee.com/yangheng95/ABSADatasets.git" ) git.Repo.clone_from( "https://gitee.com/yangheng95/ABSADatasets.git", tmpdir, branch="v2.0", depth=1, ) # git.Repo.clone_from('https://github.com/yangheng95/ABSADatasets.git', tmpdir, branch='master', depth=1) try: shutil.move( os.path.join(tmpdir, "datasets"), "{}".format(save_path) ) except IOError as e: pass except Exception as e: fprint( colored( "Exception: {}. Fail to clone ABSADatasets, please check your connection".format( e ), "red", ) ) time.sleep(3) download_all_available_datasets(**kwargs)
# from pyabsa.tasks.AspectPolarityClassification import APCDatasetList
[docs]def download_dataset_by_name( task_code: Union[ TaskCodeOption, str ] = TaskCodeOption.Aspect_Polarity_Classification, dataset_name: Union[DatasetItem, str] = None, **kwargs ): """ If download all datasets failed, try to download dataset by name from Huggingface Download dataset from Huggingface: https://huggingface.co/spaces/yangheng/PyABSA :param task_code: task code -> e.g., TaskCodeOption.Aspect_Polarity_Classification :param dataset_name: dataset name -> e.g, pyabsa.tasks.AspectPolarityClassification.APCDatasetList.Laptop14 """ logger = kwargs.get("logger", None) if isinstance(dataset_name, DatasetItem): for d in dataset_name: download_dataset_by_name(task_code=task_code, dataset_name=d, **kwargs) if logger: logger.info("Start {} downloading".format(dataset_name)) url = ( PyABSAMaterialHostAddress + "resolve/main/integrated_datasets/{}_datasets.{}.zip".format( task_code, dataset_name ).lower() ) try: # from Huggingface Space response = requests.get(url, stream=True) save_path = dataset_name.lower() + ".zip" with open(save_path, "wb") as f: for chunk in tqdm.tqdm( response.iter_content(chunk_size=1024), unit="KiB", total=int(response.headers["content-length"]) // 1024, desc="Downloading ({}){} dataset".format( TaskNameOption[task_code], dataset_name ), ): f.write(chunk) with zipfile.ZipFile(save_path, "r") as zip_ref: zip_ref.extractall(os.getcwd()) except Exception as e: if logger: logger.info( "Exception: {}. Fail to download dataset from {}. Please check your connection".format( e, url ) ) else: fprint( colored( "Exception: {}. Fail to download dataset from {}. Please check your connection".format( e, url ), "red",
) )