Skip to content

Latest commit

 

History

History
67 lines (47 loc) · 2.79 KB

File metadata and controls

67 lines (47 loc) · 2.79 KB

Visits GitHub last commit GitHub repo size GitHub stars GitHub forks Python

Minimal Visual Transformer implementation in PyTorch

Hey Friends!

Welcome to this tiny experiment where we compare a classic Convolutional Neural Network (CNN) against the modern Vision Transformer (ViT). The task: old and gold handwritten digits recognition.

Why digits? Because my PC will explode with anything more serious 😂 Because the goal here is not to ACHIEVE, but to see how ViTs actually work under the hood.


⚙️ Installation

Create a virtual environment and install dependencies:

python -m venv vit-venv
source vit-venv/Scripts/activate
pip3 install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
pip3 install einops torchsummary

How tu run:

python train.py

🥊 CNN vs ViT

We train two models with roughly ~2k trainable parameters each.

🔵 CNN! the test accuracy looks like this:

CNN

🔴 ViT! And here’s the ViT’s performance:

VIT

🤔 Observations

  • Attention layers involve matrix multiplications of full sequences (O(N²) complexity), so ViT is SLOWER. Not like a turtle.. but the turtle loaded with bags from supermarket 😂
  • The ViT also gets lower accuracy, because Transformers work good when they can model long-range dependencies and are fed with lots of data. On MNIST, the images are tiny (28×28) and the dataset is small. CNNs are simply better at extracting local patterns like edges, strokes, and curves.

In other words: asking a ViT to classify MNIST is like hiring a theoretical physicist to count apples at a grocery store. So choose your model wisely! 🧐

⚖️ Pros & Cons

CNN

  • ✅ Fast to train
  • ✅ Great at local pattern recognition (edges, textures, shapes)
  • ✅ Works very well with small datasets
  • ❌ Limited ability to capture global context

ViT

  • ✅ Elegant, unified architecture (no handcrafted kernels)
  • ✅ Scales with data (huge datasets)
  • ✅ Attention maps are interpretable (attention weights show where the model is “looking”)
  • ❌ Training is slower
  • ❌ Needs more data to reach CNN-level performance