diff --git a/src/mflux/config/config.py b/src/mflux/config/config.py index b42dbc3..99495d2 100644 --- a/src/mflux/config/config.py +++ b/src/mflux/config/config.py @@ -17,6 +17,7 @@ def __init__( guidance: float = 4.0, image_path: Path | None = None, image_strength: float | None = None, + masked_image_path: Path | None = None, controlnet_strength: float | None = None, ): if width % 16 != 0 or height % 16 != 0: @@ -27,4 +28,5 @@ def __init__( self.guidance = guidance self.image_path = image_path self.image_strength = image_strength + self.masked_image_path = masked_image_path self.controlnet_strength = controlnet_strength diff --git a/src/mflux/config/model_config.py b/src/mflux/config/model_config.py index 8daa6f3..7d4b061 100644 --- a/src/mflux/config/model_config.py +++ b/src/mflux/config/model_config.py @@ -13,6 +13,8 @@ def __init__( num_train_steps: int, max_sequence_length: int, supports_guidance: bool, + requires_sigma_shift: bool, + priority: int, ): self.alias = alias self.model_name = model_name @@ -20,65 +22,57 @@ def __init__( self.num_train_steps = num_train_steps self.max_sequence_length = max_sequence_length self.supports_guidance = supports_guidance + self.requires_sigma_shift = requires_sigma_shift + self.priority = priority @staticmethod @lru_cache def dev() -> "ModelConfig": - return ModelConfig( - alias="dev", - model_name="black-forest-labs/FLUX.1-dev", - base_model=None, - num_train_steps=1000, - max_sequence_length=512, - supports_guidance=True, - ) + return AVAILABLE_MODELS["dev"] + + @staticmethod + @lru_cache + def dev_fill() -> "ModelConfig": + return AVAILABLE_MODELS["dev-fill"] @staticmethod @lru_cache def schnell() -> "ModelConfig": - return ModelConfig( - alias="schnell", - model_name="black-forest-labs/FLUX.1-schnell", - base_model=None, - num_train_steps=1000, - max_sequence_length=256, - supports_guidance=False, - ) + return AVAILABLE_MODELS["schnell"] @staticmethod def from_name( model_name: str, - base_model: Literal["dev", "schnell"] | None = None, + base_model: Literal["dev", "schnell", "dev-fill"] | None = None, ) -> "ModelConfig": - dev = ModelConfig.dev() - schnell = ModelConfig.schnell() + # 0. Get all base models (where base_model is None) sorted by priority + base_models = sorted( + [model for model in AVAILABLE_MODELS.values() if model.base_model is None], key=lambda x: x.priority + ) + + # 1. Check if model_name matches any base model's alias or full name + for base in base_models: + if model_name in (base.alias, base.model_name): + return base - # 0. Validate explicit base_model - allowed_names = [dev.alias, dev.model_name, schnell.alias, schnell.model_name] + # 2. Validate explicit base_model + allowed_names = [] + for base in base_models: + allowed_names.extend([base.alias, base.model_name]) if base_model and base_model not in allowed_names: raise InvalidBaseModel(f"Invalid base_model. Choose one of {allowed_names}") - # 1. If model_name is "dev" or "schnell" then simply return - if model_name == dev.model_name or model_name == dev.alias: - return dev - if model_name == schnell.model_name or model_name == schnell.alias: - return schnell + # 3. Determine the base model (explicit or inferred) + if base_model: + # Find by explicit base_model name + default_base = next((b for b in base_models if base_model in (b.alias, b.model_name)), None) + else: + # Infer from model_name substring (priority order via sorted base_models) + default_base = next((b for b in base_models if b.alias and b.alias in model_name), None) + if not default_base: + raise ModelConfigError(f"Cannot infer base_model from {model_name}") - # 1. Determine the appropriate base model - default_base = None - if not base_model: - if "dev" in model_name: - default_base = dev - elif "schnell" in model_name: - default_base = schnell - else: - raise ModelConfigError(f"Cannot infer base_model from {model_name}. Specify --base-model.") - elif base_model == dev.model_name or base_model == dev.alias: - default_base = dev - elif base_model == schnell.model_name or base_model == schnell.alias: - default_base = schnell - - # 2. Construct the config based on the model name and base default + # 4. Construct the config return ModelConfig( alias=default_base.alias, model_name=model_name, @@ -86,7 +80,40 @@ def from_name( num_train_steps=default_base.num_train_steps, max_sequence_length=default_base.max_sequence_length, supports_guidance=default_base.supports_guidance, + requires_sigma_shift=default_base.requires_sigma_shift, + priority=default_base.priority, ) - def is_dev(self) -> bool: - return self.alias == "dev" + +AVAILABLE_MODELS = { + "schnell": ModelConfig( + alias="schnell", + model_name="black-forest-labs/FLUX.1-schnell", + base_model=None, + num_train_steps=1000, + max_sequence_length=256, + supports_guidance=False, + requires_sigma_shift=False, + priority=2, + ), + "dev": ModelConfig( + alias="dev", + model_name="black-forest-labs/FLUX.1-dev", + base_model=None, + num_train_steps=1000, + max_sequence_length=512, + supports_guidance=True, + requires_sigma_shift=True, + priority=1, + ), + "dev-fill": ModelConfig( + alias="dev-fill", + model_name="black-forest-labs/FLUX.1-Fill-dev", + base_model=None, + num_train_steps=1000, + max_sequence_length=512, + supports_guidance=True, + requires_sigma_shift=True, + priority=0, + ), +} diff --git a/src/mflux/config/runtime_config.py b/src/mflux/config/runtime_config.py index c346b5e..d55073c 100644 --- a/src/mflux/config/runtime_config.py +++ b/src/mflux/config/runtime_config.py @@ -32,7 +32,7 @@ def width(self, value): self.config.width = value @property - def guidance(self) -> float: + def guidance(self) -> float | None: return self.config.guidance @property @@ -55,6 +55,10 @@ def image_path(self) -> str: def image_strength(self) -> float | None: return self.config.image_strength + @property + def masked_image_path(self) -> str | None: + return self.config.masked_image_path + @property def init_time_step(self) -> int: is_img2img = ( @@ -82,7 +86,7 @@ def controlnet_strength(self) -> float | None: @staticmethod def _create_sigmas(config: Config, model_config: ModelConfig) -> mx.array: sigmas = RuntimeConfig._create_sigmas_values(config.num_inference_steps) - if model_config.is_dev(): + if model_config.requires_sigma_shift: sigmas = RuntimeConfig._shift_sigmas(sigmas=sigmas, width=config.width, height=config.height) return sigmas diff --git a/src/mflux/flux/flux.py b/src/mflux/flux/flux.py index bf7f756..836fc0b 100644 --- a/src/mflux/flux/flux.py +++ b/src/mflux/flux/flux.py @@ -16,6 +16,7 @@ from mflux.post_processing.array_util import ArrayUtil from mflux.post_processing.generated_image import GeneratedImage from mflux.post_processing.image_util import ImageUtil +from mflux.post_processing.mask_util import MaskUtil from mflux.weights.model_saver import ModelSaver @@ -76,6 +77,15 @@ def generate_image( clip_text_encoder=self.clip_text_encoder, ) + # 3. Create the static masked latents + static_masked_latents = MaskUtil.create_masked_latents( + vae=self.vae, + config=config, + latents=latents, + img_path=config.image_path, + mask_path=config.masked_image_path, + ) + # (Optional) Call subscribers for beginning of loop Callbacks.before_loop( seed=seed, @@ -86,16 +96,19 @@ def generate_image( for t in time_steps: try: - # 3.t Predict the noise + # 4.t Concatenate the updated latents with the static masked latents + hidden_states = mx.concatenate([latents, static_masked_latents], axis=-1) + + # 5.t Predict the noise noise = self.transformer( t=t, config=config, - hidden_states=latents, + hidden_states=hidden_states, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, ) - # 4.t Take one denoise step + # 6.t Take one denoise step dt = config.sigmas[t + 1] - config.sigmas[t] latents += noise * dt diff --git a/src/mflux/generate.py b/src/mflux/generate.py index 49508da..667f9db 100644 --- a/src/mflux/generate.py +++ b/src/mflux/generate.py @@ -51,6 +51,7 @@ def main(): width=args.width, guidance=args.guidance, image_path=args.image_path, + masked_image_path=args.masked_image_path, image_strength=args.image_strength, ), ) diff --git a/src/mflux/post_processing/array_util.py b/src/mflux/post_processing/array_util.py index 9fe82b8..c0dccea 100644 --- a/src/mflux/post_processing/array_util.py +++ b/src/mflux/post_processing/array_util.py @@ -10,8 +10,8 @@ def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: return latents @staticmethod - def pack_latents(latents: mx.array, height: int, width: int) -> mx.array: - latents = mx.reshape(latents, (1, 16, height // 16, 2, width // 16, 2)) + def pack_latents(latents: mx.array, height: int, width: int, num_channels_latents: int = 16) -> mx.array: + latents = mx.reshape(latents, (1, num_channels_latents, height // 16, 2, width // 16, 2)) latents = mx.transpose(latents, (0, 2, 4, 1, 3, 5)) - latents = mx.reshape(latents, (1, (width // 16) * (height // 16), 64)) + latents = mx.reshape(latents, (1, (width // 16) * (height // 16), num_channels_latents * 4)) return latents diff --git a/src/mflux/post_processing/image_util.py b/src/mflux/post_processing/image_util.py index 563d9fa..4a4895e 100644 --- a/src/mflux/post_processing/image_util.py +++ b/src/mflux/post_processing/image_util.py @@ -70,6 +70,10 @@ def _denormalize(images: mx.array) -> mx.array: def _normalize(images: mx.array) -> mx.array: return 2.0 * images - 1.0 + @staticmethod + def _binarize(image: mx.array) -> mx.array: + return mx.where(image < 0.5, mx.zeros_like(image), mx.ones_like(image)) + @staticmethod def _to_numpy(images: mx.array) -> np.ndarray: images = mx.transpose(images, (0, 2, 3, 1)) @@ -90,11 +94,14 @@ def _pil_to_numpy(image: PIL.Image.Image) -> np.ndarray: return images @staticmethod - def to_array(image: PIL.Image.Image) -> mx.array: + def to_array(image: PIL.Image.Image, is_mask: bool = False) -> mx.array: image = ImageUtil._pil_to_numpy(image) array = mx.array(image) array = mx.transpose(array, (0, 3, 1, 2)) - array = ImageUtil._normalize(array) + if is_mask: + array = ImageUtil._binarize(array) + else: + array = ImageUtil._normalize(array) return array @staticmethod diff --git a/src/mflux/post_processing/mask_util.py b/src/mflux/post_processing/mask_util.py new file mode 100644 index 0000000..692a8e8 --- /dev/null +++ b/src/mflux/post_processing/mask_util.py @@ -0,0 +1,54 @@ +import mlx.core as mx + +from mflux.config.runtime_config import RuntimeConfig +from mflux.models.vae.vae import VAE +from mflux.post_processing.array_util import ArrayUtil + + +class MaskUtil: + @staticmethod + def create_masked_latents( + vae: VAE, + config: RuntimeConfig, + latents: mx.array, + img_path: str, + mask_path: str | None + ) -> mx.array: # fmt: off + from mflux import ImageUtil + + # 1. Get the reference image + scaled_image = ImageUtil.scale_to_dimensions( + image=ImageUtil.load_image(config.image_path).convert("RGB"), + target_width=config.width, + target_height=config.height, + ) + image = ImageUtil.to_array(scaled_image) + + # 2. Get the mask + scaled = ImageUtil.scale_to_dimensions( + image=ImageUtil.load_image(mask_path).convert("RGB"), + target_width=config.width, + target_height=config.height, + ) + the_mask = ImageUtil.to_array(scaled, is_mask=True) + + # 3. Create and pack the masked image + masked_image = image * (1 - the_mask) + masked_image = vae.encode(masked_image) + masked_image = ArrayUtil.pack_latents(latents=masked_image, height=config.height, width=config.width) + + # 4. Resize mask and pack latents + mask = MaskUtil._reshape_mask(the_mask=the_mask, height=config.height, width=config.width) + mask = ArrayUtil.pack_latents(latents=mask, height=config.height, width=config.width, num_channels_latents=64) + + # 5. Concat the masked_image and the mask + masked_image_latents = mx.concatenate([masked_image, mask], axis=-1) + return masked_image_latents + + @staticmethod + def _reshape_mask(the_mask: mx.array, height: int, width: int): + mask = the_mask[:, 0, :, :] + mask = mx.reshape(mask, (1, height // 8, 8, width // 8, 8)) + mask = mx.transpose(mask, (0, 2, 4, 1, 3)) + mask = mx.reshape(mask, (1, 64, height // 8, width // 8)) + return mask diff --git a/src/mflux/ui/cli/parsers.py b/src/mflux/ui/cli/parsers.py index ad898df..face9e1 100644 --- a/src/mflux/ui/cli/parsers.py +++ b/src/mflux/ui/cli/parsers.py @@ -11,13 +11,14 @@ class ModelSpecAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): - if values in ["dev", "schnell"]: + if values in ui_defaults.MODEL_CHOICES: setattr(namespace, self.dest, values) return if values.count("/") != 1: raise argparse.ArgumentError( - self, f'Value must be either "dev", "schnell", or "in format "org/model". Got: {values}' + self, + (f'Value must be either {" ".join(ui_defaults.MODEL_CHOICES)} or in format "org/model". Got: {values}'), ) # If we got here, values contains exactly one slash @@ -77,6 +78,7 @@ def add_image_to_image_arguments(self, required=False) -> None: self.supports_image_to_image = True self.add_argument("--image-path", type=Path, required=required, default=None, help="Local path to init image") self.add_argument("--image-strength", type=float, required=False, default=ui_defaults.IMAGE_STRENGTH, help=f"Controls how strongly the init image influences the output image. A value of 0.0 means no influence. (Default is {ui_defaults.IMAGE_STRENGTH})") + self.add_argument("--masked-image-path", type=Path, required=False, default=None, help="Local path to separate masked image as complement to --image-path") def add_batch_image_generator_arguments(self) -> None: self.add_argument("--prompts-file", type=Path, required=True, default=argparse.SUPPRESS, help="Local path for a file that holds a batch of prompts.") diff --git a/src/mflux/ui/defaults.py b/src/mflux/ui/defaults.py index 1c12e28..c1bd6f3 100644 --- a/src/mflux/ui/defaults.py +++ b/src/mflux/ui/defaults.py @@ -2,7 +2,7 @@ GUIDANCE_SCALE = 3.5 HEIGHT, WIDTH = 1024, 1024 IMAGE_STRENGTH = 0.4 -MODEL_CHOICES = ["dev", "schnell"] +MODEL_CHOICES = ["dev", "dev-fill", "schnell"] MODEL_INFERENCE_STEPS = { "dev": 14, "schnell": 4, diff --git a/tests/model_config/test_model_config.py b/tests/model_config/test_model_config.py index 99cf540..e18a8bb 100644 --- a/tests/model_config/test_model_config.py +++ b/tests/model_config/test_model_config.py @@ -5,65 +5,134 @@ def test_bfl_dev(): - model_attrs = ModelConfig.from_name("dev") - assert model_attrs.model_name.startswith("black-forest-labs/") - assert model_attrs.max_sequence_length == 512 - assert model_attrs.supports_guidance is True + model = ModelConfig.from_name("dev") + assert model.alias == "dev" + assert model.model_name.startswith("black-forest-labs/") + assert model.max_sequence_length == 512 + assert model.num_train_steps == 1000 + assert model.supports_guidance is True + assert model.requires_sigma_shift is True def test_bfl_dev_full_name(): - model_attrs = ModelConfig.from_name("black-forest-labs/FLUX.1-dev") - assert model_attrs.model_name.startswith("black-forest-labs/") - assert model_attrs.max_sequence_length == 512 - assert model_attrs.supports_guidance is True + model = ModelConfig.from_name("black-forest-labs/FLUX.1-dev") + assert model.alias == "dev" + assert model.model_name.startswith("black-forest-labs/") + assert model.max_sequence_length == 512 + assert model.num_train_steps == 1000 + assert model.supports_guidance is True + assert model.requires_sigma_shift is True def test_bfl_schnell(): - model_attrs = ModelConfig.from_name("schnell") - assert model_attrs.model_name.startswith("black-forest-labs/") - assert model_attrs.max_sequence_length == 256 - assert model_attrs.supports_guidance is False + model = ModelConfig.from_name("schnell") + assert model.alias == "schnell" + assert model.model_name.startswith("black-forest-labs/") + assert model.max_sequence_length == 256 + assert model.num_train_steps == 1000 + assert model.supports_guidance is False + assert model.requires_sigma_shift is False def test_bfl_schnell_full_name(): - model_attrs = ModelConfig.from_name("black-forest-labs/FLUX.1-schnell") - assert model_attrs.model_name.startswith("black-forest-labs/") - assert model_attrs.max_sequence_length == 256 - assert model_attrs.supports_guidance is False + model = ModelConfig.from_name("black-forest-labs/FLUX.1-schnell") + assert model.alias == "schnell" + assert model.model_name.startswith("black-forest-labs/") + assert model.max_sequence_length == 256 + assert model.num_train_steps == 1000 + assert model.supports_guidance is False + assert model.requires_sigma_shift is False + + +def test_bfl_dev_fill(): + model = ModelConfig.from_name("dev-fill") + assert model.alias == "dev-fill" + assert model.model_name == "black-forest-labs/FLUX.1-Fill-dev" + assert model.max_sequence_length == 512 + assert model.num_train_steps == 1000 + assert model.supports_guidance is True + assert model.requires_sigma_shift is True + + +def test_bfl_dev_fill_full_name(): + model = ModelConfig.from_name("black-forest-labs/FLUX.1-Fill-dev") + assert model.alias == "dev-fill" + assert model.model_name == "black-forest-labs/FLUX.1-Fill-dev" + assert model.max_sequence_length == 512 + assert model.num_train_steps == 1000 + assert model.supports_guidance is True + assert model.requires_sigma_shift is True + + +def test_community_dev_fill_implicit_base_model(): + model = ModelConfig.from_name("acme-lab/some-dev-fill-model") + assert model.alias == "dev-fill" + assert model.max_sequence_length == 512 + assert model.num_train_steps == 1000 + assert model.supports_guidance is True + assert model.requires_sigma_shift is True + + +def test_community_dev_fill_explicit_base_model(): + model = ModelConfig.from_name("acme-lab/some-model", base_model="dev-fill") + assert model.alias == "dev-fill" + assert model.max_sequence_length == 512 + assert model.num_train_steps == 1000 + assert model.supports_guidance is True + assert model.requires_sigma_shift is True + + +def test_implicit_base_model_prefers_dev_fill_over_dev(): + model = ModelConfig.from_name("acme-lab/dev-fill-based-model") + assert model.alias == "dev-fill" + assert model.base_model == "black-forest-labs/FLUX.1-Fill-dev" + assert model.max_sequence_length == 512 + assert model.num_train_steps == 1000 + assert model.requires_sigma_shift is True def test_community_dev_implicit_base_model(): - model_attrs = ModelConfig.from_name("acme-lab/some-awesome-dev-model") - assert model_attrs.max_sequence_length == 512 - assert model_attrs.supports_guidance is True + model = ModelConfig.from_name("acme-lab/some-awesome-dev-model") + assert model.alias == "dev" + assert model.max_sequence_length == 512 + assert model.num_train_steps == 1000 + assert model.supports_guidance is True + assert model.requires_sigma_shift is True def test_community_schnell_implicit_base_model(): - model_attrs = ModelConfig.from_name("acme-lab/some-quick-schnell-model") - assert model_attrs.max_sequence_length == 256 - assert model_attrs.supports_guidance is False + model = ModelConfig.from_name("acme-lab/some-quick-schnell-model") + assert model.alias == "schnell" + assert model.max_sequence_length == 256 + assert model.num_train_steps == 1000 + assert model.supports_guidance is False + assert model.requires_sigma_shift is False def test_community_dev_explicit_base_model(): - model_attrs = ModelConfig.from_name("acme-lab/some-awesome-model", base_model="dev") - assert model_attrs.max_sequence_length == 512 - assert model_attrs.supports_guidance is True + model = ModelConfig.from_name("acme-lab/some-awesome-model", base_model="dev") + assert model.alias == "dev" + assert model.base_model == "black-forest-labs/FLUX.1-dev" + assert model.max_sequence_length == 512 + assert model.num_train_steps == 1000 + assert model.supports_guidance is True + assert model.requires_sigma_shift is True def test_community_schnell_explicit_base_model(): - model_attrs = ModelConfig.from_name("acme-lab/some-awesome-model", base_model="schnell") - assert model_attrs.max_sequence_length == 256 - assert model_attrs.supports_guidance is False + model = ModelConfig.from_name("acme-lab/some-awesome-model", base_model="schnell") + assert model.base_model == "black-forest-labs/FLUX.1-schnell" + assert model.max_sequence_length == 256 + assert model.num_train_steps == 1000 + assert model.supports_guidance is False + assert model.requires_sigma_shift is False def test_model_config_error(): - assert pytest.raises(ModelConfigError, ModelConfig.from_name, "acme-lab/some-model-who-knows-what-its-based-on") + with pytest.raises(ModelConfigError): + ModelConfig.from_name("acme-lab/some-model-who-knows-what-its-based-on") def test_invalid_base_model_error(): - assert pytest.raises( - InvalidBaseModel, - ModelConfig.from_name, - "acme-lab/some-model-who-knows-what-its-based-on", - base_model="something_unknown", - ) + with pytest.raises(InvalidBaseModel): + ModelConfig.from_name("acme-lab/some-model-who-knows-what-its-based-on", base_model="something_unknown") diff --git a/tools/inpaint_mask_tool.py b/tools/inpaint_mask_tool.py new file mode 100644 index 0000000..07c1a79 --- /dev/null +++ b/tools/inpaint_mask_tool.py @@ -0,0 +1,154 @@ +from pathlib import Path + +import cv2 +import numpy as np + + +class MaskCreator: + BRUSH_SIZES = {"1": 2, "2": 4, "3": 8, "4": 16, "5": 32, "6": 48, "7": 96, "8": 192, "9": 384} + + def __init__(self, image_path: Path): + self.original_image = cv2.imread(image_path) + if self.original_image is None: + raise FileNotFoundError(f"Could not open or find the image: {image_path}") + + self.mask_output_path = image_path.with_name(f"{image_path.stem}_mask").with_suffix(".png") + + # Create a window and set up display image + self.window_name = "MFlux Inpaint Mask Creator - Draw with mouse or trackpad (hot keys: (s)ave, (r)eset, (q)uit" + self.display_image = self.original_image.copy() + + # Create a blank mask the same size as the image + self.mask = np.zeros(self.original_image.shape[:2], dtype=np.uint8) + + # Set up drawing parameters + self.drawing = False + self.brush_size = self.BRUSH_SIZES["5"] + self.last_point = None + + self.overlay = np.zeros_like(self.original_image) + + # Update display every N drawing events - lower is more responsive + self.update_frequency = 1 + self.event_counter = 0 + + # Show the initial display + self.update_display() + + def mouse_callback(self, event, x, y, flags, param): + # Start drawing + if event == cv2.EVENT_LBUTTONDOWN: + self.drawing = True + self.last_point = (x, y) + cv2.circle(self.mask, (x, y), self.brush_size, 255, -1) + self.event_counter += 1 + if self.event_counter % self.update_frequency == 0: + self.update_display() + + # Continue drawing + elif event == cv2.EVENT_MOUSEMOVE and self.drawing: + # Use thickness based on brush size for smoother lines + if self.last_point: # Ensure we have a last point + # Draw a line between the last point and current point + cv2.line(self.mask, self.last_point, (x, y), 255, self.brush_size * 2) + # Also draw a circle at the current point to avoid gaps in fast movements + cv2.circle(self.mask, (x, y), self.brush_size, 255, -1) + self.last_point = (x, y) + self.event_counter += 1 + if self.event_counter % self.update_frequency == 0: + self.update_display() + + # Stop drawing + elif event == cv2.EVENT_LBUTTONUP: + self.drawing = False + self.update_display() # Always update display when stopping + + def update_display(self): + # Create a copy of the original image + self.display_image = self.original_image.copy() + + # Create a colored overlay for the mask (semi-transparent red) + self.overlay[:] = 0 # Reset overlay + self.overlay[self.mask > 0] = [0, 0, 255] # Red overlay + + # Apply the overlay + alpha = 0.5 # Transparency level + cv2.addWeighted(self.overlay, alpha, self.display_image, 1 - alpha, 0, self.display_image) + + # Draw brush size indicator in the corner + text = f"Brush Size: {self.brush_size} (Hotkeys 1-9: change brush size) " + cv2.putText( + self.display_image, + text, + (10, 30), + cv2.FONT_HERSHEY_DUPLEX, + 0.5, + (255, 0, 0), + 1, + cv2.LINE_AA, # line type: anti-aliased + ) + + # Display the result with OpenCV's high GUI priority + cv2.imshow(self.window_name, self.display_image) + cv2.waitKey(1) # Process events to force display update + + def save_mask(self, output_path): + cv2.imwrite(output_path, self.mask) + print(f"Mask saved to {output_path}") + + def reset_mask(self): + self.mask = np.zeros(self.original_image.shape[:2], dtype=np.uint8) + self.update_display() + + def set_brush_size(self, size_key): + if size_key in self.BRUSH_SIZES: + self.brush_size = self.BRUSH_SIZES[size_key] + print(f"Brush size {size_key}: {self.brush_size}") + self.update_display() + + def run(self): + cv2.namedWindow(self.window_name) + cv2.setMouseCallback(self.window_name, self.mouse_callback) + + while True: + key = cv2.waitKey(1) & 0xFF + key_char = chr(key) if key < 128 else "" + + # Check for brush size hotkeys (1-5) + if key_char in self.BRUSH_SIZES: + self.set_brush_size(key_char) + + # Save mask (press 's') + elif key == ord("s"): + self.save_mask(self.mask_output_path) + + # Reset mask (press 'r') + elif key == ord("r"): + self.reset_mask() + print("Mask reset") + + # Quit (press 'q' or ESC) + elif key == ord("q") or key == 27: + break + + cv2.destroyAllWindows() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Create binary mask image from source image to use as the complementary --masked-image-path arg." + ) + parser.add_argument("image_path", type=Path, help="Path to the input image") + args = parser.parse_args() + + try: + mask_creator = MaskCreator(args.image_path) + mask_creator.run() + except FileNotFoundError as e: + print(f"Error: {e}") + except Exception as e: # noqa + print(f"An unexpected error occurred: {e}") + except KeyboardInterrupt: + pass