-
Notifications
You must be signed in to change notification settings - Fork 4
/
finetune_clip.py
91 lines (70 loc) · 2.88 KB
/
finetune_clip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pandas as pd
from argparse import ArgumentParser
import os
from PIL import Image
from sentence_transformers import InputExample
from torch.utils.data import Dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers import losses
from torch.utils.data import DataLoader
def get_opts():
parser = ArgumentParser()
parser.add_argument('--pseudolabels', '-p', type=str, default="data/pseudolabels.csv", help="filename of pseudolabel csv file")
parser.add_argument('--epochs', '-e', type=int, default=5, help="epochs")
parser.add_argument('--lr', '-l', type=float, default=1e-6, help="learning rate")
parser.add_argument('--batch_size', '-b', type=int, default=2 ** 7, help="batch size")
parser.add_argument('--num_workers', '-n', type=int, default=0, help="number of dataloader workers")
parser.add_argument('--output', '-o', type=str, default="data/clip_ckpt", help="output checkpoint directory")
parser.add_argument('--data_dir', '-d', type=str, default="data/wikiscenes", help="directory WikiScenes data is stored in")
return parser.parse_args()
def row2ex(row, data_dir):
ps, fn_ = row.pseudolabel, row.fn
fn = os.path.join(data_dir, fn_)
img = Image.open(fn).convert('RGB')
return InputExample(texts=[img, ps])
class DS(Dataset):
def __init__(self, df, data_dir):
self.df = df
self.data_dir = data_dir
def __len__(self):
return self.df.shape[0]
def __getitem__(self, idx):
return row2ex(self.df.iloc[idx], self.data_dir)
def main():
args = get_opts()
assert os.path.exists(args.pseudolabels), f'Missing pseudolabel file: {args.pseudolabels}'
assert os.path.exists(args.data_dir), f'Missing data directory: {args.data_dir}'
print("Loading pseudolabel table...")
fn = args.pseudolabels
assert os.path.exists(fn), f'Missing file: {fn}'
df = pd.read_csv(fn)
print("Pseudolabel table loaded")
print(f"{len(df)} rows")
df = df[df.pseudolabel.notna()].copy()
print(f"Only using {len(df)} rows with non-empty pseudolabels")
print("Train-test split:")
print(df.spl.value_counts())
df_train = df[df.spl == 'train'].copy()
ds = DS(df_train, args.data_dir)
print("Loading CLIP model...")
model = SentenceTransformer('clip-ViT-B-32')
print("CLIP model loaded")
dl = DataLoader(
ds,
shuffle=True,
batch_size=args.batch_size,
num_workers=args.num_workers,
collate_fn=model.smart_batching_collate,
pin_memory=True)
loss = losses.MultipleNegativesRankingLoss(model=model)
print("Training model...")
model.fit(
train_objectives=[(dl, loss)],
epochs=args.epochs,
output_path=args.output,
optimizer_params={'lr': args.lr}
)
print("Model saved to:", args.output)
print("done")
if __name__ == "__main__":
main()