Skip to content

Commit

Permalink
feat: support langchain serialization
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D committed May 28, 2024
1 parent b89924a commit b7ed959
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 3 deletions.
35 changes: 33 additions & 2 deletions src/langchain/llm-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { ChatGenerationChunk, ChatResult } from '@langchain/core/outputs';
import { BaseLanguageModelCallOptions as BaseChatModelCallOptions } from '@langchain/core/language_models/base';
import merge from 'lodash/merge.js';
import { load } from '@langchain/core/load';
import type { Serialized } from '@langchain/core/load/serializable';

import { Client, Configuration } from '../client.js';
import { TextChatCreateInput, TextChatCreateStreamInput } from '../schema.js';
Expand Down Expand Up @@ -218,11 +220,40 @@ export class GenAIChatModel extends BaseChatModel<GenAIChatModelOptions> {
});
}

_llmType(): string {
return 'GenAIChat';
lc_serializable = true;
lc_namespace = ['@ibm-generative-ai/node-sdk', 'langchain', 'llm-chat'];

get lc_id(): string[] {
return [...this.lc_namespace, 'GenAIChatModel'];
}

lc_kwargs = {
modelId: undefined,
promptId: undefined,
conversationId: undefined,
parameters: undefined,
moderations: undefined,
useConversationParameters: undefined,
parentId: undefined,
trimMethod: undefined,
};

static async fromJSON(value: string | Serialized) {
const input = typeof value === 'string' ? value : JSON.stringify(value);
return await load(input, {
optionalImportsMap: {
'@ibm-generative-ai/node-sdk/langchain/llm-chat': {
GenAIModel: GenAIChatModel,
},
},
});
}

_modelType(): string {
return this.modelId;
}

_llmType(): string {
return 'GenAIChatModel';
}
}
33 changes: 32 additions & 1 deletion src/langchain/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import type { LLMResult } from '@langchain/core/outputs';
import { GenerationChunk } from '@langchain/core/outputs';
import merge from 'lodash/merge.js';
import { load } from '@langchain/core/load';
import type { Serialized } from '@langchain/core/load/serializable';

import { Client, Configuration } from '../client.js';
import { concatUnique, isNullish } from '../helpers/common.js';
Expand Down Expand Up @@ -183,11 +185,40 @@ export class GenAIModel extends BaseLLM<GenAIModelOptions> {
return result.results.at(0)?.token_count ?? 0;
}

static async fromJSON(value: string | Serialized) {
const input = typeof value === 'string' ? value : JSON.stringify(value);
return await load(input, {
optionalImportsMap: {
'@ibm-generative-ai/node-sdk/langchain/llm': {
GenAIModel: GenAIModel,
},
},
});
}

_modelType(): string {
return this.modelId;
}

_llmType(): string {
return 'GenAI';
return 'GenAIModel';
}

lc_serializable = true;
lc_namespace = ['@ibm-generative-ai/node-sdk', 'langchain', 'llm'];

get lc_id(): string[] {
return [...this.lc_namespace, 'GenAIModel'];
}

lc_kwargs = {
modelId: undefined,
promptId: undefined,
conversationId: undefined,
parameters: undefined,
moderations: undefined,
useConversationParameters: undefined,
parentId: undefined,
trimMethod: undefined,
};
}
7 changes: 7 additions & 0 deletions tests/e2e/langchain/llm-chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,11 @@ describe('LangChain Chat', () => {
expect(tokens).toStrictEqual(contents);
});
});

it('Serializes', async () => {
const model = makeModel();
const serialized = model.toJSON();
const deserialized = await GenAIChatModel.fromJSON(serialized);
expect(deserialized).toBeInstanceOf(GenAIChatModel);
});
});
7 changes: 7 additions & 0 deletions tests/e2e/langchain/llm.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ describe('Langchain', () => {
});
});

it('Serializes', async () => {
const client = makeClient('google/flan-ul2');
const serialized = client.toJSON();
const deserialized = await GenAIModel.fromJSON(serialized);
expect(deserialized).toBeInstanceOf(GenAIModel);
});

describe('generate', () => {
// TODO: enable once we will set default model for the test account
test.skip('should handle empty modelId', async () => {
Expand Down

0 comments on commit b7ed959

Please sign in to comment.