Pytorch CIFAR10 image classification, data loading and visualization

created at 08-03-2021 views: 43

Pytorch general process

  • Data read
  • data processing
  • Build a network
  • Model training
  • Model online

Here will first talk about the data loading and image visualization of CIFAR10, and then the model will introduce and implement the network.

1. Data reading

CIFAR-10 is a small data set compiled by Hinton students Alex Krizhevsky and Ilya Sutskever for identifying universal objects. A total of 10 categories of RGB color pictures are included: arplane, automobile, bird, cat, deer, dog, frog, horse, ship and truck. The size of the picture is 32×32, and there are a total of 50,000 training pictures and 10,000 test pictures in the data set.

Compared with the MNIST data set, CIFAR-10 has the following differences:

  • CIFAR-10 is a 3-channel color RGB image, and MNIST is a grayscale image.
  • The picture size of CIFAR-10 is 32×32, while the picture size of MNIST is 28×28, which is slightly larger than MNIST.
  • Compared with handwritten characters, CIFAR-10 contains real objects in the real world, not only the noise is large, but the proportions and features of the objects are not the same, which brings great difficulties to recognition.


First use torchvision to load and normalize our training data and test data.

  1. Torchvision implements some commonly used deep learning related image data loading functions, such as cifar10, Imagenet, Mnist, etc., which are saved in the torchvision.datasets module.
  2. At the same time, some methods of processing data are also encapsulated. Saved in the torchvision.transforms module
  3. Some models and tools are also encapsulated in corresponding models. For example, torchvision.models contains models such as AlexNet, VGG, ResNet, and SqueezeNet.

Since the output of torchvision's datasets is [0,1] PILImage, we first normalize to [-1,1] Tensor

First, a transformation transform is defined. The Compose() in the transforms module mentioned above is used to combine multiple transformations. You can see that the two transformations ToTensor and Normalize are combined.

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) The front (0.5, 0.5, 0.5) is the average of the three RGB channels, and the back (0.5, 0.5, 0.5) is the three channels Note that the channel order is RGB. Students who have used opencv should know that the image read by openCV is in BRG order. These two tuple data are used to normalize the RGB image, as shown by the name Normalize, where 0.5 is just an approximate operation. In fact, the mean and variance are not so many, but for this example, the impact Don't count. The exact value is calculated by calculating the data of the three channels R, G, and B respectively.

transform = transforms.Compose([
#     transforms.CenterCrop(224),
    transforms.RandomCrop(32,padding=4), # Data augmentation
    transforms.RandomHorizontalFlip(),  # Data augmentation
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

Trainloader is actually a more important thing. We will pass the data to the network through trainloader later. Of course, the trainloader here is actually a variable name, which can be taken whatever you want. The point is that it is defined by the following Yes, this thing comes from the module.

Batch_Size = 256
trainset = datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
testset = datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
trainloader =, batch_size=Batch_Size,shuffle=True, num_workers=2)
testloader =, batch_size=Batch_Size,shuffle=True, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Files already downloaded and verified
Files already downloaded and verified

2. View data (format, size, shape)

First you can view the category

classes = trainset.classes
{'airplane': 0,
 'automobile': 1,
 'bird': 2,
 'cat': 3,
 'deer': 4,
 'dog': 5,
 'frog': 6,
 'horse': 7,
 'ship': 8,
 'truck': 9}

You can also check the data of the training set 

#50000 is the number of pictures, 32x32 is the picture size, 3 is the number of channels RGB
(50000, 32, 32, 3)

View data type


<class 'numpy.ndarray'>
<class 'torchvision.datasets.cifar.CIFAR10'>


  • is the standard numpy.ndarray type, where 50000 is the number of pictures, 32x32 is the picture size, and 3 is the number of channels RGB;
  • trainset is the standard ? ? type, where 50000 is the number of pictures, 0 means take the previous data, 2 means 3 channels RGB, 32*32 means the picture size

3. View pictures

import numpy as np
import matplotlib.pyplot as plt
im,label = iter(trainloader).next()

View pictures

Convert np.ndarray to torch.Tensor

In deep learning, the original image needs to be converted to a data format customized by the deep learning framework, and in pytorch, it needs to be converted to torch.Tensor.
pytorch provides torch.Tensor and numpy.ndarray conversion interfaces:

  • torch.from_numpy(xxx): Convert numpy.ndarray to torch.Tensor
  • tensor1.numpy(): Get the numpy format data of the tensor1 object

The representation of torch.Tensor high-dimensional matrix: N x C x H x W

numpy.ndarray Representation of high-dimensional matrix: N x H x W x C

Therefore, you need to use the numpy.transpose() method when converting between the two.

def imshow(img):
    img = img / 2 + 0.5
    img = np.transpose(img.numpy(),(1,2,0))




created at:08-03-2021
edited at: 08-03-2021: