Skip to content
This repository has been archived by the owner on Oct 10, 2024. It is now read-only.

Commit

Permalink
Merge pull request #68 from sublimator/nd-add-signal-support-and-fix-…
Browse files Browse the repository at this point in the history
…timeout-2024-05-01

feat: add signal support and fix timeout, closes #5
  • Loading branch information
lerela authored May 6, 2024
2 parents 06dabd3 + 3ce76a0 commit 7b55eaa
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 33 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/build_publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Build and Publish
on:
push:
branches: ["main"]

# We only deploy on tags and main branch
tags:
# Only run on tags that match the following regex
Expand All @@ -14,13 +14,13 @@ on:
pull_request:

jobs:

lint_and_test:
runs-on: ubuntu-latest

strategy:
matrix:
node-version: [18, 20]
node-version: [18, 20, 22]

steps:
# Checkout the repository
Expand Down Expand Up @@ -72,4 +72,3 @@ jobs:
sed -i 's/VERSION = '\''0.0.1'\''/VERSION = '\''${{ github.ref_name }}'\''/g' src/client.js
npm publish
107 changes: 78 additions & 29 deletions src/client.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,29 @@ class MistralAPIError extends Error {
}
};

/**
* @param {Array<AbortSignal|undefined>} signals to merge
* @return {AbortSignal} signal which will abort when any of signals abort
*/
function combineSignals(signals) {
const controller = new AbortController();
signals.forEach((signal) => {
if (!signal) {
return;
}

signal.addEventListener('abort', () => {
controller.abort(signal.reason);
}, {once: true});

if (signal.aborted) {
controller.abort(signal.reason);
}
});

return controller.signal;
}

/**
* MistralClient
* @return {MistralClient}
Expand Down Expand Up @@ -69,9 +92,10 @@ class MistralClient {
* @param {*} method
* @param {*} path
* @param {*} request
* @param {*} signal
* @return {Promise<*>}
*/
_request = async function(method, path, request) {
_request = async function(method, path, request, signal) {
const url = `${this.endpoint}/${path}`;
const options = {
method: method,
Expand All @@ -82,7 +106,8 @@ class MistralClient {
'Authorization': `Bearer ${this.apiKey}`,
},
body: method !== 'get' ? JSON.stringify(request) : null,
timeout: this.timeout * 1000,
signal: combineSignals(
[AbortSignal.timeout(this.timeout * 1000), signal]),
};

for (let attempts = 0; attempts < this.maxRetries; attempts++) {
Expand Down Expand Up @@ -208,19 +233,30 @@ class MistralClient {
};

/**
* A chat endpoint without streaming
* @param {*} model the name of the model to chat with, e.g. mistral-tiny
* @param {*} messages an array of messages to chat with, e.g.
* [{role: 'user', content: 'What is the best French cheese?'}]
* @param {*} tools a list of tools to use.
* @param {*} temperature the temperature to use for sampling, e.g. 0.5
* @param {*} maxTokens the maximum number of tokens to generate, e.g. 100
* @param {*} topP the cumulative probability of tokens to generate, e.g. 0.9
* @param {*} randomSeed the random seed to use for sampling, e.g. 42
* @param {*} safeMode deprecated use safePrompt instead
* @param {*} safePrompt whether to use safe mode, e.g. true
* @param {*} toolChoice the tool to use, e.g. 'auto'
* @param {*} responseFormat the format of the response, e.g. 'json_format'
* A chat endpoint without streaming.
*
* @param {Object} data - The main chat configuration.
* @param {*} data.model - the name of the model to chat with,
* e.g. mistral-tiny
* @param {*} data.messages - an array of messages to chat with, e.g.
* [{role: 'user', content: 'What is the best
* French cheese?'}]
* @param {*} data.tools - a list of tools to use.
* @param {*} data.temperature - the temperature to use for sampling, e.g. 0.5
* @param {*} data.maxTokens - the maximum number of tokens to generate,
* e.g. 100
* @param {*} data.topP - the cumulative probability of tokens to generate,
* e.g. 0.9
* @param {*} data.randomSeed - the random seed to use for sampling, e.g. 42
* @param {*} data.safeMode - deprecated use safePrompt instead
* @param {*} data.safePrompt - whether to use safe mode, e.g. true
* @param {*} data.toolChoice - the tool to use, e.g. 'auto'
* @param {*} data.responseFormat - the format of the response,
* e.g. 'json_format'
* @param {Object} options - Additional operational options.
* @param {*} [options.signal] - optional AbortSignal instance to control
* request The signal will be combined with
* default timeout signal
* @return {Promise<Object>}
*/
chat = async function({
Expand All @@ -235,7 +271,7 @@ class MistralClient {
safePrompt,
toolChoice,
responseFormat,
}) {
}, {signal} = {}) {
const request = this._makeChatCompletionRequest(
model,
messages,
Expand All @@ -254,24 +290,36 @@ class MistralClient {
'post',
'v1/chat/completions',
request,
signal,
);
return response;
};

/**
* A chat endpoint that streams responses.
* @param {*} model the name of the model to chat with, e.g. mistral-tiny
* @param {*} messages an array of messages to chat with, e.g.
* [{role: 'user', content: 'What is the best French cheese?'}]
* @param {*} tools a list of tools to use.
* @param {*} temperature the temperature to use for sampling, e.g. 0.5
* @param {*} maxTokens the maximum number of tokens to generate, e.g. 100
* @param {*} topP the cumulative probability of tokens to generate, e.g. 0.9
* @param {*} randomSeed the random seed to use for sampling, e.g. 42
* @param {*} safeMode deprecated use safePrompt instead
* @param {*} safePrompt whether to use safe mode, e.g. true
* @param {*} toolChoice the tool to use, e.g. 'auto'
* @param {*} responseFormat the format of the response, e.g. 'json_format'
*
* @param {Object} data - The main chat configuration.
* @param {*} data.model - the name of the model to chat with,
* e.g. mistral-tiny
* @param {*} data.messages - an array of messages to chat with, e.g.
* [{role: 'user', content: 'What is the best
* French cheese?'}]
* @param {*} data.tools - a list of tools to use.
* @param {*} data.temperature - the temperature to use for sampling, e.g. 0.5
* @param {*} data.maxTokens - the maximum number of tokens to generate,
* e.g. 100
* @param {*} data.topP - the cumulative probability of tokens to generate,
* e.g. 0.9
* @param {*} data.randomSeed - the random seed to use for sampling, e.g. 42
* @param {*} data.safeMode - deprecated use safePrompt instead
* @param {*} data.safePrompt - whether to use safe mode, e.g. true
* @param {*} data.toolChoice - the tool to use, e.g. 'auto'
* @param {*} data.responseFormat - the format of the response,
* e.g. 'json_format'
* @param {Object} options - Additional operational options.
* @param {*} [options.signal] - optional AbortSignal instance to control
* request The signal will be combined with
* default timeout signal
* @return {Promise<Object>}
*/
chatStream = async function* ({
Expand All @@ -286,7 +334,7 @@ class MistralClient {
safePrompt,
toolChoice,
responseFormat,
}) {
}, {signal} = {}) {
const request = this._makeChatCompletionRequest(
model,
messages,
Expand All @@ -305,6 +353,7 @@ class MistralClient {
'post',
'v1/chat/completions',
request,
signal,
);

let buffer = '';
Expand Down

0 comments on commit 7b55eaa

Please sign in to comment.