From c2bc9dcc48136d6bb482efb30b3d1545a284d477 Mon Sep 17 00:00:00 2001 From: Jan Michelfeit Date: Thu, 1 Dec 2022 14:41:06 +0100 Subject: [PATCH] #625 extract _preference_feedback_schedule() --- .../algorithms/preference_comparisons.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 165c9fc00..c5150db75 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1668,16 +1668,9 @@ def train( A dictionary with final metrics such as loss and accuracy of the reward model. """ - initial_comparisons = int(total_comparisons * self.initial_comparison_frac) - total_comparisons -= initial_comparisons - # Compute the number of comparisons to request at each iteration in advance. - vec_schedule = np.vectorize(self.query_schedule) - unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations)) - probs = unnormalized_probs / np.sum(unnormalized_probs) - shares = util.oric(probs * total_comparisons) - schedule = [initial_comparisons] + shares.tolist() - print(f"Query schedule: {schedule}") + preference_query_schedule = self._preference_gather_schedule(total_comparisons) + print(f"Query schedule: {preference_query_schedule}") timesteps_per_iteration, extra_timesteps = divmod( total_timesteps, @@ -1686,7 +1679,7 @@ def train( reward_loss = None reward_accuracy = None - for i, num_pairs in enumerate(schedule): + for i, num_pairs in enumerate(preference_query_schedule): ########################## # Gather new preferences # ########################## @@ -1749,3 +1742,13 @@ def train( self._iteration += 1 return {"reward_loss": reward_loss, "reward_accuracy": reward_accuracy} + + def _preference_gather_schedule(self, total_comparisons): + initial_comparisons = int(total_comparisons * self.initial_comparison_frac) + total_comparisons -= initial_comparisons + vec_schedule = np.vectorize(self.query_schedule) + unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations)) + probs = unnormalized_probs / np.sum(unnormalized_probs) + shares = util.oric(probs * total_comparisons) + schedule = [initial_comparisons] + shares.tolist() + return schedule