diff --git a/src/api-types.ts b/src/api-types.ts index 2342f8c..8310146 100644 --- a/src/api-types.ts +++ b/src/api-types.ts @@ -82,6 +82,26 @@ export const GenerateStopReasonSchema = z.enum([ ]); export type GenerateStopReason = z.infer; +const GenerateModerationSchema = z + .object({ + hap: z.optional( + z.array( + z + .object({ + success: z.boolean(), + flagged: z.boolean(), + score: z.number().min(0).max(1), + position: z.object({ + start: z.number().int().min(0), + stop: z.number().int().min(0), + }), + }) + .passthrough(), + ), + ), + }) + .passthrough(); + export const GenerateResultSchema = z .object({ generated_text: z.string(), @@ -97,6 +117,7 @@ export const GenerateOutputSchema = z model_id: z.string(), created_at: z.coerce.date(), results: z.array(GenerateResultSchema), + moderation: GenerateModerationSchema.optional(), }) .passthrough(); export type GenerateOutput = z.infer; diff --git a/src/client/client.ts b/src/client/client.ts index c07af46..7fce265 100644 --- a/src/client/client.ts +++ b/src/client/client.ts @@ -457,13 +457,18 @@ export class Client { stop_reason = null, input_token_count = 0, generated_token_count = 0, - } = chunk.results[0]; + ...props + } = (chunk.results || [{}])[0]; callback(null, { generated_text, stop_reason, input_token_count, generated_token_count, + ...(chunk.moderation && { + moderation: chunk.moderation, + }), + ...props, } as GenerateOutput); } catch (e) { const err = (chunk || e) as unknown as Error; diff --git a/src/client/types.ts b/src/client/types.ts index 1a93c33..061b728 100644 --- a/src/client/types.ts +++ b/src/client/types.ts @@ -43,7 +43,9 @@ export const GenerateInputSchema = z.union([ }), ]); export type GenerateInput = z.infer; -export type GenerateOutput = ApiTypes.GenerateOutput['results'][number]; +export type GenerateOutput = ApiTypes.GenerateOutput['results'][number] & { + moderation?: ApiTypes.GenerateOutput['moderation']; +}; export const GenerateConfigInputSchema = ApiTypes.GenerateConfigInputSchema; export type GenerateConfigInput = z.input; diff --git a/src/tests/e2e/client.test.ts b/src/tests/e2e/client.test.ts index 2f3f297..bd07263 100644 --- a/src/tests/e2e/client.test.ts +++ b/src/tests/e2e/client.test.ts @@ -51,13 +51,14 @@ describe('client', () => { }, 15_000); describe('streaming', () => { - const makeValidStream = () => + const makeValidStream = (parameters: Record = {}) => client.generate( { model_id: 'google/ul2', input: 'Hello, World', parameters: { max_new_tokens: 10, + ...parameters, }, }, { @@ -73,6 +74,10 @@ describe('client', () => { expect(chunk.generated_token_count).not.toBeNegative(); expect(chunk.input_token_count).not.toBeNegative(); expect(chunk.stop_reason).toSatisfy(isNumberOrNull); + expect(chunk.moderation).toBeOneOf([ + undefined, + expect.objectContaining({ hap: expect.any(Array) }), + ]); }; test('should throw for multiple inputs', () => { @@ -95,6 +100,27 @@ describe('client', () => { ).toThrowError('Cannot do streaming for more than one input!'); }); + test('should correctly process moderation chunks during streaming', async () => { + const stream = makeValidStream({ + moderations: { + min_new_tokens: 1, + max_new_tokens: 5, + hap: { + input: true, + threshold: 0.01, + }, + }, + }); + + for await (const chunk of stream) { + validateStreamChunk(chunk); + if (chunk.moderation) { + return; + } + } + throw Error('No moderation chunks has been retrieved from the API'); + }); + test('should return valid stream for a single input', async () => { const stream = makeValidStream();