import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
RobertaConfig,
RobertaModel,
RobertaTokenizer,
BartConfig,
BartForConditionalGeneration,
BartTokenizer,
T5Config,
T5ForConditionalGeneration,
T5Tokenizer,
)
[docs]
logger = logging.getLogger(__name__)
[docs]
MODEL_CLASSES = {
"roberta": (RobertaConfig, RobertaModel, RobertaTokenizer),
"t5": (T5Config, T5ForConditionalGeneration, T5Tokenizer),
"codet5": (T5Config, T5ForConditionalGeneration, RobertaTokenizer),
"bart": (BartConfig, BartForConditionalGeneration, BartTokenizer),
}
[docs]
def get_model_size(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
model_size = sum([np.prod(p.size()) for p in model_parameters])
return "{}M".format(round(model_size / 1e6))
[docs]
def build_or_load_gen_model(args):
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path
)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name)
if args.model_type == "roberta":
encoder = model_class.from_pretrained(args.model_name_or_path, config=config)
decoder_layer = nn.TransformerDecoderLayer(
d_model=config.hidden_size, nhead=config.num_attention_heads
)
decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
model = Seq2Seq(
encoder=encoder,
decoder=decoder,
config=config,
beam_size=args.beam_size,
max_length=args.max_target_length,
sos_id=tokenizer.cls_token_id,
eos_id=tokenizer.sep_token_id,
)
else:
model = model_class.from_pretrained(args.model_name_or_path)
logger.info(
"Finish loading model [%s] from %s",
get_model_size(model),
args.model_name_or_path,
)
if args.load_model_path is not None:
logger.info("Reload model from {}".format(args.load_model_path))
model.load_state_dict(torch.load(args.load_model_path), strict=False)
return config, model, tokenizer
[docs]
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, 2)
[docs]
def forward(self, x, **kwargs):
x = x.reshape(-1, x.size(-1) * 2)
x = self.dense(x)
x = torch.tanh(x)
x = self.out_proj(x)
return x
[docs]
class FocalLoss(nn.Module):
"""
Reference:
Li et al., Focal Loss for Dense Object Detection. ICCV 2017.
Equation: Loss(x, class) = - (1-sigmoid(p^t))^gamma \log(p^t)
Focal loss tries to make neural networks to pay more attentions on difficult samples.
Args:
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
"""
def __init__(self, gamma=0.75):
super(FocalLoss, self).__init__()
self.gamma = torch.tensor(gamma, dtype=torch.float32)
self.eps = 1e-6
[docs]
def forward(self, input, target):
# input are not the probabilities, they are just the cnn out vector
# input and target shape: (bs, n_classes)
# sigmoid
probs = torch.sigmoid(input)
log_probs = -torch.log(probs)
focal_loss = torch.sum(
torch.pow(1 - probs + self.eps, self.gamma).mul(log_probs).mul(target),
dim=1,
)
# bce_loss = torch.sum(log_probs.mul(target), dim = 1)
return focal_loss.mean() # , bce_loss
[docs]
class LDAMLoss(nn.Module):
"""
References:
Cao et al., Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss. NeurIPS 2019.
Args:
s(float, double) : the scale of logits, according to the official codes.
max_m(float, double): margin on loss functions. See original paper's Equation (12) and (13)
Notes: There are two hyper-parameters of LDAMLoss codes provided by official codes,
but the authors only provided the settings on long-tailed CIFAR.
Settings on other datasets are not avaliable (https://github.com/kaidic/LDAM-DRW/issues/5).
"""
def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30):
super(LDAMLoss, self).__init__()
m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list))
m_list = m_list * (max_m / np.max(m_list))
m_list = torch.cuda.FloatTensor(m_list)
self.m_list = m_list
assert s > 0
self.s = s
self.weight = weight
[docs]
def forward(self, x, target):
index = torch.zeros_like(x, dtype=torch.uint8)
index.scatter_(1, target.data.view(-1, 1), 1)
index_float = index.type(torch.cuda.FloatTensor)
batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1))
batch_m = batch_m.view((-1, 1))
x_m = x - batch_m
output = torch.where(index, x_m, x)
return F.cross_entropy(self.s * output, target, weight=self.weight)
[docs]
class ClassBalanceCE(nn.Module):
r"""
Reference:
Cui et al., Class-Balanced Loss Based on Effective Number of Samples. CVPR 2019.
Equation: Loss(x, c) = \frac{1-\beta}{1-\beta^{n_c}} * CrossEntropy(x, c)
Class-balanced loss considers the real volumes, named effective numbers, of each class, \
rather than nominal numeber of images provided by original datasets.
Args:
beta(float, double) : hyper-parameter for class balanced loss to control the cost-sensitive weights.
"""
def __init__(self, para_dict=None):
super(ClassBalanceCE, self).__init__(para_dict)
self.beta = self.para_dict["cfg"].LOSS.ClassBalanceCE.BETA
self.class_balanced_weight = np.array(
[(1 - self.beta) / (1 - self.beta**N) for N in self.num_class_list]
)
self.class_balanced_weight = torch.FloatTensor(
self.class_balanced_weight
/ np.sum(self.class_balanced_weight)
* self.num_classes
).to(self.device)
[docs]
def update(self, epoch):
"""
Args:
epoch: int. starting from 1.
"""
if not self.drw:
self.weight_list = self.class_balanced_weight
else:
start = (epoch - 1) // self.drw_start_epoch
if start:
self.weight_list = self.class_balanced_weight
[docs]
class DefectModel(nn.Module):
[docs]
MODEL_CLASSES = {
"roberta-base": (RobertaConfig, RobertaModel, RobertaTokenizer),
"t5-base": (T5Config, T5ForConditionalGeneration, T5Tokenizer),
"facebook/bart-base": (BartConfig, BartForConditionalGeneration, BartTokenizer),
"Salesforce/codet5-small": (
T5Config,
T5ForConditionalGeneration,
RobertaTokenizer,
),
"Salesforce/codet5-base": (
T5Config,
T5ForConditionalGeneration,
RobertaTokenizer,
),
}
def __init__(self, config):
super(DefectModel, self).__init__()
self.encoder = MODEL_CLASSES[self.config.pretrained_bert].from_pretrained(
self.config.pretrained_bert
)
self.config = config
self.tokenizer = self.config.tokenizer
self.classifier1 = nn.Linear(config.hidden_size, 2)
self.classifier2 = nn.Linear(config.hidden_size, 2)
[docs]
def get_t5_vec(self, source_ids):
attention_mask = source_ids.ne(self.tokenizer.pad_token_id)
outputs = self.encoder(
input_ids=source_ids,
attention_mask=attention_mask,
labels=source_ids,
decoder_attention_mask=attention_mask,
output_hidden_states=True,
)
hidden_states = outputs["decoder_hidden_states"][-1]
eos_mask = source_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
vec = hidden_states[eos_mask, :].view(
hidden_states.size(0), -1, hidden_states.size(-1)
)[:, -1, :]
return vec
[docs]
def get_bart_vec(self, source_ids):
attention_mask = source_ids.ne(self.tokenizer.pad_token_id)
outputs = self.encoder(
input_ids=source_ids,
attention_mask=attention_mask,
labels=source_ids,
decoder_attention_mask=attention_mask,
output_hidden_states=True,
)
hidden_states = outputs["decoder_hidden_states"][-1]
eos_mask = source_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
vec = hidden_states[eos_mask, :].view(
hidden_states.size(0), -1, hidden_states.size(-1)
)[:, -1, :]
return vec
[docs]
def get_roberta_vec(self, source_ids):
attention_mask = source_ids.ne(self.tokenizer.pad_token_id)
vec = self.encoder(input_ids=source_ids, attention_mask=attention_mask)[0][
:, 0, :
]
return vec
[docs]
def forward(self, source_ids=None, labels=None, corrupt_labels=None):
source_ids = source_ids.view(-1, self.args.max_source_length)
if self.args.model_type == "codet5":
vec = self.get_t5_vec(source_ids)
elif self.args.model_type == "bart":
vec = self.get_bart_vec(source_ids)
elif self.args.model_type == "roberta":
vec = self.get_roberta_vec(source_ids)
elif self.args.model_type == "t5":
vec = self.get_t5_vec(source_ids)
logits1 = self.classifier1(vec)
logits2 = self.classifier2(vec)
prob = nn.functional.softmax(logits1, dim=-1)
c_prob = nn.functional.softmax(logits2, dim=-1)
if labels is not None:
loss_fct1 = nn.CrossEntropyLoss()
loss_fct2 = nn.CrossEntropyLoss()
# loss_fct1 = FocalLoss()
# loss_fct2 = FocalLoss()
loss = loss_fct1(logits1, labels) + loss_fct2(logits2, corrupt_labels)
# loss = loss_fct1(logits1, labels)
# loss = loss_fct1(logits2, corrupt_labels)
return loss, prob, c_prob
else:
return prob, c_prob
# https://github.com/microsoft/CodeBERT/blob/master/CodeBERT/code2nl/model.py
[docs]
class Seq2Seq(nn.Module):
"""
Build Seqence-to-Sequence.
Parameters:
* `encoder`- encoder of seq2seq model. e.g. roberta
* `decoder`- decoder of seq2seq model. e.g. transformer
* `config`- configuration of encoder model.
* `beam_size`- beam size for beam search.
* `max_length`- max length of target for beam search.
* `sos_id`- start of symbol ids in target for beam search.
* `eos_id`- end of symbol ids in target for beam search.
"""
def __init__(
self,
encoder,
decoder,
config,
beam_size=None,
max_length=None,
sos_id=None,
eos_id=None,
):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.config = config
self.register_buffer("bias", torch.tril(torch.ones(2048, 2048)))
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.lsm = nn.LogSoftmax(dim=-1)
self.tie_weights()
self.beam_size = beam_size
self.max_length = max_length
self.sos_id = sos_id
self.eos_id = eos_id
[docs]
def _tie_or_clone_weights(self, first_module, second_module):
"""Tie or clone module weights depending of weither we are using TorchScript or not"""
if self.config.torchscript:
first_module.weight = nn.Parameter(second_module.weight.clone())
else:
first_module.weight = second_module.weight
[docs]
def tie_weights(self):
"""Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(
self.lm_head, self.encoder.embeddings.word_embeddings
)
[docs]
def forward(
self,
source_ids=None,
source_mask=None,
target_ids=None,
target_mask=None,
args=None,
):
outputs = self.encoder(source_ids, attention_mask=source_mask)
encoder_output = outputs[0].permute([1, 0, 2]).contiguous()
if target_ids is not None:
attn_mask = -1e4 * (
1 - self.bias[: target_ids.shape[1], : target_ids.shape[1]]
)
tgt_embeddings = (
self.encoder.embeddings(target_ids).permute([1, 0, 2]).contiguous()
)
out = self.decoder(
tgt_embeddings,
encoder_output,
tgt_mask=attn_mask,
memory_key_padding_mask=~source_mask,
)
# memory_key_padding_mask=(1 - source_mask).bool())
hidden_states = torch.tanh(self.dense(out)).permute([1, 0, 2]).contiguous()
lm_logits = self.lm_head(hidden_states)
# Shift so that tokens < n predict n
active_loss = target_mask[..., 1:].ne(0).view(-1) == 1
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = target_ids[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1))[active_loss],
shift_labels.view(-1)[active_loss],
)
outputs = loss, loss * active_loss.sum(), active_loss.sum()
return outputs
else:
# Predict
preds = []
zero = torch.cuda.LongTensor(1).fill_(0)
for i in range(source_ids.shape[0]):
context = encoder_output[:, i : i + 1]
context_mask = source_mask[i : i + 1, :]
beam = Beam(self.beam_size, self.sos_id, self.eos_id)
input_ids = beam.getCurrentState()
context = context.repeat(1, self.beam_size, 1)
context_mask = context_mask.repeat(self.beam_size, 1)
for _ in range(self.max_length):
if beam.done():
break
attn_mask = -1e4 * (
1 - self.bias[: input_ids.shape[1], : input_ids.shape[1]]
)
tgt_embeddings = (
self.encoder.embeddings(input_ids)
.permute([1, 0, 2])
.contiguous()
)
out = self.decoder(
tgt_embeddings,
context,
tgt_mask=attn_mask,
memory_key_padding_mask=~context_mask,
)
# memory_key_padding_mask=(1 - context_mask).bool())
out = torch.tanh(self.dense(out))
hidden_states = out.permute([1, 0, 2]).contiguous()[:, -1, :]
out = self.lsm(self.lm_head(hidden_states)).data
beam.advance(out)
input_ids.data.copy_(
input_ids.data.index_select(0, beam.getCurrentOrigin())
)
input_ids = torch.cat((input_ids, beam.getCurrentState()), -1)
hyp = beam.getHyp(beam.getFinal())
pred = beam.buildTargetTokens(hyp)[: self.beam_size]
pred = [
torch.cat(
[x.view(-1) for x in p] + [zero] * (self.max_length - len(p))
).view(1, -1)
for p in pred
]
preds.append(torch.cat(pred, 0).unsqueeze(0))
preds = torch.cat(preds, 0)
return preds
[docs]
class Beam(object):
def __init__(self, size, sos, eos):
self.size = size
self.tt = torch.cuda
# The score for each translation on the beam.
self.scores = self.tt.FloatTensor(size).zero_()
# The backpointers at each time-step.
self.prevKs = []
# The outputs at each time-step.
self.nextYs = [self.tt.LongTensor(size).fill_(0)]
self.nextYs[0][0] = sos
# Has EOS topped the beam yet.
self._eos = eos
self.eosTop = False
# Time and k pair for finished.
self.finished = []
[docs]
def getCurrentState(self):
"Get the outputs for the current timestep."
batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)
return batch
[docs]
def getCurrentOrigin(self):
"Get the backpointers for the current timestep."
return self.prevKs[-1]
[docs]
def advance(self, wordLk):
"""
Given prob over words for every last beam `wordLk` and attention
`attnOut`: Compute and update the beam search.
Parameters:
* `wordLk`- probs of advancing from the last step (K x words)
* `attnOut`- attention at the last step
Returns: True if beam search is complete.
"""
numWords = wordLk.size(1)
# Sum the previous scores.
if len(self.prevKs) > 0:
beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
# Don't let EOS have children.
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] == self._eos:
beamLk[i] = -1e20
else:
beamLk = wordLk[0]
flatBeamLk = beamLk.view(-1)
bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
self.scores = bestScores
# bestScoresId is flattened beam x word array, so calculate which
# word and beam each score came from
prevK = bestScoresId // numWords
self.prevKs.append(prevK)
self.nextYs.append((bestScoresId - prevK * numWords))
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] == self._eos:
s = self.scores[i]
self.finished.append((s, len(self.nextYs) - 1, i))
# End condition is when top-of-beam is EOS and no global score.
if self.nextYs[-1][0] == self._eos:
self.eosTop = True
[docs]
def done(self):
return self.eosTop and len(self.finished) >= self.size
[docs]
def getFinal(self):
if len(self.finished) == 0:
self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
self.finished.sort(key=lambda a: -a[0])
if len(self.finished) != self.size:
unfinished = []
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] != self._eos:
s = self.scores[i]
unfinished.append((s, len(self.nextYs) - 1, i))
unfinished.sort(key=lambda a: -a[0])
self.finished += unfinished[: self.size - len(self.finished)]
return self.finished[: self.size]
[docs]
def getHyp(self, beam_res):
"""
Walk back to construct the full hypothesis.
"""
hyps = []
for _, timestep, k in beam_res:
hyp = []
for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
hyp.append(self.nextYs[j + 1][k])
k = self.prevKs[j][k]
hyps.append(hyp[::-1])
return hyps
[docs]
def buildTargetTokens(self, preds):
sentence = []
for pred in preds:
tokens = []
for tok in pred:
if tok == self._eos:
break
tokens.append(tok)
sentence.append(tokens)
return sentence