Skip to content

Commit

Permalink
Merge branch 'jolibrain:master' into feat_discriminator_unet_mha
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 authored Nov 7, 2023
2 parents 2090e9c + 98a09fe commit 439d06e
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 14 deletions.
6 changes: 3 additions & 3 deletions data/online_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def crop_image(
if len(load_size) == 1:
load_size.extend(load_size)
old_size = img.size
img = F.resize(img, load_size)
img = F.resize(img, (load_size[1], load_size[0]))
new_size = img.size
ratio_x = img.size[0] / old_size[0]
ratio_y = img.size[1] / old_size[1]
Expand All @@ -59,7 +59,6 @@ def crop_image(
# Bbox file
f = open(bbox_path, "r")
else:

# bbox_img = np.array(Image.open(img_path))
import cv2

Expand Down Expand Up @@ -95,7 +94,6 @@ def crop_image(
print("%s does not describe a bbox" % line)

else:

cat = str(int(np.max(bbox_img)))

# Find the indices of non-zero elements in the image
Expand Down Expand Up @@ -169,6 +167,8 @@ def crop_image(
else:
raise ValueError("mask_delta value is incorrect.")
else:
if len(mask_delta) <= cat - 1:
raise ValueError("too few classes, can't find mask_delta value")
mask_delta_cat = mask_delta[cat - 1]
if isinstance(mask_delta[0][0], float):
if len(mask_delta_cat) == 1:
Expand Down
2 changes: 1 addition & 1 deletion docs/options.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Here are all the available options to call with `train.py`
| --D_norm | string | instance | instance normalization or batch normalization for D<br/><br/> **Values:** instance, batch, none |
| --D_proj_config_segformer | string | models/configs/segformer/segformer_config_b0.json | path to segformer configuration file |
| --D_proj_interp | int | -1 | whether to force projected discriminator interpolation to a value \> 224, -1 means no interpolation |
| --D_proj_network_type | string | efficientnet | projected discriminator architecture<br/><br/> **Values:** efficientnet, segformer, vitbase, vitsmall, vitsmall2, vitclip16, vitclip14, depth |
| --D_proj_network_type | string | efficientnet | projected discriminator architecture<br/><br/> **Values:** efficientnet, segformer, vitbase, vitsmall, vitsmall2, vitclip16, vitclip14, depth, dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14 |
| --D_proj_weight_segformer | string | models/configs/segformer/pretrain/segformer_mit-b0.pth | path to segformer weight |
| --D_spectral | flag | | whether to use spectral norm in the discriminator |
| --D_temporal_every | int | 4 | apply temporal discriminator every x steps |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/_static/openapi.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Discriminator
+----------------------------+-----------------+--------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| --D_proj_interp | int | -1 | whether to force projected discriminator interpolation to a value > 224, -1 means no interpolation |
+----------------------------+-----------------+--------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| --D_proj_network_type | string | efficientnet | projected discriminator architecture **Values:** efficientnet, segformer, vitbase, vitsmall, vitsmall2, vitclip16, vitclip14, depth |
| --D_proj_network_type | string | efficientnet | projected discriminator architecture **Values:** efficientnet, segformer, vitbase, vitsmall, vitsmall2, vitclip16, vitclip14, depth, dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14 |
+----------------------------+-----------------+--------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| --D_proj_weight_segformer | string | models/configs/segformer/pretrain/segformer_mit-b0.pth | path to segformer weight |
+----------------------------+-----------------+--------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
Expand Down
3 changes: 2 additions & 1 deletion models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,8 @@ def compute_D_loss_generic(
else:
real = getattr(self, real_name)

loss = loss.compute_loss_D(netD, real, fake, fake_2)
with torch.cuda.amp.autocast(enabled=self.with_amp):
loss = loss.compute_loss_D(netD, real, fake, fake_2)
return loss

def compute_D_loss(self):
Expand Down
49 changes: 47 additions & 2 deletions models/modules/projected_d/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def _make_depth(model):
return model


def _make_dinov2(model):
configure_get_feats_dinov2(model)
return model


def configure_forward_network(net):
def forward(x):
out0 = net.layer0(x)
Expand Down Expand Up @@ -144,7 +149,6 @@ def get_feats(x):

def configure_get_feats_depth(net):
def get_feats(x):

x = net.transform(x)

if net.channels_last == True:
Expand Down Expand Up @@ -184,6 +188,23 @@ def get_feats(x):
net.get_feats = get_feats


def configure_get_feats_dinov2(net):
dino_layers = {
"dinov2_vits14": [2, 5, 8, 11],
"dinov2_vitb14": [3, 8, 12, 17],
"dinov2_vitl14": [4, 10, 16, 23],
"dinov2_vitg14": [6, 16, 26, 39],
}

def get_feats(x):
feats = net.get_intermediate_layers(
x, n=[2, 5, 8, 11], return_class_token=False
)
return feats

net.get_feats = get_feats


def calc_channels(pretrained, inp_res=224):
channels = []
feats = []
Expand Down Expand Up @@ -216,8 +237,12 @@ def create_clip_model(model_name, config_path, weight_path, img_size):
return model[0].visual.float().cpu()


def create_segformer_model(model_name, config_path, weight_path, img_size):
def create_dinov2_model(model_name, config_path, weight_path, img_size):
dinov2_model = torch.hub.load("facebookresearch/dinov2", model_name)
return dinov2_model


def create_segformer_model(model_name, config_path, weight_path, img_size):
cfg = load_config_file(config_path)
try:
weights = torch.jit.load(weight_path).state_dict()
Expand Down Expand Up @@ -303,6 +328,26 @@ def create_depth_model(model_name, config_path, weight_path, img_size):
"create_model_function": create_depth_model,
"make_function": _make_depth,
},
"dinov2_vits14": {
"model_name": "dinov2_vits14",
"create_model_function": create_dinov2_model,
"make_function": _make_dinov2,
},
"dinov2_vitb14": {
"model_name": "dinov2_vitb14",
"create_model_function": create_dinov2_model,
"make_function": _make_dinov2,
},
"dinov2_vitl14": {
"model_name": "dinov2_vitl14",
"create_model_function": create_dinov2_model,
"make_function": _make_dinov2,
},
"dinov2_vitg14": {
"model_name": "dinov2_vitg14",
"create_model_function": create_dinov2_model,
"make_function": _make_dinov2,
},
}


Expand Down
4 changes: 4 additions & 0 deletions options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,10 @@ def initialize(self, parser):
"vitclip16",
"vitclip14",
"depth",
"dinov2_vits14",
"dinov2_vitb14",
"dinov2_vitl14",
"dinov2_vitg14",
],
help="projected discriminator architecture",
)
Expand Down
3 changes: 3 additions & 0 deletions scripts/gen_single_image_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,9 @@ def generate(
"""if crop_width > 0 and crop_height > 0:
mask = resize(mask).clone().detach()"""
if ref is not None:
ref = cv2.resize(
ref, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC
)
ref_tensor = tran(ref).clone().detach()

if not cpu:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_run_diffusion_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"data_online_creation_mask_delta_A_ratio": [[0.2, 0.2]],
"data_online_creation_crop_size_B": 420,
"data_online_creation_crop_delta_B": 50,
"data_online_creation_load_size_A": [2500, 1000],
"data_online_creation_load_size_B": [2500, 1000],
"data_online_creation_load_size_A": [1000, 2500],
"data_online_creation_load_size_B": [1000, 2500],
"train_n_epochs": 1,
"train_n_epochs_decay": 0,
"data_max_dataset_size": 10,
Expand Down
11 changes: 8 additions & 3 deletions tests/test_run_semantic_mask_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
"data_online_creation_mask_delta_A_ratio": [[0.2, 0.2]],
"data_online_creation_crop_size_B": 420,
"data_online_creation_crop_delta_B": 50,
"data_online_creation_load_size_A": [2500, 1000],
"data_online_creation_load_size_B": [2500, 1000],
"data_online_creation_load_size_A": [1000, 2500],
"data_online_creation_load_size_B": [1000, 2500],
"data_online_context_pixels": 0,
"train_n_epochs": 1,
"train_n_epochs_decay": 0,
Expand Down Expand Up @@ -59,7 +59,7 @@

G_netG = ["mobile_resnet_attn", "segformer_attn_conv"]

D_proj_network_type = ["efficientnet", "vitsmall"]
D_proj_network_type = ["efficientnet", "vitsmall", "dinov2_vits14"]

D_netDs = [
["basic", "projected_d"],
Expand All @@ -73,6 +73,8 @@

data_online_context_pixels = [0, 10]

with_amp = [False, True]

product_list = product(
models_semantic_mask,
G_netG,
Expand All @@ -81,6 +83,7 @@
f_s_net,
model_type_sam,
data_online_context_pixels,
with_amp,
)


Expand All @@ -96,6 +99,7 @@ def test_semantic_mask_online(dataroot):
f_s_type,
sam_type,
data_online_context_pixels,
with_amp,
) in product_list:
if model == "cycle_gan":
if (
Expand All @@ -113,6 +117,7 @@ def test_semantic_mask_online(dataroot):
json_like_dict_c["f_s_net"] = f_s_type
json_like_dict_c["model_type_sam"] = sam_type
json_like_dict_c["data_online_context_pixels"] = data_online_context_pixels
json_like_dict_c["with_amp"] = with_amp

opt = TrainOptions().parse_json(json_like_dict_c, save_config=True)
train.launch_training(opt)

0 comments on commit 439d06e

Please sign in to comment.