Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference method #7

Open
prnvsheth opened this issue Oct 5, 2023 · 1 comment
Open

Inference method #7

prnvsheth opened this issue Oct 5, 2023 · 1 comment

Comments

@prnvsheth
Copy link

Is there a code already written to infer whether a specific image is fake or not? Can you point to the appropriate documentation from the code base.

@oshita-n
Copy link

oshita-n commented Jan 30, 2024

Perhaps it could be written like this.

inference.py

import argparse
from ast import arg
import os
import csv
import torch
import torchvision.transforms as transforms
import torch.utils.data
import numpy as np
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score
from torch.utils.data import Dataset
import sys
from models import get_model
from PIL import Image 
import pickle
from tqdm import tqdm
from io import BytesIO
from copy import deepcopy
from dataset_paths import DATASET_PATHS
import random
import shutil
from scipy.ndimage.filters import gaussian_filter
import torchvision

SEED = 0
def set_seed():
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)


MEAN = {
    "imagenet":[0.485, 0.456, 0.406],
    "clip":[0.48145466, 0.4578275, 0.40821073]
}

STD = {
    "imagenet":[0.229, 0.224, 0.225],
    "clip":[0.26862954, 0.26130258, 0.27577711]
}

def inference(model, img):
    with torch.no_grad():
        y_pred = model(img).sigmoid().flatten().squeeze().cpu().numpy()
    return y_pred

if __name__ == '__main__':


    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--image' , type=str, default='./test_images/real.png')
    parser.add_argument('--arch', type=str, default='res50')
    parser.add_argument('--ckpt', type=str, default='./pretrained_weights/fc_weights.pth')

    opt = parser.parse_args()

    model = get_model(opt.arch)
    state_dict = torch.load(opt.ckpt, map_location='cpu')
    model.fc.load_state_dict(state_dict)
    model.eval()
    model.cuda()

    stat_from = "imagenet" if opt.arch.lower().startswith("imagenet") else "clip"

    transform = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ),
    ])

    img_tensor = transform(Image.open(opt.image).convert("RGB")).unsqueeze(0).cuda()
    y_pred = inference(model, img_tensor)

    print ("Prediction: ", y_pred)

python inference.py  --arch=CLIP:ViT-L/14   --ckpt=pretrained_weights/fc_weights.pth   --image real/image-02.png

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants