From 54a90448d86d7d281f98d8e9d522d2a5b037b59a Mon Sep 17 00:00:00 2001 From: Zhongpin Wang Date: Mon, 3 Feb 2025 16:23:38 +0100 Subject: [PATCH] refactor: Content filter helper function (#441) --- .changeset/afraid-cooks-shave.md | 6 + packages/orchestration/README.md | 64 +-- packages/orchestration/src/index.ts | 7 +- packages/orchestration/src/internal.ts | 2 +- .../src/orchestration-client.test.ts | 26 +- .../orchestration/src/orchestration-client.ts | 2 +- ...-completion-post-request-from-json.test.ts | 2 +- ...hestration-completion-post-request.test.ts | 22 +- .../orchestration/src/orchestration-types.ts | 50 +++ .../src/orchestration-utils.test.ts | 378 ------------------ .../orchestration/src/util/filtering.test.ts | 214 ++++++++++ packages/orchestration/src/util/filtering.ts | 60 +++ .../orchestration/src/util/grounding.test.ts | 57 +++ packages/orchestration/src/util/grounding.ts | 25 ++ packages/orchestration/src/util/index.ts | 3 + .../src/util/module-config.test.ts | 155 +++++++ .../module-config.ts} | 52 +-- packages/orchestration/tsconfig.json | 2 +- sample-code/README.md | 4 +- sample-code/src/orchestration.ts | 25 +- tests/type-tests/test/orchestration.test-d.ts | 24 +- 21 files changed, 700 insertions(+), 480 deletions(-) create mode 100644 .changeset/afraid-cooks-shave.md delete mode 100644 packages/orchestration/src/orchestration-utils.test.ts create mode 100644 packages/orchestration/src/util/filtering.test.ts create mode 100644 packages/orchestration/src/util/filtering.ts create mode 100644 packages/orchestration/src/util/grounding.test.ts create mode 100644 packages/orchestration/src/util/grounding.ts create mode 100644 packages/orchestration/src/util/index.ts create mode 100644 packages/orchestration/src/util/module-config.test.ts rename packages/orchestration/src/{orchestration-utils.ts => util/module-config.ts} (74%) diff --git a/.changeset/afraid-cooks-shave.md b/.changeset/afraid-cooks-shave.md new file mode 100644 index 000000000..fa3a2b281 --- /dev/null +++ b/.changeset/afraid-cooks-shave.md @@ -0,0 +1,6 @@ +--- +'@sap-ai-sdk/orchestration': minor +--- + +[Compatibility Note] Deprecate `buildAzureContentFilter()` function. +Use `buildAzureContentSafetyFilter()` function instead. \ No newline at end of file diff --git a/packages/orchestration/README.md b/packages/orchestration/README.md index 1c2d36f11..946bb2c15 100644 --- a/packages/orchestration/README.md +++ b/packages/orchestration/README.md @@ -265,24 +265,42 @@ Use the orchestration client with filtering to restrict content that is passed t This feature allows filtering both the [input](https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/input-filtering) and [output](https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/output-filtering) of a model based on content safety criteria. -```ts -import { - OrchestrationClient, - buildAzureContentFilter -} from '@sap-ai-sdk/orchestration'; +#### Azure Content Filter + +Use `buildAzureContentSafetyFilter()` function to build an Azure content filter for both input and output. +Each category of the filter can be assigned a specific severity level, which corresponds to an Azure threshold value. + +| Severity Level | Azure Threshold Value | +| ----------------------- | --------------------- | +| `ALLOW_SAFE` | 0 | +| `ALLOW_SAFE_LOW` | 2 | +| `ALLOW_SAFE_LOW_MEDIUM` | 4 | +| `ALLOW_ALL` | 6 | -const filter = buildAzureContentFilter({ Hate: 2, Violence: 4 }); +```ts +import { OrchestrationClient, ContentFilters } from '@sap-ai-sdk/orchestration'; +const llm = { + model_name: 'gpt-4o', + model_params: { max_tokens: 50, temperature: 0.1 } +}; +const templating = { + template: [{ role: 'user', content: '{{?input}}' }] +}; + +const filter = buildAzureContentSafetyFilter({ + Hate: 'ALLOW_SAFE_LOW', + Violence: 'ALLOW_SAFE_LOW_MEDIUM' +}); const orchestrationClient = new OrchestrationClient({ - llm: { - model_name: 'gpt-4o', - model_params: { max_tokens: 50, temperature: 0.1 } - }, - templating: { - template: [{ role: 'user', content: '{{?input}}' }] - }, + llm, + templating, filtering: { - input: filter, - output: filter + input: { + filters: [filter] + }, + output: { + filters: [filter] + } } }); @@ -296,23 +314,19 @@ try { } ``` +#### Error Handling + Both `chatCompletion()` and `getContent()` methods can throw errors. -- **axios errors**: +- **Axios Errors**: When the chat completion request fails with a `400` status code, the caught error will be an `Axios` error. - The property `error.response.data.message` may provide additional details about the failure's cause. + The property `error.response.data.message` provides additional details about the failure. -- **output content filtered**: - The method `getContent()` can throw an error if the output filter filters the model output. +- **Output Content Filtered**: + The `getContent()` method can throw an error if the output filter filters the model output. This can occur even if the chat completion request responds with a `200` HTTP status code. The `error.message` property indicates if the output was filtered. -Therefore, handle errors appropriately to ensure meaningful feedback for both types of errors. - -`buildAzureContentFilter()` is a convenience function that creates an Azure content filter configuration based on the provided inputs. -The Azure content filter supports four categories: `Hate`, `Violence`, `Sexual`, and `SelfHarm`. -Each category can be configured with severity levels of 0, 2, 4, or 6. - ### Data Masking You can anonymize or pseudonomize the prompt using the data masking capabilities of the orchestration service. diff --git a/packages/orchestration/src/index.ts b/packages/orchestration/src/index.ts index ef3e251b5..98466d1c7 100644 --- a/packages/orchestration/src/index.ts +++ b/packages/orchestration/src/index.ts @@ -8,7 +8,9 @@ export type { StreamOptions, DocumentGroundingServiceConfig, DocumentGroundingServiceFilter, - LlmModelParams + LlmModelParams, + AzureContentFilter, + AzureFilterThreshold } from './orchestration-types.js'; export { OrchestrationStreamResponse } from './orchestration-stream-response.js'; @@ -21,8 +23,9 @@ export { OrchestrationClient } from './orchestration-client.js'; export { buildAzureContentFilter, + buildAzureContentSafetyFilter, buildDocumentGroundingConfig -} from './orchestration-utils.js'; +} from './util/index.js'; export { OrchestrationResponse } from './orchestration-response.js'; diff --git a/packages/orchestration/src/internal.ts b/packages/orchestration/src/internal.ts index 037444bef..f2366e6cc 100644 --- a/packages/orchestration/src/internal.ts +++ b/packages/orchestration/src/internal.ts @@ -1,4 +1,4 @@ export * from './orchestration-client.js'; -export * from './orchestration-utils.js'; +export * from './util/index.js'; export * from './orchestration-types.js'; export * from './orchestration-response.js'; diff --git a/packages/orchestration/src/orchestration-client.test.ts b/packages/orchestration/src/orchestration-client.test.ts index 425739912..fdce271dc 100644 --- a/packages/orchestration/src/orchestration-client.test.ts +++ b/packages/orchestration/src/orchestration-client.test.ts @@ -9,12 +9,12 @@ import { parseMockResponse } from '../../../test-util/mock-http.js'; import { OrchestrationClient } from './orchestration-client.js'; +import { OrchestrationResponse } from './orchestration-response.js'; import { - buildAzureContentFilter, + constructCompletionPostRequestFromJsonModuleConfig, constructCompletionPostRequest, - constructCompletionPostRequestFromJsonModuleConfig -} from './orchestration-utils.js'; -import { OrchestrationResponse } from './orchestration-response.js'; + buildAzureContentSafetyFilter +} from './util/index.js'; import type { CompletionPostResponse } from './client/api/schema/index.js'; import type { OrchestrationModuleConfig, @@ -162,8 +162,22 @@ describe('orchestration service client', () => { ] }, filtering: { - input: buildAzureContentFilter({ Hate: 4, SelfHarm: 2 }), - output: buildAzureContentFilter({ Sexual: 0, Violence: 4 }) + input: { + filters: [ + buildAzureContentSafetyFilter({ + Hate: 'ALLOW_SAFE_LOW_MEDIUM', + SelfHarm: 'ALLOW_SAFE_LOW' + }) + ] + }, + output: { + filters: [ + buildAzureContentSafetyFilter({ + Sexual: 'ALLOW_SAFE', + Violence: 'ALLOW_SAFE_LOW_MEDIUM' + }) + ] + } } }; const prompt = { diff --git a/packages/orchestration/src/orchestration-client.ts b/packages/orchestration/src/orchestration-client.ts index eba03b335..ea51c4c39 100644 --- a/packages/orchestration/src/orchestration-client.ts +++ b/packages/orchestration/src/orchestration-client.ts @@ -7,7 +7,7 @@ import { OrchestrationResponse } from './orchestration-response.js'; import { constructCompletionPostRequest, constructCompletionPostRequestFromJsonModuleConfig -} from './orchestration-utils.js'; +} from './util/index.js'; import type { HttpResponse, CustomRequestConfig diff --git a/packages/orchestration/src/orchestration-completion-post-request-from-json.test.ts b/packages/orchestration/src/orchestration-completion-post-request-from-json.test.ts index d58e7985c..afd91c531 100644 --- a/packages/orchestration/src/orchestration-completion-post-request-from-json.test.ts +++ b/packages/orchestration/src/orchestration-completion-post-request-from-json.test.ts @@ -1,4 +1,4 @@ -import { constructCompletionPostRequestFromJsonModuleConfig } from './orchestration-utils.js'; +import { constructCompletionPostRequestFromJsonModuleConfig } from './util/module-config.js'; describe('construct completion post request from JSON', () => { it('should construct completion post request from JSON', () => { diff --git a/packages/orchestration/src/orchestration-completion-post-request.test.ts b/packages/orchestration/src/orchestration-completion-post-request.test.ts index d3d8bf523..39cf5a287 100644 --- a/packages/orchestration/src/orchestration-completion-post-request.test.ts +++ b/packages/orchestration/src/orchestration-completion-post-request.test.ts @@ -1,7 +1,7 @@ import { constructCompletionPostRequest, - buildAzureContentFilter -} from './orchestration-utils.js'; + buildAzureContentSafetyFilter +} from './util/index.js'; import type { CompletionPostRequest } from './client/api/schema/index.js'; import type { OrchestrationModuleConfig, @@ -169,7 +169,14 @@ describe('construct completion post request', () => { const config: OrchestrationModuleConfig = { ...defaultConfig, filtering: { - input: buildAzureContentFilter({ Hate: 4, SelfHarm: 0 }) + input: { + filters: [ + buildAzureContentSafetyFilter({ + Hate: 'ALLOW_SAFE_LOW_MEDIUM', + SelfHarm: 'ALLOW_SAFE' + }) + ] + } } }; const expectedCompletionPostRequest: CompletionPostRequest = { @@ -209,7 +216,14 @@ describe('construct completion post request', () => { const config: OrchestrationModuleConfig = { ...defaultConfig, filtering: { - output: buildAzureContentFilter({ Hate: 4, SelfHarm: 0 }) + output: { + filters: [ + buildAzureContentSafetyFilter({ + Hate: 'ALLOW_SAFE_LOW_MEDIUM', + SelfHarm: 'ALLOW_SAFE' + }) + ] + } } }; diff --git a/packages/orchestration/src/orchestration-types.ts b/packages/orchestration/src/orchestration-types.ts index 74508211b..e9e5e315b 100644 --- a/packages/orchestration/src/orchestration-types.ts +++ b/packages/orchestration/src/orchestration-types.ts @@ -63,6 +63,17 @@ export interface OrchestrationModuleConfig { llm: LlmModuleConfig; /** * Filtering module configuration. + * Construct filter configuration for both input and output filters using convenience functions. + * @example + * ```ts + * filtering: { + * input: { + * filters: [ + * buildAzureContentSafetyFilter({ Hate: 'ALLOW_SAFE', Violence: 'ALLOW_SAFE_LOW_MEDIUM' }) + * ] + * } + * } + * ``` */ filtering?: FilteringModuleConfig; /** @@ -151,3 +162,42 @@ export interface DocumentGroundingServiceConfig { */ output_param: string; } + +/** + * Filter configuration for Azure content safety Filter. + */ +export interface AzureContentFilter { + /** + * The filter category for hate content. + */ + Hate?: AzureFilterThreshold; + /** + * The filter category for self-harm content. + */ + SelfHarm?: AzureFilterThreshold; + /** + * The filter category for sexual content. + */ + Sexual?: AzureFilterThreshold; + /** + * The filter category for violence content. + */ + Violence?: AzureFilterThreshold; +} + +/** + * A descriptive constant for Azure content safety filter threshold. + * @internal + */ +export const supportedAzureFilterThresholds = { + ALLOW_SAFE: 0, + ALLOW_SAFE_LOW: 2, + ALLOW_SAFE_LOW_MEDIUM: 4, + ALLOW_ALL: 6 +} as const; + +/** + * The Azure threshold level supported for each azure content filter category. + * + */ +export type AzureFilterThreshold = keyof typeof supportedAzureFilterThresholds; diff --git a/packages/orchestration/src/orchestration-utils.test.ts b/packages/orchestration/src/orchestration-utils.test.ts deleted file mode 100644 index 353cda280..000000000 --- a/packages/orchestration/src/orchestration-utils.test.ts +++ /dev/null @@ -1,378 +0,0 @@ -import { createLogger } from '@sap-cloud-sdk/util'; -import { jest } from '@jest/globals'; -import { - addStreamOptions, - addStreamOptionsToLlmModuleConfig, - addStreamOptionsToOutputFilteringConfig, - buildAzureContentFilter, - buildDocumentGroundingConfig, - constructCompletionPostRequest -} from './orchestration-utils.js'; -import type { - CompletionPostRequest, - FilteringModuleConfig, - ModuleConfigs, - OrchestrationConfig -} from './client/api/schema/index.js'; -import type { - OrchestrationModuleConfig, - DocumentGroundingServiceConfig, - StreamOptions -} from './orchestration-types.js'; - -describe('orchestration utils', () => { - describe('stream util tests', () => { - const defaultOrchestrationModuleConfig: OrchestrationModuleConfig = { - llm: { - model_name: 'gpt-35-turbo-16k', - model_params: { max_tokens: 50, temperature: 0.1 } - }, - templating: { - template: [ - { role: 'user', content: 'Create paraphrases of {{?phrase}}' } - ] - } - }; - - const defaultModuleConfigs: ModuleConfigs = { - llm_module_config: defaultOrchestrationModuleConfig.llm, - templating_module_config: defaultOrchestrationModuleConfig.templating - }; - - const defaultStreamOptions: StreamOptions = { - global: { chunk_size: 100 }, - llm: { include_usage: false }, - outputFiltering: { overlap: 100 } - }; - - it('should add include_usage to llm module config', () => { - const llmConfig = addStreamOptionsToLlmModuleConfig( - defaultOrchestrationModuleConfig.llm - ); - expect(llmConfig.model_params?.stream_options).toEqual({ - include_usage: true - }); - }); - - it('should set include_usage to false in llm module config', () => { - const llmConfig = addStreamOptionsToLlmModuleConfig( - defaultOrchestrationModuleConfig.llm, - defaultStreamOptions - ); - expect(llmConfig.model_params?.stream_options).toEqual({ - include_usage: false - }); - }); - - it('should not add any stream options to llm module config', () => { - const llmConfig = addStreamOptionsToLlmModuleConfig( - defaultOrchestrationModuleConfig.llm, - { - llm: null - } - ); - expect( - Object.keys(llmConfig.model_params ?? {}).every( - key => key !== 'stream_options' - ) - ).toBe(true); - }); - - it('should add stream options to output filtering config', () => { - const config: OrchestrationModuleConfig = { - ...defaultOrchestrationModuleConfig, - filtering: { - output: buildAzureContentFilter({ Hate: 4, SelfHarm: 0 }) - } - }; - const filteringConfig = addStreamOptionsToOutputFilteringConfig( - config.filtering!.output!, - defaultStreamOptions.outputFiltering! - ); - expect(filteringConfig.filters).toEqual( - config.filtering?.output?.filters - ); - expect(filteringConfig.stream_options).toEqual({ - overlap: 100 - }); - }); - - it('should add stream options to orchestration config', () => { - const config: ModuleConfigs = { - ...defaultModuleConfigs, - filtering_module_config: { - output: buildAzureContentFilter({ Hate: 4, SelfHarm: 0 }) - } - }; - - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { llm, ...streamOptions } = defaultStreamOptions; - - const expectedOrchestrationConfig: OrchestrationConfig = { - stream: true, - stream_options: streamOptions.global, - module_configurations: { - ...config, - llm_module_config: { - ...config.llm_module_config, - model_params: { - ...config.llm_module_config.model_params, - stream_options: { include_usage: true } - } - }, - filtering_module_config: { - output: { - ...config.filtering_module_config!.output!, - stream_options: streamOptions.outputFiltering - } - } - } - }; - const orchestrationConfig = addStreamOptions(config, streamOptions); - expect(orchestrationConfig).toEqual(expectedOrchestrationConfig); - }); - - it('should warn if no filter config was set, but streaming options were set', () => { - const logger = createLogger({ - package: 'orchestration', - messageContext: 'orchestration-utils' - }); - - const warnSpy = jest.spyOn(logger, 'warn'); - - const config = addStreamOptions( - defaultModuleConfigs, - defaultStreamOptions - ); - - expect(warnSpy).toHaveBeenCalledWith( - 'Output filter stream options are not applied because filtering module is not configured.' - ); - expect( - config.module_configurations.filtering_module_config - ).toBeUndefined(); - }); - }); - - describe('azure filter', () => { - const config: OrchestrationModuleConfig = { - llm: { - model_name: 'gpt-35-turbo-16k', - model_params: { max_tokens: 50, temperature: 0.1 } - }, - templating: { - template: [ - { role: 'user', content: 'Create {number} paraphrases of {phrase}' } - ] - } - }; - const prompt = { inputParams: { phrase: 'I hate you.', number: '3' } }; - - afterEach(() => { - config.filtering = undefined; - }); - - it('constructs filter configuration with only input', async () => { - const filtering: FilteringModuleConfig = { - input: buildAzureContentFilter({ Hate: 4, SelfHarm: 0 }) - }; - const expectedFilterConfig: FilteringModuleConfig = { - input: { - filters: [ - { - type: 'azure_content_safety', - config: { - Hate: 4, - SelfHarm: 0 - } - } - ] - } - }; - config.filtering = filtering; - const completionPostRequest: CompletionPostRequest = - constructCompletionPostRequest(config, prompt); - expect( - completionPostRequest.orchestration_config.module_configurations - .filtering_module_config - ).toEqual(expectedFilterConfig); - }); - - it('constructs filter configuration with only output', async () => { - const filtering: FilteringModuleConfig = { - output: buildAzureContentFilter({ Sexual: 2, Violence: 6 }) - }; - const expectedFilterConfig: FilteringModuleConfig = { - output: { - filters: [ - { - type: 'azure_content_safety', - config: { - Sexual: 2, - Violence: 6 - } - } - ] - } - }; - config.filtering = filtering; - const completionPostRequest: CompletionPostRequest = - constructCompletionPostRequest(config, prompt); - expect( - completionPostRequest.orchestration_config.module_configurations - .filtering_module_config - ).toEqual(expectedFilterConfig); - }); - - it('constructs filter configuration with both input and output', async () => { - const filtering: FilteringModuleConfig = { - input: buildAzureContentFilter({ - Hate: 4, - SelfHarm: 0, - Sexual: 2, - Violence: 6 - }), - output: buildAzureContentFilter({ Sexual: 2, Violence: 6 }) - }; - const expectedFilterConfig: FilteringModuleConfig = { - input: { - filters: [ - { - type: 'azure_content_safety', - config: { - Hate: 4, - SelfHarm: 0, - Sexual: 2, - Violence: 6 - } - } - ] - }, - output: { - filters: [ - { - type: 'azure_content_safety', - config: { - Sexual: 2, - Violence: 6 - } - } - ] - } - }; - config.filtering = filtering; - const completionPostRequest: CompletionPostRequest = - constructCompletionPostRequest(config, prompt); - expect( - completionPostRequest.orchestration_config.module_configurations - .filtering_module_config - ).toEqual(expectedFilterConfig); - }); - - it('omits filters if not set', async () => { - const filtering: FilteringModuleConfig = { - input: buildAzureContentFilter(), - output: buildAzureContentFilter() - }; - config.filtering = filtering; - const completionPostRequest: CompletionPostRequest = - constructCompletionPostRequest(config, prompt); - const expectedFilterConfig: FilteringModuleConfig = { - input: { - filters: [ - { - type: 'azure_content_safety' - } - ] - }, - output: { - filters: [ - { - type: 'azure_content_safety' - } - ] - } - }; - expect( - completionPostRequest.orchestration_config.module_configurations - .filtering_module_config - ).toEqual(expectedFilterConfig); - }); - - it('omits filter configuration if not set', async () => { - const filtering: FilteringModuleConfig = {}; - config.filtering = filtering; - const completionPostRequest: CompletionPostRequest = - constructCompletionPostRequest(config, prompt); - expect( - completionPostRequest.orchestration_config.module_configurations - .filtering_module_config - ).toBeUndefined(); - }); - - it('throw error when configuring empty filter', async () => { - const createFilterConfig = () => { - { - buildAzureContentFilter({}); - } - }; - expect(createFilterConfig).toThrow( - 'Filter property cannot be an empty object' - ); - }); - }); - describe('document grounding', () => { - it('builds grounding configuration with minimal required properties', () => { - const groundingConfig: DocumentGroundingServiceConfig = { - filters: [ - { - id: 'filter-id' - } - ], - input_params: ['input'], - output_param: 'output' - }; - expect(buildDocumentGroundingConfig(groundingConfig)).toEqual({ - type: 'document_grounding_service', - config: { - filters: [ - { - id: 'filter-id', - data_repository_type: 'vector' - } - ], - input_params: ['input'], - output_param: 'output' - } - }); - }); - - it('overrides default data repository type', () => { - const groundingConfig: DocumentGroundingServiceConfig = { - filters: [ - { - id: 'filter-id', - data_repositories: ['repo1', 'repo2'], - data_repository_type: 'custom-type' - } - ], - input_params: ['input'], - output_param: 'output' - }; - expect(buildDocumentGroundingConfig(groundingConfig)).toEqual({ - type: 'document_grounding_service', - config: { - filters: [ - { - id: 'filter-id', - data_repositories: ['repo1', 'repo2'], - data_repository_type: 'custom-type' - } - ], - input_params: ['input'], - output_param: 'output' - } - }); - }); - }); -}); diff --git a/packages/orchestration/src/util/filtering.test.ts b/packages/orchestration/src/util/filtering.test.ts new file mode 100644 index 000000000..77fcd4a9a --- /dev/null +++ b/packages/orchestration/src/util/filtering.test.ts @@ -0,0 +1,214 @@ +import { + buildAzureContentFilter, + buildAzureContentSafetyFilter +} from './filtering.js'; +import { constructCompletionPostRequest } from './module-config.js'; +import type { OrchestrationModuleConfig } from '../orchestration-types.js'; +import type { + CompletionPostRequest, + FilterConfig, + FilteringModuleConfig +} from '../client/api/schema/index.js'; + +describe('Content filter util', () => { + // TODO: Remove this test collection once `buildAzureContentFilter` is removed. + describe('buildAzureContentFilter', () => { + const config: OrchestrationModuleConfig = { + llm: { + model_name: 'gpt-35-turbo-16k', + model_params: { max_tokens: 50, temperature: 0.1 } + }, + templating: { + template: [ + { role: 'user', content: 'Create {number} paraphrases of {phrase}' } + ] + } + }; + + const prompt = { inputParams: { phrase: 'I hate you.', number: '3' } }; + + afterEach(() => { + config.filtering = undefined; + }); + + it('constructs filter configuration with only input', async () => { + const filtering: FilteringModuleConfig = { + input: buildAzureContentFilter({ Hate: 4, SelfHarm: 0 }) + }; + const expectedFilterConfig: FilteringModuleConfig = { + input: { + filters: [ + { + type: 'azure_content_safety', + config: { + Hate: 4, + SelfHarm: 0 + } + } + ] + } + }; + config.filtering = filtering; + const completionPostRequest: CompletionPostRequest = + constructCompletionPostRequest(config, prompt); + expect( + completionPostRequest.orchestration_config.module_configurations + .filtering_module_config + ).toEqual(expectedFilterConfig); + }); + + it('constructs filter configuration with only output', async () => { + const filtering: FilteringModuleConfig = { + output: buildAzureContentFilter({ Sexual: 2, Violence: 6 }) + }; + const expectedFilterConfig: FilteringModuleConfig = { + output: { + filters: [ + { + type: 'azure_content_safety', + config: { + Sexual: 2, + Violence: 6 + } + } + ] + } + }; + config.filtering = filtering; + const completionPostRequest: CompletionPostRequest = + constructCompletionPostRequest(config, prompt); + expect( + completionPostRequest.orchestration_config.module_configurations + .filtering_module_config + ).toEqual(expectedFilterConfig); + }); + + it('constructs filter configuration with both input and output', async () => { + const filtering: FilteringModuleConfig = { + input: buildAzureContentFilter({ + Hate: 4, + SelfHarm: 0, + Sexual: 2, + Violence: 6 + }), + output: buildAzureContentFilter({ Sexual: 2, Violence: 6 }) + }; + const expectedFilterConfig: FilteringModuleConfig = { + input: { + filters: [ + { + type: 'azure_content_safety', + config: { + Hate: 4, + SelfHarm: 0, + Sexual: 2, + Violence: 6 + } + } + ] + }, + output: { + filters: [ + { + type: 'azure_content_safety', + config: { + Sexual: 2, + Violence: 6 + } + } + ] + } + }; + config.filtering = filtering; + const completionPostRequest: CompletionPostRequest = + constructCompletionPostRequest(config, prompt); + expect( + completionPostRequest.orchestration_config.module_configurations + .filtering_module_config + ).toEqual(expectedFilterConfig); + }); + + it('omits filters if not set', async () => { + const filtering: FilteringModuleConfig = { + input: buildAzureContentFilter(), + output: buildAzureContentFilter() + }; + config.filtering = filtering; + const completionPostRequest: CompletionPostRequest = + constructCompletionPostRequest(config, prompt); + const expectedFilterConfig: FilteringModuleConfig = { + input: { + filters: [ + { + type: 'azure_content_safety' + } + ] + }, + output: { + filters: [ + { + type: 'azure_content_safety' + } + ] + } + }; + expect( + completionPostRequest.orchestration_config.module_configurations + .filtering_module_config + ).toEqual(expectedFilterConfig); + }); + + it('omits filter configuration if not set', async () => { + const filtering: FilteringModuleConfig = {}; + config.filtering = filtering; + const completionPostRequest: CompletionPostRequest = + constructCompletionPostRequest(config, prompt); + expect( + completionPostRequest.orchestration_config.module_configurations + .filtering_module_config + ).toBeUndefined(); + }); + + it('throw error when configuring empty filter', async () => { + const createFilterConfig = () => { + { + buildAzureContentFilter({}); + } + }; + expect(createFilterConfig).toThrow( + 'Filter property cannot be an empty object' + ); + }); + }); + + describe('Azure content filter', () => { + it('builds filter config', async () => { + const filterConfig = buildAzureContentSafetyFilter({ + Hate: 'ALLOW_SAFE_LOW_MEDIUM', + SelfHarm: 'ALLOW_SAFE' + }); + const expectedFilterConfig: FilterConfig = { + type: 'azure_content_safety', + config: { + Hate: 4, + SelfHarm: 0 + } + }; + expect(filterConfig).toEqual(expectedFilterConfig); + }); + + it('builds filter config with no config', async () => { + const filterConfig = buildAzureContentSafetyFilter(); + const expectedFilterConfig: FilterConfig = { + type: 'azure_content_safety' + }; + expect(filterConfig).toEqual(expectedFilterConfig); + }); + + it('throw error when configuring empty filter', async () => { + expect(() => buildAzureContentSafetyFilter({})).toThrow( + 'Filtering configuration cannot be an empty object' + ); + }); + }); +}); diff --git a/packages/orchestration/src/util/filtering.ts b/packages/orchestration/src/util/filtering.ts new file mode 100644 index 000000000..7c3c120ac --- /dev/null +++ b/packages/orchestration/src/util/filtering.ts @@ -0,0 +1,60 @@ +import { supportedAzureFilterThresholds } from '../orchestration-types.js'; +import type { + AzureContentSafety, + AzureContentSafetyFilterConfig, + InputFilteringConfig, + OutputFilteringConfig +} from '../client/api/schema/index.js'; +import type { + AzureContentFilter, + AzureFilterThreshold +} from '../orchestration-types.js'; + +/** + * Convenience function to create Azure content filters. + * @param filter - Filtering configuration for Azure filter. If skipped, the default Azure content filter configuration is used. + * @returns An object with the Azure filtering configuration. + * @deprecated Since 1.8.0. Use {@link buildAzureContentSafetyFilter()} instead. + */ +export function buildAzureContentFilter( + filter?: AzureContentSafety +): InputFilteringConfig | OutputFilteringConfig { + if (filter && !Object.keys(filter).length) { + throw new Error('Filter property cannot be an empty object'); + } + return { + filters: [ + { + type: 'azure_content_safety', + ...(filter && { config: filter }) + } + ] + }; +} + +/** + * Convenience function to create Azure content filters. + * @param config - Configuration for Azure content safety filter. + * If skipped, the default configuration of `ALLOW_SAFE_LOW` is used for all filter categories. + * @returns Filter config object. + */ +export function buildAzureContentSafetyFilter( + config?: AzureContentFilter +): AzureContentSafetyFilterConfig { + if (config && !Object.keys(config).length) { + throw new Error('Filtering configuration cannot be an empty object'); + } + return { + type: 'azure_content_safety', + ...(config && { + config: { + ...Object.fromEntries( + Object.entries(config).map(([key, value]) => [ + key, + supportedAzureFilterThresholds[value as AzureFilterThreshold] + ]) + ) + } + }) + }; +} diff --git a/packages/orchestration/src/util/grounding.test.ts b/packages/orchestration/src/util/grounding.test.ts new file mode 100644 index 000000000..25b8cc05e --- /dev/null +++ b/packages/orchestration/src/util/grounding.test.ts @@ -0,0 +1,57 @@ +import { buildDocumentGroundingConfig } from './grounding.js'; +import type { DocumentGroundingServiceConfig } from '../orchestration-types.js'; + +describe('document grounding util', () => { + it('builds grounding configuration with minimal required properties', () => { + const groundingConfig: DocumentGroundingServiceConfig = { + filters: [ + { + id: 'filter-id' + } + ], + input_params: ['input'], + output_param: 'output' + }; + expect(buildDocumentGroundingConfig(groundingConfig)).toEqual({ + type: 'document_grounding_service', + config: { + filters: [ + { + id: 'filter-id', + data_repository_type: 'vector' + } + ], + input_params: ['input'], + output_param: 'output' + } + }); + }); + + it('overrides default data repository type', () => { + const groundingConfig: DocumentGroundingServiceConfig = { + filters: [ + { + id: 'filter-id', + data_repositories: ['repo1', 'repo2'], + data_repository_type: 'custom-type' + } + ], + input_params: ['input'], + output_param: 'output' + }; + expect(buildDocumentGroundingConfig(groundingConfig)).toEqual({ + type: 'document_grounding_service', + config: { + filters: [ + { + id: 'filter-id', + data_repositories: ['repo1', 'repo2'], + data_repository_type: 'custom-type' + } + ], + input_params: ['input'], + output_param: 'output' + } + }); + }); +}); diff --git a/packages/orchestration/src/util/grounding.ts b/packages/orchestration/src/util/grounding.ts new file mode 100644 index 000000000..1ee76b41b --- /dev/null +++ b/packages/orchestration/src/util/grounding.ts @@ -0,0 +1,25 @@ +import type { GroundingModuleConfig } from '../client/api/schema/index.js'; +import type { DocumentGroundingServiceConfig } from '../orchestration-types.js'; + +/** + * Convenience function to create Document Grounding configuration. + * @param groundingConfig - Configuration for the document grounding service. + * @returns An object with the full grounding configuration. + */ +export function buildDocumentGroundingConfig( + groundingConfig: DocumentGroundingServiceConfig +): GroundingModuleConfig { + return { + type: 'document_grounding_service', + config: { + input_params: groundingConfig.input_params, + output_param: groundingConfig.output_param, + ...(groundingConfig.filters && { + filters: groundingConfig.filters?.map(filter => ({ + data_repository_type: 'vector', + ...filter + })) + }) + } + }; +} diff --git a/packages/orchestration/src/util/index.ts b/packages/orchestration/src/util/index.ts new file mode 100644 index 000000000..2ee5f942e --- /dev/null +++ b/packages/orchestration/src/util/index.ts @@ -0,0 +1,3 @@ +export * from './filtering.js'; +export * from './grounding.js'; +export * from './module-config.js'; diff --git a/packages/orchestration/src/util/module-config.test.ts b/packages/orchestration/src/util/module-config.test.ts new file mode 100644 index 000000000..b7b7e62e0 --- /dev/null +++ b/packages/orchestration/src/util/module-config.test.ts @@ -0,0 +1,155 @@ +import { createLogger } from '@sap-cloud-sdk/util'; +import { jest } from '@jest/globals'; +import { + addStreamOptions, + addStreamOptionsToLlmModuleConfig, + addStreamOptionsToOutputFilteringConfig +} from './module-config.js'; +import { buildAzureContentSafetyFilter } from './filtering.js'; +import type { + ModuleConfigs, + OrchestrationConfig +} from '../client/api/schema/index.js'; +import type { + OrchestrationModuleConfig, + StreamOptions +} from '../orchestration-types.js'; +describe('stream util tests', () => { + const defaultOrchestrationModuleConfig: OrchestrationModuleConfig = { + llm: { + model_name: 'gpt-35-turbo-16k', + model_params: { max_tokens: 50, temperature: 0.1 } + }, + templating: { + template: [{ role: 'user', content: 'Create paraphrases of {{?phrase}}' }] + } + }; + + const defaultModuleConfigs: ModuleConfigs = { + llm_module_config: defaultOrchestrationModuleConfig.llm, + templating_module_config: defaultOrchestrationModuleConfig.templating + }; + + const defaultStreamOptions: StreamOptions = { + global: { chunk_size: 100 }, + llm: { include_usage: false }, + outputFiltering: { overlap: 100 } + }; + + it('should add include_usage to llm module config', () => { + const llmConfig = addStreamOptionsToLlmModuleConfig( + defaultOrchestrationModuleConfig.llm + ); + expect(llmConfig.model_params?.stream_options).toEqual({ + include_usage: true + }); + }); + + it('should set include_usage to false in llm module config', () => { + const llmConfig = addStreamOptionsToLlmModuleConfig( + defaultOrchestrationModuleConfig.llm, + defaultStreamOptions + ); + expect(llmConfig.model_params?.stream_options).toEqual({ + include_usage: false + }); + }); + + it('should not add any stream options to llm module config', () => { + const llmConfig = addStreamOptionsToLlmModuleConfig( + defaultOrchestrationModuleConfig.llm, + { + llm: null + } + ); + expect( + Object.keys(llmConfig.model_params ?? {}).every( + key => key !== 'stream_options' + ) + ).toBe(true); + }); + + it('should add stream options to output filtering config', () => { + const config: OrchestrationModuleConfig = { + ...defaultOrchestrationModuleConfig, + filtering: { + output: { + filters: [ + buildAzureContentSafetyFilter({ + Hate: 'ALLOW_SAFE_LOW_MEDIUM', + SelfHarm: 'ALLOW_SAFE' + }) + ] + } + } + }; + const filteringConfig = addStreamOptionsToOutputFilteringConfig( + config.filtering!.output!, + defaultStreamOptions.outputFiltering! + ); + expect(filteringConfig.filters).toEqual(config.filtering?.output?.filters); + expect(filteringConfig.stream_options).toEqual({ + overlap: 100 + }); + }); + + it('should add stream options to orchestration config', () => { + const config: ModuleConfigs = { + ...defaultModuleConfigs, + filtering_module_config: { + output: { + filters: [ + buildAzureContentSafetyFilter({ + Hate: 'ALLOW_SAFE_LOW_MEDIUM', + SelfHarm: 'ALLOW_SAFE' + }) + ] + } + } + }; + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { llm, ...streamOptions } = defaultStreamOptions; + + const expectedOrchestrationConfig: OrchestrationConfig = { + stream: true, + stream_options: streamOptions.global, + module_configurations: { + ...config, + llm_module_config: { + ...config.llm_module_config, + model_params: { + ...config.llm_module_config.model_params, + stream_options: { include_usage: true } + } + }, + filtering_module_config: { + output: { + ...config.filtering_module_config!.output!, + stream_options: streamOptions.outputFiltering + } + } + } + }; + const orchestrationConfig = addStreamOptions(config, streamOptions); + expect(orchestrationConfig).toEqual(expectedOrchestrationConfig); + }); + + it('should warn if no filter config was set, but streaming options were set', () => { + const logger = createLogger({ + package: 'orchestration', + messageContext: 'orchestration-utils' + }); + + const warnSpy = jest.spyOn(logger, 'warn'); + + const config = addStreamOptions(defaultModuleConfigs, defaultStreamOptions); + + expect(warnSpy).toHaveBeenCalledWith( + 'Output filter stream options are not applied because filtering module is not configured.' + ); + expect( + config.module_configurations.filtering_module_config + ).toBeUndefined(); + }); +}); diff --git a/packages/orchestration/src/orchestration-utils.ts b/packages/orchestration/src/util/module-config.ts similarity index 74% rename from packages/orchestration/src/orchestration-utils.ts rename to packages/orchestration/src/util/module-config.ts index 5dbbf06ca..4cac255a1 100644 --- a/packages/orchestration/src/orchestration-utils.ts +++ b/packages/orchestration/src/util/module-config.ts @@ -1,22 +1,18 @@ import { createLogger } from '@sap-cloud-sdk/util'; import type { - DocumentGroundingServiceConfig, Prompt, StreamOptions, LlmModuleConfig, OrchestrationModuleConfig -} from './orchestration-types.js'; +} from '../orchestration-types.js'; import type { - AzureContentSafety, - GroundingModuleConfig, - InputFilteringConfig, CompletionPostRequest, FilteringStreamOptions, ModuleConfigs, OrchestrationConfig, OutputFilteringConfig, GlobalStreamOptions -} from './client/api/schema/index.js'; +} from '../client/api/schema/index.js'; const logger = createLogger({ package: 'orchestration', @@ -182,47 +178,3 @@ function mergeStreamOptions( }) }; } - -/** - * Convenience function to create Azure content filters. - * @param filter - Filtering configuration for Azure filter. If skipped, the default Azure content filter configuration is used. - * @returns An object with the Azure filtering configuration. - */ -export function buildAzureContentFilter( - filter?: AzureContentSafety -): InputFilteringConfig | OutputFilteringConfig { - if (filter && !Object.keys(filter).length) { - throw new Error('Filter property cannot be an empty object'); - } - return { - filters: [ - { - type: 'azure_content_safety', - ...(filter && { config: filter }) - } - ] - }; -} - -/** - * Convenience function to create Document Grounding configuration. - * @param groundingConfig - Configuration for the document grounding service. - * @returns An object with the full grounding configuration. - */ -export function buildDocumentGroundingConfig( - groundingConfig: DocumentGroundingServiceConfig -): GroundingModuleConfig { - return { - type: 'document_grounding_service', - config: { - input_params: groundingConfig.input_params, - output_param: groundingConfig.output_param, - ...(groundingConfig.filters && { - filters: groundingConfig.filters?.map(filter => ({ - data_repository_type: 'vector', - ...filter - })) - }) - } - }; -} diff --git a/packages/orchestration/tsconfig.json b/packages/orchestration/tsconfig.json index 78a9a5e29..2caf7a3cc 100644 --- a/packages/orchestration/tsconfig.json +++ b/packages/orchestration/tsconfig.json @@ -6,7 +6,7 @@ "tsBuildInfoFile": "./dist/.tsbuildinfo", "composite": true }, - "include": ["src/**/*.ts", "src/orchestration-utils.test.ts"], + "include": ["src/**/*.ts"], "exclude": ["dist/**/*", "test/**/*", "**/*.test.ts", "node_modules/**/*"], "references": [{ "path": "../core" }, { "path": "../ai-api" }] } diff --git a/sample-code/README.md b/sample-code/README.md index 763a30add..6188bb66b 100644 --- a/sample-code/README.md +++ b/sample-code/README.md @@ -147,14 +147,14 @@ Get chat completion response with image input. `GET /orchestration/inputFiltering` Get chat completion response with Azure content filter for the input. -Use `buildAzureContentFilter()` to build the content filter. +For example, use `buildAzureContentSafetyFilter()` function to build Azure content filter. #### Output Filtering `GET /orchestration/outputFiltering` Get chat completion response with Azure content filter for the output. -Use `buildAzureContentFilter()` to build the content filter. +For example, use `buildAzureContentSafetyFilter()` function to build Azure content filter. #### Custom Request Config diff --git a/sample-code/src/orchestration.ts b/sample-code/src/orchestration.ts index 7123a6633..e08a83c12 100644 --- a/sample-code/src/orchestration.ts +++ b/sample-code/src/orchestration.ts @@ -3,8 +3,8 @@ import { join, dirname } from 'path'; import { fileURLToPath } from 'url'; import { OrchestrationClient, - buildAzureContentFilter, - buildDocumentGroundingConfig + buildDocumentGroundingConfig, + buildAzureContentSafetyFilter } from '@sap-ai-sdk/orchestration'; import { createLogger } from '@sap-cloud-sdk/util'; import type { @@ -168,14 +168,18 @@ const templating = { template: [{ role: 'user', content: '{{?input}}' }] }; */ export async function orchestrationInputFiltering(): Promise { // create a filter with minimal thresholds for hate and violence - // lower numbers mean more strict filtering - const filter = buildAzureContentFilter({ Hate: 0, Violence: 0 }); + const azureContentFilter = buildAzureContentSafetyFilter({ + Hate: 'ALLOW_SAFE', + Violence: 'ALLOW_SAFE' + }); const orchestrationClient = new OrchestrationClient({ llm, templating, - // configure the filter to be applied for both input and output + // configure the filter to be applied for input filtering: { - input: filter + input: { + filters: [azureContentFilter] + } } }); @@ -201,12 +205,17 @@ export async function orchestrationInputFiltering(): Promise { export async function orchestrationOutputFiltering(): Promise { // output filters are build in the same way as input filters // set the thresholds to the minimum to maximize the chance the LLM output will be filtered - const filter = buildAzureContentFilter({ Hate: 0, Violence: 0 }); + const azureContentFilter = buildAzureContentSafetyFilter({ + Hate: 'ALLOW_SAFE', + Violence: 'ALLOW_SAFE' + }); const orchestrationClient = new OrchestrationClient({ llm, templating, filtering: { - output: filter + output: { + filters: [azureContentFilter] + } } }); /** diff --git a/tests/type-tests/test/orchestration.test-d.ts b/tests/type-tests/test/orchestration.test-d.ts index d3accedf7..f5422f112 100644 --- a/tests/type-tests/test/orchestration.test-d.ts +++ b/tests/type-tests/test/orchestration.test-d.ts @@ -1,6 +1,7 @@ import { expectError, expectType, expectAssignable } from 'tsd'; import { OrchestrationClient, + buildAzureContentSafetyFilter, buildDocumentGroundingConfig } from '@sap-ai-sdk/orchestration'; import type { @@ -9,7 +10,8 @@ import type { TokenUsage, ChatModel, GroundingModuleConfig, - LlmModelParams + LlmModelParams, + AzureContentSafetyFilterConfig } from '@sap-ai-sdk/orchestration'; /** @@ -244,6 +246,26 @@ expectType>( expect('custom-model'); expect('gemini-1.0-pro'); +/** + * Filtering Util. + */ + +expectType( + buildAzureContentSafetyFilter({ + Hate: 'ALLOW_ALL', + SelfHarm: 'ALLOW_SAFE_LOW', + Sexual: 'ALLOW_SAFE_LOW_MEDIUM', + Violence: 'ALLOW_SAFE' + }) +); + +expectError( + buildAzureContentSafetyFilter({ + Hate: 2, + SelfHarm: 4 + }) +); + /** * Grounding util. */