Skip to content

Commit

Permalink
CUDA Throughput Improvement, main branch (2024.05.06.) (#576)
Browse files Browse the repository at this point in the history
* Made I/O functions use traccc::io::get_absolute_path(...) during reading.

This is to make it possible to pick up files from anywhere, as long
as the user provides an absolute file name.

* Returning a reduced amount of info about the fitted tracks.

This is to avoid the (currently) very expensive copy of the jagged
vector of track states back to the host.

* Fixed the geometry file locations for the tests.

So that they would not be treated as absolute path names
by the updated I/O code.

* Removed the comparison of chi2 values from the fit results.

This is a temporary measure, until the chi2 values "start behaving".
But for now it's not completely ridiculous that the CPU and GPU
code would produce different values for that parameter.
  • Loading branch information
krasznaa authored May 6, 2024
1 parent 592231d commit 2b0b03a
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 69 deletions.
8 changes: 4 additions & 4 deletions examples/run/cuda/full_chain_algorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ full_chain_algorithm::full_chain_algorithm(
m_fitting(fitting_config,
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
m_stream),
m_result_copy(memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy),
m_finder_config(finder_config),
m_grid_config(grid_config),
m_filter_config(filter_config),
Expand Down Expand Up @@ -113,7 +112,6 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
m_fitting(parent.m_fitting_config,
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
m_stream),
m_result_copy(memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy),
m_finder_config(parent.m_finder_config),
m_grid_config(parent.m_grid_config),
m_filter_config(parent.m_filter_config),
Expand Down Expand Up @@ -179,8 +177,10 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
m_fitting(m_device_detector_view, m_field, navigation_buffer,
track_candidates);

// Return the final container, copied back to the host.
return m_result_copy(track_states);
// Copy a limited amount of result data back to the host.
output_type result{&m_host_mr};
m_copy(track_states.headers, result)->wait();
return result;

}
// If not, copy the track parameters back to the host, and return a dummy
Expand Down
12 changes: 5 additions & 7 deletions examples/run/cuda/full_chain_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "traccc/cuda/seeding/seeding_algorithm.hpp"
#include "traccc/cuda/seeding/track_params_estimation.hpp"
#include "traccc/cuda/utils/stream.hpp"
#include "traccc/device/container_d2h_copy_alg.hpp"
#include "traccc/edm/cell.hpp"
#include "traccc/edm/track_state.hpp"
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
Expand All @@ -30,6 +29,7 @@
#include "detray/propagator/rk_stepper.hpp"

// VecMem include(s).
#include <vecmem/containers/vector.hpp>
#include <vecmem/memory/binary_page_memory_resource.hpp>
#include <vecmem/memory/cuda/device_memory_resource.hpp>
#include <vecmem/memory/memory_resource.hpp>
Expand All @@ -44,9 +44,10 @@ namespace traccc::cuda {
///
/// At least as much as is implemented in the project at any given moment.
///
class full_chain_algorithm : public algorithm<track_state_container_types::host(
const cell_collection_types::host&,
const cell_module_collection_types::host&)> {
class full_chain_algorithm
: public algorithm<vecmem::vector<fitting_result<default_algebra>>(
const cell_collection_types::host&,
const cell_module_collection_types::host&)> {

public:
/// @name Type declaration(s)
Expand Down Expand Up @@ -161,9 +162,6 @@ class full_chain_algorithm : public algorithm<track_state_container_types::host(
/// Track fitting algorithm
fitting_algorithm m_fitting;

/// Algorithm copying the result container back to the host
device::container_d2h_copy_alg<track_state_container_types> m_result_copy;

/// @}

/// @name Algorithm configurations
Expand Down
31 changes: 21 additions & 10 deletions io/src/event_map2.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2022 CERN for the benefit of the ACTS project
* (c) 2022-2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/
Expand All @@ -13,26 +13,37 @@
#include "traccc/io/csv/make_measurement_reader.hpp"
#include "traccc/io/csv/make_particle_reader.hpp"
#include "traccc/io/utils.hpp"

// System include(s).
#include <filesystem>

namespace traccc {

event_map2::event_map2(std::size_t event, const std::string& measurement_dir,
const std::string& hit_dir,
const std::string particle_dir) {

std::string io_measurement_hit_id_file =
io::data_directory() + hit_dir +
io::get_event_filename(event, "-measurement-simhit-map.csv");
io::get_absolute_path((std::filesystem::path(hit_dir) /
std::filesystem::path(io::get_event_filename(
event, "-measurement-simhit-map.csv")))
.native());

std::string io_particle_file =
io::data_directory() + particle_dir +
io::get_event_filename(event, "-particles.csv");
std::string io_particle_file = io::get_absolute_path(
(std::filesystem::path(particle_dir) /
std::filesystem::path(io::get_event_filename(event, "-particles.csv")))
.native());

std::string io_hit_file = io::data_directory() + hit_dir +
io::get_event_filename(event, "-hits.csv");
std::string io_hit_file = io::get_absolute_path(
(std::filesystem::path(hit_dir) /
std::filesystem::path(io::get_event_filename(event, "-hits.csv")))
.native());

std::string io_measurement_file =
io::data_directory() + measurement_dir +
io::get_event_filename(event, "-measurements.csv");
io::get_absolute_path((std::filesystem::path(measurement_dir) /
std::filesystem::path(io::get_event_filename(
event, "-measurements.csv")))
.native());

auto mreader = io::csv::make_measurement_reader(io_measurement_file);

Expand Down
33 changes: 23 additions & 10 deletions io/src/mapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
#include "traccc/clusterization/measurement_creation_algorithm.hpp"
#include "traccc/clusterization/sparse_ccl_algorithm.hpp"

// System include(s).
#include <filesystem>

namespace traccc {

particle_map generate_particle_map(std::size_t event,
Expand All @@ -31,8 +34,10 @@ particle_map generate_particle_map(std::size_t event,

// Read the particles from the relevant event file
std::string io_particles_file =
io::data_directory() + particle_dir +
io::get_event_filename(event, "-particles_initial.csv");
io::get_absolute_path((std::filesystem::path(particle_dir) /
std::filesystem::path(io::get_event_filename(
event, "-particles_initial.csv")))
.native());

auto preader = io::csv::make_particle_reader(io_particles_file);

Expand Down Expand Up @@ -61,8 +66,10 @@ hit_particle_map generate_hit_particle_map(std::size_t event,
auto pmap = generate_particle_map(event, particle_dir);

// Read the hits from the relevant event file
std::string io_hits_file = io::data_directory() + hits_dir +
io::get_event_filename(event, "-hits.csv");
std::string io_hits_file = io::get_absolute_path(
(std::filesystem::path(hits_dir) /
std::filesystem::path(io::get_event_filename(event, "-hits.csv")))
.native());

auto hreader = io::csv::make_hit_reader(io_hits_file);

Expand Down Expand Up @@ -93,17 +100,21 @@ hit_map generate_hit_map(std::size_t event, const std::string& hits_dir) {
hit_map result;

// Read the hits from the relevant event file
std::string io_hits_file = io::data_directory() + hits_dir +
io::get_event_filename(event, "-hits.csv");
std::string io_hits_file = io::get_absolute_path(
(std::filesystem::path(hits_dir) /
std::filesystem::path(io::get_event_filename(event, "-hits.csv")))
.native());

auto hreader = io::csv::make_hit_reader(io_hits_file);

io::csv::hit iohit;

// Read the hits from the relevant event file
std::string io_measurement_hit_id_file =
io::data_directory() + hits_dir +
io::get_event_filename(event, "-measurement-simhit-map.csv");
io::get_absolute_path((std::filesystem::path(hits_dir) /
std::filesystem::path(io::get_event_filename(
event, "-measurement-simhit-map.csv")))
.native());

auto mhid_reader =
io::csv::make_measurement_hit_id_reader(io_measurement_hit_id_file);
Expand Down Expand Up @@ -141,8 +152,10 @@ hit_cell_map generate_hit_cell_map(std::size_t event,
auto hmap = generate_hit_map(event, hits_dir);

// Read the cells from the relevant event file
std::string io_cells_file = io::data_directory() + cells_dir +
io::get_event_filename(event, "-cells.csv");
std::string io_cells_file = io::get_absolute_path(
(std::filesystem::path(cells_dir) /
std::filesystem::path(io::get_event_filename(event, "-cells.csv")))
.native());

auto creader = io::csv::make_cell_reader(io_cells_file);

Expand Down
28 changes: 20 additions & 8 deletions io/src/read_cells.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include "read_binary.hpp"
#include "traccc/io/utils.hpp"

// System include(s).
#include <filesystem>

namespace traccc::io {

void read_cells(
Expand All @@ -23,19 +26,28 @@ void read_cells(

switch (format) {
case data_format::csv: {
read_cells(out,
data_directory() + directory.data() +
get_event_filename(event, "-cells.csv"),
format, geom, dconfig, barcode_map, deduplicate);
read_cells(
out,
get_absolute_path((std::filesystem::path(directory) /
std::filesystem::path(
get_event_filename(event, "-cells.csv")))
.native()),
format, geom, dconfig, barcode_map, deduplicate);
break;
}
case data_format::binary: {
details::read_binary_collection<cell_collection_types::host>(
out.cells, data_directory() + directory.data() +
get_event_filename(event, "-cells.dat"));
out.cells,
get_absolute_path((std::filesystem::path(directory) /
std::filesystem::path(
get_event_filename(event, "-cells.dat")))
.native()));
details::read_binary_collection<cell_module_collection_types::host>(
out.modules, data_directory() + directory.data() +
get_event_filename(event, "-modules.dat"));
out.modules,
get_absolute_path((std::filesystem::path(directory) /
std::filesystem::path(get_event_filename(
event, "-modules.dat")))
.native()));
break;
}
default:
Expand Down
4 changes: 2 additions & 2 deletions io/src/read_digitization_config.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2022 CERN for the benefit of the ACTS project
* (c) 2022-2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/
Expand Down Expand Up @@ -64,7 +64,7 @@ digitization_config read_digitization_config(std::string_view filename,
data_format format) {

// Construct the full filename.
std::string full_filename = data_directory() + filename.data();
std::string full_filename = get_absolute_path(filename);

// Decide how to read the file.
switch (format) {
Expand Down
2 changes: 1 addition & 1 deletion io/src/read_geometry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ std::pair<geometry,
read_geometry(std::string_view filename, data_format format) {

// Construct the full file name.
const std::string full_filename = data_directory() + filename.data();
const std::string full_filename = get_absolute_path(filename);

// Decide how to read the file.
switch (format) {
Expand Down
24 changes: 17 additions & 7 deletions io/src/read_measurements.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2022 CERN for the benefit of the ACTS project
* (c) 2022-2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/
Expand All @@ -12,6 +12,9 @@
#include "read_binary.hpp"
#include "traccc/io/utils.hpp"

// System include(s).
#include <filesystem>

namespace traccc::io {

void read_measurements(measurement_reader_output& out, std::size_t event,
Expand All @@ -21,20 +24,27 @@ void read_measurements(measurement_reader_output& out, std::size_t event,
case data_format::csv: {
read_measurements(
out,
data_directory() + directory.data() +
get_event_filename(event, "-measurements.csv"),
get_absolute_path((std::filesystem::path(directory) /
std::filesystem::path(get_event_filename(
event, "-measurements.csv")))
.native()),
format);
break;
}
case data_format::binary: {

details::read_binary_collection<measurement_collection_types::host>(
out.measurements,
data_directory() + directory.data() +
get_event_filename(event, "-measurements.dat"));
get_absolute_path((std::filesystem::path(directory) /
std::filesystem::path(get_event_filename(
event, "-measurements.dat")))
.native()));
details::read_binary_collection<cell_module_collection_types::host>(
out.modules, data_directory() + directory.data() +
get_event_filename(event, "-modules.dat"));
out.modules,
get_absolute_path((std::filesystem::path(directory) /
std::filesystem::path(get_event_filename(
event, "-modules.dat")))
.native()));
break;
}
default:
Expand Down
11 changes: 8 additions & 3 deletions io/src/read_particles.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2022 CERN for the benefit of the ACTS project
* (c) 2022-2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/
Expand All @@ -11,6 +11,9 @@
#include "csv/read_particles.hpp"
#include "traccc/io/utils.hpp"

// System include(s).
#include <filesystem>

namespace traccc::io {

particle_collection_types::host read_particles(std::size_t event,
Expand All @@ -21,8 +24,10 @@ particle_collection_types::host read_particles(std::size_t event,
switch (format) {
case data_format::csv:
return read_particles(
data_directory() + directory.data() +
get_event_filename(event, "-particles.csv"),
get_absolute_path((std::filesystem::path(directory) /
std::filesystem::path(get_event_filename(
event, "-particles.csv")))
.native()),
format, mr);
default:
throw std::invalid_argument("Unsupported data format");
Expand Down
Loading

0 comments on commit 2b0b03a

Please sign in to comment.