Source code for pyabsa.tasks.AspectPolarityClassification.models.__plm__.mgan_bert

# -*- coding: utf-8 -*-
# file: mgan.py
# author: gene_zc <gene_zhangchen@163.com>
# Copyright (C) 2018. All Rights Reserved.

import torch
import torch.nn as nn
import torch.nn.functional as F
from pyabsa.networks.dynamic_rnn import DynamicLSTM


[docs] class LocationEncoding(nn.Module): def __init__(self, config): super(LocationEncoding, self).__init__() self.config = config
[docs] def forward(self, x, pos_inx): batch_size, seq_len = x.size()[0], x.size()[1] weight = self.weight_matrix(pos_inx, batch_size, seq_len).to(self.config.device) x = weight.unsqueeze(2) * x return x
[docs] def weight_matrix(self, pos_inx, batch_size, seq_len): pos_inx = pos_inx.cpu().numpy() weight = [[] for i in range(batch_size)] for i in range(batch_size): for j in range(pos_inx[i][0]): relative_pos = pos_inx[i][0] - j aspect_len = pos_inx[i][1] - pos_inx[i][0] + 1 sentence_len = seq_len - aspect_len weight[i].append(1 - relative_pos / sentence_len) for j in range(pos_inx[i][0], pos_inx[i][1] + 1): weight[i].append(0) for j in range(pos_inx[i][1] + 1, seq_len): relative_pos = j - pos_inx[i][1] aspect_len = pos_inx[i][1] - pos_inx[i][0] + 1 sentence_len = seq_len - aspect_len weight[i].append(1 - relative_pos / sentence_len) weight = torch.tensor(weight) return weight
[docs] class AlignmentMatrix(nn.Module): def __init__(self, config): super(AlignmentMatrix, self).__init__() self.config = config self.w_u = nn.Parameter(torch.Tensor(6 * config.hidden_dim, 1))
[docs] def forward(self, batch_size, ctx, asp): ctx_len = ctx.size(1) asp_len = asp.size(1) alignment_mat = torch.zeros(batch_size, ctx_len, asp_len).to(self.config.device) ctx_chunks = ctx.chunk(ctx_len, dim=1) asp_chunks = asp.chunk(asp_len, dim=1) for i, ctx_chunk in enumerate(ctx_chunks): for j, asp_chunk in enumerate(asp_chunks): feat = torch.cat( [ctx_chunk, asp_chunk, ctx_chunk * asp_chunk], dim=2 ) # batch_size x 1 x 6*hidden_dim alignment_mat[:, i, j] = ( feat.matmul(self.w_u.expand(batch_size, -1, -1)) .squeeze(-1) .squeeze(-1) ) return alignment_mat
[docs] class MGAN_BERT(nn.Module):
[docs] inputs = ["text_indices", "aspect_indices", "left_indices"]
def __init__(self, bert, config): super(MGAN_BERT, self).__init__() self.config = config self.embed = bert self.ctx_lstm = DynamicLSTM( config.embed_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=True, ) self.asp_lstm = DynamicLSTM( config.embed_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=True, ) self.location = LocationEncoding(config) self.w_a2c = nn.Parameter( torch.Tensor(2 * config.hidden_dim, 2 * config.hidden_dim) ) self.w_c2a = nn.Parameter( torch.Tensor(2 * config.hidden_dim, 2 * config.hidden_dim) ) self.alignment = AlignmentMatrix(config) self.dense = nn.Linear(8 * config.hidden_dim, config.output_dim)
[docs] def forward(self, inputs): text_raw_indices = inputs["text_indices"] # batch_size x seq_len aspect_indices = inputs["aspect_indices"] text_left_indices = inputs["left_indices"] batch_size = text_raw_indices.size(0) ctx_len = torch.sum(text_raw_indices != 0, dim=1) asp_len = torch.sum(aspect_indices != 0, dim=1) left_len = torch.sum(text_left_indices != 0, dim=-1) aspect_in_text = torch.cat( [left_len.unsqueeze(-1), (left_len + asp_len - 1).unsqueeze(-1)], dim=-1 ) ctx = self.embed(text_raw_indices)[ "last_hidden_state" ] # batch_size x seq_len x embed_dim asp = self.embed(aspect_indices)[ "last_hidden_state" ] # batch_size x seq_len x embed_dim ctx_out, (_, _) = self.ctx_lstm(ctx, ctx_len) ctx_out = self.location( ctx_out, aspect_in_text ).float() # batch_size x (ctx)seq_len x 2*hidden_dim ctx_pool = torch.sum(ctx_out, dim=1) ctx_pool = torch.div(ctx_pool, ctx_len.float().unsqueeze(-1)).unsqueeze( -1 ) # batch_size x 2*hidden_dim x 1 asp_out, (_, _) = self.asp_lstm( asp, asp_len ) # batch_size x (asp)seq_len x 2*hidden_dim asp_pool = torch.sum(asp_out, dim=1) asp_pool = torch.div(asp_pool, asp_len.float().unsqueeze(-1)).unsqueeze( -1 ) # batch_size x 2*hidden_dim x 1 alignment_mat = self.alignment( batch_size, ctx_out, asp_out.float() ) # batch_size x (ctx)seq_len x (asp)seq_len # batch_size x 2*hidden_dim f_asp2ctx = torch.matmul( ctx_out.transpose(1, 2), F.softmax(alignment_mat.max(2, keepdim=True)[0], dim=1), ).squeeze(-1) f_ctx2asp = ( torch.matmul( F.softmax(alignment_mat.max(1, keepdim=True)[0], dim=2), asp_out ) .transpose(1, 2) .squeeze(-1) ) c_asp2ctx_alpha = F.softmax( ctx_out.matmul(self.w_a2c.expand(batch_size, -1, -1)).matmul(asp_pool), dim=1, ) c_asp2ctx = torch.matmul(ctx_out.transpose(1, 2), c_asp2ctx_alpha).squeeze(-1) c_ctx2asp_alpha = F.softmax( asp_out.matmul(self.w_c2a.expand(batch_size, -1, -1)).matmul(ctx_pool), dim=1, ) c_ctx2asp = torch.matmul(asp_out.transpose(1, 2), c_ctx2asp_alpha).squeeze(-1) feat = torch.cat([c_asp2ctx, f_asp2ctx, f_ctx2asp, c_ctx2asp], dim=1) out = self.dense(feat) # batch_size x polarity_dim return {"logits": out}