-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8a7c59a
commit f163e64
Showing
1,575 changed files
with
274,133 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
|
||
from argparse import ArgumentParser, Namespace | ||
import sys | ||
import os | ||
|
||
class GroupParams: | ||
pass | ||
|
||
class ParamGroup: | ||
def __init__(self, parser: ArgumentParser, name : str, fill_none = False): | ||
group = parser.add_argument_group(name) | ||
for key, value in vars(self).items(): | ||
shorthand = False | ||
if key.startswith("_"): | ||
shorthand = True | ||
key = key[1:] | ||
t = type(value) | ||
value = value if not fill_none else None | ||
if shorthand: | ||
if t == bool: | ||
group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") | ||
else: | ||
group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) | ||
else: | ||
if t == bool: | ||
group.add_argument("--" + key, default=value, action="store_true") | ||
else: | ||
group.add_argument("--" + key, default=value, type=t) | ||
|
||
def extract(self, args): | ||
group = GroupParams() | ||
for arg in vars(args).items(): | ||
if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): | ||
setattr(group, arg[0], arg[1]) | ||
return group | ||
|
||
class ModelParams(ParamGroup): | ||
def __init__(self, parser, sentinel=False): | ||
self.sh_degree = 3 | ||
self._source_path = "" | ||
self._model_path = "" | ||
self._images = "images" | ||
self._resolution = -1 | ||
self._white_background = False | ||
self.data_device = "cuda" | ||
self.eval = False | ||
super().__init__(parser, "Loading Parameters", sentinel) | ||
|
||
def extract(self, args): | ||
g = super().extract(args) | ||
g.source_path = os.path.abspath(g.source_path) | ||
return g | ||
|
||
class PipelineParams(ParamGroup): | ||
def __init__(self, parser): | ||
self.convert_SHs_python = False | ||
self.compute_cov3D_python = False | ||
self.debug = False | ||
super().__init__(parser, "Pipeline Parameters") | ||
|
||
class OptimizationParams(ParamGroup): | ||
def __init__(self, parser): | ||
self.iterations = 30_000 | ||
self.position_lr_init = 0.00016 | ||
self.position_lr_final = 0.0000016 | ||
self.position_lr_delay_mult = 0.01 | ||
self.position_lr_max_steps = 30_000 | ||
self.feature_lr = 0.0025 | ||
self.opacity_lr = 0.05 | ||
self.scaling_lr = 0.005 | ||
self.rotation_lr = 0.001 | ||
self.percent_dense = 0.01 | ||
self.lambda_dssim = 0.2 | ||
self.densification_interval = 100 | ||
self.opacity_reset_interval = 3000 | ||
self.densify_from_iter = 500 | ||
self.densify_until_iter = 15_000 | ||
self.densify_grad_threshold = 0.0002 | ||
self.random_background = False | ||
super().__init__(parser, "Optimization Parameters") | ||
|
||
def get_combined_args(parser : ArgumentParser): | ||
cmdlne_string = sys.argv[1:] | ||
cfgfile_string = "Namespace()" | ||
args_cmdline = parser.parse_args(cmdlne_string) | ||
|
||
try: | ||
cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") | ||
print("Looking for config file in", cfgfilepath) | ||
with open(cfgfilepath) as cfg_file: | ||
print("Config file found: {}".format(cfgfilepath)) | ||
cfgfile_string = cfg_file.read() | ||
except TypeError: | ||
print("Config file not found at") | ||
pass | ||
args_cfgfile = eval(cfgfile_string) | ||
|
||
merged_dict = vars(args_cfgfile).copy() | ||
for k,v in vars(args_cmdline).items(): | ||
if v != None: | ||
merged_dict[k] = v | ||
return Namespace(**merged_dict) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
name: gaussian_splatting | ||
channels: | ||
- pytorch | ||
- conda-forge | ||
- defaults | ||
dependencies: | ||
- cudatoolkit=11.6 | ||
- plyfile=0.8.1 | ||
- python=3.7.13 | ||
- pip=22.3.1 | ||
- pytorch=1.12.1 | ||
- torchaudio=0.12.1 | ||
- torchvision=0.13.1 | ||
- tqdm | ||
- pip: | ||
- submodules/diff-gaussian-rasterization | ||
- submodules/simple-knn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import torch | ||
import math | ||
import sys | ||
sys.path.append('..') | ||
from submodules.diff_gaussian_rasterization.diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer | ||
from scene.gaussian_model import GaussianModel | ||
from utils.sh_utils import eval_sh | ||
|
||
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None, low_pass = 0.3): | ||
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 | ||
try: | ||
screenspace_points.retain_grad() | ||
except: | ||
pass | ||
|
||
|
||
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) | ||
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) | ||
|
||
raster_settings = GaussianRasterizationSettings( | ||
image_height=int(viewpoint_camera.image_height), | ||
image_width=int(viewpoint_camera.image_width), | ||
tanfovx=tanfovx, | ||
tanfovy=tanfovy, | ||
bg=bg_color, | ||
scale_modifier=scaling_modifier, | ||
viewmatrix=viewpoint_camera.world_view_transform, | ||
projmatrix=viewpoint_camera.full_proj_transform, | ||
sh_degree=pc.active_sh_degree, | ||
campos=viewpoint_camera.camera_center, | ||
prefiltered=False, | ||
debug=pipe.debug, | ||
low_pass=low_pass | ||
) | ||
rasterizer = GaussianRasterizer(raster_settings=raster_settings) | ||
|
||
means3D = pc.get_xyz | ||
means2D = screenspace_points | ||
opacity = pc.get_opacity | ||
|
||
scales = None | ||
rotations = None | ||
cov3D_precomp = None | ||
if pipe.compute_cov3D_python: | ||
cov3D_precomp = pc.get_covariance(scaling_modifier) | ||
else: | ||
scales = pc.get_scaling | ||
rotations = pc.get_rotation | ||
|
||
shs = None | ||
colors_precomp = None | ||
if override_color is None: | ||
if pipe.convert_SHs_python: | ||
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) | ||
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) | ||
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) | ||
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) | ||
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) | ||
else: | ||
shs = pc.get_features | ||
else: | ||
colors_precomp = override_color | ||
|
||
rendered_image, radii, depth = rasterizer( | ||
means3D = means3D, | ||
means2D = means2D, | ||
shs = shs, | ||
colors_precomp = colors_precomp, | ||
opacities = opacity, | ||
scales = scales, | ||
rotations = rotations, | ||
cov3D_precomp = cov3D_precomp) | ||
|
||
|
||
return {"render": rendered_image, | ||
"viewspace_points": screenspace_points, | ||
"visibility_filter" : radii > 0, | ||
"radii": radii, | ||
"depth": depth} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
|
||
import torch | ||
import traceback | ||
import socket | ||
import json | ||
from scene.cameras import MiniCam | ||
|
||
host = "127.0.0.1" | ||
port = 6009 | ||
|
||
conn = None | ||
addr = None | ||
|
||
listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||
|
||
def init(wish_host, wish_port): | ||
global host, port, listener | ||
host = wish_host | ||
port = wish_port | ||
listener.bind((host, port)) | ||
listener.listen() | ||
listener.settimeout(0) | ||
|
||
def try_connect(): | ||
global conn, addr, listener | ||
try: | ||
conn, addr = listener.accept() | ||
print(f"\nConnected by {addr}") | ||
conn.settimeout(None) | ||
except Exception as inst: | ||
pass | ||
|
||
def read(): | ||
global conn | ||
messageLength = conn.recv(4) | ||
messageLength = int.from_bytes(messageLength, 'little') | ||
message = conn.recv(messageLength) | ||
return json.loads(message.decode("utf-8")) | ||
|
||
def send(message_bytes, verify): | ||
global conn | ||
if message_bytes != None: | ||
conn.sendall(message_bytes) | ||
conn.sendall(len(verify).to_bytes(4, 'little')) | ||
conn.sendall(bytes(verify, 'ascii')) | ||
|
||
def receive(): | ||
message = read() | ||
|
||
width = message["resolution_x"] | ||
height = message["resolution_y"] | ||
|
||
if width != 0 and height != 0: | ||
try: | ||
do_training = bool(message["train"]) | ||
fovy = message["fov_y"] | ||
fovx = message["fov_x"] | ||
znear = message["z_near"] | ||
zfar = message["z_far"] | ||
do_shs_python = bool(message["shs_python"]) | ||
do_rot_scale_python = bool(message["rot_scale_python"]) | ||
keep_alive = bool(message["keep_alive"]) | ||
scaling_modifier = message["scaling_modifier"] | ||
world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() | ||
world_view_transform[:,1] = -world_view_transform[:,1] | ||
world_view_transform[:,2] = -world_view_transform[:,2] | ||
full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() | ||
full_proj_transform[:,1] = -full_proj_transform[:,1] | ||
custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) | ||
except Exception as e: | ||
print("") | ||
traceback.print_exc() | ||
raise e | ||
return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier | ||
else: | ||
return None, None, None, None, None, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import torch | ||
|
||
from .modules.lpips import LPIPS | ||
|
||
|
||
def lpips(x: torch.Tensor, | ||
y: torch.Tensor, | ||
net_type: str = 'alex', | ||
version: str = '0.1'): | ||
device = x.device | ||
criterion = LPIPS(net_type, version).to(device) | ||
return criterion(x, y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from .networks import get_network, LinLayers | ||
from .utils import get_state_dict | ||
|
||
|
||
class LPIPS(nn.Module): | ||
def __init__(self, net_type: str = 'alex', version: str = '0.1'): | ||
|
||
assert version in ['0.1'], 'v0.1 is only supported now' | ||
|
||
super(LPIPS, self).__init__() | ||
|
||
|
||
self.net = get_network(net_type) | ||
|
||
|
||
self.lin = LinLayers(self.net.n_channels_list) | ||
self.lin.load_state_dict(get_state_dict(net_type, version)) | ||
|
||
def forward(self, x: torch.Tensor, y: torch.Tensor): | ||
feat_x, feat_y = self.net(x), self.net(y) | ||
|
||
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] | ||
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] | ||
|
||
return torch.sum(torch.cat(res, 0), 0, True) |
Oops, something went wrong.