Skip to content

Commit

Permalink
add custom img and flags
Browse files Browse the repository at this point in the history
  • Loading branch information
maturk committed Oct 3, 2023
1 parent 6ad2a8e commit d9e6499
Showing 1 changed file with 29 additions and 15 deletions.
44 changes: 29 additions & 15 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import random
from pathlib import Path
from typing import Optional

import numpy as np
import torch
Expand Down Expand Up @@ -93,7 +94,7 @@ def _init_gaussians(self):
def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = True):
optimizer = optim.Adam(
[self.rgbs, self.means, self.scales, self.opacities, self.quats], lr
) # try training self.opacities/scales etc.
)
mse_loss = torch.nn.MSELoss()
frames = []
for iter in range(iterations):
Expand Down Expand Up @@ -124,9 +125,7 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = True
)
loss = mse_loss(out_img, self.gt_image)
optimizer.zero_grad()
torch.cuda.synchronize()
loss.backward()
torch.cuda.synchronize()
optimizer.step()
print(f"Iteration {iter + 1}/{iterations}, Loss: {loss.item()}")

Expand All @@ -136,7 +135,7 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = True
# save them as a gif with PIL
frames = [Image.fromarray(frame) for frame in frames]
frames[0].save(
os.getcwd() + f"/renders/training.gif",
os.getcwd() + "/renders/training.gif",
save_all=True,
append_images=frames[1:],
optimize=False,
Expand All @@ -145,24 +144,39 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = True
)


def image_path_to_tensor(image_path):
def image_path_to_tensor(image_path: Path):
import torchvision.transforms as transforms

img = Image.open(image_path)
transform = transforms.ToTensor()
img_tensor = transform(img).permute(1, 2, 0)
img_tensor = transform(img).permute(1, 2, 0)[..., :3]
return img_tensor


def main(height: int = 256, width: int = 256) -> None:
gt_image = torch.ones((height, width, 3)) * 1.0
# make top left and bottom right red,blue
gt_image[: height // 2, : width // 2, :] = torch.tensor([1.0, 0.0, 0.0])
gt_image[height // 2 :, width // 2 :, :] = torch.tensor([0.0, 0.0, 1.0])
# gt_image = image_path_to_tensor(os.getcwd() + "path_to_your_image")
trainer = SimpleTrainer(gt_image=gt_image)

trainer.train()
def main(
height: int = 256,
width: int = 256,
num_points: int = 2000,
save_imgs: bool = True,
img_path: Optional[Path] = None,
iterations: int = 1000,
lr: float = 0.01,
) -> None:

if img_path:
gt_image = image_path_to_tensor(img_path)
else:
gt_image = torch.ones((height, width, 3)) * 1.0
# make top left and bottom right red, blue
gt_image[: height // 2, : width // 2, :] = torch.tensor([1.0, 0.0, 0.0])
gt_image[height // 2 :, width // 2 :, :] = torch.tensor([0.0, 0.0, 1.0])

trainer = SimpleTrainer(gt_image=gt_image, num_points=num_points)
trainer.train(
iterations=iterations,
lr=lr,
save_imgs=save_imgs,
)


if __name__ == "__main__":
Expand Down

0 comments on commit d9e6499

Please sign in to comment.