Skip to content

Commit

Permalink
feat: initial commit for RAIN-GS
Browse files Browse the repository at this point in the history
  • Loading branch information
ONground-Korea committed Mar 14, 2024
1 parent 8a7c59a commit f163e64
Show file tree
Hide file tree
Showing 1,575 changed files with 274,133 additions and 0 deletions.
Empty file modified LICENSE
100644 → 100755
Empty file.
6 changes: 6 additions & 0 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@ We implement **RAIN-GS** above the official implementation of 3D Gaussian Splatt

To train 3D Gaussian Splatting with our novel strategy (**RAIN-GS**), all you need to do is:

```bash
python train.py -s {dataset_path} --exp_name {exp_name} --eval --ours
```

For dense-small-variance (DSV) initialization, you can simply run the following command:
```bash
python train.py -s {dataset_path} --exp_name {exp_name} --eval --DSV
```

To train with Mip-NeRF360 dataset, you can add argument `--images images_4` for outdoor scenes and `--images images_2` for indoor scenes to modify the resolution of the input images.

## Acknowledgement

Expand Down
102 changes: 102 additions & 0 deletions arguments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@

from argparse import ArgumentParser, Namespace
import sys
import os

class GroupParams:
pass

class ParamGroup:
def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
group = parser.add_argument_group(name)
for key, value in vars(self).items():
shorthand = False
if key.startswith("_"):
shorthand = True
key = key[1:]
t = type(value)
value = value if not fill_none else None
if shorthand:
if t == bool:
group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
else:
group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
else:
if t == bool:
group.add_argument("--" + key, default=value, action="store_true")
else:
group.add_argument("--" + key, default=value, type=t)

def extract(self, args):
group = GroupParams()
for arg in vars(args).items():
if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
setattr(group, arg[0], arg[1])
return group

class ModelParams(ParamGroup):
def __init__(self, parser, sentinel=False):
self.sh_degree = 3
self._source_path = ""
self._model_path = ""
self._images = "images"
self._resolution = -1
self._white_background = False
self.data_device = "cuda"
self.eval = False
super().__init__(parser, "Loading Parameters", sentinel)

def extract(self, args):
g = super().extract(args)
g.source_path = os.path.abspath(g.source_path)
return g

class PipelineParams(ParamGroup):
def __init__(self, parser):
self.convert_SHs_python = False
self.compute_cov3D_python = False
self.debug = False
super().__init__(parser, "Pipeline Parameters")

class OptimizationParams(ParamGroup):
def __init__(self, parser):
self.iterations = 30_000
self.position_lr_init = 0.00016
self.position_lr_final = 0.0000016
self.position_lr_delay_mult = 0.01
self.position_lr_max_steps = 30_000
self.feature_lr = 0.0025
self.opacity_lr = 0.05
self.scaling_lr = 0.005
self.rotation_lr = 0.001
self.percent_dense = 0.01
self.lambda_dssim = 0.2
self.densification_interval = 100
self.opacity_reset_interval = 3000
self.densify_from_iter = 500
self.densify_until_iter = 15_000
self.densify_grad_threshold = 0.0002
self.random_background = False
super().__init__(parser, "Optimization Parameters")

def get_combined_args(parser : ArgumentParser):
cmdlne_string = sys.argv[1:]
cfgfile_string = "Namespace()"
args_cmdline = parser.parse_args(cmdlne_string)

try:
cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
print("Looking for config file in", cfgfilepath)
with open(cfgfilepath) as cfg_file:
print("Config file found: {}".format(cfgfilepath))
cfgfile_string = cfg_file.read()
except TypeError:
print("Config file not found at")
pass
args_cfgfile = eval(cfgfile_string)

merged_dict = vars(args_cfgfile).copy()
for k,v in vars(args_cmdline).items():
if v != None:
merged_dict[k] = v
return Namespace(**merged_dict)
Empty file modified assets/teaser.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 17 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: gaussian_splatting
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- cudatoolkit=11.6
- plyfile=0.8.1
- python=3.7.13
- pip=22.3.1
- pytorch=1.12.1
- torchaudio=0.12.1
- torchvision=0.13.1
- tqdm
- pip:
- submodules/diff-gaussian-rasterization
- submodules/simple-knn
79 changes: 79 additions & 0 deletions gaussian_renderer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
import math
import sys
sys.path.append('..')
from submodules.diff_gaussian_rasterization.diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from scene.gaussian_model import GaussianModel
from utils.sh_utils import eval_sh

def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None, low_pass = 0.3):
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
try:
screenspace_points.retain_grad()
except:
pass


tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.image_height),
image_width=int(viewpoint_camera.image_width),
tanfovx=tanfovx,
tanfovy=tanfovy,
bg=bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform,
sh_degree=pc.active_sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=pipe.debug,
low_pass=low_pass
)
rasterizer = GaussianRasterizer(raster_settings=raster_settings)

means3D = pc.get_xyz
means2D = screenspace_points
opacity = pc.get_opacity

scales = None
rotations = None
cov3D_precomp = None
if pipe.compute_cov3D_python:
cov3D_precomp = pc.get_covariance(scaling_modifier)
else:
scales = pc.get_scaling
rotations = pc.get_rotation

shs = None
colors_precomp = None
if override_color is None:
if pipe.convert_SHs_python:
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
else:
shs = pc.get_features
else:
colors_precomp = override_color

rendered_image, radii, depth = rasterizer(
means3D = means3D,
means2D = means2D,
shs = shs,
colors_precomp = colors_precomp,
opacities = opacity,
scales = scales,
rotations = rotations,
cov3D_precomp = cov3D_precomp)


return {"render": rendered_image,
"viewspace_points": screenspace_points,
"visibility_filter" : radii > 0,
"radii": radii,
"depth": depth}
76 changes: 76 additions & 0 deletions gaussian_renderer/network_gui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

import torch
import traceback
import socket
import json
from scene.cameras import MiniCam

host = "127.0.0.1"
port = 6009

conn = None
addr = None

listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

def init(wish_host, wish_port):
global host, port, listener
host = wish_host
port = wish_port
listener.bind((host, port))
listener.listen()
listener.settimeout(0)

def try_connect():
global conn, addr, listener
try:
conn, addr = listener.accept()
print(f"\nConnected by {addr}")
conn.settimeout(None)
except Exception as inst:
pass

def read():
global conn
messageLength = conn.recv(4)
messageLength = int.from_bytes(messageLength, 'little')
message = conn.recv(messageLength)
return json.loads(message.decode("utf-8"))

def send(message_bytes, verify):
global conn
if message_bytes != None:
conn.sendall(message_bytes)
conn.sendall(len(verify).to_bytes(4, 'little'))
conn.sendall(bytes(verify, 'ascii'))

def receive():
message = read()

width = message["resolution_x"]
height = message["resolution_y"]

if width != 0 and height != 0:
try:
do_training = bool(message["train"])
fovy = message["fov_y"]
fovx = message["fov_x"]
znear = message["z_near"]
zfar = message["z_far"]
do_shs_python = bool(message["shs_python"])
do_rot_scale_python = bool(message["rot_scale_python"])
keep_alive = bool(message["keep_alive"])
scaling_modifier = message["scaling_modifier"]
world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
world_view_transform[:,1] = -world_view_transform[:,1]
world_view_transform[:,2] = -world_view_transform[:,2]
full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
full_proj_transform[:,1] = -full_proj_transform[:,1]
custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
except Exception as e:
print("")
traceback.print_exc()
raise e
return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
else:
return None, None, None, None, None, None
12 changes: 12 additions & 0 deletions lpipsPyTorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch

from .modules.lpips import LPIPS


def lpips(x: torch.Tensor,
y: torch.Tensor,
net_type: str = 'alex',
version: str = '0.1'):
device = x.device
criterion = LPIPS(net_type, version).to(device)
return criterion(x, y)
28 changes: 28 additions & 0 deletions lpipsPyTorch/modules/lpips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
import torch.nn as nn

from .networks import get_network, LinLayers
from .utils import get_state_dict


class LPIPS(nn.Module):
def __init__(self, net_type: str = 'alex', version: str = '0.1'):

assert version in ['0.1'], 'v0.1 is only supported now'

super(LPIPS, self).__init__()


self.net = get_network(net_type)


self.lin = LinLayers(self.net.n_channels_list)
self.lin.load_state_dict(get_state_dict(net_type, version))

def forward(self, x: torch.Tensor, y: torch.Tensor):
feat_x, feat_y = self.net(x), self.net(y)

diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]

return torch.sum(torch.cat(res, 0), 0, True)
Loading

0 comments on commit f163e64

Please sign in to comment.