Skip to content

Commit

Permalink
attempted rebase of filipstrand#129 on ce7dfea
Browse files Browse the repository at this point in the history
  • Loading branch information
Anthony Wu committed Mar 8, 2025
1 parent ce7dfea commit 452c34c
Show file tree
Hide file tree
Showing 12 changed files with 424 additions and 91 deletions.
2 changes: 2 additions & 0 deletions src/mflux/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
guidance: float = 4.0,
image_path: Path | None = None,
image_strength: float | None = None,
masked_image_path: Path | None = None,
controlnet_strength: float | None = None,
):
if width % 16 != 0 or height % 16 != 0:
Expand All @@ -27,4 +28,5 @@ def __init__(
self.guidance = guidance
self.image_path = image_path
self.image_strength = image_strength
self.masked_image_path = masked_image_path
self.controlnet_strength = controlnet_strength
113 changes: 70 additions & 43 deletions src/mflux/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,80 +13,107 @@ def __init__(
num_train_steps: int,
max_sequence_length: int,
supports_guidance: bool,
requires_sigma_shift: bool,
priority: int,
):
self.alias = alias
self.model_name = model_name
self.base_model = base_model
self.num_train_steps = num_train_steps
self.max_sequence_length = max_sequence_length
self.supports_guidance = supports_guidance
self.requires_sigma_shift = requires_sigma_shift
self.priority = priority

@staticmethod
@lru_cache
def dev() -> "ModelConfig":
return ModelConfig(
alias="dev",
model_name="black-forest-labs/FLUX.1-dev",
base_model=None,
num_train_steps=1000,
max_sequence_length=512,
supports_guidance=True,
)
return AVAILABLE_MODELS["dev"]

@staticmethod
@lru_cache
def dev_fill() -> "ModelConfig":
return AVAILABLE_MODELS["dev-fill"]

@staticmethod
@lru_cache
def schnell() -> "ModelConfig":
return ModelConfig(
alias="schnell",
model_name="black-forest-labs/FLUX.1-schnell",
base_model=None,
num_train_steps=1000,
max_sequence_length=256,
supports_guidance=False,
)
return AVAILABLE_MODELS["schnell"]

@staticmethod
def from_name(
model_name: str,
base_model: Literal["dev", "schnell"] | None = None,
base_model: Literal["dev", "schnell", "dev-fill"] | None = None,
) -> "ModelConfig":
dev = ModelConfig.dev()
schnell = ModelConfig.schnell()
# 0. Get all base models (where base_model is None) sorted by priority
base_models = sorted(
[model for model in AVAILABLE_MODELS.values() if model.base_model is None], key=lambda x: x.priority
)

# 1. Check if model_name matches any base model's alias or full name
for base in base_models:
if model_name in (base.alias, base.model_name):
return base

# 0. Validate explicit base_model
allowed_names = [dev.alias, dev.model_name, schnell.alias, schnell.model_name]
# 2. Validate explicit base_model
allowed_names = []
for base in base_models:
allowed_names.extend([base.alias, base.model_name])
if base_model and base_model not in allowed_names:
raise InvalidBaseModel(f"Invalid base_model. Choose one of {allowed_names}")

# 1. If model_name is "dev" or "schnell" then simply return
if model_name == dev.model_name or model_name == dev.alias:
return dev
if model_name == schnell.model_name or model_name == schnell.alias:
return schnell
# 3. Determine the base model (explicit or inferred)
if base_model:
# Find by explicit base_model name
default_base = next((b for b in base_models if base_model in (b.alias, b.model_name)), None)
else:
# Infer from model_name substring (priority order via sorted base_models)
default_base = next((b for b in base_models if b.alias and b.alias in model_name), None)
if not default_base:
raise ModelConfigError(f"Cannot infer base_model from {model_name}")

# 1. Determine the appropriate base model
default_base = None
if not base_model:
if "dev" in model_name:
default_base = dev
elif "schnell" in model_name:
default_base = schnell
else:
raise ModelConfigError(f"Cannot infer base_model from {model_name}. Specify --base-model.")
elif base_model == dev.model_name or base_model == dev.alias:
default_base = dev
elif base_model == schnell.model_name or base_model == schnell.alias:
default_base = schnell

# 2. Construct the config based on the model name and base default
# 4. Construct the config
return ModelConfig(
alias=default_base.alias,
model_name=model_name,
base_model=default_base.model_name,
num_train_steps=default_base.num_train_steps,
max_sequence_length=default_base.max_sequence_length,
supports_guidance=default_base.supports_guidance,
requires_sigma_shift=default_base.requires_sigma_shift,
priority=default_base.priority,
)

def is_dev(self) -> bool:
return self.alias == "dev"

AVAILABLE_MODELS = {
"schnell": ModelConfig(
alias="schnell",
model_name="black-forest-labs/FLUX.1-schnell",
base_model=None,
num_train_steps=1000,
max_sequence_length=256,
supports_guidance=False,
requires_sigma_shift=False,
priority=2,
),
"dev": ModelConfig(
alias="dev",
model_name="black-forest-labs/FLUX.1-dev",
base_model=None,
num_train_steps=1000,
max_sequence_length=512,
supports_guidance=True,
requires_sigma_shift=True,
priority=1,
),
"dev-fill": ModelConfig(
alias="dev-fill",
model_name="black-forest-labs/FLUX.1-Fill-dev",
base_model=None,
num_train_steps=1000,
max_sequence_length=512,
supports_guidance=True,
requires_sigma_shift=True,
priority=0,
),
}
8 changes: 6 additions & 2 deletions src/mflux/config/runtime_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def width(self, value):
self.config.width = value

@property
def guidance(self) -> float:
def guidance(self) -> float | None:
return self.config.guidance

@property
Expand All @@ -55,6 +55,10 @@ def image_path(self) -> str:
def image_strength(self) -> float | None:
return self.config.image_strength

@property
def masked_image_path(self) -> str | None:
return self.config.masked_image_path

@property
def init_time_step(self) -> int:
is_img2img = (
Expand Down Expand Up @@ -82,7 +86,7 @@ def controlnet_strength(self) -> float | None:
@staticmethod
def _create_sigmas(config: Config, model_config: ModelConfig) -> mx.array:
sigmas = RuntimeConfig._create_sigmas_values(config.num_inference_steps)
if model_config.is_dev():
if model_config.requires_sigma_shift:
sigmas = RuntimeConfig._shift_sigmas(sigmas=sigmas, width=config.width, height=config.height)
return sigmas

Expand Down
19 changes: 16 additions & 3 deletions src/mflux/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mflux.post_processing.array_util import ArrayUtil
from mflux.post_processing.generated_image import GeneratedImage
from mflux.post_processing.image_util import ImageUtil
from mflux.post_processing.mask_util import MaskUtil
from mflux.weights.model_saver import ModelSaver


Expand Down Expand Up @@ -76,6 +77,15 @@ def generate_image(
clip_text_encoder=self.clip_text_encoder,
)

# 3. Create the static masked latents
static_masked_latents = MaskUtil.create_masked_latents(
vae=self.vae,
config=config,
latents=latents,
img_path=config.image_path,
mask_path=config.masked_image_path,
)

# (Optional) Call subscribers for beginning of loop
Callbacks.before_loop(
seed=seed,
Expand All @@ -86,16 +96,19 @@ def generate_image(

for t in time_steps:
try:
# 3.t Predict the noise
# 4.t Concatenate the updated latents with the static masked latents
hidden_states = mx.concatenate([latents, static_masked_latents], axis=-1)

# 5.t Predict the noise
noise = self.transformer(
t=t,
config=config,
hidden_states=latents,
hidden_states=hidden_states,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
)

# 4.t Take one denoise step
# 6.t Take one denoise step
dt = config.sigmas[t + 1] - config.sigmas[t]
latents += noise * dt

Expand Down
1 change: 1 addition & 0 deletions src/mflux/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def main():
width=args.width,
guidance=args.guidance,
image_path=args.image_path,
masked_image_path=args.masked_image_path,
image_strength=args.image_strength,
),
)
Expand Down
6 changes: 3 additions & 3 deletions src/mflux/post_processing/array_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array:
return latents

@staticmethod
def pack_latents(latents: mx.array, height: int, width: int) -> mx.array:
latents = mx.reshape(latents, (1, 16, height // 16, 2, width // 16, 2))
def pack_latents(latents: mx.array, height: int, width: int, num_channels_latents: int = 16) -> mx.array:
latents = mx.reshape(latents, (1, num_channels_latents, height // 16, 2, width // 16, 2))
latents = mx.transpose(latents, (0, 2, 4, 1, 3, 5))
latents = mx.reshape(latents, (1, (width // 16) * (height // 16), 64))
latents = mx.reshape(latents, (1, (width // 16) * (height // 16), num_channels_latents * 4))
return latents
11 changes: 9 additions & 2 deletions src/mflux/post_processing/image_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def _denormalize(images: mx.array) -> mx.array:
def _normalize(images: mx.array) -> mx.array:
return 2.0 * images - 1.0

@staticmethod
def _binarize(image: mx.array) -> mx.array:
return mx.where(image < 0.5, mx.zeros_like(image), mx.ones_like(image))

@staticmethod
def _to_numpy(images: mx.array) -> np.ndarray:
images = mx.transpose(images, (0, 2, 3, 1))
Expand All @@ -90,11 +94,14 @@ def _pil_to_numpy(image: PIL.Image.Image) -> np.ndarray:
return images

@staticmethod
def to_array(image: PIL.Image.Image) -> mx.array:
def to_array(image: PIL.Image.Image, is_mask: bool = False) -> mx.array:
image = ImageUtil._pil_to_numpy(image)
array = mx.array(image)
array = mx.transpose(array, (0, 3, 1, 2))
array = ImageUtil._normalize(array)
if is_mask:
array = ImageUtil._binarize(array)
else:
array = ImageUtil._normalize(array)
return array

@staticmethod
Expand Down
54 changes: 54 additions & 0 deletions src/mflux/post_processing/mask_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import mlx.core as mx

from mflux.config.runtime_config import RuntimeConfig
from mflux.models.vae.vae import VAE
from mflux.post_processing.array_util import ArrayUtil


class MaskUtil:
@staticmethod
def create_masked_latents(
vae: VAE,
config: RuntimeConfig,
latents: mx.array,
img_path: str,
mask_path: str | None
) -> mx.array: # fmt: off
from mflux import ImageUtil

# 1. Get the reference image
scaled_image = ImageUtil.scale_to_dimensions(
image=ImageUtil.load_image(config.image_path).convert("RGB"),
target_width=config.width,
target_height=config.height,
)
image = ImageUtil.to_array(scaled_image)

# 2. Get the mask
scaled = ImageUtil.scale_to_dimensions(
image=ImageUtil.load_image(mask_path).convert("RGB"),
target_width=config.width,
target_height=config.height,
)
the_mask = ImageUtil.to_array(scaled, is_mask=True)

# 3. Create and pack the masked image
masked_image = image * (1 - the_mask)
masked_image = vae.encode(masked_image)
masked_image = ArrayUtil.pack_latents(latents=masked_image, height=config.height, width=config.width)

# 4. Resize mask and pack latents
mask = MaskUtil._reshape_mask(the_mask=the_mask, height=config.height, width=config.width)
mask = ArrayUtil.pack_latents(latents=mask, height=config.height, width=config.width, num_channels_latents=64)

# 5. Concat the masked_image and the mask
masked_image_latents = mx.concatenate([masked_image, mask], axis=-1)
return masked_image_latents

@staticmethod
def _reshape_mask(the_mask: mx.array, height: int, width: int):
mask = the_mask[:, 0, :, :]
mask = mx.reshape(mask, (1, height // 8, 8, width // 8, 8))
mask = mx.transpose(mask, (0, 2, 4, 1, 3))
mask = mx.reshape(mask, (1, 64, height // 8, width // 8))
return mask
6 changes: 4 additions & 2 deletions src/mflux/ui/cli/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@

class ModelSpecAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if values in ["dev", "schnell"]:
if values in ui_defaults.MODEL_CHOICES:
setattr(namespace, self.dest, values)
return

if values.count("/") != 1:
raise argparse.ArgumentError(
self, f'Value must be either "dev", "schnell", or "in format "org/model". Got: {values}'
self,
(f'Value must be either {" ".join(ui_defaults.MODEL_CHOICES)} or in format "org/model". Got: {values}'),
)

# If we got here, values contains exactly one slash
Expand Down Expand Up @@ -77,6 +78,7 @@ def add_image_to_image_arguments(self, required=False) -> None:
self.supports_image_to_image = True
self.add_argument("--image-path", type=Path, required=required, default=None, help="Local path to init image")
self.add_argument("--image-strength", type=float, required=False, default=ui_defaults.IMAGE_STRENGTH, help=f"Controls how strongly the init image influences the output image. A value of 0.0 means no influence. (Default is {ui_defaults.IMAGE_STRENGTH})")
self.add_argument("--masked-image-path", type=Path, required=False, default=None, help="Local path to separate masked image as complement to --image-path")

def add_batch_image_generator_arguments(self) -> None:
self.add_argument("--prompts-file", type=Path, required=True, default=argparse.SUPPRESS, help="Local path for a file that holds a batch of prompts.")
Expand Down
2 changes: 1 addition & 1 deletion src/mflux/ui/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
GUIDANCE_SCALE = 3.5
HEIGHT, WIDTH = 1024, 1024
IMAGE_STRENGTH = 0.4
MODEL_CHOICES = ["dev", "schnell"]
MODEL_CHOICES = ["dev", "dev-fill", "schnell"]
MODEL_INFERENCE_STEPS = {
"dev": 14,
"schnell": 4,
Expand Down
Loading

0 comments on commit 452c34c

Please sign in to comment.