Skip to content

Commit

Permalink
feat: adding llama support via replicate and ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
chr-hertel committed Oct 5, 2024
1 parent eb8eca5 commit 3af2a0a
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ ANTHROPIC_API_KEY=
# For using Voyage
VOYAGE_API_KEY=

# For using Replicate
REPLICATE_API_KEY=

# For using Ollama
OLLAMA_HOST_URL=

# For using GPT on Azure
AZURE_OPENAI_BASEURL=
AZURE_OPENAI_DEPLOYMENT=
Expand Down
29 changes: 29 additions & 0 deletions examples/chat-llama-ollama.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
<?php

use PhpLlm\LlmChain\Chain;
use PhpLlm\LlmChain\Message\Message;
use PhpLlm\LlmChain\Message\MessageBag;
use PhpLlm\LlmChain\Model\Language\Llama;
use PhpLlm\LlmChain\Platform\Ollama;
use Symfony\Component\Dotenv\Dotenv;
use Symfony\Component\HttpClient\HttpClient;

require_once dirname(__DIR__).'/vendor/autoload.php';
(new Dotenv())->loadEnv(dirname(__DIR__).'/.env');

if (empty($_ENV['OLLAMA_HOST_URL'])) {
echo 'Please set the OLLAMA_HOST_URL environment variable.'.PHP_EOL;
exit(1);
}

$platform = new Ollama(HttpClient::create(), $_ENV['OLLAMA_HOST_URL']);
$llm = new Llama($platform);

$chain = new Chain($llm);
$messages = new MessageBag(
Message::forSystem('You are a helpful assistant.'),
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'),
);
$response = $chain->call($messages);

echo $response.PHP_EOL;

Check failure on line 29 in examples/chat-llama-ollama.php

View workflow job for this annotation

GitHub Actions / qa

Binary operation "." between PhpLlm\LlmChain\Response\ResponseInterface and "\n"|"\r\n" results in an error.
29 changes: 29 additions & 0 deletions examples/chat-llama-replicate.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
<?php

use PhpLlm\LlmChain\Chain;
use PhpLlm\LlmChain\Message\Message;
use PhpLlm\LlmChain\Message\MessageBag;
use PhpLlm\LlmChain\Model\Language\Llama;
use PhpLlm\LlmChain\Platform\Replicate;
use Symfony\Component\Dotenv\Dotenv;
use Symfony\Component\HttpClient\HttpClient;

require_once dirname(__DIR__).'/vendor/autoload.php';
(new Dotenv())->loadEnv(dirname(__DIR__).'/.env');

if (empty($_ENV['REPLICATE_API_KEY'])) {
echo 'Please set the REPLICATE_API_KEY environment variable.'.PHP_EOL;
exit(1);
}

$platform = new Replicate(HttpClient::create(), $_ENV['REPLICATE_API_KEY']);
$llm = new Llama($platform);

$chain = new Chain($llm);
$messages = new MessageBag(
Message::forSystem('You are a helpful assistant.'),
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'),
);
$response = $chain->call($messages);

echo $response.PHP_EOL;

Check failure on line 29 in examples/chat-llama-replicate.php

View workflow job for this annotation

GitHub Actions / qa

Binary operation "." between PhpLlm\LlmChain\Response\ResponseInterface and "\n"|"\r\n" results in an error.
48 changes: 48 additions & 0 deletions src/Model/Language/Llama.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Model\Language;

use PhpLlm\LlmChain\LanguageModel;
use PhpLlm\LlmChain\Message\MessageBag;
use PhpLlm\LlmChain\Platform\Ollama;
use PhpLlm\LlmChain\Platform\Replicate;
use PhpLlm\LlmChain\Response\Choice;
use PhpLlm\LlmChain\Response\Response;

final readonly class Llama implements LanguageModel
{
public function __construct(
private Replicate|Ollama $platform,
) {
}

public function call(MessageBag $messages, array $options = []): Response
{
$systemMessage = $messages->getSystemMessage();
$endpoint = $this->platform instanceof Replicate ? 'predictions' : 'chat';

$response = $this->platform->request('meta/meta-llama-3.1-405b-instruct', $endpoint, [
'system' => $systemMessage?->content,
'prompt' => $messages->withoutSystemMessage()->getIterator()->current()->content[0]->text, // @phpstan-ignore-line TODO: Multiple messages
]);

return new Response(new Choice(implode('', $response['output'])));
}

public function supportsToolCalling(): bool
{
return false; // it does, but implementation here is still open.
}

public function supportsImageInput(): bool
{
return false; // it does, but implementation here is still open.
}

public function supportsStructuredOutput(): bool
{
return false;
}
}
36 changes: 36 additions & 0 deletions src/Platform/Ollama.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Platform;

use Symfony\Contracts\HttpClient\HttpClientInterface;

final readonly class Ollama
{
public function __construct(
private HttpClientInterface $httpClient,
private string $hostUrl,
) {
}

/**
* @param string $model The model name on Replicate, e.g. "meta/meta-llama-3.1-405b-instruct"
* @param array<string, mixed> $body
*
* @return array<string, mixed>
*/
public function request(string $model, string $endpoint, array $body): array
{
$url = sprintf('%s/api/%s', $this->hostUrl, $endpoint);

$response = $this->httpClient->request('POST', $url, [
'headers' => ['Content-Type' => 'application/json'],
'json' => array_merge($body, [
'model' => $model,
]),
]);

return dump($response->toArray());
}
}
56 changes: 56 additions & 0 deletions src/Platform/Replicate.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Platform;

use Symfony\Contracts\HttpClient\HttpClientInterface;

final readonly class Replicate
{
public function __construct(
private HttpClientInterface $httpClient,
private string $apiKey,
) {
}

/**
* @param string $model The model name on Replicate, e.g. "meta/meta-llama-3.1-405b-instruct"
* @param array<string, mixed> $body
*
* @return array<string, mixed>
*/
public function request(string $model, string $endpoint, array $body): array
{
$url = sprintf('https://api.replicate.com/v1/models/%s/%s', $model, $endpoint);

$response = $this->httpClient->request('POST', $url, [
'headers' => ['Content-Type' => 'application/json'],
'auth_bearer' => $this->apiKey,
'json' => ['input' => $body],
])->toArray();

while (!in_array($response['status'], ['succeeded', 'failed', 'canceled'], true)) {
sleep(1);

$response = $this->getResponse($response['id']);
}

return $response;
}

/**
* @return array<string, mixed>
*/
private function getResponse(string $id): array
{
$url = sprintf('https://api.replicate.com/v1/predictions/%s', $id);

$response = $this->httpClient->request('GET', $url, [
'headers' => ['Content-Type' => 'application/json'],
'auth_bearer' => $this->apiKey,
]);

return $response->toArray();
}
}

0 comments on commit 3af2a0a

Please sign in to comment.