-
Notifications
You must be signed in to change notification settings - Fork 290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2DGS #208
2DGS #208
Changes from 58 commits
086dfe0
906a955
b288462
436b424
9a549d4
5764be5
cd7786f
f408b13
8580e5e
e596d75
fecabc5
151ee3d
8cb5154
9a92cca
d6ead12
449e1a4
af66b5e
4c3717c
f15625b
39ca680
966ec1f
9530c89
79298a4
6a0c8c9
d7024ec
3b19b70
f484aa0
3a00b5f
37a9f3e
f2248b6
0473dc9
84ac77a
3909121
7fabcf4
5ab8391
406cd6f
b0b5da1
0ff18e2
1d245be
39b6e1a
1b25093
fd36aa9
4c11611
b37f171
780936e
3cd6690
4d68c74
37106b8
9d47e82
30108ec
b55ea01
76fdf71
ca9f0e1
7101d5f
164c1bc
e4646dd
0bb01b7
6574f15
6e7b302
bb508c0
7afd909
56d0a81
6c1dee0
8bcf6ab
eebe72b
bc4fdfe
091fc71
0a62fe8
7e86657
65ce16d
c79a32e
c31a4dc
2a8f0d6
7784955
2b350c9
8345c5d
a006f94
c82ce96
353fc27
6e2cc36
fd7c866
e0f2879
c2420a8
666705b
d57b75f
a57c525
6dfd704
48abf70
2e8da2e
c538b51
7daa1ef
d34baa2
2f9eb75
5566aa5
bade5bf
afe0308
1a91b03
7bb5af9
b1697c2
408f16e
6987c00
217fdc8
896cc45
4630230
61ce019
c5edf16
4176666
935cb1d
cf6b32f
c5d1449
75de227
8f39645
14a097a
9d80b19
24add4d
3487c74
fd219f8
8c4efba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -121,5 +121,8 @@ venv.bak/ | |
compile_commands.json | ||
*.dump | ||
|
||
data | ||
results | ||
/data | ||
/result | ||
/results | ||
/renders | ||
*.png | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should avoid puting this into gitignore as png file might be needed for the doc page at some point in the future. |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we could merge this file with |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import math | ||
import os | ||
from pathlib import Path | ||
from typing import Optional, Literal | ||
|
||
import torch | ||
import numpy as np | ||
import tyro | ||
import matplotlib | ||
from PIL import Image | ||
|
||
from gsplat import rasterization, rasterization_2dgs | ||
|
||
class SimpleTrainer: | ||
|
||
def __init__( | ||
self, | ||
num_points_per_axis: int = 8, | ||
): | ||
self.device = torch.device("cuda:0") | ||
self.num_points_per_axis = num_points_per_axis | ||
|
||
fov_x = math.pi / 2.0 | ||
self.H, self.W = 256, 256 | ||
self.focal = 0.5 * float(self.W) / math.tan(0.5 * fov_x) | ||
self.img_size = torch.tensor([self.W, self.H, 1], device=self.device) | ||
|
||
self._init_gaussians() | ||
|
||
def _init_gaussians(self): | ||
length = 0.4 | ||
x = np.linspace(-1, 1, self.num_points_per_axis) | ||
y = np.linspace(-1, 1, self.num_points_per_axis) | ||
x, y = np.meshgrid(x, y) | ||
means3D = torch.from_numpy(np.stack([x, y, np.ones_like(x)], axis=-1).reshape(-1, 3)).cuda().float() | ||
quats = torch.zeros(1, 4).repeat(len(means3D), 1).cuda() | ||
quats[..., 0] = 1. | ||
scale = 0.6 / (self.num_points_per_axis - 1) | ||
scales = torch.zeros(1, 3).repeat(len(means3D), 1).fill_(scale).cuda() | ||
num_points = self.num_points_per_axis ** 2 | ||
colors = matplotlib.colormaps['Accent'](np.random.randint(1, num_points, num_points) / num_points)[..., :3] | ||
colors = torch.from_numpy(colors).cuda() | ||
opacity = torch.ones_like(means3D[:, 0]) | ||
|
||
self.viewmat = torch.tensor( | ||
[[-8.6086e-01, 3.7950e-01, -3.3896e-01, 6.7791e-01], | ||
[ 5.0884e-01, 6.4205e-01, -5.7346e-01, 1.1469e+00], | ||
[ 1.0934e-08, -6.6614e-01, -7.4583e-01, 1.4917e+00], | ||
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00]], | ||
device=self.device, | ||
) | ||
|
||
self.means = means3D.float() | ||
self.scales = scales.float() | ||
self.rgbs = colors.float() | ||
self.opacities = opacity.float() | ||
self.quats = quats.float() | ||
|
||
def render( | ||
self, | ||
model_type: Literal["3dgs", "2dgs"] = "3dgs", | ||
): | ||
frames = [] | ||
K = torch.tensor( | ||
[ | ||
[self.focal, 0, self.W / 2], | ||
[0, self.focal, self.H / 2], | ||
[0, 0, 1], | ||
], | ||
device=self.device, | ||
) | ||
if model_type == "3dgs": | ||
renders, _ = rasterization( | ||
self.means, | ||
self.quats / self.quats.norm(dim=-1, keepdim=True), | ||
self.scales, | ||
self.opacities, | ||
self.rgbs, | ||
self.viewmat[None], | ||
K[None], | ||
self.W, | ||
self.H, | ||
packed=False, | ||
) | ||
|
||
elif model_type == "2dgs": | ||
renders, _ = rasterization_2dgs( | ||
self.means, | ||
self.quats / self.quats.norm(dim=-1, keepdim=True), | ||
self.scales, | ||
self.opacities, | ||
self.rgbs, | ||
self.viewmat[None], | ||
K[None], | ||
self.W, | ||
self.H, | ||
packed=False, | ||
) | ||
else: | ||
raise NotImplementedError("Model not implemented") | ||
out_img = renders[0].squeeze(0) | ||
torch.cuda.synchronize() | ||
|
||
frame = (out_img.detach().cpu().numpy() * 255).astype(np.uint8) | ||
frame_img = Image.fromarray(frame) | ||
out_dir = os.path.join(os.getcwd(), "renders") | ||
os.makedirs(out_dir, exist_ok=True) | ||
frame_img.save(f"{out_dir}/{model_type}.png") | ||
|
||
def main( | ||
height: int = 256, | ||
width: int = 256, | ||
num_points_per_axis: int = 8, | ||
save_imgs: bool = True, | ||
img_path: Optional[Path] = None, | ||
model_type: Literal["3dgs", "2dgs"] = "3dgs", | ||
) -> None: | ||
trainer = SimpleTrainer(num_points_per_axis=num_points_per_axis) | ||
trainer.render( | ||
model_type=model_type, | ||
) | ||
|
||
if __name__ == "__main__": | ||
tyro.cli(main) |
liruilong940607 marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,3 +16,4 @@ opencv-python | |
tyro | ||
Pillow | ||
tensorboard | ||
matplotlib |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should dump all outputs to a single
results
folder so that the local file structure is cleaner and we dont need to put all these variation folders into gitignore.