diff --git a/automation-api/yival_experiments/custom_configuration/vertex_ai_evaluator.py b/automation-api/yival_experiments/custom_configuration/vertex_ai_evaluator.py index ec24717..0b92cef 100644 --- a/automation-api/yival_experiments/custom_configuration/vertex_ai_evaluator.py +++ b/automation-api/yival_experiments/custom_configuration/vertex_ai_evaluator.py @@ -36,9 +36,6 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# because all claude models are avaliable in us-east5 -VERTEX_LOCATION = "us-east5" - class VertexAIEvaluator(BaseEvaluator): """Evaluator using VertexAI's prompt-based evaluation.""" @@ -48,6 +45,10 @@ class VertexAIEvaluator(BaseEvaluator): def __init__(self, config: VertexAIEvaluatorConfig): super().__init__(config) self.config = config + if "claude" in self.config.model_name: + self.vertex_location = "us-east5" + else: + self.vertex_location = "us-central1" def evaluate(self, experiment_result: ExperimentResult) -> EvaluatorOutput: """Evaluate the experiment result using Vertex AI's prompt-based evaluation.""" @@ -70,7 +71,7 @@ def evaluate(self, experiment_result: ExperimentResult) -> EvaluatorOutput: max_tokens=2000, request_timeout=60, caching=True, - vertex_ai_location=VERTEX_LOCATION, + vertex_ai_location=self.vertex_location, vertex_ai_project=os.environ["VERTEXAI_PROJECT"], ) # response = openai.ChatCompletion.create( @@ -114,6 +115,7 @@ def main(): MethodCalculationMethod(MethodCalculationMethod.AVERAGE) ) ], + model_name="gemini-pro-experimental", prompt=prompt, choices=choices, evaluator_type=EvaluatorType.INDIVIDUAL,