Skip to content

Commit

Permalink
style: apply ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar committed Oct 5, 2024
1 parent 4459e8d commit ac05cd7
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 21 deletions.
2 changes: 1 addition & 1 deletion jflux/math.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax
from flax import nnx
from chex import Array
from einops import rearrange
from flax import nnx
from jax import numpy as jnp


Expand Down
2 changes: 1 addition & 1 deletion jflux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from jax import numpy as jnp
from jax.typing import DTypeLike

from jflux.modules import DoubleStreamBlock, MLPEmbedder, SingleStreamBlock
from jflux.modules.layers import (
AdaLayerNorm,
Embed,
Identity,
timestep_embedding,
)
from jflux.modules import DoubleStreamBlock, MLPEmbedder, SingleStreamBlock


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion jflux/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from jax.typing import DTypeLike
from safetensors.numpy import load_file as load_sft

from jflux.model import Flux, FluxParams
from jflux.modules.autoencoder import AutoEncoder, AutoEncoderParams
from jflux.modules.conditioner import HFEmbedder
from jflux.model import Flux, FluxParams


@dataclass
Expand Down
26 changes: 12 additions & 14 deletions tests/modules/test_autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
from einops import rearrange
import jax.numpy as jnp
import numpy as np
import torch
from einops import rearrange
from flax import nnx

from flux.modules.autoencoder import AttnBlock as TorchAttnBlock
from flux.modules.autoencoder import ResnetBlock as TorchResnetBlock
from flux.modules.autoencoder import Downsample as TorchDownsample
from flux.modules.autoencoder import Upsample as TorchUpsample
from flux.modules.autoencoder import Encoder as TorchEncoder
from flux.modules.autoencoder import Decoder as TorchDecoder
from flux.modules.autoencoder import AutoEncoder as TorchAutoEncoder
from flux.modules.autoencoder import AutoEncoderParams as TorchAutoEncoderParams
from flux.modules.autoencoder import Decoder as TorchDecoder
from flux.modules.autoencoder import Downsample as TorchDownsample
from flux.modules.autoencoder import Encoder as TorchEncoder
from flux.modules.autoencoder import ResnetBlock as TorchResnetBlock
from flux.modules.autoencoder import Upsample as TorchUpsample

from jflux.modules.autoencoder import AttnBlock as JaxAttnBlock
from jflux.modules.autoencoder import ResnetBlock as JaxResnetBlock
from jflux.modules.autoencoder import Downsample as JaxDownsample
from jflux.modules.autoencoder import Upsample as JaxUpsample
from jflux.modules.autoencoder import Encoder as JaxEncoder
from jflux.modules.autoencoder import Decoder as JaxDecoder
from jflux.modules.autoencoder import AutoEncoder as JaxAutoEncoder
from jflux.modules.autoencoder import AutoEncoderParams as JaxAutoEncoderParams

import numpy as np
from jflux.modules.autoencoder import Decoder as JaxDecoder
from jflux.modules.autoencoder import Downsample as JaxDownsample
from jflux.modules.autoencoder import Encoder as JaxEncoder
from jflux.modules.autoencoder import ResnetBlock as JaxResnetBlock
from jflux.modules.autoencoder import Upsample as JaxUpsample
from tests.utils import torch2jax


Expand Down
7 changes: 3 additions & 4 deletions tests/test_math.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import unittest

import jax.numpy as jnp
import numpy as np
import torch
import jax.numpy as jnp

from flux.math import rope as torch_rope
from flux.math import apply_rope as torch_apply_rope
from flux.math import attention as torch_attention
from flux.math import rope as torch_rope

from jflux.math import rope as jax_rope
from jflux.math import apply_rope as jax_apply_rope
from jflux.math import attention as jax_attention
from jflux.math import rope as jax_rope


class TestMath(np.testing.TestCase):
Expand Down

0 comments on commit ac05cd7

Please sign in to comment.