From 7fcf790661e64d458238c2ac68f440098caa0bae Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Wed, 8 Nov 2023 14:59:26 +0000 Subject: [PATCH] feat(ml): dinov2 discriminator with registers --- models/modules/projected_d/projector.py | 28 ++++++++++++++++++++++++- options/base_options.py | 4 ++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/models/modules/projected_d/projector.py b/models/modules/projected_d/projector.py index 1f14a8fb9..4d71bceda 100644 --- a/models/modules/projected_d/projector.py +++ b/models/modules/projected_d/projector.py @@ -194,6 +194,10 @@ def configure_get_feats_dinov2(net): "dinov2_vitb14": [3, 8, 12, 17], "dinov2_vitl14": [4, 10, 16, 23], "dinov2_vitg14": [6, 16, 26, 39], + "dinov2_vits14_reg": [2, 5, 8, 11], + "dinov2_vitb14_reg": [3, 8, 12, 17], + "dinov2_vitl14_reg": [4, 10, 16, 23], + "dinov2_vitg14_reg": [6, 16, 26, 39], } def get_feats(x): @@ -238,7 +242,9 @@ def create_clip_model(model_name, config_path, weight_path, img_size): def create_dinov2_model(model_name, config_path, weight_path, img_size): - dinov2_model = torch.hub.load("facebookresearch/dinov2", model_name) + dinov2_model = torch.hub.load( + "facebookresearch/dinov2", model_name, force_reload=True + ) return dinov2_model @@ -348,6 +354,26 @@ def create_depth_model(model_name, config_path, weight_path, img_size): "create_model_function": create_dinov2_model, "make_function": _make_dinov2, }, + "dinov2_vits14_reg": { + "model_name": "dinov2_vits14_reg", + "create_model_function": create_dinov2_model, + "make_function": _make_dinov2, + }, + "dinov2_vitb14_reg": { + "model_name": "dinov2_vitb14_reg", + "create_model_function": create_dinov2_model, + "make_function": _make_dinov2, + }, + "dinov2_vitl14": { + "model_name": "dinov2_vitl14_reg", + "create_model_function": create_dinov2_model, + "make_function": _make_dinov2, + }, + "dinov2_vitg14_reg": { + "model_name": "dinov2_vitg14_reg", + "create_model_function": create_dinov2_model, + "make_function": _make_dinov2, + }, } diff --git a/options/base_options.py b/options/base_options.py index c0a44e0dd..a72a20070 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -447,6 +447,10 @@ def initialize(self, parser): "dinov2_vitb14", "dinov2_vitl14", "dinov2_vitg14", + "dinov2_vits14_reg", + "dinov2_vitb14_reg", + "dinov2_vitl14_reg", + "dinov2_vitg14_reg", ], help="projected discriminator architecture", )