Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP flax.linen flux.1 ported from nnx jflux #141

Closed
wants to merge 3 commits into from

Conversation

jcaraban
Copy link

Background in #115

Note @entrpn this PR is not in a mergeable state, but it runs so you can see where we stand. We are working on integrating FA and multi-node (from MaxText) so we are bound to mess the code even more. Hence it might be easier if you cherry-pick our changes rather than merging what we have, but let's discuss that.

TODO:

  • rename JFlux to FlaxFlux or whatever appropriate
  • clean max_utils.py, we copied logic from MaxText to test multi-node
  • check checkpointing logic, in case we applied some dubious fixes
  • double-check Flux TFlop calculation, might be 10% off
  • migrate ae_flux_nnx.py and HFEmbedder(nnx.Module) to linen?
  • how to get rid of autoencoder model torch->jax porting?
  • verify and improve sharding
  • verify the time scheduler

@entrpn
Copy link
Collaborator

entrpn commented Jan 28, 2025

@jcaraban thank you for the PR. I was able to get it running on my TPU machine, however, I am getting a noisy image. Perhaps I am not running it correctly.

Regarding the TODO items, my branch, flux_impl has some of these items; made changes to the vae so that it can load the Flux vae and the text encoders from the diffusers Flux checkpoint onto the cpu using the transformers library and implementing flash attention. I haven't verified my sharding, its probably wrong right now.

I'll keep analyzing the code and think of a good merging strategy. I might end up cherry picking some changes from here
into my branch as you mentioned.

Going back to the noisy image. First I tried running using the following command:

python src/maxdiffusion/generate_jflux.py src/maxdiffusion/configs/base_jflux.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/

Then I noticed its using schnell so I tried running as follows:

python src/maxdiffusion/generate_jflux.py src/maxdiffusion/configs/base_jflux.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ num_inference_steps=4 guidance_scale=0.0

Any ideas? This is the output:

maxdiff_img_1

@ksikiric
Copy link

Hi! Regarding the noise in the output. The model by default starts from random weights, to download the pre-trained model and use it, you first need to run:

python src/maxdiffusion/create_jflux_checkpoints.py src/maxdiffusion/configs/base_jflux.yml

This will store the weights in an orbax checkpoint, that will then be loaded when running inference and training.

@jcaraban
Copy link
Author

Sorry @entrpn I forgot to mention that create_jflux_checkpoint.py needs to run first so generate_jflux.py picks that checkpoint, otherwise it will start from noise as @ksikiric explained above. I also noticed that schnell works but dev fails to load the checkpoint, I'll check that. I pushed a commit to split the yaml in two, so it's easier to test the two dev/schnell variants. Run like:

❯ python -m src.maxdiffusion.create_jflux_checkpoints src/maxdiffusion/configs/base_jflux_schnell.yml
❯ python -m src.maxdiffusion.generate_jflux src/maxdiffusion/configs/base_jflux_schnell.yml

@jcaraban
Copy link
Author

878796f fixes the issue with dev, we had forgotten to map the guidance_in layer. Now I see proper images for both schnell and dev inference. Run the later like:

❯ python -m src.maxdiffusion.create_jflux_checkpoints src/maxdiffusion/configs/base_jflux_dev.yml
❯ python -m src.maxdiffusion.generate_jflux src/maxdiffusion/configs/base_jflux_dev.yml

@entrpn
Copy link
Collaborator

entrpn commented Jan 29, 2025

@jcaraban @ksikiric thanks for sharing. I was able to convert the checkpoint to orbax, but I'm still working on loading the models. I'm trying to load them on a TPUv4, but I'm OOMing. Will keep you updated once I figure it out.

@entrpn
Copy link
Collaborator

entrpn commented Jan 30, 2025

@jcaraban I took bits and pieces from your code and got my branchflux_impl running e2e, however, images are distorted coming out and I'm still working through that. If you want to take a look at some of the changes, the code shards and jits the model and runs much faster on my end, although the sharding is probably still not optimal.

To run it, no need to do weight conversion, just run:

python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/

I'm working with limited memory TPU, so I offload the t5 encoder. If you have more memory, you can set offload_encoders to False in the config or cli.

@entrpn
Copy link
Collaborator

entrpn commented Feb 3, 2025

@jcaraban thanks again for sharing your PR. I incorporated some more changes from it and now I have it working e2e on my branch for both schnell and dev. I also fixed the flash attention issue. Sharing some inference numbers using a TPU v4-8 as well. I played with the sharding for a bit, and this was the best one I found so far. I'll need to spend more time with the profiler to see if I can find a better sharding configuration. Let me know what you think, I plan to merge this to main next week.

# dev
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="A dog riding a skateboard"
# TPU v4-8
# Compile time: 100.4s.
# Inference time: 4.4s.
# Inference time: 4.5s.
#schnell
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="A dog riding a skateboard"
# TPU v4-8
# Compile time: 163.4s.
# Inference time: 58.8s.
# Inference time: 58.8s.

@jcaraban
Copy link
Author

jcaraban commented Feb 4, 2025

Nice @entrpn you did some great progress! Will check that your branch also works for our GPU configurations, and compare the performance to our stale branch. At this point merging your flux_impl branch is possibly the easier way forward. However If you do, given that you have cherry-picked changes from our work, I hope you can in some way attribute that contribution to our employer AMD 🙂

These are the inference numbers for MI300x in your branch:

## dev 50 steps
- Compile: 123.1s
- Inference: 14.6s

## schnell 50 steps
- Compile: 53.8s
- Inference: 13.4s

@jcaraban
Copy link
Author

Closing as #146 cherry-picked the best parts
The training logic is rebased in #147

@jcaraban jcaraban closed this Feb 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants