# -*- coding: utf-8 -*-
# file: attention.py
# author: songyouwei <youwei0314@gmail.com>
# Copyright (C) 2018. All Rights Reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class Attention(nn.Module):
def __init__(
self,
embed_dim,
hidden_dim=None,
out_dim=None,
n_head=1,
score_function="dot_product",
dropout=0,
):
"""Attention Mechanism
:param embed_dim:
:param hidden_dim:
:param out_dim:
:param n_head: num of head (Multi-Head Attention)
:param score_function: scaled_dot_product / mlp (concat) / bi_linear (general dot)
:return (?, q_len, out_dim,)
"""
super(Attention, self).__init__()
if hidden_dim is None:
hidden_dim = embed_dim // n_head
if out_dim is None:
out_dim = embed_dim
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.n_head = n_head
self.score_function = score_function
self.w_k = nn.Linear(embed_dim, n_head * hidden_dim)
self.w_q = nn.Linear(embed_dim, n_head * hidden_dim)
self.proj = nn.Linear(n_head * hidden_dim, out_dim)
self.dropout = nn.Dropout(dropout)
if score_function == "mlp":
self.weight = nn.Parameter(torch.Tensor(hidden_dim * 2))
elif self.score_function == "bi_linear":
self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
else: # dot_product / scaled_dot_product
self.register_parameter("weight", None)
self.reset_parameters()
[docs]
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_dim)
if self.weight is not None:
self.weight.data.uniform_(-stdv, stdv)
[docs]
def forward(self, k, q):
if len(q.shape) == 2: # q_len missing
q = torch.unsqueeze(q, dim=1)
if len(k.shape) == 2: # k_len missing
k = torch.unsqueeze(k, dim=1)
mb_size = k.shape[0] # ?
k_len = k.shape[1]
q_len = q.shape[1]
# k: (?, k_len, embed_dim,)
# q: (?, q_len, embed_dim,)
# kx: (n_head*?, k_len, hidden_dim)
# qx: (n_head*?, q_len, hidden_dim)
# score: (n_head*?, q_len, k_len,)
# output: (?, q_len, out_dim,)
kx = self.w_k(k).view(mb_size, k_len, self.n_head, self.hidden_dim)
kx = kx.permute(2, 0, 1, 3).contiguous().view(-1, k_len, self.hidden_dim)
qx = self.w_q(q).view(mb_size, q_len, self.n_head, self.hidden_dim)
qx = qx.permute(2, 0, 1, 3).contiguous().view(-1, q_len, self.hidden_dim)
if self.score_function == "dot_product":
kt = kx.permute(0, 2, 1)
score = torch.bmm(qx, kt)
elif self.score_function == "scaled_dot_product":
kt = kx.permute(0, 2, 1)
qkt = torch.bmm(qx, kt)
score = torch.div(qkt, math.sqrt(self.hidden_dim))
elif self.score_function == "mlp":
kxx = torch.unsqueeze(kx, dim=1).expand(-1, q_len, -1, -1)
qxx = torch.unsqueeze(qx, dim=2).expand(-1, -1, k_len, -1)
kq = torch.cat((kxx, qxx), dim=-1) # (n_head*?, q_len, k_len, hidden_dim*2)
# kq = torch.unsqueeze(kx, dim=1) + torch.unsqueeze(qx, dim=2)
score = F.tanh(torch.matmul(kq, self.weight))
elif self.score_function == "bi_linear":
qw = torch.matmul(qx, self.weight)
kt = kx.permute(0, 2, 1)
score = torch.bmm(qw, kt)
else:
raise RuntimeError("invalid score_function")
score = F.softmax(score, dim=-1)
output = torch.bmm(score, kx) # (n_head*?, q_len, hidden_dim)
output = torch.cat(
torch.split(output, mb_size, dim=0), dim=-1
) # (?, q_len, n_head*hidden_dim)
output = self.proj(output) # (?, q_len, out_dim)
output = self.dropout(output)
return output, score
[docs]
class NoQueryAttention(Attention):
"""q is a parameter"""
def __init__(
self,
embed_dim,
hidden_dim=None,
out_dim=None,
n_head=1,
score_function="dot_product",
q_len=1,
dropout=0,
):
super(NoQueryAttention, self).__init__(
embed_dim, hidden_dim, out_dim, n_head, score_function, dropout
)
self.q_len = q_len
self.q = nn.Parameter(torch.Tensor(q_len, embed_dim))
self.reset_q()
[docs]
def reset_q(self):
stdv = 1.0 / math.sqrt(self.embed_dim)
self.q.data.uniform_(-stdv, stdv)
[docs]
def forward(self, k, **kwargs):
mb_size = k.shape[0]
q = self.q.expand(mb_size, -1, -1)
return super(NoQueryAttention, self).forward(k, q)