Skip to content

Commit 437c621

Browse files
[mlir][memref] Remove redundant memref.tensor_store op (#71010)
`bufferization.materialize_in_destination` should be used instead. Both ops bufferize to a memcpy. This change also conceptually cleans up the memref dialect a bit: the memref dialect no longer contains ops that operate on tensor values.
1 parent 6529c9a commit 437c621

File tree

15 files changed

+53
-191
lines changed

15 files changed

+53
-191
lines changed

mlir/docs/LangRef.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,12 @@ func.func @mul(%A: tensor<100x?xf32>, %B: tensor<?x50xf32>) -> (tensor<100x50xf3
7777
7878
// Allocate addressable "buffers" and copy tensors %A and %B into them.
7979
%A_m = memref.alloc(%n) : memref<100x?xf32>
80-
memref.tensor_store %A to %A_m : memref<100x?xf32>
80+
bufferization.materialize_in_destination %A in writable %A_m
81+
: (tensor<100x?xf32>, memref<100x?xf32>) -> ()
8182
8283
%B_m = memref.alloc(%n) : memref<?x50xf32>
83-
memref.tensor_store %B to %B_m : memref<?x50xf32>
84+
bufferization.materialize_in_destination %B in writable %B_m
85+
: (tensor<?x50xf32>, memref<?x50xf32>) -> ()
8486
8587
// Call function @multiply passing memrefs as arguments,
8688
// and getting returned the result of the multiplication.

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
103103

104104
```
105105
%alloc = memref.alloc() : memref<10xf32>
106-
memref.tensor_store %dest, %alloc : memref<10xf32>
106+
bufferization.materialize_in_destination %dest in %alloc
107107
memref.store %f, %alloc[%pos] : memref<10xf32>
108108
%0 = bufferization.to_tensor %alloc restrict writable : memref<10xf32>
109109
```
@@ -118,15 +118,16 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
118118
An optional memory space attribute can be specified for the materialized
119119
buffer allocation.
120120

121-
If a memory copy is needed, a "memref.tensor_store" is used when possible.
122-
This is an op with tensor semantics that will bufferize to a memory copy
123-
later. Which concrete op will be used for the memory copy is up to the
124-
bufferization framework. Alternatively, a custom memcpy op can be specified
125-
via `memcpy_op`. Currently supported are "memref.copy" and "linalg.copy".
126-
In that case, the source of each memcpy must not have a custom memory space.
127-
Furthermore, because the future buffer layout unknown for a given tensor,
128-
a fully dynamic layout is assumed for best compatibility. Users should use
129-
"memref.tensor_store" when possible.
121+
If a memory copy is needed, a "bufferization.materialize_in_destination" is
122+
used when possible. This is an op with tensor semantics that will bufferize
123+
to a memory copy later. Which concrete op will be used for the memory copy
124+
is up to the bufferization framework. Alternatively, a custom memcpy op can
125+
be specified via `memcpy_op`. Currently supported are "memref.copy" and
126+
"linalg.copy". In that case, the source of each memcpy must not have a
127+
custom memory space. Furthermore, because the future buffer layout unknown
128+
for a given tensor, a fully dynamic layout is assumed for best
129+
compatibility. Users should use "bufferization.materialize_in_destination"
130+
when possible.
130131

131132
"memref.alloc" is used for new buffer allocations. The buffer is deallocated
132133
at the end of the block if the "emit_dealloc" attribute is present. If this
@@ -148,7 +149,8 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
148149

149150
let arguments = (ins TransformHandleTypeInterface:$target,
150151
OptionalAttr<AnyAttr>:$memory_space,
151-
DefaultValuedAttr<StrAttr, "\"memref.tensor_store\"">:
152+
DefaultValuedAttr<StrAttr,
153+
"\"bufferization.materialize_in_destination\"">:
152154
$memcpy_op,
153155
DefaultValuedAttr<StrAttr, "\"memref.alloc\"">:
154156
$alloc_op,

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@ struct BufferizeToAllocationOptions {
5252
enum class AllocOp { MemrefAlloc = 0, MemrefAlloca = 1 };
5353
AllocOp allocOp = AllocOp::MemrefAlloc;
5454

55-
enum class MemcpyOp { MemrefTensorStore = 0, MemrefCopy = 1, LinalgCopy = 2 };
56-
MemcpyOp memcpyOp = MemcpyOp::MemrefTensorStore;
55+
enum class MemcpyOp {
56+
MaterializeInDestination = 0,
57+
MemrefCopy = 1,
58+
LinalgCopy = 2
59+
};
60+
MemcpyOp memcpyOp = MemcpyOp::MaterializeInDestination;
5761

5862
/// If set to "true", only the destination tensor operands are bufferized to
5963
/// a new allocation (and wrapped in "bufferization.to_tensor"), but not the
@@ -68,7 +72,8 @@ struct BufferizeToAllocationOptions {
6872
};
6973

7074
/// Materialize a buffer allocation for the given tensor.pad op and lower the
71-
/// op to linalg.fill/linalg.generic + memref.tensor_store. E.g.:
75+
/// op to linalg.fill/linalg.generic + bufferization.materialize_in_destination.
76+
/// E.g.:
7277
///
7378
/// %0 = tensor.pad low[%l] high[%h] %t ...
7479
///
@@ -77,7 +82,7 @@ struct BufferizeToAllocationOptions {
7782
/// %alloc = memref.alloc
7883
/// linalg.fill ... outs(%alloc)
7984
/// %subview = memref.subview %alloc [%l] [...] [1]
80-
/// memref.tensor_store %t, %subview
85+
/// bufferization.materialize_in_destination %t in %subview
8186
/// %0 = bufferization.to_tensor %alloc restrict writable
8287
///
8388
/// In addition to rewriting the IR as shown above, this function returns the
@@ -98,7 +103,7 @@ Value bufferizeToAllocation(RewriterBase &rewriter,
98103
/// is lowered to:
99104
///
100105
/// %alloc = memref.alloc
101-
/// memref.tensor_store %t, %subview
106+
/// bufferization.materialize_in_destination %t in %subview
102107
/// vector.mask {
103108
/// vector.transfer_write %arg0, %alloc : vector<16xf32>, memref<?xf32>
104109
/// } : vector<16xi1>

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2095,37 +2095,6 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20952095
let hasVerifier = 1;
20962096
}
20972097

2098-
//===----------------------------------------------------------------------===//
2099-
// TensorStoreOp
2100-
//===----------------------------------------------------------------------===//
2101-
2102-
def TensorStoreOp : MemRef_Op<"tensor_store",
2103-
[SameOperandsShape, SameOperandsElementType,
2104-
TypesMatchWith<"type of 'value' matches tensor equivalent of 'memref'",
2105-
"memref", "tensor",
2106-
"getTensorTypeFromMemRefType($_self)">]> {
2107-
let summary = "tensor store operation";
2108-
let description = [{
2109-
Stores the contents of a tensor into a memref. The first operand is a value
2110-
of tensor type, the second operand is a value of memref type. The shapes and
2111-
element types of these must match, and are specified by the memref type.
2112-
2113-
Example:
2114-
2115-
```mlir
2116-
%9 = dim %8, 1 : tensor<4x?xf32>
2117-
%10 = memref.alloc(%9) : memref<4x?xf32, #layout, memspace0>
2118-
memref.tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0>
2119-
```
2120-
}];
2121-
2122-
let arguments = (ins AnyTensor:$tensor, Arg<AnyRankedOrUnrankedMemRef,
2123-
"the reference to store to",
2124-
[MemWriteAt<0, FullEffect>]>:$memref);
2125-
2126-
let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
2127-
}
2128-
21292098
//===----------------------------------------------------------------------===//
21302099
// TransposeOp
21312100
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h

Lines changed: 0 additions & 21 deletions
This file was deleted.

mlir/include/mlir/InitAllDialects.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
5454
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
5555
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
56-
#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
5756
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
5857
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
5958
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -157,7 +156,6 @@ inline void registerAllDialects(DialectRegistry &registry) {
157156
linalg::registerTilingInterfaceExternalModels(registry);
158157
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
159158
memref::registerAllocationOpInterfaceExternalModels(registry);
160-
memref::registerBufferizableOpInterfaceExternalModels(registry);
161159
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
162160
memref::registerValueBoundsOpInterfaceExternalModels(registry);
163161
memref::registerMemorySlotExternalModels(registry);

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,11 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
585585
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
586586
buffer = getDest();
587587
}
588-
rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), buffer);
588+
auto srcBuffer = getBuffer(rewriter, getSource(), options);
589+
if (failed(srcBuffer))
590+
return failure();
591+
if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
592+
return failure();
589593
replaceOpWithBufferizedValues(rewriter, getOperation(),
590594
tensorDest ? ValueRange(buffer) : ValueRange());
591595
return success();
@@ -682,8 +686,9 @@ LogicalResult MaterializeInDestinationOp::verify() {
682686
void MaterializeInDestinationOp::build(OpBuilder &builder,
683687
OperationState &state, Value source,
684688
Value dest) {
685-
assert(isa<TensorType>(dest.getType()) && "expected tensor type");
686-
build(builder, state, /*result=*/dest.getType(), source, dest);
689+
auto destTensorType = dyn_cast<TensorType>(dest.getType());
690+
build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
691+
source, dest);
687692
}
688693

689694
bool MaterializeInDestinationOp::isWritable(Value value,

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,9 @@ DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
241241
rewriter.setListener(&newOpsListener);
242242

243243
linalg::BufferizeToAllocationOptions options;
244-
if (getMemcpyOp() == "memref.tensor_store") {
245-
options.memcpyOp =
246-
linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefTensorStore;
244+
if (getMemcpyOp() == "bufferization.materialize_in_destination") {
245+
options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
246+
MaterializeInDestination;
247247
} else if (getMemcpyOp() == "memref.copy") {
248248
options.memcpyOp =
249249
linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy;
@@ -296,7 +296,7 @@ void transform::BufferizeToAllocationOp::getEffects(
296296
}
297297

298298
LogicalResult transform::BufferizeToAllocationOp::verify() {
299-
if (getMemcpyOp() != "memref.tensor_store" &&
299+
if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
300300
getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
301301
return emitOpError() << "unsupported memcpy op";
302302
if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")

mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,14 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
6363
assert(memrefDest.getType().isa<MemRefType>() && "expected ranked memref");
6464

6565
switch (options.memcpyOp) {
66-
case linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefTensorStore:
66+
case linalg::BufferizeToAllocationOptions::MemcpyOp::
67+
MaterializeInDestination: {
6768
// Note: This is the preferred way of memcpy'ing because no layout map
6869
// and/or memory space must be specified for the source.
69-
b.create<memref::TensorStoreOp>(loc, tensorSource, memrefDest);
70-
break;
70+
auto materializeOp = b.create<bufferization::MaterializeInDestinationOp>(
71+
loc, tensorSource, memrefDest);
72+
materializeOp.setWritable(true);
73+
} break;
7174
case linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy: {
7275
// TODO: Support custom memory space on source.
7376
// We do not know the layout map of the source yet, so use a fully dynamic
@@ -238,7 +241,7 @@ Value linalg::bufferizeToAllocation(
238241
rewriter.setInsertionPointAfter(fillOp);
239242
}
240243

241-
// Create memref.tensor_store.
244+
// Create memcpy.
242245
SmallVector<OpFoldResult> sizes =
243246
getMixedSizes(rewriter, loc, padOp.getSource());
244247
SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(),

mlir/lib/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 0 additions & 63 deletions
This file was deleted.

0 commit comments

Comments
 (0)