Skip to content

Commit

Permalink
mlp folder
Browse files Browse the repository at this point in the history
  • Loading branch information
jefequien committed Nov 13, 2024
1 parent f355877 commit a14bcb7
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 51 deletions.
54 changes: 4 additions & 50 deletions examples/blur_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from examples.mlp import create_mlp
from examples.mlp import create_mlp, get_encoder
from gsplat.utils import log_transform


Expand All @@ -12,9 +12,9 @@ class BlurOptModule(nn.Module):
def __init__(self, n: int, embed_dim: int = 4):
super().__init__()
self.embeds = torch.nn.Embedding(n, embed_dim)
self.means_encoder = get_encoder(3, 3)
self.depths_encoder = get_encoder(3, 1)
self.grid_encoder = get_encoder(1, 2)
self.means_encoder = get_encoder(num_freqs=3, input_dims=3)
self.depths_encoder = get_encoder(num_freqs=3, input_dims=1)
self.grid_encoder = get_encoder(num_freqs=1, input_dims=2)
self.blur_mask_mlp = create_mlp(
in_dim=embed_dim + self.depths_encoder.out_dim + self.grid_encoder.out_dim,
num_layers=5,
Expand Down Expand Up @@ -96,49 +96,3 @@ def loss_fn(x: Tensor):
ys = loss_fn(xs)
c = ys.min()
return lambda x: loss_fn(x) - c


def get_encoder(num_freqs: int, input_dims: int):
kwargs = {
"include_input": True,
"input_dims": input_dims,
"max_freq_log2": num_freqs - 1,
"num_freqs": num_freqs,
"log_sampling": True,
"periodic_fns": [torch.sin, torch.cos],
}
encoder = Encoder(**kwargs)
return encoder


class Encoder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()

def create_embedding_fn(self):
embed_fns = []
d = self.kwargs["input_dims"]
out_dim = 0
if self.kwargs["include_input"]:
embed_fns.append(lambda x: x)
out_dim += d

max_freq = self.kwargs["max_freq_log2"]
N_freqs = self.kwargs["num_freqs"]

if self.kwargs["log_sampling"]:
freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)

for freq in freq_bands:
for p_fn in self.kwargs["periodic_fns"]:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
out_dim += d

self.embed_fns = embed_fns
self.out_dim = out_dim

def encode(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
2 changes: 2 additions & 0 deletions examples/mlp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .encoder import get_encoder
from .mlp import create_mlp
47 changes: 47 additions & 0 deletions examples/mlp/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch


def get_encoder(num_freqs: int, input_dims: int):
kwargs = {
"include_input": True,
"input_dims": input_dims,
"max_freq_log2": num_freqs - 1,
"num_freqs": num_freqs,
"log_sampling": True,
"periodic_fns": [torch.sin, torch.cos],
}
encoder = Encoder(**kwargs)
return encoder


class Encoder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()

def create_embedding_fn(self):
embed_fns = []
d = self.kwargs["input_dims"]
out_dim = 0
if self.kwargs["include_input"]:
embed_fns.append(lambda x: x)
out_dim += d

max_freq = self.kwargs["max_freq_log2"]
N_freqs = self.kwargs["num_freqs"]

if self.kwargs["log_sampling"]:
freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)

for freq in freq_bands:
for p_fn in self.kwargs["periodic_fns"]:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
out_dim += d

self.embed_fns = embed_fns
self.out_dim = out_dim

def encode(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
File renamed without changes.
2 changes: 1 addition & 1 deletion examples/mlp.py → examples/mlp/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from torch import nn

from examples.external import TCNN_EXISTS, tcnn
from examples.mlp.external import TCNN_EXISTS, tcnn


def activation_to_tcnn_string(activation: Union[nn.Module, None]) -> str:
Expand Down

0 comments on commit a14bcb7

Please sign in to comment.