Skip to content

Add Gemm+Elementwise+Gemm support #1774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 8, 2025
Merged

Conversation

dhernandez0
Copy link
Contributor

@dhernandez0 dhernandez0 commented Mar 12, 2025

This PR introduces GEMM+GEMM fusion. There are some things left for future PRs:

  • extend allowed types
  • Jenkins GEMM+GEMM (tuna-script.sh, perfRegression report, etc)
  • CK benchmark script
  • conv+gemm

I've created an epic for the pending tasks: https://github.com/ROCm/rocMLIR-internal/issues/1791

I've done an initial perf comparison of these GEMM+GEMM problems:

-transO false -transC false -transB true -transA false -t f32 -g 32 -m 8192 -n 8192 -k 128 -gemmO 128
-transO false -transC false -transB true -transA false -t f32 -g 16 -m 16384 -n 16384 -k 128 -gemmO 128

Which correspond to these 4 GEMMs:

#-transO false -transC false -transB true -transA false -t f32 -g 32 -m 8192 -n 8192 -k 128 -gemmO 128
-transB true -transA false -g 32 -m 8192 -n 8192 -k 128 -t f32 -out_datatype f32
-transB false -transA false -g 32 -m 8192 -n 128 -k 8192 -t f32 -out_datatype f32
#-transO false -transC false -transB true -transA false -t f32 -g 16 -m 16384 -n 16384 -k 128 -gemmO 128
-transB true -transA false -g 16 -m 16384 -n 16384 -k 128 -t f32 -out_datatype f32
-transB false -transA false -g 16 -m 16384 -n 128 -k 16384 -t f32 -out_datatype f32

Using tuningRunner and perfRunner I get the following run-times:

run-time (ms) PR 2 GEMMs
Problem 1 8.9470 10.9019
Problem 2 17.5556 22.1153

@dhernandez0 dhernandez0 self-assigned this Mar 12, 2025
@dhernandez0 dhernandez0 requested a review from causten as a code owner March 12, 2025 12:17
@dhernandez0 dhernandez0 force-pushed the 1704-gemm-elementwise-gemm branch from a5dd470 to 0b3e9e1 Compare March 12, 2025 12:31
Comment on lines 571 to 575
RewritePatternSet patternsGemmElementwiseGemm(&ctx);
patternsGemmElementwiseGemm.add<GemmElementwiseGemmRewritePattern>(&ctx);
if (failed(applyOpPatternsGreedily(
getOperations<rock::GemmElementwiseGemmOp>(func),
std::move(patternsGemmElementwiseGemm), config)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder instead of applying each patterns separately if you can provide ranking as did in TosaToRock.

or if there is some other way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue here is that we are replacing (for example) GemmOp by another GemmOp. So, we need to use GreedyRewriteStrictness::ExistingOps, applyOpPatternsGreedily and pass exactly the operations that need to run the pass. Otherwise, the matcher will match the new GemmOp again and we'll go into an infinite loop.

It might work by passing all operations and all rewrite patterns, but I think that can be done in another clean up PR, I'm already doing too many changes here.

@dhernandez0 dhernandez0 force-pushed the 1704-gemm-elementwise-gemm branch 3 times, most recently from 8943210 to ee9f776 Compare March 24, 2025 15:47
@dhernandez0 dhernandez0 changed the title [DRAFT] Add Gemm+Elementwise+Gemm support Add Gemm+Elementwise+Gemm support Mar 24, 2025
@dhernandez0 dhernandez0 mentioned this pull request Mar 24, 2025
1 task
Copy link

codecov bot commented Mar 25, 2025

Codecov Report

Attention: Patch coverage is 67.18547% with 253 lines in your changes missing coverage. Please review.

Project coverage is 78.37%. Comparing base (1fb495e) to head (8d27a5c).
Report is 2 commits behind head on develop.

Files with missing lines Patch % Lines
mlir/tools/rocmlir-gen/rocmlir-gen.cpp 72.56% 38 Missing and 24 partials ⚠️
mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp 34.52% 40 Missing and 15 partials ⚠️
mlir/lib/Dialect/Rock/IR/RockDialect.cpp 54.70% 49 Missing and 4 partials ⚠️
...ialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp 75.74% 34 Missing and 15 partials ⚠️
...lir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp 77.77% 20 Missing and 6 partials ⚠️
mlir/include/mlir/Dialect/Rock/IR/GemmGemmSize.h 0.00% 5 Missing ⚠️
.../Dialect/Rock/Transforms/AffixTuningParameters.cpp 88.23% 2 Missing ⚠️
mlir/lib/Dialect/Rock/Transforms/Regularize.cpp 0.00% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #1774      +/-   ##
===========================================
- Coverage    78.60%   78.37%   -0.23%     
===========================================
  Files           99      100       +1     
  Lines        29389    29768     +379     
  Branches      4379     4442      +63     
===========================================
+ Hits         23100    23332     +232     
- Misses        4492     4600     +108     
- Partials      1797     1836      +39     
Flag Coverage Δ
mfma 78.37% <67.18%> (-0.23%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@dhernandez0 dhernandez0 force-pushed the 1704-gemm-elementwise-gemm branch 2 times, most recently from bd2c24e to 54237b0 Compare March 26, 2025 07:53
@dhernandez0 dhernandez0 mentioned this pull request Mar 28, 2025
2 tasks
: Rock_Op<"gemm_elementwise_gemm", [DeclareOpInterfaceMethods<
MemoryEffectsOpInterface>,
RockFusionRoot]>,
AllElementTypesMatch<["a", "b", "c"]>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if elementwise op converts the data type in between ? In that case

a * b may generate output of different dtype than expected by c

we should check for allowed ElementWise Ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see the comment below, we don't know if the element-wise tensors are used for indirect things (for example, a mask). So, forcing the type with AllElementTypesMatch to be the same would limit valid fusions.

But you are right, I'm not sure how this is verified for attention either, I think we probably fail when lowering instead of at rock.attention verifiers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you are right, I'm not sure how this is verified for attention either, I think we probably fail when lowering instead of at rock.attention verifiers.

It is better to fail during verification than during lowering IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I'd prefer to fix this in the future, because it affects attention as well. Also, in practice I'm not sure if this is an issue when lowering from migraphx or only in hand-crafted rock IR kernels.

AllElementTypesMatch<["a", "b", "c"]>,
Arguments<(ins TensorOrMemRefOf<[F32]>:$a, TensorOrMemRefOf<[F32]>:$b,
TensorOrMemRefOf<[F32]>:$c,
Variadic<AnyTensorOrMemRef>:$elemwiseInputs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dtypes that are allowed for the elementwise inputs will be tied to dtypes that GEMM expects later on.

Therefore it can't be AnyTensorOrMemRef

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can have a tensor that is used to indirectly change a*b (for example causal attention for attention kernels). So, the tensors don't have to be the same type.

Comment on lines 338 to 340
Type elemTypeQ = cast<MemRefType>(op.getA().getType()).getElementType();
Type elemTypeK = cast<MemRefType>(op.getB().getType()).getElementType();
Type elemTypeV = cast<MemRefType>(op.getC().getType()).getElementType();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering perhaps we can classify both Attention & GemmEwiseGemm into
GemmPlusGemmWrapperInterface and use that instead whereever we want to apply same treatment for both Attention and GemmEwiseGemm.

You can then use dyn_cast on GemmPlusGemmWrapperInterface to check if it is attention or GemmEwGemm and based on that you can enable/disable softmax and other stuffs.

@umangyadav
Copy link
Member

@dhernandez0 this PR has merge conflicts.

@dhernandez0 dhernandez0 force-pushed the 1704-gemm-elementwise-gemm branch from 3ae73c6 to a7a804e Compare April 7, 2025 14:22
@dhernandez0
Copy link
Contributor Author

@dhernandez0 this PR has merge conflicts.

solved

/*desc=*/[{
Set the tuning parameters attribute of the first GEMM

This is needed for --affix-tuning-params to work and can go away if it does
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can go away if it doesn't require it ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, this is just copy-paste from RockGemmWrapperInterface. Not sure what the context of this is, we use the method, so it can't go away currently. Should I remove the sentence?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please remove the sentence

let summary = "Attention operation of transformer models";
let description = [{
Performs the operation out = SOFTMAX((queries * keys) .* scale) * values.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it meant scale for preSoftmaxElemwise Ops . Do we want to keep it ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can see here: b1e26db
scale was an actual input for attentionop, I guess that was before pre-softmax fusion was implemented. So, it made sense to have scale there as it was an input. Now, scale is just one possible fusion. So, to me it doesn't make sense to have it in the description anymore. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can mention it also does elementwise op before doing softmax.

return GemmGemmSize(g, m, k, n, o);
}

static LogicalResult verifyAttentionOp(RockGemmGemmWrapperInterface op,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
static LogicalResult verifyAttentionOp(RockGemmGemmWrapperInterface op,
static LogicalResult verifyGemmPlusGemmLikeOp(RockGemmGemmWrapperInterface op,

Copy link
Member

@umangyadav umangyadav left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any pressing comments in particular. Looks good.

@dhernandez0 dhernandez0 force-pushed the 1704-gemm-elementwise-gemm branch from 3f02a22 to d7ad776 Compare April 8, 2025 10:09
@dhernandez0 dhernandez0 force-pushed the 1704-gemm-elementwise-gemm branch from 4821b84 to 8d27a5c Compare April 8, 2025 10:41
@dhernandez0 dhernandez0 merged commit ceae1ec into develop Apr 8, 2025
14 of 25 checks passed
@dhernandez0 dhernandez0 deleted the 1704-gemm-elementwise-gemm branch April 8, 2025 13:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants