Skip to content

Commit

Permalink
add more dino models
Browse files Browse the repository at this point in the history
  • Loading branch information
zhoubin-me committed Feb 25, 2025
1 parent 10b8823 commit c8c20ef
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
28 changes: 22 additions & 6 deletions src/vision/dinov2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ impl DinoVisionTransformer {
Self { patch_embed, cls_token, pos_embed, blocks, norm, head }
}

fn interpolate_pos_encoding(&self, xs: &Tensor, w: i64, h: i64) -> Tensor {
pub fn interpolate_pos_encoding(&self, xs: &Tensor, w: i64, h: i64) -> Tensor {
let npatch = xs.size()[1] - 1;
let n = self.pos_embed.size()[1] - 1;
let sqrt_n = (n as f64).sqrt();
Expand All @@ -192,16 +192,14 @@ impl DinoVisionTransformer {
Tensor::cat(&[class_pos_embed, patch_pos_embed], 1)
}

fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Tensor {
pub fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Tensor {
let (b, _nc, w, h) = xs.size4().unwrap();
let xs = xs.apply(&self.patch_embed);
let xs = Tensor::concat(&[self.cls_token.expand([b, -1, -1], false), xs], 1);
&xs + &self.interpolate_pos_encoding(&xs, w, h)
}
}

impl nn::Module for DinoVisionTransformer {
fn forward(&self, xs: &Tensor) -> Tensor {
pub fn extract_features(&self, xs: &Tensor) -> Tensor {
let mut xs = self.prepare_tokens_with_mask(xs);
for blk in self.blocks.iter() {
xs = xs.apply(blk)
Expand All @@ -210,10 +208,28 @@ impl nn::Module for DinoVisionTransformer {
let xs_norm_clstoken = xs.i((.., 0));
let xs_norm_patchtokens = xs.i((.., 1..)).mean_dim(1, false, None);
let xs = Tensor::concat(&[xs_norm_clstoken, xs_norm_patchtokens], -1);
xs.apply(&self.head)
xs
}
}

impl nn::Module for DinoVisionTransformer {
fn forward(&self, xs: &Tensor) -> Tensor {
self.extract_features(xs).apply(&self.head)
}
}

pub fn vit_small(vs: &nn::Path) -> DinoVisionTransformer {
DinoVisionTransformer::new(vs, 12, 384, 6)
}

pub fn vit_base(vs: &nn::Path) -> DinoVisionTransformer {
DinoVisionTransformer::new(vs, 12, 768, 12)
}

pub fn vit_large(vs: &nn::Path) -> DinoVisionTransformer {
DinoVisionTransformer::new(vs, 24, 1024, 16)
}

pub fn vit_giant(vs: &nn::Path) -> DinoVisionTransformer {
DinoVisionTransformer::new(vs, 40, 1536, 24)
}
11 changes: 6 additions & 5 deletions src/vision/export_dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ def normalize_key(k):
k = k[7:]
return k

dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc', layers=1)
print(dinov2_vits14)
weights = dinov2_vits14.state_dict()
weights = {normalize_key(k): v for k, v in weights.items()}
save_file(weights, "dinov2_vits14.safetensors")
for model_size in ["small", "base", "large", "giant"]:
letter = model_size[0]
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{letter}14_lc', layers=1)
weights = dinov2_vits14.state_dict()
weights = {normalize_key(k): v for k, v in weights.items()}
save_file(weights, f"dinov2_vit{letter}14.safetensors")

0 comments on commit c8c20ef

Please sign in to comment.