@@ -537,3 +537,71 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>, %sz0:
537
537
// CHECK: %[[GENERIC:.+]] = linalg.generic
538
538
// CHECK-SAME: ins(%[[EXPAND_ARG0]] :
539
539
// CHECK: return %[[GENERIC]]
540
+
541
+ // -----
542
+
543
+ func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
544
+ %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
545
+ %cst = arith.constant 0 : i32
546
+ %padded_0 = tensor.pad %expand low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
547
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
548
+ %arg5: index, %arg6: index, %arg7: index, %arg8: index):
549
+ tensor.yield %cst : i32
550
+ } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
551
+ return %padded_0 : tensor<8x3x4x17x6x7x8x14xi32>
552
+ }
553
+ // CHECK: func @fuse_by_collapsing_pad(
554
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
555
+ // CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
556
+ // CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
557
+ // CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
558
+ // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
559
+ // CHECK-SAME: output_shape [8, 3, 4, 17, 6, 7, 8, 14] : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
560
+ // CHECK: return %[[EXPAND]]
561
+
562
+ // -----
563
+
564
+ func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> {
565
+ %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
566
+ %cst = arith.constant 0 : i32
567
+ %padded_0 = tensor.pad %expand low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
568
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
569
+ %arg5: index, %arg6: index, %arg7: index, %arg8: index):
570
+ tensor.yield %cst : i32
571
+ } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
572
+ return %padded_0 : tensor<8x5x4x17x6x7x8x14xi32>
573
+ }
574
+ // CHECK: func @no_fuse_by_collapsing_pad(
575
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
576
+ // CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
577
+ // CHECK-SAME: output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
578
+ // CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND_ARG0]]
579
+ // CHECK-SAME: low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
580
+ // CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
581
+ // CHECK: return %[[PAD]]
582
+
583
+ // -----
584
+
585
+ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
586
+ %s0 : index, %s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index,
587
+ %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?x?x?xf32> {
588
+ %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
589
+ %cst = arith.constant 0.0 : f32
590
+ %padded_0 = tensor.pad %expand low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
591
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
592
+ tensor.yield %cst : f32
593
+ } : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
594
+ return %padded_0 : tensor<?x?x?x?x?x?xf32>
595
+ }
596
+ // CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
597
+ // CHECK: func @fuse_by_collapsing_dynamic_pad(
598
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
599
+ // CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
600
+ // CHECK: %[[PAD_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[L0]], %[[H0]], %[[S0]]]
601
+ // CHECK: %[[PAD_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[L1]], %[[H1]], %[[S3]]]
602
+ // CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
603
+ // CHECK-SAME: low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
604
+ // CHECK: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
605
+ // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
606
+ // CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
607
+ // CHECK: return %[[EXPAND]]
0 commit comments