Source code for pyabsa.utils.text_utils.mlm

# -*- coding: utf-8 -*-
# file: mlm.py
# time: 03/11/2022 15:37
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# GScholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# ResearchGate: https://www.researchgate.net/profile/Heng-Yang-17/research
# Copyright (C) 2022. All Rights Reserved.
from transformers import (
    BertForMaskedLM,
    RobertaForMaskedLM,
    DebertaV2ForMaskedLM,
    AutoConfig,
)
from transformers import AutoTokenizer


[docs] def get_mlm_and_tokenizer(model, config): """ Returns a masked language model (MLM) and a tokenizer for the specified model and config. Args: model: The BERT-like model to use. config: The configuration object to use. Returns: A tuple containing the MLM and the tokenizer. """ base_model = None for child in model.children(): if hasattr(child, "base_model"): base_model = child.base_model break pretrained_config = AutoConfig.from_pretrained(config.pretrained_bert) if "deberta-v3" in config.pretrained_bert: MLM = DebertaV2ForMaskedLM(pretrained_config) MLM.deberta = base_model elif "roberta" in config.pretrained_bert: MLM = RobertaForMaskedLM(pretrained_config) MLM.roberta = base_model else: MLM = BertForMaskedLM(pretrained_config) MLM.bert = base_model return MLM, AutoTokenizer.from_pretrained(config.pretrained_bert)