Skip to content

Commit

Permalink
#5 edit model name / fix model fc gender layer
Browse files Browse the repository at this point in the history
  • Loading branch information
“KimDaeYu” committed Mar 3, 2022
1 parent a417823 commit 9b06fc8
Showing 1 changed file with 114 additions and 6 deletions.
120 changes: 114 additions & 6 deletions Submission_codes/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ def forward(self, x):
return x

# Custom Model Template
class Res183Ways(nn.Module):
class Res18_3MultiLabel(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.res18 = Res18(num_classes)
self.res18.load_state_dict(torch.load("./model/res18/best.pth"))
self.res18.load_state_dict(torch.load("/opt/ml/workspace/code/model/Reres18/best.pth"))
self.res18 = nn.Sequential(*list(self.res18.pretrain_model.children())[:-1])

self.mask = nn.Linear(512, 3, bias=True)
self.age = nn.Linear(512, 3, bias=True)
self.gender = nn.Linear(512, 3, bias=True)
self.gender = nn.Linear(512, 2, bias=True)

def dfs_freeze(model):
for name, child in model.named_children():
Expand All @@ -89,16 +89,16 @@ def forward(self, x):
return {"mask":m, "age":a, "gender":s}

# Custom Model Template
class Res503Ways(nn.Module):
class Res50_3MultiLabel(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.res50 = Res50(num_classes)
self.res50.load_state_dict(torch.load("./model/res50/best.pth"))
self.res50.load_state_dict(torch.load("/opt/ml/workspace/code/model/Reres50/best.pth"))
self.res50 = nn.Sequential(*list(self.res50.pretrain_model.children())[:-1])

self.mask = nn.Linear(2048, 3, bias=True)
self.age = nn.Linear(2048, 3, bias=True)
self.gender = nn.Linear(2048, 3, bias=True)
self.gender = nn.Linear(2048, 2, bias=True)

def dfs_freeze(model):
for name, child in model.named_children():
Expand All @@ -117,3 +117,111 @@ def forward(self, x):
a = self.age(x)
s = self.gender(x)
return {"mask":m, "age":a, "gender":s}



# Custom Model Template
class Res18_2MultiLabel(nn.Module):
def __init__(self, num_classes):
# super().__init__()
# self.res50 = Res50(num_classes)
# self.res50.load_state_dict(torch.load("./model/Reres50/best.pth"))
# self.res50 = nn.Sequential(*list(self.res50.pretrain_model.children())[:-1])
super().__init__()
self.res18 = Res18(num_classes)
self.res18.load_state_dict(torch.load("/opt/ml/workspace/code/model/Reres18/best.pth"))
self.res18 = nn.Sequential(*list(self.res18.pretrain_model.children())[:-1])

self.mask = nn.Linear(512, 3, bias=True)
self.mask.load_state_dict(torch.load("./resnext50_32x4dfc3ways_maskv3.pt"))

self.age = nn.Linear(524, 3, bias=True)
self.gender = nn.Linear(524, 2, bias=True)
def dfs_freeze(model):
for name, child in model.named_children():
for param in child.parameters():
#print(param)
param.requires_grad = False
#print(param)
dfs_freeze(child)
dfs_freeze(self.res18)

def forward(self, x):
x = self.res18.forward(x)
x = torch.flatten(x, start_dim=1)

m = self.mask(x)

pred_mask = torch.argmax(m, dim=-1).cpu().numpy()
nx = []
base = torch.ones(12)
for i,k in enumerate(pred_mask):
v = int(k)
if v == 0:
base = base * 0
elif v == 1:
base = base * 10
elif v == 2:
base = base * -10
tmp = torch.cat([x[i], self.base.to("cuda")])
nx.append(tmp)

nx = torch.stack(nx)
#print(nx.shape)
a = self.age(nx)
g = self.gender(nx)
return {"mask":m, "age":a, "gender":g}


# Custom Model Template
class Res50_M2MultiLabel(nn.Module):
def __init__(self, num_classes):
# super().__init__()
# self.res50 = Res50(num_classes)
# self.res50.load_state_dict(torch.load("./model/Reres50/best.pth"))
# self.res50 = nn.Sequential(*list(self.res50.pretrain_model.children())[:-1])
super().__init__()
self.res50 = Res50(num_classes)
self.res50.load_state_dict(torch.load("/opt/ml/workspace/code/model/Reres50/best.pth"))
self.res50 = nn.Sequential(*list(self.res50.pretrain_model.children())[:-1])

self.mask = nn.Linear(2048, 3, bias=True)
self.mask.load_state_dict(torch.load("./resnext50_32x4dfc3ways_maskv3.pt"))

self.age = nn.Linear(2060, 3, bias=True)
self.gender = nn.Linear(2060, 2, bias=True)

def dfs_freeze(model):
for name, child in model.named_children():
for param in child.parameters():
#print(param)
param.requires_grad = False
#print(param)
dfs_freeze(child)
dfs_freeze(self.res50)

def forward(self, x):
x = self.res50.forward(x)
x = torch.flatten(x, start_dim=1)

m = self.mask(x)

pred_mask = torch.argmax(m, dim=-1).cpu().numpy()
nx = []
self.base = torch.ones(12)
for i,k in enumerate(pred_mask):
v = int(k)
if v == 0:
self.base = self.base * 0
elif v == 1:
self.base = self.base * 10
elif v == 2:
self.base = self.base * -10
tmp = torch.cat([x[i], self.base.to("cuda")])
nx.append(tmp)

nx = torch.stack(nx)
#print(nx.shape)
a = self.age(nx)
g = self.gender(nx)
return {"mask":m, "age":a, "gender":g}

0 comments on commit 9b06fc8

Please sign in to comment.