Skip to content

Commit

Permalink
Merge pull request #72 from jmisilo/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
jmisilo authored Nov 24, 2022
2 parents a9100e3 + 2e8065d commit 792c7a3
Show file tree
Hide file tree
Showing 22 changed files with 453 additions and 319 deletions.
Binary file added examples/23012796-RS.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed examples/23012796.jpg
Binary file not shown.
Binary file added examples/36979-RS.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed examples/36979.jpg
Binary file not shown.
Binary file added examples/89407459-RL.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed examples/89407459.jpg
Binary file not shown.
Binary file removed examples/loss_lr.jpg
Binary file not shown.
34 changes: 22 additions & 12 deletions readme.md
Original file line number Diff line number Diff line change
@@ -1,28 +1,38 @@
# CLIPxGPT Captioner

### Description
## Description

**`CLIPxGPT Captioner`** is Image Captioning Model based on [OpenAI's](https://openai.com/) [CLIP](https://openai.com/blog/clip/) and [GPT-2](https://openai.com/blog/better-language-models/). The Model uses a Mapping module to "translate" CLIP embeddings ​​to GPT-2. The model is trained on the [Flickr30k](https://shannon.cs.illinois.edu/DenotationGraph/) dataset, downloaded from [Kaggle](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset)

**The goal** of the project was to find out about the possibility of CLIP + GPT-2 connection and to check whether, with a relatively short training time and a small dataset, the model will be able to recognize situations in the pictures. In the first version, the model achieved satisfactory results.
**The goal** of the project was to find out about the possibility of CLIP + GPT-2 connection and to check whether, with a relatively short training time and a small dataset, the model will be able to recognize situations in the pictures. The model achieved satisfactory results.

The Model uses prefixes as in the [ClipCap](https://arxiv.org/abs/2111.09734) paper. In my original idea, the length of the prefix was 1, but after reading publication, the length of the prefix was changed to 4, thanks to which the performance increased.

The Model was trained with a frozen CLIP, a fully trained Mapping Module (6x Transformer Encoder Layers) and with partially frozen GPT-2 (the first and last 14 layers were trained).
The Model was trained with a frozen CLIP, a fully trained Mapping Module (5-6x Transformer Encoder Layers) and with partially frozen GPT-2 (the first and last 14 layers were trained).

The training process was carried out using the [Kaggle](https://www.kaggle.com/) P100 GPU. Training time - about 3 x 11h (150 epochs) with a linear learning rate warmup (max LR `3e-3`) and batch size 64.
The training process was carried out using the [Kaggle](https://www.kaggle.com/) P100 GPU.

#### Loss and Learning Rate during training
### Model Versions

![LOSSxLR](./examples/loss_lr.jpg)
> **Small** - [Download](https://drive.google.com/uc?id=1p91KBj-oUmuMfG2Gc33tEN5Js5HpV8YH)
> * Text Model - GPT-2 Small - 124M parameters
> * Mapping Module - 6x Transformer Encoder Layers
> * CLIP Base - Patch 32 model
> * 256M Parameters
### Example results
> **Large** - [Download](https://drive.google.com/uc?id=12h-NgryAf6zZdA1KclHdfzU35D1icjEp)
> * Text Model - GPT-2 Medium - 355M parameters
> * Mapping Module - 5x Transformer Encoder Layers
> * CLIP Large - Patch 14 model
> * 736M Parameters
![Example1](./examples/23012796.jpg)
![Example2](./examples/36979.jpg)
![Example3](./examples/89407459.jpg)
## Example results

### Usage
![Example1](./examples/23012796-RS.jpg)
![Example2](./examples/36979-RS.jpg)
![Example3](./examples/89407459-RL.jpg)

## Usage

Clone repository using:

Expand All @@ -47,7 +57,7 @@ pip install -r requirements.txt
And run prediction:

```bash
python .\src\predict.py -I <image_path>
python .\src\predict.py -I <image_path> -S <model_size [S/L]> -C <checkpoint_name>
```

### References:
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ matplotlib==3.6.0
numpy==1.23.3
pandas==1.5.0
Pillow==9.3.0
torch==1.12.1+cu116
torch==1.13.0+cu117
tqdm==4.64.1
transformers==4.22.1
wandb==0.13.4
7 changes: 6 additions & 1 deletion src/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* get_loader returns DataLoader object.
'''

import os
import pickle

import numpy as np
Expand All @@ -16,6 +17,10 @@

class MiniFlickrDataset(Dataset):
def __init__(self, path):
# check if file is file
if not os.path.isfile(path):
raise OSError('Dataset file not found. Downloading...')

with open(path, 'rb') as f:
self.data = pickle.load(f)

Expand All @@ -40,7 +45,7 @@ def cl_fn(batch, tokenizer):
return img_emb, input_ids, attention_mask

def get_loader(dataset, bs_exp=5, shuffle=True, num_workers=0, pin_memory=False):
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-xl')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
tokenizer.pad_token = tokenizer.eos_token

return DataLoader(
Expand Down
4 changes: 2 additions & 2 deletions src/dataset_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load CLIP model and processor
preprocessor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32').vision_model.to(device)
preprocessor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14')
model = CLIPModel.from_pretrained('openai/clip-vit-large-patch14').vision_model.to(device)

# Load dataset
df = pd.read_csv(os.path.join(DATA_PATH, 'raw', 'results.csv'), sep='|')
Expand Down
55 changes: 49 additions & 6 deletions src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,36 @@
import random

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
from torch.utils.data import random_split
from tqdm import tqdm

from data import MiniFlickrDataset
from model import evaluate_dataset, Net
from utils import Config, download_weights
from model import Net
from utils import ConfigS, ConfigL, download_weights

config = Config()
parser = argparse.ArgumentParser()

parser.add_argument(
'-C',
'--checkpoint-name',
type=str,
default='',
default='model.pt',
help='Checkpoint name'
)

parser.add_argument(
'-S',
'--size',
type=str,
default='S',
help='Model size [S, L]',
choices=['S', 'L', 's', 'l']
)

parser.add_argument(
'-I',
'--img-path',
Expand All @@ -52,6 +63,8 @@

args = parser.parse_args()

config = ConfigL() if args.size.upper() == 'L' else ConfigS()

ckp_path = os.path.join(config.weights_dir, args.checkpoint_name)

assert os.path.exists(args.img_path), 'Path to the test image folder does not exist'
Expand All @@ -66,8 +79,38 @@
is_cuda = torch.cuda.is_available()
device = 'cuda' if is_cuda else 'cpu'

def evaluate_dataset(model, dataset, img_path, save_path, temperature=1.0):
'''
Evaluate model on dataset.
Args:
model: model to evaluate
dataset: dataset to evaluate on
img_path: path to images
save_path: path to save results
'''
model.eval()

loop = tqdm(dataset, total=len(dataset))
for img_name, _, _ in loop:
img = Image.open(os.path.join(img_path, img_name))

with torch.no_grad():
caption, _ = model(img, temperature)

plt.imshow(img)
plt.title(caption)
plt.axis('off')

plt.savefig(os.path.join(save_path, img_name), bbox_inches='tight')

plt.clf()
plt.close()

if __name__ == '__main__':
model = Net(
clip_model=config.clip_model,
text_model=config.text_model,
ep_len=config.ep_len,
num_layers=config.num_layers,
n_heads=config.n_heads,
Expand All @@ -89,12 +132,12 @@
os.makedirs(config.weights_dir)

if not os.path.isfile(ckp_path):
download_weights(ckp_path)
download_weights(ckp_path, args.size)

checkpoint = torch.load(ckp_path, map_location=device)
model.load_state_dict(checkpoint)

save_path = os.path.join(args.res_path, args.checkpoint_name[:-3])
save_path = os.path.join(args.res_path, f'{args.checkpoint_name[:-3]}_{args.size.upper()}')

if not os.path.exists(save_path):
os.mkdir(save_path)
Expand Down
2 changes: 1 addition & 1 deletion src/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from model.loops import *
from model.trainer import *
from model.model import *
159 changes: 0 additions & 159 deletions src/model/loops.py

This file was deleted.

Loading

0 comments on commit 792c7a3

Please sign in to comment.