From a14bcb7d26796cbb8cac0dc31fc863c94fbfc391 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 13 Nov 2024 12:47:34 -0800 Subject: [PATCH] mlp folder --- examples/blur_opt.py | 54 +++------------------------------- examples/mlp/__init__.py | 2 ++ examples/mlp/encoder.py | 47 +++++++++++++++++++++++++++++ examples/{ => mlp}/external.py | 0 examples/{ => mlp}/mlp.py | 2 +- 5 files changed, 54 insertions(+), 51 deletions(-) create mode 100644 examples/mlp/__init__.py create mode 100644 examples/mlp/encoder.py rename examples/{ => mlp}/external.py (100%) rename examples/{ => mlp}/mlp.py (98%) diff --git a/examples/blur_opt.py b/examples/blur_opt.py index e2fcc506d..ee529d6c3 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -2,7 +2,7 @@ import torch.nn as nn from torch import Tensor import torch.nn.functional as F -from examples.mlp import create_mlp +from examples.mlp import create_mlp, get_encoder from gsplat.utils import log_transform @@ -12,9 +12,9 @@ class BlurOptModule(nn.Module): def __init__(self, n: int, embed_dim: int = 4): super().__init__() self.embeds = torch.nn.Embedding(n, embed_dim) - self.means_encoder = get_encoder(3, 3) - self.depths_encoder = get_encoder(3, 1) - self.grid_encoder = get_encoder(1, 2) + self.means_encoder = get_encoder(num_freqs=3, input_dims=3) + self.depths_encoder = get_encoder(num_freqs=3, input_dims=1) + self.grid_encoder = get_encoder(num_freqs=1, input_dims=2) self.blur_mask_mlp = create_mlp( in_dim=embed_dim + self.depths_encoder.out_dim + self.grid_encoder.out_dim, num_layers=5, @@ -96,49 +96,3 @@ def loss_fn(x: Tensor): ys = loss_fn(xs) c = ys.min() return lambda x: loss_fn(x) - c - - -def get_encoder(num_freqs: int, input_dims: int): - kwargs = { - "include_input": True, - "input_dims": input_dims, - "max_freq_log2": num_freqs - 1, - "num_freqs": num_freqs, - "log_sampling": True, - "periodic_fns": [torch.sin, torch.cos], - } - encoder = Encoder(**kwargs) - return encoder - - -class Encoder: - def __init__(self, **kwargs): - self.kwargs = kwargs - self.create_embedding_fn() - - def create_embedding_fn(self): - embed_fns = [] - d = self.kwargs["input_dims"] - out_dim = 0 - if self.kwargs["include_input"]: - embed_fns.append(lambda x: x) - out_dim += d - - max_freq = self.kwargs["max_freq_log2"] - N_freqs = self.kwargs["num_freqs"] - - if self.kwargs["log_sampling"]: - freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs) - else: - freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs) - - for freq in freq_bands: - for p_fn in self.kwargs["periodic_fns"]: - embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) - out_dim += d - - self.embed_fns = embed_fns - self.out_dim = out_dim - - def encode(self, inputs): - return torch.cat([fn(inputs) for fn in self.embed_fns], -1) diff --git a/examples/mlp/__init__.py b/examples/mlp/__init__.py new file mode 100644 index 000000000..58f4382b4 --- /dev/null +++ b/examples/mlp/__init__.py @@ -0,0 +1,2 @@ +from .encoder import get_encoder +from .mlp import create_mlp diff --git a/examples/mlp/encoder.py b/examples/mlp/encoder.py new file mode 100644 index 000000000..188a25bf8 --- /dev/null +++ b/examples/mlp/encoder.py @@ -0,0 +1,47 @@ +import torch + + +def get_encoder(num_freqs: int, input_dims: int): + kwargs = { + "include_input": True, + "input_dims": input_dims, + "max_freq_log2": num_freqs - 1, + "num_freqs": num_freqs, + "log_sampling": True, + "periodic_fns": [torch.sin, torch.cos], + } + encoder = Encoder(**kwargs) + return encoder + + +class Encoder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs["input_dims"] + out_dim = 0 + if self.kwargs["include_input"]: + embed_fns.append(lambda x: x) + out_dim += d + + max_freq = self.kwargs["max_freq_log2"] + N_freqs = self.kwargs["num_freqs"] + + if self.kwargs["log_sampling"]: + freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs) + else: + freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs["periodic_fns"]: + embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def encode(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) diff --git a/examples/external.py b/examples/mlp/external.py similarity index 100% rename from examples/external.py rename to examples/mlp/external.py diff --git a/examples/mlp.py b/examples/mlp/mlp.py similarity index 98% rename from examples/mlp.py rename to examples/mlp/mlp.py index f5bc48acd..f68330bee 100644 --- a/examples/mlp.py +++ b/examples/mlp/mlp.py @@ -20,7 +20,7 @@ from torch import nn -from examples.external import TCNN_EXISTS, tcnn +from examples.mlp.external import TCNN_EXISTS, tcnn def activation_to_tcnn_string(activation: Union[nn.Module, None]) -> str: