Skip to content

Commit

Permalink
#625 extract _preference_feedback_schedule()
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 1, 2022
1 parent 567e980 commit c2bc9dc
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,16 +1668,9 @@ def train(
A dictionary with final metrics such as loss and accuracy
of the reward model.
"""
initial_comparisons = int(total_comparisons * self.initial_comparison_frac)
total_comparisons -= initial_comparisons

# Compute the number of comparisons to request at each iteration in advance.
vec_schedule = np.vectorize(self.query_schedule)
unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations))
probs = unnormalized_probs / np.sum(unnormalized_probs)
shares = util.oric(probs * total_comparisons)
schedule = [initial_comparisons] + shares.tolist()
print(f"Query schedule: {schedule}")
preference_query_schedule = self._preference_gather_schedule(total_comparisons)
print(f"Query schedule: {preference_query_schedule}")

timesteps_per_iteration, extra_timesteps = divmod(
total_timesteps,
Expand All @@ -1686,7 +1679,7 @@ def train(
reward_loss = None
reward_accuracy = None

for i, num_pairs in enumerate(schedule):
for i, num_pairs in enumerate(preference_query_schedule):
##########################
# Gather new preferences #
##########################
Expand Down Expand Up @@ -1749,3 +1742,13 @@ def train(
self._iteration += 1

return {"reward_loss": reward_loss, "reward_accuracy": reward_accuracy}

def _preference_gather_schedule(self, total_comparisons):
initial_comparisons = int(total_comparisons * self.initial_comparison_frac)
total_comparisons -= initial_comparisons
vec_schedule = np.vectorize(self.query_schedule)
unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations))
probs = unnormalized_probs / np.sum(unnormalized_probs)
shares = util.oric(probs * total_comparisons)
schedule = [initial_comparisons] + shares.tolist()
return schedule

0 comments on commit c2bc9dc

Please sign in to comment.