Source code for pyabsa.tasks.AspectPolarityClassification.models.__lcf__.dlcf_dca_bert

# -*- coding: utf-8 -*-
# file: apc_utils.py
# time: 2021/5/23 0023
# author: xumayi <xumayi@m.scnu.edu.cn>
# github: https://github.com/XuMayi
# Copyright (C) 2021. All Rights Reserved.

import torch
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertPooler

from pyabsa.networks.sa_encoder import Encoder


[docs] def weight_distrubute_local( bert_local_out, depend_weight, depended_weight, depend_vec, depended_vec, config ): bert_local_out2 = torch.zeros_like(bert_local_out) depend_vec2 = torch.mul(depend_vec, depend_weight.unsqueeze(2)) depended_vec2 = torch.mul(depended_vec, depended_weight.unsqueeze(2)) bert_local_out2 = ( bert_local_out2 + torch.mul(bert_local_out, depend_vec2) + torch.mul(bert_local_out, depended_vec2) ) for j in range(depend_weight.size()[0]): bert_local_out2[j][0] = bert_local_out[j][0] return bert_local_out2
[docs] class PointwiseFeedForward(nn.Module): """A two-feed-forward-layer module""" def __init__(self, d_hid, d_inner_hid=None, d_out=None, dropout=0): super(PointwiseFeedForward, self).__init__() if d_inner_hid is None: d_inner_hid = d_hid if d_out is None: d_out = d_inner_hid self.w_1 = nn.Conv1d(d_hid, d_inner_hid, 1) # position-wise self.w_2 = nn.Conv1d(d_inner_hid, d_out, 1) # position-wise self.dropout = nn.Dropout(dropout) self.relu = nn.ReLU()
[docs] def forward(self, x): output = self.relu(self.w_1(x.transpose(1, 2))) output = self.w_2(output).transpose(2, 1) output = self.dropout(output) return output
[docs] class DLCF_DCA_BERT(nn.Module):
[docs] inputs = [ "text_indices", "text_raw_bert_indices", "dlcf_vec", "depend_vec", "depended_vec", ]
def __init__(self, bert, config): super(DLCF_DCA_BERT, self).__init__() self.bert4global = bert self.bert4local = self.bert4global self.hidden = config.embed_dim self.config = config self.config.bert_dim = config.embed_dim self.dropout = nn.Dropout(config.dropout) self.bert_SA_ = Encoder(bert.config, config) self.mean_pooling_double = PointwiseFeedForward( self.hidden * 2, self.hidden, self.hidden ) self.bert_pooler = BertPooler(bert.config) self.dense = nn.Linear(self.hidden, config.output_dim) self.dca_sa = nn.ModuleList() self.dca_pool = nn.ModuleList() self.dca_lin = nn.ModuleList() for i in range(config.dca_layer): self.dca_sa.append(Encoder(bert.config, config)) self.dca_pool.append(BertPooler(bert.config)) self.dca_lin.append( nn.Sequential( nn.Linear(config.bert_dim, config.bert_dim * 2), nn.GELU(), nn.Linear(config.bert_dim * 2, 1), nn.Sigmoid(), ) )
[docs] def weight_calculate(self, sa, pool, lin, d_w, ded_w, depend_out, depended_out): depend_sa_out = sa(depend_out) depend_sa_out = self.dropout(depend_sa_out) depended_sa_out = sa(depended_out) depended_sa_out = self.dropout(depended_sa_out) depend_pool_out = pool(depend_sa_out) depend_pool_out = self.dropout(depend_pool_out) depended_pool_out = pool(depended_sa_out) depended_pool_out = self.dropout(depended_pool_out) depend_weight = lin(depend_pool_out) depend_weight = self.dropout(depend_weight) depended_weight = lin(depended_pool_out) depended_weight = self.dropout(depended_weight) for i in range(depend_weight.size()[0]): depend_weight[i] = depend_weight[i].item() * d_w[i].item() depended_weight[i] = depended_weight[i].item() * ded_w[i].item() weight_sum = depend_weight[i].item() + depended_weight[i].item() if weight_sum != 0: depend_weight[i] = ( 2 * depend_weight[i] / weight_sum ) ** self.config.dca_p if depend_weight[i] > 2: depend_weight[i] = 2 depended_weight[i] = ( 2 * depended_weight[i] / weight_sum ) ** self.config.dca_p if depended_weight[i] > 2: depended_weight[i] = 2 else: depend_weight[i] = 1 depended_weight[i] = 1 return depend_weight, depended_weight
[docs] def forward(self, inputs): if self.config.use_bert_spc: text_indices = inputs["text_indices"] else: text_indices = inputs["text_raw_bert_indices"] text_local_indices = inputs["text_raw_bert_indices"] lcf_matrix = inputs["dlcf_vec"].unsqueeze(2) depend_vec = inputs["depend_vec"].unsqueeze(2) depended_vec = inputs["depended_vec"].unsqueeze(2) global_context_features = self.bert4global(text_indices)["last_hidden_state"] local_context_features = self.bert4local(text_local_indices)[ "last_hidden_state" ] bert_local_out = torch.mul(local_context_features, lcf_matrix) depend_weight = torch.ones(bert_local_out.size()[0]) depended_weight = torch.ones(bert_local_out.size()[0]) for i in range(self.config.dca_layer): depend_out = torch.mul(bert_local_out, depend_vec) depended_out = torch.mul(bert_local_out, depended_vec) depend_weight, depended_weight = self.weight_calculate( self.dca_sa[i], self.dca_pool[i], self.dca_lin[i], depend_weight, depended_weight, depend_out, depended_out, ) bert_local_out = weight_distrubute_local( bert_local_out, depend_weight, depended_weight, depend_vec, depended_vec, self.config, ) out_cat = torch.cat((bert_local_out, global_context_features), dim=-1) out_cat = self.mean_pooling_double(out_cat) out_cat = self.bert_SA_(out_cat) out_cat = self.bert_pooler(out_cat) dense_out = self.dense(out_cat) return {"logits": dense_out, "hidden_state": out_cat}