Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(langchain): updates #63

Merged
merged 2 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ await model.call('Tell me a joke.', undefined, [

```typescript
import { GenAIChatModel } from '@ibm-generative-ai/node-sdk/langchain';
import { SystemMessage, HumanMessage } from 'langchain/schema';

const client = new GenAIChatModel({
modelId: 'eleutherai/gpt-neox-20b',
Expand All @@ -268,13 +269,13 @@ const client = new GenAIChatModel({
});

const response = await client.call([
new SystemChatMessage(
new SystemMessage(
'You are a helpful assistant that translates English to Spanish.',
),
new HumanChatMessage('I love programming.'),
new HumanMessage('I love programming.'),
]);

console.info(response.text); // "Me encanta la programación."
console.info(response.content); // "Me encanta la programación."
```

#### Prompt Templates (GenAI x LangChain)
Expand Down
6 changes: 3 additions & 3 deletions examples/langchain/llm-chat.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { HumanChatMessage } from 'langchain/schema';
import { HumanMessage } from 'langchain/schema';

import { GenAIChatModel } from '../../src/langchain/llm-chat.js';

Expand Down Expand Up @@ -31,7 +31,7 @@ const makeClient = (stream?: boolean) =>
const chat = makeClient();

const response = await chat.call([
new HumanChatMessage(
new HumanMessage(
'What is a good name for a company that makes colorful socks?',
),
]);
Expand All @@ -43,7 +43,7 @@ const makeClient = (stream?: boolean) =>
// Streaming
const chat = makeClient(true);

await chat.call([new HumanChatMessage('Tell me a joke.')], undefined, [
await chat.call([new HumanMessage('Tell me a joke.')], undefined, [
{
handleLLMNewToken(token) {
console.log(token);
Expand Down
49 changes: 49 additions & 0 deletions examples/langchain/llm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { GenAIModel } from '../../src/langchain/index.js';

const makeClient = (stream?: boolean) =>
new GenAIModel({
modelId: 'google/flan-t5-xl',
stream,
configuration: {
endpoint: process.env.ENDPOINT,
apiKey: process.env.API_KEY,
},
parameters: {
decoding_method: 'greedy',
min_new_tokens: 5,
max_new_tokens: 25,
repetition_penalty: 1.5,
},
});

{
// Basic
console.info('---Single Input Example---');
const model = makeClient();

const prompt = 'What is a good name for a company that makes colorful socks?';
console.info(`Request: ${prompt}`);
const response = await model.call(prompt);
console.log(`Response: ${response}`);
}

{
console.info('---Multiple Inputs Example---');
const model = makeClient();

const prompts = ['What is IBM?', 'What is WatsonX?'];
console.info('Request prompts:', prompts);
const response = await model.generate(prompts);
console.info('Response:', response);
}

{
console.info('---Streaming Example---');
const chat = makeClient(true);

const prompt = 'What is a molecule?';
console.info(`Request: ${prompt}`);
for await (const token of await chat.stream(prompt)) {
console.info(`Received token: ${token}`);
}
}
2 changes: 1 addition & 1 deletion src/langchain/llm-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ export class GenAIChatModel extends BaseChatModel {
`Unsupported message type "${msg._getType()}"`,
);
}
return `${type.stopSequence}${msg.text}`;
return `${type.stopSequence}${msg.content}`;
})
.join('\n')
.concat(this.#rolesMapping.system.stopSequence);
Expand Down
26 changes: 13 additions & 13 deletions src/tests/e2e/langchain/llm-chat.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { HumanChatMessage, SystemChatMessage } from 'langchain/schema';
import { HumanMessage, SystemMessage } from 'langchain/schema';

import { GenAIChatModel } from '../../../langchain/index.js';
import { describeIf } from '../../utils.js';
Expand Down Expand Up @@ -47,34 +47,34 @@ describeIf(process.env.RUN_LANGCHAIN_CHAT_TESTS === 'true')(
const chat = makeClient();

const response = await chat.call([
new HumanChatMessage(
new HumanMessage(
'What is a good name for a company that makes colorful socks?',
),
]);
expectIsNonEmptyString(response.text);
expectIsNonEmptyString(response.content);
});

test('should handle question with additional hint', async () => {
const chat = makeClient();

const response = await chat.call([
new SystemChatMessage(SYSTEM_MESSAGE),
new HumanChatMessage('I love programming.'),
new SystemMessage(SYSTEM_MESSAGE),
new HumanMessage('I love programming.'),
]);
expectIsNonEmptyString(response.text);
expectIsNonEmptyString(response.content);
});

test('should handle multiple questions', async () => {
const chat = makeClient();

const response = await chat.generate([
[
new SystemChatMessage(SYSTEM_MESSAGE),
new HumanChatMessage('I love programming.'),
new SystemMessage(SYSTEM_MESSAGE),
new HumanMessage('I love programming.'),
],
[
new SystemChatMessage(SYSTEM_MESSAGE),
new HumanChatMessage('I love artificial intelligence.'),
new SystemMessage(SYSTEM_MESSAGE),
new HumanMessage('I love artificial intelligence.'),
],
]);

Expand All @@ -95,7 +95,7 @@ describeIf(process.env.RUN_LANGCHAIN_CHAT_TESTS === 'true')(
});

const output = await chat.call(
[new HumanChatMessage('Tell me a joke.')],
[new HumanMessage('Tell me a joke.')],
undefined,
[
{
Expand All @@ -105,8 +105,8 @@ describeIf(process.env.RUN_LANGCHAIN_CHAT_TESTS === 'true')(
);

expect(handleNewToken).toHaveBeenCalled();
expectIsNonEmptyString(output.text);
expect(tokens.join('')).toStrictEqual(output.text);
expectIsNonEmptyString(output.content);
expect(tokens.join('')).toStrictEqual(output.content);
});
});
},
Expand Down