Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tutorial] Remove incorrect caching from softmax tutorial #5162

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Mogball
Copy link
Collaborator

@Mogball Mogball commented Nov 15, 2024

The fused softmax implementation in the tutorial precompiles the kernel to query the register usage of the kernel, based on the parameters used to specialize the kernel. On top of this, it implements a simple caching system for this step based on just the block size.

As noted in #4739, this caching is incorrect, because it's also not keyed on the num_stages constexpr argument or the shapes of the tensors. Since triton already has its own JIT compilation cache, and this caching bit is not really relevant to the tutorial, just remove it to get rid of the footgun.

The fused softmax implementation in the tutorial precompiles the kernel
to query the register usage of the kernel, based on the parameters used
to specialize the kernel. On top of this, it implements a simple caching
system for this step based on just the block size.

As noted in triton-lang#4739, this
caching is incorrect, because it's also not keyed on the `num_stages`
constexpr argument or the shapes of the tensors. Since triton already
has its own JIT compilation cache, and this caching bit is not really
relevant to the tutorial, just remove it to get rid of the footgun.
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
kernels[BLOCK_SIZE] = (kernel, num_programs)
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is still a bug here because warmup uses MockTensor:

return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)

And MockTensor doesn't respect the real pointer's alignment:

@staticmethod
def data_ptr():
return 0 # optimistically assumes multiple of 16

This seems questionable though. I'm not sure why warmup couldn't just operate on the real tensors since it doesn't actually run any device code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Let me try digging into this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants