# -*- coding: utf-8 -*-
# file: ensembler.py
# time: 2021/11/17
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# Copyright (C) 2021. All Rights Reserved.
import copy
import os
import pickle
import re
from hashlib import sha256
from findfile import find_cwd_dir
from termcolor import colored
from torch import nn
from torch.nn import ModuleList
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import AutoTokenizer, AutoModel
from pyabsa.utils.pyabsa_utils import fprint
from ..models.__classic__ import GloVeAPCModelList
from ..models.__lcf__ import APCModelList
from ..models.__plm__ import BERTBaselineAPCModelList
from ..dataset_utils.__classic__.data_utils_for_training import GloVeABSADataset
from ..dataset_utils.__lcf__.data_utils_for_training import ABSADataset
from ..dataset_utils.__plm__.data_utils_for_training import BERTBaselineABSADataset
from pyabsa.framework.tokenizer_class.tokenizer_class import (
PretrainedTokenizer,
Tokenizer,
build_embedding_matrix,
)
[docs]
def model_pool_check(models):
set1 = set([model for model in models if hasattr(APCModelList, model.__name__)])
set2 = set(
[model for model in models if hasattr(BERTBaselineAPCModelList, model.__name__)]
)
set3 = set(
[model for model in models if hasattr(GloVeAPCModelList, model.__name__)]
)
if set1 and set2 or set1 and set3 or set2 and set3:
raise RuntimeError("The APCEnsembler only support the models in same type. ")
[docs]
class APCEnsembler(nn.Module):
def __init__(self, config, load_dataset=True, **kwargs):
super(APCEnsembler, self).__init__()
self.config = config
models = [config.model] if not isinstance(config.model, list) else config.model
model_pool_check(models)
self.config.inputs_cols = set()
for model in models:
self.config.inputs_cols |= set(model.inputs)
self.config.inputs_cols = sorted(self.config.inputs_cols)
self.inputs_cols = self.config.inputs_cols
self.models = ModuleList()
self.tokenizer = None
self.bert = None
self.embedding_matrix = None
self.train_set = None
self.test_set = None
self.valid_set = None
self.test_dataloader = None
self.valid_dataloader = None
for i in range(len(models)):
config_str = re.sub(
r"<.*?>",
"",
str(
sorted(
[
str(self.config.args[k])
for k in self.config.args
if k != "seed"
]
)
),
)
hash_tag = sha256(config_str.encode()).hexdigest()
cache_path = "{}.{}.dataset.{}.cache".format(
self.config.model_name, self.config.dataset_name, hash_tag
)
if (
load_dataset
and os.path.exists(cache_path)
and not self.config.overwrite_cache
):
fprint(colored("Loading dataset cache: {}".format(cache_path), "green"))
with open(cache_path, mode="rb") as f_cache:
(
self.train_set,
self.valid_set,
self.test_set,
self.config,
) = pickle.load(f_cache)
config.update(self.config)
config.args_call_count.update(self.config.args_call_count)
if hasattr(APCModelList, models[i].__name__):
try:
if kwargs.get("offline", False):
self.tokenizer = AutoTokenizer.from_pretrained(
find_cwd_dir(self.config.pretrained_bert.split("/")[-1]),
do_lower_case="uncased" in self.config.pretrained_bert,
)
self.bert = (
AutoModel.from_pretrained(
find_cwd_dir(self.config.pretrained_bert.split("/")[-1])
)
if not self.bert
else self.bert
) # share the underlying bert between models
else:
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.pretrained_bert,
do_lower_case="uncased" in self.config.pretrained_bert,
)
self.bert = (
AutoModel.from_pretrained(self.config.pretrained_bert)
if not self.bert
else self.bert
)
except ValueError as e:
fprint("Init pretrained model failed, exception: {}".format(e))
exit(-1)
if (
load_dataset
and not os.path.exists(cache_path)
or self.config.overwrite_cache
):
self.train_set = (
ABSADataset(self.config, self.tokenizer, dataset_type="train")
if not self.train_set
else self.train_set
)
self.test_set = (
ABSADataset(self.config, self.tokenizer, dataset_type="test")
if not self.test_set
else self.test_set
)
self.valid_set = (
ABSADataset(self.config, self.tokenizer, dataset_type="valid")
if not self.valid_set
else self.valid_set
)
self.models.append(models[i](self.bert, self.config))
elif hasattr(BERTBaselineAPCModelList, models[i].__name__):
self.tokenizer = (
PretrainedTokenizer(self.config)
if not self.tokenizer
else self.tokenizer
)
self.bert = (
AutoModel.from_pretrained(self.config.pretrained_bert)
if not self.bert
else self.bert
)
if (
load_dataset
and not os.path.exists(cache_path)
or self.config.overwrite_cache
):
self.train_set = (
BERTBaselineABSADataset(
self.config, self.tokenizer, dataset_type="train"
)
if not self.train_set
else self.train_set
)
self.test_set = (
BERTBaselineABSADataset(
self.config, self.tokenizer, dataset_type="test"
)
if not self.test_set
else self.test_set
)
self.valid_set = (
BERTBaselineABSADataset(
self.config, self.tokenizer, dataset_type="valid"
)
if not self.valid_set
else self.valid_set
)
self.models.append(
models[i](
copy.deepcopy(self.bert)
if self.config.deep_ensemble
else self.bert,
self.config,
)
)
elif hasattr(GloVeAPCModelList, models[i].__name__):
self.tokenizer = (
Tokenizer.build_tokenizer(
config=self.config,
cache_path="{0}_tokenizer.dat".format(
os.path.basename(config.dataset_name)
),
)
if not self.tokenizer
else self.tokenizer
)
self.embedding_matrix = (
build_embedding_matrix(
config=self.config,
tokenizer=self.tokenizer,
cache_path="{0}_{1}_embedding_matrix.dat".format(
str(config.embed_dim), os.path.basename(config.dataset_name)
),
)
if not self.embedding_matrix
else self.embedding_matrix
)
if (
load_dataset
and not os.path.exists(cache_path)
or self.config.overwrite_cache
):
self.train_set = (
GloVeABSADataset(
self.config, self.tokenizer, dataset_type="train"
)
if not self.train_set
else self.train_set
)
self.test_set = (
GloVeABSADataset(
self.config, self.tokenizer, dataset_type="test"
)
if not self.test_set
else self.test_set
)
self.valid_set = (
GloVeABSADataset(
self.config, self.tokenizer, dataset_type="valid"
)
if not self.valid_set
else self.valid_set
)
self.models.append(
models[i](
copy.deepcopy(self.embedding_matrix)
if self.config.deep_ensemble
else self.embedding_matrix,
self.config,
)
)
self.config.embedding_matrix = self.embedding_matrix
if (
self.config.cache_dataset
and not os.path.exists(cache_path)
and not self.config.overwrite_cache
):
fprint(
colored(
"Caching dataset... please remove cached dataset if any problem happens.",
"red",
)
)
with open(cache_path, mode="wb") as f_cache:
pickle.dump(
(self.train_set, self.valid_set, self.test_set, self.config),
f_cache,
)
if load_dataset:
train_sampler = RandomSampler(self.train_set)
self.train_dataloader = DataLoader(
self.train_set,
batch_size=self.config.batch_size,
pin_memory=True,
sampler=train_sampler,
)
if self.test_set:
test_sampler = SequentialSampler(self.test_set)
self.test_dataloader = DataLoader(
self.test_set,
batch_size=self.config.batch_size,
pin_memory=True,
sampler=test_sampler,
)
if self.valid_set:
valid_sampler = SequentialSampler(self.valid_set)
self.valid_dataloader = DataLoader(
self.valid_set,
batch_size=self.config.batch_size,
pin_memory=True,
sampler=valid_sampler,
)
self.config.tokenizer = self.tokenizer
self.dense = nn.Linear(config.output_dim * len(models), config.output_dim)
[docs]
def forward(self, inputs):
outputs = [self.models[i](inputs) for i in range(len(self.models))]
loss = torch.tensor(0.0, requires_grad=True)
if "ensemble_mode" not in self.config:
self.config.ensemble_mode = "cat"
logits = None
if len(outputs) > 1:
for i, out in enumerate(outputs):
if self.config.ensemble_mode == "cat":
logits = (
torch.cat((logits, out["logits"]), dim=-1)
if i != 0
else out["logits"]
)
elif self.config.ensemble_mode == "mean":
logits = logits + out["logits"] if i != 0 else out["logits"]
else:
raise KeyError("Invalid ensemble_mode!")
if "loss" in out:
loss = loss + out["loss"] if i != 0 else out["loss"]
if "ensemble_mode" not in self.config or self.config.ensemble_mode == "cat":
logits = self.dense(logits)
elif self.config.ensemble_mode == "mean":
logits = logits / len(self.models)
else:
logits = outputs[0]["logits"]
loss = outputs[0]["loss"] if "loss" in outputs[0] else loss
return {"logits": logits, "loss": loss.to(logits.device)}