Skip to content

Commit

Permalink
feat: adding llama support via replicate and ollama
Browse files Browse the repository at this point in the history
feat: LLama prompt (#135)

* Llama prompt

* -

* -

* -

* -

feat: wrap up ollama example

Update src/Platform/Replicate.php

Co-authored-by: Oskar Stark <[email protected]>
  • Loading branch information
chr-hertel and OskarStark committed Oct 13, 2024
1 parent 0c12cd1 commit 8f87f8b
Show file tree
Hide file tree
Showing 8 changed files with 439 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ ANTHROPIC_API_KEY=
# For using Voyage
VOYAGE_API_KEY=

# For using Replicate
REPLICATE_API_KEY=

# For using Ollama
OLLAMA_HOST_URL=

# For using GPT on Azure
AZURE_OPENAI_BASEURL=
AZURE_OPENAI_DEPLOYMENT=
Expand Down
29 changes: 29 additions & 0 deletions examples/chat-llama-ollama.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
<?php

use PhpLlm\LlmChain\Chain;
use PhpLlm\LlmChain\Message\Message;
use PhpLlm\LlmChain\Message\MessageBag;
use PhpLlm\LlmChain\Model\Language\Llama;
use PhpLlm\LlmChain\Platform\Ollama;
use Symfony\Component\Dotenv\Dotenv;
use Symfony\Component\HttpClient\HttpClient;

require_once dirname(__DIR__).'/vendor/autoload.php';
(new Dotenv())->loadEnv(dirname(__DIR__).'/.env');

if (empty($_ENV['OLLAMA_HOST_URL'])) {
echo 'Please set the OLLAMA_HOST_URL environment variable.'.PHP_EOL;
exit(1);
}

$platform = new Ollama(HttpClient::create(), $_ENV['OLLAMA_HOST_URL']);
$llm = new Llama($platform);

$chain = new Chain($llm);
$messages = new MessageBag(
Message::forSystem('You are a helpful assistant.'),
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'),
);
$response = $chain->call($messages);

echo $response->getContent().PHP_EOL;
29 changes: 29 additions & 0 deletions examples/chat-llama-replicate.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
<?php

use PhpLlm\LlmChain\Chain;
use PhpLlm\LlmChain\Message\Message;
use PhpLlm\LlmChain\Message\MessageBag;
use PhpLlm\LlmChain\Model\Language\Llama;
use PhpLlm\LlmChain\Platform\Replicate;
use Symfony\Component\Dotenv\Dotenv;
use Symfony\Component\HttpClient\HttpClient;

require_once dirname(__DIR__).'/vendor/autoload.php';
(new Dotenv())->loadEnv(dirname(__DIR__).'/.env');

if (empty($_ENV['REPLICATE_API_KEY'])) {
echo 'Please set the REPLICATE_API_KEY environment variable.'.PHP_EOL;
exit(1);
}

$platform = new Replicate(HttpClient::create(), $_ENV['REPLICATE_API_KEY']);
$llm = new Llama($platform);

$chain = new Chain($llm);
$messages = new MessageBag(
Message::forSystem('You are a helpful assistant.'),
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'),
);
$response = $chain->call($messages);

echo $response->getContent().PHP_EOL;
13 changes: 13 additions & 0 deletions ollama.http
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
### Llama 3.2 on Ollama
POST http://localhost:11434/api/chat

{
"model": "llama3.2",
"messages": [
{
"role": "user",
"content": "why is the sky blue?"
}
],
"stream": false
}
137 changes: 137 additions & 0 deletions src/Model/Language/Llama.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Model\Language;

use PhpLlm\LlmChain\Exception\RuntimeException;
use PhpLlm\LlmChain\LanguageModel;
use PhpLlm\LlmChain\Message\AssistantMessage;
use PhpLlm\LlmChain\Message\Content\Image;
use PhpLlm\LlmChain\Message\Content\Text;
use PhpLlm\LlmChain\Message\MessageBag;
use PhpLlm\LlmChain\Message\SystemMessage;
use PhpLlm\LlmChain\Message\UserMessage;
use PhpLlm\LlmChain\Platform\Ollama;
use PhpLlm\LlmChain\Platform\Replicate;
use PhpLlm\LlmChain\Response\TextResponse;

final readonly class Llama implements LanguageModel
{
public function __construct(
private Replicate|Ollama $platform,
) {
}

public function call(MessageBag $messages, array $options = []): TextResponse
{
if ($this->platform instanceof Replicate) {
$response = $this->platform->request('meta/meta-llama-3.1-405b-instruct', 'predictions', [
'system' => self::convertMessage($messages->getSystemMessage() ?? new SystemMessage('')),
'prompt' => self::convertToPrompt($messages->withoutSystemMessage()),
]);

return new TextResponse(implode('', $response['output']));
}

$response = $this->platform->request('llama3.2', 'chat', ['messages' => $messages, 'stream' => false]);

return new TextResponse($response['message']['content']);
}

/**
* @todo make method private, just for testing, or create a MessageBag to LLama convert class :thinking:
*/
public static function convertToPrompt(MessageBag $messageBag): string
{
$messages = [];

/** @var UserMessage|SystemMessage|AssistantMessage $message */
foreach ($messageBag->getIterator() as $message) {
$messages[] = self::convertMessage($message);
}

$messages = array_filter($messages, fn ($message) => '' !== $message);

return trim(implode(PHP_EOL.PHP_EOL, $messages)).PHP_EOL.PHP_EOL.'<|start_header_id|>assistant<|end_header_id|>';
}

/**
* @todo make method private, just for testing
*/
public static function convertMessage(UserMessage|SystemMessage|AssistantMessage $message): string
{
if ($message instanceof SystemMessage) {
return trim(<<<SYSTEM
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{$message->content}<|eot_id|>
SYSTEM);
}

if ($message instanceof AssistantMessage) {
if ('' === $message->content || null === $message->content) {
return '';
}

return trim(<<<ASSISTANT
<|start_header_id|>{$message->getRole()->value}<|end_header_id|>
{$message->content}<|eot_id|>
ASSISTANT);
}

if ($message instanceof UserMessage) {
$count = count($message->content);

$contentParts = [];
if ($count > 1) {
foreach ($message->content as $value) {
if ($value instanceof Text) {
$contentParts[] = $value->text;
}

if ($value instanceof Image) {
$contentParts[] = $value->url;
}
}
} elseif (1 === $count) {
$value = $message->content[0];
if ($value instanceof Text) {
$contentParts[] = $value->text;
}

if ($value instanceof Image) {
$contentParts[] = $value->url;
}
} else {
throw new RuntimeException('Unsupported message type.');
}

$content = implode(PHP_EOL, $contentParts);

return trim(<<<USER
<|start_header_id|>{$message->getRole()->value}<|end_header_id|>
{$content}<|eot_id|>
USER);
}

throw new RuntimeException('Unsupported message type.'); // @phpstan-ignore-line
}

public function supportsToolCalling(): bool
{
return false; // it does, but implementation here is still open.
}

public function supportsImageInput(): bool
{
return false; // it does, but implementation here is still open.
}

public function supportsStructuredOutput(): bool
{
return false;
}
}
36 changes: 36 additions & 0 deletions src/Platform/Ollama.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Platform;

use Symfony\Contracts\HttpClient\HttpClientInterface;

final readonly class Ollama
{
public function __construct(
private HttpClientInterface $httpClient,
private string $hostUrl,
) {
}

/**
* @param string $model The model name on Replicate, e.g. "meta/meta-llama-3.1-405b-instruct"
* @param array<string, mixed> $body
*
* @return array<string, mixed>
*/
public function request(string $model, string $endpoint, array $body): array
{
$url = sprintf('%s/api/%s', $this->hostUrl, $endpoint);

$response = $this->httpClient->request('POST', $url, [
'headers' => ['Content-Type' => 'application/json'],
'json' => array_merge($body, [
'model' => $model,
]),
]);

return $response->toArray();
}
}
56 changes: 56 additions & 0 deletions src/Platform/Replicate.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Platform;

use Symfony\Contracts\HttpClient\HttpClientInterface;

final readonly class Replicate
{
public function __construct(
private HttpClientInterface $httpClient,
#[\SensitiveParameter] private string $apiKey,
) {
}

/**
* @param string $model The model name on Replicate, e.g. "meta/meta-llama-3.1-405b-instruct"
* @param array<string, mixed> $body
*
* @return array<string, mixed>
*/
public function request(string $model, string $endpoint, array $body): array
{
$url = sprintf('https://api.replicate.com/v1/models/%s/%s', $model, $endpoint);

$response = $this->httpClient->request('POST', $url, [
'headers' => ['Content-Type' => 'application/json'],
'auth_bearer' => $this->apiKey,
'json' => ['input' => $body],
])->toArray();

while (!in_array($response['status'], ['succeeded', 'failed', 'canceled'], true)) {
sleep(1);

$response = $this->getResponse($response['id']);
}

return $response;
}

/**
* @return array<string, mixed>
*/
private function getResponse(string $id): array
{
$url = sprintf('https://api.replicate.com/v1/predictions/%s', $id);

$response = $this->httpClient->request('GET', $url, [
'headers' => ['Content-Type' => 'application/json'],
'auth_bearer' => $this->apiKey,
]);

return $response->toArray();
}
}
Loading

0 comments on commit 8f87f8b

Please sign in to comment.