Skip to content

Commit

Permalink
feat(bedrock): Added bedrock reasoning support
Browse files Browse the repository at this point in the history
  • Loading branch information
Und3rf10w committed Feb 25, 2025
1 parent 88f6e01 commit 476bfe4
Showing 1 changed file with 54 additions and 27 deletions.
81 changes: 54 additions & 27 deletions packages/amazon-bedrock/src/bedrock-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ export class BedrockChatLanguageModel implements LanguageModelV1 {
readonly modelId: BedrockChatModelId,
private readonly settings: BedrockChatSettings,
private readonly config: BedrockChatConfig,
) {}
) { }

private getArgs({
mode,
Expand Down Expand Up @@ -207,20 +207,20 @@ export class BedrockChatLanguageModel implements LanguageModelV1 {
const providerMetadata =
response.trace || response.usage
? {
bedrock: {
...(response.trace && typeof response.trace === 'object'
? { trace: response.trace as JSONObject }
: {}),
...(response.usage && {
usage: {
cacheReadInputTokens:
response.usage?.cacheReadInputTokens ?? Number.NaN,
cacheWriteInputTokens:
response.usage?.cacheWriteInputTokens ?? Number.NaN,
},
}),
},
}
bedrock: {
...(response.trace && typeof response.trace === 'object'
? { trace: response.trace as JSONObject }
: {}),
...(response.usage && {
usage: {
cacheReadInputTokens:
response.usage?.cacheReadInputTokens ?? Number.NaN,
cacheWriteInputTokens:
response.usage?.cacheWriteInputTokens ?? Number.NaN,
},
}),
},
}
: undefined;

return {
Expand All @@ -246,6 +246,12 @@ export class BedrockChatLanguageModel implements LanguageModelV1 {
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings,
reasoning: response.output?.message?.reasoningContent
?.map(part => ({
type: 'text',
text: part?.reasoningText?.text,
signature: part?.reasoningText?.signature,
})),
...(providerMetadata && { providerMetadata }),
};
}
Expand Down Expand Up @@ -345,23 +351,23 @@ export class BedrockChatLanguageModel implements LanguageModelV1 {

const cacheUsage =
value.metadata.usage?.cacheReadInputTokens != null ||
value.metadata.usage?.cacheWriteInputTokens != null
value.metadata.usage?.cacheWriteInputTokens != null
? {
usage: {
cacheReadInputTokens:
value.metadata.usage?.cacheReadInputTokens ??
Number.NaN,
cacheWriteInputTokens:
value.metadata.usage?.cacheWriteInputTokens ??
Number.NaN,
},
}
usage: {
cacheReadInputTokens:
value.metadata.usage?.cacheReadInputTokens ??
Number.NaN,
cacheWriteInputTokens:
value.metadata.usage?.cacheWriteInputTokens ??
Number.NaN,
},
}
: undefined;

const trace = value.metadata.trace
? {
trace: value.metadata.trace as JSONObject,
}
trace: value.metadata.trace as JSONObject,
}
: undefined;

if (cacheUsage || trace) {
Expand All @@ -385,6 +391,17 @@ export class BedrockChatLanguageModel implements LanguageModelV1 {
});
}

if (
value.contentBlockDelta?.delta &&
'reasoningContent' in value.contentBlockDelta.delta &&
value.contentBlockDelta.delta.reasoningContent
) {
controller.enqueue({
type: 'reasoning',
text: value.contentBlockDelta.delta.reasoningContent.text
})
}

const contentBlockStart = value.contentBlockStart;
if (contentBlockStart?.start?.toolUse != null) {
const toolUse = contentBlockStart.start.toolUse;
Expand Down Expand Up @@ -468,6 +485,11 @@ const BedrockToolUseSchema = z.object({
input: z.unknown(),
});

const BedrockReasoningTextSchema = z.object({
signature: z.string().nullish(),
text: z.string(),
});

// limited version of the schema, focussed on what is needed for the implementation
// this approach limits breakages when the API changes and increases efficiency
const BedrockResponseSchema = z.object({
Expand All @@ -482,6 +504,9 @@ const BedrockResponseSchema = z.object({
z.object({
text: z.string().nullish(),
toolUse: BedrockToolUseSchema.nullish(),
reasoningContent: z.object({
reasoningText: BedrockReasoningTextSchema
}).nullish()
}),
),
role: z.string(),
Expand All @@ -508,6 +533,7 @@ const BedrockStreamSchema = z.object({
.union([
z.object({ text: z.string() }),
z.object({ toolUse: z.object({ input: z.string() }) }),
z.object({ reasoningContent: z.object({ text: z.string() }) })
])
.nullish(),
})
Expand Down Expand Up @@ -551,3 +577,4 @@ const BedrockStreamSchema = z.object({
throttlingException: z.record(z.unknown()).nullish(),
validationException: z.record(z.unknown()).nullish(),
});

0 comments on commit 476bfe4

Please sign in to comment.