Skip to content

Commit

Permalink
Add callback for querying physical device from OXR runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
shg8 committed Apr 2, 2024
1 parent a92f8a9 commit 00f4257
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 25 deletions.
10 changes: 5 additions & 5 deletions apps/vr_viewer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ project(3dgs_vr_viewer)
find_package(Vulkan REQUIRED)

FetchContent_Declare(OpenXR
GIT_REPOSITORY https://github.com/KhronosGroup/OpenXR-SDK.git
GIT_TAG release-1.0.34
GIT_REPOSITORY https://github.com/KhronosGroup/OpenXR-SDK.git
GIT_TAG release-1.0.34
)

FetchContent_GetProperties(OpenXR)
if (NOT OpenXR_POPULATED)
FetchContent_Populate(OpenXR)
add_subdirectory(${openxr_SOURCE_DIR} ${openxr_BINARY_DIR})
endif()
endif ()

add_executable(3dgs_vr_viewer
src/main.cpp
src/main.cpp
src/VRViewer.cpp
src/VRViewer.h
src/OXRUtils.h
Expand All @@ -32,4 +32,4 @@ target_include_directories(3dgs_vr_viewer PRIVATE
${spdlog_SOURCE_DIR}/include
${Vulkan_INCLUDE_DIRS}
)
target_link_libraries(3dgs_vr_viewer PRIVATE openxr_loader third_party)
target_link_libraries(3dgs_vr_viewer PRIVATE 3dgs_cpp openxr_loader third_party)
16 changes: 16 additions & 0 deletions apps/vr_viewer/src/OXRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,19 @@ std::vector<std::string> OXRContext::getRequiredVulkanDeviceExtensions() const {

return extensionList;
}

void * OXRContext::getPhysicalDevice(void *instance) const {
PFN_xrGetVulkanGraphicsDeviceKHR xrGetVulkanGraphicsDeviceKHR;
auto result = xrGetInstanceProcAddr(oxrInstance, "xrGetVulkanGraphicsDeviceKHR", reinterpret_cast<PFN_xrVoidFunction*>(&xrGetVulkanGraphicsDeviceKHR));
if (XR_FAILED(result)) {
throw std::runtime_error("Failed to get xrGetVulkanGraphicsDeviceKHR");
}

VkPhysicalDevice physicalDevice;
result = xrGetVulkanGraphicsDeviceKHR(oxrInstance, systemId, static_cast<VkInstance>(instance), &physicalDevice);
if (XR_FAILED(result)) {
throw std::runtime_error("Failed to get Vulkan graphics device");
}

return physicalDevice;
}
4 changes: 3 additions & 1 deletion apps/vr_viewer/src/OXRContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
#ifndef OXRCONTEXT_H
#define OXRCONTEXT_H
#include <openxr/openxr.h>
#include <vulkan/vulkan.hpp>

#include "OXRUtils.h"


class OXRContext {
public:
void setup();

void* getPhysicalDevice(void *instance) const;
private:
XrInstance oxrInstance;
XrSystemId systemId;
Expand Down
5 changes: 4 additions & 1 deletion apps/vr_viewer/src/VRViewer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
#include <memory>

#include "OXRContext.h"

#include "3dgs.h"

void VRViewer::run() {
context = std::make_shared<OXRContext>();
context->setup();

VulkanSplatting::OpenXRConfiguration configuration;
configuration.getPhysicalDevice = std::bind(&OXRContext::getPhysicalDevice, context.get(), std::placeholders::_1);
}

3 changes: 3 additions & 0 deletions include/3dgs/3dgs.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef VULKANSPLATTING_H
#define VULKANSPLATTING_H

#include <functional>
#include <optional>
#include <string>
#include <memory>
Expand Down Expand Up @@ -39,6 +40,8 @@ class VulkanSplatting {
struct OpenXRConfiguration {
std::vector<std::string> instanceExtensions;
std::vector<std::string> deviceExtensions;

std::function<void*(void*)> getPhysicalDevice;
};
static std::shared_ptr<RenderingTarget> createOpenXRRenderingTarget(OpenXRConfiguration configuration);
#endif
Expand Down
4 changes: 1 addition & 3 deletions src/3dgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ std::shared_ptr<RenderingTarget> VulkanSplatting::createMetalWindow(void *caMeta
#endif

std::shared_ptr<RenderingTarget> VulkanSplatting::createOpenXRRenderingTarget(OpenXRConfiguration configuration) {
auto target = std::make_shared<OpenXRStereo>();
target->instanceExtensions = configuration.instanceExtensions;
target->deviceExtensions = configuration.deviceExtensions;
auto target = std::make_shared<OpenXRStereo>(configuration);
return target;
}

Expand Down
24 changes: 12 additions & 12 deletions src/Renderer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void Renderer::retrieveTimestamps() {
}

auto metrics = queryManager->parseResults(timestamps);
for (auto& metric: metrics) {
for (auto &metric: metrics) {
if (configuration.enableGui)
guiManager.pushMetric(metric.first, metric.second / 1000000.0);
}
Expand Down Expand Up @@ -124,8 +124,8 @@ void Renderer::initializeVulkan() {

context->createInstance();
auto surface = static_cast<vk::SurfaceKHR>(window->createSurface(context));
auto requiredPhysicalDevice = window->requirePhysicalDevice();
if (!requiredPhysicalDevice.has_value()) {
if (auto requiredPhysicalDevice = window->requirePhysicalDevice(context->instance.get());
!requiredPhysicalDevice.has_value()) {
context->selectPhysicalDevice(configuration.physicalDeviceId, surface);
} else {
context->physicalDevice = requiredPhysicalDevice.value();
Expand Down Expand Up @@ -357,7 +357,7 @@ void Renderer::createRenderPipeline() {
inputSet->build();

auto outputSet = std::make_shared<DescriptorSet>(context, 1);
for (auto& image: swapchain->swapchainImages) {
for (auto &image: swapchain->swapchainImages) {
outputSet->bindImageToDescriptorSet(0, vk::DescriptorType::eStorageImage, vk::ShaderStageFlagBits::eCompute,
image);
}
Expand Down Expand Up @@ -418,7 +418,7 @@ void Renderer::draw() {

try {
ret = context->queues[VulkanContext::Queue::PRESENT].queue.presentKHR(presentInfo);
} catch (vk::OutOfDateKHRError& e) {
} catch (vk::OutOfDateKHRError &e) {
recreateSwapchain();
return;
}
Expand Down Expand Up @@ -528,7 +528,7 @@ void Renderer::recordPreprocessCommandBuffer() {
}

preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eComputeShader, context->queryPool.get(),
queryManager->registerQuery("prefix_sum_end"));
queryManager->registerQuery("prefix_sum_end"));

preprocessCommandBuffer->end();
}
Expand Down Expand Up @@ -583,23 +583,23 @@ bool Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {
auto numGroups = (scene->getNumVertices() + 255) / 256;
preprocessSortPipeline->bind(renderCommandBuffer, 0, iters % 2 == 0 ? 0 : 1);
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eComputeShader, context->queryPool.get(),
queryManager->registerQuery("preprocess_sort_start"));
queryManager->registerQuery("preprocess_sort_start"));
uint32_t tileX = (swapchain->swapchainExtent.width + 16 - 1) / 16;
// assert(tileX == 50);
renderCommandBuffer->pushConstants(preprocessSortPipeline->pipelineLayout.get(),
vk::ShaderStageFlagBits::eCompute, 0,
sizeof(uint32_t), &tileX);
vk::ShaderStageFlagBits::eCompute, 0,
sizeof(uint32_t), &tileX);
renderCommandBuffer->dispatch(numGroups, 1, 1);

sortKBufferEven->computeWriteReadBarrier(renderCommandBuffer.get());
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eComputeShader, context->queryPool.get(),
queryManager->registerQuery("preprocess_sort_end"));
queryManager->registerQuery("preprocess_sort_end"));

// std::cout << "Num instances: " << numInstances << std::endl;

assert(numInstances <= scene->getNumVertices() * sortBufferSizeMultiplier);
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eComputeShader, context->queryPool.get(),
queryManager->registerQuery("sort_start"));
queryManager->registerQuery("sort_start"));
for (auto i = 0; i < 8; i++) {
sortHistPipeline->bind(renderCommandBuffer, 0, i % 2 == 0 ? 0 : 1);
auto invocationSize = (numInstances + numRadixSortBlocksPerWorkgroup - 1) / numRadixSortBlocksPerWorkgroup;
Expand Down Expand Up @@ -633,7 +633,7 @@ bool Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {
}
}
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eComputeShader, context->queryPool.get(),
queryManager->registerQuery("sort_end"));
queryManager->registerQuery("sort_end"));

renderCommandBuffer->fillBuffer(tileBoundaryBuffer->buffer, 0, VK_WHOLE_SIZE, 0);

Expand Down
2 changes: 1 addition & 1 deletion src/vulkan/RenderingTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class RenderingTarget {

virtual void logMovement(float x, float y) { };

virtual std::optional<vk::PhysicalDevice> requirePhysicalDevice() {
virtual std::optional<vk::PhysicalDevice> requirePhysicalDevice(vk::Instance instance) {
return std::nullopt;
}

Expand Down
5 changes: 5 additions & 0 deletions src/vulkan/targets/OpenXRStereo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ void OpenXRStereo::logTranslation(float x, float y) {
void OpenXRStereo::logMovement(float x, float y) {
RenderingTarget::logMovement(x, y);
}

std::optional<vk::PhysicalDevice> OpenXRStereo::requirePhysicalDevice(vk::Instance instance) {
void *pdPtr = configuration.getPhysicalDevice(instance);
return {static_cast<vk::PhysicalDevice>(static_cast<VkPhysicalDevice>(pdPtr))};
}
11 changes: 9 additions & 2 deletions src/vulkan/targets/OpenXRStereo.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
#ifndef OPENXRSTEREO_H
#define OPENXRSTEREO_H
#include <3dgs.h>
#include <string>
#include <vector>

#include "../RenderingTarget.h"

class OpenXRStereo : public RenderingTarget {
public:
explicit OpenXRStereo(const VulkanSplatting::OpenXRConfiguration &configuration)
: configuration(configuration) {
}

VkSurfaceKHR createSurface(std::shared_ptr<VulkanContext> context) override;

std::array<bool, 3> getMouseButton() override;
Expand All @@ -27,8 +32,10 @@ class OpenXRStereo : public RenderingTarget {

void logMovement(float x, float y) override;

std::vector<std::string> instanceExtensions;
std::vector<std::string> deviceExtensions;
std::optional<vk::PhysicalDevice> requirePhysicalDevice(vk::Instance instance) override;

private:
VulkanSplatting::OpenXRConfiguration configuration;
};


Expand Down

0 comments on commit 00f4257

Please sign in to comment.