Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP flax.linen flux.1 ported from nnx jflux #141

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxdiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@
"FlaxStableDiffusionInpaintPipeline",
"FlaxStableDiffusionPipeline",
"FlaxStableDiffusionXLPipeline",
"JfluxPipeline",
]
)

Expand Down Expand Up @@ -478,6 +479,7 @@
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
FlaxStableDiffusionXLPipeline,
JfluxPipeline,
)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from maxdiffusion.transformers import (CLIPTokenizer, FlaxCLIPTextModel, CLIPTextConfig, FlaxCLIPTextModelWithProjection)

from maxdiffusion.checkpointing.checkpointing_utils import (
create_orbax_checkpoint_manager,
create_stable_diffusion_orbax_checkpoint_manager,
load_stable_diffusion_configs,
)

Expand All @@ -57,24 +57,18 @@ def __init__(self, config, checkpoint_type):
self.mesh = Mesh(devices_array, self.config.mesh_axes)
self.total_train_batch_size = self.config.total_train_batch_size

self.checkpoint_manager = create_orbax_checkpoint_manager(
self.config.checkpoint_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=checkpoint_type,
dataset_type=config.dataset_type,
self.checkpoint_manager = create_stable_diffusion_orbax_checkpoint_manager(
self.config.checkpoint_dir, enable_checkpointing=True, save_interval_steps=1, checkpoint_type=checkpoint_type
)

def _create_optimizer(self, config, learning_rate):

learning_rate_scheduler = max_utils.create_learning_rate_schedule(
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
)
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
return tx, learning_rate_scheduler

def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training):

tx, learning_rate_scheduler = None, None
if is_training:
learning_rate = self.config.learning_rate
Expand All @@ -96,7 +90,6 @@ def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training)
return unet_state, state_mesh_shardings, learning_rate_scheduler

def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):

# Currently VAE training is not supported.
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng)
return max_utils.setup_initial_state(
Expand All @@ -112,7 +105,6 @@ def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=F
)

def create_text_encoder_state(self, pipeline, params, checkpoint_item_name, is_training):

tx = None
if is_training:
learning_rate = self.config.text_encoder_learning_rate
Expand Down Expand Up @@ -259,11 +251,9 @@ def config_to_json(model_or_config):
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

def load_params(self, step=None):

self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX

def load_checkpoint(self, step=None, scheduler_class=None):

pipeline_class = self._get_pipeline_class()

self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX
Expand Down
57 changes: 42 additions & 15 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@
def create_orbax_checkpoint_manager(
checkpoint_dir: str,
enable_checkpointing: bool,
save_interval_steps,
save_interval_steps: int,
checkpoint_type: str,
dataset_type: str = "tf",
use_async: bool = True,
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
item_names=None,
):
"""
Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.
Expand All @@ -56,6 +57,29 @@ def create_orbax_checkpoint_manager(
max_logging.log(f"checkpoint dir: {checkpoint_dir}")
p = epath.Path(checkpoint_dir)

print("item_names: ", item_names)

mngr = CheckpointManager(
p,
item_names=item_names,
options=CheckpointManagerOptions(
create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async
),
logger=orbax_logger,
)

max_logging.log("Checkpoint manager created!")
return mngr


def create_stable_diffusion_orbax_checkpoint_manager(
checkpoint_dir: str,
enable_checkpointing: bool,
save_interval_steps: int,
checkpoint_type: str,
use_async: bool = True,
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
):
item_names = (
"unet_config",
"vae_config",
Expand All @@ -74,6 +98,8 @@ def create_orbax_checkpoint_manager(
if dataset_type == "grain":
item_names += ("iter",)

if override_item_names is not None:
item_names = override_item_names
print("item_names: ", item_names)

mngr = CheckpointManager(
Expand All @@ -84,9 +110,9 @@ def create_orbax_checkpoint_manager(
),
logger=orbax_logger,
)

max_logging.log("Checkpoint manager created!")
return mngr
return create_orbax_checkpoint_manager(
checkpoint_dir, enable_checkpointing, save_interval_steps, use_async, orbax_logger, item_names
)


def load_stable_diffusion_configs(
Expand Down Expand Up @@ -204,11 +230,10 @@ def load_state_if_possible(
if latest_step is None:
return None
else:
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
try:
if not enable_single_replica_ckpt_restoring:
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
max_logging.log(
f"restoring from this run's directory latest step {latest_step}"
)

def map_to_pspec(data):
pspec = data.sharding.spec
Expand All @@ -227,18 +252,20 @@ def map_to_pspec(data):
dtype=data.dtype,
)

array_handler = ocp.type_handlers.SingleReplicaArrayHandler(
replica_axis_index=0,
broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)
if enable_single_replica_ckpt_restoring:
array_handler = ocp.type_handlers.SingleReplicaArrayHandler(
replica_axis_index=0,
broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

restore_args = jax.tree_util.tree_map(
map_to_pspec,
abstract_unboxed_pre_state,
)

item = {checkpoint_item: ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)}
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
except:
max_logging.log(f"could not load {checkpoint_item} from orbax")
except Exception as e:
max_logging.log(f"could not load {checkpoint_item} from orbax: {e}")
return None
157 changes: 157 additions & 0 deletions src/maxdiffusion/checkpointing/jflux_checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from abc import ABC
import functools
import jax
from jax.sharding import PartitionSpec
from jax.sharding import Mesh
import orbax.checkpoint as ocp
from maxdiffusion import (max_utils)
from maxdiffusion.pipelines.jflux.pipeline_jflux import JfluxPipeline
from maxdiffusion.models.flux_utils import configs
from maxdiffusion.models.transformers.transformer_flux_flax import FluxTransformer2DModel
from maxdiffusion.models.embeddings_flax import HFEmbedder
from maxdiffusion.models.flux_utils import load_ae
from flax.linen import partitioning as nn_partitioning

from maxdiffusion.checkpointing.checkpointing_utils import (
create_orbax_checkpoint_manager,
)


class JfluxCheckpointer(ABC):
flux_state_item_name = "flux_state"
config_item_name = "config"

def __init__(self, config):
self.config = config

self.rng = jax.random.PRNGKey(self.config.seed)
devices_array = max_utils.create_device_mesh(config)
self.mesh = Mesh(devices_array, self.config.mesh_axes)
self.total_train_batch_size = self.config.total_train_batch_size

self.checkpoint_manager = create_orbax_checkpoint_manager(
self.config.checkpoint_dir,
enable_checkpointing=True,
save_interval_steps=self.config.save_interval_steps,
checkpoint_type="none",
item_names={JfluxCheckpointer.flux_state_item_name, JfluxCheckpointer.config_item_name},
)

def _create_optimizer(self, config):
learning_rate_scheduler = max_utils.create_learning_rate_schedule(config)
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
return tx, learning_rate_scheduler

def create_flux_state(self, flux, init_flux_weights, params, is_training, use_jit=True):
tx, learning_rate_scheduler = None, None
if is_training:

tx, learning_rate_scheduler = self._create_optimizer(self.config)

if init_flux_weights is not None:
weights_init_fn = functools.partial(init_flux_weights, rng=self.rng)
else:
weights_init_fn = None
flux_state, state_mesh_shardings = max_utils.setup_initial_state(
model=flux,
tx=tx,
config=self.config,
mesh=self.mesh,
weights_init_fn=weights_init_fn,
model_params=params.get(JfluxCheckpointer.flux_state_item_name, None) if params is not None else None,
checkpoint_manager=self.checkpoint_manager,
checkpoint_item=JfluxCheckpointer.flux_state_item_name,
training=is_training,
use_jit=use_jit,
)

return flux_state, state_mesh_shardings, learning_rate_scheduler

def _get_pipeline_class(self):
return JfluxPipeline

def save_checkpoint(self, train_step, pipeline, train_states):
items = {
JfluxCheckpointer.config_item_name: ocp.args.JsonSave({"model_name": self.config.model_name}),
}

items[JfluxCheckpointer.flux_state_item_name] = ocp.args.PyTreeSave(train_states[JfluxCheckpointer.flux_state_item_name])

self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

def load_pretrained_model(self, model_name):
# This code to generate the safetensors filename may not generalize
# but loading does not work without it
print(f"loading pretrained model {self.config.pretrained_model_name_or_path}")
stname = self.config.pretrained_model_name_or_path.split("/")[1].lower().replace(".", "")
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
flux, weights = FluxTransformer2DModel.from_pretrained(
pretrained_model_name_or_path=self.config.pretrained_model_name_or_path,
subfolder="transformer",
from_pt=True,
filename=f"{stname}.safetensors",
mesh=self.mesh,
)
weights = jax.tree_util.tree_map(lambda x: x.astype(self.config.weights_dtype), weights)
return flux, weights

def load_checkpoint(self, step=None, scheduler_class=None):
with jax.default_device(jax.devices("cpu")[0]):
t5 = HFEmbedder(
"ariG23498/t5-v1-1-xxl-flax",
max_length=256 if self.config.model_name == "flux-schnell" else 512,
dtype=jax.numpy.bfloat16,
)

clip = HFEmbedder(
"ariG23498/clip-vit-large-patch14-text-flax",
max_length=77,
dtype=jax.numpy.bfloat16,
)

ae = load_ae(self.config.model_name, "cpu")

precision = max_utils.get_precision(self.config)
flash_block_sizes = max_utils.get_flash_block_sizes(self.config)
data_sharding = jax.sharding.NamedSharding(self.mesh, PartitionSpec(*self.config.data_sharding))
# loading from pretrained here causes a crash when trying to compile the model
# Failed to load HSACO: HIP_ERROR_NoBinaryForGpu
model_params = configs[self.config.model_name].params
flux = FluxTransformer2DModel(
num_layers=model_params.depth,
num_single_layers=model_params.depth_single_blocks,
in_channels=model_params.in_channels,
attention_head_dim=int(model_params.hidden_size / model_params.num_heads),
num_attention_heads=model_params.num_heads,
joint_attention_dim=model_params.context_in_dim,
pooled_projection_dim=model_params.vec_in_dim,
mlp_ratio=model_params.mlp_ratio,
qkv_bias=model_params.qkv_bias,
theta=model_params.theta,
guidance_embeds=model_params.guidance_embed,
axes_dims_rope=model_params.axes_dim,
dtype=self.config.activations_dtype,
weights_dtype=self.config.weights_dtype,
attention_kernel=self.config.attention,
flash_block_sizes=flash_block_sizes,
mesh=self.mesh,
precision=precision,
)

return JfluxPipeline(t5, clip, flux, ae, dtype=self.config.activations_dtype, sharding=data_sharding, scheduler=None)
13 changes: 13 additions & 0 deletions src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,21 @@

BATCH = "activation_batch"
LENGTH = "activation_length"
EMBED = "activation_embed"
HEAD = "activation_heads"
KV_BATCH = "activation_kv_batch"
KV_HEAD = "activation_kv_heads"
KV_HEAD_DIM = "activation_kv_head_dim"
D_KV = "activation_kv"
KEEP_1 = "activation_keep_1"
KEEP_2 = "activation_keep_2"
CONV_OUT = "activation_conv_out_channels"

# needed for flash attention
MODEL_MODE_AUTOREGRESSIVE = "autoregressive"
MODEL_MODE_PREFILL = "prefill"
MODEL_MODE_TRAIN = "train"

# A large negative mask value is used for masking to ensure that the
# softmax function assigns an extremely low probability to the masked positions.
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
Loading