Source code for pyabsa.tasks.AspectPolarityClassification.models.__plm__.tnet_lf_bert

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.bert.modeling_bert import BertPooler

from pyabsa.networks.dynamic_rnn import DynamicLSTM


[docs] class Absolute_Position_Embedding(nn.Module): def __init__(self, config, size=None, mode="sum"): self.config = config self.size = size # 必须为偶数 self.mode = mode super(Absolute_Position_Embedding, self).__init__()
[docs] def forward(self, x, pos_inx): if (self.size is None) or (self.mode == "sum"): self.size = int(x.size(-1)) batch_size, seq_len = x.size()[0], x.size()[1] weight = self.weight_matrix(pos_inx, batch_size, seq_len).to(self.config.device) x = weight.unsqueeze(2) * x return x
[docs] def weight_matrix(self, pos_inx, batch_size, seq_len): pos_inx = pos_inx.cpu().numpy() weight = [[] for i in range(batch_size)] for i in range(batch_size): for j in range(pos_inx[i][1]): relative_pos = pos_inx[i][1] - j weight[i].append(1 - relative_pos / 40) for j in range(pos_inx[i][1], seq_len): relative_pos = j - pos_inx[i][0] weight[i].append(1 - relative_pos / 40) weight = torch.tensor(weight) return weight
[docs] class TNet_LF_BERT_Unit(nn.Module): def __init__(self, bert, config): super(TNet_LF_BERT_Unit, self).__init__() self.embed = bert self.position = Absolute_Position_Embedding(config) self.config = config D = config.embed_dim # 模型词向量维度 C = config.output_dim # 分类数目 L = config.max_seq_len HD = config.hidden_dim self.lstm1 = DynamicLSTM( config.embed_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=True, ) self.lstm2 = DynamicLSTM( config.embed_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=True, ) self.convs3 = nn.Conv1d(2 * HD, 50, 3, padding=1) self.fc1 = nn.Linear(4 * HD, 2 * HD) self.fc = nn.Linear(50, C)
[docs] def forward(self, inputs): text_raw_indices, aspect_indices, aspect_in_text = ( inputs[0], inputs[1], inputs[2], ) feature_len = torch.sum(text_raw_indices != 0, dim=-1) aspect_len = torch.sum(aspect_indices != 0, dim=-1) feature = self.embed(text_raw_indices)["last_hidden_state"] aspect = self.embed(aspect_indices)["last_hidden_state"] v, (_, _) = self.lstm1(feature, feature_len) e, (_, _) = self.lstm2(aspect, aspect_len) v = v.transpose(1, 2) e = e.transpose(1, 2) for i in range(2): a = torch.bmm(e.transpose(1, 2), v) a = F.softmax(a, 1) # (aspect_len,context_len) aspect_mid = torch.bmm(e, a) aspect_mid = torch.cat((aspect_mid, v), dim=1).transpose(1, 2) aspect_mid = F.relu(self.fc1(aspect_mid).transpose(1, 2)) v = aspect_mid + v v = self.position(v.transpose(1, 2), aspect_in_text).transpose(1, 2) e = e.float() v = v.float() z = F.relu(self.convs3(v)) # [(N,Co,L), ...]*len(Ks) z = F.max_pool1d(z, z.size(2)).squeeze(2) return z
# out = self.fc(z) # return {'logits': out}
[docs] class TNet_LF_BERT(nn.Module):
[docs] inputs = [ "text_indices", "aspect_indices", "aspect_boundary", "left_aspect_indices", "left_aspect_boundary", "right_aspect_indices", "right_aspect_boundary", ]
def __init__(self, bert, config): super(TNet_LF_BERT, self).__init__() self.config = config self.asgcn_left = TNet_LF_BERT_Unit(bert, config) if self.config.lsa else None self.asgcn_central = TNet_LF_BERT_Unit(bert, config) self.asgcn_right = TNet_LF_BERT_Unit(bert, config) if self.config.lsa else None self.dropout = nn.Dropout(config.dropout) self.pooler = BertPooler(bert.config) self.linear = nn.Linear(50 * 3, self.config.output_dim) self.dense = nn.Linear(50, self.config.output_dim)
[docs] def forward(self, inputs): res = {"logits": None} if self.config.lsa: cat_feat = torch.cat( ( self.asgcn_left( [ inputs["text_indices"], inputs["left_aspect_indices"], inputs["left_aspect_boundary"], ] ), self.asgcn_central( [ inputs["text_indices"], inputs["aspect_indices"], inputs["aspect_boundary"], ] ), self.asgcn_right( [ inputs["text_indices"], inputs["right_aspect_indices"], inputs["right_aspect_boundary"], ] ), ), -1, ) cat_feat = self.dropout(cat_feat) res["logits"] = self.linear(cat_feat) else: cat_feat = self.asgcn_central( [ inputs["text_indices"], inputs["aspect_indices"], inputs["aspect_boundary"], ] ) cat_feat = self.dropout(cat_feat) res["logits"] = self.dense(cat_feat) return res