Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FLUX.1 Tools | Fill #129

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mflux/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
guidance: float = 4.0,
image_path: Path | None = None,
image_strength: float | None = None,
masked_image_path: Path | None = None,
controlnet_strength: float | None = None,
):
if width % 16 != 0 or height % 16 != 0:
Expand All @@ -27,4 +28,5 @@ def __init__(
self.guidance = guidance
self.image_path = image_path
self.image_strength = image_strength
self.masked_image_path = masked_image_path
self.controlnet_strength = controlnet_strength
113 changes: 70 additions & 43 deletions src/mflux/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,80 +13,107 @@ def __init__(
num_train_steps: int,
max_sequence_length: int,
supports_guidance: bool,
requires_sigma_shift: bool,
priority: int,
):
self.alias = alias
self.model_name = model_name
self.base_model = base_model
self.num_train_steps = num_train_steps
self.max_sequence_length = max_sequence_length
self.supports_guidance = supports_guidance
self.requires_sigma_shift = requires_sigma_shift
self.priority = priority

@staticmethod
@lru_cache
def dev() -> "ModelConfig":
return ModelConfig(
alias="dev",
model_name="black-forest-labs/FLUX.1-dev",
base_model=None,
num_train_steps=1000,
max_sequence_length=512,
supports_guidance=True,
)
return AVAILABLE_MODELS["dev"]

@staticmethod
@lru_cache
def dev_fill() -> "ModelConfig":
return AVAILABLE_MODELS["dev-fill"]

@staticmethod
@lru_cache
def schnell() -> "ModelConfig":
return ModelConfig(
alias="schnell",
model_name="black-forest-labs/FLUX.1-schnell",
base_model=None,
num_train_steps=1000,
max_sequence_length=256,
supports_guidance=False,
)
return AVAILABLE_MODELS["schnell"]

@staticmethod
def from_name(
model_name: str,
base_model: Literal["dev", "schnell"] | None = None,
base_model: Literal["dev", "schnell", "dev-fill"] | None = None,
) -> "ModelConfig":
dev = ModelConfig.dev()
schnell = ModelConfig.schnell()
# 0. Get all base models (where base_model is None) sorted by priority
base_models = sorted(
[model for model in AVAILABLE_MODELS.values() if model.base_model is None], key=lambda x: x.priority
)

# 1. Check if model_name matches any base model's alias or full name
for base in base_models:
if model_name in (base.alias, base.model_name):
return base

# 0. Validate explicit base_model
allowed_names = [dev.alias, dev.model_name, schnell.alias, schnell.model_name]
# 2. Validate explicit base_model
allowed_names = []
for base in base_models:
allowed_names.extend([base.alias, base.model_name])
if base_model and base_model not in allowed_names:
raise InvalidBaseModel(f"Invalid base_model. Choose one of {allowed_names}")

# 1. If model_name is "dev" or "schnell" then simply return
if model_name == dev.model_name or model_name == dev.alias:
return dev
if model_name == schnell.model_name or model_name == schnell.alias:
return schnell
# 3. Determine the base model (explicit or inferred)
if base_model:
# Find by explicit base_model name
default_base = next((b for b in base_models if base_model in (b.alias, b.model_name)), None)
else:
# Infer from model_name substring (priority order via sorted base_models)
default_base = next((b for b in base_models if b.alias and b.alias in model_name), None)
if not default_base:
raise ModelConfigError(f"Cannot infer base_model from {model_name}")

# 1. Determine the appropriate base model
default_base = None
if not base_model:
if "dev" in model_name:
default_base = dev
elif "schnell" in model_name:
default_base = schnell
else:
raise ModelConfigError(f"Cannot infer base_model from {model_name}. Specify --base-model.")
elif base_model == dev.model_name or base_model == dev.alias:
default_base = dev
elif base_model == schnell.model_name or base_model == schnell.alias:
default_base = schnell

# 2. Construct the config based on the model name and base default
# 4. Construct the config
return ModelConfig(
alias=default_base.alias,
model_name=model_name,
base_model=default_base.model_name,
num_train_steps=default_base.num_train_steps,
max_sequence_length=default_base.max_sequence_length,
supports_guidance=default_base.supports_guidance,
requires_sigma_shift=default_base.requires_sigma_shift,
priority=default_base.priority,
)

def is_dev(self) -> bool:
return self.alias == "dev"

AVAILABLE_MODELS = {
"schnell": ModelConfig(
alias="schnell",
model_name="black-forest-labs/FLUX.1-schnell",
base_model=None,
num_train_steps=1000,
max_sequence_length=256,
supports_guidance=False,
requires_sigma_shift=False,
priority=2,
),
"dev": ModelConfig(
alias="dev",
model_name="black-forest-labs/FLUX.1-dev",
base_model=None,
num_train_steps=1000,
max_sequence_length=512,
supports_guidance=True,
requires_sigma_shift=True,
priority=1,
),
"dev-fill": ModelConfig(
alias="dev-fill",
model_name="black-forest-labs/FLUX.1-Fill-dev",
base_model=None,
num_train_steps=1000,
max_sequence_length=512,
supports_guidance=True,
requires_sigma_shift=True,
priority=0,
),
}
8 changes: 6 additions & 2 deletions src/mflux/config/runtime_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def width(self, value):
self.config.width = value

@property
def guidance(self) -> float:
def guidance(self) -> float | None:
return self.config.guidance

@property
Expand All @@ -55,6 +55,10 @@ def image_path(self) -> str:
def image_strength(self) -> float | None:
return self.config.image_strength

@property
def masked_image_path(self) -> str | None:
return self.config.masked_image_path

@property
def init_time_step(self) -> int:
is_img2img = (
Expand Down Expand Up @@ -82,7 +86,7 @@ def controlnet_strength(self) -> float | None:
@staticmethod
def _create_sigmas(config: Config, model_config: ModelConfig) -> mx.array:
sigmas = RuntimeConfig._create_sigmas_values(config.num_inference_steps)
if model_config.is_dev():
if model_config.requires_sigma_shift:
sigmas = RuntimeConfig._shift_sigmas(sigmas=sigmas, width=config.width, height=config.height)
return sigmas

Expand Down
19 changes: 16 additions & 3 deletions src/mflux/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mflux.post_processing.array_util import ArrayUtil
from mflux.post_processing.generated_image import GeneratedImage
from mflux.post_processing.image_util import ImageUtil
from mflux.post_processing.mask_util import MaskUtil
from mflux.weights.model_saver import ModelSaver


Expand Down Expand Up @@ -76,6 +77,15 @@ def generate_image(
clip_text_encoder=self.clip_text_encoder,
)

# 3. Create the static masked latents
static_masked_latents = MaskUtil.create_masked_latents(
vae=self.vae,
config=config,
latents=latents,
img_path=config.image_path,
mask_path=config.masked_image_path,
)

# (Optional) Call subscribers for beginning of loop
Callbacks.before_loop(
seed=seed,
Expand All @@ -86,16 +96,19 @@ def generate_image(

for t in time_steps:
try:
# 3.t Predict the noise
# 4.t Concatenate the updated latents with the static masked latents
hidden_states = mx.concatenate([latents, static_masked_latents], axis=-1)

# 5.t Predict the noise
noise = self.transformer(
t=t,
config=config,
hidden_states=latents,
hidden_states=hidden_states,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
)

# 4.t Take one denoise step
# 6.t Take one denoise step
dt = config.sigmas[t + 1] - config.sigmas[t]
latents += noise * dt

Expand Down
1 change: 1 addition & 0 deletions src/mflux/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def main():
width=args.width,
guidance=args.guidance,
image_path=args.image_path,
masked_image_path=args.masked_image_path,
image_strength=args.image_strength,
),
)
Expand Down
6 changes: 3 additions & 3 deletions src/mflux/post_processing/array_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array:
return latents

@staticmethod
def pack_latents(latents: mx.array, height: int, width: int) -> mx.array:
latents = mx.reshape(latents, (1, 16, height // 16, 2, width // 16, 2))
def pack_latents(latents: mx.array, height: int, width: int, num_channels_latents: int = 16) -> mx.array:
latents = mx.reshape(latents, (1, num_channels_latents, height // 16, 2, width // 16, 2))
latents = mx.transpose(latents, (0, 2, 4, 1, 3, 5))
latents = mx.reshape(latents, (1, (width // 16) * (height // 16), 64))
latents = mx.reshape(latents, (1, (width // 16) * (height // 16), num_channels_latents * 4))
return latents
11 changes: 9 additions & 2 deletions src/mflux/post_processing/image_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def _denormalize(images: mx.array) -> mx.array:
def _normalize(images: mx.array) -> mx.array:
return 2.0 * images - 1.0

@staticmethod
def _binarize(image: mx.array) -> mx.array:
return mx.where(image < 0.5, mx.zeros_like(image), mx.ones_like(image))

@staticmethod
def _to_numpy(images: mx.array) -> np.ndarray:
images = mx.transpose(images, (0, 2, 3, 1))
Expand All @@ -90,11 +94,14 @@ def _pil_to_numpy(image: PIL.Image.Image) -> np.ndarray:
return images

@staticmethod
def to_array(image: PIL.Image.Image) -> mx.array:
def to_array(image: PIL.Image.Image, is_mask: bool = False) -> mx.array:
image = ImageUtil._pil_to_numpy(image)
array = mx.array(image)
array = mx.transpose(array, (0, 3, 1, 2))
array = ImageUtil._normalize(array)
if is_mask:
array = ImageUtil._binarize(array)
else:
array = ImageUtil._normalize(array)
return array

@staticmethod
Expand Down
54 changes: 54 additions & 0 deletions src/mflux/post_processing/mask_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import mlx.core as mx

from mflux.config.runtime_config import RuntimeConfig
from mflux.models.vae.vae import VAE
from mflux.post_processing.array_util import ArrayUtil


class MaskUtil:
@staticmethod
def create_masked_latents(
vae: VAE,
config: RuntimeConfig,
latents: mx.array,
img_path: str,
mask_path: str | None
) -> mx.array: # fmt: off
from mflux import ImageUtil

# 1. Get the reference image
scaled_image = ImageUtil.scale_to_dimensions(
image=ImageUtil.load_image(config.image_path).convert("RGB"),
target_width=config.width,
target_height=config.height,
)
image = ImageUtil.to_array(scaled_image)

# 2. Get the mask
scaled = ImageUtil.scale_to_dimensions(
image=ImageUtil.load_image(mask_path).convert("RGB"),
target_width=config.width,
target_height=config.height,
)
the_mask = ImageUtil.to_array(scaled, is_mask=True)

# 3. Create and pack the masked image
masked_image = image * (1 - the_mask)
masked_image = vae.encode(masked_image)
masked_image = ArrayUtil.pack_latents(latents=masked_image, height=config.height, width=config.width)

# 4. Resize mask and pack latents
mask = MaskUtil._reshape_mask(the_mask=the_mask, height=config.height, width=config.width)
mask = ArrayUtil.pack_latents(latents=mask, height=config.height, width=config.width, num_channels_latents=64)

# 5. Concat the masked_image and the mask
masked_image_latents = mx.concatenate([masked_image, mask], axis=-1)
return masked_image_latents

@staticmethod
def _reshape_mask(the_mask: mx.array, height: int, width: int):
mask = the_mask[:, 0, :, :]
mask = mx.reshape(mask, (1, height // 8, 8, width // 8, 8))
mask = mx.transpose(mask, (0, 2, 4, 1, 3))
mask = mx.reshape(mask, (1, 64, height // 8, width // 8))
return mask
6 changes: 4 additions & 2 deletions src/mflux/ui/cli/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@

class ModelSpecAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if values in ["dev", "schnell"]:
if values in ui_defaults.MODEL_CHOICES:
setattr(namespace, self.dest, values)
return

if values.count("/") != 1:
raise argparse.ArgumentError(
self, f'Value must be either "dev", "schnell", or "in format "org/model". Got: {values}'
self,
(f'Value must be either {" ".join(ui_defaults.MODEL_CHOICES)} or in format "org/model". Got: {values}'),
)

# If we got here, values contains exactly one slash
Expand Down Expand Up @@ -77,6 +78,7 @@ def add_image_to_image_arguments(self, required=False) -> None:
self.supports_image_to_image = True
self.add_argument("--image-path", type=Path, required=required, default=None, help="Local path to init image")
self.add_argument("--image-strength", type=float, required=False, default=ui_defaults.IMAGE_STRENGTH, help=f"Controls how strongly the init image influences the output image. A value of 0.0 means no influence. (Default is {ui_defaults.IMAGE_STRENGTH})")
self.add_argument("--masked-image-path", type=Path, required=False, default=None, help="Local path to separate masked image as complement to --image-path")

def add_batch_image_generator_arguments(self) -> None:
self.add_argument("--prompts-file", type=Path, required=True, default=argparse.SUPPRESS, help="Local path for a file that holds a batch of prompts.")
Expand Down
2 changes: 1 addition & 1 deletion src/mflux/ui/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
GUIDANCE_SCALE = 3.5
HEIGHT, WIDTH = 1024, 1024
IMAGE_STRENGTH = 0.4
MODEL_CHOICES = ["dev", "schnell"]
MODEL_CHOICES = ["dev", "dev-fill", "schnell"]
MODEL_INFERENCE_STEPS = {
"dev": 14,
"schnell": 4,
Expand Down
Loading