Skip to content

Commit

Permalink
Using new multisample option for decoding instead of workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
ahottung committed Jun 4, 2024
1 parent e039432 commit b081e8f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions rl4co/models/zoo/polynet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def __init__(
self.augment = None

# Add `_multistart` to decode type for train, val and test in policy
for phase in ["train", "val", "test"]:
self.set_decode_type_multistart(phase)
# for phase in ["train", "val", "test"]:
# self.set_decode_type_multistart(phase)

def shared_step(
self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None
Expand All @@ -144,8 +144,8 @@ def shared_step(
self.env,
phase=phase,
num_starts=n_start,
multisample=True,
return_actions=True,
select_start_nodes_fn=(lambda *args: None),
)

# Unbatchify reward to [batch_size, num_augment, num_starts].
Expand Down

0 comments on commit b081e8f

Please sign in to comment.