A complete implementation of the Vision Transformer (ViT) architecture in pure PyTorch, trained on CIFAR-10.
The Vision Transformer (ViT) revolutionized computer vision by applying the standard Transformer architecture directly to images, with the fewest possible modifications.
In this project, I built a ViT from the ground up in PyTorch without relying on pre-built high-level libraries. Key components implemented include Patch Embeddings, the CLS Token, and Multi-Head Self-Attention.
Instead of processing individual pixels, we split the image into fixed-size patches. This is efficiently implemented using a Conv2d layer with kernel size and stride equal to the patch size.
class PatchEmbedding(nn.Module):
def __init__(self, in_channels=3, patch_size=4, embed_dim=128):
super().__init__()
self.patch_size = patch_size
self.conv = nn.Conv2d(in_channels, embed_dim,
kernel_size=patch_size,
stride=patch_size)
def forward(self, x):
# x: (B, C, H, W) -> (B, Embed_Dim, Grid_H, Grid_W)
x = self.conv(x)
# Flatten -> (B, Embed_Dim, N_Patches)
x = x.flatten(2)
# Transpose -> (B, N_Patches, Embed_Dim)
return x.transpose(1, 2)
I experimented with a CNN-ViT Hybrid stem to introduce inductive bias for local texture recognition before the global transformer layers.
Since self-attention is permutation-invariant, learnable positional embeddings are added to the patch embeddings to retain spatial information.
*Training on small datasets like CIFAR-10 is challenging for pure ViTs without massive pre-training.
The full implementation is available as a Jupyter Notebook.
View on Colab