diff --git a/config.d.ts b/config.d.ts index 85143b6..3a67332 100644 --- a/config.d.ts +++ b/config.d.ts @@ -312,6 +312,11 @@ export interface AiWarpConfig { }; rateLimiting?: { max?: number; + maxByClaims?: { + claim: string; + claimValue: string; + max: number; + }[]; timeWindow?: number | string; hook?: "onRequest" | "preParsing" | "preValidation" | "preHandler"; cache?: number; diff --git a/index.ts b/index.ts index c1b7930..9509541 100644 --- a/index.ts +++ b/index.ts @@ -1,6 +1,5 @@ import { platformaticService, Stackable } from '@platformatic/service' import fastifyUser from 'fastify-user' -import fastifyRateLimit from '@fastify/rate-limit' import fastifyPlugin from 'fastify-plugin' import { schema } from './lib/schema' import { Generator } from './lib/generator' @@ -8,7 +7,7 @@ import { AiWarpConfig } from './config' import warpPlugin from './plugins/warp' import authPlugin from './plugins/auth' import apiPlugin from './plugins/api' -import createError from '@fastify/error' +import rateLimitPlugin from './plugins/rate-limiting' const stackable: Stackable = async function (fastify, opts) { const { config } = fastify.platformatic @@ -17,58 +16,7 @@ const stackable: Stackable = async function (fastify, opts) { await fastify.register(warpPlugin, opts) // needs to be registered here for fastify.ai to be decorated - const { rateLimiting } = fastify.ai - const { rateLimiting: rateLimitingConfig } = config - await fastify.register(fastifyRateLimit, { - max: async (req, key) => { - if (rateLimiting.max !== undefined) { - return await rateLimiting.max(req, key) - } else { - return rateLimitingConfig?.max ?? 1000 - } - }, - allowList: async (req, key) => { - if (rateLimiting.allowList !== undefined) { - return await rateLimiting.allowList(req, key) - } else if (rateLimitingConfig?.allowList !== undefined) { - return rateLimitingConfig.allowList.includes(key) - } - return false - }, - onBanReach: (req, key) => { - if (rateLimiting.onBanReach !== undefined) { - rateLimiting.onBanReach(req, key) - } - }, - keyGenerator: async (req) => { - if (rateLimiting.keyGenerator !== undefined) { - return await rateLimiting.keyGenerator(req) - } else { - return req.ip - } - }, - errorResponseBuilder: (req, context) => { - if (rateLimiting.errorResponseBuilder !== undefined) { - return rateLimiting.errorResponseBuilder(req, context) - } else { - const RateLimitError = createError('RATE_LIMITED', 'Rate limit exceeded, retry in %s') - const err = new RateLimitError(context.after) - err.statusCode = 429 // TODO: use context.statusCode https://github.com/fastify/fastify-rate-limit/pull/366 - return err - } - }, - onExceeding: (req, key) => { - if (rateLimiting.onExceeded !== undefined) { - rateLimiting.onExceeded(req, key) - } - }, - onExceeded: (req, key) => { - if (rateLimiting.onExceeding !== undefined) { - rateLimiting.onExceeding(req, key) - } - }, - ...rateLimitingConfig - }) + await fastify.register(rateLimitPlugin, opts) await fastify.register(apiPlugin, opts) await fastify.register(platformaticService, opts) diff --git a/lib/schema.ts b/lib/schema.ts index aed0749..1954ec7 100644 --- a/lib/schema.ts +++ b/lib/schema.ts @@ -192,6 +192,19 @@ const aiWarpSchema = { properties: { // Pulled from https://github.com/fastify/fastify-rate-limit/blob/master/types/index.d.ts#L81 max: { type: 'number' }, + maxByClaims: { + type: 'array', + items: { + type: 'object', + properties: { + claim: { type: 'string' }, + claimValue: { type: 'string' }, + max: { type: 'number' } + }, + additionalProperties: false, + required: ['claim', 'claimValue', 'max'] + } + }, timeWindow: { oneOf: [ { type: 'number' }, diff --git a/plugins/auth.ts b/plugins/auth.ts index b740d5f..cfac271 100644 --- a/plugins/auth.ts +++ b/plugins/auth.ts @@ -9,7 +9,7 @@ const UnauthorizedError = createError('UNAUTHORIZED', 'Unauthorized', 401) export default fastifyPlugin(async (fastify: FastifyInstance) => { const { config } = fastify.platformatic - fastify.addHook('preHandler', async (request) => { + fastify.addHook('onRequest', async (request) => { await request.extractUser() const isAuthRequired = config.auth?.required !== undefined && config.auth?.required diff --git a/plugins/rate-limiting.ts b/plugins/rate-limiting.ts new file mode 100644 index 0000000..3f158f0 --- /dev/null +++ b/plugins/rate-limiting.ts @@ -0,0 +1,114 @@ +// eslint-disable-next-line +/// +import { FastifyInstance } from 'fastify' +import createError from '@fastify/error' +import fastifyPlugin from 'fastify-plugin' +import fastifyRateLimit from '@fastify/rate-limit' +import { AiWarpConfig } from '../config' + +interface RateLimitMax { + // One claim to many values & maxes + values: Record +} + +function buildMaxByClaimLookupTable (config: AiWarpConfig['rateLimiting']): Record { + const table: Record = {} + if (config === undefined || config.maxByClaims === undefined) { + return table + } + + for (const { claim, claimValue: value, max } of config.maxByClaims) { + if (!(claim in table)) { + table[claim] = { values: {} } + } + + table[claim].values[value] = max + } + + return table +} + +export default fastifyPlugin(async (fastify: FastifyInstance) => { + const { config } = fastify.platformatic + const { rateLimiting: rateLimitingConfig } = config + const maxByClaimLookupTable = buildMaxByClaimLookupTable(rateLimitingConfig) + const { rateLimiting } = fastify.ai + + await fastify.register(fastifyRateLimit, { + // Note: user can override this by setting it in their platformatic config + max: async (req, key) => { + if (rateLimiting.max !== undefined) { + return await rateLimiting.max(req, key) + } + + if (rateLimitingConfig !== undefined) { + if ( + req.user !== undefined && + req.user !== null && + typeof req.user === 'object' + ) { + for (const claim of Object.keys(req.user)) { + if (claim in maxByClaimLookupTable) { + const { values } = maxByClaimLookupTable[claim] + + // @ts-expect-error + if (req.user[claim] in values) { + // @ts-expect-error + return values[req.user[claim]] + } + } + } + } + + const { max } = rateLimitingConfig + if (max !== undefined) { + return max + } + } + + return 1000 // default used in @fastify/rate-limit + }, + // Note: user can override this by setting it in their platformatic config + allowList: async (req, key) => { + if (rateLimiting.allowList !== undefined) { + return await rateLimiting.allowList(req, key) + } else if (rateLimitingConfig?.allowList !== undefined) { + return rateLimitingConfig.allowList.includes(key) + } + return false + }, + onBanReach: (req, key) => { + if (rateLimiting.onBanReach !== undefined) { + rateLimiting.onBanReach(req, key) + } + }, + keyGenerator: async (req) => { + if (rateLimiting.keyGenerator !== undefined) { + return await rateLimiting.keyGenerator(req) + } else { + return req.ip + } + }, + errorResponseBuilder: (req, context) => { + if (rateLimiting.errorResponseBuilder !== undefined) { + return rateLimiting.errorResponseBuilder(req, context) + } else { + const RateLimitError = createError('RATE_LIMITED', 'Rate limit exceeded, retry in %s') + const err = new RateLimitError(context.after) + err.statusCode = 429 // TODO: use context.statusCode https://github.com/fastify/fastify-rate-limit/pull/366 + return err + } + }, + onExceeding: (req, key) => { + if (rateLimiting.onExceeded !== undefined) { + rateLimiting.onExceeded(req, key) + } + }, + onExceeded: (req, key) => { + if (rateLimiting.onExceeding !== undefined) { + rateLimiting.onExceeding(req, key) + } + }, + ...rateLimitingConfig + }) +}) diff --git a/tests/e2e/rate-limiting.test.ts b/tests/e2e/rate-limiting.test.ts index 94523b2..07ead42 100644 --- a/tests/e2e/rate-limiting.test.ts +++ b/tests/e2e/rate-limiting.test.ts @@ -4,6 +4,7 @@ import assert from 'node:assert' import fastifyPlugin from 'fastify-plugin' import { AiWarpConfig } from '../../config' import { buildAiWarpApp } from '../utils/stackable' +import { authConfig, createToken } from '../utils/auth' const aiProvider: AiWarpConfig['aiProvider'] = { openai: { @@ -131,3 +132,44 @@ it('calls ai.rateLimiting.errorResponseBuilder callback', async () => { await app.close() } }) + +it('uses the max for a specific claim', async () => { + const [app, port] = await buildAiWarpApp({ + aiProvider, + rateLimiting: { + maxByClaims: [ + { + claim: 'rateLimitMax', + claimValue: '10', + max: 10 + }, + { + claim: 'rateLimitMax', + claimValue: '100', + max: 100 + } + ] + }, + auth: authConfig + }) + + try { + await app.start() + + let res = await fetch(`http://localhost:${port}`, { + headers: { + Authorization: `Bearer ${createToken({ rateLimitMax: '10' })}` + } + }) + assert.strictEqual(res.headers.get('x-ratelimit-limit'), '10') + + res = await fetch(`http://localhost:${port}`, { + headers: { + Authorization: `Bearer ${createToken({ rateLimitMax: '100' })}` + } + }) + assert.strictEqual(res.headers.get('x-ratelimit-limit'), '100') + } finally { + await app.close() + } +})