Back to Projects

Vision Transformer
from Scratch

A complete implementation of the Vision Transformer (ViT) architecture in pure PyTorch, trained on CIFAR-10.

PyTorch ViT Computer Vision
Transformer Architecture

01. The Concept

The Vision Transformer (ViT) revolutionized computer vision by applying the standard Transformer architecture directly to images, with the fewest possible modifications.

In this project, I built a ViT from the ground up in PyTorch without relying on pre-built high-level libraries. Key components implemented include Patch Embeddings, the CLS Token, and Multi-Head Self-Attention.

02. Technical Implementation

Patch Embedding

Instead of processing individual pixels, we split the image into fixed-size patches. This is efficiently implemented using a Conv2d layer with kernel size and stride equal to the patch size.

model.py
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=4, embed_dim=128):
        super().__init__()
        self.patch_size = patch_size
        self.conv = nn.Conv2d(in_channels, embed_dim, 
                              kernel_size=patch_size, 
                              stride=patch_size)

    def forward(self, x):
        # x: (B, C, H, W) -> (B, Embed_Dim, Grid_H, Grid_W)
        x = self.conv(x) 
        # Flatten -> (B, Embed_Dim, N_Patches)
        x = x.flatten(2)
        # Transpose -> (B, N_Patches, Embed_Dim)
        return x.transpose(1, 2)

Experimental Hybrid

I experimented with a CNN-ViT Hybrid stem to introduce inductive bias for local texture recognition before the global transformer layers.

Positional Embedding

Since self-attention is permutation-invariant, learnable positional embeddings are added to the patch embeddings to retain spatial information.

Results (CIFAR-10)

Final Accuracy 66.44%

*Training on small datasets like CIFAR-10 is challenging for pure ViTs without massive pre-training.

Source Code

The full implementation is available as a Jupyter Notebook.

View on Colab