Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for macOS #5

Merged
merged 15 commits into from
Feb 17, 2024
Prev Previous commit
Next Next commit
WIP: add timestamps
shg8 committed Feb 15, 2024
commit 829d238372b84ed64880268a8290b2a750c693eb
37 changes: 35 additions & 2 deletions 3dgs/Renderer.cpp
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@

#include "../vulkan/Utils.h"

#define SORT_ALLOCATE_MULTIPLIER 5
#define SORT_ALLOCATE_MULTIPLIER 50

void Renderer::initialize() {
initializeVulkan();
@@ -60,8 +60,20 @@ void Renderer::handleInput() {
}
}

void Renderer::retrieveTimestamps() {
std::vector<uint64_t> timestamps(queryManager->nextId);
auto res = context->device->getQueryPoolResults(context->queryPool.get(), 0, queryManager->nextId,
timestamps.size() * sizeof(uint64_t),
timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait);
if (res != vk::Result::eSuccess) {
throw std::runtime_error("Failed to retrieve timestamps");
}

queryManager->parseResults(timestamps);
}

void Renderer::initializeVulkan() {
window = std::make_shared<Window>("Vulkan Splatting", 800, 600);
window = std::make_shared<Window>("Vulkan Splatting", 1920, 1080);
context = std::make_shared<VulkanContext>(Window::getRequiredInstanceExtensions(), std::vector<std::string>{},
configuration.enableVulkanValidationLayers);

@@ -346,6 +358,8 @@ void Renderer::run() {
fpsCounter++;
}

retrieveTimestamps();

// auto nn = totalSumBufferHost->readOne<uint32_t>() ;
// auto staging = Buffer::staging(context, nn* sizeof(uint64_t));
// sortKVBufferEven->downloadTo(staging);
@@ -386,6 +400,7 @@ void Renderer::recordPreprocessCommandBuffer() {
preprocessCommandBuffer->begin(vk::CommandBufferBeginInfo{});

preprocessPipeline->bind(preprocessCommandBuffer, 0, 0);
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("preprocess_start"));
preprocessCommandBuffer->dispatch(numGroups, 1, 1);
tileOverlapBuffer->computeWriteReadBarrier(preprocessCommandBuffer.get());

@@ -394,7 +409,10 @@ void Renderer::recordPreprocessCommandBuffer() {

prefixSumPingBuffer->computeWriteReadBarrier(preprocessCommandBuffer.get());

preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("preprocess_end"));

prefixSumPipeline->bind(preprocessCommandBuffer, 0, 0);
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("prefix_sum_start"));
const auto iters = static_cast<uint32_t>(std::ceil(std::log2(static_cast<float>(scene->getNumVertices()))));
for (uint32_t timestep = 0; timestep <= iters; timestep++) {
preprocessCommandBuffer->pushConstants(prefixSumPipeline->pipelineLayout.get(),
@@ -424,7 +442,10 @@ void Renderer::recordPreprocessCommandBuffer() {

vertexAttributeBuffer->computeWriteReadBarrier(preprocessCommandBuffer.get());

preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("prefix_sum_end"));

preprocessSortPipeline->bind(preprocessCommandBuffer, 0, iters % 2 == 0 ? 0 : 1);
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("preprocess_sort_start"));
uint32_t tileX = (swapchain->swapchainExtent.width + 16 - 1) / 16;
// assert(tileX == 50);
preprocessCommandBuffer->pushConstants(preprocessSortPipeline->pipelineLayout.get(),
@@ -433,6 +454,7 @@ void Renderer::recordPreprocessCommandBuffer() {
preprocessCommandBuffer->dispatch(numGroups, 1, 1);

sortKBufferEven->computeWriteReadBarrier(preprocessCommandBuffer.get());
preprocessCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("preprocess_sort_end"));

preprocessCommandBuffer->end();
}
@@ -451,6 +473,9 @@ void Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {
assert(numInstances <= scene->getNumVertices() * SORT_ALLOCATE_MULTIPLIER);
for (auto i = 0; i < 8; i++) {
sortHistPipeline->bind(renderCommandBuffer, 0, i % 2 == 0 ? 0 : 1);
if (i == 0) {
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("sort_start"));
}
auto invocationSize = (numInstances + numRadixSortBlocksPerWorkgroup - 1) / numRadixSortBlocksPerWorkgroup;
invocationSize = (invocationSize + 255) / 256;

@@ -481,6 +506,10 @@ void Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {
sortKBufferEven->computeWriteReadBarrier(renderCommandBuffer.get());
sortVBufferEven->computeWriteReadBarrier(renderCommandBuffer.get());
}

if (i == 7) {
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("sort_end"));
}
}

renderCommandBuffer->fillBuffer(tileBoundaryBuffer->buffer, 0, VK_WHOLE_SIZE, 0);
@@ -493,14 +522,17 @@ void Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {

// Since we have 64 bit keys, the sort result is always in the even buffer
tileBoundaryPipeline->bind(renderCommandBuffer, 0, 0);
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("tile_boundary_start"));
renderCommandBuffer->pushConstants(tileBoundaryPipeline->pipelineLayout.get(),
vk::ShaderStageFlagBits::eCompute, 0,
sizeof(uint32_t), &numInstances);
renderCommandBuffer->dispatch((numInstances + 255) / 256, 1, 1);

tileBoundaryBuffer->computeWriteReadBarrier(renderCommandBuffer.get());
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("tile_boundary_end"));

renderPipeline->bind(renderCommandBuffer, 0, std::vector<uint32_t>{0, currentImageIndex});
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eTopOfPipe, context->queryPool.get(), queryManager->registerQuery("render_start"));
auto [width, height] = window->getFramebufferSize();
uint32_t constants[2] = {width, height};
renderCommandBuffer->pushConstants(renderPipeline->pipelineLayout.get(),
@@ -533,6 +565,7 @@ void Renderer::recordRenderCommandBuffer(uint32_t currentFrame) {
renderCommandBuffer->pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader,
vk::PipelineStageFlagBits::eBottomOfPipe,
vk::DependencyFlagBits::eByRegion, nullptr, nullptr, imageMemoryBarrier);
renderCommandBuffer->writeTimestamp(vk::PipelineStageFlagBits::eBottomOfPipe, context->queryPool.get(), queryManager->registerQuery("render_end"));
renderCommandBuffer->end();
}

5 changes: 5 additions & 0 deletions 3dgs/Renderer.h
Original file line number Diff line number Diff line change
@@ -11,6 +11,8 @@
#include "../vulkan/Swapchain.h"
#include <glm/gtc/quaternion.hpp>

#include "../vulkan/QueryManager.h"

struct RendererConfiguration {
bool enableVulkanValidationLayers = false;
std::optional<uint8_t> physicalDeviceId = std::nullopt;
@@ -64,6 +66,8 @@ class Renderer {

void handleInput();

void retrieveTimestamps();

void run();

~Renderer();
@@ -72,6 +76,7 @@ class Renderer {
std::shared_ptr<Window> window;
std::shared_ptr<VulkanContext> context;
std::shared_ptr<GSScene> scene;
std::shared_ptr<QueryManager> queryManager = std::make_shared<QueryManager>();

std::shared_ptr<ComputePipeline> preprocessPipeline;
std::shared_ptr<ComputePipeline> renderPipeline;
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -105,6 +105,8 @@ add_executable(vulkan_splatting main.cpp
vulkan/pipelines/ComputePipeline.h
vulkan/Swapchain.cpp
vulkan/Swapchain.h
vulkan/QueryManager.cpp
vulkan/QueryManager.h
)

add_dependencies(vulkan_splatting Shaders)
60 changes: 60 additions & 0 deletions vulkan/QueryManager.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//
// Created by Steven on 2/15/24.
//

#include "QueryManager.h"

#include <iostream>

uint32_t QueryManager::registerQuery(const std::string& name) {
std::lock_guard<std::mutex> lock(mutex);
if (!registry.contains(name)) {
registry[name] = nextId++;
}
return registry[name];
}

uint32_t QueryManager::getQueryId(const std::string& name) {
std::lock_guard<std::mutex> lock(mutex);
if (registry.contains(name)) {
return registry.at(name);
}
return 0;
}

void QueryManager::parseResults(const std::vector<uint64_t>& results) {
// all names end with _start or _end
// calculate the time between the two
// push the results to the results map
// print every 1 seconds
std::lock_guard<std::mutex> lock(mutex);
for (auto& [name, id] : registry) {
if (name.ends_with("_start")) {
auto endName = name.substr(0, name.size() - 5) + "end";
if (registry.contains(endName)) {
auto start = results[id];
auto end = results[registry[endName]];
auto diff = end - start;
if (this->results.contains(name)) {
this->results[name].push_back(diff);
} else {
this->results[name] = {diff};
}
}
}
}
auto now = std::chrono::high_resolution_clock::now();
if (now - lastPrint > std::chrono::seconds(1)) {
lastPrint = now;
for (auto& [name, result] : this->results) {
auto truncated = name.substr(0, name.size() - 6);
std::cout << truncated << ": ";
// calculate average
uint64_t sum = 0;
for (auto& r : result) {
sum += r;
}
std::cout << sum / result.size() / 1000000.0 << "ms" << std::endl;
}
}
}
24 changes: 24 additions & 0 deletions vulkan/QueryManager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef QUERYMANAGER_H
#define QUERYMANAGER_H
#include <mutex>
#include <unordered_map>
#include <sys/types.h>
#include <vector>


class QueryManager {
public:
uint32_t registerQuery(const std::string &name);
[[nodiscard]] uint32_t getQueryId(const std::string &name);
void parseResults(const std::vector<uint64_t>& results);
int nextId = 0;
private:
std::mutex mutex;
std::unordered_map<std::string, uint32_t> registry;
std::unordered_map<std::string, std::vector<uint64_t>> results;
std::chrono::time_point<std::chrono::high_resolution_clock> lastPrint;
};



#endif //QUERYMANAGER_H
12 changes: 12 additions & 0 deletions vulkan/VulkanContext.cpp
Original file line number Diff line number Diff line change
@@ -186,6 +186,13 @@ VulkanContext::QueueFamilyIndices VulkanContext::findQueueFamilies() {
return indices;
}

void VulkanContext::createQueryPool() {
vk::QueryPoolCreateInfo queryPoolCreateInfo = {};
queryPoolCreateInfo.queryType = vk::QueryType::eTimestamp;
queryPoolCreateInfo.queryCount = 20;
queryPool = device->createQueryPoolUnique(queryPoolCreateInfo);
}

void VulkanContext::createLogicalDevice(vk::PhysicalDeviceFeatures deviceFeatures,
vk::PhysicalDeviceVulkan11Features deviceFeatures11,
vk::PhysicalDeviceVulkan12Features deviceFeatures12) {
@@ -212,6 +219,10 @@ void VulkanContext::createLogicalDevice(vk::PhysicalDeviceFeatures deviceFeature
createInfo.pNext = &deviceFeatures11;
deviceFeatures11.pNext = &deviceFeatures12;

vk::PhysicalDeviceHostQueryResetFeatures hostQueryResetFeatures = {};
hostQueryResetFeatures.hostQueryReset = VK_TRUE;
deviceFeatures12.pNext = &hostQueryResetFeatures;

device = physicalDevice.createDeviceUnique(createInfo);

for (auto unique_queue_family: uniqueQueueFamilies) {
@@ -235,6 +246,7 @@ void VulkanContext::createLogicalDevice(vk::PhysicalDeviceFeatures deviceFeature
// Create VMA
setupVma();
createCommandPool();
createQueryPool();
}

vk::UniqueCommandBuffer VulkanContext::beginOneTimeCommandBuffer() {
4 changes: 4 additions & 0 deletions vulkan/VulkanContext.h
Original file line number Diff line number Diff line change
@@ -74,6 +74,8 @@ class VulkanContext {

VulkanContext::QueueFamilyIndices findQueueFamilies();

void createQueryPool();

void createLogicalDevice(vk::PhysicalDeviceFeatures deviceFeatures, vk::PhysicalDeviceVulkan11Features deviceFeatures11, vk::PhysicalDeviceVulkan12Features deviceFeatures12);

void createDescriptorPool(uint8_t framesInFlight);
@@ -92,6 +94,8 @@ class VulkanContext {
VmaAllocator allocator;

vk::UniqueDescriptorPool descriptorPool;
vk::UniqueQueryPool queryPool;

private:
vk::DynamicLoader dl;
std::vector<std::string> instanceExtensions;