# -*- coding: utf-8 -*-
# file: imblanced_sampler.py
# time: 23:10 2023/1/13
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2021. All Rights Reserved.
from typing import Callable
import numpy as np
import pandas as pd
import torch
import torch.utils.data
# based on https://github.com/ufoym/imbalanced-dataset-sampler
[docs]
class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
"""Samples elements randomly from a given list of indices for imbalanced dataset
Arguments:
indices: a list of indices
num_samples: number of samples to draw
callback_get_label: a callback-like function which takes two arguments - dataset and index
"""
def __init__(
self,
dataset,
labels: list = None,
indices: list = None,
num_samples: int = None,
callback_get_label: Callable = None,
):
# if indices is not provided, all elements in the dataset will be considered
self.indices = list(range(len(dataset))) if indices is None else indices
# define custom callback
self.callback_get_label = callback_get_label
# if num_samples is not provided, draw `len(indices)` samples in each iteration
self.num_samples = len(self.indices) if num_samples is None else num_samples
# distribution of classes in the dataset
df = pd.DataFrame()
df["label"] = (
np.asarray(self._get_labels(dataset)) if labels is None else labels
)
df.index = self.indices
df = df.sort_index()
label_to_count = df["label"].value_counts(dropna=False)
weights = 1.0 / label_to_count[df["label"]]
self.weights = torch.DoubleTensor(weights.to_list())
[docs]
def _get_labels(self, dataset):
if self.callback_get_label:
return self.callback_get_label(dataset)
elif isinstance(dataset, torch.utils.data.TensorDataset):
return dataset.tensors[1]
elif isinstance(dataset, torch.utils.data.Subset):
return dataset.dataset.imgs[:][1]
elif isinstance(dataset, torch.utils.data.Dataset):
return dataset.get_labels()
else:
raise NotImplementedError
[docs]
def __iter__(self):
return (
self.indices[i]
for i in torch.multinomial(self.weights, self.num_samples, replacement=True)
)
[docs]
def __len__(self):
return self.num_samples