Skip to content

Commit

Permalink
Mouse capture & vertical movements (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
shg8 authored Feb 24, 2024
1 parent 875485a commit 98a8e57
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 75 deletions.
29 changes: 26 additions & 3 deletions 3dgs/GUIManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ void GUIManager::init() {
}

void GUIManager::buildGui() {
if (mouseCapture) {
ImGui::BeginDisabled(true);
}

ImGui::SetNextWindowSize(ImVec2(400, 250), ImGuiCond_FirstUseEver);
ImGui::SetNextWindowPos(ImVec2(10, 10), ImGuiCond_FirstUseEver);
ImGui::Begin("Performance");
Expand All @@ -37,11 +41,22 @@ void GUIManager::buildGui() {
ImGui::SliderFloat("History", &history, 1, 30, "%.1f s");
ImGui::End();

// always auto resize

bool popen = true;
ImGui::SetNextWindowPos(ImVec2(10, 270), ImGuiCond_FirstUseEver);
ImGui::Begin("Controls");
ImGui::Text("WASD: Move");
ImGui::Text("Mouse: Look");
ImGui::Begin("Controls", &popen, ImGuiWindowFlags_AlwaysAutoResize);
ImGui::Text("WASD: move");
ImGui::Text("Space: up");
ImGui::Text("Shift: down");
ImGui::Text("Left click: capture mouse");
ImGui::Text("ESC: release mouse");
ImGui::Text("Mouse captured: %s", mouseCapture ? "true" : "false");
ImGui::End();

if (mouseCapture) {
ImGui::EndDisabled();
}
}

void GUIManager::pushMetric(const std::string& name, float value) {
Expand All @@ -57,3 +72,11 @@ void GUIManager::pushMetric(const std::unordered_map<std::string, float>& name)
pushMetric(n, v);
}
}

bool GUIManager::wantCaptureMouse() {
return ImGui::GetIO().WantCaptureMouse;
}

bool GUIManager::wantCaptureKeyboard() {
return ImGui::GetIO().WantCaptureKeyboard;
}
8 changes: 7 additions & 1 deletion 3dgs/GUIManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ class GUIManager {

static void init();

static void buildGui();
void buildGui();

static void pushMetric(const std::string& name, float value);

static void pushMetric(const std::unordered_map<std::string, float>& name);

static bool wantCaptureMouse();

static bool wantCaptureKeyboard();

bool mouseCapture = false;

};

#endif //GUIMANAGER_H
145 changes: 86 additions & 59 deletions 3dgs/Renderer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,39 +33,60 @@ void Renderer::handleInput() {
auto translation = window->getCursorTranslation();
auto keys = window->getKeys(); // W, A, S, D

if (!guiManager.wantCaptureMouse() && !guiManager.mouseCapture && window->getMouseButton()[0]) {
window->mouseCapture(true);
guiManager.mouseCapture = true;
}

// rotate camera
if (translation[0] != 0.0 || translation[1] != 0.0) {
camera.rotation = glm::rotate(camera.rotation, static_cast<float>(translation[0]) * 0.005f,
glm::vec3(0.0f, -1.0f, 0.0f));
camera.rotation = glm::rotate(camera.rotation, static_cast<float>(translation[1]) * 0.005f,
glm::vec3(-1.0f, 0.0f, 0.0f));
if (guiManager.mouseCapture) {
if (translation[0] != 0.0 || translation[1] != 0.0) {
camera.rotation = glm::rotate(camera.rotation, static_cast<float>(translation[0]) * 0.005f,
glm::vec3(0.0f, -1.0f, 0.0f));
camera.rotation = glm::rotate(camera.rotation, static_cast<float>(translation[1]) * 0.005f,
glm::vec3(-1.0f, 0.0f, 0.0f));
}
}


// move camera
glm::vec3 direction = glm::vec3(0.0f, 0.0f, 0.0f);
if (keys[0]) {
direction += glm::vec3(0.0f, 0.0f, -1.0f);
}
if (keys[1]) {
direction += glm::vec3(-1.0f, 0.0f, 0.0f);
}
if (keys[2]) {
direction += glm::vec3(0.0f, 0.0f, 1.0f);
}
if (keys[3]) {
direction += glm::vec3(1.0f, 0.0f, 0.0f);
}
if (direction != glm::vec3(0.0f, 0.0f, 0.0f)) {
direction = glm::normalize(direction);
camera.position += (glm::mat4_cast(camera.rotation) * glm::vec4(direction, 1.0f)).xyz() * 0.3f;
if (!guiManager.wantCaptureKeyboard()) {
glm::vec3 direction = glm::vec3(0.0f, 0.0f, 0.0f);
if (keys[0]) {
direction += glm::vec3(0.0f, 0.0f, -1.0f);
}
if (keys[1]) {
direction += glm::vec3(-1.0f, 0.0f, 0.0f);
}
if (keys[2]) {
direction += glm::vec3(0.0f, 0.0f, 1.0f);
}
if (keys[3]) {
direction += glm::vec3(1.0f, 0.0f, 0.0f);
}
if (keys[4]) {
direction += glm::vec3(0.0f, 1.0f, 0.0f);
}
if (keys[5]) {
direction += glm::vec3(0.0f, -1.0f, 0.0f);
}
if (keys[6]) {
window->mouseCapture(false);
guiManager.mouseCapture = false;
}
if (direction != glm::vec3(0.0f, 0.0f, 0.0f)) {
direction = glm::normalize(direction);
camera.position += (glm::mat4_cast(camera.rotation) * glm::vec4(direction, 1.0f)).xyz() * 0.3f;
}
}
}

void Renderer::retrieveTimestamps() {
std::vector<uint64_t> timestamps(queryManager->nextId);
auto res = context->device->getQueryPoolResults(context->queryPool.get(), 0, queryManager->nextId,
timestamps.size() * sizeof(uint64_t),
timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait);
timestamps.size() * sizeof(uint64_t),
timestamps.data(), sizeof(uint64_t),
vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait);
if (res != vk::Result::eSuccess) {
throw std::runtime_error("Failed to retrieve timestamps");
}
Expand Down Expand Up @@ -302,7 +323,7 @@ void Renderer::createRenderPipeline() {
inputSet->build();

auto outputSet = std::make_shared<DescriptorSet>(context, 1);
for (auto&image: swapchain->swapchainImages) {
for (auto& image: swapchain->swapchainImages) {
outputSet->bindImageToDescriptorSet(0, vk::DescriptorType::eStorageImage, vk::ShaderStageFlagBits::eCompute,
image);
}
Expand Down Expand Up @@ -331,8 +352,7 @@ void Renderer::run() {
if (res == vk::Result::eErrorOutOfDateKHR) {
swapchain->recreate();
continue;
}
else if (res != vk::Result::eSuccess && res != vk::Result::eSuboptimalKHR) {
} else if (res != vk::Result::eSuccess && res != vk::Result::eSuboptimalKHR) {
throw std::runtime_error("Failed to acquire swapchain image");
}

Expand All @@ -341,7 +361,7 @@ void Renderer::run() {

updateUniforms();

auto submitInfo = vk::SubmitInfo {}.setCommandBuffers(preprocessCommandBuffer.get());
auto submitInfo = vk::SubmitInfo{}.setCommandBuffers(preprocessCommandBuffer.get());
context->queues[VulkanContext::Queue::COMPUTE].queue.submit(submitInfo, inflightFences[0].get());

ret = context->device->waitForFences(inflightFences[0].get(), VK_TRUE, UINT64_MAX);
Expand All @@ -354,10 +374,10 @@ void Renderer::run() {
goto startOfRenderLoop;
}
vk::PipelineStageFlags waitStage = vk::PipelineStageFlagBits::eComputeShader;
submitInfo = vk::SubmitInfo {}.setWaitSemaphores(swapchain->imageAvailableSemaphores[0].get())
.setCommandBuffers(renderCommandBuffer.get())
.setSignalSemaphores(renderFinishedSemaphores[0].get())
.setWaitDstStageMask(waitStage);
submitInfo = vk::SubmitInfo{}.setWaitSemaphores(swapchain->imageAvailableSemaphores[0].get())
.setCommandBuffers(renderCommandBuffer.get())
.setSignalSemaphores(renderFinishedSemaphores[0].get())
.setWaitDstStageMask(waitStage);
context->queues[VulkanContext::Queue::COMPUTE].queue.submit(submitInfo, inflightFences[0].get());

vk::PresentInfoKHR presentInfo{};
Expand All @@ -370,8 +390,7 @@ void Renderer::run() {
ret = context->queues[VulkanContext::Queue::PRESENT].queue.presentKHR(presentInfo);
if (ret == vk::Result::eErrorOutOfDateKHR || ret == vk::Result::eSuboptimalKHR) {
swapchain->recreate();
}
else if (ret != vk::Result::eSuccess) {
} else if (ret != vk::Result::eSuccess) {
throw std::runtime_error("Failed to present swapchain image");
}

Expand Down Expand Up @@ -436,7 +455,8 @@ void Renderer::recordPreprocessCommandBuffer() {
preprocessCommandBuffer->resetQueryPool(context->queryPool.get(), 0, 12);

preprocessPipeline->bind(preprocessCommandBuffer, 0, 0);
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("preprocess_start"));
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(),
queryManager->registerQuery("preprocess_start"));
preprocessCommandBuffer->dispatch(numGroups, 1, 1);
tileOverlapBuffer->computeWriteReadBarrier(preprocessCommandBuffer.get());

Expand All @@ -445,10 +465,12 @@ void Renderer::recordPreprocessCommandBuffer() {

prefixSumPingBuffer->computeWriteReadBarrier(preprocessCommandBuffer.get());

preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("preprocess_end"));
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(),
queryManager->registerQuery("preprocess_end"));

prefixSumPipeline->bind(preprocessCommandBuffer, 0, 0);
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("prefix_sum_start"));
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(),
queryManager->registerQuery("prefix_sum_start"));
const auto iters = static_cast<uint32_t>(std::ceil(std::log2(static_cast<float>(scene->getNumVertices()))));
for (uint32_t timestep = 0; timestep <= iters; timestep++) {
preprocessCommandBuffer->pushConstants(prefixSumPipeline->pipelineLayout.get(),
Expand All @@ -459,8 +481,7 @@ void Renderer::recordPreprocessCommandBuffer() {
if (timestep % 2 == 0) {
prefixSumPongBuffer->computeWriteReadBarrier(preprocessCommandBuffer.get());
prefixSumPingBuffer->computeReadWriteBarrier(preprocessCommandBuffer.get());
}
else {
} else {
prefixSumPingBuffer->computeWriteReadBarrier(preprocessCommandBuffer.get());
prefixSumPongBuffer->computeReadWriteBarrier(preprocessCommandBuffer.get());
}
Expand All @@ -470,18 +491,19 @@ void Renderer::recordPreprocessCommandBuffer() {
if (iters % 2 == 0) {
preprocessCommandBuffer->copyBuffer(prefixSumPingBuffer->buffer, totalSumBufferHost->buffer, 1,
&totalSumRegion);
}
else {
} else {
preprocessCommandBuffer->copyBuffer(prefixSumPongBuffer->buffer, totalSumBufferHost->buffer, 1,
&totalSumRegion);
}

vertexAttributeBuffer->computeWriteReadBarrier(preprocessCommandBuffer.get());

preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("prefix_sum_end"));
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(),
queryManager->registerQuery("prefix_sum_end"));

preprocessSortPipeline->bind(preprocessCommandBuffer, 0, iters % 2 == 0 ? 0 : 1);
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("preprocess_sort_start"));
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(),
queryManager->registerQuery("preprocess_sort_start"));
uint32_t tileX = (swapchain->swapchainExtent.width + 16 - 1) / 16;
// assert(tileX == 50);
preprocessCommandBuffer->pushConstants(preprocessSortPipeline->pipelineLayout.get(),
Expand All @@ -490,7 +512,8 @@ void Renderer::recordPreprocessCommandBuffer() {
preprocessCommandBuffer->dispatch(numGroups, 1, 1);

sortKBufferEven->computeWriteReadBarrier(preprocessCommandBuffer.get());
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("preprocess_sort_end"));
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(),
queryManager->registerQuery("preprocess_sort_end"));

preprocessCommandBuffer->end();
}
Expand Down Expand Up @@ -535,7 +558,8 @@ bool Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {
for (auto i = 0; i < 8; i++) {
sortHistPipeline->bind(renderCommandBuffer, 0, i % 2 == 0 ? 0 : 1);
if (i == 0) {
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("sort_start"));
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(),
queryManager->registerQuery("sort_start"));
}
auto invocationSize = (numInstances + numRadixSortBlocksPerWorkgroup - 1) / numRadixSortBlocksPerWorkgroup;
invocationSize = (invocationSize + 255) / 256;
Expand All @@ -562,14 +586,14 @@ bool Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {
if (i % 2 == 0) {
sortKBufferOdd->computeWriteReadBarrier(renderCommandBuffer.get());
sortVBufferOdd->computeWriteReadBarrier(renderCommandBuffer.get());
}
else {
} else {
sortKBufferEven->computeWriteReadBarrier(renderCommandBuffer.get());
sortVBufferEven->computeWriteReadBarrier(renderCommandBuffer.get());
}

if (i == 7) {
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("sort_end"));
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(),
queryManager->registerQuery("sort_end"));
}
}

Expand All @@ -583,17 +607,20 @@ bool Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {

// Since we have 64 bit keys, the sort result is always in the even buffer
tileBoundaryPipeline->bind(renderCommandBuffer, 0, 0);
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("tile_boundary_start"));
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(),
queryManager->registerQuery("tile_boundary_start"));
renderCommandBuffer->pushConstants(tileBoundaryPipeline->pipelineLayout.get(),
vk::ShaderStageFlagBits::eCompute, 0,
sizeof(uint32_t), &numInstances);
renderCommandBuffer->dispatch((numInstances + 255) / 256, 1, 1);

tileBoundaryBuffer->computeWriteReadBarrier(renderCommandBuffer.get());
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("tile_boundary_end"));
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(),
queryManager->registerQuery("tile_boundary_end"));

renderPipeline->bind(renderCommandBuffer, 0, std::vector<uint32_t>{0, currentImageIndex});
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("render_start"));
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(),
queryManager->registerQuery("render_start"));
auto [width, height] = window->getFramebufferSize();
uint32_t constants[2] = {width, height};
renderCommandBuffer->pushConstants(renderPipeline->pipelineLayout.get(),
Expand Down Expand Up @@ -626,19 +653,20 @@ bool Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {
imageMemoryBarrier.newLayout = vk::ImageLayout::eColorAttachmentOptimal;
imageMemoryBarrier.dstAccessMask = vk::AccessFlagBits::eColorAttachmentWrite;
renderCommandBuffer->pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader,
vk::PipelineStageFlagBits::eColorAttachmentOutput,
vk::DependencyFlagBits::eByRegion, nullptr, nullptr, imageMemoryBarrier);
vk::PipelineStageFlagBits::eColorAttachmentOutput,
vk::DependencyFlagBits::eByRegion, nullptr, nullptr, imageMemoryBarrier);
} else {
imageMemoryBarrier.newLayout = vk::ImageLayout::ePresentSrcKHR;
imageMemoryBarrier.dstAccessMask = vk::AccessFlagBits::eMemoryRead;
renderCommandBuffer->pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader,
vk::PipelineStageFlagBits::eBottomOfPipe,
vk::DependencyFlagBits::eByRegion, nullptr, nullptr, imageMemoryBarrier);
vk::PipelineStageFlagBits::eBottomOfPipe,
vk::DependencyFlagBits::eByRegion, nullptr, nullptr, imageMemoryBarrier);
}
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("render_end"));
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(),
queryManager->registerQuery("render_end"));

if (configuration.enableGui) {
imguiManager->draw(renderCommandBuffer.get(), currentImageIndex, &GUIManager::buildGui);
imguiManager->draw(renderCommandBuffer.get(), currentImageIndex, std::bind(&GUIManager::buildGui, &guiManager));

imageMemoryBarrier.oldLayout = vk::ImageLayout::eColorAttachmentOptimal;
imageMemoryBarrier.srcAccessMask = vk::AccessFlagBits::eColorAttachmentWrite;
Expand All @@ -647,8 +675,8 @@ bool Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {
imageMemoryBarrier.dstAccessMask = vk::AccessFlagBits::eMemoryRead;

renderCommandBuffer->pipelineBarrier(vk::PipelineStageFlagBits::eColorAttachmentOutput,
vk::PipelineStageFlagBits::eBottomOfPipe,
vk::DependencyFlagBits::eByRegion, nullptr, nullptr, imageMemoryBarrier);
vk::PipelineStageFlagBits::eBottomOfPipe,
vk::DependencyFlagBits::eByRegion, nullptr, nullptr, imageMemoryBarrier);
}
renderCommandBuffer->end();

Expand Down Expand Up @@ -690,5 +718,4 @@ void Renderer::updateUniforms() {
}

Renderer::~Renderer() {

}
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ VulkanSplatting is an (not-yet-highly-) optimized, cross-platform implementation
## TODO
The goal of this project is to provide a go-to implementation for high performance rendering of point-based radiance fields that works on all platforms, but we need your help! Please feel free to open an issue if you have any ideas or are interested in contributing.

- [ ] Better controls and GUI on GLFW
- [x] Better controls and GUI on GLFW
- [ ] Implement SOTA parallel radix sort for sorting Gaussian instances
- [ ] Use Vulkan subgroups to batch Gaussian retrievals at the warp level
- [ ] OpenXR support
Expand Down
Loading

0 comments on commit 98a8e57

Please sign in to comment.