Skip to content

Commit

Permalink
refactor: Content filter helper function (#441)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhongpinWang authored Feb 3, 2025
1 parent ec6a0a4 commit 54a9044
Show file tree
Hide file tree
Showing 21 changed files with 700 additions and 480 deletions.
6 changes: 6 additions & 0 deletions .changeset/afraid-cooks-shave.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
'@sap-ai-sdk/orchestration': minor
---

[Compatibility Note] Deprecate `buildAzureContentFilter()` function.
Use `buildAzureContentSafetyFilter()` function instead.
64 changes: 39 additions & 25 deletions packages/orchestration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,24 +265,42 @@ Use the orchestration client with filtering to restrict content that is passed t

This feature allows filtering both the [input](https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/input-filtering) and [output](https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/output-filtering) of a model based on content safety criteria.

```ts
import {
OrchestrationClient,
buildAzureContentFilter
} from '@sap-ai-sdk/orchestration';
#### Azure Content Filter

Use `buildAzureContentSafetyFilter()` function to build an Azure content filter for both input and output.
Each category of the filter can be assigned a specific severity level, which corresponds to an Azure threshold value.

| Severity Level | Azure Threshold Value |
| ----------------------- | --------------------- |
| `ALLOW_SAFE` | 0 |
| `ALLOW_SAFE_LOW` | 2 |
| `ALLOW_SAFE_LOW_MEDIUM` | 4 |
| `ALLOW_ALL` | 6 |

const filter = buildAzureContentFilter({ Hate: 2, Violence: 4 });
```ts
import { OrchestrationClient, ContentFilters } from '@sap-ai-sdk/orchestration';
const llm = {
model_name: 'gpt-4o',
model_params: { max_tokens: 50, temperature: 0.1 }
};
const templating = {
template: [{ role: 'user', content: '{{?input}}' }]
};

const filter = buildAzureContentSafetyFilter({
Hate: 'ALLOW_SAFE_LOW',
Violence: 'ALLOW_SAFE_LOW_MEDIUM'
});
const orchestrationClient = new OrchestrationClient({
llm: {
model_name: 'gpt-4o',
model_params: { max_tokens: 50, temperature: 0.1 }
},
templating: {
template: [{ role: 'user', content: '{{?input}}' }]
},
llm,
templating,
filtering: {
input: filter,
output: filter
input: {
filters: [filter]
},
output: {
filters: [filter]
}
}
});

Expand All @@ -296,23 +314,19 @@ try {
}
```

#### Error Handling

Both `chatCompletion()` and `getContent()` methods can throw errors.

- **axios errors**:
- **Axios Errors**:
When the chat completion request fails with a `400` status code, the caught error will be an `Axios` error.
The property `error.response.data.message` may provide additional details about the failure's cause.
The property `error.response.data.message` provides additional details about the failure.

- **output content filtered**:
The method `getContent()` can throw an error if the output filter filters the model output.
- **Output Content Filtered**:
The `getContent()` method can throw an error if the output filter filters the model output.
This can occur even if the chat completion request responds with a `200` HTTP status code.
The `error.message` property indicates if the output was filtered.

Therefore, handle errors appropriately to ensure meaningful feedback for both types of errors.

`buildAzureContentFilter()` is a convenience function that creates an Azure content filter configuration based on the provided inputs.
The Azure content filter supports four categories: `Hate`, `Violence`, `Sexual`, and `SelfHarm`.
Each category can be configured with severity levels of 0, 2, 4, or 6.

### Data Masking

You can anonymize or pseudonomize the prompt using the data masking capabilities of the orchestration service.
Expand Down
7 changes: 5 additions & 2 deletions packages/orchestration/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ export type {
StreamOptions,
DocumentGroundingServiceConfig,
DocumentGroundingServiceFilter,
LlmModelParams
LlmModelParams,
AzureContentFilter,
AzureFilterThreshold
} from './orchestration-types.js';

export { OrchestrationStreamResponse } from './orchestration-stream-response.js';
Expand All @@ -21,8 +23,9 @@ export { OrchestrationClient } from './orchestration-client.js';

export {
buildAzureContentFilter,
buildAzureContentSafetyFilter,
buildDocumentGroundingConfig
} from './orchestration-utils.js';
} from './util/index.js';

export { OrchestrationResponse } from './orchestration-response.js';

Expand Down
2 changes: 1 addition & 1 deletion packages/orchestration/src/internal.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export * from './orchestration-client.js';
export * from './orchestration-utils.js';
export * from './util/index.js';
export * from './orchestration-types.js';
export * from './orchestration-response.js';
26 changes: 20 additions & 6 deletions packages/orchestration/src/orchestration-client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import {
parseMockResponse
} from '../../../test-util/mock-http.js';
import { OrchestrationClient } from './orchestration-client.js';
import { OrchestrationResponse } from './orchestration-response.js';
import {
buildAzureContentFilter,
constructCompletionPostRequestFromJsonModuleConfig,
constructCompletionPostRequest,
constructCompletionPostRequestFromJsonModuleConfig
} from './orchestration-utils.js';
import { OrchestrationResponse } from './orchestration-response.js';
buildAzureContentSafetyFilter
} from './util/index.js';
import type { CompletionPostResponse } from './client/api/schema/index.js';
import type {
OrchestrationModuleConfig,
Expand Down Expand Up @@ -162,8 +162,22 @@ describe('orchestration service client', () => {
]
},
filtering: {
input: buildAzureContentFilter({ Hate: 4, SelfHarm: 2 }),
output: buildAzureContentFilter({ Sexual: 0, Violence: 4 })
input: {
filters: [
buildAzureContentSafetyFilter({
Hate: 'ALLOW_SAFE_LOW_MEDIUM',
SelfHarm: 'ALLOW_SAFE_LOW'
})
]
},
output: {
filters: [
buildAzureContentSafetyFilter({
Sexual: 'ALLOW_SAFE',
Violence: 'ALLOW_SAFE_LOW_MEDIUM'
})
]
}
}
};
const prompt = {
Expand Down
2 changes: 1 addition & 1 deletion packages/orchestration/src/orchestration-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { OrchestrationResponse } from './orchestration-response.js';
import {
constructCompletionPostRequest,
constructCompletionPostRequestFromJsonModuleConfig
} from './orchestration-utils.js';
} from './util/index.js';
import type {
HttpResponse,
CustomRequestConfig
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { constructCompletionPostRequestFromJsonModuleConfig } from './orchestration-utils.js';
import { constructCompletionPostRequestFromJsonModuleConfig } from './util/module-config.js';

describe('construct completion post request from JSON', () => {
it('should construct completion post request from JSON', () => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {
constructCompletionPostRequest,
buildAzureContentFilter
} from './orchestration-utils.js';
buildAzureContentSafetyFilter
} from './util/index.js';
import type { CompletionPostRequest } from './client/api/schema/index.js';
import type {
OrchestrationModuleConfig,
Expand Down Expand Up @@ -169,7 +169,14 @@ describe('construct completion post request', () => {
const config: OrchestrationModuleConfig = {
...defaultConfig,
filtering: {
input: buildAzureContentFilter({ Hate: 4, SelfHarm: 0 })
input: {
filters: [
buildAzureContentSafetyFilter({
Hate: 'ALLOW_SAFE_LOW_MEDIUM',
SelfHarm: 'ALLOW_SAFE'
})
]
}
}
};
const expectedCompletionPostRequest: CompletionPostRequest = {
Expand Down Expand Up @@ -209,7 +216,14 @@ describe('construct completion post request', () => {
const config: OrchestrationModuleConfig = {
...defaultConfig,
filtering: {
output: buildAzureContentFilter({ Hate: 4, SelfHarm: 0 })
output: {
filters: [
buildAzureContentSafetyFilter({
Hate: 'ALLOW_SAFE_LOW_MEDIUM',
SelfHarm: 'ALLOW_SAFE'
})
]
}
}
};

Expand Down
50 changes: 50 additions & 0 deletions packages/orchestration/src/orchestration-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ export interface OrchestrationModuleConfig {
llm: LlmModuleConfig;
/**
* Filtering module configuration.
* Construct filter configuration for both input and output filters using convenience functions.
* @example
* ```ts
* filtering: {
* input: {
* filters: [
* buildAzureContentSafetyFilter({ Hate: 'ALLOW_SAFE', Violence: 'ALLOW_SAFE_LOW_MEDIUM' })
* ]
* }
* }
* ```
*/
filtering?: FilteringModuleConfig;
/**
Expand Down Expand Up @@ -151,3 +162,42 @@ export interface DocumentGroundingServiceConfig {
*/
output_param: string;
}

/**
* Filter configuration for Azure content safety Filter.
*/
export interface AzureContentFilter {
/**
* The filter category for hate content.
*/
Hate?: AzureFilterThreshold;
/**
* The filter category for self-harm content.
*/
SelfHarm?: AzureFilterThreshold;
/**
* The filter category for sexual content.
*/
Sexual?: AzureFilterThreshold;
/**
* The filter category for violence content.
*/
Violence?: AzureFilterThreshold;
}

/**
* A descriptive constant for Azure content safety filter threshold.
* @internal
*/
export const supportedAzureFilterThresholds = {
ALLOW_SAFE: 0,
ALLOW_SAFE_LOW: 2,
ALLOW_SAFE_LOW_MEDIUM: 4,
ALLOW_ALL: 6
} as const;

/**
* The Azure threshold level supported for each azure content filter category.
*
*/
export type AzureFilterThreshold = keyof typeof supportedAzureFilterThresholds;
Loading

0 comments on commit 54a9044

Please sign in to comment.