Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
dyabel committed Jul 13, 2022
1 parent e72aa54 commit 49d6f35
Show file tree
Hide file tree
Showing 46 changed files with 3,774 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
.history
.vscode
changelog.py
*.pyc
*.pth
Empty file added MODEL/__init__.py
Empty file.
1 change: 1 addition & 0 deletions MODEL/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .defaults import _C as cfg
132 changes: 132 additions & 0 deletions MODEL/config/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os
from pickle import FALSE

from yacs.config import CfgNode as CN

_C = CN()

_C.MODEL = CN()
_C.MODEL.DEVICE = "cuda"

_C.MODEL.META_ARCHITECTURE = "GEMModel"
_C.MODEL.NAME = ""

_C.MODEL.WEIGHT = ""

_C.MODEL.SCALE = 20.0
_C.MODEL.SCALE_SEMANTIC = 25.0

_C.MODEL.ATTEN_THR = 9.0
_C.MODEL.HID = 512
_C.MODEL.ORTH=False




# -----------------------------------------------------------------------------
# Backbone, ResNet101
# -----------------------------------------------------------------------------
_C.MODEL.BACKBONE = CN()
_C.MODEL.BACKBONE.PRETRAINED = True


# -----------------------------------------------------------------------------
# Attention
# -----------------------------------------------------------------------------
_C.MODEL.ATTENTION = CN()
_C.MODEL.ATTENTION.MODE = 'add' # 'add', 'concat'
_C.MODEL.ATTENTION.CHANNEL = 512
_C.MODEL.ATTENTION.WEIGHT_SHARED = True
_C.MODEL.ATTENTION.W2V_PATH = "datasets/Attrbute/w2v"

# -----------------------------------------------------------------------------
# Loss
# -----------------------------------------------------------------------------
_C.MODEL.LOSS = CN()
_C.MODEL.LOSS.LAMBDA1 = 1.
_C.MODEL.LOSS.LAMBDA2 = 1.
_C.MODEL.LOSS.LAMBDA3 = 1.
_C.MODEL.LOSS.LAMBDA4 = 1.
_C.MODEL.LOSS.LAMBDA5 = 1.
_C.MODEL.LOSS.LAMBDA6 = 1.
_C.MODEL.LOSS.LAMBDA7 = 1.
_C.MODEL.RESUME_FROM = None

_C.MODEL.LOSS.TEMP = 0.07
_C.MODEL.LOSS.MARGIN = 0.8
_C.MODEL.LOSS.ALPHA = 0.5
_C.MODEL.LOSS.BETA = 0.



# -----------------------------------------------------------------------------
# Dataset
# -----------------------------------------------------------------------------
_C.DATASETS = CN()
_C.DATASETS.NAME = "CUB"
_C.DATASETS.IMAGE_SIZE = 224
_C.DATASETS.WAYS = 16
_C.DATASETS.SHOTS = 4

# -----------------------------------------------------------------------------
# DataLoader
# -----------------------------------------------------------------------------
_C.DATALOADER = CN()
# Number of data loading threads
_C.DATALOADER.NUM_WORKERS = 4
_C.DATALOADER.N_BATCH = 1000
_C.DATALOADER.EP_PER_BATCH = 1
_C.DATALOADER.MODE = 'random' # random, episode


# ---------------------------------------------------------------------------- #
# Solver
# ---------------------------------------------------------------------------- #
_C.SOLVER = CN()
_C.SOLVER.MAX_EPOCH = 100

_C.SOLVER.BASE_LR = 1e-3
_C.SOLVER.BIAS_LR_FACTOR = 2


_C.SOLVER.WEIGHT_DECAY = 1e-5
_C.SOLVER.WEIGHT_DECAY_BIAS = 0

_C.SOLVER.MOMENTUM = 0.9

_C.SOLVER.GAMMA = 0.5
_C.SOLVER.STEPS = 10

_C.SOLVER.CHECKPOINT_PERIOD = 50

_C.SOLVER.DATA_AUG = "resize_random_crop"

_C.SOLVER.RESUME_OPTIM = False
_C.SOLVER.RESUME_SCHED = False

# ---------------------------------------------------------------------------- #
# Specific test options
# ---------------------------------------------------------------------------- #
_C.TEST = CN()
_C.TEST.IMS_PER_BATCH = 100
_C.TEST.DATA_AUG = "resize_crop"
_C.TEST.GAMMA = 0.7


# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #
_C.OUTPUT_DIR = "."
_C.LOG_FILE_NAME = ""
_C.MODEL_FILE_NAME = ""
_C.PRETRAINED_MODELS = "./pretrained_models"


# ---------------------------------------------------------------------------- #
# Precision options
# ---------------------------------------------------------------------------- #

# Precision of input, allowable: (float32, float16)
_C.DTYPE = "float32"

_C.PREFIX=''
1 change: 1 addition & 0 deletions MODEL/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .build import build_dataloader
213 changes: 213 additions & 0 deletions MODEL/data/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from os.path import join

import torch
from torch.utils.data import DataLoader
import numpy as np

from scipy import io
from sklearn import preprocessing

from .random_dataset import RandDataset
from .episode_dataset import EpiDataset, CategoriesSampler, DCategoriesSampler
from .test_dataset import TestDataset

from .transforms import data_transform

from MODEL.utils.comm import get_world_size
import numpy
import random
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
numpy.random.seed(worker_seed)
random.seed(worker_seed)
class ImgDatasetParam(object):
DATASETS = {
"imgroot": 'datasets',
"dataroot": 'datasets/Data',
"image_embedding": 'res101',
"class_embedding": 'att'
}

@staticmethod
def get(dataset):
attrs = ImgDatasetParam.DATASETS
attrs["imgroot"] = join(attrs["imgroot"], dataset)
args = dict(
dataset=dataset
)
args.update(attrs)
return args

def build_dataloader(cfg, is_distributed=False):

args = ImgDatasetParam.get(cfg.DATASETS.NAME)
imgroot = args['imgroot']
dataroot = args['dataroot']
image_embedding = args['image_embedding']
class_embedding = args['class_embedding']
dataset = args['dataset']

matcontent = io.loadmat(dataroot + "/" + dataset + "/" + image_embedding + ".mat")

# if cfg.DATASETS.NAME == 'APY':
# dataroot =
img_files =np.squeeze(matcontent['image_files'])
new_img_files = []
for img_file in img_files:
img_path = img_file[0]
if dataset=='CUB':
img_path = join(imgroot, '/'.join(img_path.split('/')[5:]))
elif dataset=='AwA2':
eff_path = img_path.split('/')[5:]
eff_path.remove('')
img_path = join(imgroot, '/'.join(eff_path))
elif dataset=='SUN':
img_path = join(imgroot, '/'.join(img_path.split('/')[7:]))
elif dataset=='APY':
# print(join(imgroot,'/'.join(img_path.split('/')[7:])))
img_path = join(imgroot,'/'.join(img_path.split('/')[7:]))
new_img_files.append(img_path)

new_img_files = np.array(new_img_files)
label = matcontent['labels'].astype(int).squeeze() - 1
matcontent = io.loadmat(dataroot + "/" + dataset + "/" + class_embedding + "_splits.mat")
trainvalloc = matcontent['trainval_loc'].squeeze() - 1
test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1
test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1

att_name = 'att'
# att_name = 'original_att'
# if dataset == 'AwA2':
# att_name = 'original_att'
cls_name = matcontent['allclasses_names']

attribute = matcontent[att_name].T

train_img = new_img_files[trainvalloc]
train_label = label[trainvalloc].astype(int)
train_att = attribute[train_label]

train_id, idx = np.unique(train_label, return_inverse=True)
train_att_unique = attribute[train_id]
train_clsname = cls_name[train_id]

num_train = len(train_id)
train_label = idx
train_id = np.unique(train_label)

test_img_unseen = new_img_files[test_unseen_loc]
test_label_unseen = label[test_unseen_loc].astype(int)
test_id, idx = np.unique(test_label_unseen, return_inverse=True)
att_unseen = attribute[test_id]
test_clsname = cls_name[test_id]
test_label_unseen = idx + num_train
test_id = np.unique(test_label_unseen)
train_test_att = np.concatenate((train_att_unique, att_unseen))
train_test_id = np.concatenate((train_id, test_id))

test_img_seen = new_img_files[test_seen_loc]
test_label_seen = label[test_seen_loc].astype(int)
_, idx = np.unique(test_label_seen, return_inverse=True)
test_label_seen = idx

att_unseen = torch.from_numpy(att_unseen).float()
test_label_seen = torch.tensor(test_label_seen)
test_label_unseen = torch.tensor(test_label_unseen)
train_label = torch.tensor(train_label)
att_seen = torch.from_numpy(train_att_unique).float()

res = {
'train_label': train_label,
'train_att': train_att,
'test_label_seen': test_label_seen,
'test_label_unseen': test_label_unseen,
'att_unseen': att_unseen,
'att_seen': att_seen,
'train_id': train_id,
'test_id': test_id,
'train_test_id': train_test_id,
'train_clsname': train_clsname,
'test_clsname': test_clsname
}
# print(att_unseen,att_seen)
num_gpus = get_world_size()

# train dataloader
ways = cfg.DATASETS.WAYS
shots = cfg.DATASETS.SHOTS
data_aug_train = cfg.SOLVER.DATA_AUG
img_size = cfg.DATASETS.IMAGE_SIZE
transforms = data_transform(data_aug_train, size=img_size)
if cfg.DATALOADER.MODE == 'random':
dataset = RandDataset(train_img, train_att, train_label, transforms)

if not is_distributed:
sampler = torch.utils.data.sampler.RandomSampler(dataset)
batch = ways*shots
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, batch_size=batch, drop_last=True)
tr_dataloader = torch.utils.data.DataLoader(
dataset=dataset,
num_workers=8,
batch_sampler=batch_sampler,
)
else:
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True)
batch = ways * shots
tr_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, sampler=sampler, num_workers=8,worker_init_fn=seed_worker)


elif cfg.DATALOADER.MODE == 'episode':
n_batch = cfg.DATALOADER.N_BATCH
ep_per_batch = cfg.DATALOADER.EP_PER_BATCH
dataset = EpiDataset(train_img, train_att, train_label, transforms)
if not is_distributed:
sampler = CategoriesSampler(
train_label,
n_batch,
ways,
shots,
ep_per_batch
)
else:
sampler = DCategoriesSampler(
train_label,
n_batch,
ways,
shots,
ep_per_batch
)
tr_dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, num_workers=8, pin_memory=True,worker_init_fn=seed_worker)

data_aug_test = cfg.TEST.DATA_AUG
transforms = data_transform(data_aug_test, size=img_size)
test_batch_size = cfg.TEST.IMS_PER_BATCH

if not is_distributed:
# test unseen dataloader
tu_data = TestDataset(test_img_unseen, test_label_unseen, transforms)
tu_loader = torch.utils.data.DataLoader(
tu_data, batch_size=test_batch_size, shuffle=False,
num_workers=4, pin_memory=False)

# test seen dataloader
ts_data = TestDataset(test_img_seen, test_label_seen, transforms)
ts_loader = torch.utils.data.DataLoader(
ts_data, batch_size=test_batch_size, shuffle=False,
num_workers=4, pin_memory=False)
else:
# test unseen dataloader
tu_data = TestDataset(test_img_unseen, test_label_unseen, transforms)
tu_sampler = torch.utils.data.distributed.DistributedSampler(dataset=tu_data, shuffle=False)
tu_loader = torch.utils.data.DataLoader(
tu_data, batch_size=test_batch_size, sampler=tu_sampler,
num_workers=4, pin_memory=False)

# test seen dataloader
ts_data = TestDataset(test_img_seen, test_label_seen, transforms)
ts_sampler = torch.utils.data.distributed.DistributedSampler(dataset=ts_data, shuffle=False)
ts_loader = torch.utils.data.DataLoader(
ts_data, batch_size=test_batch_size, sampler=ts_sampler,
num_workers=4, pin_memory=False)

return tr_dataloader, tu_loader, ts_loader, res

2 changes: 2 additions & 0 deletions MODEL/data/episode_dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .dataset import EpiDataset
from .samplers import CategoriesSampler, DCategoriesSampler
29 changes: 29 additions & 0 deletions MODEL/data/episode_dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
import torch.utils.data as data

import numpy as np
from PIL import Image

class EpiDataset(data.Dataset):

def __init__(self, img_path, atts, labels, transforms=None):
self.img_path = img_path
self.atts = torch.tensor(atts).float()
self.labels = torch.tensor(labels).long()
self.classes = np.unique(labels)

self.transforms = transforms

def __getitem__(self, index):
img_path = self.img_path[index]
img = Image.open(img_path).convert('RGB')
if self.transforms is not None:
img = self.transforms(img)

label = self.labels[index]
att = self.atts[index]

return img, att, label

def __len__(self):
return self.labels.size(0)
Loading

0 comments on commit 49d6f35

Please sign in to comment.