-
Notifications
You must be signed in to change notification settings - Fork 2
/
azure.ts
138 lines (111 loc) · 3.95 KB
/
azure.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import { ReadableStream, ReadableByteStreamController, UnderlyingByteSource } from 'stream/web'
import { AiProvider, ChatHistory, NoContentError, StreamChunkCallback } from './provider.js'
import { AiStreamEvent, encodeEvent } from './event.js'
import { AzureKeyCredential, ChatCompletions, ChatRequestMessageUnion, EventStream, OpenAIClient } from '@azure/openai'
type AzureStreamResponse = EventStream<ChatCompletions>
class AzureByteSource implements UnderlyingByteSource {
type: 'bytes' = 'bytes'
response: AzureStreamResponse
reader?: ReadableStreamDefaultReader<ChatCompletions>
chunkCallback?: StreamChunkCallback
constructor (response: AzureStreamResponse, chunkCallback?: StreamChunkCallback) {
this.response = response
this.chunkCallback = chunkCallback
}
start (): void {
this.reader = this.response.getReader()
}
async pull (controller: ReadableByteStreamController): Promise<void> {
// start() defines this.reader and is called before this
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const { done, value } = await this.reader!.read()
if (done !== undefined && done) {
controller.close()
return
}
if (value.choices.length === 0) {
const error = new NoContentError('Azure OpenAI')
const eventData: AiStreamEvent = {
event: 'error',
data: error
}
controller.enqueue(encodeEvent(eventData))
controller.close()
return
}
const { delta } = value.choices[0]
if (delta === undefined || delta.content === null) {
const error = new NoContentError('Azure OpenAI')
const eventData: AiStreamEvent = {
event: 'error',
data: error
}
controller.enqueue(encodeEvent(eventData))
controller.close()
return
}
let response = delta.content
if (this.chunkCallback !== undefined) {
response = await this.chunkCallback(response)
}
const eventData: AiStreamEvent = {
event: 'content',
data: {
response
}
}
controller.enqueue(encodeEvent(eventData))
}
}
interface AzureProviderCtorOptions {
endpoint: string
apiKey: string
deploymentName: string
allowInsecureConnections?: boolean
}
export class AzureProvider implements AiProvider {
deploymentName: string
client: OpenAIClient
constructor ({ endpoint, apiKey, deploymentName, allowInsecureConnections }: AzureProviderCtorOptions) {
this.deploymentName = deploymentName
this.client = new OpenAIClient(
endpoint,
new AzureKeyCredential(apiKey),
{
allowInsecureConnection: allowInsecureConnections
}
)
}
async ask (prompt: string, chatHistory?: ChatHistory): Promise<string> {
const { choices } = await this.client.getChatCompletions(this.deploymentName, [
...this.chatHistoryToMessages(chatHistory),
{ role: 'user', content: prompt }
])
if (choices.length === 0) {
throw new NoContentError('Azure OpenAI')
}
const { message } = choices[0]
if (message === undefined || message.content === null) {
throw new NoContentError('Azure OpenAI')
}
return message.content
}
async askStream (prompt: string, chunkCallback?: StreamChunkCallback, chatHistory?: ChatHistory): Promise<ReadableStream> {
const response = await this.client.streamChatCompletions(this.deploymentName, [
...this.chatHistoryToMessages(chatHistory),
{ role: 'user', content: prompt }
])
return new ReadableStream(new AzureByteSource(response, chunkCallback))
}
private chatHistoryToMessages (chatHistory?: ChatHistory): ChatRequestMessageUnion[] {
if (chatHistory === undefined) {
return []
}
const messages: ChatRequestMessageUnion[] = []
for (const previousInteraction of chatHistory) {
messages.push({ role: 'user', content: previousInteraction.prompt })
messages.push({ role: 'assistant', content: previousInteraction.response })
}
return messages
}
}