Skip to content

Commit

Permalink
Deploying to gh-pages from main @ 4b94064 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
benjijamorris committed Aug 14, 2024
1 parent c8b09e4 commit b2b084f
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 21 deletions.
10 changes: 7 additions & 3 deletions _modules/cyto_dl/nn/vits/blocks/patchify.html
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ <h1>Source code for cyto_dl.nn.vits.blocks.patchify</h1><div class="highlight"><
<span class="kn">from</span> <span class="nn">einops.layers.torch</span> <span class="kn">import</span> <span class="n">Rearrange</span><span class="p">,</span> <span class="n">Reduce</span>
<span class="kn">from</span> <span class="nn">timm.models.layers</span> <span class="kn">import</span> <span class="n">trunc_normal_</span>

<span class="kn">from</span> <span class="nn">cyto_dl.nn.vits.utils</span> <span class="kn">import</span> <span class="n">take_indexes</span>
<span class="kn">from</span> <span class="nn">cyto_dl.nn.vits.utils</span> <span class="kn">import</span> <span class="n">get_positional_embedding</span><span class="p">,</span> <span class="n">take_indexes</span>


<div class="viewcode-block" id="random_indexes"><a class="viewcode-back" href="../../../../../cyto_dl.nn.vits.blocks.patchify.html#cyto_dl.nn.vits.blocks.patchify.random_indexes">[docs]</a><span class="k">def</span> <span class="nf">random_indexes</span><span class="p">(</span><span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
Expand All @@ -453,6 +453,7 @@ <h1>Source code for cyto_dl.nn.vits.blocks.patchify</h1><div class="highlight"><
<span class="n">context_pixels</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
<span class="n">input_channels</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="p">[],</span>
<span class="n">learnable_pos_embedding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Parameters</span>
Expand All @@ -471,12 +472,16 @@ <h1>Source code for cyto_dl.nn.vits.blocks.patchify</h1><div class="highlight"><
<span class="sd"> Number of input channels</span>
<span class="sd"> tasks: List[str]</span>
<span class="sd"> List of tasks to encode</span>
<span class="sd"> learnable_pos_embedding: bool</span>
<span class="sd"> If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">n_patches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">n_patches</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">spatial_dims</span> <span class="o">=</span> <span class="n">spatial_dims</span>

<span class="bp">self</span><span class="o">.</span><span class="n">pos_embedding</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">n_patches</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="n">emb_dim</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pos_embedding</span> <span class="o">=</span> <span class="n">get_positional_embedding</span><span class="p">(</span>
<span class="n">n_patches</span><span class="p">,</span> <span class="n">emb_dim</span><span class="p">,</span> <span class="n">learnable</span><span class="o">=</span><span class="n">learnable_pos_embedding</span><span class="p">,</span> <span class="n">use_cls_token</span><span class="o">=</span><span class="kc">False</span>
<span class="p">)</span>

<span class="n">context_pixels</span> <span class="o">=</span> <span class="n">context_pixels</span><span class="p">[:</span><span class="n">spatial_dims</span><span class="p">]</span>
<span class="n">weight_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">patch_size</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">round</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">context_pixels</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span>
Expand Down Expand Up @@ -538,7 +543,6 @@ <h1>Source code for cyto_dl.nn.vits.blocks.patchify</h1><div class="highlight"><
<span class="bp">self</span><span class="o">.</span><span class="n">_init_weight</span><span class="p">()</span>

<span class="k">def</span> <span class="nf">_init_weight</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">trunc_normal_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pos_embedding</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span>
<span class="k">for</span> <span class="n">task</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">task_embedding</span><span class="p">:</span>
<span class="n">trunc_normal_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">task_embedding</span><span class="p">[</span><span class="n">task</span><span class="p">],</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span>

Expand Down
14 changes: 9 additions & 5 deletions _modules/cyto_dl/nn/vits/cross_mae.html
Original file line number Diff line number Diff line change
Expand Up @@ -426,15 +426,14 @@
<h1>Source code for cyto_dl.nn.vits.cross_mae</h1><div class="highlight"><pre>
<span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span>

<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="kn">from</span> <span class="nn">einops</span> <span class="kn">import</span> <span class="n">rearrange</span>
<span class="kn">from</span> <span class="nn">einops.layers.torch</span> <span class="kn">import</span> <span class="n">Rearrange</span>
<span class="kn">from</span> <span class="nn">timm.models.layers</span> <span class="kn">import</span> <span class="n">trunc_normal_</span>

<span class="kn">from</span> <span class="nn">cyto_dl.nn.vits.blocks</span> <span class="kn">import</span> <span class="n">CrossAttentionBlock</span>
<span class="kn">from</span> <span class="nn">cyto_dl.nn.vits.utils</span> <span class="kn">import</span> <span class="n">take_indexes</span>
<span class="kn">from</span> <span class="nn">cyto_dl.nn.vits.utils</span> <span class="kn">import</span> <span class="n">get_positional_embedding</span><span class="p">,</span> <span class="n">take_indexes</span>


<div class="viewcode-block" id="CrossMAE_Decoder"><a class="viewcode-back" href="../../../../cyto_dl.nn.vits.cross_mae.html#cyto_dl.nn.vits.cross_mae.CrossMAE_Decoder">[docs]</a><span class="k">class</span> <span class="nc">CrossMAE_Decoder</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
Expand All @@ -450,6 +449,7 @@ <h1>Source code for cyto_dl.nn.vits.cross_mae</h1><div class="highlight"><pre>
<span class="n">emb_dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="mi">192</span><span class="p">,</span>
<span class="n">num_layer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="mi">4</span><span class="p">,</span>
<span class="n">num_head</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span>
<span class="n">learnable_pos_embedding</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">bool</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Parameters</span>
Expand All @@ -466,6 +466,8 @@ <h1>Source code for cyto_dl.nn.vits.cross_mae</h1><div class="highlight"><pre>
<span class="sd"> Number of transformer layers</span>
<span class="sd"> num_head: int</span>
<span class="sd"> Number of heads in transformer</span>
<span class="sd"> learnable_pos_embedding: bool</span>
<span class="sd"> If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings are used. Empirically, fixed positional embeddings work better for brightfield images.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>

Expand All @@ -484,7 +486,10 @@ <h1>Source code for cyto_dl.nn.vits.cross_mae</h1><div class="highlight"><pre>

<span class="bp">self</span><span class="o">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">enc_dim</span><span class="p">,</span> <span class="n">emb_dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mask_token</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">emb_dim</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pos_embedding</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">num_patches</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">emb_dim</span><span class="p">))</span>

<span class="bp">self</span><span class="o">.</span><span class="n">pos_embedding</span> <span class="o">=</span> <span class="n">get_positional_embedding</span><span class="p">(</span>
<span class="n">num_patches</span><span class="p">,</span> <span class="n">emb_dim</span><span class="p">,</span> <span class="n">learnable</span><span class="o">=</span><span class="n">learnable_pos_embedding</span>
<span class="p">)</span>

<span class="bp">self</span><span class="o">.</span><span class="n">head</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">emb_dim</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">base_patch_size</span><span class="p">)))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_patches</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">num_patches</span><span class="p">)</span>
Expand All @@ -511,8 +516,7 @@ <h1>Source code for cyto_dl.nn.vits.cross_mae</h1><div class="highlight"><pre>
<span class="bp">self</span><span class="o">.</span><span class="n">init_weight</span><span class="p">()</span>

<div class="viewcode-block" id="CrossMAE_Decoder.init_weight"><a class="viewcode-back" href="../../../../cyto_dl.nn.vits.cross_mae.html#cyto_dl.nn.vits.cross_mae.CrossMAE_Decoder.init_weight">[docs]</a> <span class="k">def</span> <span class="nf">init_weight</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">trunc_normal_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mask_token</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span>
<span class="n">trunc_normal_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pos_embedding</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span></div>
<span class="n">trunc_normal_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mask_token</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span></div>

<div class="viewcode-block" id="CrossMAE_Decoder.forward"><a class="viewcode-back" href="../../../../cyto_dl.nn.vits.cross_mae.html#cyto_dl.nn.vits.cross_mae.CrossMAE_Decoder.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">forward_indexes</span><span class="p">,</span> <span class="n">backward_indexes</span><span class="p">):</span>
<span class="c1"># HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers</span>
Expand Down
Loading

0 comments on commit b2b084f

Please sign in to comment.