From a5ebc25bc9986abd77f433f313766a7f0bb6ba43 Mon Sep 17 00:00:00 2001 From: Jonas Dittrich <58814480+Kakadus@users.noreply.github.com> Date: Tue, 26 Mar 2024 18:55:46 +0000 Subject: [PATCH 1/2] feat: use dict for extra infos --- model/gym-interface/cpp/ns3-ai-gym-env.h | 2 +- model/gym-interface/cpp/ns3-ai-gym-interface.cc | 13 ++++++++----- model/gym-interface/cpp/ns3-ai-gym-interface.h | 6 +++--- model/gym-interface/messages.proto | 2 +- .../py/ns3ai_gym_env/envs/ns3_environment.py | 6 ++---- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/model/gym-interface/cpp/ns3-ai-gym-env.h b/model/gym-interface/cpp/ns3-ai-gym-env.h index ee8909f..db93a10 100644 --- a/model/gym-interface/cpp/ns3-ai-gym-env.h +++ b/model/gym-interface/cpp/ns3-ai-gym-env.h @@ -71,7 +71,7 @@ class OpenGymEnv : public Object /** * Get extra information */ - virtual std::string GetExtraInfo() = 0; + virtual std::map GetExtraInfo() = 0; /** * Execute actions. E.g., modify the contention window in TCP. diff --git a/model/gym-interface/cpp/ns3-ai-gym-interface.cc b/model/gym-interface/cpp/ns3-ai-gym-interface.cc index 452d2be..cf5e8bb 100644 --- a/model/gym-interface/cpp/ns3-ai-gym-interface.cc +++ b/model/gym-interface/cpp/ns3-ai-gym-interface.cc @@ -147,7 +147,7 @@ OpenGymInterface::NotifyCurrentState() Ptr obsDataContainer = GetObservation(); float reward = GetReward(); bool isGameOver = IsGameOver(); - std::string extraInfo = GetExtraInfo(); + std::map extraInfo = GetExtraInfo(); ns3_ai_gym::EnvStateMsg envStateMsg; // observation ns3_ai_gym::DataContainer obsDataContainerPbMsg; @@ -173,7 +173,10 @@ OpenGymInterface::NotifyCurrentState() } } // extra info - envStateMsg.set_info(extraInfo); + for (auto const &[key, value] : extraInfo) + { + (*envStateMsg.mutable_info())[key] = value; + } // get the interface Ns3AiMsgInterfaceImpl* msgInterface = @@ -298,11 +301,11 @@ OpenGymInterface::IsGameOver() return (gameOver || m_simEnd); } -std::string +std::map OpenGymInterface::GetExtraInfo() { NS_LOG_FUNCTION(this); - std::string info; + std::map info; if (!m_extraInfoCb.IsNull()) { info = m_extraInfoCb(); @@ -353,7 +356,7 @@ OpenGymInterface::SetGetRewardCb(Callback cb) } void -OpenGymInterface::SetGetExtraInfoCb(Callback cb) +OpenGymInterface::SetGetExtraInfoCb(Callback> cb) { m_extraInfoCb = cb; } diff --git a/model/gym-interface/cpp/ns3-ai-gym-interface.h b/model/gym-interface/cpp/ns3-ai-gym-interface.h index be6de8f..d3ef6c5 100644 --- a/model/gym-interface/cpp/ns3-ai-gym-interface.h +++ b/model/gym-interface/cpp/ns3-ai-gym-interface.h @@ -55,7 +55,7 @@ class OpenGymInterface : public Object Ptr GetObservation(); float GetReward(); bool IsGameOver(); - std::string GetExtraInfo(); + std::map GetExtraInfo(); bool ExecuteActions(Ptr action); void SetGetActionSpaceCb(Callback> cb); @@ -63,7 +63,7 @@ class OpenGymInterface : public Object void SetGetObservationCb(Callback> cb); void SetGetRewardCb(Callback cb); void SetGetGameOverCb(Callback cb); - void SetGetExtraInfoCb(Callback cb); + void SetGetExtraInfoCb(Callback> cb); void SetExecuteActionsCb(Callback> cb); void Notify(Ptr entity); @@ -86,7 +86,7 @@ class OpenGymInterface : public Object Callback m_gameOverCb; Callback> m_obsCb; Callback m_rewardCb; - Callback m_extraInfoCb; + Callback> m_extraInfoCb; Callback> m_actionCb; }; diff --git a/model/gym-interface/messages.proto b/model/gym-interface/messages.proto index 9045ec5..0a00c43 100644 --- a/model/gym-interface/messages.proto +++ b/model/gym-interface/messages.proto @@ -114,7 +114,7 @@ message EnvStateMsg { GameOver = 1; } Reason reason = 4; - string info = 5; + map info = 5; } message EnvActMsg { diff --git a/model/gym-interface/py/ns3ai_gym_env/envs/ns3_environment.py b/model/gym-interface/py/ns3ai_gym_env/envs/ns3_environment.py index cbbbe39..775b619 100644 --- a/model/gym-interface/py/ns3ai_gym_env/envs/ns3_environment.py +++ b/model/gym-interface/py/ns3ai_gym_env/envs/ns3_environment.py @@ -163,9 +163,7 @@ def rx_env_state(self): if self.gameOver: self.send_close_command() - self.extraInfo = envStateMsg.info - if not self.extraInfo: - self.extraInfo = {} + self.extraInfo = dict(envStateMsg.info) self.newStateRx = True @@ -268,7 +266,7 @@ def get_state(self): obs = self.get_obs() reward = self.get_reward() done = self.is_game_over() - extraInfo = {"info": self.get_extra_info()} + extraInfo = self.get_extra_info() return obs, reward, done, False, extraInfo def __init__(self, targetName, ns3Path, ns3Settings=None, shmSize=4096): From 83e97cc476a541e2b2377dac747558b84762470b Mon Sep 17 00:00:00 2001 From: Jonas Dittrich <58814480+Kakadus@users.noreply.github.com> Date: Sun, 31 Mar 2024 15:02:35 +0000 Subject: [PATCH 2/2] adapt examples --- examples/a-plus-b/use-gym/apb.cc | 6 +++--- examples/rl-tcp/use-gym/tcp-rl-env.cc | 4 ++-- examples/rl-tcp/use-gym/tcp-rl-env.h | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/a-plus-b/use-gym/apb.cc b/examples/a-plus-b/use-gym/apb.cc index 471d508..04d21a0 100644 --- a/examples/a-plus-b/use-gym/apb.cc +++ b/examples/a-plus-b/use-gym/apb.cc @@ -45,7 +45,7 @@ class ApbEnv : public OpenGymEnv bool GetGameOver() override; Ptr GetObservation() override; float GetReward() override; - std::string GetExtraInfo() override; + std::map GetExtraInfo() override; bool ExecuteActions(Ptr action) override; uint32_t m_a; @@ -125,10 +125,10 @@ ApbEnv::GetReward() return 0.0; } -std::string +std::map ApbEnv::GetExtraInfo() { - return ""; + return {}; } bool diff --git a/examples/rl-tcp/use-gym/tcp-rl-env.cc b/examples/rl-tcp/use-gym/tcp-rl-env.cc index 6eed6c2..b420a1c 100644 --- a/examples/rl-tcp/use-gym/tcp-rl-env.cc +++ b/examples/rl-tcp/use-gym/tcp-rl-env.cc @@ -187,11 +187,11 @@ TcpEnvBase::GetReward() /* Define extra info. Optional */ -std::string +std::map TcpEnvBase::GetExtraInfo() { NS_LOG_INFO("MyGetExtraInfo: " << m_info); - return m_info; + return {{"info", m_info}}; } /* diff --git a/examples/rl-tcp/use-gym/tcp-rl-env.h b/examples/rl-tcp/use-gym/tcp-rl-env.h index f124b08..44c4bfa 100644 --- a/examples/rl-tcp/use-gym/tcp-rl-env.h +++ b/examples/rl-tcp/use-gym/tcp-rl-env.h @@ -52,7 +52,7 @@ class TcpEnvBase : public OpenGymEnv Ptr GetActionSpace() override; bool GetGameOver() override; float GetReward() override; - std::string GetExtraInfo() override; + std::map GetExtraInfo() override; bool ExecuteActions(Ptr action) override; Ptr GetObservationSpace() override = 0;