Source code for pyabsa.tasks.AspectTermExtraction.dataset_utils.__lcf__.atepc_utils

# -*- coding: utf-8 -*-
# file: atepc_utils.py
# time: 2021/5/27 0027
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# Copyright (C) 2021. All Rights Reserved.

# from transformers import AutoTokenizer
import re
import string

from pyabsa.tasks.AspectPolarityClassification.dataset_utils.__lcf__.apc_utils import (
    get_syntax_distance,
    get_lca_ids_and_cdm_vec,
    get_cdw_vec,
)
from pyabsa.utils.pyabsa_utils import fprint


# It is hard to tokenize multilingual text, I decide to use a pretrained tokenizer, you can alter according to your demands
# tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')


[docs] def simple_split_text(text): # text = ' '.join(tokenizer.tokenize(text)[1:]) # return text text = text.strip() Chinese_punctuation = "#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·!?。。" punctuation = string.punctuation + Chinese_punctuation for p in punctuation: text = text.replace("{}".format(p), " {} ".format(p)) # text = ' '.join(re.compile(r'\w+|[{}]'.format(re.escape(punctuation))).findall(text)).replace('$ T $', '$T$') # for non-latin Languages non_latin_unicode = [ "\u4e00-\u9fa5", # Chinese "\u0800-\u4e00", # Japanese "\uac00-\ud7a3", # Korean "\u0e00-\u0e7f", # Thai "\u1000-\u109F", # Myanmar ] # latin_lan = ([re.match(lan, text) for lan in non_latin_unicode]) latin_lan = [re.findall("[{}]".format(lan), text) for lan in non_latin_unicode] if not any(latin_lan): return text.split() s = text word_list = [] while len(s) > 0: match_ch = re.match("[{}]".format("".join(non_latin_unicode)), s) if match_ch: word = s[0:1] else: match_en = re.match(r"[a-zA-Z\d]+", s) if match_en: word = match_en.group(0) else: word = s[0:1] # 若非英文单词,直接获取第一个字符 if word: word_list.append(word) # 从文本中去掉提取的 word,并去除文本收尾的空格字符 s = s.replace(word, "", 1).strip(" ") return word_list
[docs] def process_iob_tags(iob_tags: list) -> list: for i in range(len(iob_tags) - 1): if iob_tags[i] == "O" and "ASP" in iob_tags[i + 1]: iob_tags[i + 1] = "B-ASP" if "ASP" in iob_tags[i] and "B-ASP" in iob_tags[i + 1]: iob_tags[i + 1] = "I-ASP" return iob_tags
[docs] def prepare_input_for_atepc(config, tokenizer, text_left, text_right, aspect): if hasattr(config, "dynamic_truncate") and config.dynamic_truncate: _max_seq_len = config.max_seq_len - len(aspect.split()) text_left = text_left.split(" ") text_right = text_right.split(" ") if _max_seq_len < len(text_left) + len(text_right): cut_len = len(text_left) + len(text_right) - _max_seq_len if len(text_left) > len(text_right): text_left = text_left[cut_len:] else: text_right = text_right[: len(text_right) - cut_len] text_left = " ".join(text_left) text_right = " ".join(text_right) bos_token = tokenizer.bos_token if tokenizer.bos_token else "[CLS]" eos_token = tokenizer.eos_token if tokenizer.eos_token else "[SEP]" text_raw = text_left + " " + aspect + " " + text_right text_spc = ( bos_token + " " + text_raw + " " + eos_token + " " + aspect + " " + eos_token ) text_bert_tokens = tokenizer.tokenize(text_spc) text_raw_bert_tokens = tokenizer.tokenize( bos_token + " " + text_raw + " " + eos_token ) aspect_bert_tokens = tokenizer.tokenize(aspect) text_indices = tokenizer.convert_tokens_to_ids(text_bert_tokens) text_raw_bert_indices = tokenizer.convert_tokens_to_ids(text_raw_bert_tokens) aspect_bert_indices = tokenizer.convert_tokens_to_ids(aspect_bert_tokens) aspect_begin = len(tokenizer.tokenize(bos_token + " " + text_left)) if "lcfs" in config.model_name or config.use_syntax_based_SRD: syntactical_dist, _ = get_syntax_distance(text_raw, aspect, tokenizer, config) else: syntactical_dist = None lcf_cdm_vec = get_lca_ids_and_cdm_vec( config, text_indices, aspect_bert_indices, aspect_begin, syntactical_dist ) lcf_cdw_vec = get_cdw_vec( config, text_indices, aspect_bert_indices, aspect_begin, syntactical_dist ) inputs = { "text_raw": text_raw, "text_spc": text_spc, "aspect": aspect, "text_indices": text_indices, "text_raw_bert_indices": text_raw_bert_indices, "aspect_bert_indices": aspect_bert_indices, "lcf_cdm_vec": lcf_cdm_vec, "lcf_cdw_vec": lcf_cdw_vec, } return inputs
[docs] def load_atepc_inference_datasets(fname): lines = [] if isinstance(fname, str): fname = [fname] for f in fname: fprint("loading: {}".format(f)) fin = open(f, "r", encoding="utf-8") lines.extend(fin.readlines()) fin.close() for i in range(len(lines)): lines[i] = ( lines[i][: lines[i].find("$LABEL$")] .replace("[ASP]", "") .replace("[B-ASP]", "") .replace("[E-ASP]", "") .strip() ) return sorted(set(lines), key=lines.index)