diff --git a/src/createChat.test.ts b/src/createChat.test.ts index 1043f94..904105d 100644 --- a/src/createChat.test.ts +++ b/src/createChat.test.ts @@ -265,6 +265,52 @@ test("overrides function call", async () => { assert.equal(getCurrentWeather.mock.calls.length, 0); }); +test("calls a user defined function more than once", async () => { + const getCurrentWeather = mock.fn(({ location }: { location: string }) => { + return { + location, + temperature: "72", + unit: "fahrenheit", + forecast: ["sunny", "windy"], + }; + }); + + const chat = createChat({ + apiKey: OPENAI_API_KEY, + model: "gpt-3.5-turbo-0613", + functions: [ + { + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The city and state, e.g. San Francisco, CA", + }, + unit: { type: "string", enum: ["celsius", "fahrenheit"] }, + }, + required: ["location"], + }, + function: getCurrentWeather, + }, + ], + }); + + await chat.sendMessage("Tell me the weather in Albuquerque and Chicago?"); + + assert.equal(getCurrentWeather.mock.calls.length, 2); + assert.equal( + getCurrentWeather.mock.calls[0].arguments[0].location, + "Albuquerque" + ); + assert.equal( + getCurrentWeather.mock.calls[1].arguments[0].location, + "Chicago" + ); +}); + test("overrides message options", async () => { const chat = createChat({ apiKey: OPENAI_API_KEY, diff --git a/src/createChat.ts b/src/createChat.ts index 60b3584..2cd9670 100644 --- a/src/createChat.ts +++ b/src/createChat.ts @@ -179,7 +179,7 @@ export const createChat = ( messages.push(omit(choice, "finishReason")); - if (choice.function_call) { + while (choice.function_call) { const functionName = choice.function_call.name; const userFunction = userFunctions[functionName]; @@ -204,6 +204,9 @@ export const createChat = ( choice = await complete(messageOptions); + messages.push(omit(choice, "finishReason")); + + // TODO record a trail of function calls choice.functionCall = { name: functionName, arguments: functionArgs,