From 9271a5b98ad24ce6cc2074f8762aa698fe6cd584 Mon Sep 17 00:00:00 2001 From: Roy Garrido Date: Mon, 17 Feb 2025 23:25:46 +0100 Subject: [PATCH] feat: added initial Google provider supporting basic text generation with it --- .env | 3 + examples/chat-gemini-google.php | 28 +++++ src/Bridge/Google/GoogleModel.php | 60 ++++++++++ .../Google/GoogleRequestBodyProducer.php | 97 +++++++++++++++++ src/Bridge/Google/ModelHandler.php | 103 ++++++++++++++++++ src/Bridge/Google/PlatformFactory.php | 23 ++++ src/Model/Message/AssistantMessage.php | 5 + src/Model/Message/Content/Audio.php | 5 + src/Model/Message/Content/Content.php | 1 + src/Model/Message/Content/ContentVisitor.php | 12 ++ src/Model/Message/Content/Image.php | 5 + src/Model/Message/Content/Text.php | 5 + src/Model/Message/MessageInterface.php | 2 + src/Model/Message/MessageVisitor.php | 14 +++ src/Model/Message/SystemMessage.php | 5 + src/Model/Message/ToolCallMessage.php | 5 + src/Model/Message/UserMessage.php | 5 + src/Platform/RequestBodyProducer.php | 8 ++ 18 files changed, 386 insertions(+) create mode 100644 examples/chat-gemini-google.php create mode 100644 src/Bridge/Google/GoogleModel.php create mode 100644 src/Bridge/Google/GoogleRequestBodyProducer.php create mode 100644 src/Bridge/Google/ModelHandler.php create mode 100644 src/Bridge/Google/PlatformFactory.php create mode 100644 src/Model/Message/Content/ContentVisitor.php create mode 100644 src/Model/Message/MessageVisitor.php create mode 100644 src/Platform/RequestBodyProducer.php diff --git a/.env b/.env index c6bf086..e54cafe 100644 --- a/.env +++ b/.env @@ -39,3 +39,6 @@ PINECONE_HOST= # Some examples are expensive to run, so we disable them by default RUN_EXPENSIVE_EXAMPLES=false + +# For using Gemini +GOOGLE_API_KEY= \ No newline at end of file diff --git a/examples/chat-gemini-google.php b/examples/chat-gemini-google.php new file mode 100644 index 0000000..04f6865 --- /dev/null +++ b/examples/chat-gemini-google.php @@ -0,0 +1,28 @@ +loadEnv(dirname(__DIR__).'/.env'); + +if (empty($_ENV['GOOGLE_API_KEY'])) { + echo 'Please set the GOOGLE_API_KEY environment variable.'.PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['GOOGLE_API_KEY']); +$llm = new GoogleModel(GoogleModel::GEMINI_2_FLASH); + +$chain = new Chain($platform, $llm); +$messages = new MessageBag( + Message::forSystem('You are a pirate and you write funny.'), + Message::ofUser('What is the Symfony framework?'), +); +$response = $chain->call($messages); + +echo $response->getContent().PHP_EOL; diff --git a/src/Bridge/Google/GoogleModel.php b/src/Bridge/Google/GoogleModel.php new file mode 100644 index 0000000..5943a05 --- /dev/null +++ b/src/Bridge/Google/GoogleModel.php @@ -0,0 +1,60 @@ + $options The default options for the model usage + */ + public function __construct( + private string $version = self::GEMINI_2_PRO, + private array $options = ['temperature' => 1.0], + ) { + } + + public function getVersion(): string + { + return $this->version; + } + + public function getOptions(): array + { + return $this->options; + } + + public function supportsAudioInput(): bool + { + return false; // it does, but implementation here is still open; in_array($this->version, [self::GEMINI_2_FLASH, self::GEMINI_2_PRO, self::GEMINI_1_5_FLASH], true); + } + + public function supportsImageInput(): bool + { + return false; // it does, but implementation here is still open;in_array($this->version, [self::GEMINI_2_FLASH, self::GEMINI_2_PRO, self::GEMINI_2_FLASH_LITE, self::GEMINI_2_FLASH_THINKING, self::GEMINI_1_5_FLASH], true); + } + + public function supportsStreaming(): bool + { + return true; + } + + public function supportsStructuredOutput(): bool + { + return false; + } + + public function supportsToolCalling(): bool + { + return false; + } +} diff --git a/src/Bridge/Google/GoogleRequestBodyProducer.php b/src/Bridge/Google/GoogleRequestBodyProducer.php new file mode 100644 index 0000000..292b2b3 --- /dev/null +++ b/src/Bridge/Google/GoogleRequestBodyProducer.php @@ -0,0 +1,97 @@ +bag = $bag; + } + + public function createBody(): array + { + $contents = []; + foreach ($this->bag->withoutSystemMessage()->getMessages() as $message) { + $contents[] = [ + 'role' => $message->getRole(), + 'parts' => $message->accept($this), + ]; + } + + $body = [ + 'contents' => $contents, + ]; + + $systemMessage = $this->bag->getSystemMessage(); + if (null !== $systemMessage) { + $body['systemInstruction'] = [ + 'parts' => $systemMessage->accept($this), + ]; + } + + return $body; + } + + public function visitUserMessage(UserMessage $message): array + { + $parts = []; + foreach ($message->content as $content) { + $parts[] = [...$content->accept($this)]; + } + + return $parts; + } + + public function visitAssistantMessage(AssistantMessage $message): array + { + return [['text' => $message->content]]; + } + + public function visitSystemMessage(SystemMessage $message): array + { + return [['text' => $message->content]]; + } + + public function visitText(Text $content): array + { + return ['text' => $content->text]; + } + + public function visitImage(Image $content): array + { + // TODO: support image + return []; + } + + public function visitAudio(Audio $content): array + { + // TODO: support audio + return []; + } + + public function visitToolCallMessage(ToolCallMessage $message): array + { + // TODO: support tool call message + return []; + } + + public function jsonSerialize(): array + { + return $this->createBody(); + } +} diff --git a/src/Bridge/Google/ModelHandler.php b/src/Bridge/Google/ModelHandler.php new file mode 100644 index 0000000..f02adc2 --- /dev/null +++ b/src/Bridge/Google/ModelHandler.php @@ -0,0 +1,103 @@ +httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + } + + public function supports(Model $model, array|string|object $input): bool + { + return $model instanceof GoogleModel && $input instanceof MessageBagInterface; + } + + /** + * @throws TransportExceptionInterface + */ + public function request(Model $model, object|array|string $input, array $options = []): ResponseInterface + { + Assert::isInstanceOf($input, MessageBagInterface::class); + + $body = new GoogleRequestBodyProducer($input); + + return $this->httpClient->request('POST', sprintf('https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent', $model->getVersion()), [ + 'headers' => [ + 'x-goog-api-key' => $this->apiKey, + ], + 'json' => $body, + ]); + } + + /** + * @throws TransportExceptionInterface + * @throws ServerExceptionInterface + * @throws RedirectionExceptionInterface + * @throws DecodingExceptionInterface + * @throws ClientExceptionInterface + */ + public function convert(ResponseInterface $response, array $options = []): LlmResponse + { + if ($options['stream'] ?? false) { + return new StreamResponse($this->convertStream($response)); + } + + $data = $response->toArray(); + + if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) { + throw new RuntimeException('Response does not contain any content'); + } + + return new TextResponse($data['candidates'][0]['content']['parts'][0]['text']); + } + + private function convertStream(ResponseInterface $response): \Generator + { + foreach ((new EventSourceHttpClient())->stream($response) as $chunk) { + if (!$chunk instanceof ServerSentEvent || '[DONE]' === $chunk->getData()) { + continue; + } + + try { + $data = $chunk->getArrayData(); + } catch (JsonException) { + // try catch only needed for Symfony 6.4 + continue; + } + + if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) { + continue; + } + + yield $data['candidates'][0]['content']['parts'][0]['text']; + } + } +} diff --git a/src/Bridge/Google/PlatformFactory.php b/src/Bridge/Google/PlatformFactory.php new file mode 100644 index 0000000..7e601f8 --- /dev/null +++ b/src/Bridge/Google/PlatformFactory.php @@ -0,0 +1,23 @@ +visitAssistantMessage($this); + } } diff --git a/src/Model/Message/Content/Audio.php b/src/Model/Message/Content/Audio.php index fe04f9a..1927021 100644 --- a/src/Model/Message/Content/Audio.php +++ b/src/Model/Message/Content/Audio.php @@ -32,4 +32,9 @@ public function jsonSerialize(): array ], ]; } + + public function accept(ContentVisitor $visitor): array + { + return $visitor->visitAudio($this); + } } diff --git a/src/Model/Message/Content/Content.php b/src/Model/Message/Content/Content.php index a97cc9d..ee017d7 100644 --- a/src/Model/Message/Content/Content.php +++ b/src/Model/Message/Content/Content.php @@ -6,4 +6,5 @@ interface Content extends \JsonSerializable { + public function accept(ContentVisitor $visitor): array; } diff --git a/src/Model/Message/Content/ContentVisitor.php b/src/Model/Message/Content/ContentVisitor.php new file mode 100644 index 0000000..8c9c9a7 --- /dev/null +++ b/src/Model/Message/Content/ContentVisitor.php @@ -0,0 +1,12 @@ +visitImage($this); + } } diff --git a/src/Model/Message/Content/Text.php b/src/Model/Message/Content/Text.php index 08b0d87..551df00 100644 --- a/src/Model/Message/Content/Text.php +++ b/src/Model/Message/Content/Text.php @@ -18,4 +18,9 @@ public function jsonSerialize(): array { return ['type' => 'text', 'text' => $this->text]; } + + public function accept(ContentVisitor $visitor): array + { + return $visitor->visitText($this); + } } diff --git a/src/Model/Message/MessageInterface.php b/src/Model/Message/MessageInterface.php index efae114..34ad302 100644 --- a/src/Model/Message/MessageInterface.php +++ b/src/Model/Message/MessageInterface.php @@ -7,4 +7,6 @@ interface MessageInterface extends \JsonSerializable { public function getRole(): Role; + + public function accept(MessageVisitor $visitor): array; } diff --git a/src/Model/Message/MessageVisitor.php b/src/Model/Message/MessageVisitor.php new file mode 100644 index 0000000..3a084ca --- /dev/null +++ b/src/Model/Message/MessageVisitor.php @@ -0,0 +1,14 @@ + $this->content, ]; } + + public function accept(MessageVisitor $visitor): array + { + return $visitor->visitSystemMessage($this); + } } diff --git a/src/Model/Message/ToolCallMessage.php b/src/Model/Message/ToolCallMessage.php index 20a9767..2ddef68 100644 --- a/src/Model/Message/ToolCallMessage.php +++ b/src/Model/Message/ToolCallMessage.php @@ -34,4 +34,9 @@ public function jsonSerialize(): array 'tool_call_id' => $this->toolCall->id, ]; } + + public function accept(MessageVisitor $visitor): array + { + return $visitor->visitToolCallMessage($this); + } } diff --git a/src/Model/Message/UserMessage.php b/src/Model/Message/UserMessage.php index ef925c5..8cf9d05 100644 --- a/src/Model/Message/UserMessage.php +++ b/src/Model/Message/UserMessage.php @@ -68,4 +68,9 @@ public function jsonSerialize(): array return $array; } + + public function accept(MessageVisitor $visitor): array + { + return $visitor->visitUserMessage($this); + } } diff --git a/src/Platform/RequestBodyProducer.php b/src/Platform/RequestBodyProducer.php new file mode 100644 index 0000000..ff31b72 --- /dev/null +++ b/src/Platform/RequestBodyProducer.php @@ -0,0 +1,8 @@ +