Skip to content

Commit

Permalink
Stream LLM response when editing apps (#430)
Browse files Browse the repository at this point in the history
* Streaming WIP

* More wip

* Fix up the streaming parser and the tests

* Remove extra debug log

* Remove unused file

* Fix more parser bugs

* Hook up frontend

---------

Co-authored-by: Nicholas Charriere <[email protected]>
  • Loading branch information
benjreinhart and nichochar authored Oct 29, 2024
1 parent 3ff5428 commit a6bfde7
Show file tree
Hide file tree
Showing 12 changed files with 894 additions and 65 deletions.
42 changes: 26 additions & 16 deletions packages/api/ai/generate.mts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { generateText, type GenerateTextResult } from 'ai';
import { streamText, generateText, type GenerateTextResult } from 'ai';
import { getModel } from './config.mjs';
import {
type CodeLanguageType,
Expand All @@ -13,7 +13,7 @@ import Path from 'node:path';
import { PROMPTS_DIR } from '../constants.mjs';
import { encode, decodeCells } from '../srcmd.mjs';
import { buildProjectXml, type FileContent } from '../ai/app-parser.mjs';
import { type AppGenerationLog, logAppGeneration } from './logger.mjs';
import { logAppGeneration } from './logger.mjs';

const makeGenerateSrcbookSystemPrompt = () => {
return readFileSync(Path.join(PROMPTS_DIR, 'srcbook-generator.txt'), 'utf-8');
Expand Down Expand Up @@ -259,30 +259,40 @@ export async function generateApp(
return result.text;
}

export async function editApp(
export async function streamEditApp(
projectId: string,
files: FileContent[],
query: string,
appId: string,
planId: string,
): Promise<string> {
) {
const model = await getModel();

const systemPrompt = makeAppEditorSystemPrompt();
const userPrompt = makeAppEditorUserPrompt(projectId, files, query);
const result = await generateText({

let response = '';

const result = await streamText({
model,
system: systemPrompt,
prompt: userPrompt,
onChunk: (chunk) => {
if (chunk.chunk.type === 'text-delta') {
response += chunk.chunk.textDelta;
}
},
onFinish: () => {
if (process.env.SRCBOOK_DISABLE_ANALYTICS !== 'true') {
logAppGeneration({
appId,
planId,
llm_request: { model, system: systemPrompt, prompt: userPrompt },
llm_response: response,
});
}
},
});
const log: AppGenerationLog = {
appId,
planId,
llm_request: { model, system: systemPrompt, prompt: userPrompt },
llm_response: result,
};

if (process.env.SRCBOOK_DISABLE_ANALYTICS !== 'true') {
logAppGeneration(log);
}
return result.text;

return result.textStream;
}
110 changes: 110 additions & 0 deletions packages/api/ai/plan-parser.mts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import { XMLParser } from 'fast-xml-parser';
import Path from 'node:path';
import { type App as DBAppType } from '../db/schema.mjs';
import { loadFile } from '../apps/disk.mjs';
import { StreamingXMLParser, TagType } from './stream-xml-parser.mjs';
import { ActionChunkType, DescriptionChunkType } from '@srcbook/shared';

// The ai proposes a plan that we expect to contain both files and commands
// Here is an example of a plan:
Expand Down Expand Up @@ -167,3 +169,111 @@ export function getPackagesToInstall(plan: Plan): string[] {
)
.flatMap((action) => action.packages);
}

export async function streamParsePlan(
stream: AsyncIterable<string>,
app: DBAppType,
_query: string,
planId: string,
) {
let parser: StreamingXMLParser;

return new ReadableStream({
async pull(controller) {
if (parser === undefined) {
parser = new StreamingXMLParser({
async onTag(tag) {
if (tag.name === 'planDescription' || tag.name === 'action') {
const chunk = await toStreamingChunk(app, tag, planId);
if (chunk) {
controller.enqueue(JSON.stringify(chunk) + '\n');
}
}
},
});
}

try {
for await (const chunk of stream) {
parser.parse(chunk);
}
controller.close();
} catch (error) {
console.error(error);
controller.enqueue(
JSON.stringify({
type: 'error',
data: { content: 'Error while parsing streaming response' },
}) + '\n',
);
controller.error(error);
}
},
});
}

async function toStreamingChunk(
app: DBAppType,
tag: TagType,
planId: string,
): Promise<DescriptionChunkType | ActionChunkType | null> {
switch (tag.name) {
case 'planDescription':
return {
type: 'description',
planId: planId,
data: { content: tag.content },
} as DescriptionChunkType;
case 'action': {
const descriptionTag = tag.children.find((t) => t.name === 'description');
const description = descriptionTag?.content ?? '';
const type = tag.attributes.type;

if (type === 'file') {
const fileTag = tag.children.find((t) => t.name === 'file')!;

const filePath = fileTag.attributes.filename as string;
let originalContent = null;

try {
const fileContent = await loadFile(app, filePath);
originalContent = fileContent.source;
} catch (error) {
// If the file doesn't exist, it's likely that it's a new file.
}

return {
type: 'action',
planId: planId,
data: {
type: 'file',
description,
path: filePath,
dirname: Path.dirname(filePath),
basename: Path.basename(filePath),
modified: fileTag.content,
original: originalContent,
},
} as ActionChunkType;
} else if (type === 'command') {
const commandTag = tag.children.find((t) => t.name === 'commandType')!;
const packageTags = tag.children.filter((t) => t.name === 'package');

return {
type: 'action',
planId: planId,
data: {
type: 'command',
description,
command: commandTag.content,
packages: packageTags.map((t) => t.content),
},
} as ActionChunkType;
} else {
return null;
}
}
default:
return null;
}
}
138 changes: 138 additions & 0 deletions packages/api/ai/stream-xml-parser.mts
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
export type TagType = {
name: string;
attributes: Record<string, string>;
content: string;
children: TagType[];
};

export type TagCallbackType = (tag: TagType) => void;

export class StreamingXMLParser {
private buffer = '';
private currentTag: TagType | null = null;
private tagStack: TagType[] = [];
private isInCDATA = false;
private cdataBuffer = '';
private onTag: TagCallbackType;

constructor({ onTag }: { onTag: TagCallbackType }) {
this.onTag = onTag;
}

private parseAttributes(attributeString: string): Record<string, string> {
const attributes: Record<string, string> = {};
const matches = attributeString.match(/(\w+)="([^"]*?)"/g);

if (matches) {
matches.forEach((match) => {
const [key, value] = match.split('=') as [string, string];
attributes[key] = value.replace(/"/g, '');
});
}

return attributes;
}

private handleOpenTag(tagContent: string) {
const spaceIndex = tagContent.indexOf(' ');
const tagName = spaceIndex === -1 ? tagContent : tagContent.substring(0, spaceIndex);
const attributeString = spaceIndex === -1 ? '' : tagContent.substring(spaceIndex + 1);

const newTag: TagType = {
name: tagName,
attributes: this.parseAttributes(attributeString),
content: '',
children: [],
};

if (this.currentTag) {
this.tagStack.push(this.currentTag);
this.currentTag.children.push(newTag);
}

this.currentTag = newTag;
}

private handleCloseTag(tagName: string) {
if (!this.currentTag) return;

if (this.currentTag.name === tagName) {
this.onTag(this.currentTag);

if (this.tagStack.length > 0) {
this.currentTag = this.tagStack.pop()!;
} else {
this.currentTag = null;
}
}
}

parse(chunk: string) {
this.buffer += chunk;

while (this.buffer.length > 0) {
// Handle CDATA sections
if (this.isInCDATA) {
const cdataEndIndex = this.cdataBuffer.indexOf(']]>');
if (cdataEndIndex === -1) {
this.cdataBuffer += chunk;
return;
}

this.cdataBuffer = this.cdataBuffer.substring(0, cdataEndIndex);
if (this.currentTag) {
this.currentTag.content = this.cdataBuffer;
}
this.isInCDATA = false;
this.cdataBuffer = '';
this.buffer = this.buffer.substring(cdataEndIndex + 3);
continue;
}

// Start of an opening tag?
const openTagStartIdx = this.buffer.indexOf('<');
if (openTagStartIdx === -1) {
this.buffer = '';
return;
}

// If this opening tag is CDATA, handle it differently than XML tags
if (this.sequenceExistsAt('<![CDATA[', openTagStartIdx)) {
this.isInCDATA = true;
const cdataStart = this.buffer.substring(openTagStartIdx + 9);
this.buffer = cdataStart;
this.cdataBuffer = cdataStart;
return;
}

const openTagEndIdx = this.buffer.indexOf('>', openTagStartIdx);
if (openTagEndIdx === -1) {
return;
}

const tagContent = this.buffer.substring(openTagStartIdx + 1, openTagEndIdx);
this.buffer = this.buffer.substring(openTagEndIdx + 1);

if (tagContent.startsWith('/')) {
// Closing tag
this.handleCloseTag(tagContent.substring(1));
} else {
// Opening tag
this.handleOpenTag(tagContent);
}
}
}

/**
* Does the sequence exist starting at the given index in the buffer?
*/
private sequenceExistsAt(sequence: string, idx: number, buffer: string = this.buffer) {
for (let i = 0; i < sequence.length; i++) {
if (buffer[idx + i] !== sequence[i]) {
return false;
}
}

return true;
}
}
12 changes: 7 additions & 5 deletions packages/api/server/http.mts
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import {
listSessions,
exportSrcmdText,
} from '../session.mjs';
import { generateCells, generateSrcbook, healthcheck, editApp } from '../ai/generate.mjs';
import { parsePlan } from '../ai/plan-parser.mjs';
import { generateCells, generateSrcbook, healthcheck, streamEditApp } from '../ai/generate.mjs';
import { streamParsePlan } from '../ai/plan-parser.mjs';
import {
getConfig,
updateConfig,
Expand Down Expand Up @@ -63,6 +63,7 @@ import { CreateAppSchema } from '../apps/schemas.mjs';
import { AppGenerationFeedbackType } from '@srcbook/shared';
import { createZipFromApp } from '../apps/disk.mjs';
import { checkoutCommit, commitAllFiles, getCurrentCommitSha } from '../apps/git.mjs';
import { streamJsonResponse } from './utils.mjs';

const app: Application = express();

Expand Down Expand Up @@ -555,9 +556,10 @@ router.post('/apps/:id/edit', cors(), async (req, res) => {
}
const validName = toValidPackageName(app.name);
const files = await getFlatFilesForApp(String(app.externalId));
const result = await editApp(validName, files, query, id, planId);
const parsedResult = await parsePlan(result, app, query, planId);
return res.json({ data: parsedResult });
const result = await streamEditApp(validName, files, query, app.externalId, planId);
const planStream = await streamParsePlan(result, app, query, planId);

return streamJsonResponse(planStream, res, { status: 200 });
} catch (e) {
return error500(res, e as Error);
}
Expand Down
28 changes: 28 additions & 0 deletions packages/api/server/utils.mts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { ServerResponse } from 'node:http';
import { StreamToIterable } from '@srcbook/shared';

/**
* Pipe a `ReadableStream` through a Node `ServerResponse` object.
*/
export async function streamJsonResponse(
stream: ReadableStream,
response: ServerResponse,
options?: {
headers?: Record<string, string>;
status?: number;
},
) {
options ??= {};

response.writeHead(options.status || 200, {
...options.headers,
'Content-Type': 'text/plain',
'Transfer-Encoding': 'chunked',
});

for await (const chunk of StreamToIterable(stream)) {
response.write(chunk);
}

response.end();
}
Loading

0 comments on commit a6bfde7

Please sign in to comment.