Vision Transformer Code

see: https://www.kaggie.com/a-deep-dive-into-vision-transformers-and-clip/

(note, this is only 1 epoch. Certainly a better result would occur with more training.)

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from sklearn.metrics import confusion_matrix, accuracy_score
import json
import os

# Define the ViT model (simplified for brevity)
class ViT(nn.Module):
    def __init__(self, num_classes: int = 10, img_size: int = 28, patch_size: in
t = 7, num_layers: int = 6, num_heads: int = 4, embed_dim: int = 64):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_embedding = nn.Conv2d(1, embed_dim, kernel_size=patch_size, s
tride=patch_size)  # Assuming grayscale images (1 channel)
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, e
mbed_dim)) # +1 for CLS token
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLa
yer(embed_dim, num_heads, dim_feedforward=embed_dim*4, batch_first=True), num_la
yers)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Patchify the image
        x = self.patch_embedding(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2).transpose(1, 2) # (B, num_patches, embed_dim)

        # Prepend CLS token
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Add positional embedding
        x = x + self.pos_embedding[:, :(x.shape[1]), :]

        # Transformer encoder
        x = self.transformer_encoder(x)

        # Classification head
        x = self.fc(x[:, 0]) # Take CLS token for classification
        return x

def main():
    print("Starting MNIST model training...")

    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # 1. Download MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)) # Mean and std of MNIST
    ])
    train_dataset = torchvision.datasets.MNIST(root='./data', train=True, downlo
ad=True, transform=transform)
    test_dataset = torchvision.datasets.MNIST(root='./data', train=False, downlo
ad=True, transform=transform)

    # Use smaller subsets for faster training
    subset_size = 50 * 128  # 50 batches * batch_size = 6400 samples
    train_dataset = torch.utils.data.Subset(train_dataset, range(subset_size))
    test_dataset = torch.utils.data.Subset(test_dataset, range(1000)) #Small tes
t set

    batch_size = 128
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_s
ize, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_siz
e, shuffle=False)
    print(f"Downloaded MNIST dataset and created dataloaders.")


    # 2. Define the model
    model = ViT().to(device)
    print(f"Created ViT model: {model}")

    # 3. Training loop
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    num_epochs = 1
    loss_list = []
    print(f"Starting training for {num_epochs} epochs...")
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_list.append(loss.item())

            if (i+1) % 50 == 0:
                print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_
loader)}], Loss: {loss.item():.4f}')
    print("Finished training.")

    # 4. Generate and save loss curve
    plt.figure(figsize=(10, 5))
    plt.plot(loss_list)
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.savefig('loss_curve.png')
    print("Saved loss curve to loss_curve.png")

    # 5. Generate and save metrics
    model.eval()  # Set model to evaluation mode
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
    accuracy = accuracy_score(all_labels, all_predictions)
    final_loss = loss_list[-1]
    metrics = {'accuracy': accuracy, 'final_loss': final_loss}
    with open('metrics.json', 'w') as f:
        json.dump(metrics, f)
    print("Saved metrics to metrics.json")

    # 6. Generate and save confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    df_cm = pd.DataFrame(cm, index = range(10), columns = range(10))
    plt.figure(figsize = (10,7))
    sns.heatmap(df_cm, annot=True, fmt='d', cmap='viridis')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    print("Saved confusion matrix to confusion_matrix.png")


if __name__ == "__main__":
    if not os.path.exists("./data"):
        os.makedirs("./data")
    main()

Code explanation by Gemini:

This code implements a Vision Transformer (ViT) to classify handwritten digits from the MNIST dataset.

Instead of using traditional Convolutional Neural Networks (CNNs) that look at pixels through sliding windows, this model treats an image like a sequence of text, breaking it into “patches” and using the Transformer architecture (the same tech behind ChatGPT) to understand the relationships between those patches.


1. The ViT Architecture

The class ViT is the heart of the script. Here is how it processes an image:

A. Patchify and Embed

The Transformer can’t “see” a 2D image directly.

  • patch_embedding: It uses a nn.Conv2d with a stride equal to the patch_size. This effectively chops a $28 \times 28$ image into $4 \times 4$ grid of patches (each $7 \times 7$ pixels).
  • Flattening: These patches are flattened into a 1D sequence of vectors.

B. The CLS Token and Positional Embedding

  • cls_token: A special “Classification” token is added to the start of the sequence. By the end of the layers, this token will hold the “summary” of the entire image.
  • pos_embedding: Since Transformers don’t know the order of the sequence naturally, we add “positional embeddings” so the model knows which patch was in the top-left vs. the bottom-right.

C. The Transformer Encoder

  • nn.TransformerEncoder: This is a stack of layers that uses Self-Attention. Each patch “looks” at every other patch to determine which parts of the image are most important for identifying the digit.

D. The Head

  • self.fc: After passing through the encoder, we ignore all the patch outputs and only take the output of the CLS token. This vector is passed to a Linear layer to decide if the digit is $0, 1, 2…9$.

2. The Training Pipeline (main)

The rest of the script handles the “life cycle” of the model:

PhaseDescription
Data PrepDownloads MNIST, normalizes the pixel values, and creates a Subset (6,400 samples) to make training faster for a demonstration.
OptimizationUses Adam optimizer and CrossEntropyLoss (the standard for classification).
Training LoopFeeds images to the model, calculates error (loss), and updates weights using backpropagation.
EvaluationSwitches to model.eval() mode to test the model on unseen data without updating weights.

3. Visualization and Outputs

The script doesn’t just train; it documents the results by saving three files:

  1. loss_curve.png: A line graph showing how the error decreased over time.
  2. metrics.json: A text file containing the final accuracy and loss.
  3. confusion_matrix.png: A heatmap showing exactly which digits the model confused (e.g., if it often mistook a 4 for a 9).

Key Technical Detail

In the forward pass, notice this line:

Python

x = x + self.pos_embedding[:, :(x.shape[1]), :]

This is a standard “shortcut” to add spatial information. Because the sequence length includes the extra CLS token, the positional embedding is sized num_patches + 1.


Comments

Leave a Reply

Your email address will not be published. Required fields are marked *