Skip to content

Commit

Permalink
feat: added initial Google provider supporting basic text generation …
Browse files Browse the repository at this point in the history
…with it
  • Loading branch information
onemoreangle authored and chr-hertel committed Feb 21, 2025
1 parent d7ecb4d commit 9271a5b
Show file tree
Hide file tree
Showing 18 changed files with 386 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,6 @@ PINECONE_HOST=

# Some examples are expensive to run, so we disable them by default
RUN_EXPENSIVE_EXAMPLES=false

# For using Gemini
GOOGLE_API_KEY=
28 changes: 28 additions & 0 deletions examples/chat-gemini-google.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<?php

use PhpLlm\LlmChain\Bridge\Google\GoogleModel;
use PhpLlm\LlmChain\Bridge\Google\PlatformFactory;
use PhpLlm\LlmChain\Chain;
use PhpLlm\LlmChain\Model\Message\Message;
use PhpLlm\LlmChain\Model\Message\MessageBag;
use Symfony\Component\Dotenv\Dotenv;

require_once dirname(__DIR__).'/vendor/autoload.php';
(new Dotenv())->loadEnv(dirname(__DIR__).'/.env');

if (empty($_ENV['GOOGLE_API_KEY'])) {
echo 'Please set the GOOGLE_API_KEY environment variable.'.PHP_EOL;
exit(1);
}

$platform = PlatformFactory::create($_ENV['GOOGLE_API_KEY']);
$llm = new GoogleModel(GoogleModel::GEMINI_2_FLASH);

$chain = new Chain($platform, $llm);
$messages = new MessageBag(
Message::forSystem('You are a pirate and you write funny.'),
Message::ofUser('What is the Symfony framework?'),
);
$response = $chain->call($messages);

echo $response->getContent().PHP_EOL;
60 changes: 60 additions & 0 deletions src/Bridge/Google/GoogleModel.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Bridge\Google;

use PhpLlm\LlmChain\Model\LanguageModel;

final readonly class GoogleModel implements LanguageModel
{
public const GEMINI_2_FLASH = 'gemini-2.0-flash';
public const GEMINI_2_PRO = 'gemini-2.0-pro-exp-02-05';
public const GEMINI_2_FLASH_LITE = 'gemini-2.0-flash-lite-preview-02-05';
public const GEMINI_2_FLASH_THINKING = 'gemini-2.0-flash-thinking-exp-01-21';
public const GEMINI_1_5_FLASH = 'gemini-1.5-flash';

/**
* @param array<string, mixed> $options The default options for the model usage
*/
public function __construct(
private string $version = self::GEMINI_2_PRO,
private array $options = ['temperature' => 1.0],
) {
}

public function getVersion(): string
{
return $this->version;
}

public function getOptions(): array
{
return $this->options;
}

public function supportsAudioInput(): bool
{
return false; // it does, but implementation here is still open; in_array($this->version, [self::GEMINI_2_FLASH, self::GEMINI_2_PRO, self::GEMINI_1_5_FLASH], true);
}

public function supportsImageInput(): bool
{
return false; // it does, but implementation here is still open;in_array($this->version, [self::GEMINI_2_FLASH, self::GEMINI_2_PRO, self::GEMINI_2_FLASH_LITE, self::GEMINI_2_FLASH_THINKING, self::GEMINI_1_5_FLASH], true);
}

public function supportsStreaming(): bool
{
return true;
}

public function supportsStructuredOutput(): bool
{
return false;
}

public function supportsToolCalling(): bool
{
return false;
}
}
97 changes: 97 additions & 0 deletions src/Bridge/Google/GoogleRequestBodyProducer.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
<?php

namespace PhpLlm\LlmChain\Bridge\Google;

use PhpLlm\LlmChain\Model\Message\AssistantMessage;
use PhpLlm\LlmChain\Model\Message\Content\Audio;
use PhpLlm\LlmChain\Model\Message\Content\ContentVisitor;
use PhpLlm\LlmChain\Model\Message\Content\Image;
use PhpLlm\LlmChain\Model\Message\Content\Text;
use PhpLlm\LlmChain\Model\Message\MessageBagInterface;
use PhpLlm\LlmChain\Model\Message\MessageVisitor;
use PhpLlm\LlmChain\Model\Message\SystemMessage;
use PhpLlm\LlmChain\Model\Message\ToolCallMessage;
use PhpLlm\LlmChain\Model\Message\UserMessage;
use PhpLlm\LlmChain\Platform\RequestBodyProducer;

final class GoogleRequestBodyProducer implements RequestBodyProducer, MessageVisitor, ContentVisitor, \JsonSerializable
{
protected MessageBagInterface $bag;

public function __construct(MessageBagInterface $bag)
{
$this->bag = $bag;
}

public function createBody(): array
{
$contents = [];
foreach ($this->bag->withoutSystemMessage()->getMessages() as $message) {
$contents[] = [
'role' => $message->getRole(),
'parts' => $message->accept($this),
];
}

$body = [
'contents' => $contents,
];

$systemMessage = $this->bag->getSystemMessage();
if (null !== $systemMessage) {
$body['systemInstruction'] = [
'parts' => $systemMessage->accept($this),
];
}

return $body;
}

public function visitUserMessage(UserMessage $message): array
{
$parts = [];
foreach ($message->content as $content) {
$parts[] = [...$content->accept($this)];
}

return $parts;
}

public function visitAssistantMessage(AssistantMessage $message): array
{
return [['text' => $message->content]];
}

public function visitSystemMessage(SystemMessage $message): array
{
return [['text' => $message->content]];
}

public function visitText(Text $content): array
{
return ['text' => $content->text];
}

public function visitImage(Image $content): array
{
// TODO: support image
return [];
}

public function visitAudio(Audio $content): array
{
// TODO: support audio
return [];
}

public function visitToolCallMessage(ToolCallMessage $message): array
{
// TODO: support tool call message
return [];
}

public function jsonSerialize(): array
{
return $this->createBody();
}
}
103 changes: 103 additions & 0 deletions src/Bridge/Google/ModelHandler.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Bridge\Google;

use PhpLlm\LlmChain\Exception\RuntimeException;
use PhpLlm\LlmChain\Model\Message\MessageBagInterface;
use PhpLlm\LlmChain\Model\Model;
use PhpLlm\LlmChain\Model\Response\ResponseInterface as LlmResponse;
use PhpLlm\LlmChain\Model\Response\StreamResponse;
use PhpLlm\LlmChain\Model\Response\TextResponse;
use PhpLlm\LlmChain\Platform\ModelClient;
use PhpLlm\LlmChain\Platform\ResponseConverter;
use Symfony\Component\HttpClient\Chunk\ServerSentEvent;
use Symfony\Component\HttpClient\EventSourceHttpClient;
use Symfony\Component\HttpClient\Exception\JsonException;
use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface;
use Symfony\Contracts\HttpClient\Exception\DecodingExceptionInterface;
use Symfony\Contracts\HttpClient\Exception\RedirectionExceptionInterface;
use Symfony\Contracts\HttpClient\Exception\ServerExceptionInterface;
use Symfony\Contracts\HttpClient\Exception\TransportExceptionInterface;
use Symfony\Contracts\HttpClient\HttpClientInterface;
use Symfony\Contracts\HttpClient\ResponseInterface;
use Webmozart\Assert\Assert;

final readonly class ModelHandler implements ModelClient, ResponseConverter
{
private EventSourceHttpClient $httpClient;

public function __construct(
HttpClientInterface $httpClient,
#[\SensitiveParameter] private string $apiKey,
) {
$this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);
}

public function supports(Model $model, array|string|object $input): bool
{
return $model instanceof GoogleModel && $input instanceof MessageBagInterface;
}

/**
* @throws TransportExceptionInterface
*/
public function request(Model $model, object|array|string $input, array $options = []): ResponseInterface
{
Assert::isInstanceOf($input, MessageBagInterface::class);

$body = new GoogleRequestBodyProducer($input);

return $this->httpClient->request('POST', sprintf('https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent', $model->getVersion()), [
'headers' => [
'x-goog-api-key' => $this->apiKey,
],
'json' => $body,
]);
}

/**
* @throws TransportExceptionInterface
* @throws ServerExceptionInterface
* @throws RedirectionExceptionInterface
* @throws DecodingExceptionInterface
* @throws ClientExceptionInterface
*/
public function convert(ResponseInterface $response, array $options = []): LlmResponse
{
if ($options['stream'] ?? false) {
return new StreamResponse($this->convertStream($response));
}

$data = $response->toArray();

if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) {
throw new RuntimeException('Response does not contain any content');
}

return new TextResponse($data['candidates'][0]['content']['parts'][0]['text']);
}

private function convertStream(ResponseInterface $response): \Generator
{
foreach ((new EventSourceHttpClient())->stream($response) as $chunk) {
if (!$chunk instanceof ServerSentEvent || '[DONE]' === $chunk->getData()) {
continue;
}

try {
$data = $chunk->getArrayData();
} catch (JsonException) {
// try catch only needed for Symfony 6.4
continue;
}

if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) {
continue;
}

yield $data['candidates'][0]['content']['parts'][0]['text'];
}
}
}
23 changes: 23 additions & 0 deletions src/Bridge/Google/PlatformFactory.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Bridge\Google;

use PhpLlm\LlmChain\Platform;
use Symfony\Component\HttpClient\EventSourceHttpClient;
use Symfony\Contracts\HttpClient\HttpClientInterface;

final readonly class PlatformFactory
{
public static function create(
#[\SensitiveParameter]
string $apiKey,
?HttpClientInterface $httpClient = null,
): Platform {
$httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);
$responseHandler = new ModelHandler($httpClient, $apiKey);

return new Platform([$responseHandler], [$responseHandler]);
}
}
5 changes: 5 additions & 0 deletions src/Model/Message/AssistantMessage.php
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,9 @@ public function jsonSerialize(): array

return $array;
}

public function accept(MessageVisitor $visitor): array
{
return $visitor->visitAssistantMessage($this);
}
}
5 changes: 5 additions & 0 deletions src/Model/Message/Content/Audio.php
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,9 @@ public function jsonSerialize(): array
],
];
}

public function accept(ContentVisitor $visitor): array
{
return $visitor->visitAudio($this);
}
}
1 change: 1 addition & 0 deletions src/Model/Message/Content/Content.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@

interface Content extends \JsonSerializable
{
public function accept(ContentVisitor $visitor): array;
}
12 changes: 12 additions & 0 deletions src/Model/Message/Content/ContentVisitor.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<?php

namespace PhpLlm\LlmChain\Model\Message\Content;

interface ContentVisitor
{
public function visitAudio(Audio $content): array;

public function visitImage(Image $content): array;

public function visitText(Text $content): array;
}
5 changes: 5 additions & 0 deletions src/Model/Message/Content/Image.php
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,9 @@ private function fromFile(string $filePath): string

return sprintf('data:image/%s;base64,%s', $type, base64_encode($data));
}

public function accept(ContentVisitor $visitor): array
{
return $visitor->visitImage($this);
}
}
5 changes: 5 additions & 0 deletions src/Model/Message/Content/Text.php
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,9 @@ public function jsonSerialize(): array
{
return ['type' => 'text', 'text' => $this->text];
}

public function accept(ContentVisitor $visitor): array
{
return $visitor->visitText($this);
}
}
2 changes: 2 additions & 0 deletions src/Model/Message/MessageInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@
interface MessageInterface extends \JsonSerializable
{
public function getRole(): Role;

public function accept(MessageVisitor $visitor): array;
}
Loading

0 comments on commit 9271a5b

Please sign in to comment.