Skip to content

Commit

Permalink
Refactor code for speed and clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
UltralyticsAssistant committed Aug 26, 2024
1 parent 68c3966 commit b586692
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
2 changes: 1 addition & 1 deletion clip/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def tokenize(
texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False
) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)
Returns the tokenized representation of given input string(s).
Parameters
----------
Expand Down
14 changes: 14 additions & 0 deletions clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@


class Bottleneck(nn.Module):
"""Implements a residual bottleneck block with downsampling and expansion for deep neural networks."""

expansion = 4

def __init__(self, inplanes, planes, stride=1):
Expand Down Expand Up @@ -62,6 +64,8 @@ def forward(self, x: torch.Tensor):


class AttentionPool2d(nn.Module):
"""Applies multi-head attention pooling over 2D spatial data, transforming it into a fixed-size output embedding."""

def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
"""Initializes AttentionPool2d with spatial dimension, embedding dimension, number of heads, and optional output
dimension.
Expand Down Expand Up @@ -189,12 +193,16 @@ def forward(self, x: torch.Tensor):


class QuickGELU(nn.Module):
"""Applies the QuickGELU activation function, a faster approximation of GELU, to an input tensor."""

def forward(self, x: torch.Tensor):
"""Applies the QuickGELU activation function to an input tensor."""
return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):
"""Implements a residual attention block with multi-head attention and MLP layers for transformer models."""

def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
"""Initializes the ResidualAttentionBlock with model dimension, number of heads, and optional attention mask."""
super().__init__()
Expand Down Expand Up @@ -228,6 +236,8 @@ def forward(self, x: torch.Tensor):


class Transformer(nn.Module):
"""Processes input tensors through multiple residual attention blocks for sequence modeling tasks."""

def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
"""Initializes the Transformer model with specified width, layers, heads, and optional attention mask."""
super().__init__()
Expand All @@ -241,6 +251,8 @@ def forward(self, x: torch.Tensor):


class VisionTransformer(nn.Module):
"""Vision Transformer model for image classification using patch embeddings and multi-head self-attention."""

def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
"""Initialize a VisionTransformer with given input resolution, patch size, width, layers, heads, and output
dimension.
Expand Down Expand Up @@ -289,6 +301,8 @@ def forward(self, x: torch.Tensor):


class CLIP(nn.Module):
"""Multi-modal model combining vision and text encoders for joint embeddings based on arxiv.org/abs/2103.00020."""

def __init__(
self,
embed_dim: int,
Expand Down
2 changes: 2 additions & 0 deletions clip/simple_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def whitespace_clean(text):


class SimpleTokenizer(object):
"""Tokenizes text using byte pair encoding (BPE) and predefined tokenization rules for efficient text processing."""

def __init__(self, bpe_path: str = default_bpe()):
"""Initialize the SimpleTokenizer object with byte pair encoding (BPE) paths and set up encoders, decoders, and
patterns.
Expand Down
3 changes: 1 addition & 2 deletions tests/test_consistency.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import clip
import numpy as np
import pytest
import torch
from PIL import Image

import clip


@pytest.mark.parametrize("model_name", clip.available_models())
def test_consistency(model_name):
Expand Down

0 comments on commit b586692

Please sign in to comment.