Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix "best effort" hf sharding now that we have fancy meshes #622

Merged
merged 4 commits into from
Jun 11, 2024

Conversation

dlwh
Copy link
Member

@dlwh dlwh commented Jun 11, 2024

Fixes #609 I think

@dlwh
Copy link
Member Author

dlwh commented Jun 11, 2024

got a report it's working. @versae lemme know if it still gives you problems

@dlwh dlwh merged commit 7bdd375 into main Jun 11, 2024
5 checks passed
@dlwh dlwh deleted the fix_hf_loading branch June 11, 2024 06:59
@rjpower
Copy link
Collaborator

rjpower commented Jun 11, 2024

Dumb question, but why does this fix things? The original error was happening here:

lev_model = load_from_state_dict(state_dict)

Is it this linehaliax.partitioning._get_mesh() -- we now pick up the default mesh from the parent context and use that instead of inferring a sharding?

(I know a lot about meshes but almost nothing about JAX meshes, so I was a bit confused why it threw an error originally instead of just (maybe) popping up a warning and reshuffling the naively sharded data to the end form. I'm guessing either it wants an explicit copy between meshes or that there's some "lower-level" mesh where it doesn't have the information to reshard anymore.)

@dlwh
Copy link
Member Author

dlwh commented Jun 11, 2024

So... it's the _get_mesh.

It's kind of working around a problem in Haliax (which I'm now pretty sure is working around a problem in JAX) more than anything. named_jit takes three optional axis mapping arguments (input, output, context/compute), and expects a context mesh (I should probably make it take a mesh arg). https://github.com/stanford-crfm/haliax/blob/main/src/haliax/partitioning.py#L312-L327 . This is partially historical, for the pre-jax.Array era where arrays didn't know their shardings.

Now that they do, it ought to be the case that if input mapping isn't specified, it should just omit the input shardings. It should actually further be the case that we don't even use input_axis_mapping and just always preserve shardings. However, whenever I try to make that change, CPU tests fail when I use XLA_FLAGS=--xla_force_host_platform_device_count=8, and so I never pulled the trigger. I realized the other day this is probably a bug in JAX, since xla_force_host_platform_device_count is kind of an afterthought for debugging.

So what's happening is that we were using a different mesh than the "real one" and then telling jit that the shardings were using the "real mesh" down the road. This works around that by ensuring it's the same mesh... Gross but it fixes for three users and doesn't cause too much damage.

I'll see if i can do the real fix and just file a bug on the xla_force_host_platform_device_count thing.

@rjpower
Copy link
Collaborator

rjpower commented Jun 11, 2024

Ah interesting -- that makes sense -- thanks for the fix and explanation! I agree, the XLA CPU situation is always a bit of a gamble. It's great that it's there, but it definitely doesn't have the same functionality of the GPU/TPU side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incompatible devices for jitted computation
2 participants