-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP flax.linen flux.1 ported from nnx jflux #141
Conversation
@jcaraban thank you for the PR. I was able to get it running on my TPU machine, however, I am getting a noisy image. Perhaps I am not running it correctly. Regarding the TODO items, my branch, I'll keep analyzing the code and think of a good merging strategy. I might end up cherry picking some changes from here Going back to the noisy image. First I tried running using the following command: python src/maxdiffusion/generate_jflux.py src/maxdiffusion/configs/base_jflux.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ Then I noticed its using schnell so I tried running as follows: python src/maxdiffusion/generate_jflux.py src/maxdiffusion/configs/base_jflux.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ num_inference_steps=4 guidance_scale=0.0 Any ideas? This is the output: |
Hi! Regarding the noise in the output. The model by default starts from random weights, to download the pre-trained model and use it, you first need to run: python src/maxdiffusion/create_jflux_checkpoints.py src/maxdiffusion/configs/base_jflux.yml This will store the weights in an orbax checkpoint, that will then be loaded when running inference and training. |
Sorry @entrpn I forgot to mention that
|
878796f fixes the issue with
|
@jcaraban I took bits and pieces from your code and got my branch To run it, no need to do weight conversion, just run: python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ I'm working with limited memory TPU, so I offload the t5 encoder. If you have more memory, you can set |
@jcaraban thanks again for sharing your PR. I incorporated some more changes from it and now I have it working e2e on my branch for both schnell and dev. I also fixed the flash attention issue. Sharing some inference numbers using a TPU v4-8 as well. I played with the sharding for a bit, and this was the best one I found so far. I'll need to spend more time with the profiler to see if I can find a better sharding configuration. Let me know what you think, I plan to merge this to main next week.
|
Nice @entrpn you did some great progress! Will check that your branch also works for our GPU configurations, and compare the performance to our stale branch. At this point merging your These are the inference numbers for MI300x in your branch:
|
Background in #115
Note @entrpn this PR is not in a mergeable state, but it runs so you can see where we stand. We are working on integrating FA and multi-node (from MaxText) so we are bound to mess the code even more. Hence it might be easier if you cherry-pick our changes rather than merging what we have, but let's discuss that.
TODO: