Skip to content

Commit

Permalink
ref: grid payloads cleanup (#521)
Browse files Browse the repository at this point in the history
Cleans up the IO payloads (in particular the payloads for the links are put into a single place) and switches the grid writer to write the payloads by volume, which will be needed in order to fill the grid builders. Also removes some old payloads from the grids
  • Loading branch information
niermann999 authored Aug 3, 2023
1 parent e50e0f9 commit 265b987
Show file tree
Hide file tree
Showing 14 changed files with 265 additions and 387 deletions.
22 changes: 22 additions & 0 deletions io/include/detray/io/common/detail/definitions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "detray/masks/ring2D.hpp"
#include "detray/masks/single3D.hpp"
#include "detray/masks/trapezoid2D.hpp"
#include "detray/materials/material_rod.hpp"
#include "detray/materials/material_slab.hpp"
#include "detray/utils/type_registry.hpp"

namespace detray {
Expand Down Expand Up @@ -85,6 +87,26 @@ enum class material_type : unsigned int {
unknown = 11u
};

/// Infer the IO material id from the material type
template <typename material_t>
constexpr material_type get_material_id() {
using scalar_t = typename material_t::scalar_type;

/// Register the material types to the @c material_type enum
using mat_registry =
type_registry<material_type, annulus2D<>, cuboid3D<>, cylinder2D<>,
cylinder3D, rectangle2D<>, ring2D<>, trapezoid2D<>,
line<true>, line<false>, material_slab<scalar_t>,
material_rod<scalar_t>>;

// Find the correct material IO id;
if constexpr (mat_registry::is_defined(material_t{})) {
return mat_registry::get_id(material_t{});
} else {
return material_type::unknown;
}
}

/// Enumerate the different acceleration data structures
enum class acc_type : unsigned int {
brute_force = 0u, // try all
Expand Down
8 changes: 2 additions & 6 deletions io/include/detray/io/common/geometry_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,6 @@ class geometry_reader : public reader_interface<detector_t> {
det_builder.template set_volume_finder();
}

/// @returns a link from its io payload @param link_data
static dindex deserialize(const single_link_payload& link_data) {
return static_cast<dindex>(link_data.link);
}

/// @returns a surface transform from its io payload @param trf_data
static typename detector_t::transform3 deserialize(
const transform_payload& trf_data) {
Expand Down Expand Up @@ -177,7 +172,8 @@ class geometry_reader : public reader_interface<detector_t> {
std::back_inserter(mask_boundaries));

return {deserialize(sf_data.transform),
static_cast<nav_link_t>(deserialize(sf_data.mask.volume_link)),
static_cast<nav_link_t>(
base_type::deserialize(sf_data.mask.volume_link)),
std::move(mask_boundaries)};
}

Expand Down
70 changes: 14 additions & 56 deletions io/include/detray/io/common/geometry_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
#include "detray/io/common/io_interface.hpp"
#include "detray/io/common/payloads.hpp"
#include "detray/masks/masks.hpp"
#include "detray/materials/material_rod.hpp"
#include "detray/materials/material_slab.hpp"

// System include(s)
#include <algorithm>
Expand Down Expand Up @@ -76,14 +74,6 @@ class geometry_writer : public writer_interface<detector_t> {
return det_data;
}

/// Serialize a link @param idx into its io payload
static single_link_payload serialize(const std::size_t idx) {
single_link_payload link_data;
link_data.link = idx;

return link_data;
}

/// Serialize a surface transform @param trf into its io payload
static transform_payload serialize(
const typename detector_t::transform3& trf) {
Expand All @@ -109,7 +99,7 @@ class geometry_writer : public writer_interface<detector_t> {

mask_data.shape = io::detail::get_shape_id<typename mask_t::shape>();

mask_data.volume_link = serialize(m.volume_link());
mask_data.volume_link = base_type::serialize(m.volume_link());

mask_data.boundaries.resize(mask_t::boundaries::e_size);
std::copy(std::cbegin(m.values()), std::cend(m.values()),
Expand All @@ -118,29 +108,6 @@ class geometry_writer : public writer_interface<detector_t> {
return mask_data;
}

/// Serialize a surface material link @param m into its io payload
template <class material_t>
static material_link_payload serialize(const std::size_t idx) {
using scalar_t = typename material_t::scalar_type;
using type_id = material_link_payload::material_type;

material_link_payload mat_data;

// Find the correct material type index (use name for simplicity)
if constexpr (std::is_same_v<material_t, material_slab<scalar_t>>) {
mat_data.type = type_id::slab;
} else if constexpr (std::is_same_v<material_t,
material_rod<scalar_t>>) {
mat_data.type = type_id::rod;
} else {
mat_data.type = type_id::unknown;
}

mat_data.index = idx;

return mat_data;
}

/// Serialize a detector surface @param sf into its io payload
static surface_payload serialize(const surface<detector_t>& sf) {
surface_payload sf_data;
Expand All @@ -150,28 +117,18 @@ class geometry_writer : public writer_interface<detector_t> {
sf_data.transform = serialize(sf.transform({}));
sf_data.mask = sf.template visit_mask<get_mask_payload>();
sf_data.material = sf.template visit_material<get_material_payload>();
sf_data.source = serialize(sf.source());
sf_data.source = base_type::serialize(sf.source());

return sf_data;
}

/// Serialize a link @param idx into its io payload
static acc_links_payload serialize(const acc_links_payload::acc_type id,
const std::size_t idx) {
acc_links_payload link_data;
link_data.type = id;
link_data.index = idx;

return link_data;
}

/// Serialize a detector portal @param sf into its io payload
static volume_payload serialize(
const typename detector_t::volume_type& vol_desc, const detector_t& det,
const std::string& name) {
volume_payload vol_data;

vol_data.index = serialize(vol_desc.index());
vol_data.index = base_type::serialize(vol_desc.index());
vol_data.name = name;
vol_data.transform =
serialize(det.transform_store()[vol_desc.transform()]);
Expand Down Expand Up @@ -218,8 +175,11 @@ class geometry_writer : public writer_interface<detector_t> {
template <typename material_group_t, typename index_t>
inline auto operator()(const material_group_t&,
const index_t& index) const {
return geometry_writer<detector_t>::template serialize<
typename material_group_t::value_type>(index);
using material_t = typename material_group_t::value_type;

// Find the correct material type index
return base_type::serialize(
io::detail::get_material_id<material_t>(), index);
}
};

Expand All @@ -231,16 +191,14 @@ class geometry_writer : public writer_interface<detector_t> {

using accel_t = typename acc_group_t::value_type;

if constexpr (detail::is_grid_v<accel_t>) {
constexpr auto id{io::detail::get_grid_id<accel_t>()};
auto id{acc_links_payload::type_id::unknown};

return geometry_writer<detector_t>::serialize(id, index);
} else {
// This functor is only called for accelerator data structures
// that are not 'brute force'
return geometry_writer<detector_t>::serialize(
acc_links_payload::acc_type::unknown, index);
// Only serialize grids
if constexpr (detail::is_grid_v<accel_t>) {
id = io::detail::get_grid_id<accel_t>();
}

return base_type::serialize(id, index);
}
};
};
Expand Down
65 changes: 37 additions & 28 deletions io/include/detray/io/common/grid_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,11 @@ class grid_writer : public writer_interface<detector_t> {
const std::string_view det_name) {
grid_header_payload header_data;

header_data.version = detail::get_detray_version();
header_data.detector = det_name;
header_data.tag = tag;
header_data.date = detail::get_current_date();
header_data.common = base_type::serialize(det_name, tag);

header_data.n_grids = get_n_grids(det.surface_store());
header_data.sub_header.emplace();
auto& grid_sub_header = header_data.sub_header.value();
grid_sub_header.n_grids = get_n_grids(det.surface_store());

return header_data;
}
Expand All @@ -62,21 +61,36 @@ class grid_writer : public writer_interface<detector_t> {

detector_grids_payload grids_data;

// Access the acceleration data structures recursively
get_grid_payload(det.surface_store(), grids_data);
for (const auto& vol_desc : det.volumes()) {
// Links to all acceleration data structures in the volume
const auto& multi_link = vol_desc.full_link();

for (dindex i = 0u; i < multi_link.size(); ++i) {
const auto& acc_link = multi_link[i];
// Don't look at empty links
if (acc_link.is_invalid()) {
continue;
}

// If the accelerator is a grid, insert the payload
det.surface_store().template visit<get_grid_payload>(
acc_link, vol_desc.index(), grids_data);
}
}

return grids_data;
}

/// Serialize a grid @param gr of type @param type and index @param idx
/// into its io payload
template <class grid_t>
static grid_payload serialize(io::detail::acc_type type,
static grid_payload serialize(std::size_t volume_index,
io::detail::acc_type type,
const std::size_t idx, const grid_t& gr) {
grid_payload grid_data;

grid_data.type = type;
grid_data.index = idx;
grid_data.volume_link = base_type::serialize(volume_index);
grid_data.acc_link = base_type::serialize(type, idx);

// Serialize the multi-axis into single axis payloads
const std::array<axis_payload, grid_t::Dim> axes_data =
Expand Down Expand Up @@ -152,30 +166,25 @@ class grid_writer : public writer_interface<detector_t> {
}

private:
/// Retrieve @c grid_payload (s) from grid collection elements
template <std::size_t I = 0u>
static void get_grid_payload(
const typename detector_t::surface_container& store,
detector_grids_payload& grids_data) {
/// Retrieve a @c grid_payload from grid collection elements
struct get_grid_payload {

using store_t = typename detector_t::surface_container;
constexpr auto coll_id{store_t::value_types::to_id(I)};
using accel_t = typename store_t::template get_type<coll_id>;

if constexpr (detail::is_grid_v<accel_t>) {
template <typename grid_group_t, typename index_t>
inline void operator()(
[[maybe_unused]] const grid_group_t& coll,
[[maybe_unused]] const index_t& index,
[[maybe_unused]] std::size_t volume_index,
[[maybe_unused]] detector_grids_payload& grids_data) const {
using accel_t = typename grid_group_t::value_type;

const auto& coll = store.template get<coll_id>();
if constexpr (detail::is_grid_v<accel_t>) {

for (unsigned int i = 0u; i < coll.size(); ++i) {
grids_data.grids.push_back(
serialize(io::detail::get_grid_id<accel_t>(), i, coll[i]));
serialize(volume_index, io::detail::get_grid_id<accel_t>(),
index, coll[index]));
}
}

if constexpr (I < store_t::n_collections() - 1u) {
get_grid_payload<I + 1>(store, grids_data);
}
}
};

/// Retrieve number of overall grids in detector
template <std::size_t I = 0u>
Expand Down
4 changes: 2 additions & 2 deletions io/include/detray/io/common/homogeneous_material_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

namespace detray {

/// @brief Abstract base class for tracking geometry readers
/// @brief Abstract base class for a homogeneous material reader.
template <class detector_t>
class homogeneous_material_reader : public reader_interface<detector_t> {

Expand Down Expand Up @@ -57,7 +57,7 @@ class homogeneous_material_reader : public reader_interface<detector_t> {
for (const auto& mv_data : det_mat_data.volumes) {
// Decorate the current volume builder with material
auto vm_builder = det_builder.template decorate<material_builder>(
static_cast<dindex>(mv_data.index));
base_type::deserialize(mv_data.volume_link));

// Add the material data to the factory
auto mat_factory = std::make_shared<material_factory<detector_t>>();
Expand Down
28 changes: 13 additions & 15 deletions io/include/detray/io/common/homogeneous_material_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "detray/io/common/payloads.hpp"
#include "detray/materials/material_rod.hpp"
#include "detray/materials/material_slab.hpp"
#include "detray/utils/ranges.hpp"

// System include(s)
#include <string>
Expand All @@ -26,6 +25,7 @@ template <class detector_t>
class homogeneous_material_writer : public writer_interface<detector_t> {

using base_type = writer_interface<detector_t>;
using scalar_t = typename detector_t::scalar_type;

protected:
/// Tag the writer as "homogeneous_material"
Expand Down Expand Up @@ -82,10 +82,10 @@ class homogeneous_material_writer : public writer_interface<detector_t> {
static material_volume_payload serialize(
const typename detector_t::volume_type& vol_desc,
const detector_t& det) {
using material_type = material_slab_payload::material_type;
using material_type = material_slab_payload::type;

material_volume_payload mv_data;
mv_data.index = vol_desc.index();
mv_data.volume_link = base_type::serialize(vol_desc.index());

// Find all surfaces that belong to the volume
for (const auto& sf_desc : det.surface_lookup()) {
Expand All @@ -97,17 +97,17 @@ class homogeneous_material_writer : public writer_interface<detector_t> {
const material_slab_payload mslp =
sf.template visit_material<get_material_payload>();

if (mslp.type == material_type::slab) {
if (mslp.mat_link.type == material_type::slab) {
mv_data.mat_slabs.push_back(mslp);
} else if (mslp.type == material_type::rod) {
} else if (mslp.mat_link.type == material_type::rod) {
if (not mv_data.mat_rods.has_value()) {
mv_data.mat_rods.emplace();
}
mv_data.mat_rods->push_back(mslp);
} else {
throw std::runtime_error(
"Material could not be matched to payload (found type " +
std::to_string(static_cast<int>(mslp.type)) + ")");
std::to_string(static_cast<int>(mslp.mat_link.type)) + ")");
}
}

Expand All @@ -131,26 +131,24 @@ class homogeneous_material_writer : public writer_interface<detector_t> {

/// Serialize a surface material slab @param mat_slab into its io payload
static material_slab_payload serialize(
const material_slab<typename detector_t::scalar_type>& mat_slab,
std::size_t idx) {
const material_slab<scalar_t>& mat_slab, std::size_t idx) {
material_slab_payload mat_data;

mat_data.type = material_slab_payload::material_type::slab;
mat_data.index = idx;
mat_data.mat_link = base_type::serialize(
io::detail::get_material_id<material_slab<scalar_t>>(), idx);
mat_data.thickness = mat_slab.thickness();
mat_data.mat = serialize(mat_slab.get_material());

return mat_data;
}

/// Serialize a line material rod @param mat_rod into its io payload
/// Serialize a wire material rod @param mat_rod into its io payload
static material_slab_payload serialize(
const material_rod<typename detector_t::scalar_type>& mat_rod,
std::size_t idx) {
const material_rod<scalar_t>& mat_rod, std::size_t idx) {
material_slab_payload mat_data;

mat_data.type = material_slab_payload::material_type::rod;
mat_data.index = idx;
mat_data.mat_link = base_type::serialize(
io::detail::get_material_id<material_rod<scalar_t>>(), idx);
mat_data.thickness = mat_rod.radius();
mat_data.mat = serialize(mat_rod.get_material());

Expand Down
Loading

0 comments on commit 265b987

Please sign in to comment.