-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_nima.py
61 lines (49 loc) · 1.59 KB
/
train_nima.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import importlib.util
import math
import os.path
from os import path
import time
from fastai.script import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.distributed import *
from fastai.callbacks.tracker import *
torch.backends.cudnn.benchmark = True
import numpy as np
import pandas as pd
from torch import nn
tfms = ([
flip_lr(p=0.5),
brightness(change=(0.4,0.6)),
contrast(scale=(0.7,1.3))
], [])
df = pd.read_csv('./labels_sample.csv')
scores_map = dict(zip(df.image_name, df.tags))
func = lambda o: int((o.split('/')[2]).split('.')[0])
labels = list(scores_map.keys())
class NimaLabelList(CategoryList):
_processor=None
def __init__(self, items:Iterator, classes=labels, label_delim:str=None, **kwargs):
super().__init__(items, classes=classes, **kwargs)
def get(self, i):
dist = scores_map[self.items[i]]
dist = np.array(dist.split(' '), dtype=int)
dist = dist / dist.sum()
return dist
data = ImageList.from_csv('./', 'labels_sample.csv', folder='data', suffix='.jpg')
data = data.split_by_rand_pct()
data = data.label_from_func(func, label_cls=NimaLabelList)
data = data.transform(tfms, size=224)
data = data.databunch(bs=8)
x,y = next(iter(data.train_dl))
data.c = 10
def emd(y, y_hat):
cdf_y = torch.cumsum(y, dim=-1)
cdf_y_hat = torch.cumsum(y_hat, dim=-1).double()
power = torch.pow((cdf_y - cdf_y_hat), 2)
emd = torch.sqrt(torch.mean(power, dim=-1))
return torch.mean(emd)
arch = models.mobilenet_v2
learn = cnn_learner(data, arch, pretrained=True)
learn.loss_func = emd
y_hat = learn.model(x)