Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix groupnorm int32 index overflow #1845

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions apex/contrib/csrc/group_norm/group_norm_nhwc.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ struct Group_norm_nhwc_fwd_params {
// The number of instances in the batch.
int n;
// The height and width of each activation map. The number of channels.
int h, w, c, hw, hwc;
int64_t h, w, c, hw, hwc;
// The number of groups.
int groups;
// Do we apply the Swish activation function?
Expand All @@ -138,7 +138,7 @@ struct Group_norm_nhwc_fwd_params {
// The number of groups in each block.
int groups_per_block;
// The number of channels per group = c / groups.
int channels_per_group;
int channels_per_group;
// The number of channels per block = groups_per_block * channels_per_group.
int channels_per_block;
// The inverse of hwc in floats (to compute mean/var).
Expand All @@ -149,7 +149,7 @@ struct Group_norm_nhwc_fwd_params {

////////////////////////////////////////////////////////////////////////////////////////////////////

void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params&,
void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params&,
size_t &red_buffer_elts);

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -190,7 +190,7 @@ struct Group_norm_nhwc_bwd_params {
// The number of instances in the batch.
int n;
// The height and width of each activation map. The number of channels.
int h, w, c, hw, hwc;
int64_t h, w, c, hw, hwc;
// The number of groups.
int groups;
// Do we apply the Swish activation function?
Expand Down
10 changes: 5 additions & 5 deletions apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ __global__ void group_norm_nhwc_bwd_sum_kernel(Group_norm_nhwc_bwd_params params
// The first activation loaded by that block.
int hw_begin = blockIdx.y * params.acts_per_block;
// The last activation loaded by that block.
int hw_end = min(hw_begin + params.acts_per_block, params.hw);
int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw);

// The gradients for gamma/beta.
float2 dgamma = make_float2(0.f, 0.f), dbeta = make_float2(0.f, 0.f);
Expand Down Expand Up @@ -212,7 +212,7 @@ void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params &params,

// Define the number of blocks per activation map. That's a simple heuristic.
int blocks_per_act_slice = 0;
if( params.c >= 1280 ) {
if( params.c >= 1280 ) {
blocks_per_act_slice = 128 / params.n;
} else if( params.c >= 640 ) {
blocks_per_act_slice = 256 / params.n;
Expand Down Expand Up @@ -267,13 +267,13 @@ void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params &params,
// Make sure a group does not span multiple blocks.
assert(params.channels_per_block % params.channels_per_group == 0);

// The number of elements in the reduction buffer (for the sums and sums of squared).
// The number of elements in the reduction buffer (for the sums and sums of squared).
zeroed_red_buffer_elts = params.n * params.groups * 2 + params.c * 2;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

void group_norm_nhwc_bwd_two_passes_sum(const Group_norm_nhwc_bwd_params &params,
void group_norm_nhwc_bwd_two_passes_sum(const Group_norm_nhwc_bwd_params &params,
cudaStream_t stream) {

// The dimension of the grid.
Expand Down Expand Up @@ -376,7 +376,7 @@ __global__ void group_norm_nhwc_bwd_scale_kernel(Group_norm_nhwc_bwd_params para
// The first activation loaded by that block.
int hw_begin = blockIdx.y * params.acts_per_block;
// The last activation loaded by that block.
int hw_end = min(hw_begin + params.acts_per_block, params.hw);
int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw);

// Iterate over the activations to compute the sums.
for( int hwi = hw_begin; hwi < hw_end; ++hwi ) {
Expand Down
10 changes: 5 additions & 5 deletions apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ __global__ void group_norm_nhwc_fwd_sum_kernel(Group_norm_nhwc_fwd_params params
// The first activation loaded by that block.
int hw_begin = blockIdx.y * params.acts_per_block;
// The last activation loaded by that block.
int hw_end = min(hw_begin + params.acts_per_block, params.hw);
int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw);

// The sums.
float sum = 0.f, sum_sq = 0.f;
Expand Down Expand Up @@ -132,7 +132,7 @@ void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params &params,

// Define the number of blocks per activation map. That's a simple heuristic.
int blocks_per_act_slice = 0;
if( params.c >= 1280 ) {
if( params.c >= 1280 ) {
blocks_per_act_slice = 128 / params.n;
} else if( params.c >= 640 ) {
blocks_per_act_slice = 256 / params.n;
Expand Down Expand Up @@ -186,13 +186,13 @@ void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params &params,
// Make sure a group does not span multiple blocks.
assert(params.channels_per_block % params.channels_per_group == 0);

// The number of elements in the reduction buffer (for the sums and sums of squared).
// The number of elements in the reduction buffer (for the sums and sums of squared).
zeroed_red_buffer_elts = params.n * params.groups * 2;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

void group_norm_nhwc_fwd_two_passes_sum(const Group_norm_nhwc_fwd_params &params,
void group_norm_nhwc_fwd_two_passes_sum(const Group_norm_nhwc_fwd_params &params,
cudaStream_t stream) {

// The dimension of the grid.
Expand Down Expand Up @@ -285,7 +285,7 @@ __global__ void group_norm_nhwc_fwd_scale_kernel(Group_norm_nhwc_fwd_params para
// The first activation loaded by that block.
int hw_begin = blockIdx.y * params.acts_per_block;
// The last activation loaded by that block.
int hw_end = min(hw_begin + params.acts_per_block, params.hw);
int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw);

// Iterate over the activations to compute the sums.
for( int hwi = hw_begin; hwi < hw_end; ++hwi ) {
Expand Down
9 changes: 5 additions & 4 deletions apex/contrib/test/group_norm/test_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def verify_group_norm(self,
dx_tst, dw_tst, db_tst = [t.grad.clone() for t in [x, weight, bias]]

# compare
torch.testing.assert_close(y_tst, y_ref, atol=4e-2, rtol=0)
torch.testing.assert_close(dx_tst, dx_ref, atol=4e-2, rtol=0)
torch.testing.assert_close(dw_tst, dw_ref, atol=4e-2, rtol=0)
torch.testing.assert_close(db_tst, db_ref, atol=4e-2, rtol=0)
torch.testing.assert_close(y_tst, y_ref, atol=7e-2, rtol=0)
torch.testing.assert_close(dx_tst, dx_ref, atol=7e-2, rtol=0)
torch.testing.assert_close(dw_tst, dw_ref, atol=7e-2, rtol=0)
torch.testing.assert_close(db_tst, db_ref, atol=7e-2, rtol=0)
Comment on lines +92 to +95
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this for the large tensor?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, float reduction accuracy error would grow up with larger tensor


def test_fp16_one_pass_algo(self):
self.verify_group_norm(cuda_group_norm_nhwc_one_pass, act="")
Expand Down Expand Up @@ -177,6 +177,7 @@ def test_16_groups(self):
[8, 1920, 32, 32],
[8, 1920, 16, 16],
[8, 2560, 8, 8],
[1, 128, 16128, 1200],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to require about 50 GB, how about checking the memory size like

diff --git a/apex/contrib/test/group_norm/test_group_norm.py b/apex/contrib/test/group_norm/test_group_norm.py
index 5675749..687c2e1 100644
--- a/apex/contrib/test/group_norm/test_group_norm.py
+++ b/apex/contrib/test/group_norm/test_group_norm.py
@@ -177,8 +177,9 @@ class GroupNormTest(unittest.TestCase):
             [8, 1920, 32, 32],
             [8, 1920, 16, 16],
             [8, 2560, 8, 8],
-            [1, 128, 16128, 1200],
         ]
+        if torch.cuda.get_device_properties().total_memory > 50_000_000_000:
+            sizes.append([1, 128, 16128, 1200])
         for sz in sizes:
             n, c, h, w = sz
             self.verify_group_norm(GroupNorm,

Copy link
Author

@tlogn tlogn Nov 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked seriously, this tensor gots 128 * 16128 * 1200=2,477,260,800 elements, about 5GB. This kernel will take about 10GB, so this check seems not so necessary ?

]
for sz in sizes:
n, c, h, w = sz
Expand Down