Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding layers and math #7

Merged
merged 9 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions jflux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@
from jax.typing import DTypeLike

from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from jflux.util import (
configs,
load_ae,
load_clip,
load_flow_model,
load_t5,
)
from jflux.util import configs, load_ae, load_clip, load_flow_model, load_t5


@dataclass
Expand Down
42 changes: 3 additions & 39 deletions jflux/math.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,21 @@
import typing

import jax
from chex import Array
from einops import rearrange
from flax import nnx
from jax import numpy as jnp


@typing.no_type_check
def attention(q: Array, k: Array, v: Array, pe: Array) -> Array:
# TODO (ariG23498): Change all usage of attention to use this function
q, k = apply_rope(q, k, pe)

# jax expects this shape
x = rearrange(x, "B H L D -> B L H D") # noqa
x = jax.nn.dot_product_attention(q, k, v)
x = rearrange(x, "B L H D -> B L (H D)") # reshape again
x = nnx.dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")

return x


def rope(pos: Array, dim: int, theta: int) -> Array:
"""
Generate Rotary Position Embedding (RoPE) for positional encoding.

Args:
pos (Array): Positional values, typically a sequence of positions in an array format.
dim (int): The embedding dimension, which must be an even number.
theta (int): A scaling parameter for RoPE that controls the frequency range of rotations.

Returns:
Array: Rotary embeddings with cosine and sine components for each position and dimension.
"""

# Embedding dimension must be an even number
assert dim % 2 == 0

# Generate the RoPE embeddings
scale = jnp.arange(0, dim, 2, dtype=jnp.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = jnp.einsum("...n,d->...nd", pos, omega)
Expand All @@ -45,26 +25,10 @@ def rope(pos: Array, dim: int, theta: int) -> Array:


def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]:
"""
Apply RoPE to the input query and key tensors.

Args:
xq (Array): Query tensor.
xk (Array): Key tensor.
freqs_cis (Array): RoPE frequencies.

Returns:
tuple[Array, Array]: Query and key tensors with RoPE applied.
"""
# Reshape and typecast the input tensors
xq_ = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 1, 2)

# Apply RoPE to the input tensors
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]

# Reshape and typecast the output tensors
return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(
xk.dtype
)
7 changes: 1 addition & 6 deletions jflux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,8 @@
from jax import numpy as jnp
from jax.typing import DTypeLike

from jflux.modules.layers import (
AdaLayerNorm,
Embed,
Identity,
timestep_embedding,
)
from jflux.modules import DoubleStreamBlock, MLPEmbedder, SingleStreamBlock
from jflux.modules.layers import AdaLayerNorm, Embed, Identity, timestep_embedding


@dataclass
Expand Down
Loading
Loading