Skip to content

Commit

Permalink
#69 #70 Penalize bad actions and removed softmax from ceiling policy (#…
Browse files Browse the repository at this point in the history
…71)

* #70 Removed Softmax from ceiling policy

* #69 Penalized bad action
  • Loading branch information
rafaelBauer authored Aug 30, 2024
1 parent 9b5ff3e commit fb9cca0
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/config/ceiling_evaluate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
visual_embedding_dim=256,
proprioceptive_dim=9,
action_dim=4,
from_file="automatic_100_corr0_evalceiling_policy_100.pt",
from_file="automatic_100_corr_0_eval_ceiling_noisy_policy_100.pt",
),
MotionPlannerPolicyConfig(),
]
Expand All @@ -61,7 +61,7 @@
policies=policies,
learn_algorithms=learn_algorithms,
environment_config=env_config,
episodes=100,
episodes=200,
task="StackCubesInd",
train=False,
)
9 changes: 5 additions & 4 deletions src/config/ceiling_train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# ====== ManiSkill environment ========
from envs.maniskill import ManiSkillEnvironmentConfig

env_config = ManiSkillEnvironmentConfig(task_config=task_config, headless=True)
env_config = ManiSkillEnvironmentConfig(task_config=task_config, headless=False)

# ====== Mock environment ========
# from envs.mock import MockEnvironmentConfig
Expand All @@ -23,8 +23,8 @@
# from learnalgorithm.feedbackdevice.keyboardfeedback import KeyboardFeedbackConfig
from learnalgorithm.feedbackdevice.automaticfeedback import AutomaticFeedbackConfig

noise_distribution = [[88.9, 6.6, 0.2, 4.3], [9.7, 85.3, 0.2, 4.8], [0.5, 1.1, 90.4, 8.0], [1.6, 2.1, 3.8, 92.5]]
# noise_distribution = []
# noise_distribution = [[88.9, 6.6, 0.2, 4.3], [9.7, 85.3, 0.2, 4.8], [0.5, 1.1, 90.4, 8.0], [1.6, 2.1, 3.8, 92.5]]
noise_distribution = []

feedback_device_config = AutomaticFeedbackConfig(
action_dim=4, task_config=task_config, corrective_probability=100, noise_distribution=noise_distribution
Expand All @@ -50,7 +50,7 @@
# at the next index.
controllers = [
PeriodicControllerConfig(
ACTION_TYPE="PickPlaceObject", polling_period_s=5, initial_goal=[0, 0, 0, 1], log_metrics=True
ACTION_TYPE="PickPlaceObject", polling_period_s=5, initial_goal=[0, 0, 0, 1], log_metrics=True, max_steps=100
),
PeriodicControllerConfig(
ACTION_TYPE="TargetJointPositionAction", polling_period_s=0.05, initial_goal=[0, 0, 0, 0, 0, 0, 0]
Expand All @@ -75,6 +75,7 @@
+ learn_algorithms[0].name
+ "_"
+ has_noise
+ "_penalized"
+ "_policy.pt", # Number of episodes will be appended to the name before the extension
),
MotionPlannerPolicyConfig(),
Expand Down
4 changes: 2 additions & 2 deletions src/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ def set_post_step_function(self, post_step_function):
"""
self.__post_step_function = post_step_function

def publish_model(self):
self._policy.publish_model()
def publish_model(self, trained_model: bool = False):
self._policy.publish_model(trained_model)

def publish_dataset(self):
if self._learn_algorithm is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/policy/ceilingpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, config: CeilingPolicyConfig, **kwargs):
self.__lstm = nn.LSTM(LSTM_DIM, LSTM_DIM)
self.__lstm_state: tuple[torch.Tensor, torch.Tensor] | None = None

self.__action_net = nn.Sequential(nn.Linear(LSTM_DIM, config.action_dim), nn.Tanh(), nn.Softmax(dim=-1))
self.__action_net = nn.Sequential(nn.Linear(LSTM_DIM, config.action_dim), nn.Tanh())
self.__action_net = self.__action_net.to(device)

self._CONFIG: CeilingPolicyConfig = config
Expand Down
4 changes: 2 additions & 2 deletions src/policy/manualpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def __init__(self, config: ManualPolicyConfig, keyboard_observer: KeyboardObserv
@override
def forward(self, states) -> Tensor:
# For when the keyboard observer is not working
# action = numpy.array([0.0, 0.0, 0.0, -0.9, 0.0, 0.9, 0.0])
action = Tensor([0.0, 0.0, 0.0, -0.9, 0.0, 0.9, 0.0])
assert isinstance(states, SceneObservation), "states should be of type SceneObservation"
action = self._feedback_device.check_corrective_feedback(states).numpy()
action = self._feedback_device.check_corrective_feedback(action, states).to("cpu").numpy()
return self.specific_forward(action, states)

@override
Expand Down
7 changes: 5 additions & 2 deletions src/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,17 @@ def save_to_file(self):
logger.info(f"Saving policy to file: {self._CONFIG.save_to_file}")
torch.save(self.state_dict(), self._CONFIG.save_to_file)

def publish_model(self):
def publish_model(self, trained_model: bool = False):
if self._CONFIG.save_to_file:
self.save_to_file()
logger.info(f"Publishing policy to wandb: {self._CONFIG.save_to_file}")
file_name_and_extension = os.path.basename(self._CONFIG.save_to_file)
artifact = wandb.Artifact(f"{os.path.splitext(file_name_and_extension)[0]}", type="model")
artifact.add_file(self._CONFIG.save_to_file)
wandb.run.log_artifact(artifact)
published_artifact = wandb.run.log_artifact(artifact)
if trained_model:
wandb.run.summary["model_version"] = published_artifact.source_name
wandb.run.summary["model_name"] = published_artifact.name.split(":")[0]

@abstractmethod
def forward(self, states) -> Tensor:
Expand Down
4 changes: 3 additions & 1 deletion src/pre_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def create_config_from_args() -> Config:
config: Config = OmegaConf.to_container(dict_config, resolve=True, structured_config_mode=SCMode.INSTANTIATE)
if args.task:
config.task = args.task
if args.model_file:
config.policies[0].from_file = args.model_file
return config


Expand Down Expand Up @@ -190,7 +192,7 @@ def post_step(controller_step: ControllerStep):
keyboard_obs.stop()

if config.train:
controllers[0].publish_model()
controllers[0].publish_model(config.train)
logger.info("Successfully trained policy for task {}", config.task)

if config.episodes > 0 and learn_algorithms[0] is not None:
Expand Down
8 changes: 8 additions & 0 deletions src/utils/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ def parse_args(
help="Path to a dataset. May be provided instead of -t.",
)

parser.add_argument(
"-m",
"--model_file",
type=str,
default="",
help="Sets model file name",
)

for arg in extra_args:
arg_flag = arg.pop("flag")
arg_name = arg.pop("name")
Expand Down
9 changes: 9 additions & 0 deletions src/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def __getitem__(self, idx):
for trajectory_step in trajectory:
if trajectory_step.feedback == HumanFeedback.CORRECTED:
trajectory_step.feedback = alpha
elif trajectory_step.feedback == HumanFeedback.BAD:
trajectory_step.feedback = 1

return trajectory

Expand All @@ -149,6 +151,7 @@ def add(self, step: TrajectoryData):
assert isinstance(step.feedback, Tensor), f"Expected feedback to be a tensor, got {type(step.feedback)}"

self.__feedback_counter[HumanFeedback(step.feedback[0].item())] += 1
self.adapt_action(self.__current_trajectory[-1])
return

def save_current_traj(self):
Expand Down Expand Up @@ -203,8 +206,14 @@ def modify_feedback_from_current_step(self, feedback: HumanFeedback):
self.__feedback_counter[HumanFeedback(self.__current_trajectory[-1].feedback[0].item())] -= 1
self.__current_trajectory[-1].feedback = torch.Tensor([feedback])
self.__feedback_counter[feedback] += 1
self.adapt_action(self.__current_trajectory[-1])
return

def adapt_action(self, step: TrajectoryData):
if step.feedback == HumanFeedback.BAD:
# One-cold encode the bad action
step.action = 1 - step.action

def __down_sample_current_trajectory(self):
"""
Down-samples the current trajectory to match the trajectory size.
Expand Down

0 comments on commit fb9cca0

Please sign in to comment.