Skip to content

Commit

Permalink
feat(ui): allow removing individual images from batch
Browse files Browse the repository at this point in the history
  • Loading branch information
psychedelicious committed Nov 19, 2024
1 parent 3c43351 commit eb9a417
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 69 deletions.
12 changes: 6 additions & 6 deletions invokeai/frontend/web/src/features/dnd/dnd.ts
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ const _setNodeImageFieldImage = buildTypeAndKey('set-node-image-field-image');
export type SetNodeImageFieldImageDndTargetData = DndData<
typeof _setNodeImageFieldImage.type,
typeof _setNodeImageFieldImage.key,
{ fieldIdentifer: FieldIdentifier }
{ fieldIdentifier: FieldIdentifier }
>;
export const setNodeImageFieldImageDndTarget: DndTarget<SetNodeImageFieldImageDndTargetData, SingleImageDndSourceData> =
{
Expand All @@ -236,8 +236,8 @@ export const setNodeImageFieldImageDndTarget: DndTarget<SetNodeImageFieldImageDn
},
handler: ({ sourceData, targetData, dispatch }) => {
const { imageDTO } = sourceData.payload;
const { fieldIdentifer } = targetData.payload;
setNodeImageFieldImage({ fieldIdentifer, imageDTO, dispatch });
const { fieldIdentifier } = targetData.payload;
setNodeImageFieldImage({ fieldIdentifier, imageDTO, dispatch });
},
};
//#endregion
Expand All @@ -247,7 +247,7 @@ const _addImagesToNodeImageFieldCollection = buildTypeAndKey('add-images-to-imag
export type AddImagesToNodeImageFieldCollection = DndData<
typeof _addImagesToNodeImageFieldCollection.type,
typeof _addImagesToNodeImageFieldCollection.key,
{ fieldIdentifer: FieldIdentifier }
{ fieldIdentifier: FieldIdentifier }
>;
export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
AddImagesToNodeImageFieldCollection,
Expand All @@ -267,7 +267,7 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
return;
}

const { fieldIdentifer } = targetData.payload;
const { fieldIdentifier } = targetData.payload;
const imageDTOs: ImageDTO[] = [];

if (singleImageDndSource.typeGuard(sourceData)) {
Expand All @@ -276,7 +276,7 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
imageDTOs.push(...sourceData.payload.imageDTOs);
}

addImagesToNodeImageFieldCollectionAction({ fieldIdentifer, imageDTOs, dispatch, getState });
addImagesToNodeImageFieldCollectionAction({ fieldIdentifier, imageDTOs, dispatch, getState });
},
};
//#endregion
Expand Down
42 changes: 33 additions & 9 deletions invokeai/frontend/web/src/features/imageActions/actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,35 +69,59 @@ export const setUpscaleInitialImage = (arg: { imageDTO: ImageDTO; dispatch: AppD

export const setNodeImageFieldImage = (arg: {
imageDTO: ImageDTO;
fieldIdentifer: FieldIdentifier;
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
}) => {
const { imageDTO, fieldIdentifer, dispatch } = arg;
dispatch(fieldImageValueChanged({ ...fieldIdentifer, value: imageDTO }));
const { imageDTO, fieldIdentifier, dispatch } = arg;
dispatch(fieldImageValueChanged({ ...fieldIdentifier, value: imageDTO }));
};

export const addImagesToNodeImageFieldCollectionAction = (arg: {
imageDTOs: ImageDTO[];
fieldIdentifer: FieldIdentifier;
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {
const { imageDTOs, fieldIdentifer, dispatch, getState } = arg;
const { imageDTOs, fieldIdentifier, dispatch, getState } = arg;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifer.nodeId,
fieldIdentifer.fieldName
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);

if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifer }, 'Attempted to add images to a non-image field collection');
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
return;
}

const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
images.push(...imageDTOs.map(({ image_name }) => ({ image_name })));
const uniqueImages = uniqBy(images, 'image_name');
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifer, value: uniqueImages }));
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
};

export const removeImageFromNodeImageFieldCollectionAction = (arg: {
imageName: string;
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {
const { imageName, fieldIdentifier, dispatch, getState } = arg;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);

if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to remove image from a non-image field collection');
return;
}

const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
const imagesWithoutTheImageToRemove = images.filter((image) => image.image_name !== imageName);
const uniqueImages = uniqBy(imagesWithoutTheImageToRemove, 'image_name');
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
};

export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,62 +1,70 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Grid, GridItem, IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { Flex, Grid, GridItem } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { IAINoContentFallback, IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
import { UploadMultipleImageButton } from 'common/hooks/useImageUploadButton';
import type { AddImagesToNodeImageFieldCollection } from 'features/dnd/dnd';
import { addImagesToNodeImageFieldCollectionDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImageFromImageName } from 'features/dnd/DndImageFromImageName';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { removeImageFromNodeImageFieldCollectionAction } from 'features/imageActions/actions';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import type { ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import { PiArrowCounterClockwiseBold, PiExclamationMarkBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';

import type { FieldComponentProps } from './types';

const sx = {
borderWidth: 1,
'&[data-error=true]': {
borderColor: 'error.500',
borderStyle: 'solid',
borderWidth: 1,
},
} satisfies SystemStyleObject;

export const ImageFieldCollectionInputComponent = memo(
(props: FieldComponentProps<ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate>) => {
const { t } = useTranslation();
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const isInvalid = useFieldIsInvalid(nodeId, field.name);
const store = useAppStore();

const onReset = useCallback(() => {
dispatch(
fieldImageCollectionValueChanged({
nodeId,
fieldName: field.name,
value: [],
})
);
}, [dispatch, field.name, nodeId]);
const isInvalid = useFieldIsInvalid(nodeId, field.name);

const dndTargetData = useMemo<AddImagesToNodeImageFieldCollection>(
() => addImagesToNodeImageFieldCollectionDndTarget.getData({ fieldIdentifer: { nodeId, fieldName: field.name } }),
() =>
addImagesToNodeImageFieldCollectionDndTarget.getData({ fieldIdentifier: { nodeId, fieldName: field.name } }),
[field, nodeId]
);

const onUpload = useCallback(
(imageDTOs: ImageDTO[]) => {
dispatch(
store.dispatch(
fieldImageCollectionValueChanged({
nodeId,
fieldName: field.name,
value: imageDTOs,
})
);
},
[dispatch, field.name, nodeId]
[store, nodeId, field.name]
);

const onRemoveImage = useCallback(
(imageName: string) => {
removeImageFromNodeImageFieldCollectionAction({
imageName,
fieldIdentifier: { nodeId, fieldName: field.name },
dispatch: store.dispatch,
getState: store.getState,
});
},
[field.name, nodeId, store.dispatch, store.getState]
);

return (
Expand All @@ -80,33 +88,23 @@ export const ImageFieldCollectionInputComponent = memo(
/>
)}
{field.value && field.value.length > 0 && (
<>
<Grid
className="nopan"
borderRadius="base"
w="full"
h="full"
templateColumns={`repeat(${Math.min(field.value.length, 3)}, 1fr)`}
gap={1}
sx={sx}
data-error={isInvalid}
p={1}
>
{field.value.map(({ image_name }) => (
<GridItem key={image_name}>
<DndImageFromImageName imageName={image_name} asThumbnail />
</GridItem>
))}
</Grid>
<IconButton
aria-label="reset"
icon={<PiArrowCounterClockwiseBold />}
position="absolute"
top={0}
insetInlineEnd={0}
onClick={onReset}
/>
</>
<Grid
className="nopan"
borderRadius="base"
w="full"
h="full"
templateColumns="repeat(3, 1fr)"
gap={1}
sx={sx}
data-error={isInvalid}
p={1}
>
{field.value.map(({ image_name }) => (
<GridItem key={image_name} position="relative">
<ImageGridItemContent imageName={image_name} onRemoveImage={onRemoveImage} />
</GridItem>
))}
</Grid>
)}
<DndDropTarget
dndTarget={addImagesToNodeImageFieldCollectionDndTarget}
Expand All @@ -119,3 +117,37 @@ export const ImageFieldCollectionInputComponent = memo(
);

ImageFieldCollectionInputComponent.displayName = 'ImageFieldCollectionInputComponent';

const ImageGridItemContent = memo(
({ imageName, onRemoveImage }: { imageName: string; onRemoveImage: (imageName: string) => void }) => {
const query = useGetImageDTOQuery(imageName);
const onClickRemove = useCallback(() => {
onRemoveImage(imageName);
}, [imageName, onRemoveImage]);

if (query.isLoading) {
return <IAINoContentFallbackWithSpinner />;
}

if (!query.data) {
return <IAINoContentFallback icon={<PiExclamationMarkBold />} />;
}

return (
<>
<DndImage imageDTO={query.data} asThumbnail />
<DndImageIcon
onClick={onClickRemove}
icon={<PiArrowCounterClockwiseBold />}
tooltip="Reset Image"
position="absolute"
flexDir="column"
top={1}
insetInlineEnd={1}
gap={1}
/>
</>
);
}
);
ImageGridItemContent.displayName = 'ImageGridItemContent';
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const ImageFieldInputComponent = (props: FieldComponentProps<ImageFieldInputInst
const dndTargetData = useMemo<SetNodeImageFieldImageDndTargetData>(
() =>
setNodeImageFieldImageDndTarget.getData(
{ fieldIdentifer: { nodeId, fieldName: field.name } },
{ fieldIdentifier: { nodeId, fieldName: field.name } },
field.value?.image_name
),
[field, nodeId]
Expand Down Expand Up @@ -85,13 +85,16 @@ const ImageFieldInputComponent = (props: FieldComponentProps<ImageFieldInputInst
{imageDTO && (
<>
<DndImage imageDTO={imageDTO} minW={8} minH={8} />
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<DndImageIcon
onClick={handleReset}
icon={imageDTO ? <PiArrowCounterClockwiseBold /> : undefined}
tooltip="Reset Image"
/>
</Flex>
<DndImageIcon
onClick={handleReset}
icon={imageDTO ? <PiArrowCounterClockwiseBold /> : undefined}
tooltip="Reset Image"
position="absolute"
flexDir="column"
top={1}
insetInlineEnd={1}
gap={1}
/>
</>
)}
<DndDropTarget
Expand Down

0 comments on commit eb9a417

Please sign in to comment.