forked from rish-16/cs4243-project
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_analysis.py
35 lines (32 loc) · 1009 Bytes
/
dataset_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from collections import Counter
import os
import imageio
import glob
from torch.utils.data import Dataset, TensorDataset, DataLoader
import cv2
def plot_dataset_dist(d, title="untitled"):
plt.title(title)
plt.bar([c for c in d.keys()], [x.shape[0] for x in d.values()])
plt.show()
def plot_dataset_mean(d):
cats = list(d.keys())
col = 5 if len(cats) > 8 else 4
fig, ax = plt.subplots(2, col, figsize=(15,6))
plt.axis('off')
for r in range(2):
for c in range(col):
try:
mean_img = d[cats[r * col + c]].mean(axis=0).astype(int)
ax[r,c].axis('off')
ax[r,c].set_title(cats[r * col + c])
ax[r,c].set_xticks([])
ax[r,c].set_yticks([])
ax[r,c].imshow(mean_img, cmap='gray')
except:
break
plt.show()