-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: adding llama support via replicate and ollama
- Loading branch information
1 parent
eb8eca5
commit 3af2a0a
Showing
6 changed files
with
204 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
<?php | ||
|
||
use PhpLlm\LlmChain\Chain; | ||
use PhpLlm\LlmChain\Message\Message; | ||
use PhpLlm\LlmChain\Message\MessageBag; | ||
use PhpLlm\LlmChain\Model\Language\Llama; | ||
use PhpLlm\LlmChain\Platform\Ollama; | ||
use Symfony\Component\Dotenv\Dotenv; | ||
use Symfony\Component\HttpClient\HttpClient; | ||
|
||
require_once dirname(__DIR__).'/vendor/autoload.php'; | ||
(new Dotenv())->loadEnv(dirname(__DIR__).'/.env'); | ||
|
||
if (empty($_ENV['OLLAMA_HOST_URL'])) { | ||
echo 'Please set the OLLAMA_HOST_URL environment variable.'.PHP_EOL; | ||
exit(1); | ||
} | ||
|
||
$platform = new Ollama(HttpClient::create(), $_ENV['OLLAMA_HOST_URL']); | ||
$llm = new Llama($platform); | ||
|
||
$chain = new Chain($llm); | ||
$messages = new MessageBag( | ||
Message::forSystem('You are a helpful assistant.'), | ||
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'), | ||
); | ||
$response = $chain->call($messages); | ||
|
||
echo $response.PHP_EOL; | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
<?php | ||
|
||
use PhpLlm\LlmChain\Chain; | ||
use PhpLlm\LlmChain\Message\Message; | ||
use PhpLlm\LlmChain\Message\MessageBag; | ||
use PhpLlm\LlmChain\Model\Language\Llama; | ||
use PhpLlm\LlmChain\Platform\Replicate; | ||
use Symfony\Component\Dotenv\Dotenv; | ||
use Symfony\Component\HttpClient\HttpClient; | ||
|
||
require_once dirname(__DIR__).'/vendor/autoload.php'; | ||
(new Dotenv())->loadEnv(dirname(__DIR__).'/.env'); | ||
|
||
if (empty($_ENV['REPLICATE_API_KEY'])) { | ||
echo 'Please set the REPLICATE_API_KEY environment variable.'.PHP_EOL; | ||
exit(1); | ||
} | ||
|
||
$platform = new Replicate(HttpClient::create(), $_ENV['REPLICATE_API_KEY']); | ||
$llm = new Llama($platform); | ||
|
||
$chain = new Chain($llm); | ||
$messages = new MessageBag( | ||
Message::forSystem('You are a helpful assistant.'), | ||
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'), | ||
); | ||
$response = $chain->call($messages); | ||
|
||
echo $response.PHP_EOL; | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace PhpLlm\LlmChain\Model\Language; | ||
|
||
use PhpLlm\LlmChain\LanguageModel; | ||
use PhpLlm\LlmChain\Message\MessageBag; | ||
use PhpLlm\LlmChain\Platform\Ollama; | ||
use PhpLlm\LlmChain\Platform\Replicate; | ||
use PhpLlm\LlmChain\Response\Choice; | ||
use PhpLlm\LlmChain\Response\Response; | ||
|
||
final readonly class Llama implements LanguageModel | ||
{ | ||
public function __construct( | ||
private Replicate|Ollama $platform, | ||
) { | ||
} | ||
|
||
public function call(MessageBag $messages, array $options = []): Response | ||
{ | ||
$systemMessage = $messages->getSystemMessage(); | ||
$endpoint = $this->platform instanceof Replicate ? 'predictions' : 'chat'; | ||
|
||
$response = $this->platform->request('meta/meta-llama-3.1-405b-instruct', $endpoint, [ | ||
'system' => $systemMessage?->content, | ||
'prompt' => $messages->withoutSystemMessage()->getIterator()->current()->content[0]->text, // @phpstan-ignore-line TODO: Multiple messages | ||
]); | ||
|
||
return new Response(new Choice(implode('', $response['output']))); | ||
} | ||
|
||
public function supportsToolCalling(): bool | ||
{ | ||
return false; // it does, but implementation here is still open. | ||
} | ||
|
||
public function supportsImageInput(): bool | ||
{ | ||
return false; // it does, but implementation here is still open. | ||
} | ||
|
||
public function supportsStructuredOutput(): bool | ||
{ | ||
return false; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace PhpLlm\LlmChain\Platform; | ||
|
||
use Symfony\Contracts\HttpClient\HttpClientInterface; | ||
|
||
final readonly class Ollama | ||
{ | ||
public function __construct( | ||
private HttpClientInterface $httpClient, | ||
private string $hostUrl, | ||
) { | ||
} | ||
|
||
/** | ||
* @param string $model The model name on Replicate, e.g. "meta/meta-llama-3.1-405b-instruct" | ||
* @param array<string, mixed> $body | ||
* | ||
* @return array<string, mixed> | ||
*/ | ||
public function request(string $model, string $endpoint, array $body): array | ||
{ | ||
$url = sprintf('%s/api/%s', $this->hostUrl, $endpoint); | ||
|
||
$response = $this->httpClient->request('POST', $url, [ | ||
'headers' => ['Content-Type' => 'application/json'], | ||
'json' => array_merge($body, [ | ||
'model' => $model, | ||
]), | ||
]); | ||
|
||
return dump($response->toArray()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace PhpLlm\LlmChain\Platform; | ||
|
||
use Symfony\Contracts\HttpClient\HttpClientInterface; | ||
|
||
final readonly class Replicate | ||
{ | ||
public function __construct( | ||
private HttpClientInterface $httpClient, | ||
private string $apiKey, | ||
) { | ||
} | ||
|
||
/** | ||
* @param string $model The model name on Replicate, e.g. "meta/meta-llama-3.1-405b-instruct" | ||
* @param array<string, mixed> $body | ||
* | ||
* @return array<string, mixed> | ||
*/ | ||
public function request(string $model, string $endpoint, array $body): array | ||
{ | ||
$url = sprintf('https://api.replicate.com/v1/models/%s/%s', $model, $endpoint); | ||
|
||
$response = $this->httpClient->request('POST', $url, [ | ||
'headers' => ['Content-Type' => 'application/json'], | ||
'auth_bearer' => $this->apiKey, | ||
'json' => ['input' => $body], | ||
])->toArray(); | ||
|
||
while (!in_array($response['status'], ['succeeded', 'failed', 'canceled'], true)) { | ||
sleep(1); | ||
|
||
$response = $this->getResponse($response['id']); | ||
} | ||
|
||
return $response; | ||
} | ||
|
||
/** | ||
* @return array<string, mixed> | ||
*/ | ||
private function getResponse(string $id): array | ||
{ | ||
$url = sprintf('https://api.replicate.com/v1/predictions/%s', $id); | ||
|
||
$response = $this->httpClient->request('GET', $url, [ | ||
'headers' => ['Content-Type' => 'application/json'], | ||
'auth_bearer' => $this->apiKey, | ||
]); | ||
|
||
return $response->toArray(); | ||
} | ||
} |