Skip to content

Commit

Permalink
Allow yearJob to be non-copyable (#1757)
Browse files Browse the repository at this point in the history
* Shared function object

Signed-off-by: Sylvain Leclerc <[email protected]>

* Extract in another header

Signed-off-by: Sylvain Leclerc <[email protected]>

* Move solution to concurrency lib

Signed-off-by: Sylvain Leclerc <[email protected]>

* Some test

Signed-off-by: Sylvain Leclerc <[email protected]>

* Unnecessary change

Signed-off-by: Sylvain Leclerc <[email protected]>

* Rename implementation class

Signed-off-by: Sylvain Leclerc <[email protected]>

---------

Signed-off-by: Sylvain Leclerc <[email protected]>
  • Loading branch information
sylvlecl authored Nov 10, 2023
1 parent 10278e9 commit 1e137f2
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ using TaskFuture = std::future<void>;
const Task& task,
Yuni::Job::Priority priority = Yuni::Job::priorityDefault);

/*!
* \brief Queues the provided function objects and returns the corresponding std::future.
*
* T must define operator ().
*
* This allows to handle exceptions occuring in the underlying task,
* as opposite to Yuni::Job::QueueService::add which swallows them.
*/
template <class T>
[[nodiscard]] TaskFuture AddTask(Yuni::Job::QueueService& threadPool,
const std::shared_ptr<T>& task,
Yuni::Job::Priority priority = Yuni::Job::priorityDefault);

/*!
* \brief Utility class to gather futures to wait for.
*/
Expand Down Expand Up @@ -82,6 +95,44 @@ class FutureSet
std::vector<TaskFuture> futures_;
};


namespace Detail { //implementation details

/*!
* Utility class to wrap a callable object pointer
* into a copyable callable object.
*
* @tparam T the underlying callable type
*/
template<class T>
class CopyableCallable
{
public:
explicit CopyableCallable(const std::shared_ptr<T>& functionObject) :
functionObject_(functionObject)
{
}

void operator()()
{
(*functionObject_)();
}

private:
std::shared_ptr<T> functionObject_;
};

}

template <class T>
TaskFuture AddTask(Yuni::Job::QueueService& threadPool,
const std::shared_ptr<T>& task,
Yuni::Job::Priority priority)
{
Task wrappedTask = Detail::CopyableCallable<T>(task);
return AddTask(threadPool, wrappedTask, priority);
}

}


Expand Down
7 changes: 6 additions & 1 deletion src/solver/simulation/solver.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ public:
hydroHotStart = (study.parameters.initialReservoirLevels.iniLevels == Data::irlHotStart);
}

yearJob(const yearJob&) = delete;
yearJob& operator =(const yearJob&) = delete;
~yearJob() = default;

private:
ISimulation<Impl>* simulation_;
unsigned int y;
Expand Down Expand Up @@ -233,6 +237,7 @@ public:
} // End of onExecute() method
};


template<class Impl>
inline ISimulation<Impl>::ISimulation(Data::Study& study,
const ::Settings& settings,
Expand Down Expand Up @@ -983,7 +988,7 @@ void ISimulation<Impl>::loopThroughYears(uint firstYear,
// have to be rerun (meaning : they must be run once). if(!set_it->yearFailed[y])
// continue;

Concurrency::Task task = yearJob<ImplementationType>(this,
auto task = std::make_shared<yearJob<ImplementationType>>(this,
y,
set_it->yearFailed,
set_it->isFirstPerformedYearOfASet,
Expand Down
24 changes: 24 additions & 0 deletions src/tests/src/libs/antares/concurrency/test_concurrency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,27 @@ BOOST_AUTO_TEST_CASE(test_future_set_rethrows_first_submitted)
futures.add(AddTask(*threadPool, failingTask<TestExceptionN<2>>()));
BOOST_CHECK_THROW(futures.join(), TestExceptionN<1>);
}

struct NonCopyableFunctionObject
{
NonCopyableFunctionObject() = default;
NonCopyableFunctionObject(const NonCopyableFunctionObject&) = delete;
NonCopyableFunctionObject& operator=(const NonCopyableFunctionObject&) = delete;

bool called = false;

void operator()()
{
called = true;
}
};

BOOST_AUTO_TEST_CASE(allow_to_use_function_object_pointer)
{
auto threadPool = createThreadPool(1);
auto functionObjectPtr = std::make_shared<NonCopyableFunctionObject>();
BOOST_CHECK(!functionObjectPtr->called);
TaskFuture future = AddTask(*threadPool, functionObjectPtr);
future.get();
BOOST_CHECK(functionObjectPtr->called);
}

0 comments on commit 1e137f2

Please sign in to comment.