Source code for pyabsa.tasks.__SubtaskTemplate__.dataset_utils.data_utils_for_training

# -*- coding: utf-8 -*-
# file: data_utils_for_training.py
# time: 2021/5/31 0031
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# Copyright (C) 2021. All Rights Reserved.
import numpy as np
import tqdm
from termcolor import colored

from pyabsa.framework.dataset_class.dataset_template import PyABSADataset
from pyabsa.utils.file_utils.file_utils import load_dataset_from_file
from pyabsa.utils.pyabsa_utils import check_and_fix_labels, fprint
from .apc_utils import (
    build_sentiment_window,
    build_spc_mask_vec,
    prepare_input_for_apc,
    configure_spacy_model,
)
from .apc_utils_for_dlcf_dca import (
    prepare_input_for_dlcf_dca,
    configure_dlcf_spacy_model,
)
from pyabsa.utils.pyabsa_utils import validate_absa_example


[docs] class ABSADataset(PyABSADataset):
[docs] def load_data_from_dict(self, data_dict, **kwargs): pass
[docs] def load_data_from_file(self, file_path, **kwargs): configure_spacy_model(self.config) lines = load_dataset_from_file( self.config.dataset_file[self.dataset_type], config=self.config ) if len(lines) % 3 != 0 or len(lines) == 0: fprint( colored( "ERROR: one or more datasets are corrupted, make sure the number of lines in a dataset should be multiples of 3.", "red", ) ) all_data = [] # record polarities type to update output_dim label_set = set() ex_id = 0 for i in tqdm.tqdm(range(0, len(lines), 3), desc="preparing dataloader"): if lines[i].count("$T$") > 1: continue text_left, _, text_right = [s.strip() for s in lines[i].partition("$T$")] aspect = lines[i + 1].strip() polarity = lines[i + 2].strip() prepared_inputs = prepare_input_for_apc( self.config, self.tokenizer, text_left, text_right, aspect, input_demands=self.config.inputs_cols, ) text_raw = prepared_inputs["text_raw"] text_spc = prepared_inputs["text_spc"] aspect = prepared_inputs["aspect"] aspect_position = prepared_inputs["aspect_position"] text_indices = prepared_inputs["text_indices"] text_raw_bert_indices = prepared_inputs["text_raw_bert_indices"] aspect_bert_indices = prepared_inputs["aspect_bert_indices"] lcf_cdw_vec = prepared_inputs["lcf_cdw_vec"] lcf_cdm_vec = prepared_inputs["lcf_cdm_vec"] lcf_vec = prepared_inputs["lcf_vec"] lcfs_cdw_vec = prepared_inputs["lcfs_cdw_vec"] lcfs_cdm_vec = prepared_inputs["lcfs_cdm_vec"] lcfs_vec = prepared_inputs["lcfs_vec"] if validate_absa_example(text_raw, aspect, polarity, self.config): continue if ( self.config.model_name == "dlcf_dca_bert" or self.config.model_name == "dlcfs_dca_bert" ): configure_dlcf_spacy_model(self.config) prepared_inputs = prepare_input_for_dlcf_dca( self.config, self.tokenizer, text_left, text_right, aspect ) dlcf_vec = ( prepared_inputs["dlcf_cdm_vec"] if self.config.lcf == "cdm" else prepared_inputs["dlcf_cdw_vec"] ) dlcfs_vec = ( prepared_inputs["dlcfs_cdm_vec"] if self.config.lcf == "cdm" else prepared_inputs["dlcfs_cdw_vec"] ) depend_vec = prepared_inputs["depend_vec"] depended_vec = prepared_inputs["depended_vec"] data = { "ex_id": ex_id, "text_raw": text_raw, "text_spc": text_spc, "aspect": aspect, "aspect_position": aspect_position, "lca_ids": lcf_vec, # the lca indices are the same as the refactored CDM (lcf != CDW or Fusion) lcf vec "lcf_vec": lcf_vec if "lcf_vec" in self.config.inputs_cols else 0, "lcf_cdw_vec": lcf_cdw_vec if "lcf_cdw_vec" in self.config.inputs_cols else 0, "lcf_cdm_vec": lcf_cdm_vec if "lcf_cdm_vec" in self.config.inputs_cols else 0, "lcfs_vec": lcfs_vec if "lcfs_vec" in self.config.inputs_cols else 0, "lcfs_cdw_vec": lcfs_cdw_vec if "lcfs_cdw_vec" in self.config.inputs_cols else 0, "lcfs_cdm_vec": lcfs_cdm_vec if "lcfs_cdm_vec" in self.config.inputs_cols else 0, "dlcf_vec": dlcf_vec if "dlcf_vec" in self.config.inputs_cols else 0, "dlcfs_vec": dlcfs_vec if "dlcfs_vec" in self.config.inputs_cols else 0, "depend_vec": depend_vec if "depend_vec" in self.config.inputs_cols else 0, "depended_vec": depended_vec if "depended_vec" in self.config.inputs_cols else 0, "spc_mask_vec": build_spc_mask_vec(self.config, text_raw_bert_indices) if "spc_mask_vec" in self.config.inputs_cols else 0, "text_indices": text_indices if "text_indices" in self.config.inputs_cols else 0, "aspect_bert_indices": aspect_bert_indices if "aspect_bert_indices" in self.config.inputs_cols else 0, "text_raw_bert_indices": text_raw_bert_indices if "text_raw_bert_indices" in self.config.inputs_cols else 0, "polarity": polarity, } ex_id += 1 label_set.add(polarity) all_data.append(data) check_and_fix_labels(label_set, "polarity", all_data, self.config) self.config.output_dim = len(label_set) all_data = build_sentiment_window( all_data, self.tokenizer, self.config.similarity_threshold, input_demands=self.config.inputs_cols, ) for data in all_data: cluster_ids = [] for pad_idx in range(self.config.max_seq_len): if pad_idx in data["cluster_ids"]: cluster_ids.append(data["polarity"]) else: cluster_ids.append(-100) # cluster_ids.append(3) data["cluster_ids"] = np.asarray(cluster_ids, dtype=np.int64) data["side_ex_ids"] = np.array(0) data["aspect_position"] = np.array(0) self.data = all_data
def __init__(self, config, tokenizer, dataset_type="train"): super().__init__(config=config, tokenizer=tokenizer, dataset_type=dataset_type)
[docs] def __getitem__(self, index): return self.data[index]
[docs] def __len__(self): return len(self.data)