13
13
#include " iree/compiler/Dialect/Stream/IR/StreamOps.h"
14
14
#include " mlir/Dialect/Arith/IR/Arith.h"
15
15
#include " mlir/Dialect/Tensor/IR/Tensor.h"
16
+ #include " mlir/IR/BuiltinDialect.h"
16
17
#include " mlir/IR/IRMapping.h"
17
18
#include " mlir/Interfaces/FunctionInterfaces.h"
18
19
19
20
namespace mlir ::iree_compiler {
20
21
21
22
namespace {
22
23
24
+ static SmallVector<Value> flattenValues (ArrayRef<ValueRange> values) {
25
+ SmallVector<Value> vec;
26
+ for (auto v : values) {
27
+ vec.append (v.begin (), v.end ());
28
+ }
29
+ return vec;
30
+ }
31
+
23
32
// Inserts a sizeof calculation for the given tensor value type and dims.
24
33
// This should only be used to produce sizes for values produced by an op; the
25
34
// size of operands must be queried from the input resource.
@@ -142,6 +151,33 @@ struct ConvertTensorCastLikeOp
142
151
}
143
152
};
144
153
154
+ template <typename CastOpTy>
155
+ struct ConvertOneToNTensorCastLikeOp
156
+ : public AffinityAwareConversionPattern<CastOpTy> {
157
+ using AffinityAwareConversionPattern<
158
+ CastOpTy>::AffinityAwareConversionPattern;
159
+ LogicalResult matchAndRewrite (
160
+ CastOpTy op,
161
+ typename OpConversionPattern<CastOpTy>::OneToNOpAdaptor adaptor,
162
+ ConversionPatternRewriter &rewriter) const override {
163
+ auto resultAffinityAttr = this ->lookupResultAffinity (op.getResult ());
164
+ Value convertedSource =
165
+ getStreamResourceFromOneToNOpOperandAdaptor (adaptor.getSource ());
166
+ auto source = this ->transferTensorOperand (op.getLoc (), op.getSource (),
167
+ convertedSource,
168
+ resultAffinityAttr, rewriter);
169
+ auto resultSize =
170
+ buildResultSizeOf (op.getLoc (), op.getResult (), op.getResultDims (),
171
+ resultAffinityAttr, rewriter);
172
+ auto unknownType = rewriter.getType <IREE::Stream::ResourceType>();
173
+ rewriter.replaceOpWithNewOp <IREE::Stream::TensorCloneOp>(
174
+ op, unknownType, source.resource , op.getSource ().getType (),
175
+ op.getSourceDims (), source.resourceSize , op.getResult ().getType (),
176
+ flattenValues (adaptor.getResultDims ()), resultSize, resultAffinityAttr);
177
+ return success ();
178
+ }
179
+ };
180
+
145
181
struct ConvertTensorAllocaOp
146
182
: public AffinityOpConversionPattern<IREE::Flow::TensorAllocaOp> {
147
183
using AffinityOpConversionPattern::AffinityOpConversionPattern;
@@ -237,46 +273,55 @@ struct ConvertTensorTransferOp
237
273
};
238
274
239
275
struct ConvertTensorSliceOp
240
- : public AffinityOpConversionPattern <IREE::Flow::TensorSliceOp> {
241
- using AffinityOpConversionPattern::AffinityOpConversionPattern ;
276
+ : public AffinityOneToNOpConversionPattern <IREE::Flow::TensorSliceOp> {
277
+ using AffinityOneToNOpConversionPattern::AffinityOneToNOpConversionPattern ;
242
278
LogicalResult matchAndRewriteOnAffinity (
243
- IREE::Flow::TensorSliceOp op, OpAdaptor adaptor,
279
+ IREE::Flow::TensorSliceOp op, OneToNOpAdaptor adaptor,
244
280
IREE::Stream::AffinityAttr executionAffinityAttr,
245
281
ConversionPatternRewriter &rewriter) const override {
282
+ Value convertedSource =
283
+ getStreamResourceFromOneToNOpOperandAdaptor (adaptor.getSource ());
246
284
auto source =
247
- transferTensorOperand (op.getLoc (), op.getSource (), adaptor. getSource () ,
285
+ transferTensorOperand (op.getLoc (), op.getSource (), convertedSource ,
248
286
executionAffinityAttr, rewriter);
249
287
auto resultSize =
250
288
buildResultSizeOf (op.getLoc (), op.getResult (), op.getResultDims (),
251
289
executionAffinityAttr, rewriter);
252
290
auto unknownType = rewriter.getType <IREE::Stream::ResourceType>();
253
291
rewriter.replaceOpWithNewOp <IREE::Stream::TensorSliceOp>(
254
292
op, unknownType, source.resource , op.getSource ().getType (),
255
- op.getSourceDims (), source.resourceSize , adaptor.getStartIndices (),
256
- adaptor.getLengths (), op.getResult ().getType (), adaptor.getResultDims (),
257
- resultSize, executionAffinityAttr);
293
+ op.getSourceDims (), source.resourceSize ,
294
+ flattenValues (adaptor.getStartIndices ()),
295
+ flattenValues (adaptor.getLengths ()), op.getResult ().getType (),
296
+ flattenValues (adaptor.getResultDims ()), resultSize,
297
+ executionAffinityAttr);
258
298
return success ();
259
299
}
260
300
};
261
301
262
302
struct ConvertTensorUpdateOp
263
- : public AffinityOpConversionPattern <IREE::Flow::TensorUpdateOp> {
264
- using AffinityOpConversionPattern::AffinityOpConversionPattern ;
303
+ : public AffinityOneToNOpConversionPattern <IREE::Flow::TensorUpdateOp> {
304
+ using AffinityOneToNOpConversionPattern::AffinityOneToNOpConversionPattern ;
265
305
LogicalResult matchAndRewriteOnAffinity (
266
- IREE::Flow::TensorUpdateOp op, OpAdaptor adaptor,
306
+ IREE::Flow::TensorUpdateOp op, OneToNOpAdaptor adaptor,
267
307
IREE::Stream::AffinityAttr executionAffinityAttr,
268
308
ConversionPatternRewriter &rewriter) const override {
309
+ Value convertedTarget =
310
+ getStreamResourceFromOneToNOpOperandAdaptor (adaptor.getTarget ());
269
311
auto target =
270
- transferTensorOperand (op.getLoc (), op.getTarget (), adaptor. getTarget () ,
312
+ transferTensorOperand (op.getLoc (), op.getTarget (), convertedTarget ,
271
313
executionAffinityAttr, rewriter);
314
+ Value convertedUpdate =
315
+ getStreamResourceFromOneToNOpOperandAdaptor (adaptor.getUpdate ());
272
316
auto update =
273
- transferTensorOperand (op.getLoc (), op.getUpdate (), adaptor. getUpdate () ,
317
+ transferTensorOperand (op.getLoc (), op.getUpdate (), convertedUpdate ,
274
318
executionAffinityAttr, rewriter);
275
319
rewriter.replaceOpWithNewOp <IREE::Stream::TensorUpdateOp>(
276
320
op, target.resource .getType (), target.resource ,
277
- op.getTarget ().getType (), adaptor.getTargetDims (), target.resourceSize ,
278
- adaptor.getStartIndices (), update.resource , op.getUpdate ().getType (),
279
- op.getUpdateDims (), update.resourceSize , executionAffinityAttr);
321
+ op.getTarget ().getType (), flattenValues (adaptor.getTargetDims ()),
322
+ target.resourceSize , flattenValues (adaptor.getStartIndices ()),
323
+ update.resource , op.getUpdate ().getType (), op.getUpdateDims (),
324
+ update.resourceSize , executionAffinityAttr);
280
325
return success ();
281
326
}
282
327
};
@@ -296,10 +341,12 @@ struct ConvertTensorLoadOp
296
341
: public AffinityAwareConversionPattern<IREE::Flow::TensorLoadOp> {
297
342
using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
298
343
LogicalResult
299
- matchAndRewrite (IREE::Flow::TensorLoadOp op, OpAdaptor adaptor,
344
+ matchAndRewrite (IREE::Flow::TensorLoadOp op, OneToNOpAdaptor adaptor,
300
345
ConversionPatternRewriter &rewriter) const override {
346
+ Value convertedSource =
347
+ getStreamResourceFromOneToNOpOperandAdaptor (adaptor.getSource ());
301
348
auto source = resolveTensorOperand (op.getLoc (), op.getSource (),
302
- adaptor. getSource () , rewriter);
349
+ convertedSource , rewriter);
303
350
304
351
// If the source is not a staging resource then we need to transfer it to
305
352
// a staging resource. We slice out just what is being loaded so that we
@@ -311,10 +358,13 @@ struct ConvertTensorLoadOp
311
358
auto stagingType = rewriter.getType <IREE::Stream::ResourceType>(
312
359
IREE::Stream::Lifetime::Staging);
313
360
auto resultType = getTypeConverter ()->convertType (op.getResult ().getType ());
361
+ SmallVector<Value> convertedSourceDims =
362
+ flattenValues (adaptor.getSourceDims ());
363
+ SmallVector<Value> convertedIndices = flattenValues (adaptor.getIndices ());
314
364
if (source.resource .getType () == stagingType) {
315
365
rewriter.replaceOpWithNewOp <IREE::Stream::TensorLoadOp>(
316
366
op, resultType, source.resource , op.getSource ().getType (),
317
- adaptor. getSourceDims () , source.resourceSize , adaptor. getIndices () );
367
+ convertedSourceDims , source.resourceSize , convertedIndices );
318
368
return success ();
319
369
}
320
370
@@ -328,19 +378,18 @@ struct ConvertTensorLoadOp
328
378
/* result_affinity=*/ source.affinity );
329
379
rewriter.replaceOpWithNewOp <IREE::Stream::TensorLoadOp>(
330
380
op, resultType, transferOp.getResult (), sourceEncoding,
331
- adaptor.getSourceDims (), transferOp.getResultSize (),
332
- adaptor.getIndices ());
381
+ convertedSourceDims, transferOp.getResultSize (), convertedIndices);
333
382
return success ();
334
383
}
335
384
336
385
// Slice out the individual element value.
337
386
IndexSet indexSet (op.getLoc (), rewriter);
338
- indexSet.populate (adaptor. getIndices () );
387
+ indexSet.populate (convertedIndices );
339
388
SmallVector<Value> sliceIndices;
340
389
SmallVector<Value> sliceLengths;
341
390
SmallVector<Value> loadIndices;
342
391
SmallVector<int64_t > resultDims;
343
- for (auto index : adaptor. getIndices () ) {
392
+ for (auto index : convertedIndices ) {
344
393
// TODO(benvanik): support larger buffer slices.
345
394
sliceIndices.push_back (index);
346
395
sliceLengths.push_back (indexSet.get (1 ));
@@ -354,9 +403,8 @@ struct ConvertTensorLoadOp
354
403
op.getLoc (), resultEncoding, ValueRange{}, source.affinity );
355
404
auto sliceOp = rewriter.create <IREE::Stream::TensorSliceOp>(
356
405
op.getLoc (), source.resource .getType (), source.resource , sourceEncoding,
357
- adaptor.getSourceDims (), source.resourceSize , sliceIndices,
358
- sliceLengths, resultEncoding, ValueRange{}, resultSize,
359
- source.affinity );
406
+ convertedSourceDims, source.resourceSize , sliceIndices, sliceLengths,
407
+ resultEncoding, ValueRange{}, resultSize, source.affinity );
360
408
auto transferOp = rewriter.create <IREE::Stream::AsyncTransferOp>(
361
409
op.getLoc (), stagingType, sliceOp.getResult (), sliceOp.getResultSize (),
362
410
sliceOp.getResultSize (),
@@ -713,10 +761,10 @@ struct ConvertCollectiveSendRecvOp
713
761
};
714
762
715
763
struct ConvertDispatchOp
716
- : public AffinityOpConversionPattern <IREE::Flow::DispatchOp> {
717
- using AffinityOpConversionPattern::AffinityOpConversionPattern ;
764
+ : public AffinityOneToNOpConversionPattern <IREE::Flow::DispatchOp> {
765
+ using AffinityOneToNOpConversionPattern::AffinityOneToNOpConversionPattern ;
718
766
LogicalResult matchAndRewriteOnAffinity (
719
- IREE::Flow::DispatchOp op, OpAdaptor adaptor,
767
+ IREE::Flow::DispatchOp op, OneToNOpAdaptor adaptor,
720
768
IREE::Stream::AffinityAttr executionAffinityAttr,
721
769
ConversionPatternRewriter &rewriter) const override {
722
770
// Zero is going to be used for each operand to start.
@@ -729,8 +777,11 @@ struct ConvertDispatchOp
729
777
SmallVector<Value> dispatchOperandEnds;
730
778
SmallVector<Value> dispatchOperandLengths;
731
779
SmallVector<Value> operandSizes;
780
+
781
+ SmallVector<Value> convertedArguments =
782
+ getStreamResourcesFromOneToNOpOperandAdaptors (adaptor.getArguments ());
732
783
for (auto [oldOperand, newOperand] :
733
- llvm::zip_equal (op.getArguments (), adaptor. getArguments () )) {
784
+ llvm::zip_equal (op.getArguments (), convertedArguments )) {
734
785
if (llvm::isa<ShapedType>(oldOperand.getType ())) {
735
786
auto newOperandCast =
736
787
transferTensorOperand (op.getLoc (), oldOperand, newOperand,
@@ -774,10 +825,10 @@ struct ConvertDispatchOp
774
825
}
775
826
776
827
auto newOp = rewriter.replaceOpWithNewOp <IREE::Stream::AsyncDispatchOp>(
777
- op, resultTypes, adaptor.getWorkload (), adaptor. getEntryPointsAttr ( ),
778
- dispatchOperands, dispatchOperandSizes, dispatchOperandOffsets ,
779
- dispatchOperandEnds, dispatchOperandLengths, resultSizes ,
780
- adaptor.getTiedOperandsAttr (), executionAffinityAttr);
828
+ op, resultTypes, flattenValues ( adaptor.getWorkload ()),
829
+ adaptor. getEntryPointsAttr (), dispatchOperands, dispatchOperandSizes ,
830
+ dispatchOperandOffsets, dispatchOperandEnds, dispatchOperandLengths ,
831
+ resultSizes, adaptor.getTiedOperandsAttr (), executionAffinityAttr);
781
832
newOp->setDialectAttrs (op->getDialectAttrs ());
782
833
return success ();
783
834
}
@@ -1105,8 +1156,8 @@ void populateFlowToStreamConversionPatterns(
1105
1156
RewritePatternSet &patterns) {
1106
1157
patterns
1107
1158
.insert <ConvertTensorConstantOp, ConvertTensorDynamicConstantOp,
1108
- ConvertTensorCastLikeOp <IREE::Flow::TensorReshapeOp>,
1109
- ConvertTensorCastLikeOp <IREE::Flow::TensorBitCastOp>,
1159
+ ConvertOneToNTensorCastLikeOp <IREE::Flow::TensorReshapeOp>,
1160
+ ConvertOneToNTensorCastLikeOp <IREE::Flow::TensorBitCastOp>,
1110
1161
ConvertTensorAllocaOp, ConvertTensorEmptyOp, ConvertTensorSplatOp,
1111
1162
ConvertTensorCloneOp, ConvertTensorTransferOp,
1112
1163
ConvertTensorSliceOp, ConvertTensorUpdateOp, ConvertTensorLoadOp,
0 commit comments