Custom batchsampler with Mnist data set dataloader in pytorch

created at 08-01-2021 views: 15

Note the difference between list.extend() and list.append()

a = []
b = []
for i in range(5):
    temp = [1, 2, 3]
    a.append(temp)
    b.extend(temp)

print(a)
print(b)

res:

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler
import numpy as np



class BalancedBatchSampler(BatchSampler):
    def __init__(self, labels, n_classes, n_samplers):
        # super(BalancedBatchSampler, self).__init__()
        self.labels = labels
        self.labels_set = list(set(self.labels))
        self.labels_to_indices = {
            label: np.where(self.labels == label)[0] for label in self.labels_set
        }
        for i in self.labels_set:
            np.random.shuffle(self.labels_to_indices[i])

        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samplers
        self.batch_size = self.n_classes * self.n_samples
        self.n_dataset = len(self.labels)

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < self.n_dataset:
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                temp = self.labels_to_indices[class_][self.used_label_indices_count[class_]: self.used_label_indices_count[class_] + self.n_samples]
                indices.extend(temp)

                self.used_label_indices_count[class_] += self.n_classes
                if self.used_label_indices_count[class_] + self.n_samples > len(self.labels_to_indices[class_]):
                    np.random.shuffle(self.labels_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            print(indices)
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return self.n_dataset // self.batch_size

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./dataset',
                               train=True, download=True,
                               transform=transform)
print(train_dataset.train_data[0].shape)
print(train_dataset.train_labels.shape)

train_batch_sampler = BalancedBatchSampler(train_dataset.train_labels.numpy(), n_classes=10, n_samplers=10)
online_train_loader = DataLoader(train_dataset, batch_sampler=train_batch_sampler)

for idx, (x, label) in enumerate(online_train_loader):
    print(x.shape)
    print(label.shape)

res

[31483, 10535, 26733, 48799, 12254, 6124, 41593, 35301, 30832, 11899, 28915, 6203, 27524, 49453, 57655, 38734, 25387, 48723, 29194, 37253, 21439, 58740, 12650, 55498, 16384, 5177, 19609, 19944, 18446, 28570, 44042, 25775, 17630, 44943, 29797, 31594, 35361, 48422, 54207, 24207, 10484, 20005, 53892, 19964, 22763, 26977, 52685, 10135, 42490, 50863, 49466, 46537, 41434, 44982, 39428, 20841, 45431, 27237, 41334, 58236, 48337, 27806, 11072, 34847, 25957, 52599, 9922, 35327, 46962, 39800, 8550, 55299, 47119, 51285, 9674, 30535, 15989, 42494, 5234, 22419, 13827, 42807, 45712, 7123, 50337, 15850, 26048, 46070, 254, 26378, 20379, 40444, 14433, 19499, 22676, 43137, 25194, 14501, 48959, 2772]
torch.Size([100, 1, 28, 28])
torch.Size([100])
[38742, 30347, 36870, 44748, 51105, 56560, 36913, 24658, 44176, 40364, 55521, 12259, 59433, 49199, 52526, 462, 24913, 45129, 24075, 13238, 54802, 32725, 57971, 13133, 331, 39837, 6581, 43572, 9745, 38322, 23444, 20988, 41826, 55594, 12060, 56737, 11432, 408, 47327, 36687, 49957, 31545, 6504, 1278, 970, 30595, 15232, 37830, 38773, 23312, 7296, 7578, 22572, 12169, 52457, 36191, 17694, 3753, 33552, 37851, 11474, 38312, 11682, 11820, 34338, 15268, 19268, 51770, 6717, 57294, 32197, 40510, 23178, 59602, 45469, 55535, 34415, 23469, 6713, 53911, 46454, 48088, 26548, 27443, 36592, 36623, 12430, 1883, 28418, 56006, 11655, 32989, 30076, 38427, 26509, 37199, 52186, 32695, 59983, 31978]
torch.Size([100, 1, 28, 28])
torch.Size([100])
[46489, 7154, 44018, 58117, 55184, 259, 38817, 14828, 30396, 27499, 27971, 25242, 49412, 5264, 16966, 16548, 50665, 8758, 3718, 56034, 26992, 17076, 49211, 39375, 25470, 16526, 1537, 54650, 46711, 37683, 23081, 22274, 11444, 160, 42580, 58405, 1914, 3883, 24831, 40981, 46165, 32413, 13474, 22812, 39560, 55222, 50907, 6419, 48751, 34874, 68, 22265, 10469, 37822, 52797, 39156, 14258, 13197, 25285, 23454, 27700, 23652, 10279, 52782, 29060, 48516, 11084, 25122, 36775, 33290, 59257, 45935, 36184, 43035, 8799, 14400, 33886, 3821, 261, 47472, 54559, 38771, 39222, 25091, 46208, 10606, 52818, 53514, 5493, 59375, 20111, 7264, 10382, 9265, 33974, 36102, 59910, 2137, 43729, 24243]
torch.Size([100, 1, 28, 28])
torch.Size([100])
created at:08-01-2021
edited at: 08-01-2021: