# -*- coding: utf-8 -*-
# file: data_utils_for_inferring.py
# time: 2021/4/22 0022
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# Copyright (C) 2021. All Rights Reserved.
import re
from typing import Union, List
import numpy as np
from pyabsa import LabelPaddingOption
from pyabsa.framework.dataset_class.dataset_template import PyABSADataset
from pyabsa.utils.file_utils.file_utils import load_dataset_from_file
from torch.utils.data import Dataset
import tqdm
from pyabsa.utils.pyabsa_utils import fprint
from .apc_utils import (
build_sentiment_window,
build_spc_mask_vec,
prepare_input_for_apc,
configure_spacy_model,
)
from .apc_utils_for_dlcf_dca import (
prepare_input_for_dlcf_dca,
configure_dlcf_spacy_model,
)
[docs]
def parse_sample(text):
if "[B-ASP]" not in text and "[ASP]" not in text:
# if '[B-ASP]' not in text or '[E-ASP]' not in text:
text = " [B-ASP]Global Sentiment[E-ASP] " + text
_text = text
samples = []
if "$LABEL$" not in text:
text += "$LABEL$"
text, _, ref_sent = text.partition("$LABEL$")
if "[B-ASP]" in text:
ref_sent = ref_sent.split(",") if ref_sent else []
aspects = re.findall(r"\[B\-ASP\](.*?)\[E\-ASP\]", text)
for i, aspect in enumerate(aspects):
sample = (
text.replace(f"[B-ASP]{aspect}[E-ASP]", f"[TEMP]{aspect}[TEMP]", 1)
.replace("[B-ASP]", "")
.replace("[E-ASP]", "")
)
if len(aspects) == len(ref_sent):
sample += f"$LABEL${ref_sent[i]}"
samples.append(sample.replace("[TEMP]", "[ASP]"))
else:
fprint(
f"Warning: reference sentiment does not exist or its number {len(ref_sent)} "
f"is not equal to aspect number {len(aspects)}, text: {_text}"
)
samples.append(sample.replace("[TEMP]", "[ASP]"))
else:
fprint(
"[ASP] tag is detected, please use [B-ASP] and [E-ASP] to annotate aspect terms."
)
splits = text.split("[ASP]")
ref_sent = ref_sent.split(",") if ref_sent else []
if ref_sent and int((len(splits) - 1) / 2) == len(ref_sent):
for i in range(0, len(splits) - 1, 2):
sample = text.replace(
"[ASP]" + splits[i + 1] + "[ASP]",
"[TEMP]" + splits[i + 1] + "[TEMP]",
1,
).replace("[ASP]", "")
sample += " $LABEL$ " + str(ref_sent[int(i / 2)])
samples.append(sample.replace("[TEMP]", "[ASP]"))
elif not ref_sent or int((len(splits) - 1) / 2) != len(ref_sent):
# if not ref_sent:
# fprint(_text, ' -> No the reference sentiment found')
if ref_sent:
fprint(
_text,
" -> Unequal length of reference sentiment and aspects, ignore the reference sentiment.",
)
for i in range(0, len(splits) - 1, 2):
sample = text.replace(
"[ASP]" + splits[i + 1] + "[ASP]",
"[TEMP]" + splits[i + 1] + "[TEMP]",
1,
).replace("[ASP]", "")
samples.append(sample.replace("[TEMP]", "[ASP]"))
else:
raise ValueError("Invalid Input:{}".format(text))
return samples
[docs]
class ABSAInferenceDataset(Dataset):
def __init__(self, config, tokenizer):
configure_spacy_model(config)
self.tokenizer = tokenizer
self.config = config
self.data = []
[docs]
def prepare_infer_sample(self, text: Union[str, List[str]], ignore_error=True):
if isinstance(text, str):
self.process_data(parse_sample(text), ignore_error=ignore_error)
elif isinstance(text, list):
examples = []
for sample in text:
examples.extend(parse_sample(sample))
self.process_data(examples, ignore_error=ignore_error)
[docs]
def prepare_infer_dataset(self, infer_file, ignore_error):
lines = load_dataset_from_file(infer_file, config=self.config)
samples = []
for sample in lines:
if sample:
samples.extend(parse_sample(sample))
self.process_data(samples, ignore_error)
[docs]
def process_data(self, samples, ignore_error=True):
all_data = []
label_set = set()
ex_id = 0
if len(samples) > 100:
it = tqdm.tqdm(samples, desc="preparing apc inference dataloader")
else:
it = samples
for i, text in enumerate(it):
try:
# handle for empty lines in inference dataset
if text is None or "" == text.strip():
raise RuntimeError("Invalid Input!")
# check for given polarity
if "$LABEL$" in text:
text, polarity = (
text.split("$LABEL$")[0].strip(),
text.split("$LABEL$")[1].strip(),
)
polarity = (
polarity if polarity else LabelPaddingOption.LABEL_PADDING
)
text = text.replace("[PADDING]", "")
else:
polarity = str(LabelPaddingOption.LABEL_PADDING)
# simply add padding in case of some aspect is at the beginning or ending of a sentence
text_left, aspect, text_right = text.split("[ASP]")
text_left = text_left.replace("[PADDING] ", "")
text_right = text_right.replace(" [PADDING]", "")
text = text_left + " " + aspect + " " + text_right
prepared_inputs = prepare_input_for_apc(
self.config,
self.tokenizer,
text_left,
text_right,
aspect,
input_demands=self.config.inputs_cols,
)
text_raw = prepared_inputs["text_raw"]
aspect = prepared_inputs["aspect"]
aspect_position = prepared_inputs["aspect_position"]
text_indices = prepared_inputs["text_indices"]
text_raw_bert_indices = prepared_inputs["text_raw_bert_indices"]
aspect_bert_indices = prepared_inputs["aspect_bert_indices"]
lcf_cdw_vec = prepared_inputs["lcf_cdw_vec"]
lcf_cdm_vec = prepared_inputs["lcf_cdm_vec"]
lcf_vec = prepared_inputs["lcf_vec"]
lcfs_cdw_vec = prepared_inputs["lcfs_cdw_vec"]
lcfs_cdm_vec = prepared_inputs["lcfs_cdm_vec"]
lcfs_vec = prepared_inputs["lcfs_vec"]
if (
self.config.model_name == "dlcf_dca_bert"
or self.config.model_name == "dlcfs_dca_bert"
):
configure_dlcf_spacy_model(self.config)
prepared_inputs = prepare_input_for_dlcf_dca(
self.config, self.tokenizer, text_left, text_right, aspect
)
dlcf_vec = (
prepared_inputs["dlcf_cdm_vec"]
if self.config.lcf == "cdm"
else prepared_inputs["dlcf_cdw_vec"]
)
dlcfs_vec = (
prepared_inputs["dlcfs_cdm_vec"]
if self.config.lcf == "cdm"
else prepared_inputs["dlcfs_cdw_vec"]
)
depend_vec = prepared_inputs["depend_vec"]
depended_vec = prepared_inputs["depended_vec"]
data = {
"ex_id": ex_id,
"text_raw": text_raw,
"aspect": aspect,
"aspect_position": aspect_position,
"lca_ids": lcf_vec,
# the lca indices are the same as the refactored CDM (lcf != CDW or Fusion) lcf vec
"lcf_vec": lcf_vec if "lcf_vec" in self.config.inputs_cols else 0,
"lcf_cdw_vec": lcf_cdw_vec
if "lcf_cdw_vec" in self.config.inputs_cols
else 0,
"lcf_cdm_vec": lcf_cdm_vec
if "lcf_cdm_vec" in self.config.inputs_cols
else 0,
"lcfs_vec": lcfs_vec
if "lcfs_vec" in self.config.inputs_cols
else 0,
"lcfs_cdw_vec": lcfs_cdw_vec
if "lcfs_cdw_vec" in self.config.inputs_cols
else 0,
"lcfs_cdm_vec": lcfs_cdm_vec
if "lcfs_cdm_vec" in self.config.inputs_cols
else 0,
"dlcf_vec": dlcf_vec
if "dlcf_vec" in self.config.inputs_cols
else 0,
"dlcfs_vec": dlcfs_vec
if "dlcfs_vec" in self.config.inputs_cols
else 0,
"depend_vec": depend_vec
if "depend_vec" in self.config.inputs_cols
else 0,
"depended_vec": depended_vec
if "depended_vec" in self.config.inputs_cols
else 0,
"spc_mask_vec": build_spc_mask_vec(
self.config, text_raw_bert_indices
)
if "spc_mask_vec" in self.config.inputs_cols
else 0,
"text_indices": text_indices
if "text_indices" in self.config.inputs_cols
else 0,
"aspect_bert_indices": aspect_bert_indices
if "aspect_bert_indices" in self.config.inputs_cols
else 0,
"text_raw_bert_indices": text_raw_bert_indices
if "text_raw_bert_indices" in self.config.inputs_cols
else 0,
"polarity": polarity,
}
ex_id += 1
all_data.append(data)
except Exception as e:
if ignore_error:
fprint(
"Ignore error while processing: {} Error info:{}".format(
text, e
)
)
else:
raise RuntimeError(
"Ignore error while processing: {} Catch Exception: {}, use ignore_error=True to remove error samples.".format(
text, e
)
)
all_data = build_sentiment_window(
all_data,
self.tokenizer,
self.config.similarity_threshold,
input_demands=self.config.inputs_cols,
)
for data in all_data:
cluster_ids = []
for pad_idx in range(self.config.max_seq_len):
if pad_idx in data["cluster_ids"]:
# fprint(data['polarity'])
cluster_ids.append(
self.config.label_to_index.get(
self.config.index_to_label.get(data["polarity"], "N.A."),
LabelPaddingOption.SENTIMENT_PADDING,
)
)
else:
cluster_ids.append(-100)
# cluster_ids.append(3)
data["cluster_ids"] = np.asarray(cluster_ids, dtype=np.int64)
data["side_ex_ids"] = np.array(0)
data["aspect_position"] = np.array(0)
self.data = all_data
self.data = PyABSADataset.covert_to_tensor(self.data)
return self.data
[docs]
def __getitem__(self, index):
return self.data[index]
[docs]
def __len__(self):
return len(self.data)