Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load the pretrained model in kaggle to interact directly with it #49

Open
peternasser99 opened this issue Mar 31, 2024 · 0 comments
Open

Comments

@peternasser99
Copy link

i am trying to load the pretrained model for imagenet 1k in kaggle to interact with it but the performance iam getting is random at best
any help is much appreciated
data set required in kaggle : imagenet-1k-resized-256

i coped the relevant pieces from the eval script
the code as follows it takes a min to run mostly the downlad

"""

get the repo in cell 1

!git clone https://github.com/facebookresearch/jepa.git
import os
os.chdir('/kaggle/working/jepa')
!pip install .

!wget https://dl.fbaipublicfiles.com/jepa/vitl16/in1k-probe.pth.tar
!wget https://dl.fbaipublicfiles.com/jepa/vitl16/vitl16.pth.tar

config for the model i want

import yaml
with open('/kaggle/working/jepa/configs/evals/vitl16_in1k.yaml', 'r') as y_file:
params = yaml.load(y_file, Loader=yaml.FullLoader)

params['pretrain']['folder'] = '/kaggle/working/jepa'
params['pretrain']['checkpoint'] = 'vitl16.pth.tar'

loading the model

import jepa.src.models.vision_transformer as vit
import torch

def load_pretrained(
encoder,
pretrained,
checkpoint_key='target_encoder'
):
print(f'Loading pretrained model from {pretrained}')
checkpoint = torch.load(pretrained, map_location='cpu')
try:
pretrained_dict = checkpoint[checkpoint_key]
except Exception:
pretrained_dict = checkpoint['encoder']

pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}
pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()}
for k, v in encoder.state_dict().items():
    if k not in pretrained_dict:
        print(f'key "{k}" could not be found in loaded state dict')
    elif pretrained_dict[k].shape != v.shape:
        print(f'key "{k}" is of different shape in model and loaded state dict')
        pretrained_dict[k] = v
msg = encoder.load_state_dict(pretrained_dict, strict=False)
print(f'loaded pretrained model with msg: {msg}')
print(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}')
del checkpoint
return encoder

def init_model(
device,
pretrained,
model_name,
patch_size=16,
crop_size=224,
# Video specific parameters
frames_per_clip=16,
tubelet_size=2,
use_sdpa=False,
use_SiLU=False,
tight_SiLU=True,
uniform_power=False,
checkpoint_key='target_encoder'
):
encoder = vit.dict[model_name](
img_size=crop_size,
patch_size=patch_size,
num_frames=frames_per_clip,
tubelet_size=tubelet_size,
uniform_power=uniform_power,
use_sdpa=use_sdpa,
use_SiLU=use_SiLU,
tight_SiLU=tight_SiLU,
)
if frames_per_clip > 1:
def forward_prehook(module, input):
input = input[0] # [B, C, H, W]
input = input.unsqueeze(2).repeat(1, 1, frames_per_clip, 1, 1)
return (input)

    encoder.register_forward_pre_hook(forward_prehook)

encoder.to(device)

encoder = load_pretrained(encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key)
return encoder

args_eval = params

args_pretrain = args_eval.get('pretrain')

checkpoint_key = args_pretrain.get('checkpoint_key', 'target_encoder')
model_name = args_pretrain.get('model_name', None)
patch_size = args_pretrain.get('patch_size', None)
pretrain_folder = args_pretrain.get('folder', None)
ckp_fname = args_pretrain.get('checkpoint', None)
tag = args_pretrain.get('write_tag', None)
use_sdpa = args_pretrain.get('use_sdpa', True)
use_SiLU = args_pretrain.get('use_silu', False)
tight_SiLU = args_pretrain.get('tight_silu', True)
uniform_power = args_pretrain.get('uniform_power', False)
pretrained_path = os.path.join(pretrain_folder, ckp_fname)

Optional [for Video model]:

tubelet_size = args_pretrain.get('tubelet_size', 2)
frames_per_clip = args_pretrain.get('frames_per_clip', 1)

args_data = args_eval.get('data')
resolution = args_data.get('resolution', 224)
num_classes = args_data.get('num_classes')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = init_model(
crop_size=resolution,
device=device,
pretrained=pretrained_path,
model_name=model_name,
patch_size=patch_size,
frames_per_clip=1,
tubelet_size=1,
uniform_power=uniform_power,
checkpoint_key=checkpoint_key,
use_SiLU=use_SiLU,
tight_SiLU=tight_SiLU,
use_sdpa=use_sdpa)

encoder.eval()
for p in encoder.parameters():
p.requires_grad = False

print(encoder)

#loading the classifier
from jepa.src.models.attentive_pooler import AttentiveClassifier

classifier = AttentiveClassifier(
embed_dim=encoder.embed_dim,
num_heads=encoder.num_heads,
depth=1,
num_classes=num_classes
).to(device)

checkpoint = torch.load("/kaggle/working/jepa/in1k-probe.pth.tar", map_location=torch.device('cpu'))
pretrained_dict = checkpoint['classifier']
pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}

print(classifier)

msg = classifier.load_state_dict(pretrained_dict)
print(msg)

evaluating

from PIL import Image
from io import BytesIO
import pickle
import os
import pandas as pd

import torch
from torchvision import transforms

transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])

parquet_file_path = "/kaggle/input/imagenet-1k-resized-256/data/train-00001-of-00052-886eb11e764e42fe.parquet"
df = pd.read_parquet(parquet_file_path)
print(df.shape)

file_path = "/kaggle/input/imagenet-1k-resized-256/classes.pkl"
with open(file_path, "rb") as f:
classes = pickle.load( f)

for idx ,row in df.sample(n=15).iterrows():
img = (Image.open(BytesIO(row['image']['bytes'])))
outs = classifier(encoder(transform(img).unsqueeze(0).to(device)))
values, indices = torch.topk( outs, 10 )
display(img)
print(f'real {row["label"]}')
for n ,i in enumerate(indices[0]) :
print(f'{i} : class {classes[int(i)]} value {values[0][n]} ')

"""

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant