Skip to content

Commit

Permalink
Add MultipatchField types
Browse files Browse the repository at this point in the history
Add two new multipatch types:
- `MultipatchField` : To store all `Field` types (`Field`, `VectorField`, `DerivField`)
- `MultipatchFieldMem` : To store all `FieldMem` types (`FieldMem`, `VectorFieldMem`, `DerivFieldMem`)

The `MultipatchField` type is useful to provide all the aliases usually provided by `Field` types.
The `MultipatchFieldMem` type is necessary to allow these types to have `__host__` only versions of the functions in `MultipatchField`.

2 utility booleans are added:
- `has_data_access_methods_v` : To tell if a type is a `Field` type on which data access methods can be called:
     - `get_idx_range`
     - `get_field`
     - `get_const_field`
- `is_mem_type_v` : To tell if a type might allocate a chunk of memory (`FieldMem`, `VectorFieldMem`, `DerivFieldMem`)

These booleans are activated via `enable_X` booleans.

See merge request gysela-developpers/gyselalibxx!735

--------------------------------------------
  • Loading branch information
EmilyBourne committed Oct 21, 2024
1 parent 663fd65 commit 0c10faa
Show file tree
Hide file tree
Showing 19 changed files with 561 additions and 106 deletions.
2 changes: 1 addition & 1 deletion ci_tools/gyselalib_static_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
HOME_DIR = Path(__file__).parent.parent.absolute()
global_folders = [HOME_DIR / f for f in ('src', 'simulations', 'tests')]

auto_functions = set()
auto_functions = set(['build_kokkos_layout'])
field_mem_functions = set()

def report_error(level, file, linenr, message):
Expand Down
11 changes: 8 additions & 3 deletions src/data_types/derivative_field.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ template <class ElementType, class SupportType, class LayoutStridedPolicy, class
inline constexpr bool enable_borrowed_deriv_field<
DerivField<ElementType, SupportType, LayoutStridedPolicy, MemorySpace>> = true;

template <class ElementType, class SupportType, class LayoutStridedPolicy, class MemorySpace>
inline constexpr bool enable_data_access_methods<
DerivField<ElementType, SupportType, LayoutStridedPolicy, MemorySpace>> = true;

namespace ddcHelper {

/**
Expand Down Expand Up @@ -170,7 +174,7 @@ class DerivField<ElementType, IdxRange<DDims...>, LayoutStridedPolicy, MemorySpa
using reference = typename chunk_type::reference;

/// @brief The number of chunks which must be created to describe this object.
static constexpr int n_fields = base_type::n_fields;
using base_type::n_fields;

private:
/** @brief Get the subindex range to be extracted from a DerivFieldMem to build the internal_chunk at position ArrayIndex
Expand Down Expand Up @@ -251,7 +255,8 @@ class DerivField<ElementType, IdxRange<DDims...>, LayoutStridedPolicy, MemorySpa
int NDerivs,
class Allocator,
class = std::enable_if_t<std::is_same_v<typename Allocator::memory_space, MemorySpace>>>
constexpr DerivField(DerivFieldMem<OElementType, index_range_type, NDerivs, Allocator>& field)
explicit constexpr DerivField(
DerivFieldMem<OElementType, index_range_type, NDerivs, Allocator>& field)
: base_type(
field.m_physical_idx_range,
field.m_deriv_idx_range,
Expand All @@ -273,7 +278,7 @@ class DerivField<ElementType, IdxRange<DDims...>, LayoutStridedPolicy, MemorySpa
int NDerivs,
class Allocator,
class = std::enable_if_t<std::is_same_v<typename Allocator::memory_space, MemorySpace>>>
constexpr DerivField(
explicit constexpr DerivField(
DerivFieldMem<OElementType, index_range_type, NDerivs, Allocator> const& field)
: base_type(
field.m_physical_idx_range,
Expand Down
13 changes: 11 additions & 2 deletions src/data_types/derivative_field_mem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#pragma once
#include <ddc/ddc.hpp>

#include "ddc_alias_inline_functions.hpp"
#include "ddc_aliases.hpp"
#include "derivative_field.hpp"
#include "derivative_field_common.hpp"
Expand All @@ -17,6 +18,14 @@ template <class ElementType, class SupportType, int NDerivs, class MemSpace>
inline constexpr bool
enable_deriv_field<DerivFieldMem<ElementType, SupportType, NDerivs, MemSpace>> = true;

template <class ElementType, class SupportType, int NDerivs, class Allocator>
inline constexpr bool enable_data_access_methods<
DerivFieldMem<ElementType, SupportType, NDerivs, Allocator>> = true;

template <class ElementType, class SupportType, int NDerivs, class Allocator>
inline constexpr bool
enable_mem_type<DerivFieldMem<ElementType, SupportType, NDerivs, Allocator>> = true;


/**
* @brief A class which holds a chunk of memory describing a field and its derivatives.
Expand Down Expand Up @@ -132,7 +141,7 @@ class DerivFieldMem<ElementType, IdxRange<DDims...>, NDerivs, MemSpace>
using internal_mdspan_type = typename base_type::internal_mdspan_type;

/// @brief The number of chunks which must be created to describe this object.
static constexpr int n_fields = base_type::n_fields;
using base_type::n_fields;

private:
/// @brief A function to get the index range along direction Tag for the ArrayIndex-th element of internal_fields.
Expand Down Expand Up @@ -354,7 +363,7 @@ class DerivFieldMem<ElementType, IdxRange<DDims...>, NDerivs, MemSpace>
*/
span_type span_view()
{
return *this;
return span_type(*this);
}
};

Expand Down
10 changes: 10 additions & 0 deletions src/data_types/vector_field.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT

#pragma once
#include "ddc_alias_inline_functions.hpp"
#include "ddc_aliases.hpp"
#include "vector_field_mem.hpp"

Expand All @@ -22,6 +23,15 @@ template <
inline constexpr bool enable_field<
VectorField<ElementType, IdxRangeType, NDTag, LayoutStridedPolicy, MemorySpace>> = true;

template <
class ElementType,
class IdxRangeType,
class NDTag,
class LayoutStridedPolicy,
class MemorySpace>
inline constexpr bool enable_data_access_methods<
VectorField<ElementType, IdxRangeType, NDTag, LayoutStridedPolicy, MemorySpace>> = true;

template <
class ElementType,
class IdxRangeType,
Expand Down
9 changes: 9 additions & 0 deletions src/data_types/vector_field_mem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#pragma once
#include <ddc/ddc.hpp>

#include "ddc_alias_inline_functions.hpp"
#include "ddc_aliases.hpp"
#include "vector_field_common.hpp"

Expand All @@ -21,6 +22,14 @@ template <class ElementType, class IdxRangeType, class DimSeq, class MemSpace>
inline constexpr bool
enable_field<VectorFieldMem<ElementType, IdxRangeType, DimSeq, MemSpace>> = true;

template <class ElementType, class IdxRangeType, class DimSeq, class Allocator>
inline constexpr bool enable_data_access_methods<
VectorFieldMem<ElementType, IdxRangeType, DimSeq, Allocator>> = true;

template <class ElementType, class IdxRangeType, class DimSeq, class Allocator>
inline constexpr bool
enable_mem_type<VectorFieldMem<ElementType, IdxRangeType, DimSeq, Allocator>> = true;

/**
* @brief Pre-declaration of VectorField.
*/
Expand Down
189 changes: 189 additions & 0 deletions src/multipatch/data_types/multipatch_field.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// SPDX-License-Identifier: MIT

#pragma once
#include "multipatch_type.hpp"


/**
* @brief A class to store field objects on patches.
*
* On a multipatch domain when we have objects and types defined on different patches, e.g. fields.
* They can be stored in this class and then be accessed by the patch they are defined
* on.
*
* @tparam T The type of the fields/derivative fields/vector fields that are stored on the given patches.
* @tparam Patches The patches of the objects in the same order of the patches
* that the given objects are defined on.
*
* @warning The objects have to be defined on different patches. Otherwise retrieving
* them by their patch is ill-defined.
*/
template <template <typename P> typename T, class... Patches>
class MultipatchField : public MultipatchType<T, Patches...>
{
static_assert(
(has_data_access_methods_v<T<Patches>> && ...),
"The MultipatchField type should only contain instances of objects that can be "
"manipulated like fields.");

public:
/// @brief The MultipatchType from which this class inherits
using base_type = MultipatchType<T, Patches...>;

/// @brief A tag storing the order of Patches in this MultipatchField
using typename base_type::PatchOrdering;

private:
/// An internal type alias that is only instantiated if the idx_range method is called.
template <class Patch>
using InternalIdxRangeOnPatch = typename T<Patch>::discrete_domain_type;

/// An internal type alias that is only instantiated if the get_const_field method is called.
template <class Patch>
using InternalFieldOnPatch = typename T<Patch>::span_type;

/// An internal type alias that is only instantiated if the get_const_field method is called.
template <class Patch>
using InternalConstFieldOnPatch = typename T<Patch>::view_type;

template <template <typename P> typename OT, class... OPatches>
friend class MultipatchField;

using example_element = T<ddc::type_seq_element_t<0, PatchOrdering>>;

static_assert(
!is_mem_type_v<example_element>,
"For correct GPU handling a FieldMem object must be saved in a MultipatchFieldMem "
"type.");

public:
/// The type of a modifiable reference to this multipatch field
using span_type = MultipatchField<InternalFieldOnPatch, Patches...>;
/// The type of a constant reference to this multipatch field
using view_type = MultipatchField<InternalConstFieldOnPatch, Patches...>;
/// The type of the index ranges that can be used to access this field.
using discrete_domain_type = MultipatchType<InternalIdxRangeOnPatch, Patches...>;
/// The memory space (CPU/GPU) where the data is saved.
using memory_space = typename example_element::memory_space;
/// The type of the elements inside the field.
using element_type = typename example_element::element_type;

public:
/**
* Instantiate the MultipatchField class from an arbitrary number of objects.
*
* @param args The objects to be stored in the class.
*/
explicit KOKKOS_FUNCTION MultipatchField(T<Patches>... args) : base_type(args...) {}

/**
* Create a MultipatchField class by copying an instance of another compatible MultipatchField.
*
* A compatible MultipatchField is one which uses all the patches used by this class. The object
* being copied may include more patches than this MultipatchField. Further the original
* MultipatchField must store objects of the correct type (the type template may be different
* but return the same type depending on how it is designed).
*
* @param other The equivalent MultipatchField being copied.
*/
template <class MultipatchObj, std::enable_if_t<!is_mem_type_v<MultipatchObj>, bool> = true>
KOKKOS_FUNCTION MultipatchField(MultipatchObj& other)
: base_type(std::move(T<Patches>(other.template get<Patches>()))...)
{
}

/**
* Create a MultipatchField class from a compatible MultipatchFieldMem.
*
* A compatible MultipatchField is one which uses all the patches used by this class. The object
* being copied may include more patches than this MultipatchField. Further the original
* MultipatchField must store objects of the correct type.
*
* @param other The MultipatchFieldMem being accessed.
*/
template <class MultipatchObj, std::enable_if_t<is_mem_type_v<MultipatchObj>, bool> = true>
MultipatchField(MultipatchObj& other)
: base_type(std::move(T<Patches>(other.template get<Patches>()))...)
{
}

/**
* Create a MultipatchField class from an r-value (temporary) instance of another MultipatchField which
* uses the same type for the internal tuple.
*
* @param other The equivalent MultipatchField being copied.
*/
template <template <typename P> typename OT, class... OPatches>
MultipatchField(MultipatchField<OT, OPatches...>&& other) : base_type(other)
{
static_assert(
std::is_same_v<ddc::detail::TypeSeq<Patches...>, ddc::detail::TypeSeq<OPatches...>>,
"Cannot create a MultipatchField from a temporary MultipatchField with a different "
"ordering");
static_assert(
std::is_same_v<std::tuple<T<Patches>...>, std::tuple<OT<OPatches>...>>,
"MultipatchFields are not equivalent");
}

KOKKOS_FUNCTION ~MultipatchField() {}

/**
* Retrieve an object from the patch that it is defined on.
*
* @tparam Patch The patch of the object to be returned.
* @return The object on the given patch.
*/
template <class Patch>
KOKKOS_FUNCTION auto get() const
{
return ::get_const_field(std::get<T<Patch>>(base_type::m_tuple));
}

/**
* Retrieve an object from the patch that it is defined on.
*
* @tparam Patch The patch of the object to be returned.
* @return The object on the given patch.
*/
template <class Patch>
KOKKOS_FUNCTION auto get()
{
return ::get_field(std::get<T<Patch>>(base_type::m_tuple));
}

/**
* @brief Get a MultipatchType containing the index ranges on which the fields are defined.
*
* @returns The set of index ranges on which the set of fields stored in this class are defined.
*/
auto idx_range() const
{
return MultipatchType<InternalIdxRangeOnPatch, Patches...>(
get_idx_range(std::get<T<Patches>>(base_type::m_tuple))...);
}

/**
* @brief Get a MultipatchField containing modifiable fields.
*
* @returns A set of modifiable fields providing access to the fields stored in this class.
*/
KOKKOS_FUNCTION auto get_field()
{
return MultipatchField<InternalFieldOnPatch, Patches...>(
::get_field(std::get<T<Patches>>(base_type::m_tuple))...);
}

/**
* @brief Get a MultipatchField containing constant fields so the values cannot be modified.
*
* @returns A set of constant fields providing access to the fields stored in this class.
*/
KOKKOS_FUNCTION auto get_const_field() const
{
return MultipatchField<InternalConstFieldOnPatch, Patches...>(
::get_const_field(std::get<T<Patches>>(base_type::m_tuple))...);
}
};

template <template <typename P> typename T, class... Patches>
inline constexpr bool enable_multipatch_type<MultipatchField<T, Patches...>> = true;
Loading

0 comments on commit 0c10faa

Please sign in to comment.