Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DisposableElementsAttr for ZHigh constant propagation #3013

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

tungld
Copy link
Collaborator

@tungld tungld commented Nov 18, 2024

  • This patch reverts the PR [NNPA] Memory reduction of stickified constant by stickifying at file writing  #2917
  • This patch extends the constant prop mechanism (that currently applies to ONNXConstantOp) to ZHighStickifiedConstantOp, so that the zhigh constant prop can benefit from the memory management for DisposableElementsAttr: In particular, it does:
    • Make all ZHigh optimizations/rewritings work with DisposableElementsAttr
    • Change ZHighStikifiedConstantOp to accept DisposableElementsAttr. Its parser and printer are changed to read/write from/to DenseElementsAttr for lit tests.
    • Change ZHighConstantPropagationPass so that it reads data directly from DisposableElementsAttr instead of DenseElementsAttr.
    • Add two passes: ZHighDisposableGabageCollector and ZHighScrubDisposablePass to manage buffers used by ZHighStickifiedConstantOp

Quick experiment: the peak compile memory consumptions of #2917 and this PR when compiling the gpt2-large model for NNPA (744M parameters, the constant file's size is 3.2GB) are quite similar, both are about 9GB.

This patch contains the reverting code so it's no easy to follow. To ease the review, I merge all new changes (not the reverting code) into a single commit: 265ff90. Please look at this commit for review.

… at file writing (onnx#2917)"

This reverts commit 33b466e.

Signed-off-by: Tung D. Le <[email protected]>
@tungld tungld force-pushed the zhigh-constprop-with-dispose branch from 68979b3 to 48cf039 Compare November 18, 2024 06:05
@AlexandreEichenberger
Copy link
Collaborator

@tungld Just to understand the high level and without the class names. You are using Soren's approach of applying "logical" operations to the constants so that for example if we have <large-constant-tensor> * 2 + 1 we just keep the original <lage-constant-tensor> and tag along mult and add operators to the constant, so that if we need to materialize the multipied/added large constant tensor, we first apply these operations before generating the constant? And so, you added a stickify (presumably we never need an unstickify) operator?

@tungld
Copy link
Collaborator Author

tungld commented Nov 18, 2024

You are using Soren's approach of applying "logical" operations to the constants

Yes, I extend it for ZHigh operations so the same approach is used for both ONNX and ZHigh until lowering to krnl. We can extend it to cover krnl operations but it needs more work and I didn't do it in this PR.

Signed-off-by: Tung D. Le <[email protected]>
Signed-off-by: Tung D. Le <[email protected]>
Signed-off-by: Tung D. Le <[email protected]>
Signed-off-by: Tung D. Le <[email protected]>
Signed-off-by: Tung D. Le <[email protected]>
Signed-off-by: Tung D. Le <[email protected]>
Signed-off-by: Tung D. Le <[email protected]>
Copy link
Collaborator

@imaihal imaihal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment on lines +317 to +336
struct ConstantStickPattern : public OpRewritePattern<ZHighStickOp> {
ConstantStickPattern(MLIRContext *context) : OpRewritePattern(context) {}
LogicalResult matchAndRewrite(
ZHighStickOp stickOp, PatternRewriter &rewriter) const override {
Value input = stickOp.getIn();
Value output = stickOp.getOut();
StringAttr layout = stickOp.getLayoutAttr();

// Match
if (!isDenseONNXConstant(input)) {
return failure();
}

// Rewrite
Value stickifiedVal =
createConstantForStick(rewriter, output, input, layout);
replaceOpAndGC(rewriter, stickOp, stickifiedVal);
return success();
}
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, you replaced tablegen with cpp. Is this because replaceOpAndGC() is difficult to use in tablegen format?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I don't know how to do that with tablegen, let me know if you know how to do it. Thanks!

@AlexandreEichenberger
Copy link
Collaborator

Can you post here for ref the improvements you got, just for future reference purpose. Does not need to be super detailed. Thanks

@imaihal
Copy link
Collaborator

imaihal commented Nov 25, 2024

Can you post here for ref the improvements you got, just for future reference purpose. Does not need to be super detailed. Thanks

I put the measurement results of gpt2-large and Mistral-7b.
In gpt2-large, the peak memory usage reduced from 8.9 GB to 7.4 GB, and compilation time becomes faster from 5 min 22sec to 4 min 30 sec. Left graph is current main, and right graph is PR3013.

image

In Mistral-7b, the peak memory usage reduced from 33.2 GB to 27.9 GB, and compilation time becomes faster from 17 min 4 sec to 13 min 58 sec. Left graph is current main, and right graph is PR3013.

image

@tungld
Copy link
Collaborator Author

tungld commented Nov 27, 2024

@jenkins-droid test this please

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants