Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/controls: mouse capture & vertical movements #13

Merged
merged 1 commit into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions 3dgs/GUIManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ void GUIManager::init() {
}

void GUIManager::buildGui() {
if (mouseCapture) {
ImGui::BeginDisabled(true);
}

ImGui::SetNextWindowSize(ImVec2(400, 250), ImGuiCond_FirstUseEver);
ImGui::SetNextWindowPos(ImVec2(10, 10), ImGuiCond_FirstUseEver);
ImGui::Begin("Performance");
Expand All @@ -37,11 +41,22 @@ void GUIManager::buildGui() {
ImGui::SliderFloat("History", &history, 1, 30, "%.1f s");
ImGui::End();

// always auto resize

bool popen = true;
ImGui::SetNextWindowPos(ImVec2(10, 270), ImGuiCond_FirstUseEver);
ImGui::Begin("Controls");
ImGui::Text("WASD: Move");
ImGui::Text("Mouse: Look");
ImGui::Begin("Controls", &popen, ImGuiWindowFlags_AlwaysAutoResize);
ImGui::Text("WASD: move");
ImGui::Text("Space: up");
ImGui::Text("Shift: down");
ImGui::Text("Left click: capture mouse");
ImGui::Text("ESC: release mouse");
ImGui::Text("Mouse captured: %s", mouseCapture ? "true" : "false");
ImGui::End();

if (mouseCapture) {
ImGui::EndDisabled();
}
}

void GUIManager::pushMetric(const std::string& name, float value) {
Expand All @@ -57,3 +72,11 @@ void GUIManager::pushMetric(const std::unordered_map<std::string, float>& name)
pushMetric(n, v);
}
}

bool GUIManager::wantCaptureMouse() {
return ImGui::GetIO().WantCaptureMouse;
}

bool GUIManager::wantCaptureKeyboard() {
return ImGui::GetIO().WantCaptureKeyboard;
}
8 changes: 7 additions & 1 deletion 3dgs/GUIManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ class GUIManager {

static void init();

static void buildGui();
void buildGui();

static void pushMetric(const std::string& name, float value);

static void pushMetric(const std::unordered_map<std::string, float>& name);

static bool wantCaptureMouse();

static bool wantCaptureKeyboard();

bool mouseCapture = false;

};

#endif //GUIMANAGER_H
145 changes: 86 additions & 59 deletions 3dgs/Renderer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,39 +33,60 @@ void Renderer::handleInput() {
auto translation = window->getCursorTranslation();
auto keys = window->getKeys(); // W, A, S, D

if (!guiManager.wantCaptureMouse() && !guiManager.mouseCapture && window->getMouseButton()[0]) {
window->mouseCapture(true);
guiManager.mouseCapture = true;
}

// rotate camera
if (translation[0] != 0.0 || translation[1] != 0.0) {
camera.rotation = glm::rotate(camera.rotation, static_cast<float>(translation[0]) * 0.005f,
glm::vec3(0.0f, -1.0f, 0.0f));
camera.rotation = glm::rotate(camera.rotation, static_cast<float>(translation[1]) * 0.005f,
glm::vec3(-1.0f, 0.0f, 0.0f));
if (guiManager.mouseCapture) {
if (translation[0] != 0.0 || translation[1] != 0.0) {
camera.rotation = glm::rotate(camera.rotation, static_cast<float>(translation[0]) * 0.005f,
glm::vec3(0.0f, -1.0f, 0.0f));
camera.rotation = glm::rotate(camera.rotation, static_cast<float>(translation[1]) * 0.005f,
glm::vec3(-1.0f, 0.0f, 0.0f));
}
}


// move camera
glm::vec3 direction = glm::vec3(0.0f, 0.0f, 0.0f);
if (keys[0]) {
direction += glm::vec3(0.0f, 0.0f, -1.0f);
}
if (keys[1]) {
direction += glm::vec3(-1.0f, 0.0f, 0.0f);
}
if (keys[2]) {
direction += glm::vec3(0.0f, 0.0f, 1.0f);
}
if (keys[3]) {
direction += glm::vec3(1.0f, 0.0f, 0.0f);
}
if (direction != glm::vec3(0.0f, 0.0f, 0.0f)) {
direction = glm::normalize(direction);
camera.position += (glm::mat4_cast(camera.rotation) * glm::vec4(direction, 1.0f)).xyz() * 0.3f;
if (!guiManager.wantCaptureKeyboard()) {
glm::vec3 direction = glm::vec3(0.0f, 0.0f, 0.0f);
if (keys[0]) {
direction += glm::vec3(0.0f, 0.0f, -1.0f);
}
if (keys[1]) {
direction += glm::vec3(-1.0f, 0.0f, 0.0f);
}
if (keys[2]) {
direction += glm::vec3(0.0f, 0.0f, 1.0f);
}
if (keys[3]) {
direction += glm::vec3(1.0f, 0.0f, 0.0f);
}
if (keys[4]) {
direction += glm::vec3(0.0f, 1.0f, 0.0f);
}
if (keys[5]) {
direction += glm::vec3(0.0f, -1.0f, 0.0f);
}
if (keys[6]) {
window->mouseCapture(false);
guiManager.mouseCapture = false;
}
if (direction != glm::vec3(0.0f, 0.0f, 0.0f)) {
direction = glm::normalize(direction);
camera.position += (glm::mat4_cast(camera.rotation) * glm::vec4(direction, 1.0f)).xyz() * 0.3f;
}
}
}

void Renderer::retrieveTimestamps() {
std::vector<uint64_t> timestamps(queryManager->nextId);
auto res = context->device->getQueryPoolResults(context->queryPool.get(), 0, queryManager->nextId,
timestamps.size() * sizeof(uint64_t),
timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait);
timestamps.size() * sizeof(uint64_t),
timestamps.data(), sizeof(uint64_t),
vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait);
if (res != vk::Result::eSuccess) {
throw std::runtime_error("Failed to retrieve timestamps");
}
Expand Down Expand Up @@ -302,7 +323,7 @@ void Renderer::createRenderPipeline() {
inputSet->build();

auto outputSet = std::make_shared<DescriptorSet>(context, 1);
for (auto&image: swapchain->swapchainImages) {
for (auto& image: swapchain->swapchainImages) {
outputSet->bindImageToDescriptorSet(0, vk::DescriptorType::eStorageImage, vk::ShaderStageFlagBits::eCompute,
image);
}
Expand Down Expand Up @@ -331,8 +352,7 @@ void Renderer::run() {
if (res == vk::Result::eErrorOutOfDateKHR) {
swapchain->recreate();
continue;
}
else if (res != vk::Result::eSuccess && res != vk::Result::eSuboptimalKHR) {
} else if (res != vk::Result::eSuccess && res != vk::Result::eSuboptimalKHR) {
throw std::runtime_error("Failed to acquire swapchain image");
}

Expand All @@ -341,7 +361,7 @@ void Renderer::run() {

updateUniforms();

auto submitInfo = vk::SubmitInfo {}.setCommandBuffers(preprocessCommandBuffer.get());
auto submitInfo = vk::SubmitInfo{}.setCommandBuffers(preprocessCommandBuffer.get());
context->queues[VulkanContext::Queue::COMPUTE].queue.submit(submitInfo, inflightFences[0].get());

ret = context->device->waitForFences(inflightFences[0].get(), VK_TRUE, UINT64_MAX);
Expand All @@ -354,10 +374,10 @@ void Renderer::run() {
goto startOfRenderLoop;
}
vk::PipelineStageFlags waitStage = vk::PipelineStageFlagBits::eComputeShader;
submitInfo = vk::SubmitInfo {}.setWaitSemaphores(swapchain->imageAvailableSemaphores[0].get())
.setCommandBuffers(renderCommandBuffer.get())
.setSignalSemaphores(renderFinishedSemaphores[0].get())
.setWaitDstStageMask(waitStage);
submitInfo = vk::SubmitInfo{}.setWaitSemaphores(swapchain->imageAvailableSemaphores[0].get())
.setCommandBuffers(renderCommandBuffer.get())
.setSignalSemaphores(renderFinishedSemaphores[0].get())
.setWaitDstStageMask(waitStage);
context->queues[VulkanContext::Queue::COMPUTE].queue.submit(submitInfo, inflightFences[0].get());

vk::PresentInfoKHR presentInfo{};
Expand All @@ -370,8 +390,7 @@ void Renderer::run() {
ret = context->queues[VulkanContext::Queue::PRESENT].queue.presentKHR(presentInfo);
if (ret == vk::Result::eErrorOutOfDateKHR || ret == vk::Result::eSuboptimalKHR) {
swapchain->recreate();
}
else if (ret != vk::Result::eSuccess) {
} else if (ret != vk::Result::eSuccess) {
throw std::runtime_error("Failed to present swapchain image");
}

Expand Down Expand Up @@ -436,7 +455,8 @@ void Renderer::recordPreprocessCommandBuffer() {
preprocessCommandBuffer->resetQueryPool(context->queryPool.get(), 0, 12);

preprocessPipeline->bind(preprocessCommandBuffer, 0, 0);
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("preprocess_start"));
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(),
queryManager->registerQuery("preprocess_start"));
preprocessCommandBuffer->dispatch(numGroups, 1, 1);
tileOverlapBuffer->computeWriteReadBarrier(preprocessCommandBuffer.get());

Expand All @@ -445,10 +465,12 @@ void Renderer::recordPreprocessCommandBuffer() {

prefixSumPingBuffer->computeWriteReadBarrier(preprocessCommandBuffer.get());

preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("preprocess_end"));
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(),
queryManager->registerQuery("preprocess_end"));

prefixSumPipeline->bind(preprocessCommandBuffer, 0, 0);
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("prefix_sum_start"));
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(),
queryManager->registerQuery("prefix_sum_start"));
const auto iters = static_cast<uint32_t>(std::ceil(std::log2(static_cast<float>(scene->getNumVertices()))));
for (uint32_t timestep = 0; timestep <= iters; timestep++) {
preprocessCommandBuffer->pushConstants(prefixSumPipeline->pipelineLayout.get(),
Expand All @@ -459,8 +481,7 @@ void Renderer::recordPreprocessCommandBuffer() {
if (timestep % 2 == 0) {
prefixSumPongBuffer->computeWriteReadBarrier(preprocessCommandBuffer.get());
prefixSumPingBuffer->computeReadWriteBarrier(preprocessCommandBuffer.get());
}
else {
} else {
prefixSumPingBuffer->computeWriteReadBarrier(preprocessCommandBuffer.get());
prefixSumPongBuffer->computeReadWriteBarrier(preprocessCommandBuffer.get());
}
Expand All @@ -470,18 +491,19 @@ void Renderer::recordPreprocessCommandBuffer() {
if (iters % 2 == 0) {
preprocessCommandBuffer->copyBuffer(prefixSumPingBuffer->buffer, totalSumBufferHost->buffer, 1,
&totalSumRegion);
}
else {
} else {
preprocessCommandBuffer->copyBuffer(prefixSumPongBuffer->buffer, totalSumBufferHost->buffer, 1,
&totalSumRegion);
}

vertexAttributeBuffer->computeWriteReadBarrier(preprocessCommandBuffer.get());

preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("prefix_sum_end"));
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(),
queryManager->registerQuery("prefix_sum_end"));

preprocessSortPipeline->bind(preprocessCommandBuffer, 0, iters % 2 == 0 ? 0 : 1);
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("preprocess_sort_start"));
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(),
queryManager->registerQuery("preprocess_sort_start"));
uint32_t tileX = (swapchain->swapchainExtent.width + 16 - 1) / 16;
// assert(tileX == 50);
preprocessCommandBuffer->pushConstants(preprocessSortPipeline->pipelineLayout.get(),
Expand All @@ -490,7 +512,8 @@ void Renderer::recordPreprocessCommandBuffer() {
preprocessCommandBuffer->dispatch(numGroups, 1, 1);

sortKBufferEven->computeWriteReadBarrier(preprocessCommandBuffer.get());
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("preprocess_sort_end"));
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(),
queryManager->registerQuery("preprocess_sort_end"));

preprocessCommandBuffer->end();
}
Expand Down Expand Up @@ -535,7 +558,8 @@ bool Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {
for (auto i = 0; i < 8; i++) {
sortHistPipeline->bind(renderCommandBuffer, 0, i % 2 == 0 ? 0 : 1);
if (i == 0) {
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("sort_start"));
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(),
queryManager->registerQuery("sort_start"));
}
auto invocationSize = (numInstances + numRadixSortBlocksPerWorkgroup - 1) / numRadixSortBlocksPerWorkgroup;
invocationSize = (invocationSize + 255) / 256;
Expand All @@ -562,14 +586,14 @@ bool Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {
if (i % 2 == 0) {
sortKBufferOdd->computeWriteReadBarrier(renderCommandBuffer.get());
sortVBufferOdd->computeWriteReadBarrier(renderCommandBuffer.get());
}
else {
} else {
sortKBufferEven->computeWriteReadBarrier(renderCommandBuffer.get());
sortVBufferEven->computeWriteReadBarrier(renderCommandBuffer.get());
}

if (i == 7) {
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("sort_end"));
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(),
queryManager->registerQuery("sort_end"));
}
}

Expand All @@ -583,17 +607,20 @@ bool Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {

// Since we have 64 bit keys, the sort result is always in the even buffer
tileBoundaryPipeline->bind(renderCommandBuffer, 0, 0);
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("tile_boundary_start"));
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(),
queryManager->registerQuery("tile_boundary_start"));
renderCommandBuffer->pushConstants(tileBoundaryPipeline->pipelineLayout.get(),
vk::ShaderStageFlagBits::eCompute, 0,
sizeof(uint32_t), &numInstances);
renderCommandBuffer->dispatch((numInstances + 255) / 256, 1, 1);

tileBoundaryBuffer->computeWriteReadBarrier(renderCommandBuffer.get());
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("tile_boundary_end"));
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(),
queryManager->registerQuery("tile_boundary_end"));

renderPipeline->bind(renderCommandBuffer, 0, std::vector<uint32_t>{0, currentImageIndex});
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("render_start"));
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(),
queryManager->registerQuery("render_start"));
auto [width, height] = window->getFramebufferSize();
uint32_t constants[2] = {width, height};
renderCommandBuffer->pushConstants(renderPipeline->pipelineLayout.get(),
Expand Down Expand Up @@ -626,19 +653,20 @@ bool Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {
imageMemoryBarrier.newLayout = vk::ImageLayout::eColorAttachmentOptimal;
imageMemoryBarrier.dstAccessMask = vk::AccessFlagBits::eColorAttachmentWrite;
renderCommandBuffer->pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader,
vk::PipelineStageFlagBits::eColorAttachmentOutput,
vk::DependencyFlagBits::eByRegion, nullptr, nullptr, imageMemoryBarrier);
vk::PipelineStageFlagBits::eColorAttachmentOutput,
vk::DependencyFlagBits::eByRegion, nullptr, nullptr, imageMemoryBarrier);
} else {
imageMemoryBarrier.newLayout = vk::ImageLayout::ePresentSrcKHR;
imageMemoryBarrier.dstAccessMask = vk::AccessFlagBits::eMemoryRead;
renderCommandBuffer->pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader,
vk::PipelineStageFlagBits::eBottomOfPipe,
vk::DependencyFlagBits::eByRegion, nullptr, nullptr, imageMemoryBarrier);
vk::PipelineStageFlagBits::eBottomOfPipe,
vk::DependencyFlagBits::eByRegion, nullptr, nullptr, imageMemoryBarrier);
}
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("render_end"));
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(),
queryManager->registerQuery("render_end"));

if (configuration.enableGui) {
imguiManager->draw(renderCommandBuffer.get(), currentImageIndex, &GUIManager::buildGui);
imguiManager->draw(renderCommandBuffer.get(), currentImageIndex, std::bind(&GUIManager::buildGui, &guiManager));

imageMemoryBarrier.oldLayout = vk::ImageLayout::eColorAttachmentOptimal;
imageMemoryBarrier.srcAccessMask = vk::AccessFlagBits::eColorAttachmentWrite;
Expand All @@ -647,8 +675,8 @@ bool Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {
imageMemoryBarrier.dstAccessMask = vk::AccessFlagBits::eMemoryRead;

renderCommandBuffer->pipelineBarrier(vk::PipelineStageFlagBits::eColorAttachmentOutput,
vk::PipelineStageFlagBits::eBottomOfPipe,
vk::DependencyFlagBits::eByRegion, nullptr, nullptr, imageMemoryBarrier);
vk::PipelineStageFlagBits::eBottomOfPipe,
vk::DependencyFlagBits::eByRegion, nullptr, nullptr, imageMemoryBarrier);
}
renderCommandBuffer->end();

Expand Down Expand Up @@ -690,5 +718,4 @@ void Renderer::updateUniforms() {
}

Renderer::~Renderer() {

}
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ VulkanSplatting is an (not-yet-highly-) optimized, cross-platform implementation
## TODO
The goal of this project is to provide a go-to implementation for high performance rendering of point-based radiance fields that works on all platforms, but we need your help! Please feel free to open an issue if you have any ideas or are interested in contributing.

- [ ] Better controls and GUI on GLFW
- [x] Better controls and GUI on GLFW
- [ ] Implement SOTA parallel radix sort for sorting Gaussian instances
- [ ] Use Vulkan subgroups to batch Gaussian retrievals at the warp level
- [ ] OpenXR support
Expand Down
Loading
Loading