You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
i am trying to load the pretrained model for imagenet 1k in kaggle to interact with it but the performance iam getting is random at best
any help is much appreciated
data set required in kaggle : imagenet-1k-resized-256
i coped the relevant pieces from the eval script
the code as follows it takes a min to run mostly the downlad
pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}
pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()}
for k, v in encoder.state_dict().items():
if k not in pretrained_dict:
print(f'key "{k}" could not be found in loaded state dict')
elif pretrained_dict[k].shape != v.shape:
print(f'key "{k}" is of different shape in model and loaded state dict')
pretrained_dict[k] = v
msg = encoder.load_state_dict(pretrained_dict, strict=False)
print(f'loaded pretrained model with msg: {msg}')
print(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}')
del checkpoint
return encoder
checkpoint = torch.load("/kaggle/working/jepa/in1k-probe.pth.tar", map_location=torch.device('cpu'))
pretrained_dict = checkpoint['classifier']
pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}
file_path = "/kaggle/input/imagenet-1k-resized-256/classes.pkl"
with open(file_path, "rb") as f:
classes = pickle.load( f)
for idx ,row in df.sample(n=15).iterrows():
img = (Image.open(BytesIO(row['image']['bytes'])))
outs = classifier(encoder(transform(img).unsqueeze(0).to(device)))
values, indices = torch.topk( outs, 10 )
display(img)
print(f'real {row["label"]}')
for n ,i in enumerate(indices[0]) :
print(f'{i} : class {classes[int(i)]} value {values[0][n]} ')
"""
The text was updated successfully, but these errors were encountered:
i am trying to load the pretrained model for imagenet 1k in kaggle to interact with it but the performance iam getting is random at best
any help is much appreciated
data set required in kaggle : imagenet-1k-resized-256
i coped the relevant pieces from the eval script
the code as follows it takes a min to run mostly the downlad
"""
get the repo in cell 1
!git clone https://github.com/facebookresearch/jepa.git
import os
os.chdir('/kaggle/working/jepa')
!pip install .
!wget https://dl.fbaipublicfiles.com/jepa/vitl16/in1k-probe.pth.tar
!wget https://dl.fbaipublicfiles.com/jepa/vitl16/vitl16.pth.tar
config for the model i want
import yaml
with open('/kaggle/working/jepa/configs/evals/vitl16_in1k.yaml', 'r') as y_file:
params = yaml.load(y_file, Loader=yaml.FullLoader)
params['pretrain']['folder'] = '/kaggle/working/jepa'
params['pretrain']['checkpoint'] = 'vitl16.pth.tar'
loading the model
import jepa.src.models.vision_transformer as vit
import torch
def load_pretrained(
encoder,
pretrained,
checkpoint_key='target_encoder'
):
print(f'Loading pretrained model from {pretrained}')
checkpoint = torch.load(pretrained, map_location='cpu')
try:
pretrained_dict = checkpoint[checkpoint_key]
except Exception:
pretrained_dict = checkpoint['encoder']
def init_model(
device,
pretrained,
model_name,
patch_size=16,
crop_size=224,
# Video specific parameters
frames_per_clip=16,
tubelet_size=2,
use_sdpa=False,
use_SiLU=False,
tight_SiLU=True,
uniform_power=False,
checkpoint_key='target_encoder'
):
encoder = vit.dict[model_name](
img_size=crop_size,
patch_size=patch_size,
num_frames=frames_per_clip,
tubelet_size=tubelet_size,
uniform_power=uniform_power,
use_sdpa=use_sdpa,
use_SiLU=use_SiLU,
tight_SiLU=tight_SiLU,
)
if frames_per_clip > 1:
def forward_prehook(module, input):
input = input[0] # [B, C, H, W]
input = input.unsqueeze(2).repeat(1, 1, frames_per_clip, 1, 1)
return (input)
args_eval = params
args_pretrain = args_eval.get('pretrain')
checkpoint_key = args_pretrain.get('checkpoint_key', 'target_encoder')
model_name = args_pretrain.get('model_name', None)
patch_size = args_pretrain.get('patch_size', None)
pretrain_folder = args_pretrain.get('folder', None)
ckp_fname = args_pretrain.get('checkpoint', None)
tag = args_pretrain.get('write_tag', None)
use_sdpa = args_pretrain.get('use_sdpa', True)
use_SiLU = args_pretrain.get('use_silu', False)
tight_SiLU = args_pretrain.get('tight_silu', True)
uniform_power = args_pretrain.get('uniform_power', False)
pretrained_path = os.path.join(pretrain_folder, ckp_fname)
Optional [for Video model]:
tubelet_size = args_pretrain.get('tubelet_size', 2)
frames_per_clip = args_pretrain.get('frames_per_clip', 1)
args_data = args_eval.get('data')
resolution = args_data.get('resolution', 224)
num_classes = args_data.get('num_classes')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = init_model(
crop_size=resolution,
device=device,
pretrained=pretrained_path,
model_name=model_name,
patch_size=patch_size,
frames_per_clip=1,
tubelet_size=1,
uniform_power=uniform_power,
checkpoint_key=checkpoint_key,
use_SiLU=use_SiLU,
tight_SiLU=tight_SiLU,
use_sdpa=use_sdpa)
encoder.eval()
for p in encoder.parameters():
p.requires_grad = False
print(encoder)
#loading the classifier
from jepa.src.models.attentive_pooler import AttentiveClassifier
classifier = AttentiveClassifier(
embed_dim=encoder.embed_dim,
num_heads=encoder.num_heads,
depth=1,
num_classes=num_classes
).to(device)
checkpoint = torch.load("/kaggle/working/jepa/in1k-probe.pth.tar", map_location=torch.device('cpu'))
pretrained_dict = checkpoint['classifier']
pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}
print(classifier)
msg = classifier.load_state_dict(pretrained_dict)
print(msg)
evaluating
from PIL import Image
from io import BytesIO
import pickle
import os
import pandas as pd
import torch
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
parquet_file_path = "/kaggle/input/imagenet-1k-resized-256/data/train-00001-of-00052-886eb11e764e42fe.parquet"
df = pd.read_parquet(parquet_file_path)
print(df.shape)
file_path = "/kaggle/input/imagenet-1k-resized-256/classes.pkl"
with open(file_path, "rb") as f:
classes = pickle.load( f)
for idx ,row in df.sample(n=15).iterrows():
img = (Image.open(BytesIO(row['image']['bytes'])))
outs = classifier(encoder(transform(img).unsqueeze(0).to(device)))
values, indices = torch.topk( outs, 10 )
display(img)
print(f'real {row["label"]}')
for n ,i in enumerate(indices[0]) :
print(f'{i} : class {classes[int(i)]} value {values[0][n]} ')
"""
The text was updated successfully, but these errors were encountered: