Source code for pyabsa.tasks.AspectPolarityClassification.dataset_utils.__lcf__.apc_utils_for_dlcf_dca

# -*- coding: utf-8 -*-
# file: apc_utils_for_dlcf_dca.py
# time: 2021/5/23 0023
# author: xumayi <xumayi@m.scnu.edu.cn>
# github: https://github.com/XuMayi
# Copyright (C) 2021. All Rights Reserved.

import math
import os

import networkx as nx
import numpy as np
import spacy
import termcolor

from pyabsa.utils.pyabsa_utils import fprint
from .apc_utils import text_to_sequence, get_syntax_distance


[docs] def prepare_input_for_dlcf_dca(config, tokenizer, text_left, text_right, aspect): if hasattr(config, "dynamic_truncate") and config.dynamic_truncate: _max_seq_len = config.max_seq_len - len(aspect.split(" ")) text_left = text_left.split(" ") text_right = text_right.split(" ") if _max_seq_len < (len(text_left) + len(text_right)): cut_len = len(text_left) + len(text_right) - _max_seq_len if len(text_left) > len(text_right): text_left = text_left[cut_len:] else: text_right = text_right[: len(text_right) - cut_len] text_left = " ".join(text_left) text_right = " ".join(text_right) # test code text_left = " ".join( text_left.split(" ")[ int(-(config.max_seq_len - len(aspect.split())) / 2) - 1 : ] ) text_right = " ".join( text_right.split(" ")[ : int((config.max_seq_len - len(aspect.split())) / 2) + 1 ] ) bos_token = tokenizer.bos_token if tokenizer.bos_token else "[CLS]" eos_token = tokenizer.eos_token if tokenizer.eos_token else "[SEP]" text_raw = text_left + " " + aspect + " " + text_right text_spc = ( bos_token + " " + text_raw + " " + eos_token + " " + aspect + " " + eos_token ) text_indices = text_to_sequence(tokenizer, text_spc, config.max_seq_len) aspect_bert_indices = text_to_sequence(tokenizer, aspect, config.max_seq_len) aspect_begin = len(tokenizer.tokenize(bos_token + " " + text_left)) # if 'dlcf' in config.model_name or config.use_syntax_based_SRD: # syntactical_dist, max_dist = get_syntax_distance(text_raw, aspect, tokenizer, config) # else: # syntactical_dist = None syntactical_dist, max_dist = get_syntax_distance( text_raw, aspect, tokenizer, config ) dlcf_cdm_vec = get_dynamic_cdm_vec( config, max_dist, text_indices, aspect_bert_indices, aspect_begin, syntactical_dist=None, ) dlcf_cdw_vec = get_dynamic_cdw_vec( config, max_dist, text_indices, aspect_bert_indices, aspect_begin, syntactical_dist=None, ) dlcfs_cdm_vec = get_dynamic_cdm_vec( config, max_dist, text_indices, aspect_bert_indices, aspect_begin, syntactical_dist, ) dlcfs_cdw_vec = get_dynamic_cdw_vec( config, max_dist, text_indices, aspect_bert_indices, aspect_begin, syntactical_dist, ) depend_vec, depended_vec = calculate_cluster(text_raw, aspect, config) inputs = { "dlcf_cdm_vec": dlcf_cdm_vec, "dlcf_cdw_vec": dlcf_cdw_vec, "dlcfs_cdm_vec": dlcfs_cdm_vec, "dlcfs_cdw_vec": dlcfs_cdw_vec, "depend_vec": depend_vec, "depended_vec": depended_vec, } return inputs
[docs] def get_dynamic_cdw_vec( config, max_dist, bert_spc_indices, aspect_indices, aspect_begin, syntactical_dist=None, ): # the function is used to set dynamic threshold and calculate cdm/cdw for DLCF_DCA_BERT a = config.dlcf_a if max_dist > 0: dynamic_threshold = math.log(max_dist, a) + a - 1 else: dynamic_threshold = 3 cdw_vec = np.zeros((config.max_seq_len), dtype=np.float32) aspect_len = np.count_nonzero(aspect_indices) text_len = np.count_nonzero(bert_spc_indices) - np.count_nonzero(aspect_indices) - 1 if syntactical_dist is not None: for i in range(min(text_len, config.max_seq_len)): if max_dist > 0: if syntactical_dist[i] > dynamic_threshold: w = 1 - syntactical_dist[i] / max_dist cdw_vec[i] = w else: cdw_vec[i] = 1 else: cdw_vec[i] = 1 else: local_context_begin = max(0, aspect_begin - dynamic_threshold) local_context_end = min( aspect_begin + aspect_len + dynamic_threshold - 1, config.max_seq_len ) for i in range(min(text_len, config.max_seq_len)): if i < local_context_begin: w = 1 - (local_context_begin - i) / text_len elif local_context_begin <= i <= local_context_end: w = 1 else: w = 1 - (i - local_context_end) / text_len try: assert 0 <= w <= 1 # exception except: fprint("Warning! invalid CDW weight:", w) cdw_vec[i] = 1 return cdw_vec
[docs] def get_dynamic_cdm_vec( config, max_dist, bert_spc_indices, aspect_indices, aspect_begin, syntactical_dist=None, ): # the function is used to set dynamic threshold and calculate cdm/cdw for DLCF_DCA_BERT a = config.dlcf_a if max_dist > 0: dynamic_threshold = math.log(max_dist, a) + a - 1 else: dynamic_threshold = 3 cdm_vec = np.zeros((config.max_seq_len), dtype=np.float32) aspect_len = np.count_nonzero(aspect_indices) text_len = np.count_nonzero(bert_spc_indices) - np.count_nonzero(aspect_indices) - 1 if syntactical_dist is not None: for i in range(min(text_len, config.max_seq_len)): if syntactical_dist[i] <= dynamic_threshold: cdm_vec[i] = 1 else: local_context_begin = max(0, aspect_begin - dynamic_threshold) local_context_end = min( aspect_begin + aspect_len + dynamic_threshold - 1, config.max_seq_len ) for i in range(min(text_len, config.max_seq_len)): if local_context_begin <= i <= local_context_end: cdm_vec[i] = 1 return cdm_vec
[docs] def configure_dlcf_spacy_model(config): if not hasattr(config, "spacy_model"): config.spacy_model = "en_core_web_sm" global nlp try: nlp = spacy.load(config.spacy_model) except: fprint( "Can not load {} from spacy, try to download it in order to parse syntax tree:".format( config.spacy_model ), termcolor.colored( "\npython -m spacy download {}".format(config.spacy_model), "green" ), ) try: os.system("python -m spacy download {}".format(config.spacy_model)) nlp = spacy.load(config.spacy_model) except: raise RuntimeError( "Download failed, you can download {} manually.".format( config.spacy_model ) ) return nlp
[docs] def calculate_cluster(sentence, aspect, config): terms = [a.lower() for a in aspect.split()] doc_list = [] doc = [a.lower() for a in sentence.split()] for i in range(len(doc)): doc_list.append(i) doc = nlp(sentence.strip()) # Load spacy's dependency tree into a networkx graph edges = [] cnt = 0 term_ids = [0] * len(terms) for token in doc: # Record the position of aspect terms if cnt < len(terms) and token.lower_ == terms[cnt]: term_ids[cnt] = token.i cnt += 1 for child in token.children: edges.append((token.i, child.i)) graph = nx.DiGraph(edges) graph2 = nx.Graph(edges) no_connect = [] for i, word in enumerate(doc): source = i for j in term_ids: target = j try: sum = nx.shortest_path_length(graph2, source=source, target=target) except: if (i not in no_connect) and (i not in term_ids): no_connect.append(i) depend_ids = [] depended_ids = doc_list for k in range(len(terms)): temp_aspcet_ids = term_ids[k] try: temp_nodes = list(nx.dfs_preorder_nodes(graph, source=temp_aspcet_ids)) except: temp_nodes = [temp_aspcet_ids] for i in range(len(temp_nodes)): flag = 1 for j in range(len(depend_ids)): if depend_ids[j] == temp_nodes[i]: flag = 0 if flag == 1: depend_ids.append(temp_nodes[i]) for i in range(len(depend_ids)): s = depend_ids[i] if s in depended_ids: depended_ids.remove(s) for i in range(len(terms)): temp_aspcet_ids = term_ids[i] if temp_aspcet_ids in depend_ids: depend_ids.remove(temp_aspcet_ids) for i in range(len(terms)): temp_aspcet_ids = term_ids[i] if temp_aspcet_ids in depended_ids: depended_ids.remove(temp_aspcet_ids) for i in range(len(no_connect)): if no_connect[i] in depended_ids: depended_ids.remove(no_connect[i]) depend_vec = np.zeros((config.max_seq_len), dtype=np.float32) depended_vec = np.zeros((config.max_seq_len), dtype=np.float32) depended_vec[0] = 1 depend_vec[0] = 1 for i in range(len(depend_ids)): if depend_ids[i] < (config.max_seq_len - 1): depend_vec[depend_ids[i] + 1] = 1 for i in range(len(depended_ids)): if depended_ids[i] < (config.max_seq_len - 1): depended_vec[depended_ids[i] + 1] = 1 return depend_vec, depended_vec