Skip to content

Commit

Permalink
Beam: state cleanup and sync
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed Mar 11, 2024
1 parent 9a0fda8 commit 5c22061
Showing 3 changed files with 56 additions and 69 deletions.
28 changes: 5 additions & 23 deletions src/common/beam/BeamGatherControls.tsx
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@ const beamGatherControlsSx: SxProps = {
// layout
display: 'flex',
alignItems: 'center',
gap: 'var(--Pad_2)',
};

export function BeamGatherControls(props: {
@@ -31,31 +32,9 @@ export function BeamGatherControls(props: {
return (
<Box sx={beamGatherControlsSx}>

{/* Title */}
<Box sx={{ display: 'flex', gap: 'var(--Pad_2)', my: 'auto', border: '2px solid red' }}>
{/*<Typography level='h4'>*/}
{/* <ChatBeamIcon sx={{ animation: `${animationColorDarkerRainbow} 2s linear 2.66` }} />*/}
{/*</Typography>*/}
<div>
<Typography level='h4' component='h2'>
{/*big-AGI · */}
Gather
</Typography>

<Typography level='body-sm'>
Test
</Typography>
</div>
</Box>

<Box sx={{ whiteSpace: 'break-spaces', border: '2px solid red' }}>
{JSON.stringify(props)}
</Box>

{/* Algo */}
<FormControl sx={{ flex: 1, display: 'flex', justifyContent: 'space-between' /* gridColumn: '1 / -1' */ }}>
<FormControl>
{!props.isMobile && <FormLabelStart title='Beam Fusion' />}

<ButtonGroup variant='soft' color='success'>
<Button
sx={{
@@ -89,6 +68,9 @@ export function BeamGatherControls(props: {
</ButtonGroup>
</FormControl>

<Typography sx={{ flex: 1, whiteSpace: 'break-spaces', border: '1px solid red' }}>
{JSON.stringify(props)}
</Typography>

<Button variant='solid' color='neutral' onClick={props.onClose} sx={{ ml: 'auto', minWidth: 100 }}>
Close
6 changes: 2 additions & 4 deletions src/common/beam/BeamRay.tsx
Original file line number Diff line number Diff line change
@@ -145,13 +145,11 @@ export function BeamRay(props: {
const isScattering = rayIsScattering(ray);
const isSelectable = rayIsSelectable(ray);
const isSelected = rayIsUserSelected(ray);
const { removeRay, updateRay, toggleScattering /*, toggleUserSelection*/ } = props.beamStore.getState();
const { removeRay, toggleScattering, setRayLlmId } = props.beamStore.getState();

const isLlmLinked = !!props.gatherLlmId && !ray?.scatterLlmId;
const llmId: DLLMId | null = isLlmLinked ? props.gatherLlmId : ray?.scatterLlmId || null;
const setLlmId = React.useCallback((llmId: DLLMId | null) => {
updateRay(props.rayId, { scatterLlmId: llmId });
}, [props.rayId, updateRay]);
const setLlmId = React.useCallback((llmId: DLLMId | null) => setRayLlmId(props.rayId, llmId), [props.rayId, setRayLlmId]);
const handleLlmLink = React.useCallback(() => setLlmId(null), [setLlmId]);
const [_, llmComponent] = useLLMSelect(llmId, setLlmId, '', true, isScattering);

91 changes: 49 additions & 42 deletions src/common/beam/store-beam.ts
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@ function rayScatterStart(ray: DRay, onlyIdle: boolean, beamStore: BeamStore): DR
if (onlyIdle && ray.status !== 'empty')
return ray;

const { gatherLlmId, inputHistory, rays, updateRay, syncRaysStateToBeam } = beamStore;
const { gatherLlmId, inputHistory, rays, _updateRay, syncRaysStateToBeam } = beamStore;

// validate model
const rayLlmId = ray.scatterLlmId || gatherLlmId;
@@ -57,7 +57,7 @@ function rayScatterStart(ray: DRay, onlyIdle: boolean, beamStore: BeamStore): DR

const abortController = new AbortController();

const updateMessage = (update: Partial<DMessage>) => updateRay(ray.rayId, (ray) => ({
const updateMessage = (update: Partial<DMessage>) => _updateRay(ray.rayId, (ray) => ({
...ray,
message: {
...ray.message,
@@ -69,13 +69,13 @@ function rayScatterStart(ray: DRay, onlyIdle: boolean, beamStore: BeamStore): DR
// stream the assistant's messages
streamAssistantMessage(rayLlmId, inputHistory, rays.length, 'off', updateMessage, abortController.signal)
.then((outcome) => {
updateRay(ray.rayId, {
_updateRay(ray.rayId, {
status: (outcome === 'success') ? 'success' : (outcome === 'aborted') ? 'stopped' : (outcome === 'errored') ? 'error' : 'empty',
genAbortController: undefined,
});
})
.catch((error) => {
updateRay(ray.rayId, {
_updateRay(ray.rayId, {
status: 'error',
scatterIssue: error?.message || error?.toString() || 'Unknown error',
genAbortController: undefined,
@@ -154,13 +154,14 @@ interface BeamStore extends BeamState {

setGatherLlmId: (llmId: DLLMId | null) => void;
setRayCount: (count: number) => void;
removeRay: (rayId: DRayId) => void;

startScatteringAll: () => void;
stopScatteringAll: () => void;
toggleScattering: (rayId: DRayId) => void;
toggleUserSelection: (rayId: DRayId) => void;
removeRay: (rayId: DRayId) => void;
updateRay: (rayId: DRayId, update: Partial<DRay> | ((ray: DRay) => Partial<DRay>)) => void;
setRayLlmId: (rayId: DRayId, llmId: DLLMId | null) => void;
_updateRay: (rayId: DRayId, update: Partial<DRay> | ((ray: DRay) => Partial<DRay>)) => void;

syncRaysStateToBeam: () => void;

@@ -237,7 +238,7 @@ export const createBeamStore = () => createStore<BeamStore>()(
}),

setRayCount: (count: number) => {
const { rays } = _get();
const { rays, syncRaysStateToBeam } = _get();
if (count < rays.length) {
rays.slice(count).forEach(rayScatterStop);
_set({
@@ -248,20 +249,31 @@ export const createBeamStore = () => createStore<BeamStore>()(
rays: [...rays, ...Array(count - rays.length).fill(null).map(() => createDRay(_get().gatherLlmId))],
});
}
syncRaysStateToBeam();
},

removeRay: (rayId: DRayId) => {
const { syncRaysStateToBeam } = _get();
_set((state) => ({
rays: state.rays.filter((ray) => {
if (ray.rayId === rayId) {
rayScatterStop(ray);
return false;
}
return true;
}),
}));
syncRaysStateToBeam();
},


startScatteringAll: () => {
const { readyScatter, isScattering, inputHistory, rays } = _get();
if (!readyScatter) {
console.warn('startScattering: not ready', { isScattering, readyScatter, inputHistory });
return;
}
const newRays = rays.map(ray => rayScatterStart(ray, false, _get()));
const { rays, syncRaysStateToBeam } = _get();
_set({
isScattering: newRays.some((ray) => ray.status === 'scattering'),
rays: newRays
rays: rays.map(ray => rayScatterStart(ray, false, _get())),
});
// always need to invoke syncRaysStateToBeam after rayScatterStart
syncRaysStateToBeam();
},

stopScatteringAll: () => {
@@ -273,16 +285,15 @@ export const createBeamStore = () => createStore<BeamStore>()(
},

toggleScattering: (rayId: DRayId) => {
const store = _get();
const newRays = store.rays.map((ray) => (ray.rayId === rayId)
? (ray.status === 'scattering' ? rayScatterStop(ray) : rayScatterStart(ray, false, _get()))
: ray,
);
const anyStarted = newRays.some((ray) => ray.status === 'scattering');
const { rays, syncRaysStateToBeam } = _get();
_set({
isScattering: anyStarted,
rays: newRays,
rays: rays.map((ray) => (ray.rayId === rayId)
? (ray.status === 'scattering' ? rayScatterStop(ray) : rayScatterStart(ray, false, _get()))
: ray,
),
});
// always need to invoke syncRaysStateToBeam after rayScatterStart
syncRaysStateToBeam();
},

toggleUserSelection: (rayId: DRayId) => _set((state) => ({
@@ -292,17 +303,15 @@ export const createBeamStore = () => createStore<BeamStore>()(
),
})),

removeRay: (rayId: DRayId) => _set((state) => ({
rays: state.rays.filter((ray) => {
if (ray.rayId === rayId) {
rayScatterStop(ray);
return false;
}
return true;
}),
setRayLlmId: (rayId: DRayId, llmId: DLLMId | null) => _set((state) => ({
rays: state.rays.map((ray) => (ray.rayId === rayId)
? { ...ray, scatterLlmId: llmId }
: ray,
),
})),

updateRay: (rayId: DRayId, update: Partial<DRay> | ((ray: DRay) => Partial<DRay>)) => _set((state) => ({

_updateRay: (rayId: DRayId, update: Partial<DRay> | ((ray: DRay) => Partial<DRay>)) => _set((state) => ({
rays: state.rays.map((ray) => (ray.rayId === rayId)
? { ...ray, ...(typeof update === 'function' ? update(ray) : update) }
: ray,
@@ -314,17 +323,15 @@ export const createBeamStore = () => createStore<BeamStore>()(
const { rays } = _get();

// Check if all rays have finished generating
const allDone = rays.every(ray => ray.status !== 'scattering');
const hasRays = rays.length > 0;
const allDone = !rays.some(rayIsScattering);
const raysReady = rays.filter(rayIsSelectable).length;

if (allDone) {
// If all rays are done, update state accordingly
_set({
isScattering: false,
// Update other state properties as needed
});
// TODO... continue
console.log('All rays have finished generating - TODO: ');
}
console.log('syncRaysStateToBeam', { count: rays.length, isScattering: hasRays && !allDone, allDone, raysReady });

_set({
isScattering: hasRays && !allDone,
});
},

}),

0 comments on commit 5c22061

Please sign in to comment.