Source code for pyabsa.tasks.AspectPolarityClassification.models.__lcf__.bert_base

# -*- coding: utf-8 -*-
# file: bert_base.py
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# Copyright (C) 2020. All Rights Reserved.

import torch.nn as nn
from transformers.models.bert.modeling_bert import BertPooler


[docs] class BERT_MLP(nn.Module):
[docs] inputs = ["text_indices"]
def __init__(self, bert, config): super(BERT_MLP, self).__init__() self.bert = bert self.config = config self.dropout = nn.Dropout(config.dropout) self.pooler = BertPooler(bert.config) self.dense = nn.Linear(config.embed_dim, config.output_dim)
[docs] def forward(self, inputs): text_indices = inputs["text_indices"] text_features = self.bert(text_indices)["last_hidden_state"] pooled_output = self.pooler(text_features) pooled_output = self.dropout(pooled_output) logits = self.dense(pooled_output) return {"logits": logits, "hidden_state": pooled_output}