Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reapply streaming #434

Merged
merged 5 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions packages/api/ai/generate.mts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { generateText, type GenerateTextResult } from 'ai';
import { streamText, generateText, type GenerateTextResult } from 'ai';
import { getModel } from './config.mjs';
import {
type CodeLanguageType,
Expand All @@ -13,7 +13,7 @@ import Path from 'node:path';
import { PROMPTS_DIR } from '../constants.mjs';
import { encode, decodeCells } from '../srcmd.mjs';
import { buildProjectXml, type FileContent } from '../ai/app-parser.mjs';
import { type AppGenerationLog, logAppGeneration } from './logger.mjs';
import { logAppGeneration } from './logger.mjs';

const makeGenerateSrcbookSystemPrompt = () => {
return readFileSync(Path.join(PROMPTS_DIR, 'srcbook-generator.txt'), 'utf-8');
Expand Down Expand Up @@ -259,30 +259,40 @@ export async function generateApp(
return result.text;
}

export async function editApp(
export async function streamEditApp(
projectId: string,
files: FileContent[],
query: string,
appId: string,
planId: string,
): Promise<string> {
) {
const model = await getModel();

const systemPrompt = makeAppEditorSystemPrompt();
const userPrompt = makeAppEditorUserPrompt(projectId, files, query);
const result = await generateText({

let response = '';

const result = await streamText({
model,
system: systemPrompt,
prompt: userPrompt,
onChunk: (chunk) => {
if (chunk.chunk.type === 'text-delta') {
response += chunk.chunk.textDelta;
}
},
onFinish: () => {
if (process.env.SRCBOOK_DISABLE_ANALYTICS !== 'true') {
logAppGeneration({
appId,
planId,
llm_request: { model, system: systemPrompt, prompt: userPrompt },
llm_response: response,
});
}
},
});
const log: AppGenerationLog = {
appId,
planId,
llm_request: { model, system: systemPrompt, prompt: userPrompt },
llm_response: result,
};

if (process.env.SRCBOOK_DISABLE_ANALYTICS !== 'true') {
logAppGeneration(log);
}
return result.text;

return result.textStream;
}
110 changes: 110 additions & 0 deletions packages/api/ai/plan-parser.mts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import { XMLParser } from 'fast-xml-parser';
import Path from 'node:path';
import { type App as DBAppType } from '../db/schema.mjs';
import { loadFile } from '../apps/disk.mjs';
import { StreamingXMLParser, TagType } from './stream-xml-parser.mjs';
import { ActionChunkType, DescriptionChunkType } from '@srcbook/shared';

// The ai proposes a plan that we expect to contain both files and commands
// Here is an example of a plan:
Expand Down Expand Up @@ -167,3 +169,111 @@ export function getPackagesToInstall(plan: Plan): string[] {
)
.flatMap((action) => action.packages);
}

export async function streamParsePlan(
stream: AsyncIterable<string>,
app: DBAppType,
_query: string,
planId: string,
) {
let parser: StreamingXMLParser;

return new ReadableStream({
async pull(controller) {
if (parser === undefined) {
parser = new StreamingXMLParser({
async onTag(tag) {
if (tag.name === 'planDescription' || tag.name === 'action') {
const chunk = await toStreamingChunk(app, tag, planId);
if (chunk) {
controller.enqueue(JSON.stringify(chunk) + '\n');
}
}
},
});
}

try {
for await (const chunk of stream) {
parser.parse(chunk);
}
controller.close();
} catch (error) {
console.error(error);
controller.enqueue(
JSON.stringify({
type: 'error',
data: { content: 'Error while parsing streaming response' },
}) + '\n',
);
controller.error(error);
}
},
});
}

async function toStreamingChunk(
app: DBAppType,
tag: TagType,
planId: string,
): Promise<DescriptionChunkType | ActionChunkType | null> {
switch (tag.name) {
case 'planDescription':
return {
type: 'description',
planId: planId,
data: { content: tag.content },
} as DescriptionChunkType;
case 'action': {
const descriptionTag = tag.children.find((t) => t.name === 'description');
const description = descriptionTag?.content ?? '';
const type = tag.attributes.type;

if (type === 'file') {
const fileTag = tag.children.find((t) => t.name === 'file')!;

const filePath = fileTag.attributes.filename as string;
let originalContent = null;

try {
const fileContent = await loadFile(app, filePath);
originalContent = fileContent.source;
} catch (error) {
// If the file doesn't exist, it's likely that it's a new file.
}

return {
type: 'action',
planId: planId,
data: {
type: 'file',
description,
path: filePath,
dirname: Path.dirname(filePath),
basename: Path.basename(filePath),
modified: fileTag.content,
original: originalContent,
},
} as ActionChunkType;
} else if (type === 'command') {
const commandTag = tag.children.find((t) => t.name === 'commandType')!;
const packageTags = tag.children.filter((t) => t.name === 'package');

return {
type: 'action',
planId: planId,
data: {
type: 'command',
description,
command: commandTag.content,
packages: packageTags.map((t) => t.content),
},
} as ActionChunkType;
} else {
return null;
}
}
default:
return null;
}
}
207 changes: 207 additions & 0 deletions packages/api/ai/stream-xml-parser.mts
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
export type NodeSchema = {
isContentNode?: boolean;
hasCdata?: boolean;
allowedChildren?: string[];
};

export const xmlSchema: Record<string, NodeSchema> = {
plan: { isContentNode: false, hasCdata: false },
action: { isContentNode: false, hasCdata: false },
description: { isContentNode: true, hasCdata: true },
file: { isContentNode: false, hasCdata: true },
commandType: { isContentNode: true, hasCdata: false },
package: { isContentNode: true, hasCdata: false },
planDescription: { isContentNode: true, hasCdata: true },
};

export type TagType = {
name: string;
attributes: Record<string, string>;
content: string;
children: TagType[];
};

export type TagCallbackType = (tag: TagType) => void;

export class StreamingXMLParser {
private buffer = '';
private currentTag: TagType | null = null;
private tagStack: TagType[] = [];
private isInCDATA = false;
private cdataBuffer = '';
private textBuffer = '';
private onTag: TagCallbackType;

constructor({ onTag }: { onTag: TagCallbackType }) {
this.onTag = onTag;
}

private parseAttributes(attributeString: string): Record<string, string> {
const attributes: Record<string, string> = {};
const matches = attributeString.match(/(\w+)="([^"]*?)"/g);

if (matches) {
matches.forEach((match) => {
const [key, value] = match.split('=') as [string, string];
attributes[key] = value.replace(/"/g, '');
});
}

return attributes;
}

private handleOpenTag(tagContent: string) {
// First, save any accumulated text content to the current tag
if (this.currentTag && this.textBuffer.trim()) {
this.currentTag.content = this.textBuffer.trim();
}
this.textBuffer = '';

const spaceIndex = tagContent.indexOf(' ');
const tagName = spaceIndex === -1 ? tagContent : tagContent.substring(0, spaceIndex);
const attributeString = spaceIndex === -1 ? '' : tagContent.substring(spaceIndex + 1);

const newTag: TagType = {
name: tagName,
attributes: this.parseAttributes(attributeString),
content: '',
children: [],
};

if (this.currentTag) {
// Push current tag to stack before moving to new tag
this.tagStack.push(this.currentTag);
this.currentTag.children.push(newTag);
}

this.currentTag = newTag;
}

private handleCloseTag(tagName: string) {
if (!this.currentTag) {
console.warn('Attempted to handle close tag with no current tag');
return;
}

// Save any remaining text content before closing
// Don't overwrite CDATA content, it's already been written
const schema = xmlSchema[this.currentTag.name];
const isCdataNode = schema ? schema.hasCdata : false;
if (!isCdataNode) {
this.currentTag.content = this.textBuffer.trim();
}
this.textBuffer = '';

if (this.currentTag.name !== tagName) {
return;
}

// Clean and emit the completed tag
this.currentTag = this.cleanNode(this.currentTag);
this.onTag(this.currentTag);

// Pop the parent tag from the stack
if (this.tagStack.length > 0) {
this.currentTag = this.tagStack.pop()!;
} else {
this.currentTag = null;
}
}

private cleanNode(node: TagType): TagType {
const schema = xmlSchema[node.name];

// If it's not in the schema, default to treating it as a content node
const isContentNode = schema ? schema.isContentNode : true;

// If it's not a content node and has children, remove its content
if (!isContentNode && node.children.length > 0) {
node.content = '';
}

// Recursively clean children
node.children = node.children.map((child) => this.cleanNode(child));

return node;
}

parse(chunk: string) {
this.buffer += chunk;

while (this.buffer.length > 0) {
// Handle CDATA sections
if (this.isInCDATA) {
const cdataEndIndex = this.cdataBuffer.indexOf(']]>');
if (cdataEndIndex === -1) {
this.cdataBuffer += this.buffer;
// Sometimes ]]> is in the next chunk, and we don't want to lose what's behind it
const nextCdataEnd = this.cdataBuffer.indexOf(']]>');
if (nextCdataEnd !== -1) {
this.buffer = this.cdataBuffer.substring(nextCdataEnd);
} else {
this.buffer = '';
}
return;
}

this.cdataBuffer = this.cdataBuffer.substring(0, cdataEndIndex);
if (this.currentTag) {
this.currentTag.content = this.cdataBuffer.trim();
}
this.isInCDATA = false;
this.buffer = this.cdataBuffer.substring(cdataEndIndex + 3) + this.buffer;
this.cdataBuffer = '';
continue;
}

// Look for the next tag
const openTagStartIdx = this.buffer.indexOf('<');
if (openTagStartIdx === -1) {
// No more tags in this chunk, save the rest as potential content
this.textBuffer += this.buffer;
this.buffer = '';
return;
}

// Save any text content before this tag
if (openTagStartIdx > 0) {
this.textBuffer += this.buffer.substring(0, openTagStartIdx);
this.buffer = this.buffer.substring(openTagStartIdx);
}

// Check for CDATA
if (this.sequenceExistsAt('<![CDATA[', 0)) {
this.isInCDATA = true;
const cdataStart = this.buffer.substring(9);
this.cdataBuffer = cdataStart;
this.buffer = '';
return;
}

const openTagEndIdx = this.buffer.indexOf('>');
if (openTagEndIdx === -1) {
return;
}

const tagContent = this.buffer.substring(1, openTagEndIdx);
this.buffer = this.buffer.substring(openTagEndIdx + 1);

if (tagContent.startsWith('/')) {
// Closing tag
this.handleCloseTag(tagContent.substring(1));
} else {
// Opening tag
this.handleOpenTag(tagContent);
}
}
}

private sequenceExistsAt(sequence: string, idx: number, buffer: string = this.buffer) {
for (let i = 0; i < sequence.length; i++) {
if (buffer[idx + i] !== sequence[i]) {
return false;
}
}
return true;
}
}
Loading