Skip to content
Aritra Roy Gosthipaty edited this page Sep 4, 2024 · 5 revisions

Welcome to the flux-jax wiki!

Today I learnt

device = "gpu" if jax.get_devices("gpu") else "cpu"

Jax dot product attention Torch scaled dot product attention
Batch, Target_Length, Num_Heads, Hidden_Dims Batch, Num_Heads, Target_Length, Hidden_Dims
Clone this wiki locally