# -*- coding: utf-8 -*-
# file: aste_utils.py
# time: 05/03/2023
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2021. All Rights Reserved.
import json
import math
import pickle
import torch
from collections import OrderedDict, defaultdict
from sklearn import metrics
from transformers import BertTokenizer
[docs]label = ["N", "B-A", "I-A", "A", "B-O", "I-O", "O", "Negative", "Neutral", "Positive"]
label2id, id2label = OrderedDict(), OrderedDict()
for i, v in enumerate(label):
label2id[v] = i
id2label[i] = v
[docs]def get_spans(tags):
"""for BIO tag"""
tags = tags.strip().split()
length = len(tags)
spans = []
start = -1
for i in range(length):
if tags[i].endswith("B"):
if start != -1:
spans.append([start, i - 1])
start = i
elif tags[i].endswith("O"):
if start != -1:
spans.append([start, i - 1])
start = -1
if start != -1:
spans.append([start, length - 1])
return spans
[docs]def get_evaluate_spans(tags, length, token_range):
"""for BIO tag"""
spans = []
start = -1
for i in range(length):
l, r = token_range[i]
if tags[l] == -1:
continue
elif tags[l] == 1:
if start != -1:
spans.append([start, i - 1])
start = i
elif tags[l] == 0:
if start != -1:
spans.append([start, i - 1])
start = -1
if start != -1:
spans.append([start, length - 1])
return spans
[docs]class Instance(object):
def __init__(
self,
tokenizer,
sentence_pack,
post_vocab,
deprel_vocab,
postag_vocab,
synpost_vocab,
config,
):
self.id = sentence_pack["id"]
self.sentence = sentence_pack["sentence"]
self.tokens = self.sentence.strip().split()
self.postag = sentence_pack["postag"]
self.head = sentence_pack["head"]
self.deprel = sentence_pack["deprel"]
self.sen_length = len(self.tokens)
self.token_range = []
self.text_ids = tokenizer.encode(
self.sentence,
padding="do_not_pad",
max_length=config.max_seq_len,
truncation=True,
)
self.length = len(self.text_ids)
self.bert_tokens_padding = torch.zeros(config.max_seq_len).long()
self.aspect_tags = torch.zeros(config.max_seq_len).long()
self.opinion_tags = torch.zeros(config.max_seq_len).long()
self.tags = torch.zeros(config.max_seq_len, config.max_seq_len).long()
self.tags_symmetry = torch.zeros(config.max_seq_len, config.max_seq_len).long()
self.mask = torch.zeros(config.max_seq_len)
for i in range(self.length):
self.bert_tokens_padding[i] = self.text_ids[i]
self.mask[: self.length] = 1
token_start = 1
for i, w in enumerate(self.tokens):
token_end = token_start + len(
tokenizer.encode(
w,
padding="do_not_pad",
max_length=config.max_seq_len,
truncation=True,
add_special_tokens=False,
)
)
self.token_range.append([token_start, token_end - 1])
token_start = token_end
assert self.length == self.token_range[-1][-1] + 2, "length error"
self.aspect_tags[self.length :] = -1
self.aspect_tags[0] = -1
self.aspect_tags[self.length - 1] = -1
self.opinion_tags[self.length :] = -1
self.opinion_tags[0] = -1
self.opinion_tags[self.length - 1] = -1
self.tags[:, :] = -1
self.tags_symmetry[:, :] = -1
for i in range(1, self.length - 1):
for j in range(i, self.length - 1):
self.tags[i][j] = 0
for triple in sentence_pack["triples"]:
aspect = triple["target_tags"]
opinion = triple["opinion_tags"]
aspect_span = get_spans(aspect)
opinion_span = get_spans(opinion)
"""set tag for aspect"""
for l, r in aspect_span:
start = self.token_range[l][0]
end = self.token_range[r][1]
for i in range(start, end + 1):
for j in range(i, end + 1):
if j == start:
self.tags[i][j] = label2id["B-A"]
elif j == i:
self.tags[i][j] = label2id["I-A"]
else:
self.tags[i][j] = label2id["A"]
for i in range(l, r + 1):
set_tag = 1 if i == l else 2
al, ar = self.token_range[i]
self.aspect_tags[al] = set_tag
self.aspect_tags[al + 1 : ar + 1] = -1
"""mask positions of sub words"""
self.tags[al + 1 : ar + 1, :] = -1
self.tags[:, al + 1 : ar + 1] = -1
"""set tag for opinion"""
for l, r in opinion_span:
start = self.token_range[l][0]
end = self.token_range[r][1]
for i in range(start, end + 1):
for j in range(i, end + 1):
if j == start:
self.tags[i][j] = label2id["B-O"]
elif j == i:
self.tags[i][j] = label2id["I-O"]
else:
self.tags[i][j] = label2id["O"]
for i in range(l, r + 1):
set_tag = 1 if i == l else 2
pl, pr = self.token_range[i]
self.opinion_tags[pl] = set_tag
self.opinion_tags[pl + 1 : pr + 1] = -1
self.tags[pl + 1 : pr + 1, :] = -1
self.tags[:, pl + 1 : pr + 1] = -1
for al, ar in aspect_span:
for pl, pr in opinion_span:
for i in range(al, ar + 1):
for j in range(pl, pr + 1):
sal, sar = self.token_range[i]
spl, spr = self.token_range[j]
self.tags[sal : sar + 1, spl : spr + 1] = -1
if config.task == "pair":
if i > j:
self.tags[spl][sal] = 7
else:
self.tags[sal][spl] = 7
elif config.task == "triplet":
if i > j:
self.tags[spl][sal] = label2id[triple["sentiment"]]
else:
self.tags[sal][spl] = label2id[triple["sentiment"]]
for i in range(1, self.length - 1):
for j in range(i, self.length - 1):
self.tags_symmetry[i][j] = self.tags[i][j]
self.tags_symmetry[j][i] = self.tags_symmetry[i][j]
"""1. generate position index of the word pair"""
self.word_pair_position = torch.zeros(
config.max_seq_len, config.max_seq_len
).long()
for i in range(len(self.tokens)):
start, end = self.token_range[i][0], self.token_range[i][1]
for j in range(len(self.tokens)):
s, e = self.token_range[j][0], self.token_range[j][1]
for row in range(start, end + 1):
for col in range(s, e + 1):
self.word_pair_position[row][col] = post_vocab.stoi.get(
abs(row - col), post_vocab.unk_index
)
"""2. generate deprel index of the word pair"""
self.word_pair_deprel = torch.zeros(
config.max_seq_len, config.max_seq_len
).long()
for i in range(len(self.tokens)):
start = self.token_range[i][0]
end = self.token_range[i][1]
for j in range(start, end + 1):
s, e = (
self.token_range[self.head[i] - 1] if self.head[i] != 0 else (0, 0)
)
for k in range(s, e + 1):
self.word_pair_deprel[j][k] = deprel_vocab.stoi.get(self.deprel[i])
self.word_pair_deprel[k][j] = deprel_vocab.stoi.get(self.deprel[i])
self.word_pair_deprel[j][j] = deprel_vocab.stoi.get("self")
"""3. generate POS tag index of the word pair"""
self.word_pair_pos = torch.zeros(config.max_seq_len, config.max_seq_len).long()
for i in range(len(self.tokens)):
start, end = self.token_range[i][0], self.token_range[i][1]
for j in range(len(self.tokens)):
s, e = self.token_range[j][0], self.token_range[j][1]
for row in range(start, end + 1):
for col in range(s, e + 1):
self.word_pair_pos[row][col] = postag_vocab.stoi.get(
tuple(sorted([self.postag[i], self.postag[j]]))
)
"""4. generate synpost index of the word pair"""
self.word_pair_synpost = torch.zeros(
config.max_seq_len, config.max_seq_len
).long()
tmp = [[0] * len(self.tokens) for _ in range(len(self.tokens))]
for i in range(len(self.tokens)):
j = self.head[i]
if j == 0:
continue
tmp[i][j - 1] = 1
tmp[j - 1][i] = 1
tmp_dict = defaultdict(list)
for i in range(len(self.tokens)):
for j in range(len(self.tokens)):
if tmp[i][j] == 1:
tmp_dict[i].append(j)
word_level_degree = [[4] * len(self.tokens) for _ in range(len(self.tokens))]
for i in range(len(self.tokens)):
node_set = set()
word_level_degree[i][i] = 0
node_set.add(i)
for j in tmp_dict[i]:
if j not in node_set:
word_level_degree[i][j] = 1
node_set.add(j)
for k in tmp_dict[j]:
if k not in node_set:
word_level_degree[i][k] = 2
node_set.add(k)
for g in tmp_dict[k]:
if g not in node_set:
word_level_degree[i][g] = 3
node_set.add(g)
for i in range(len(self.tokens)):
start, end = self.token_range[i][0], self.token_range[i][1]
for j in range(len(self.tokens)):
s, e = self.token_range[j][0], self.token_range[j][1]
for row in range(start, end + 1):
for col in range(s, e + 1):
self.word_pair_synpost[row][col] = synpost_vocab.stoi.get(
word_level_degree[i][j], synpost_vocab.unk_index
)
[docs] def get_data(self):
return {
"id": self.id,
"sentence": self.sentence,
"sen_length": self.sen_length,
"token_range": self.token_range,
"bert_tokens_padding": self.bert_tokens_padding,
"length": self.length,
"mask": self.mask,
"aspect_tags": self.aspect_tags,
"opinion_tags": self.opinion_tags,
"tags": self.tags,
"tags_symmetry": self.tags_symmetry,
"word_pair_position": self.word_pair_position,
"word_pair_deprel": self.word_pair_deprel,
"word_pair_pos": self.word_pair_pos,
"word_pair_synpost": self.word_pair_synpost,
}
[docs]def load_data_instances(
sentence_packs, post_vocab, deprel_vocab, postag_vocab, synpost_vocab, config
):
instances = list()
tokenizer = BertTokenizer.from_pretrained(config.pretrained_bert)
for sentence_pack in sentence_packs:
instances.append(
Instance(
tokenizer,
sentence_pack,
post_vocab,
deprel_vocab,
postag_vocab,
synpost_vocab,
config,
)
)
return instances
[docs]def load_tokens(filename):
with open(filename, encoding="utf-8") as infile:
data = json.load(infile)
tokens = []
deprel = []
postag = []
postag_ca = []
max_len = 0
for d in data:
sentence = d["sentence"].split()
tokens.extend(sentence)
deprel.extend(d["deprel"])
postag_ca.extend(d["postag"])
# postag.extend(d['postag'])
n = len(d["postag"])
tmp_pos = []
for i in range(n):
for j in range(n):
tup = tuple(sorted([d["postag"][i], d["postag"][j]]))
tmp_pos.append(tup)
postag.extend(tmp_pos)
max_len = max(len(sentence), max_len)
print(
"{} tokens from {} examples loaded from {}.".format(
len(tokens), len(data), filename
)
)
return tokens, deprel, postag, postag_ca, max_len
[docs]class Metric:
def __init__(
self,
config,
predictions,
goldens,
bert_lengths,
sen_lengths,
tokens_ranges,
):
self.config = config
self.predictions = predictions
self.goldens = goldens
self.bert_lengths = bert_lengths
self.sen_lengths = sen_lengths
self.tokens_ranges = tokens_ranges
self.ignore_index = -1
self.data_num = len(self.predictions)
[docs] def get_spans(self, tags, length, token_range, type):
spans = []
start = -1
for i in range(length):
l, r = token_range[i]
if tags[l][l] == self.ignore_index:
continue
elif tags[l][l] == type:
if start == -1:
start = i
elif tags[l][l] != type:
if start != -1:
spans.append([start, i - 1])
start = -1
if start != -1:
spans.append([start, length - 1])
return spans
[docs] def find_pair(self, tags, aspect_spans, opinion_spans, token_ranges):
pairs = []
for al, ar in aspect_spans:
for pl, pr in opinion_spans:
tag_num = [0] * 4
for i in range(al, ar + 1):
for j in range(pl, pr + 1):
a_start = token_ranges[i][0]
o_start = token_ranges[j][0]
if al < pl:
tag_num[int(tags[a_start][o_start])] += 1
else:
tag_num[int(tags[o_start][a_start])] += 1
if tag_num[3] == 0:
continue
sentiment = -1
pairs.append([al, ar, pl, pr, sentiment])
return pairs
[docs] def find_triplet(self, tags, aspect_spans, opinion_spans, token_ranges):
# label2id = {'N': 0, 'B-A': 1, 'I-A': 2, 'A': 3, 'B-O': 4, 'I-O': 5, 'O': 6, 'negative': 7, 'neutral': 8, 'positive': 9}
triplets_utm = []
for al, ar in aspect_spans:
for pl, pr in opinion_spans:
tag_num = [0] * len(self.config.label_to_index)
for i in range(al, ar + 1):
for j in range(pl, pr + 1):
a_start = token_ranges[i][0]
o_start = token_ranges[j][0]
if al < pl:
tag_num[int(tags[a_start][o_start])] += 1
else:
tag_num[int(tags[o_start][a_start])] += 1
if sum(tag_num[7:]) == 0:
continue
sentiment = -1
if tag_num[9] >= tag_num[8] and tag_num[9] >= tag_num[7]:
sentiment = 9
elif tag_num[8] >= tag_num[7] and tag_num[8] >= tag_num[9]:
sentiment = 8
elif tag_num[7] >= tag_num[9] and tag_num[7] >= tag_num[8]:
sentiment = 7
if sentiment == -1:
print("wrong!!!!!!!!!!!!!!!!!!!!")
exit()
triplets_utm.append([al, ar, pl, pr, sentiment])
return triplets_utm
[docs] def score_aspect(self):
assert len(self.predictions) == len(self.goldens)
golden_set = set()
predicted_set = set()
for i in range(self.data_num):
golden_aspect_spans = get_aspects(
self.goldens[i], self.sen_lengths[i], self.tokens_ranges[i], self.config
)
for spans in golden_aspect_spans:
golden_set.add(str(i) + "-" + "-".join(map(str, spans)))
predicted_aspect_spans = get_aspects(
self.predictions[i],
self.sen_lengths[i],
self.tokens_ranges[i],
self.config,
)
for spans in predicted_aspect_spans:
predicted_set.add(str(i) + "-" + "-".join(map(str, spans)))
correct_num = len(golden_set & predicted_set)
precision = correct_num / len(predicted_set) if len(predicted_set) > 0 else 0
recall = correct_num / len(golden_set) if len(golden_set) > 0 else 0
f1 = (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0
)
return precision, recall, f1
[docs] def score_opinion(self):
assert len(self.predictions) == len(self.goldens)
golden_set = set()
predicted_set = set()
for i in range(self.data_num):
golden_opinion_spans = get_opinions(
self.goldens[i], self.sen_lengths[i], self.tokens_ranges[i], self.config
)
for spans in golden_opinion_spans:
golden_set.add(str(i) + "-" + "-".join(map(str, spans)))
predicted_opinion_spans = get_opinions(
self.predictions[i],
self.sen_lengths[i],
self.tokens_ranges[i],
self.config,
)
for spans in predicted_opinion_spans:
predicted_set.add(str(i) + "-" + "-".join(map(str, spans)))
correct_num = len(golden_set & predicted_set)
precision = correct_num / len(predicted_set) if len(predicted_set) > 0 else 0
recall = correct_num / len(golden_set) if len(golden_set) > 0 else 0
f1 = (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0
)
return precision, recall, f1
[docs] def parse_triplet(self, golden=True):
all_golden_tuples = []
all_predicted_tuples = []
for i in range(self.data_num):
if golden:
assert len(self.predictions) == len(self.goldens)
golden_aspect_spans = get_aspects(
self.goldens[i],
self.sen_lengths[i],
self.tokens_ranges[i],
self.config,
)
golden_opinion_spans = get_opinions(
self.goldens[i],
self.sen_lengths[i],
self.tokens_ranges[i],
self.config,
)
if self.config.task == "pair":
golden_tuples = self.find_pair(
self.goldens[i],
golden_aspect_spans,
golden_opinion_spans,
self.tokens_ranges[i],
)
elif self.config.task == "triplet":
golden_tuples = self.find_triplet(
self.goldens[i],
golden_aspect_spans,
golden_opinion_spans,
self.tokens_ranges[i],
)
else:
raise ValueError("Unknown task type: {}".format(self.config.task))
all_golden_tuples.append(golden_tuples)
predicted_aspect_spans = get_aspects(
self.predictions[i],
self.sen_lengths[i],
self.tokens_ranges[i],
self.config,
)
predicted_opinion_spans = get_opinions(
self.predictions[i],
self.sen_lengths[i],
self.tokens_ranges[i],
self.config,
)
if self.config.task == "pair":
predicted_tuples = self.find_pair(
self.predictions[i],
predicted_aspect_spans,
predicted_opinion_spans,
self.tokens_ranges[i],
)
elif self.config.task == "triplet":
predicted_tuples = self.find_triplet(
self.predictions[i],
predicted_aspect_spans,
predicted_opinion_spans,
self.tokens_ranges[i],
)
else:
raise ValueError("task must be pair or triplet")
all_predicted_tuples.append(predicted_tuples)
return all_golden_tuples, all_predicted_tuples
[docs] def tagReport(self):
print(len(self.predictions))
print(len(self.goldens))
golden_tags = []
predict_tags = []
for i in range(self.data_num):
for r in range(102):
for c in range(r, 102):
if self.goldens[i][r][c] == -1:
continue
golden_tags.append(self.goldens[i][r][c])
predict_tags.append(self.predictions[i][r][c])
print(len(golden_tags))
print(len(predict_tags))
target_names = [
"N",
"B-A",
"I-A",
"A",
"B-O",
"I-O",
"O",
"negative",
"neutral",
"positive",
]
print(
metrics.classification_report(
golden_tags, predict_tags, target_names=target_names, digits=4
)
)
[docs]def get_aspects(tags, length, token_range, config=None):
ignore_index = -1
spans = []
start, end = -1, -1
for i in range(length):
l, r = token_range[i]
if tags[l][l] == ignore_index:
continue
label = config.index_to_label[int(tags[l][l])]
if label == "B-A":
if start != -1:
spans.append([start, end])
start, end = i, i
elif label == "I-A":
end = i
else:
if start != -1:
spans.append([start, end])
start, end = -1, -1
if start != -1:
spans.append([start, length - 1])
return spans
[docs]def get_opinions(tags, length, token_range, config=None):
ignore_index = -1
spans = []
start, end = -1, -1
for i in range(length):
l, r = token_range[i]
if tags[l][l] == ignore_index:
continue
label = config.index_to_label[int(tags[l][l])]
if label == "B-O":
if start != -1:
spans.append([start, end])
start, end = i, i
elif label == "I-O":
end = i
else:
if start != -1:
spans.append([start, end])
start, end = -1, -1
if start != -1:
spans.append([start, length - 1])
return spans
[docs]class DataIterator(object):
def __init__(self, instances, config):
self.instances = instances
self.config = config
self.batch_count = math.ceil(len(instances) / config.batch_size)
[docs] def get_batch(self, index):
sentence_ids = []
sentences = []
sens_lens = []
token_ranges = []
bert_tokens = []
lengths = []
masks = []
aspect_tags = []
opinion_tags = []
tags = []
tags_symmetry = []
word_pair_position = []
word_pair_deprel = []
word_pair_pos = []
word_pair_synpost = []
for i in range(
index * self.config.batch_size,
min((index + 1) * self.config.batch_size, len(self.instances)),
):
sentence_ids.append(self.instances[i].id)
sentences.append(self.instances[i].sentence)
sens_lens.append(self.instances[i].sen_length)
token_ranges.append(self.instances[i].token_range)
bert_tokens.append(self.instances[i].bert_tokens_padding)
lengths.append(self.instances[i].length)
masks.append(self.instances[i].mask)
aspect_tags.append(self.instances[i].aspect_tags)
opinion_tags.append(self.instances[i].opinion_tags)
tags.append(self.instances[i].tags)
tags_symmetry.append(self.instances[i].tags_symmetry)
word_pair_position.append(self.instances[i].word_pair_position)
word_pair_deprel.append(self.instances[i].word_pair_deprel)
word_pair_pos.append(self.instances[i].word_pair_pos)
word_pair_synpost.append(self.instances[i].word_pair_synpost)
bert_tokens = torch.stack(bert_tokens).to(self.config.device)
lengths = torch.tensor(lengths).to(self.config.device)
masks = torch.stack(masks).to(self.config.device)
aspect_tags = torch.stack(aspect_tags).to(self.config.device)
opinion_tags = torch.stack(opinion_tags).to(self.config.device)
tags = torch.stack(tags).to(self.config.device)
tags_symmetry = torch.stack(tags_symmetry).to(self.config.device)
word_pair_position = torch.stack(word_pair_position).to(self.config.device)
word_pair_deprel = torch.stack(word_pair_deprel).to(self.config.device)
word_pair_pos = torch.stack(word_pair_pos).to(self.config.device)
word_pair_synpost = torch.stack(word_pair_synpost).to(self.config.device)
return (
sentence_ids,
sentences,
bert_tokens,
lengths,
masks,
sens_lens,
token_ranges,
aspect_tags,
tags,
word_pair_position,
word_pair_deprel,
word_pair_pos,
word_pair_synpost,
tags_symmetry,
)
[docs] def __len__(self):
return self.batch_count
[docs] def __iter__(self):
for i in range(self.batch_count):
yield self.get_batch(i)
[docs]class VocabHelp(object):
def __init__(self, counter, specials=["<pad>", "<unk>"]):
self.pad_index = 0
self.unk_index = 1
counter = counter.copy()
self.itos = list(specials)
for tok in specials:
del counter[tok]
# sort by frequency, then alphabetically
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
words_and_frequencies.sort(
key=lambda tup: tup[1], reverse=True
) # words_and_frequencies is a tuple
for word, freq in words_and_frequencies:
self.itos.append(word)
# stoi is simply a reverse dict for itos
self.stoi = {tok: i for i, tok in enumerate(self.itos)}
[docs] def __eq__(self, other):
if self.stoi != other.stoi:
return False
if self.itos != other.itos:
return False
return True
[docs] def __len__(self):
return len(self.itos)
[docs] def extend(self, v):
words = v.itos
for w in words:
if w not in self.stoi:
self.itos.append(w)
self.stoi[w] = len(self.itos) - 1
return self
@staticmethod
[docs] def load_vocab(vocab_path: str):
with open(vocab_path, "rb") as f:
return pickle.load(f)
[docs] def save_vocab(self, vocab_path):
with open(vocab_path, "wb") as f:
pickle.dump(self, f)