Skip to content

This repository contains a PyTorch based Vision Transformer implementation trained on Devanagari Handwritten Character Dataset.

Notifications You must be signed in to change notification settings

1rsh/vit-from-scratch

Repository files navigation

Vision Transformer Implementation for Devanagari Character Recognition PyTorch

vit-header

This repository contains an implementation of ViT (Vision Transformer Architecture) that is introduced in the paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale using PyTorch.

The model is trained on Devanagari Handwritten Character Dataset available on the UC Irvine Machine Learning Repository on an M1 Pro achieving a validation accuracy of 96.70%.


Repository Walkthrough

1. PatchEmbedding

class PatchEmbedding(nn.Module):
    def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
        super().__init__()
        
        # Dividing into patches
        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels = in_channels,
                out_channels = embed_dim,
                kernel_size = patch_size,
                stride = patch_size
            ),
            nn.Flatten(2))
        
        self.cls_token = nn.Parameter(torch.randn(size = (1, in_channels, embed_dim)), requires_grad = True)
        self.position_embeddings = nn.Parameter(torch.randn(size = (1, num_patches + 1, embed_dim), requires_grad = True))
        self.dropout = nn.Dropout(p = dropout)
        
    def forward(self, x):
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        
        x = self.patcher(x).permute(0, 2, 1)
        x = torch.cat([cls_token, x], dim = 1) # adding cls_token to left
        x = self.position_embeddings + x # adding position embeddings to patches
        x = self.dropout(x)
        return x

2. VisionTransformer

class VisionTransformer(nn.Module):
    def __init__(self, num_patches, img_size, num_classes, patch_size, embed_dim, num_encoders, num_heads, hidden_dim, dropout, activation, in_channels):
        super().__init__()
        
        self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels)
        encoder_layer = nn.TransformerEncoderLayer(d_model = embed_dim, nhead = num_heads, dropout = dropout, activation = activation, batch_first = True, norm_first = True)
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers = num_encoders)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape = embed_dim),
            nn.Linear(in_features = embed_dim, out_features = num_classes)
        )
        
    def forward(self, x):
        x = self.embeddings_block(x)
        x = self.encoder_blocks(x)
        x = self.mlp_head(x[:, 0, :]) # taking only cls_token
        return x

3. Hyperparameters

batch_size = 512
num_epochs = 40

learning_rate = 1e-4
num_classes = 46
patch_size = 4
img_size = 32
in_channels = 1
num_heads = 8
dropout = 0.001
hidden_dim = 1024
adam_weight_decay = 0
adam_betas = (0.9, 0.999)
activation = "gelu"
num_encoders = 4
embed_dim = (patch_size ** 2) * in_channels # 16
num_patches = (img_size // patch_size) ** 2 # 64

4. Training

for epoch in tqdm(range(num_epochs), position = 0, leave = True):
    model.train()
    
    train_labels = []
    train_preds = []
    
    train_running_loss = 0
    
    for idx, img_label in enumerate(tqdm(train_dataloader, position = 0, leave = True)):
        img = img_label[0].float().to(device)
        label = img_label[1].type(torch.uint8).to(device)
        
        y_pred = model(img)
        y_pred_label = torch.argmax(y_pred, dim = 1)
        
        train_labels.extend(label.cpu().detach())
        train_preds.extend(y_pred_label.cpu().detach())
        
        loss = criterion(y_pred, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_running_loss += loss.item()
        
    train_loss = train_running_loss / (idx + 1)
    
    
    if((epoch + 1) % 5 == 0):
        model.eval()

        val_labels = []
        val_preds = []
        val_running_loss = 0

        with torch.no_grad():
            for idx, img_label in enumerate(tqdm(test_dataloader, position = 0, leave = True)):
                img = img_label[0].float().to(device)
                label = img_label[1].type(torch.uint8).to(device)

                y_pred = model(img)
                y_pred_label = torch.argmax(y_pred, dim = 1)

                val_labels.extend(label.cpu().detach())
                val_preds.extend(y_pred_label.cpu().detach())

                loss = criterion(y_pred, label)

                val_running_loss += loss.item()

            val_loss = val_running_loss / (idx + 1)

training-history

Footnote

If you wish to use the following PyTorch implementation of Vision Transformer for your own project, just download the notebook and update train_dir and test_dir acoording to your file hierarchy. Also make sure to adjust variables such as img_size.
Feel free to contact me at [email protected].

Python

About

This repository contains a PyTorch based Vision Transformer implementation trained on Devanagari Handwritten Character Dataset.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published