-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathvis_embed.py
31 lines (28 loc) · 993 Bytes
/
vis_embed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from tensorboardX import SummaryWriter
import pickle
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='Visualize embedding in tensorboard.')
parser.add_argument('--embeddings', type=str, required=True, help='embedding pickle file.')
parser.add_argument('--save-folder', type=str, default='./runs', help='directory to save tensorboard file.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
assert args.embeddings.endswith('.pkl')
writer = SummaryWriter(args.save_folder)
with open(args.embeddings, 'rb') as f:
data = pickle.load(f)
labels = []
feats = []
for label in data.keys():
labels.extend([label]*len(data[label]))
feats.extend(data[label])
feats = [torch.Tensor(feat) for feat in feats]
feats = torch.cat(feats, 0)
print(feats.shape, len(labels))
writer.add_embedding(feats, metadata=labels)
writer.close()