From 2d08078a7d33e378959aef519b93fc33d939e0b5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 8 Nov 2024 10:17:46 +1000 Subject: [PATCH] fix(ui): fit bbox to layers math --- .../konva/CanvasTool/CanvasBboxToolModule.ts | 21 ++---- .../features/controlLayers/konva/util.test.ts | 74 ++++++++++++++++++- .../src/features/controlLayers/konva/util.ts | 28 +++++++ 3 files changed, 106 insertions(+), 17 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts index 933ce203533..0bd6bd063a7 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts @@ -1,13 +1,8 @@ -import { - roundDownToMultiple, - roundToMultiple, - roundToMultipleMin, - roundUpToMultiple, -} from 'common/util/roundDownToMultiple'; +import { roundToMultiple, roundToMultipleMin } from 'common/util/roundDownToMultiple'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule'; -import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util'; +import { fitRectToGrid, getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util'; import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSlice'; import { selectBbox } from 'features/controlLayers/store/selectors'; import type { Coordinate, Rect } from 'features/controlLayers/store/types'; @@ -398,18 +393,12 @@ export class CanvasBboxToolModule extends CanvasModuleBase { } // Determine the bbox size that fits within the visible rect. The bbox must be at least 64px in width and height, - // and its width and height must be multiples of 8px. + // and its width and height must be multiples of the bbox grid size. const gridSize = this.manager.stateApi.getBboxGridSize(); - // To be conservative, we will round up the x and y to the nearest grid size, and round down the width and height. - // This ensures the bbox is never _larger_ than the visible rect. If the bbox is larger than the visible, we - // will always trigger the outpainting workflow, which is not what the user wants. - const x = roundUpToMultiple(visibleRect.x, gridSize); - const y = roundUpToMultiple(visibleRect.y, gridSize); - const width = roundDownToMultiple(visibleRect.width, gridSize); - const height = roundDownToMultiple(visibleRect.height, gridSize); + const rect = fitRectToGrid(visibleRect, gridSize); - this.manager.stateApi.setGenerationBbox({ x, y, width, height }); + this.manager.stateApi.setGenerationBbox(rect); }; /** diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/util.test.ts b/invokeai/frontend/web/src/features/controlLayers/konva/util.test.ts index a7db030ae4a..f3c44821a06 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/util.test.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/util.test.ts @@ -1,4 +1,6 @@ -import { getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util'; +import { roundUpToMultiple } from 'common/util/roundDownToMultiple'; +import { fitRectToGrid, getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util'; +import type { Rect } from 'features/controlLayers/store/types'; import { describe, expect, it } from 'vitest'; describe('util', () => { @@ -44,4 +46,74 @@ describe('util', () => { expect(union).toEqual({ x: 0, y: 0, width: 0, height: 0 }); }); }); + + describe('fitRectToGrid', () => { + it('should fit rect within grid without exceeding bounds', () => { + const rect: Rect = { x: 0, y: 0, width: 1047, height: 1758 }; + const gridSize = 50; + const result = fitRectToGrid(rect, gridSize); + + expect(result.x).toBe(roundUpToMultiple(rect.x, gridSize)); + expect(result.y).toBe(roundUpToMultiple(rect.y, gridSize)); + expect(result.width).toBeLessThanOrEqual(rect.width); + expect(result.height).toBeLessThanOrEqual(rect.height); + expect(result.width % gridSize).toBe(0); + expect(result.height % gridSize).toBe(0); + }); + + it('should handle small rect within grid bounds', () => { + const rect: Rect = { x: 20, y: 30, width: 80, height: 90 }; + const gridSize = 25; + const result = fitRectToGrid(rect, gridSize); + + expect(result.x).toBe(25); + expect(result.y).toBe(50); + expect(result.width % gridSize).toBe(0); + expect(result.height % gridSize).toBe(0); + expect(result.width).toBeLessThanOrEqual(rect.width); + expect(result.height).toBeLessThanOrEqual(rect.height); + }); + + it('should handle rect starting outside of grid alignment', () => { + const rect: Rect = { x: 13, y: 27, width: 94, height: 112 }; + const gridSize = 20; + const result = fitRectToGrid(rect, gridSize); + + expect(result.x).toBe(20); + expect(result.y).toBe(40); + expect(result.width % gridSize).toBe(0); + expect(result.height % gridSize).toBe(0); + expect(result.width).toBeLessThanOrEqual(rect.width); + expect(result.height).toBeLessThanOrEqual(rect.height); + }); + + it('should return the same rect if already aligned to grid', () => { + const rect: Rect = { x: 100, y: 100, width: 200, height: 300 }; + const gridSize = 50; + const result = fitRectToGrid(rect, gridSize); + + expect(result).toEqual(rect); + }); + + it('should handle large grid sizes relative to rect dimensions', () => { + const rect: Rect = { x: 250, y: 300, width: 400, height: 500 }; + const gridSize = 100; + const result = fitRectToGrid(rect, gridSize); + + expect(result.x).toBe(300); + expect(result.y).toBe(300); + expect(result.width % gridSize).toBe(0); + expect(result.height % gridSize).toBe(0); + expect(result.width).toBeLessThanOrEqual(rect.width); + expect(result.height).toBeLessThanOrEqual(rect.height); + }); + + it('should handle rect with zero width and height', () => { + const rect: Rect = { x: 40, y: 60, width: 100, height: 200 }; + const gridSize = 20; + const result = fitRectToGrid(rect, gridSize); + + expect(result).toEqual({ x: 40, y: 60, width: 100, height: 200 }); + }); + }); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/util.ts b/invokeai/frontend/web/src/features/controlLayers/konva/util.ts index 38ce6389cc2..a2d28b5ac0e 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/util.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/util.ts @@ -1,5 +1,6 @@ import type { Selector, Store } from '@reduxjs/toolkit'; import { $authToken } from 'app/store/nanostores/authToken'; +import { roundDownToMultiple, roundUpToMultiple } from 'common/util/roundDownToMultiple'; import type { CanvasEntityIdentifier, CanvasObjectState, @@ -560,6 +561,33 @@ export const getRectIntersection = (...rects: Rect[]): Rect => { return rect || getEmptyRect(); }; +/** + * Fits a rect to the nearest multiple of the grid size, rounding down. The returned rect will be smaller than or equal + * to the input rect, and will be aligned to the grid. + * + * In other words, shrink the rect inwards on each size until it fits within the visible rect and aligns to the grid. + * + * @param rect The rect to fit + * @param gridSize The size of the grid + * @returns The fitted rect + */ +export const fitRectToGrid = (rect: Rect, gridSize: number): Rect => { + // Rounding x and y up effectively shrinks the left and top edges of the rect, and rounding width and height down + // effectively shrinks the right and bottom edges. + const x = roundUpToMultiple(rect.x, gridSize); + const y = roundUpToMultiple(rect.y, gridSize); + + // Because we've just shifted the rect's x and y, we need to adjust the width and height by the same amount before + // we round those values down. + const offsetX = x - rect.x; + const offsetY = y - rect.y; + + const width = roundDownToMultiple(rect.width - offsetX, gridSize); + const height = roundDownToMultiple(rect.height - offsetY, gridSize); + + return { x, y, width, height }; +}; + /** * Asserts that the value is never reached. Used for exhaustive checks in switch statements or conditional logic to ensure * that all possible values are handled.