Index

09Field ReportDeep learning · ViT

Vision Transformer

An annotated PyTorch implementation of the Vision Transformer (ViT), training a miniature model on CIFAR-10 to demonstrate patch-based image classification.

Private repository

Overview

This project provides an annotated implementation of the Vision Transformer (ViT) model based on the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., ICLR 2021). While Convolutional Neural Networks (CNNs) have long dominated computer vision, this project explores how Transformers—originally designed for NLP—can be applied directly to image sequences. We demonstrate the implementation from the ground up and analyze its performance when trained from scratch on the CIFAR-10 dataset.

Background

For decades, CNNs like ResNet and VGG have been the standard for vision tasks due to their built-in inductive biases:

  • Locality: Convolutions operate on local neighborhoods of pixels.
  • Translation Equivariance: Features detected in one location can be recognized elsewhere.
  • Hierarchical Features: Stacking layers builds from edges to textures to objects.

However, the Transformer revolution in NLP (BERT, GPT) showed that massive datasets and self-attention mechanisms could outperform hand-crafted priors. The Vision Transformer paper made a simple but powerful observation: Treat image patches as tokens.

Key findings from the paper indicate that while ViT underperforms ResNets on small datasets (due to lack of inductive bias), it achieves state-of-the-art results when pre-trained on massive datasets (like JFT-300M), effectively learning the spatial relationships that CNNs have hard-coded.

Model Architecture

The ViT architecture consists of several specific components designed to map image data into a sequence format compatible with standard Transformers.

Vision Transformer Architecture
Figure 1: The Vision Transformer workflow. Image patches are linearly embedded, prepended with a class token, and processed by a Transformer encoder.

1. Patch Embedding

The first step converts an image into a sequence. We split the image into fixed-size patches and project them to an embedding dimension. In our implementation, we use a Conv2d layer with kernel size and stride equal to the patch size to perform this extraction and projection simultaneously.

class PatchEmbedding(nn.Module):
def __init__(self, in_channels=3, embed_dim=768, patch_size=16):
    super().__init__()
    self.proj = nn.Conv2d(in_channels, embed_dim,
                          kernel_size=patch_size, stride=patch_size)

def forward(self, x):
    x = self.proj(x)       # (B, Embed, H/P, W/P)
    x = x.flatten(2)       # (B, Embed, N)
    x = x.transpose(1, 2)  # (B, N, Embed)
    return x

2. Class Token & Position Embedding

Following BERT, ViT prepends a learnable [CLS] token to the patch embeddings. This token aggregates information from the entire sequence via self-attention. Additionally, since Transformers have no inherent notion of order, learnable 1D position embeddings are added to the sequence to retain spatial information.

class ClassToken(nn.Module):
def __init__(self, embed_dim):
    super().__init__()
    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

def forward(self, x):
    B = x.shape[0]
    cls_tokens = self.cls_token.expand(B, -1, -1)
    return torch.cat((cls_tokens, x), dim=1)

3. Transformer Encoder

The core is a standard Transformer encoder using a Pre-Norm architecture...

Transformer Encoder Block
Figure 2: Encoder block with Residuals.
class TransformerEncoderBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.0):
    super().__init__()
    self.ln1 = nn.LayerNorm(embed_dim)
    self.attn = nn.MultiheadAttention(embed_dim, num_heads,
                                      dropout=dropout, batch_first=True)
    self.ln2 = nn.LayerNorm(embed_dim)

    hidden_dim = int(embed_dim * mlp_ratio)
    self.mlp = nn.Sequential(
        nn.Linear(embed_dim, hidden_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(hidden_dim, embed_dim),
        nn.Dropout(dropout)
    )

def forward(self, x):
    x_norm = self.ln1(x)
    attn_out, _ = self.attn(x_norm, x_norm, x_norm)
    x = x + attn_out  # Residual

    x_norm = self.ln2(x)
    mlp_out = self.mlp(x_norm)
    x = x + mlp_out   # Residual
    return x

4. Classification Head

Finally, a classification head maps the [CLS] token's representation to the output classes. While the original paper uses an MLP during pre-training, we use a single linear layer for our CIFAR-10 model to act as a structural regularizer and prevent overfitting.

Inference Pipeline

Once trained, the inference process is straightforward. We implemented a pipeline that takes an image, processes it through the transformer, and outputs probabilities.

  1. Input: Image tensor of shape (B, C, H, W).
  2. Patching: Image is split into patches and projected.
  3. Tokens: [CLS] token prepended; Position embeddings added.
  4. Encoder: Sequence passes through L layers of attention.
  5. Prediction: [CLS] token is extracted and projected to class logits.
@torch.no_grad()
def predict(model, image, device):
    model.eval()
    image = image.to(device)
    if image.dim() == 3: image = image.unsqueeze(0)

    logits = model(image)
    probs = F.softmax(logits, dim=1)
    pred_class = logits.argmax(dim=1)

    return pred_class, probs

Experimental Setup (Mini-ViT)

To demonstrate feasibility on consumer hardware (CPU/Laptop), we trained a miniaturized version of ViT on the CIFAR-10 dataset. The original ViT-Base is too large for this scale of data, so we adapted the architecture:

Model Config

  • Image Size: 32x32 (vs 224)
  • Patch Size: 4x4 (64 patches)
  • Depth: 4 Layers
  • Embed Dim: 128
  • Heads: 4

Training Config

  • Epochs: 20
  • Optimizer: AdamW
  • Learning Rate: 3e-4
  • Schedule: Warmup + Cosine Decay
  • Augmentation: RandomCrop, Flip
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        correct += outputs.argmax(1).eq(labels).sum().item()
        total += labels.size(0)

    return total_loss / total, 100.0 * correct / total

Results & Analysis

Training Curves
Figure 3: Training/Validation Loss and Accuracy over 20 epochs.

The training curves reveal a significant gap between Training Accuracy (80.23%) and Validation Accuracy (71.68%). This overfitting is expected and highlights the primary limitation of Vision Transformers: Data Hunger.

Model Predictions
Figure 4: Sample predictions on CIFAR-10 test images.

Why does Mini-ViT underperform compared to CNNs on CIFAR-10?

  • Lack of Inductive Biases: CNNs inherently understand locality and translation. ViT must learn these spatial concepts from scratch, which is difficult with only 50k images.
  • Low Resolution: At 32x32 resolution with 4x4 patches, the model only has 64 tokens to work with. This limits the "long-range" context that self-attention is designed to exploit.
  • No Pre-training: ViT relies heavily on large-scale pre-training (ImageNet-21k, JFT-300M) to learn robust features. Training from scratch on a small dataset yields suboptimal results compared to Transfer Learning.

Discussion

Strengths:

  • Scalability: ViT scales significantly better than CNNs as data and compute increase.
  • Global Attention: Can capture relationships between distant parts of an image in the very first layer.
  • Simplicity: Uses a standard Transformer encoder without complex vision-specific operators.

Weaknesses:

  • Quadratic Complexity: Attention costs grow quadratically with the number of patches, limiting high-resolution applications.
  • Data Inefficiency: Requires massive datasets to beat CNN baselines.

Conclusion: This project successfully implemented a functional Vision Transformer from scratch. While the Mini-ViT did not break state-of-the-art records on CIFAR-10, it effectively demonstrated the mechanics of patch embedding, positional encoding, and self-attention in a computer vision context. Future work involves exploring data-efficient training recipes (like DeiT) or hybrid CNN-Transformer architectures.