-
Notifications
You must be signed in to change notification settings - Fork 60.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementation of API between Docs and CSE Copilot (#52892)
Co-authored-by: Evan Bonsignori <[email protected]> Co-authored-by: Evan Bonsignori <[email protected]>
- Loading branch information
1 parent
d35127d
commit ff3bd58
Showing
11 changed files
with
490 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import { Request, Response } from 'express' | ||
import got from 'got' | ||
import { getHmacWithEpoch } from '@/search/lib/helpers/get-cse-copilot-auth' | ||
import { getCSECopilotSource } from '#src/search/lib/helpers/cse-copilot-docs-versions.js' | ||
|
||
export const aiSearchProxy = async (req: Request, res: Response) => { | ||
const { query, version, language } = req.body | ||
const errors = [] | ||
|
||
// Validate request body | ||
if (!query) { | ||
errors.push({ message: `Missing required key 'query' in request body` }) | ||
} else if (typeof query !== 'string') { | ||
errors.push({ message: `Invalid 'query' in request body. Must be a string` }) | ||
} | ||
if (!version) { | ||
errors.push({ message: `Missing required key 'version' in request body` }) | ||
} | ||
if (!language) { | ||
errors.push({ message: `Missing required key 'language' in request body` }) | ||
} | ||
|
||
let docsSource = '' | ||
try { | ||
docsSource = getCSECopilotSource(version, language) | ||
} catch (error: any) { | ||
errors.push({ message: error?.message || 'Invalid version or language' }) | ||
} | ||
|
||
if (errors.length) { | ||
res.status(400).json({ errors }) | ||
return | ||
} | ||
|
||
const body = { | ||
chat_context: 'defaults', | ||
docs_source: docsSource, | ||
query, | ||
stream: true, | ||
} | ||
|
||
try { | ||
const stream = got.post(`${process.env.CSE_COPILOT_ENDPOINT}/answers`, { | ||
json: body, | ||
headers: { | ||
Authorization: getHmacWithEpoch(), | ||
'Content-Type': 'application/json', | ||
}, | ||
isStream: true, | ||
}) | ||
|
||
// Set response headers | ||
res.setHeader('Content-Type', 'application/x-ndjson') | ||
res.flushHeaders() | ||
|
||
// Pipe the got stream directly to the response | ||
stream.pipe(res) | ||
|
||
// Handle stream errors | ||
stream.on('error', (error) => { | ||
console.error('Error streaming from cse-copilot:', error) | ||
// Only send error response if headers haven't been sent | ||
if (!res.headersSent) { | ||
res.status(500).json({ errors: [{ message: 'Internal server error' }] }) | ||
} else { | ||
res.end() | ||
} | ||
}) | ||
|
||
// Ensure response ends when stream ends | ||
stream.on('end', () => { | ||
res.end() | ||
}) | ||
} catch (error) { | ||
console.error('Error posting /answers to cse-copilot:', error) | ||
res.status(500).json({ errors: [{ message: 'Internal server error' }] }) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// Versions used by cse-copilot | ||
import { allVersions } from '@/versions/lib/all-versions' | ||
const CSE_COPILOT_DOCS_VERSIONS = ['dotcom', 'ghec', 'ghes'] | ||
|
||
// Languages supported by cse-copilot | ||
const DOCS_LANGUAGES = ['en'] | ||
export function supportedCSECopilotLanguages() { | ||
return DOCS_LANGUAGES | ||
} | ||
|
||
export function getCSECopilotSource( | ||
version: (typeof CSE_COPILOT_DOCS_VERSIONS)[number], | ||
language: (typeof DOCS_LANGUAGES)[number], | ||
) { | ||
const cseCopilotDocsVersion = getMiscBaseNameFromVersion(version) | ||
if (!CSE_COPILOT_DOCS_VERSIONS.includes(cseCopilotDocsVersion)) { | ||
throw new Error( | ||
`Invalid 'version' in request body: '${version}'. Must be one of: ${CSE_COPILOT_DOCS_VERSIONS.join(', ')}`, | ||
) | ||
} | ||
if (!DOCS_LANGUAGES.includes(language)) { | ||
throw new Error( | ||
`Invalid 'language' in request body '${language}'. Must be one of: ${DOCS_LANGUAGES.join(', ')}`, | ||
) | ||
} | ||
return `docs_${version}_${language}` | ||
} | ||
|
||
function getMiscBaseNameFromVersion(Version: string): string { | ||
const miscBaseName = | ||
Object.values(allVersions).find( | ||
(info) => | ||
info.shortName === Version || | ||
info.plan === Version || | ||
info.miscVersionName === Version || | ||
info.currentRelease === Version, | ||
)?.miscBaseName || '' | ||
|
||
if (!miscBaseName) { | ||
return '' | ||
} | ||
|
||
return miscBaseName | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import crypto from 'crypto' | ||
|
||
// github/cse-copilot's API requires an HMAC-SHA256 signature with each request | ||
export function getHmacWithEpoch() { | ||
const epochTime = getEpochTime().toString() | ||
// CSE_COPILOT_SECRET needs to be set for the api-ai-search tests to work | ||
if (process.env.NODE_ENV === 'test') { | ||
process.env.CSE_COPILOT_SECRET = 'mock-secret' | ||
} | ||
if (!process.env.CSE_COPILOT_SECRET) { | ||
throw new Error('CSE_COPILOT_SECRET is not defined') | ||
} | ||
const hmac = generateHmacSha256(process.env.CSE_COPILOT_SECRET, epochTime) | ||
return `${epochTime}.${hmac}` | ||
} | ||
|
||
// In seconds | ||
function getEpochTime(): number { | ||
return Math.floor(Date.now() / 1000) | ||
} | ||
|
||
function generateHmacSha256(key: string, data: string): string { | ||
return crypto.createHmac('sha256', key).update(data).digest('hex') | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import express, { Request, Response } from 'express' | ||
|
||
import catchMiddlewareError from '#src/observability/middleware/catch-middleware-error.js' | ||
import { aiSearchProxy } from '../lib/ai-search-proxy' | ||
|
||
const router = express.Router() | ||
|
||
router.post( | ||
'/v1', | ||
catchMiddlewareError(async (req: Request, res: Response) => { | ||
await aiSearchProxy(req, res) | ||
}), | ||
) | ||
|
||
// Redirect to most recent version | ||
router.post('/', (req, res) => { | ||
res.redirect(307, req.originalUrl.replace('/ai-search', '/ai-search/v1')) | ||
}) | ||
|
||
export default router |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import { expect, test, describe, beforeAll, afterAll } from 'vitest' | ||
|
||
import { post } from 'src/tests/helpers/e2etest.js' | ||
import { startMockServer, stopMockServer } from '@/tests/mocks/start-mock-server' | ||
|
||
describe('AI Search Routes', () => { | ||
beforeAll(() => { | ||
startMockServer() | ||
}) | ||
afterAll(() => stopMockServer()) | ||
|
||
test('/api/ai-search/v1 should handle a successful response', async () => { | ||
let apiBody = { query: 'How do I create a Repository?', language: 'en', version: 'dotcom' } | ||
|
||
const response = await fetch('http://localhost:4000/api/ai-search/v1', { | ||
method: 'POST', | ||
headers: { 'Content-Type': 'application/json' }, | ||
body: JSON.stringify(apiBody), | ||
}) | ||
|
||
expect(response.ok).toBe(true) | ||
expect(response.headers.get('content-type')).toBe('application/x-ndjson') | ||
expect(response.headers.get('transfer-encoding')).toBe('chunked') | ||
|
||
if (!response.body) { | ||
throw new Error('ReadableStream not supported in this environment.') | ||
} | ||
|
||
const decoder = new TextDecoder('utf-8') | ||
const reader = response.body.getReader() | ||
let done = false | ||
const chunks = [] | ||
|
||
while (!done) { | ||
const { value, done: readerDone } = await reader.read() | ||
done = readerDone | ||
|
||
if (value) { | ||
// Decode the Uint8Array chunk into a string | ||
const chunkStr = decoder.decode(value, { stream: true }) | ||
chunks.push(chunkStr) | ||
} | ||
} | ||
|
||
// Combine all chunks into a single string | ||
const fullResponse = chunks.join('') | ||
// Split the response into individual chunk lines | ||
const chunkLines = fullResponse.split('\n').filter((line) => line.trim() !== '') | ||
|
||
// Assertions: | ||
|
||
// 1. First chunk should be the SOURCES chunk | ||
expect(chunkLines.length).toBeGreaterThan(0) | ||
const firstChunkMatch = chunkLines[0].match(/^Chunk: (.+)$/) | ||
expect(firstChunkMatch).not.toBeNull() | ||
|
||
const sourcesChunk = JSON.parse(firstChunkMatch?.[1] || '') | ||
expect(sourcesChunk).toHaveProperty('chunkType', 'SOURCES') | ||
expect(sourcesChunk).toHaveProperty('sources') | ||
expect(Array.isArray(sourcesChunk.sources)).toBe(true) | ||
expect(sourcesChunk.sources.length).toBe(3) | ||
|
||
// 2. Subsequent chunks should be MESSAGE_CHUNKs | ||
for (let i = 1; i < chunkLines.length; i++) { | ||
const line = chunkLines[i] | ||
const messageChunk = JSON.parse(line) | ||
expect(messageChunk).toHaveProperty('chunkType', 'MESSAGE_CHUNK') | ||
expect(messageChunk).toHaveProperty('text') | ||
expect(typeof messageChunk.text).toBe('string') | ||
} | ||
|
||
// 3. Verify the complete message is expected | ||
const expectedMessage = | ||
'Creating a repository on GitHub is something you should already know how to do :shrug:' | ||
const receivedMessage = chunkLines | ||
.slice(1) | ||
.map((line) => JSON.parse(line).text) | ||
.join('') | ||
expect(receivedMessage).toBe(expectedMessage) | ||
}) | ||
|
||
test('should handle validation errors: query missing', async () => { | ||
let body = { language: 'en', version: 'dotcom' } | ||
const response = await post('/api/ai-search/v1', { | ||
body: JSON.stringify(body), | ||
headers: { 'Content-Type': 'application/json' }, | ||
}) | ||
|
||
const responseBody = JSON.parse(response.body) | ||
|
||
expect(response.ok).toBe(false) | ||
expect(responseBody['errors']).toEqual([ | ||
{ message: `Missing required key 'query' in request body` }, | ||
]) | ||
}) | ||
|
||
test('should handle validation errors: language missing', async () => { | ||
let body = { query: 'example query', version: 'dotcom' } | ||
const response = await post('/api/ai-search/v1', { | ||
body: JSON.stringify(body), | ||
headers: { 'Content-Type': 'application/json' }, | ||
}) | ||
|
||
const responseBody = JSON.parse(response.body) | ||
|
||
expect(response.ok).toBe(false) | ||
expect(responseBody['errors']).toEqual([ | ||
{ message: `Missing required key 'language' in request body` }, | ||
{ message: `Invalid 'language' in request body 'undefined'. Must be one of: en` }, | ||
]) | ||
}) | ||
|
||
test('should handle validation errors: version missing', async () => { | ||
let body = { query: 'example query', language: 'en' } | ||
const response = await post('/api/ai-search/v1', { | ||
body: JSON.stringify(body), | ||
headers: { 'Content-Type': 'application/json' }, | ||
}) | ||
|
||
const responseBody = JSON.parse(response.body) | ||
|
||
expect(response.ok).toBe(false) | ||
expect(responseBody['errors']).toEqual([ | ||
{ message: `Missing required key 'version' in request body` }, | ||
{ | ||
message: `Invalid 'version' in request body: 'undefined'. Must be one of: dotcom, ghec, ghes`, | ||
}, | ||
]) | ||
}) | ||
|
||
test('should handle multiple validation errors: query missing, invalid language and version', async () => { | ||
let body = { language: 'fr', version: 'fpt' } | ||
const response = await post('/api/ai-search/v1', { | ||
body: JSON.stringify(body), | ||
headers: { 'Content-Type': 'application/json' }, | ||
}) | ||
|
||
const responseBody = JSON.parse(response.body) | ||
|
||
expect(response.ok).toBe(false) | ||
expect(responseBody['errors']).toEqual([ | ||
{ message: `Missing required key 'query' in request body` }, | ||
{ | ||
message: `Invalid 'language' in request body 'fr'. Must be one of: en`, | ||
}, | ||
]) | ||
}) | ||
}) |
Oops, something went wrong.