-
Notifications
You must be signed in to change notification settings - Fork 43
Add Gemm+Elementwise+Gemm support #1774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
a5dd470
to
0b3e9e1
Compare
RewritePatternSet patternsGemmElementwiseGemm(&ctx); | ||
patternsGemmElementwiseGemm.add<GemmElementwiseGemmRewritePattern>(&ctx); | ||
if (failed(applyOpPatternsGreedily( | ||
getOperations<rock::GemmElementwiseGemmOp>(func), | ||
std::move(patternsGemmElementwiseGemm), config))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder instead of applying each patterns separately if you can provide ranking as did in TosaToRock.
or if there is some other way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue here is that we are replacing (for example) GemmOp by another GemmOp. So, we need to use GreedyRewriteStrictness::ExistingOps, applyOpPatternsGreedily and pass exactly the operations that need to run the pass. Otherwise, the matcher will match the new GemmOp again and we'll go into an infinite loop.
It might work by passing all operations and all rewrite patterns, but I think that can be done in another clean up PR, I'm already doing too many changes here.
8943210
to
ee9f776
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #1774 +/- ##
===========================================
- Coverage 78.60% 78.37% -0.23%
===========================================
Files 99 100 +1
Lines 29389 29768 +379
Branches 4379 4442 +63
===========================================
+ Hits 23100 23332 +232
- Misses 4492 4600 +108
- Partials 1797 1836 +39
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
bd2c24e
to
54237b0
Compare
: Rock_Op<"gemm_elementwise_gemm", [DeclareOpInterfaceMethods< | ||
MemoryEffectsOpInterface>, | ||
RockFusionRoot]>, | ||
AllElementTypesMatch<["a", "b", "c"]>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if elementwise
op converts the data type in between ? In that case
a * b
may generate output of different dtype than expected by c
we should check for allowed ElementWise Ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see the comment below, we don't know if the element-wise tensors are used for indirect things (for example, a mask). So, forcing the type with AllElementTypesMatch to be the same would limit valid fusions.
But you are right, I'm not sure how this is verified for attention either, I think we probably fail when lowering instead of at rock.attention verifiers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But you are right, I'm not sure how this is verified for attention either, I think we probably fail when lowering instead of at rock.attention verifiers.
It is better to fail during verification than during lowering IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I'd prefer to fix this in the future, because it affects attention as well. Also, in practice I'm not sure if this is an issue when lowering from migraphx or only in hand-crafted rock IR kernels.
AllElementTypesMatch<["a", "b", "c"]>, | ||
Arguments<(ins TensorOrMemRefOf<[F32]>:$a, TensorOrMemRefOf<[F32]>:$b, | ||
TensorOrMemRefOf<[F32]>:$c, | ||
Variadic<AnyTensorOrMemRef>:$elemwiseInputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think dtypes that are allowed for the elementwise inputs will be tied to dtypes that GEMM expects later on.
Therefore it can't be AnyTensorOrMemRef
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can have a tensor that is used to indirectly change a*b (for example causal attention for attention kernels). So, the tensors don't have to be the same type.
Type elemTypeQ = cast<MemRefType>(op.getA().getType()).getElementType(); | ||
Type elemTypeK = cast<MemRefType>(op.getB().getType()).getElementType(); | ||
Type elemTypeV = cast<MemRefType>(op.getC().getType()).getElementType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering perhaps we can classify both Attention & GemmEwiseGemm into
GemmPlusGemmWrapperInterface
and use that instead whereever we want to apply same treatment for both Attention and GemmEwiseGemm.
You can then use dyn_cast
on GemmPlusGemmWrapperInterface to check if it is attention or GemmEwGemm and based on that you can enable/disable softmax and other stuffs.
@dhernandez0 this PR has merge conflicts. |
3ae73c6
to
a7a804e
Compare
solved |
/*desc=*/[{ | ||
Set the tuning parameters attribute of the first GEMM | ||
|
||
This is needed for --affix-tuning-params to work and can go away if it does |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can go away if it doesn't require it ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, this is just copy-paste from RockGemmWrapperInterface. Not sure what the context of this is, we use the method, so it can't go away currently. Should I remove the sentence?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes please remove the sentence
let summary = "Attention operation of transformer models"; | ||
let description = [{ | ||
Performs the operation out = SOFTMAX((queries * keys) .* scale) * values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it meant scale
for preSoftmaxElemwise Ops . Do we want to keep it ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can see here: b1e26db
scale was an actual input for attentionop, I guess that was before pre-softmax fusion was implemented. So, it made sense to have scale there as it was an input. Now, scale is just one possible fusion. So, to me it doesn't make sense to have it in the description anymore. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can mention it also does elementwise op before doing softmax.
return GemmGemmSize(g, m, k, n, o); | ||
} | ||
|
||
static LogicalResult verifyAttentionOp(RockGemmGemmWrapperInterface op, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
static LogicalResult verifyAttentionOp(RockGemmGemmWrapperInterface op, | |
static LogicalResult verifyGemmPlusGemmLikeOp(RockGemmGemmWrapperInterface op, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have any pressing comments in particular. Looks good.
3f02a22
to
d7ad776
Compare
4821b84
to
8d27a5c
Compare
This PR introduces GEMM+GEMM fusion. There are some things left for future PRs:
I've created an epic for the pending tasks: https://github.com/ROCm/rocMLIR-internal/issues/1791
I've done an initial perf comparison of these GEMM+GEMM problems:
Which correspond to these 4 GEMMs:
Using tuningRunner and perfRunner I get the following run-times: