Skip to content

Commit

Permalink
Add Azure OpenAI provider (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
flakey5 authored Apr 16, 2024
1 parent 5114fcf commit f065ebf
Show file tree
Hide file tree
Showing 18 changed files with 734 additions and 263 deletions.
148 changes: 148 additions & 0 deletions ai-providers/azure.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import { ReadableStream, ReadableByteStreamController, UnderlyingByteSource } from 'stream/web'
import { AiProvider, NoContentError, StreamChunkCallback } from './provider'
import { AiStreamEvent, encodeEvent } from './event'

// @ts-expect-error
type AzureEventStream<T> = import('@azure/openai').EventStream<T>
// @ts-expect-error
type AzureChatCompletions = import('@azure/openai').ChatCompletions

type AzureStreamResponse = AzureEventStream<AzureChatCompletions>

class AzureByteSource implements UnderlyingByteSource {
type: 'bytes' = 'bytes'
response: AzureStreamResponse
reader?: ReadableStreamDefaultReader<AzureChatCompletions>
chunkCallback?: StreamChunkCallback

constructor (response: AzureStreamResponse, chunkCallback?: StreamChunkCallback) {
this.response = response
this.chunkCallback = chunkCallback
}

start (): void {
this.reader = this.response.getReader()
}

async pull (controller: ReadableByteStreamController): Promise<void> {
// start() defines this.reader and is called before this
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const { done, value } = await this.reader!.read()

if (done !== undefined && done) {
controller.close()
return
}

if (value.choices.length === 0) {
const error = new NoContentError('Azure OpenAI')

const eventData: AiStreamEvent = {
event: 'error',
data: error
}
controller.enqueue(encodeEvent(eventData))
controller.close()

return
}

const { delta } = value.choices[0]
if (delta === undefined || delta.content === null) {
const error = new NoContentError('Azure OpenAI')

const eventData: AiStreamEvent = {
event: 'error',
data: error
}
controller.enqueue(encodeEvent(eventData))
controller.close()

return
}

let response = delta.content
if (this.chunkCallback !== undefined) {
response = await this.chunkCallback(response)
}

const eventData: AiStreamEvent = {
event: 'content',
data: {
response
}
}
controller.enqueue(encodeEvent(eventData))
}
}

interface AzureProviderCtorOptions {
endpoint: string
apiKey: string
deploymentName: string
allowInsecureConnections?: boolean
}

export class AzureProvider implements AiProvider {
endpoint: string
deploymentName: string
apiKey: string
// @ts-expect-error typescript doesn't like this type import even though
// it's fine in the Mistral client?
client?: import('@azure/openai').OpenAIClient = undefined
allowInsecureConnections: boolean

constructor ({ endpoint, apiKey, deploymentName, allowInsecureConnections }: AzureProviderCtorOptions) {
this.endpoint = endpoint
this.apiKey = apiKey
this.deploymentName = deploymentName
this.allowInsecureConnections = allowInsecureConnections ?? false
}

async ask (prompt: string): Promise<string> {
if (this.client === undefined) {
const { OpenAIClient, AzureKeyCredential } = await import('@azure/openai')
this.client = new OpenAIClient(
this.endpoint,
new AzureKeyCredential(this.apiKey),
{
allowInsecureConnection: this.allowInsecureConnections
}
)
}

const { choices } = await this.client.getChatCompletions(this.deploymentName, [
{ role: 'user', content: prompt }
])

if (choices.length === 0) {
throw new NoContentError('Azure OpenAI')
}

const { message } = choices[0]
if (message === undefined || message.content === null) {
throw new NoContentError('Azure OpenAI')
}

return message.content
}

async askStream (prompt: string, chunkCallback?: StreamChunkCallback | undefined): Promise<ReadableStream> {
if (this.client === undefined) {
const { OpenAIClient, AzureKeyCredential } = await import('@azure/openai')
this.client = new OpenAIClient(
this.endpoint,
new AzureKeyCredential(this.apiKey),
{
allowInsecureConnection: this.allowInsecureConnections
}
)
}

const response = await this.client.streamChatCompletions(this.deploymentName, [
{ role: 'user', content: prompt }
])

return new ReadableStream(new AzureByteSource(response, chunkCallback))
}
}
8 changes: 8 additions & 0 deletions config.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ export interface AiWarpConfig {
host: string;
model: string;
};
}
| {
azure: {
endpoint: string;
apiKey: string;
deploymentName: string;
allowInsecureConnections?: boolean;
};
};
promptDecorators?: {
prefix?: string;
Expand Down
9 changes: 9 additions & 0 deletions lib/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ class AiWarpGenerator extends ServiceGenerator {
}
}
break
case 'azure':
config.aiProvider = {
azure: {
endpoint: 'https://myaccount.openai.azure.com/',
apiKey: `{${this.getEnvVarName('PLT_AZURE_API_KEY')}}`,
deploymentName: this.config.aiModel
}
}
break
default:
config.aiProvider = {
openai: {
Expand Down
20 changes: 20 additions & 0 deletions lib/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,26 @@ const aiWarpSchema = {
},
required: ['ollama'],
additionalProperties: false
},
{
properties: {
azure: {
type: 'object',
properties: {
endpoint: { type: 'string' },
apiKey: { type: 'string' },
deploymentName: { type: 'string' },
allowInsecureConnections: {
type: 'boolean',
default: false
}
},
required: ['endpoint', 'apiKey', 'deploymentName'],
additionalProperties: false
}
},
required: ['azure'],
additionalProperties: false
}
]
},
Expand Down
156 changes: 156 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"typescript": "^5.3.3"
},
"dependencies": {
"@azure/openai": "^1.0.0-beta.12",
"@fastify/error": "^3.4.1",
"@fastify/rate-limit": "^9.1.0",
"@fastify/type-provider-typebox": "^4.0.0",
Expand Down
Loading

0 comments on commit f065ebf

Please sign in to comment.