-
Notifications
You must be signed in to change notification settings - Fork 2
/
mistral.ts
109 lines (89 loc) · 3.12 KB
/
mistral.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import { ReadableStream, UnderlyingByteSource, ReadableByteStreamController } from 'node:stream/web'
import MistralClient, { ChatCompletionResponseChunk } from '@platformatic/mistral-client'
import { AiProvider, ChatHistory, NoContentError, StreamChunkCallback } from './provider.js'
import { AiStreamEvent, encodeEvent } from './event.js'
type MistralStreamResponse = AsyncGenerator<ChatCompletionResponseChunk, void, unknown>
class MistralByteSource implements UnderlyingByteSource {
type: 'bytes' = 'bytes'
response: MistralStreamResponse
chunkCallback?: StreamChunkCallback
constructor (response: MistralStreamResponse, chunkCallback?: StreamChunkCallback) {
this.response = response
this.chunkCallback = chunkCallback
}
async pull (controller: ReadableByteStreamController): Promise<void> {
const { done, value } = await this.response.next()
if (done !== undefined && done) {
controller.close()
return
}
if (value.choices.length === 0) {
const error = new NoContentError('Mistral (Stream)')
const eventData: AiStreamEvent = {
event: 'error',
data: error
}
controller.enqueue(encodeEvent(eventData))
controller.close()
return
}
const { content } = value.choices[0].delta
let response = content ?? ''
if (this.chunkCallback !== undefined) {
response = await this.chunkCallback(response)
}
const eventData: AiStreamEvent = {
event: 'content',
data: {
response
}
}
controller.enqueue(encodeEvent(eventData))
}
}
interface MistralProviderCtorOptions {
model: string
apiKey: string
}
export class MistralProvider implements AiProvider {
model: string
client: MistralClient
constructor ({ model, apiKey }: MistralProviderCtorOptions) {
this.model = model
this.client = new MistralClient(apiKey)
}
async ask (prompt: string, chatHistory?: ChatHistory): Promise<string> {
const response = await this.client.chat({
model: this.model,
messages: [
...this.chatHistoryToMessages(chatHistory),
{ role: 'user', content: prompt }
]
})
if (response.choices.length === 0) {
throw new NoContentError('Mistral')
}
return response.choices[0].message.content
}
async askStream (prompt: string, chunkCallback?: StreamChunkCallback, chatHistory?: ChatHistory): Promise<ReadableStream> {
const response = this.client.chatStream({
model: this.model,
messages: [
...this.chatHistoryToMessages(chatHistory),
{ role: 'user', content: prompt }
]
})
return new ReadableStream(new MistralByteSource(response, chunkCallback))
}
private chatHistoryToMessages (chatHistory?: ChatHistory): Array<{ role: string, content: string }> {
if (chatHistory === undefined) {
return []
}
const messages: Array<{ role: string, content: string }> = []
for (const previousInteraction of chatHistory) {
messages.push({ role: 'user', content: previousInteraction.prompt })
messages.push({ role: 'assistant', content: previousInteraction.response })
}
return messages
}
}