| |

[PyTorch] CNN on CIFAR10 Image Classification

In this article, we will be looking at how to build a Convolutional Neural Network for classifying images in CIFAR10 dataset. We will be using PyTorch for the implementation.

Convolutional Neural Network

A CNN is a deep learning technique commonly used for image classification tasks. It can recognize and classify particular features from images. At a high level, CNNs contain three main types of layers:

  1. Convolutional layers. The fundamental building block of a CNN model. It applies convolutional filters or kernels (small matrix of weights) to the input to extract features. The filter moves across the receptive field of an input image to detect the special features.
  2. Pooling layers. The pooling layer follows the convolutional layer. It downsamples the feature maps (reduce dimensionality) from the convolutional layers while retaining critical information. Max pooling and average pooling are commonly used strategies.
  3. Fully-connected layers. Take the high-level features from the convolutional and pooling layers as input for classification. These layers are responsible for classifying images based on features extracted from previous layers. Multiple fully connected layers can be stacked.

About the CIFAR10 Dataset

The CIFAR-10 (Canadian Institute for Advanced Research, 10 classes) dataset has a total of 60,000 images split into 50,000 training images and 10,000 test images. Each image is of size 32 x 32 x 3 (32 wide, 32 high, 3 color channels). Each pixel-channel value is an integer between 0 and 255. Each image is one of 10 classes: plane (class 0), car, bird, cat, deer, dog, frog, horse, ship, truck (class 9). Most deep learning frameworks, including PyTorch, Tensorflow, and Keras, have built-in CIFAR-10 datasets.

Alongside the MNIST dataset, CIFAR 10 is one of the most popular datasets in the field of machine learning research. 

CIFAR10 dataset

CNN Pytorch implementation

We are going to implement Convolutional Neural Network model(CNN) to perform Multi-Class Image Classification on the famous CIFAR10 dataset.

Import the libraries

import torch
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.utils import make_grid

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchinfo import summary

from sklearn.metrics import confusion_matrix, classification_report

import numpy as np
from tqdm.notebook import tqdm

# Seed
import random
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import matplotlib.pyplot as plt
%matplotlib inline
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

Hyper-parameter setting

# Hyper-parameters

N_VALID = 0.4
NUM_WORKERS=0
NUM_EPOCHS= 20
BATCH_SIZE = 32
LEARNING_RATE = 0.001

Transforms

The mean and std are pre-calculated for normalizing. Normalizing images is a common preprocessing step for computer vision tasks. The image data is normalized by the mean and standard deviation. In ImageNet (a widely used dataset for pre-training deep learning models in computer vision) which is an open-source dataset, its mean and STD values have been calculated and can be used directly:
mean = (0.485,0.456,0.406)
std = (0.229,0.224,0.225)

In PyTorch tutorials, we can see that they use (0.5, 0.5, 0.5) and (0.5, 0.5, 0.5) as the mean and standard deviation for normalizing CIFAR-10. You can also calculate mean and std across multiple channels of the data you are working with for normalization to obtain a 0 mean and standard deviation (std) of 1. You can check the code for finding the mean and std in this link.

  • transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) subtracts the mean from each value and then divides it by the standard deviation.

Here, I am not resizing the image as all the images are of size 32×32.

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5), # Random horizontal flip 
    #transforms.ColorJitter(brightness = 0.1, # Randomly adjust color jitter of the images
    #                       contrast = 0.1,
     #                      saturation = 0.1),
    transforms.ToTensor(), # Converting image to tensor
    transforms.Normalize(([0.4914, 0.4822, 0.4465]), [0.2470, 0.2435, 0.2616]),
    transforms.RandomErasing(p=0.75,scale=(0.02, 0.1),value=1.0, inplace=False)
])

test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(([0.4914, 0.4822, 0.4465]), [0.2470, 0.2435, 0.2616])
    ])

Load the data

SubsetRandomSampler is used for splitting the training dataset into train and validation set.

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                                             download=True,
                                            transform=train_transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                            download=True,
                                           transform=test_transform)

# Get indices for training_set and validation_set
n_train = len(train_dataset)
indices = list(range(n_train))
np.random.shuffle(indices)
split = int(np.floor(N_VALID * n_train))
train_idx, valid_idx = indices[split:], indices[:split]

# Define samplers for obtaining training and validation
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=BATCH_SIZE, 
                                           sampler = train_sampler,
                                           num_workers = NUM_WORKERS
                                           )
valid_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size = BATCH_SIZE,
                                          sampler = valid_sampler,
                                          num_workers = NUM_WORKERS)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=BATCH_SIZE,
                                          num_workers = NUM_WORKERS,
                                          shuffle=False)
# print shape of dataset for each set
for key, value in {'Train': train_loader, "Validation": valid_loader, 'Test': test_loader}.items():
    for X, y in value:
        print(f"{key}:")
        print(f"Shape of X [N, C, H, W]: {X.shape}")
        print(f"Shape of y: {y.shape} {y.dtype}\n")
        break

Train:
Shape of X [N, C, H, W]: torch.Size([32, 3, 32, 32])
Shape of y: torch.Size([32]) torch.int64

Validation:
Shape of X [N, C, H, W]: torch.Size([32, 3, 32, 32])
Shape of y: torch.Size([32]) torch.int64

Test:
Shape of X [N, C, H, W]: torch.Size([32, 3, 32, 32])
Shape of y: torch.Size([32]) torch.int64

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Visualize a batch of training data

# Helper function to display the image
def imshow(img):
    # Un-normalize and display the image
    img = img / 2 + 0.5
    # Convert from tensor image
    plt.imshow(np.transpose(img, (1,2,0)))
# Get one batch of training images
examples = iter(train_loader)
images, labels = next(examples)
# Convert images to numpy for display
images = images.numpy()

# Plot the images in the batch
fig = plt.figure(figsize=(25, 4))

# Display 20 images
for idx in np.arange(20):
    ax = fig.add_subplot(2, 10, idx+1, xticks=[], yticks=[])
    images[idx] = images[idx].clip(0, 1) # Clipping the size to print the images later
    imshow(images[idx])
    ax.set_title(classes[labels[idx]])
cifar10 image classification

Define CNN model

This CNN model consists of four convolutional layers followed by max-pooling operations, three fully connected layers, and dropout regularization.

  • The convolutional layers (conv1, conv2, conv3, and conv4) apply learnable filters to the input image to extract features.
  • Max-pooling layers (pool) downsample the feature maps obtained from convolutional layers, reducing their spatial dimensions while retaining important information.
  • The fully connected layers (fc1, fc2, and fc3) perform classification based on the features extracted by convolutional layers. The first fully connected layer has 240 neurons, followed by a layer with 84 neurons, and finally, a layer with 10 neurons, corresponding to the number of output classes.
  • Dropout regularization (dropout) is applied to the second fully connected layer to prevent overfitting by randomly dropping a fraction of the neurons during training.
  • The forward method defines the forward pass of the network, where input x undergoes a sequence of operations, including convolution, activation (ReLU and Leaky ReLU), max-pooling, flattening, fully connected layers, and dropout. The output represents the class scores for each input sample.
class ConvNet_1(nn.Module):
    def __init__(self):
        super(ConvNet_1, self).__init__()
        # convolutional layer
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size =5, padding=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size = 5, padding=2)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size = 5, padding=2)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size = 5, padding=2)
        # max pooling layer
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(128 * 2 * 2, 240)
        self.fc2 = nn.Linear(240, 84)
        self.fc3 = nn.Linear(84, 10)
        # dropout
        self.dropout = nn.Dropout(p=.5)

    def forward(self, x):
        # -> n, 3, 32, 32
        x = self.pool(F.relu(self.conv1(x)))  
        x = self.pool(F.relu(self.conv2(x)))  
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.leaky_relu(self.conv4(x)))
        # flattening
        x = x.view(-1, 128 * 2 * 2)            
        x = F.relu(self.fc1(x))         
        x = self.dropout(F.relu(self.fc2(x)))             
        x = self.fc3(x)                       
        return x
    
model = ConvNet_1().to(device)
# print summary of model like summary in tensorflow
summary(model, input_size=(BATCH_SIZE, 3, 32, 32))
CNN model summary

Loss function and optimizer

The loss function used is the cross-entropy loss. It measures the performance of a classification model whose output is a probability value between 0 and 1. 

Adam optimization algorithm is used, which is an adaptive learning rate optimization algorithm designed to improve training speeds in deep neural networks and reach convergence quickly.

# Specify the Loss function
criterion = nn.CrossEntropyLoss()

# Specify the optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

Train the network

valid_loss_min = float('inf') # track change in validation loss
#List to store loss to visualize
train_losslist = []
valid_losslist = []
val_accuracies = []

for epoch in range(1, NUM_EPOCHS+1):

    # keep track of training and validation loss
    train_loss = 0.0
    valid_loss = 0.0
    correct = total = 0
    
    ###################
    # train the model #
    ###################
    model.train()
    for data, target in train_loader:
        # move tensors to GPU 
        data, target = data.to(device), target.to(device)
        # clear the gradients of all optimized variables
        optimizer.zero_grad()
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data)
        # calculate the batch loss
        loss = criterion(output, target)
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
        # update training loss
        train_loss += loss.item()*data.size(0)
        
    ######################    
    # validate the model #
    ######################
    model.eval()
    with torch.no_grad():
        for data, target in valid_loader:
            # move tensors to GPU 
            data, target = data.to(device), target.to(device)
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the batch loss
            loss = criterion(output, target)
            # update average validation loss 
            valid_loss += loss.item()*data.size(0)
            
            _, predicted = torch.max(output, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    
    # calculate average losses
    avg_train_loss = train_loss/len(train_loader.dataset)
    avg_valid_loss = valid_loss/len(valid_loader.dataset)
    
    train_losslist.append(avg_train_loss)
    valid_losslist.append(avg_valid_loss)
    
    val_accuracy = correct / total
    val_accuracies.append(val_accuracy)
  
    # print training/validation statistics 
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f} \tValidation accuracy: {: .4f}'.format(
        epoch, avg_train_loss, avg_valid_loss, round(val_accuracy*100, 2)))
    
    
    # save model if validation loss has decreased
    if avg_valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,avg_valid_loss))
        torch.save(model.state_dict(), 'model_cifar.pt')
        valid_loss_min = avg_valid_loss
model.load_state_dict(torch.load('model_cifar.pt'))

Plotting loss and accuracy

Visualizing training and validation metrics helps us analyze how the performance of the model changes over epochs. It helps us gain insights into issues like underfitting, overfitting, and convergence, allowing us to make informed decisions.

plt.figure(figsize=(12, 4), dpi=300)

plt.subplot(1, 2, 1)
plt.plot(train_losslist, label='Training Loss')
plt.plot(valid_losslist, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(val_accuracies, label='Validation Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy')
plt.legend()

plt.tight_layout()
plt.show()
Training and validation loss

Testing

We are going to evaluate the CNN model’s performance on test data

# test model
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    n_class_correct = [0 for i in range(10)]
    n_class_samples = [0 for i in range(10)]
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        # max returns (value ,index)
        _, predicted = torch.max(outputs, 1)
        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()

        for i in range(len(labels)):
            label = labels[i]
            pred = predicted[i]
            if (label == pred):
                n_class_correct[label] += 1
            n_class_samples[label] += 1

    tot_acc = 100.0 * n_correct / n_samples
    print(f'Accuracy of the network: {tot_acc} %')

    for i in range(10):
        acc = 100.0 * n_class_correct[i] / n_class_samples[i]
        print(f'Accuracy of {classes[i]}: {acc} %')

Accuracy of the network: 68.08 %
Accuracy of plane: 77.3 %
Accuracy of car: 82.4 %
Accuracy of bird: 53.1 %
Accuracy of cat: 46.8 %
Accuracy of deer: 66.8 %
Accuracy of dog: 42.6 %
Accuracy of frog: 78.7 %
Accuracy of horse: 73.2 %
Accuracy of ship: 86.6 %
Accuracy of truck: 73.3 %

Conclusion

  • We got an overall accuracy of 68.08% which is okay but not the best. We can improve the accuracy by increasing the layers of the CNN model or by using pre-trained models like a Resnet model.
  • Our model struggles with images of cats, dogs, and birds.

Similar Posts

Leave a Reply