Source code for pyabsa.tasks.CodeDefectDetection.dataset_utils.__plm__.data_utils_for_training

# -*- coding: utf-8 -*-
# file: data_utils_for_training.py
# time: 02/11/2022 15:39
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# GScholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# ResearchGate: https://www.researchgate.net/profile/Heng-Yang-17/research
# Copyright (C) 2022. All Rights Reserved.
import random

import tqdm
from pyabsa.framework.tokenizer_class.tokenizer_class import pad_and_truncate

from pyabsa.framework.dataset_class.dataset_template import PyABSADataset
from ..cdd_utils import read_defect_examples, _prepare_corrupt_code
from pyabsa.utils.file_utils.file_utils import load_dataset_from_file
from pyabsa.utils.pyabsa_utils import check_and_fix_labels, fprint


[docs] class BERTCDDDataset(PyABSADataset):
[docs] def load_data_from_dict(self, dataset_dict, **kwargs): pass
[docs] def load_data_from_file(self, dataset_file, **kwargs): lines = load_dataset_from_file( self.config.dataset_file[self.dataset_type], config=self.config ) natural_examples = read_defect_examples( lines, self.config.get("data_num", None), self.config.get("remove_comments", True), tokenizer=self.tokenizer, ) all_data = [] label_set = set() c_label_set = set() for ex_id, line in enumerate( tqdm.tqdm(natural_examples, desc="preparing dataloader") ): code_src, label = line.strip().split("$LABEL$") if "$FEATURE$" in code_src: code_src, feature = code_src.split("$FEATURE$") # print(len(self.tokenizer.tokenize(code_src.replace('\n', '')))) code_ids = self.tokenizer.text_to_sequence( code_src, max_length=self.config.max_seq_len, padding="do_not_pad", truncation=False, ) if self.dataset_type == "train" and label == "1": over_sampling = self.config.get("over_sampling", 2) else: over_sampling = 1 for _ in range(over_sampling): code_ids = self.prepare_token_ids( code_ids, self.config.get("sliding_window", False) ) for ids in code_ids: all_data.append( { "ex_id": ex_id, # "code": code_src, "source_ids": ids, "label": label, "corrupt_label": 0, } ) label_set.add(label) c_label_set.add(0) if self.dataset_type == "train": corrupt_examples = read_defect_examples( lines, self.config.get("data_num", None), self.config.get("remove_comments", True), ) for _ in range(self.config.get("noise_instance_num", 0)): for ex_id, line in enumerate( tqdm.tqdm( corrupt_examples, desc="preparing corrupted code dataloader for training set", ) ): code_src, label = line.strip().split("$LABEL$") if label == "0": continue if "$FEATURE$" in code_src: code_src, feature = code_src.split("$FEATURE$") code_src = _prepare_corrupt_code(code_src) corrupt_code_ids = self.tokenizer.text_to_sequence( code_src, max_length=self.config.max_seq_len, padding="do_not_pad", truncation=False, ) corrupt_code_ids = self.prepare_token_ids( corrupt_code_ids, self.config.get("sliding_window", False) ) for ids in corrupt_code_ids: all_data.append( { "ex_id": ex_id, # "code": code_src, "source_ids": ids, "label": "-100", "corrupt_label": 1, } ) label_set.add("-100") c_label_set.add(1) check_and_fix_labels(label_set, "label", all_data, self.config) self.config.output_dim = len(label_set) self.data = all_data
[docs] def prepare_token_ids(self, code_ids, sliding_window=False): all_code_ids = [] code_ids = code_ids[1:-1] if sliding_window is False: code_ids = pad_and_truncate( code_ids, self.config.max_seq_len - 2, value=self.tokenizer.pad_token_id, ) all_code_ids.append( [self.tokenizer.cls_token_id] + code_ids + [self.tokenizer.eos_token_id] ) if all_code_ids[-1].count(self.tokenizer.eos_token_id) != 1: raise ValueError("last token id is not eos token id") return all_code_ids else: code_ids = pad_and_truncate( code_ids, self.config.max_seq_len - 2, value=self.tokenizer.pad_token_id, ) # for x in range(len(code_ids) // ((self.config.max_seq_len - 2) // 2) + 1): # _code_ids = code_ids[x * (self.config.max_seq_len - 2) // 2: # (x + 1) * (self.config.max_seq_len - 2) // 2 + (self.config.max_seq_len - 2) // 2] # print(x * (self.config.max_seq_len - 2) // 2) # print((x + 1) * (self.config.max_seq_len - 2) // 2 + (self.config.max_seq_len - 2) // 2) for x in range(len(code_ids) // (self.config.max_seq_len - 2) + 1): _code_ids = code_ids[ x * (self.config.max_seq_len - 2) : (x + 1) * (self.config.max_seq_len - 2) ] _code_ids = pad_and_truncate( _code_ids, self.config.max_seq_len - 2, value=self.tokenizer.pad_token_id, ) if _code_ids: all_code_ids.append( [self.tokenizer.cls_token_id] + _code_ids + [self.tokenizer.eos_token_id] ) if all_code_ids[-1].count(self.tokenizer.eos_token_id) != 1: raise ValueError("last token id is not eos token id") return all_code_ids
def __init__(self, config, tokenizer, dataset_type="train", **kwargs): super().__init__(config, tokenizer, dataset_type, **kwargs)
[docs] def __getitem__(self, index): return self.data[index]
[docs] def __len__(self): return len(self.data)