-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: adding llama support via replicate and ollama
feat: LLama prompt (#135) * Llama prompt * - * - * - * - feat: wrap up ollama example Update src/Platform/Replicate.php Co-authored-by: Oskar Stark <[email protected]>
- Loading branch information
1 parent
0c12cd1
commit 8f87f8b
Showing
8 changed files
with
439 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
<?php | ||
|
||
use PhpLlm\LlmChain\Chain; | ||
use PhpLlm\LlmChain\Message\Message; | ||
use PhpLlm\LlmChain\Message\MessageBag; | ||
use PhpLlm\LlmChain\Model\Language\Llama; | ||
use PhpLlm\LlmChain\Platform\Ollama; | ||
use Symfony\Component\Dotenv\Dotenv; | ||
use Symfony\Component\HttpClient\HttpClient; | ||
|
||
require_once dirname(__DIR__).'/vendor/autoload.php'; | ||
(new Dotenv())->loadEnv(dirname(__DIR__).'/.env'); | ||
|
||
if (empty($_ENV['OLLAMA_HOST_URL'])) { | ||
echo 'Please set the OLLAMA_HOST_URL environment variable.'.PHP_EOL; | ||
exit(1); | ||
} | ||
|
||
$platform = new Ollama(HttpClient::create(), $_ENV['OLLAMA_HOST_URL']); | ||
$llm = new Llama($platform); | ||
|
||
$chain = new Chain($llm); | ||
$messages = new MessageBag( | ||
Message::forSystem('You are a helpful assistant.'), | ||
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'), | ||
); | ||
$response = $chain->call($messages); | ||
|
||
echo $response->getContent().PHP_EOL; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
<?php | ||
|
||
use PhpLlm\LlmChain\Chain; | ||
use PhpLlm\LlmChain\Message\Message; | ||
use PhpLlm\LlmChain\Message\MessageBag; | ||
use PhpLlm\LlmChain\Model\Language\Llama; | ||
use PhpLlm\LlmChain\Platform\Replicate; | ||
use Symfony\Component\Dotenv\Dotenv; | ||
use Symfony\Component\HttpClient\HttpClient; | ||
|
||
require_once dirname(__DIR__).'/vendor/autoload.php'; | ||
(new Dotenv())->loadEnv(dirname(__DIR__).'/.env'); | ||
|
||
if (empty($_ENV['REPLICATE_API_KEY'])) { | ||
echo 'Please set the REPLICATE_API_KEY environment variable.'.PHP_EOL; | ||
exit(1); | ||
} | ||
|
||
$platform = new Replicate(HttpClient::create(), $_ENV['REPLICATE_API_KEY']); | ||
$llm = new Llama($platform); | ||
|
||
$chain = new Chain($llm); | ||
$messages = new MessageBag( | ||
Message::forSystem('You are a helpful assistant.'), | ||
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'), | ||
); | ||
$response = $chain->call($messages); | ||
|
||
echo $response->getContent().PHP_EOL; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
### Llama 3.2 on Ollama | ||
POST http://localhost:11434/api/chat | ||
|
||
{ | ||
"model": "llama3.2", | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": "why is the sky blue?" | ||
} | ||
], | ||
"stream": false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace PhpLlm\LlmChain\Model\Language; | ||
|
||
use PhpLlm\LlmChain\Exception\RuntimeException; | ||
use PhpLlm\LlmChain\LanguageModel; | ||
use PhpLlm\LlmChain\Message\AssistantMessage; | ||
use PhpLlm\LlmChain\Message\Content\Image; | ||
use PhpLlm\LlmChain\Message\Content\Text; | ||
use PhpLlm\LlmChain\Message\MessageBag; | ||
use PhpLlm\LlmChain\Message\SystemMessage; | ||
use PhpLlm\LlmChain\Message\UserMessage; | ||
use PhpLlm\LlmChain\Platform\Ollama; | ||
use PhpLlm\LlmChain\Platform\Replicate; | ||
use PhpLlm\LlmChain\Response\TextResponse; | ||
|
||
final readonly class Llama implements LanguageModel | ||
{ | ||
public function __construct( | ||
private Replicate|Ollama $platform, | ||
) { | ||
} | ||
|
||
public function call(MessageBag $messages, array $options = []): TextResponse | ||
{ | ||
if ($this->platform instanceof Replicate) { | ||
$response = $this->platform->request('meta/meta-llama-3.1-405b-instruct', 'predictions', [ | ||
'system' => self::convertMessage($messages->getSystemMessage() ?? new SystemMessage('')), | ||
'prompt' => self::convertToPrompt($messages->withoutSystemMessage()), | ||
]); | ||
|
||
return new TextResponse(implode('', $response['output'])); | ||
} | ||
|
||
$response = $this->platform->request('llama3.2', 'chat', ['messages' => $messages, 'stream' => false]); | ||
|
||
return new TextResponse($response['message']['content']); | ||
} | ||
|
||
/** | ||
* @todo make method private, just for testing, or create a MessageBag to LLama convert class :thinking: | ||
*/ | ||
public static function convertToPrompt(MessageBag $messageBag): string | ||
{ | ||
$messages = []; | ||
|
||
/** @var UserMessage|SystemMessage|AssistantMessage $message */ | ||
foreach ($messageBag->getIterator() as $message) { | ||
$messages[] = self::convertMessage($message); | ||
} | ||
|
||
$messages = array_filter($messages, fn ($message) => '' !== $message); | ||
|
||
return trim(implode(PHP_EOL.PHP_EOL, $messages)).PHP_EOL.PHP_EOL.'<|start_header_id|>assistant<|end_header_id|>'; | ||
} | ||
|
||
/** | ||
* @todo make method private, just for testing | ||
*/ | ||
public static function convertMessage(UserMessage|SystemMessage|AssistantMessage $message): string | ||
{ | ||
if ($message instanceof SystemMessage) { | ||
return trim(<<<SYSTEM | ||
<|begin_of_text|><|start_header_id|>system<|end_header_id|> | ||
{$message->content}<|eot_id|> | ||
SYSTEM); | ||
} | ||
|
||
if ($message instanceof AssistantMessage) { | ||
if ('' === $message->content || null === $message->content) { | ||
return ''; | ||
} | ||
|
||
return trim(<<<ASSISTANT | ||
<|start_header_id|>{$message->getRole()->value}<|end_header_id|> | ||
{$message->content}<|eot_id|> | ||
ASSISTANT); | ||
} | ||
|
||
if ($message instanceof UserMessage) { | ||
$count = count($message->content); | ||
|
||
$contentParts = []; | ||
if ($count > 1) { | ||
foreach ($message->content as $value) { | ||
if ($value instanceof Text) { | ||
$contentParts[] = $value->text; | ||
} | ||
|
||
if ($value instanceof Image) { | ||
$contentParts[] = $value->url; | ||
} | ||
} | ||
} elseif (1 === $count) { | ||
$value = $message->content[0]; | ||
if ($value instanceof Text) { | ||
$contentParts[] = $value->text; | ||
} | ||
|
||
if ($value instanceof Image) { | ||
$contentParts[] = $value->url; | ||
} | ||
} else { | ||
throw new RuntimeException('Unsupported message type.'); | ||
} | ||
|
||
$content = implode(PHP_EOL, $contentParts); | ||
|
||
return trim(<<<USER | ||
<|start_header_id|>{$message->getRole()->value}<|end_header_id|> | ||
{$content}<|eot_id|> | ||
USER); | ||
} | ||
|
||
throw new RuntimeException('Unsupported message type.'); // @phpstan-ignore-line | ||
} | ||
|
||
public function supportsToolCalling(): bool | ||
{ | ||
return false; // it does, but implementation here is still open. | ||
} | ||
|
||
public function supportsImageInput(): bool | ||
{ | ||
return false; // it does, but implementation here is still open. | ||
} | ||
|
||
public function supportsStructuredOutput(): bool | ||
{ | ||
return false; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace PhpLlm\LlmChain\Platform; | ||
|
||
use Symfony\Contracts\HttpClient\HttpClientInterface; | ||
|
||
final readonly class Ollama | ||
{ | ||
public function __construct( | ||
private HttpClientInterface $httpClient, | ||
private string $hostUrl, | ||
) { | ||
} | ||
|
||
/** | ||
* @param string $model The model name on Replicate, e.g. "meta/meta-llama-3.1-405b-instruct" | ||
* @param array<string, mixed> $body | ||
* | ||
* @return array<string, mixed> | ||
*/ | ||
public function request(string $model, string $endpoint, array $body): array | ||
{ | ||
$url = sprintf('%s/api/%s', $this->hostUrl, $endpoint); | ||
|
||
$response = $this->httpClient->request('POST', $url, [ | ||
'headers' => ['Content-Type' => 'application/json'], | ||
'json' => array_merge($body, [ | ||
'model' => $model, | ||
]), | ||
]); | ||
|
||
return $response->toArray(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace PhpLlm\LlmChain\Platform; | ||
|
||
use Symfony\Contracts\HttpClient\HttpClientInterface; | ||
|
||
final readonly class Replicate | ||
{ | ||
public function __construct( | ||
private HttpClientInterface $httpClient, | ||
#[\SensitiveParameter] private string $apiKey, | ||
) { | ||
} | ||
|
||
/** | ||
* @param string $model The model name on Replicate, e.g. "meta/meta-llama-3.1-405b-instruct" | ||
* @param array<string, mixed> $body | ||
* | ||
* @return array<string, mixed> | ||
*/ | ||
public function request(string $model, string $endpoint, array $body): array | ||
{ | ||
$url = sprintf('https://api.replicate.com/v1/models/%s/%s', $model, $endpoint); | ||
|
||
$response = $this->httpClient->request('POST', $url, [ | ||
'headers' => ['Content-Type' => 'application/json'], | ||
'auth_bearer' => $this->apiKey, | ||
'json' => ['input' => $body], | ||
])->toArray(); | ||
|
||
while (!in_array($response['status'], ['succeeded', 'failed', 'canceled'], true)) { | ||
sleep(1); | ||
|
||
$response = $this->getResponse($response['id']); | ||
} | ||
|
||
return $response; | ||
} | ||
|
||
/** | ||
* @return array<string, mixed> | ||
*/ | ||
private function getResponse(string $id): array | ||
{ | ||
$url = sprintf('https://api.replicate.com/v1/predictions/%s', $id); | ||
|
||
$response = $this->httpClient->request('GET', $url, [ | ||
'headers' => ['Content-Type' => 'application/json'], | ||
'auth_bearer' => $this->apiKey, | ||
]); | ||
|
||
return $response->toArray(); | ||
} | ||
} |
Oops, something went wrong.