Skip to content

Commit

Permalink
Add device-related fwd-decl to madness backend
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Dec 19, 2023
1 parent f424bcd commit 3bf1e16
Show file tree
Hide file tree
Showing 13 changed files with 115 additions and 111 deletions.
18 changes: 2 additions & 16 deletions ttg/ttg/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,14 @@
#include <memory>
#include "ttg/impl_selector.h"

#if defined(TTG_IMPL_DEVICE_SUPPORT)

namespace ttg {

template<typename T, typename Allocator = std::allocator<T>>
using Buffer = TTG_IMPL_NS::Buffer<T, Allocator>;

namespace detail {
template<typename T>
struct is_buffer : std::false_type
{ };

template<typename T, typename A>
struct is_buffer<ttg::Buffer<T, A>> : std::true_type
{ };

template<typename T>
constexpr bool is_buffer_v = is_buffer<T>::value;

static_assert(is_buffer_v<ttg::Buffer<double>>);
static_assert(is_buffer_v<TTG_IMPL_NS::Buffer<double>>);
} // namespace detail

} // namespace ttg

#endif // TTG_IMPL_DEVICE_SUPPORT
#endif // TTG_buffer_H
4 changes: 2 additions & 2 deletions ttg/ttg/device/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ namespace ttg::device {
template<typename... Buffers>
[[nodiscard]]
inline auto wait(Buffers&&... args) {
static_assert(((ttg::detail::is_buffer_v<std::decay_t<Buffers>>
||ttg::detail::is_devicescratch_v<std::decay_t<Buffers>>)&&...),
static_assert(((ttg::meta::is_buffer_v<std::decay_t<Buffers>>
||ttg::meta::is_devicescratch_v<std::decay_t<Buffers>>)&&...),
"Only ttg::Buffer and ttg::devicescratch can be waited on!");
return detail::wait_kernel_t<std::remove_reference_t<Buffers>...>{std::tie(std::forward<Buffers>(args)...)};
}
Expand Down
15 changes: 0 additions & 15 deletions ttg/ttg/devicescratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,6 @@ auto make_scratch(T* val, ttg::scope scope, std::size_t count = 1) {
return devicescratch<T>(val, scope, count);
}

namespace detail {

template<typename T>
struct is_devicescratch : std::false_type
{ };

template<typename T>
struct is_devicescratch<ttg::devicescratch<T>> : std::true_type
{ };

template<typename T>
constexpr bool is_devicescratch_v = is_devicescratch<T>::value;

} // namespace detail

} // namespace ttg

#endif // TTG_DEVICESCRATCH_H
26 changes: 26 additions & 0 deletions ttg/ttg/madness/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,32 @@ namespace ttg_madness {
template <typename T>
inline void ttg_broadcast(ttg::World world, T &data, int source_rank);


/* device definitions, not currently provided by this impl */
template<typename T, typename Allocator>
struct Buffer;

template<typename T>
struct Ptr;

template<typename T>
struct devicescratch;

template<typename T, typename... Args>
Ptr<T> make_ptr(Args&&... args);

template<typename T>
auto get_ptr(T&& obj);

template<typename... Views>
inline bool register_device_memory(std::tuple<Views&...> &views);

template<typename... Buffer>
inline void post_device_out(std::tuple<Buffer&...> &b);

template<typename... Buffer>
inline void mark_device_out(std::tuple<Buffer&...> &b);

} // namespace ttg_madness

#endif // TTG_MADNESS_FWD_H
15 changes: 0 additions & 15 deletions ttg/ttg/parsec/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,21 +403,6 @@ struct Buffer : public detail::ttg_parsec_data_wrapper_t

};

template<typename T>
struct is_buffer : std::false_type
{ };

template<typename T, typename A>
struct is_buffer<Buffer<T, A>> : std::true_type
{ };

template<typename T, typename A>
struct is_buffer<const Buffer<T, A>> : std::true_type
{ };

template<typename T>
constexpr static const bool is_buffer_v = is_buffer<T>::value;

namespace detail {
template<typename T, typename A>
parsec_data_t* get_parsec_data(const ttg_parsec::Buffer<T, A>& db) {
Expand Down
4 changes: 2 additions & 2 deletions ttg/ttg/parsec/devicefunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace ttg_parsec {

auto& view = std::get<I>(views);
bool is_current = false;
static_assert(ttg::is_buffer_v<view_type> || ttg_parsec::is_devicescratch_v<view_type>);
static_assert(ttg::meta::is_buffer_v<view_type> || ttg::meta::is_devicescratch_v<view_type>);
/* get_parsec_data is overloaded for buffer and devicescratch */
parsec_data_t* data = detail::get_parsec_data(view);
/* TODO: check whether the device is current */
Expand All @@ -41,7 +41,7 @@ namespace ttg_parsec {
//if (flows[I].flow_flags != PARSEC_FLOW_ACCESS_RW) {
access = PARSEC_FLOW_ACCESS_READ;
//}
} else if constexpr (ttg_parsec::is_devicescratch_v<view_type>) {
} else if constexpr (ttg::meta::is_devicescratch_v<view_type>) {
if (view.scope() == ttg::scope::Allocate) {
access = PARSEC_FLOW_ACCESS_WRITE;
}
Expand Down
15 changes: 0 additions & 15 deletions ttg/ttg/parsec/devicescratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,6 @@ struct devicescratch {

};

template<typename T>
struct is_devicescratch : std::false_type
{ };

template<typename T>
struct is_devicescratch<devicescratch<T>> : std::true_type
{ };

template<typename T>
struct is_devicescratch<const devicescratch<T>> : std::true_type
{ };

template<typename T>
constexpr static const bool is_devicescratch_v = is_devicescratch<T>::value;

namespace detail {
template<typename T>
parsec_data_t* get_parsec_data(const ttg_parsec::devicescratch<T>& scratch) {
Expand Down
6 changes: 3 additions & 3 deletions ttg/ttg/parsec/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace ttg_parsec {
class TT;

template<typename T>
struct ptr;
struct Ptr;

template<typename T, typename Allocator = std::allocator<T>>
struct Buffer;
Expand Down Expand Up @@ -79,10 +79,10 @@ namespace ttg_parsec {
inline std::pair<bool, std::tuple<ptr<std::decay_t<Args>>...>> get_ptr(Args&&... args);
#endif
template<typename T>
inline ptr<std::decay_t<T>> get_ptr(T&& obj);
inline Ptr<std::decay_t<T>> get_ptr(T&& obj);

template<typename T, typename... Args>
inline ptr<T> make_ptr(Args&&... args);
inline Ptr<T> make_ptr(Args&&... args);


} // namespace ttg_parsec
Expand Down
1 change: 1 addition & 0 deletions ttg/ttg/parsec/import.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define TTG_SELECTED_DEFAULT_IMPL parsec
#define TTG_PARSEC_IMPORTED 1
#define TTG_IMPL_NS ttg_parsec
#define TTG_IMPL_DEVICE_SUPPORT 1

namespace ttg_parsec {}

Expand Down
50 changes: 25 additions & 25 deletions ttg/ttg/parsec/ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ namespace ttg_parsec {
template <typename Value>
inline ttg_data_copy_t *create_new_datacopy(Value &&value);

struct ptr {
struct ptr_impl {
using copy_type = detail::ttg_data_copy_t;

private:
static inline std::unordered_map<ptr*, bool> m_ptr_map;
static inline std::unordered_map<ptr_impl*, bool> m_ptr_map;
static inline std::mutex m_ptr_map_mtx;

copy_type *m_copy = nullptr;
Expand Down Expand Up @@ -51,7 +51,7 @@ namespace ttg_parsec {
}

public:
ptr(copy_type *copy)
ptr_impl(copy_type *copy)
: m_copy(copy)
{
register_self();
Expand All @@ -63,28 +63,28 @@ namespace ttg_parsec {
return m_copy;
}

ptr(const ptr& p)
ptr_impl(const ptr_impl& p)
: m_copy(p.m_copy)
{
register_self();
m_copy->add_ref();
std::cout << "ptr cpy " << m_copy << " ref " << m_copy->num_ref() << std::endl;
}

ptr(ptr&& p)
ptr_impl(ptr_impl&& p)
: m_copy(p.m_copy)
{
register_self();
p.m_copy = nullptr;
std::cout << "ptr mov " << m_copy << " ref " << m_copy->num_ref() << std::endl;
}

~ptr() {
~ptr_impl() {
deregister_self();
drop_copy();
}

ptr& operator=(const ptr& p)
ptr_impl& operator=(const ptr_impl& p)
{
drop_copy();
m_copy = p.m_copy;
Expand All @@ -93,7 +93,7 @@ namespace ttg_parsec {
return *this;
}

ptr& operator=(ptr&& p) {
ptr_impl& operator=(ptr_impl&& p) {
drop_copy();
m_copy = p.m_copy;
p.m_copy = nullptr;
Expand Down Expand Up @@ -128,34 +128,34 @@ namespace ttg_parsec {

// fwd decl
template<typename T, typename... Args>
ptr<T> make_ptr(Args&&... args);
Ptr<T> make_ptr(Args&&... args);

// fwd decl
template<typename T>
ptr<std::decay_t<T>> get_ptr(T&& obj);
Ptr<std::decay_t<T>> get_ptr(T&& obj);

template<typename T>
struct ptr {
struct Ptr {

using value_type = std::decay_t<T>;

private:
using copy_type = detail::ttg_data_value_copy_t<value_type>;

std::unique_ptr<detail::ptr> m_ptr;
std::unique_ptr<detail::ptr_impl> m_ptr;

/* only PaRSEC backend functions are allowed to touch our private parts */
template<typename... Args>
friend ptr<T> make_ptr(Args&&... args);
friend Ptr<T> make_ptr(Args&&... args);
template<typename S>
friend ptr<std::decay_t<S>> get_ptr(S&& obj);
friend Ptr<std::decay_t<S>> get_ptr(S&& obj);
template<typename S>
friend detail::ttg_data_copy_t* detail::get_copy(ptr<S>& p);
friend detail::ttg_data_copy_t* detail::get_copy(Ptr<S>& p);
friend ttg::detail::value_copy_handler<ttg::Runtime::PaRSEC>;

/* only accessible by get_ptr and make_ptr */
ptr(detail::ptr::copy_type *copy)
: m_ptr(new detail::ptr(copy))
Ptr(detail::ptr_impl::copy_type *copy)
: m_ptr(new detail::ptr_impl(copy))
{ }

copy_type* get_copy() const {
Expand All @@ -164,22 +164,22 @@ namespace ttg_parsec {

public:

ptr() = default;
Ptr() = default;

ptr(const ptr& p)
: ptr(p.get_copy())
Ptr(const Ptr& p)
: Ptr(p.get_copy())
{ }

ptr(ptr&& p) = default;
Ptr(Ptr&& p) = default;

~ptr() = default;
~Ptr() = default;

ptr& operator=(const ptr& p) {
m_ptr.reset(new detail::ptr(p.get_copy()));
Ptr& operator=(const Ptr& p) {
m_ptr.reset(new detail::ptr_impl(p.get_copy()));
return *this;
}

ptr& operator=(ptr&& p) = default;
Ptr& operator=(Ptr&& p) = default;

value_type& operator*() const {
return **static_cast<copy_type*>(m_ptr->get_copy());
Expand Down
6 changes: 3 additions & 3 deletions ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ namespace ttg_parsec {
if(0 == ttg::default_execution_context().rank())
ttg::default_execution_context().impl().final_task();
ttg::detail::set_default_world(ttg::World{}); // reset the default world
detail::ptr::drop_all_ptr();
detail::ptr_impl::drop_all_ptr();
ttg::detail::destroy_worlds<ttg_parsec::WorldImpl>();
if (detail::initialized_mpi()) MPI_Finalize();
}
Expand Down Expand Up @@ -4004,7 +4004,7 @@ ttg::abort(); // should not happen
template<typename Key, typename Arg, typename... Args, std::size_t I, std::size_t... Is>
void invoke_arglist(std::index_sequence<I, Is...>, const Key& key, Arg&& arg, Args&&... args) {
using arg_type = std::decay_t<Arg>;
if constexpr (ttg::detail::is_ptr_v<arg_type>) {
if constexpr (ttg::meta::is_ptr_v<arg_type>) {
/* add a reference to the object */
auto copy = ttg_parsec::detail::get_copy(arg);
copy->add_ref();
Expand All @@ -4017,7 +4017,7 @@ ttg::abort(); // should not happen
/* if the ptr was moved in we reset it */
arg.reset();
}
} else if constexpr (!ttg::detail::is_ptr_v<arg_type>) {
} else if constexpr (!ttg::meta::is_ptr_v<arg_type>) {
set_arg<I>(key, std::forward<Arg>(arg));
}
if constexpr (sizeof...(Is) > 0) {
Expand Down
16 changes: 1 addition & 15 deletions ttg/ttg/ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace ttg {

template<typename T>
using Ptr = TTG_IMPL_NS::ptr<T>;
using Ptr = TTG_IMPL_NS::Ptr<T>;

template<typename T, typename... Args>
Ptr<T> make_ptr(Args&&... args) {
Expand All @@ -18,20 +18,6 @@ auto get_ptr(T&& obj) {
return TTG_IMPL_NS::get_ptr(std::forward<T>(obj));
}

namespace detail {
template<typename T>
struct is_ptr : std::false_type
{ };

template<typename T>
struct is_ptr<ttg::Ptr<T>> : std::true_type
{ };

template<typename T>
constexpr bool is_ptr_v = is_ptr<T>::value;

} // namespace detail

#if 0
namespace detail {

Expand Down
Loading

0 comments on commit 3bf1e16

Please sign in to comment.