Source code for pyabsa.utils.check_utils.dataset_version_check

# -*- coding: utf-8 -*-
# file: dataset_version_check.py
# time: 02/11/2022 15:51
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# GScholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# ResearchGate: https://www.researchgate.net/profile/Heng-Yang-17/research
# Copyright (C) 2022. All Rights Reserved.

from packaging import version
import requests

from findfile import find_cwd_file
from termcolor import colored

from pyabsa.utils.exception_utils.exception_utils import time_out
from pyabsa.utils.pyabsa_utils import fprint


@time_out(10)
[docs] def query_local_datasets_version(**kwargs): try: fin = open(find_cwd_file(["__init__.py", "integrated_datasets"])) local_version = fin.read().split("'")[-2] fin.close() except: return None return local_version
@time_out(10)
[docs] def query_remote_datasets_version(**kwargs): logger = kwargs.get("logger", None) try: dataset_url = "https://raw.githubusercontent.com/yangheng95/ABSADatasets/v1.2/datasets/__init__.py" content = requests.get(dataset_url, timeout=5) remote_version = content.text.split("'")[-2] except Exception as e: try: dataset_url = "https://gitee.com/yangheng95/ABSADatasets/raw/v1.2/datasets/__init__.py" content = requests.get(dataset_url, timeout=5) remote_version = content.text.split("'")[-2] except Exception as e: if logger: logger.warning("Failed to query remote version") else: fprint(colored("Failed to query remote version", "red")) return None return remote_version
@time_out(10)
[docs] def check_datasets_version(**kwargs): """ Check if the local dataset version is the same as the remote dataset version. """ logger = kwargs.get("logger", None) try: local_version = query_local_datasets_version() remote_version = query_remote_datasets_version() if logger is not None: logger.info(f"Local dataset version: {local_version}") logger.info(f"Remote dataset version: {remote_version}") else: fprint(f"Local dataset version: {local_version}") fprint(f"Remote dataset version: {remote_version}") if not remote_version: if logger: logger.warning( "Failed to check ABSADatasets version, please" "check the latest version of ABSADatasets at https://github.com/yangheng95/ABSADatasets" ) else: fprint( colored( "Failed to check ABSADatasets version, please" "check the latest version of ABSADatasets at https://github.com/yangheng95/ABSADatasets", "red", ) ) if not local_version: if logger: logger.warning( "Failed to check local ABSADatasets version, please make sure you have downloaded the latest version of ABSADatasets." ) else: fprint( colored( "Failed to check local ABSADatasets version, please make sure you have downloaded the latest version of ABSADatasets.", "red", ) ) if version.parse(local_version) < version.parse(remote_version): if logger: logger.warning( "Local ABSADatasets version is lower than remote ABSADatasets version, please upgrade your ABSADatasets." ) else: fprint( colored( "Local ABSADatasets version is lower than remote ABSADatasets version, please upgrade your ABSADatasets.", "red", ) ) except Exception as e: if logger: logger.warning( "ABSADatasets version check failed: {}, please check the latest datasets at https://github.com/yangheng95/ABSADatasets manually.".format( e ) ) else: fprint( colored( "ABSADatasets version check failed: {}, please check the latest datasets at https://github.com/yangheng95/ABSADatasets manually.".format( e ), "red", ) )