Skip to content

Commit

Permalink
feat: introduce chain awareness in chain processors to nest them
Browse files Browse the repository at this point in the history
  • Loading branch information
chr-hertel committed Oct 5, 2024
1 parent 1be8cb2 commit 009d216
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 32 deletions.
58 changes: 58 additions & 0 deletions examples/structured-output-clock.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<?php

use PhpLlm\LlmChain\Chain;
use PhpLlm\LlmChain\Message\Message;
use PhpLlm\LlmChain\Message\MessageBag;
use PhpLlm\LlmChain\OpenAI\Model\Gpt;
use PhpLlm\LlmChain\OpenAI\Model\Gpt\Version;
use PhpLlm\LlmChain\OpenAI\Platform\OpenAI;
use PhpLlm\LlmChain\StructuredOutput\ChainProcessor as StructuredOutputProcessor;
use PhpLlm\LlmChain\StructuredOutput\ResponseFormatFactory;
use PhpLlm\LlmChain\ToolBox\ChainProcessor as ToolProcessor;
use PhpLlm\LlmChain\ToolBox\Tool\Clock;
use PhpLlm\LlmChain\ToolBox\ToolAnalyzer;
use PhpLlm\LlmChain\ToolBox\ToolBox;
use Symfony\Component\Clock\Clock as SymfonyClock;
use Symfony\Component\Dotenv\Dotenv;
use Symfony\Component\HttpClient\HttpClient;
use Symfony\Component\Serializer\Encoder\JsonEncoder;
use Symfony\Component\Serializer\Normalizer\ObjectNormalizer;
use Symfony\Component\Serializer\Serializer;

require_once dirname(__DIR__).'/vendor/autoload.php';
(new Dotenv())->loadEnv(dirname(__DIR__).'/.env');

if (empty($_ENV['OPENAI_API_KEY'])) {
echo 'Please set the OPENAI_API_KEY environment variable.'.PHP_EOL;
exit(1);
}

$platform = new OpenAI(HttpClient::create(), $_ENV['OPENAI_API_KEY']);
$llm = new Gpt($platform, Version::gpt4oMini());

$clock = new Clock(new SymfonyClock());
$toolBox = new ToolBox(new ToolAnalyzer(), [$clock]);
$toolProcessor = new ToolProcessor($toolBox);
$serializer = new Serializer([new ObjectNormalizer()], [new JsonEncoder()]);
$structuredOutputProcessor = new StructuredOutputProcessor(new ResponseFormatFactory(), $serializer);
$chain = new Chain($llm, [$toolProcessor, $structuredOutputProcessor], [$toolProcessor, $structuredOutputProcessor]);

$messages = new MessageBag(Message::ofUser('What date and time is it?'));
$response = $chain->call($messages, ['response_format' => [
'type' => 'json_schema',
'json_schema' => [
'name' => 'clock',
'strict' => true,
'schema' => [
'type' => 'object',
'properties' => [
'date' => ['type' => 'string', 'description' => 'The current date in the format YYYY-MM-DD.'],
'time' => ['type' => 'string', 'description' => 'The current time in the format HH:MM:SS.'],
],
'required' => ['date', 'time'],
'additionalProperties' => false,
],
],
]]);

dump($response->getContent());
30 changes: 23 additions & 7 deletions src/Chain.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

namespace PhpLlm\LlmChain;

use PhpLlm\LlmChain\Chain\ChainAwareProcessor;
use PhpLlm\LlmChain\Chain\Input;
use PhpLlm\LlmChain\Chain\InputProcessor;
use PhpLlm\LlmChain\Chain\Output;
use PhpLlm\LlmChain\Chain\OutputProcessor;
use PhpLlm\LlmChain\Exception\InvalidArgumentException;
use PhpLlm\LlmChain\Exception\MissingModelSupport;
use PhpLlm\LlmChain\Message\MessageBag;
use PhpLlm\LlmChain\Response\ResponseInterface;
Expand All @@ -33,8 +35,8 @@ public function __construct(
iterable $inputProcessor = [],
iterable $outputProcessor = [],
) {
$this->inputProcessor = $inputProcessor instanceof \Traversable ? iterator_to_array($inputProcessor) : $inputProcessor;
$this->outputProcessor = $outputProcessor instanceof \Traversable ? iterator_to_array($outputProcessor) : $outputProcessor;
$this->inputProcessor = $this->initializeProcessors($inputProcessor, InputProcessor::class);
$this->outputProcessor = $this->initializeProcessors($outputProcessor, OutputProcessor::class);
}

/**
Expand All @@ -52,14 +54,28 @@ public function call(MessageBag $messages, array $options = []): ResponseInterfa
$response = $this->llm->call($messages, $options = $input->getOptions());

$output = new Output($this->llm, $response, $messages, $options);
foreach ($this->outputProcessor as $outputProcessor) {
$result = $outputProcessor->processOutput($output);
array_map(fn (OutputProcessor $processor) => $processor->processOutput($output), $this->outputProcessor);

if (null !== $result) {
return $result;
return $output->response;
}

/**
* @param InputProcessor[]|OutputProcessor[] $processors
*
* @return InputProcessor[]|OutputProcessor[]
*/
private function initializeProcessors(iterable $processors, string $interface): array
{
foreach ($processors as $processor) {
if (!$processor instanceof $interface) {
throw new InvalidArgumentException(sprintf('Processor %s must implement %s interface.', $processor::class, $interface));
}

if ($processor instanceof ChainAwareProcessor) {
$processor->setChain($this);
}
}

return $response;
return $processors instanceof \Traversable ? iterator_to_array($processors) : $processors;
}
}
12 changes: 12 additions & 0 deletions src/Chain/ChainAwareProcessor.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Chain;

use PhpLlm\LlmChain\Chain;

interface ChainAwareProcessor
{
public function setChain(Chain $chain): void;
}
17 changes: 17 additions & 0 deletions src/Chain/ChainAwareTrait.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Chain;

use PhpLlm\LlmChain\Chain;

trait ChainAwareTrait
{
private Chain $chain;

public function setChain(Chain $chain): void
{
$this->chain = $chain;
}
}
8 changes: 4 additions & 4 deletions src/Chain/Output.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
use PhpLlm\LlmChain\Message\MessageBag;
use PhpLlm\LlmChain\Response\ResponseInterface;

final readonly class Output
final class Output
{
/**
* @param array<string, mixed> $options
*/
public function __construct(
public LanguageModel $llm,
public readonly LanguageModel $llm,
public ResponseInterface $response,
public MessageBag $messages,
public array $options,
public readonly MessageBag $messages,
public readonly array $options,
) {
}
}
2 changes: 1 addition & 1 deletion src/Chain/OutputProcessor.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@

interface OutputProcessor
{
public function processOutput(Output $output): mixed;
public function processOutput(Output $output): void;
}
18 changes: 14 additions & 4 deletions src/StructuredOutput/ChainProcessor.php
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,25 @@ public function processInput(Input $input): void
$input->setOptions($options);
}

public function processOutput(Output $output): ?StructuredResponse
public function processOutput(Output $output): void
{
$options = $output->options;

if (!isset($options['output_structure'])) {
return null;
if ($output->response instanceof StructuredResponse) {
return;
}

if (!isset($options['response_format'])) {
return;
}

if (!isset($this->outputStructure)) {
$output->response = new StructuredResponse(json_decode($output->response->getContent(), true));

return;
}

return new StructuredResponse(
$output->response = new StructuredResponse(
$this->serializer->deserialize($output->response->getContent(), $this->outputStructure, 'json')
);
}
Expand Down
18 changes: 9 additions & 9 deletions src/ToolBox/ChainProcessor.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@

namespace PhpLlm\LlmChain\ToolBox;

use PhpLlm\LlmChain\Chain\ChainAwareProcessor;
use PhpLlm\LlmChain\Chain\ChainAwareTrait;
use PhpLlm\LlmChain\Chain\Input;
use PhpLlm\LlmChain\Chain\InputProcessor;
use PhpLlm\LlmChain\Chain\Output;
use PhpLlm\LlmChain\Chain\OutputProcessor;
use PhpLlm\LlmChain\Exception\MissingModelSupport;
use PhpLlm\LlmChain\Message\Message;
use PhpLlm\LlmChain\Response\ResponseInterface;
use PhpLlm\LlmChain\Response\ToolCallResponse;

final readonly class ChainProcessor implements InputProcessor, OutputProcessor
final class ChainProcessor implements InputProcessor, OutputProcessor, ChainAwareProcessor
{
use ChainAwareTrait;

public function __construct(
private ToolBoxInterface $toolBox,
) {
Expand All @@ -31,23 +34,20 @@ public function processInput(Input $input): void
$input->setOptions($options);
}

public function processOutput(Output $output): ResponseInterface
public function processOutput(Output $output): void
{
$response = $output->response;
$messages = clone $output->messages;

while ($response instanceof ToolCallResponse) {
$toolCalls = $response->getContent();
while ($output->response instanceof ToolCallResponse) {
$toolCalls = $output->response->getContent();
$messages[] = Message::ofAssistant(toolCalls: $toolCalls);

foreach ($toolCalls as $toolCall) {
$result = $this->toolBox->execute($toolCall);
$messages[] = Message::ofToolCall($toolCall, $result);
}

$response = $output->llm->call($messages, $output->options);
$output->response = $this->chain->call($messages, $output->options);
}

return $response;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

declare(strict_types=1);

namespace PhpLlm\LlmChain\Tests\ToolBox;
namespace PhpLlm\LlmChain\Tests\StructuredOutput;

use PhpLlm\LlmChain\Chain\Input;
use PhpLlm\LlmChain\Chain\Output;
use PhpLlm\LlmChain\Exception\MissingModelSupport;
use PhpLlm\LlmChain\LanguageModel;
use PhpLlm\LlmChain\Message\MessageBag;
use PhpLlm\LlmChain\Response\Choice;
use PhpLlm\LlmChain\Response\StructuredResponse;
use PhpLlm\LlmChain\Response\TextResponse;
use PhpLlm\LlmChain\StructuredOutput\ChainProcessor;
use PhpLlm\LlmChain\Tests\Double\ConfigurableResponseFormatFactory;
Expand Down Expand Up @@ -96,12 +97,13 @@ public function processOutputWithResponseFormat(): void

$response = new TextResponse('{"some": "data"}');

$output = new Output($llm, $response, new MessageBag(), $options);
$output = new Output($llm, $response, new MessageBag(), $input->getOptions());

$result = $chainProcessor->processOutput($output)->getContent();
$chainProcessor->processOutput($output);

self::assertInstanceOf(SomeStructure::class, $result);
self::assertSame('data', $result->some);
self::assertInstanceOf(StructuredResponse::class, $output->response);
self::assertInstanceOf(SomeStructure::class, $output->response->getContent());
self::assertSame('data', $output->response->getContent()->some);
}

#[Test]
Expand All @@ -116,8 +118,8 @@ public function processOutputWithoutResponseFormat(): void

$output = new Output($llm, $response, new MessageBag(), []);

$result = $chainProcessor->processOutput($output);
$chainProcessor->processOutput($output);

self::assertNull($result);
self::assertSame($response, $output->response);
}
}

0 comments on commit 009d216

Please sign in to comment.