diff --git a/jflux/math.py b/jflux/math.py index 69f42c4..b502b03 100644 --- a/jflux/math.py +++ b/jflux/math.py @@ -1,7 +1,7 @@ import jax -from flax import nnx from chex import Array from einops import rearrange +from flax import nnx from jax import numpy as jnp diff --git a/jflux/model.py b/jflux/model.py index bd02bac..6ef9c35 100644 --- a/jflux/model.py +++ b/jflux/model.py @@ -6,13 +6,13 @@ from jax import numpy as jnp from jax.typing import DTypeLike +from jflux.modules import DoubleStreamBlock, MLPEmbedder, SingleStreamBlock from jflux.modules.layers import ( AdaLayerNorm, Embed, Identity, timestep_embedding, ) -from jflux.modules import DoubleStreamBlock, MLPEmbedder, SingleStreamBlock @dataclass diff --git a/jflux/util.py b/jflux/util.py index 722a2e4..53928d9 100644 --- a/jflux/util.py +++ b/jflux/util.py @@ -8,9 +8,9 @@ from jax.typing import DTypeLike from safetensors.numpy import load_file as load_sft +from jflux.model import Flux, FluxParams from jflux.modules.autoencoder import AutoEncoder, AutoEncoderParams from jflux.modules.conditioner import HFEmbedder -from jflux.model import Flux, FluxParams @dataclass diff --git a/tests/modules/test_autoencoder.py b/tests/modules/test_autoencoder.py index a4aa73a..1fea730 100644 --- a/tests/modules/test_autoencoder.py +++ b/tests/modules/test_autoencoder.py @@ -1,27 +1,25 @@ -from einops import rearrange import jax.numpy as jnp +import numpy as np import torch +from einops import rearrange from flax import nnx - from flux.modules.autoencoder import AttnBlock as TorchAttnBlock -from flux.modules.autoencoder import ResnetBlock as TorchResnetBlock -from flux.modules.autoencoder import Downsample as TorchDownsample -from flux.modules.autoencoder import Upsample as TorchUpsample -from flux.modules.autoencoder import Encoder as TorchEncoder -from flux.modules.autoencoder import Decoder as TorchDecoder from flux.modules.autoencoder import AutoEncoder as TorchAutoEncoder from flux.modules.autoencoder import AutoEncoderParams as TorchAutoEncoderParams +from flux.modules.autoencoder import Decoder as TorchDecoder +from flux.modules.autoencoder import Downsample as TorchDownsample +from flux.modules.autoencoder import Encoder as TorchEncoder +from flux.modules.autoencoder import ResnetBlock as TorchResnetBlock +from flux.modules.autoencoder import Upsample as TorchUpsample from jflux.modules.autoencoder import AttnBlock as JaxAttnBlock -from jflux.modules.autoencoder import ResnetBlock as JaxResnetBlock -from jflux.modules.autoencoder import Downsample as JaxDownsample -from jflux.modules.autoencoder import Upsample as JaxUpsample -from jflux.modules.autoencoder import Encoder as JaxEncoder -from jflux.modules.autoencoder import Decoder as JaxDecoder from jflux.modules.autoencoder import AutoEncoder as JaxAutoEncoder from jflux.modules.autoencoder import AutoEncoderParams as JaxAutoEncoderParams - -import numpy as np +from jflux.modules.autoencoder import Decoder as JaxDecoder +from jflux.modules.autoencoder import Downsample as JaxDownsample +from jflux.modules.autoencoder import Encoder as JaxEncoder +from jflux.modules.autoencoder import ResnetBlock as JaxResnetBlock +from jflux.modules.autoencoder import Upsample as JaxUpsample from tests.utils import torch2jax diff --git a/tests/test_math.py b/tests/test_math.py index 30aa35b..c0107a3 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -1,16 +1,15 @@ import unittest +import jax.numpy as jnp import numpy as np import torch -import jax.numpy as jnp - -from flux.math import rope as torch_rope from flux.math import apply_rope as torch_apply_rope from flux.math import attention as torch_attention +from flux.math import rope as torch_rope -from jflux.math import rope as jax_rope from jflux.math import apply_rope as jax_apply_rope from jflux.math import attention as jax_attention +from jflux.math import rope as jax_rope class TestMath(np.testing.TestCase):