llvm-project/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
Quentin Colombet 64f99842a6 [mlir][ExpandStridedMetadata] Handle collapse_shape of dim of size 1 gracefully
Collapsing dimensions of size 1 with random strides (a.k.a.
non-contiguous w.r.t. collapsed dimensions) is a grey area that we'd
like to clean-up. (See https://reviews.llvm.org/D136483#3909856)

That said, the implementation in `memref-to-llvm` currently skips
dimensions of size 1 when computing the stride of a group.

While longer term we may want to clean that up, for now matches this
behavior, at least in the static case.

For the dynamic case, for this patch we stick to `min(group strides)`.
However, if we want to handle the dynamic cases correctly while allowing
non-truly-contiguous dynamic size of 1, we would need to `if-then-else`
every dynamic size. In other words `min(stride_i, for all i in group and
dim_i != 1)`.

I didn't implement that in this patch at the moment since
`memref-to-llvm` is technically broken in the general case for this. (It
currently would only produce something sensible for row major tensors.)

Differential Revision: https://reviews.llvm.org/D139329
2022-12-08 07:32:01 +00:00

1293 lines
53 KiB
MLIR

// RUN: mlir-opt --expand-strided-metadata -split-input-file %s -o - | FileCheck %s
// CHECK-LABEL: func @extract_strided_metadata_constants
// CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32, strided<[4, 1], offset: 2>>)
func.func @extract_strided_metadata_constants(%base: memref<5x4xf32, strided<[4, 1], offset: 2>>)
-> (memref<f32>, index, index, index, index, index) {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %base :
memref<5x4xf32, strided<[4,1], offset:2>>
-> memref<f32>, index, index, index, index, index
// CHECK: %[[BASE]], %[[C2]], %[[C5]], %[[C4]], %[[C4]], %[[C1]]
return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
memref<f32>, index, index, index, index, index
}
// -----
// Check that we simplify subview(src) into:
// base, offset, sizes, strides xtract_strided_metadata src
// final_sizes = subSizes
// final_strides = <some math> strides
// final_offset = <some math> offset
// reinterpret_cast base to final_offset, final_sizes, final_ strides
//
// Orig strides: [s0, s1, s2]
// Sub strides: [subS0, subS1, subS2]
// => New strides: [s0 * subS0, s1 * subS1, s2 * subS2]
// ==> 1 affine map (used for each stride) with two values.
//
// Orig offset: origOff
// Sub offsets: [subO0, subO1, subO2]
// => Final offset: s0 * * subO0 + ... + s2 * * subO2 + origOff
// ==> 1 affine map with (rank * 2 + 1) symbols
//
// CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)>
// CHECK-LABEL: func @simplify_subview_all_dynamic
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index)
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK-DAG: %[[FINAL_STRIDE0:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE0]], %[[STRIDES]]#0]
// CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1]
// CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2]
//
// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[STRIDES]]#2]
//
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[FINAL_OFFSET]]], sizes: [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]], strides: [%[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]]]
//
// CHECK: return %[[RES]]
func.func @simplify_subview_all_dynamic(
%base: memref<?x?x?xf32, strided<[?,?,?], offset:?>>,
%offset0: index, %offset1: index, %offset2: index,
%size0: index, %size1: index, %size2: index,
%stride0: index, %stride1: index, %stride2: index)
-> memref<?x?x?xf32, strided<[?,?,?], offset:?>> {
%subview = memref.subview %base[%offset0, %offset1, %offset2]
[%size0, %size1, %size2]
[%stride0, %stride1, %stride2] :
memref<?x?x?xf32, strided<[?,?,?], offset: ?>> to
memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
return %subview : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
}
// -----
// Check that we simplify extract_strided_metadata of subview to
// base_buf, base_offset, base_sizes, base_strides = extract_strided_metadata
// strides = base_stride_i * subview_stride_i
// offset = base_offset + sum(subview_offsets_i * base_strides_i).
//
// This test also checks that we don't create useless arith operations
// when subview_offsets_i is 0.
//
// CHECK-LABEL: func @extract_strided_metadata_of_subview
// CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32>)
//
// Materialize the offset for dimension 1.
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
//
// Plain extract_strided_metadata.
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
//
// Final offset is:
// origOffset + (== 0)
// base_stride0 * subview_offset0 + (== 4 * 0 == 0)
// base_stride1 * subview_offset1 (== 1 * 2)
// == 2
//
// Return the new tuple.
// CHECK: return %[[BASE]], %[[C2]], %[[C2]], %[[C2]], %[[C4]], %[[C1]]
func.func @extract_strided_metadata_of_subview(%base: memref<5x4xf32>)
-> (memref<f32>, index, index, index, index, index) {
%subview = memref.subview %base[0, 2][2, 2][1, 1] :
memref<5x4xf32> to memref<2x2xf32, strided<[4, 1], offset: 2>>
%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
memref<2x2xf32, strided<[4,1], offset:2>>
-> memref<f32>, index, index, index, index, index
return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
memref<f32>, index, index, index, index, index
}
// -----
// Check that we simplify extract_strided_metadata of subview properly
// when dynamic sizes are involved.
// See extract_strided_metadata_of_subview for an explanation of the actual
// expansion.
// Orig strides: [64, 4, 1]
// Sub strides: [1, 1, 1]
// => New strides: [64, 4, 1]
//
// Orig offset: 0
// Sub offsets: [3, 4, 2]
// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210
//
// Final sizes == subview sizes == [%size, 6, 3]
//
// CHECK-LABEL: func @extract_strided_metadata_of_subview_with_dynamic_size
// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>,
// CHECK-SAME: %[[DYN_SIZE:.*]]: index)
//
// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK: return %[[BASE]], %[[C210]], %[[DYN_SIZE]], %[[C6]], %[[C3]], %[[C64]], %[[C4]], %[[C1]]
func.func @extract_strided_metadata_of_subview_with_dynamic_size(
%base: memref<8x16x4xf32>, %size: index)
-> (memref<f32>, index, index, index, index, index, index, index) {
%subview = memref.subview %base[3, 4, 2][%size, 6, 3][1, 1, 1] :
memref<8x16x4xf32> to memref<?x6x3xf32, strided<[64, 4, 1], offset: 210>>
%base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview :
memref<?x6x3xf32, strided<[64,4,1], offset: 210>>
-> memref<f32>, index, index, index, index, index, index, index
return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 :
memref<f32>, index, index, index, index, index, index, index
}
// -----
// Check that we simplify extract_strided_metadata of subview properly
// when the subview reduces the ranks.
// In particular the returned strides must come from #1 and #2 of the %strides
// value of the new extract_strided_metadata_of_subview, not #0 and #1.
// See extract_strided_metadata_of_subview for an explanation of the actual
// expansion.
//
// Orig strides: [64, 4, 1]
// Sub strides: [1, 1, 1]
// => New strides: [64, 4, 1]
// Final strides == filterOutReducedDim(new strides, 0) == [4 , 1]
//
// Orig offset: 0
// Sub offsets: [3, 4, 2]
// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210
//
// Final sizes == filterOutReducedDim(subview sizes, 0) == [6, 3]
//
// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview
// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>)
//
// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[C4]], %[[C1]]
func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x4xf32>)
-> (memref<f32>, index, index, index, index, index) {
%subview = memref.subview %base[3, 4, 2][1, 6, 3][1, 1, 1] :
memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>>
%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
memref<6x3xf32, strided<[4,1], offset: 210>>
-> memref<f32>, index, index, index, index, index
return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
memref<f32>, index, index, index, index, index
}
// -----
// Check that we simplify extract_strided_metadata of subview properly
// when the subview reduces the rank and some of the strides are variable.
// In particular, we check that:
// A. The dynamic stride is multiplied with the base stride to create the new
// stride for dimension 1.
// B. The first returned stride is the value computed in #A.
// See extract_strided_metadata_of_subview for an explanation of the actual
// expansion.
//
// Orig strides: [64, 4, 1]
// Sub strides: [1, %stride, 1]
// => New strides: [64, 4 * %stride, 1]
// Final strides == filterOutReducedDim(new strides, 0) == [4 * %stride , 1]
//
// Orig offset: 0
// Sub offsets: [3, 4, 2]
// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210
//
// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides
// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>,
// CHECK-SAME: %[[DYN_STRIDE:.*]]: index)
//
// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK-DAG: %[[DIM1_STRIDE:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_STRIDE]]]
//
// CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]]
func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides(
%base: memref<8x16x4xf32>, %stride: index)
-> (memref<f32>, index, index, index, index, index) {
%subview = memref.subview %base[3, 4, 2][1, 6, 3][1, %stride, 1] :
memref<8x16x4xf32> to memref<6x3xf32, strided<[?, 1], offset: 210>>
%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
memref<6x3xf32, strided<[?, 1], offset: 210>>
-> memref<f32>, index, index, index, index, index
return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
memref<f32>, index, index, index, index, index
}
// -----
// Check that we simplify extract_strided_metadata of subview properly
// when the subview uses variable offsets.
// See extract_strided_metadata_of_subview for an explanation of the actual
// expansion.
//
// Orig strides: [128, 1]
// Sub strides: [1, 1]
// => New strides: [128, 1]
//
// Orig offset: 0
// Sub offsets: [%arg1, %arg2]
// => Final offset: 128 * arg1 + 1 * %arg2 + 0
//
// CHECK-DAG: #[[$OFFSETS_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 128 + s1)>
// CHECK-LABEL: func @extract_strided_metadata_of_subview_w_variable_offset
// CHECK-SAME: (%[[ARG:.*]]: memref<384x128xf32>,
// CHECK-SAME: %[[DYN_OFFSET0:.*]]: index,
// CHECK-SAME: %[[DYN_OFFSET1:.*]]: index)
//
// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSETS_MAP]]()[%[[DYN_OFFSET0]], %[[DYN_OFFSET1]]]
//
// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C64]], %[[C64]], %[[C128]], %[[C1]]
func.func @extract_strided_metadata_of_subview_w_variable_offset(
%arg0: memref<384x128xf32>, %arg1 : index, %arg2 : index)
-> (memref<f32>, index, index, index, index, index) {
%subview = memref.subview %arg0[%arg1, %arg2] [64, 64] [1, 1] :
memref<384x128xf32> to memref<64x64xf32, strided<[128, 1], offset: ?>>
%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
memref<64x64xf32, strided<[128, 1], offset: ?>> -> memref<f32>, index, index, index, index, index
return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
memref<f32>, index, index, index, index, index
}
// -----
// Check that all the math is correct for all types of computations.
// We achieve that by using dynamic values for all the different types:
// - Offsets
// - Sizes
// - Strides
//
// Orig strides: [s0, s1, s2]
// Sub strides: [subS0, subS1, subS2]
// => New strides: [s0 * subS0, s1 * subS1, s2 * subS2]
// ==> 1 affine map (used for each stride) with two values.
//
// Orig offset: origOff
// Sub offsets: [subO0, subO1, subO2]
// => Final offset: s0 * * subO0 + ... + s2 * subO2 + origOff
// ==> 1 affine map with (rank * 2 + 1) symbols
//
// CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)>
// CHECK-LABEL: func @extract_strided_metadata_of_subview_all_dynamic
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index)
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK-DAG: %[[FINAL_STRIDE0:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE0]], %[[STRIDES]]#0]
// CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1]
// CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2]
//
// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[STRIDES]]#2]
//
// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]], %[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]]
func.func @extract_strided_metadata_of_subview_all_dynamic(
%base: memref<?x?x?xf32, strided<[?,?,?], offset:?>>,
%offset0: index, %offset1: index, %offset2: index,
%size0: index, %size1: index, %size2: index,
%stride0: index, %stride1: index, %stride2: index)
-> (memref<f32>, index, index, index, index, index, index, index) {
%subview = memref.subview %base[%offset0, %offset1, %offset2]
[%size0, %size1, %size2]
[%stride0, %stride1, %stride2] :
memref<?x?x?xf32, strided<[?,?,?], offset: ?>> to
memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
%base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview :
memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
-> memref<f32>, index, index, index, index, index, index, index
return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 :
memref<f32>, index, index, index, index, index, index, index
}
// -----
// Check that we properly simplify expand_shape into:
// reinterpret_cast(extract_strided_metadata) + <some math>
//
// Here we have:
// For the group applying to dim0:
// size 0 = baseSizes#0 / (all static sizes in that group)
// = baseSizes#0 / (7 * 8 * 9)
// = baseSizes#0 / 504
// size 1 = 7
// size 2 = 8
// size 3 = 9
// stride 0 = baseStrides#0 * 7 * 8 * 9
// = baseStrides#0 * 504
// stride 1 = baseStrides#0 * 8 * 9
// = baseStrides#0 * 72
// stride 2 = baseStrides#0 * 9
// stride 3 = baseStrides#0
//
// For the group applying to dim1:
// size 4 = 10
// size 5 = 2
// size 6 = baseSizes#1 / (all static sizes in that group)
// = baseSizes#1 / (10 * 2 * 3)
// = baseSizes#1 / 60
// size 7 = 3
// stride 4 = baseStrides#1 * size 5 * size 6 * size 7
// = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
// = baseStrides#1 * (baseSizes#1 / 60) * 6
// and since we know that baseSizes#1 is a multiple of 60:
// = baseStrides#1 * (baseSizes#1 / 10)
// stride 5 = baseStrides#1 * size 6 * size 7
// = baseStrides#1 * (baseSizes#1 / 60) * 3
// = baseStrides#1 * (baseSizes#1 / 20)
// stride 6 = baseStrides#1 * size 7
// = baseStrides#1 * 3
// stride 7 = baseStrides#1
//
// Base and offset are unchanged.
//
// CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
// CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
//
// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
// CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
// CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
// CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
// CHECK-LABEL: func @simplify_expand_shape
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
//
// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
// CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
// CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
//
// CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[DYN_SIZE0]], 7, 8, 9, 10, 2, %[[DYN_SIZE6]], 3], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1]
//
// CHECK: return %[[REINTERPRET_CAST]]
func.func @simplify_expand_shape(
%base: memref<?x?xf32, strided<[?,?], offset:?>>,
%offset0: index, %offset1: index, %offset2: index,
%size0: index, %size1: index, %size2: index,
%stride0: index, %stride1: index, %stride2: index)
-> memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>> {
%subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] :
memref<?x?xf32, strided<[?,?], offset: ?>> into
memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
return %subview :
memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
}
// -----
// Check that we properly simplify extract_strided_metadata of expand_shape
// into:
// baseBuffer, baseOffset, baseSizes, baseStrides =
// extract_strided_metadata(memref)
// sizes#reassIdx =
// baseSizes#reassDim / product(expandShapeSizes#j,
// for j in group excluding reassIdx)
// strides#reassIdx =
// baseStrides#reassDim * product(expandShapeSizes#j, for j in
// reassIdx+1..reassIdx+group.size)
//
// Here we have:
// For the group applying to dim0:
// size 0 = 3
// size 1 = 5
// size 2 = 2
// stride 0 = baseStrides#0 * 5 * 2
// = 4 * 5 * 2
// = 40
// stride 1 = baseStrides#0 * 2
// = 4 * 2
// = 8
// stride 2 = baseStrides#0
// = 4
//
// For the group applying to dim1:
// size 3 = 2
// size 4 = 2
// stride 3 = baseStrides#1 * 2
// = 1 * 2
// = 2
// stride 4 = baseStrides#1
// = 1
//
// Base and offset are unchanged.
//
// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static
// CHECK-SAME: (%[[ARG:.*]]: memref<30x4xi16>)
//
// CHECK-DAG: %[[C40:.*]] = arith.constant 40 : index
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<30x4xi16> -> memref<i16>, index, index, index, index, index
//
// CHECK: return %[[BASE]], %[[C0]], %[[C3]], %[[C5]], %[[C2]], %[[C2]], %[[C2]], %[[C40]], %[[C8]], %[[C4]], %[[C2]], %[[C1]] : memref<i16>, index, index, index, index, index, index, index, index, index, index, index
func.func @extract_strided_metadata_of_expand_shape_all_static(
%arg : memref<30x4xi16>)
-> (memref<i16>, index,
index, index, index, index, index,
index, index, index, index, index) {
%expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] :
memref<30x4xi16> into memref<3x5x2x2x2xi16>
%base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape :
memref<3x5x2x2x2xi16>
-> memref<i16>, index,
index, index, index, index, index,
index, index, index, index, index
return %base, %offset,
%sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4,
%strides#0, %strides#1, %strides#2, %strides#3, %strides#4 :
memref<i16>, index,
index, index, index, index, index,
index, index, index, index, index
}
// -----
// Check that we properly simplify extract_strided_metadata of expand_shape
// when dynamic sizes, strides, and offsets are involved.
// See extract_strided_metadata_of_expand_shape_all_static for an explanation
// of the expansion.
//
// One of the important characteristic of this test is that the dynamic
// dimensions produced by the expand_shape appear both in the first dimension
// (for group 1) and the non-first dimension (second dimension for group 2.)
// The idea is to make sure that:
// 1. We properly account for dynamic shapes even when the strides are not
// affected by them. (When the dynamic dimension is the first one.)
// 2. We properly compute the strides affected by dynamic shapes. (When the
// dynamic dimension is not the first one.)
//
// Here we have:
// For the group applying to dim0:
// size 0 = baseSizes#0 / (all static sizes in that group)
// = baseSizes#0 / (7 * 8 * 9)
// = baseSizes#0 / 504
// size 1 = 7
// size 2 = 8
// size 3 = 9
// stride 0 = baseStrides#0 * 7 * 8 * 9
// = baseStrides#0 * 504
// stride 1 = baseStrides#0 * 8 * 9
// = baseStrides#0 * 72
// stride 2 = baseStrides#0 * 9
// stride 3 = baseStrides#0
//
// For the group applying to dim1:
// size 4 = 10
// size 5 = 2
// size 6 = baseSizes#1 / (all static sizes in that group)
// = baseSizes#1 / (10 * 2 * 3)
// = baseSizes#1 / 60
// size 7 = 3
// stride 4 = baseStrides#1 * size 5 * size 6 * size 7
// = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
// = baseStrides#1 * (baseSizes#1 / 60) * 6
// and since we know that baseSizes#1 is a multiple of 60:
// = baseStrides#1 * (baseSizes#1 / 10)
// stride 5 = baseStrides#1 * size 6 * size 7
// = baseStrides#1 * (baseSizes#1 / 60) * 3
// = baseStrides#1 * (baseSizes#1 / 20)
// stride 6 = baseStrides#1 * size 7
// = baseStrides#1 * 3
// stride 7 = baseStrides#1
//
// Base and offset are unchanged.
//
// CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
// CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
//
// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
// CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
// CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
// CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
//
// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
//
// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
// CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
// CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
// CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[DYN_SIZE6]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
func.func @extract_strided_metadata_of_expand_shape_all_dynamic(
%base: memref<?x?xf32, strided<[?,?], offset:?>>,
%offset0: index, %offset1: index, %offset2: index,
%size0: index, %size1: index, %size2: index,
%stride0: index, %stride1: index, %stride2: index)
-> (memref<f32>, index,
index, index, index, index, index, index, index, index,
index, index, index, index, index, index, index, index) {
%subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] :
memref<?x?xf32, strided<[?,?], offset: ?>> into
memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
%base_buffer, %offset, %sizes:8, %strides:8 = memref.extract_strided_metadata %subview :
memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
-> memref<f32>, index,
index, index, index, index, index, index, index, index,
index, index, index, index, index, index, index, index
return %base_buffer, %offset,
%sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, %sizes#5, %sizes#6, %sizes#7,
%strides#0, %strides#1, %strides#2, %strides#3, %strides#4, %strides#5, %strides#6, %strides#7 :
memref<f32>, index,
index, index, index, index, index, index, index, index,
index, index, index, index, index, index, index, index
}
// -----
// Check that we properly handle extract_strided_metadata of expand_shape for
// 0-D input.
// The 0-D case is pretty boring:
// All expanded sizes are 1, likewise for the strides, and we keep the
// original base and offset.
// We have still a test for it, because since the input reassociation map
// of the expand_shape is empty, the handling of such shape hits a corner
// case.
// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static_0_rank
// CHECK-SAME: (%[[ARG:.*]]: memref<i16, strided<[], offset: ?>>)
//
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]] : memref<i16, strided<[], offset: ?>> -> memref<i16>, index
//
// CHECK: return %[[BASE]], %[[OFFSET]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]] : memref<i16>, index, index, index, index, index, index, index, index, index, index, index
func.func @extract_strided_metadata_of_expand_shape_all_static_0_rank(
%arg : memref<i16, strided<[], offset: ?>>)
-> (memref<i16>, index,
index, index, index, index, index,
index, index, index, index, index) {
%expand_shape = memref.expand_shape %arg[] :
memref<i16, strided<[], offset: ?>> into memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>>
%base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape :
memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>>
-> memref<i16>, index,
index, index, index, index, index,
index, index, index, index, index
return %base, %offset,
%sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4,
%strides#0, %strides#1, %strides#2, %strides#3, %strides#4 :
memref<i16>, index,
index, index, index, index, index,
index, index, index, index, index
}
// -----
// Check that we simplify extract_strided_metadata(alloc)
// into simply the alloc with the information extracted from
// the memref type and arguments of the alloc.
//
// baseBuffer = reinterpret_cast alloc
// offset = 0
// sizes = shape(memref)
// strides = strides(memref)
//
// For dynamic shapes, we simply use the values that feed the alloc.
//
// Simple rank 0 test: we don't need a reinterpret_cast here.
// CHECK-LABEL: func @extract_strided_metadata_of_alloc_all_static_0_rank
//
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc()
// CHECK: return %[[ALLOC]], %[[C0]] : memref<i16>, index
func.func @extract_strided_metadata_of_alloc_all_static_0_rank()
-> (memref<i16>, index) {
%A = memref.alloc() : memref<i16>
%base, %offset = memref.extract_strided_metadata %A :
memref<i16>
-> memref<i16>, index
return %base, %offset :
memref<i16>, index
}
// -----
// Simplification of extract_strided_metadata(alloc).
// Check that we properly use the dynamic sizes to
// create the new sizes and strides.
// size 0 = dyn_size0
// size 1 = 4
// size 2 = dyn_size2
// size 3 = dyn_size3
//
// stride 0 = size 1 * size 2 * size 3
// = 4 * dyn_size2 * dyn_size3
// stride 1 = size 2 * size 3
// = dyn_size2 * dyn_size3
// stride 2 = size 3
// = dyn_size3
// stride 3 = 1
//
// CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK-LABEL: extract_strided_metadata_of_alloc_dyn_size
// CHECK-SAME: (%[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_SIZE3:.*]]: index)
//
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[DYN_SIZE0]], %[[DYN_SIZE2]], %[[DYN_SIZE3]])
//
// CHECK-DAG: %[[STRIDE0:.*]] = affine.apply #[[$STRIDE0_MAP]]()[%[[DYN_SIZE2]], %[[DYN_SIZE3]]]
// CHECK-DAG: %[[STRIDE1:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_SIZE2]], %[[DYN_SIZE3]]]
//
// CHECK-DAG: %[[CASTED_ALLOC:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<?x4x?x?xi16> to memref<i16>
//
// CHECK: return %[[CASTED_ALLOC]], %[[C0]], %[[DYN_SIZE0]], %[[C4]], %[[DYN_SIZE2]], %[[DYN_SIZE3]], %[[STRIDE0]], %[[STRIDE1]], %[[DYN_SIZE3]], %[[C1]]
func.func @extract_strided_metadata_of_alloc_dyn_size(
%dyn_size0 : index, %dyn_size2 : index, %dyn_size3 : index)
-> (memref<i16>, index,
index, index, index, index,
index, index, index, index) {
%A = memref.alloc(%dyn_size0, %dyn_size2, %dyn_size3) : memref<?x4x?x?xi16>
%base, %offset, %sizes:4, %strides:4 = memref.extract_strided_metadata %A :
memref<?x4x?x?xi16>
-> memref<i16>, index,
index, index, index, index,
index, index, index, index
return %base, %offset,
%sizes#0, %sizes#1, %sizes#2, %sizes#3,
%strides#0, %strides#1, %strides#2, %strides#3 :
memref<i16>, index,
index, index, index, index,
index, index, index, index
}
// -----
// Same check as extract_strided_metadata_of_alloc_dyn_size but alloca
// instead of alloc. Just to make sure we handle allocas the same way
// we do with alloc.
// While at it, test a slightly different shape than
// extract_strided_metadata_of_alloc_dyn_size.
//
// size 0 = dyn_size0
// size 1 = dyn_size1
// size 2 = 4
// size 3 = dyn_size3
//
// stride 0 = size 1 * size 2 * size 3
// = dyn_size1 * 4 * dyn_size3
// stride 1 = size 2 * size 3
// = 4 * dyn_size3
// stride 2 = size 3
// = dyn_size3
// stride 3 = 1
//
// CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-LABEL: extract_strided_metadata_of_alloca_dyn_size
// CHECK-SAME: (%[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE3:.*]]: index)
//
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca(%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE3]])
//
// CHECK-DAG: %[[STRIDE0:.*]] = affine.apply #[[$STRIDE0_MAP]]()[%[[DYN_SIZE1]], %[[DYN_SIZE3]]]
// CHECK-DAG: %[[STRIDE1:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_SIZE3]]]
//
// CHECK-DAG: %[[CASTED_ALLOCA:.*]] = memref.reinterpret_cast %[[ALLOCA]] to offset: [0], sizes: [], strides: [] : memref<?x?x4x?xi16> to memref<i16>
//
// CHECK: return %[[CASTED_ALLOCA]], %[[C0]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[C4]], %[[DYN_SIZE3]], %[[STRIDE0]], %[[STRIDE1]], %[[DYN_SIZE3]], %[[C1]]
func.func @extract_strided_metadata_of_alloca_dyn_size(
%dyn_size0 : index, %dyn_size1 : index, %dyn_size3 : index)
-> (memref<i16>, index,
index, index, index, index,
index, index, index, index) {
%A = memref.alloca(%dyn_size0, %dyn_size1, %dyn_size3) : memref<?x?x4x?xi16>
%base, %offset, %sizes:4, %strides:4 = memref.extract_strided_metadata %A :
memref<?x?x4x?xi16>
-> memref<i16>, index,
index, index, index, index,
index, index, index, index
return %base, %offset,
%sizes#0, %sizes#1, %sizes#2, %sizes#3,
%strides#0, %strides#1, %strides#2, %strides#3 :
memref<i16>, index,
index, index, index, index,
index, index, index, index
}
// -----
// The following few alloc tests are negative tests (the simplification
// doesn't happen) to make sure non trivial memref types are treated
// as "not been normalized".
// CHECK-LABEL: extract_strided_metadata_of_alloc_with_variable_offset
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]]
// CHECK: return %[[BASE]]
#map0 = affine_map<(d0)[s0] -> (d0 + s0)>
func.func @extract_strided_metadata_of_alloc_with_variable_offset(%arg : index)
-> (memref<i16>, index, index, index) {
%A = memref.alloc()[%arg] : memref<4xi16, #map0>
%base, %offset, %size, %stride = memref.extract_strided_metadata %A :
memref<4xi16, #map0>
-> memref<i16>, index, index, index
return %base, %offset, %size, %stride :
memref<i16>, index, index, index
}
// -----
// CHECK-LABEL: extract_strided_metadata_of_alloc_with_cst_offset
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]]
// CHECK: return %[[BASE]]
#map0 = affine_map<(d0) -> (d0 + 12)>
func.func @extract_strided_metadata_of_alloc_with_cst_offset(%arg : index)
-> (memref<i16>, index, index, index) {
%A = memref.alloc() : memref<4xi16, #map0>
%base, %offset, %size, %stride = memref.extract_strided_metadata %A :
memref<4xi16, #map0>
-> memref<i16>, index, index, index
return %base, %offset, %size, %stride :
memref<i16>, index, index, index
}
// -----
// CHECK-LABEL: extract_strided_metadata_of_alloc_with_cst_offset_in_type
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]]
// CHECK: return %[[BASE]]
func.func @extract_strided_metadata_of_alloc_with_cst_offset_in_type(%arg : index)
-> (memref<i16>, index, index, index) {
%A = memref.alloc() : memref<4xi16, strided<[1], offset : 10>>
%base, %offset, %size, %stride = memref.extract_strided_metadata %A :
memref<4xi16, strided<[1], offset : 10>>
-> memref<i16>, index, index, index
return %base, %offset, %size, %stride :
memref<i16>, index, index, index
}
// -----
// CHECK-LABEL: extract_strided_metadata_of_alloc_with_strided
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]]
// CHECK: return %[[BASE]]
func.func @extract_strided_metadata_of_alloc_with_strided(%arg : index)
-> (memref<i16>, index, index, index) {
%A = memref.alloc() : memref<4xi16, strided<[12]>>
%base, %offset, %size, %stride = memref.extract_strided_metadata %A :
memref<4xi16, strided<[12]>>
-> memref<i16>, index, index, index
return %base, %offset, %size, %stride :
memref<i16>, index, index, index
}
// -----
// CHECK-LABEL: extract_aligned_pointer_as_index
// CHECK-SAME: (%[[ARG0:.*]]: memref<f32>
func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index {
// CHECK-NOT: memref.subview
// CHECK: memref.extract_aligned_pointer_as_index %[[ARG0]]
%c = memref.subview %arg0[] [] [] : memref<f32> to memref<f32>
%r = memref.extract_aligned_pointer_as_index %arg0: memref<f32> -> index
return %r : index
}
// -----
// Check that we simplify collapse_shape into
// reinterpret_cast(extract_strided_metadata) + <some math>
//
// We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]
// Size 0 = origSize0
// Size 1 = origSize1 * origSize2 * origSize3
// = origSize1 * 4 * origSize3
// Size 2 = origSize4 * origSize5
// = 6 * 7
// = 42
// Stride 0 = min(origStride0)
// = Right now the folder of affine.min is not smart
// enough to just return origStride0
// Stride 1 = min(origStride1, origStride2, origStride3)
// = min(origStride1, origStride2, 42)
// Stride 2 = min(origStride4, origStride5)
// = min(7, 1)
// = 1
//
// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0] -> (s0)>
// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-DAG: #[[$STRIDE1_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1, 42)>
// CHECK-LABEL: func @simplify_collapse(
// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
//
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
//
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.min #[[$STRIDE1_MIN_MAP]]()[%[[STRIDES]]#1, %[[STRIDES]]#2]
//
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], 1]
func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
-> memref<?x?x42xi32> {
%collapsed_view = memref.collapse_shape %arg [[0], [1, 2, 3], [4, 5]] :
memref<?x?x4x?x6x7xi32> into memref<?x?x42xi32>
return %collapsed_view : memref<?x?x42xi32>
}
// -----
// Check that we simplify collapse_shape into
// reinterpret_cast(extract_strided_metadata) + <some math>
// when there are dimensions of size 1 involved.
//
// We transform: 3x1 to [0, 1]
//
// The tricky bit here is the strides between dimension 0 and 1
// are not truly contiguous, but since we dealing with a dimension of size 1
// this is actually fine (i.e., we are not going to jump around.)
//
// As a result the resulting stride needs to ignore the strides of the
// dimensions of size 1.
//
// Size 0 = origSize0 * origSize1
// = 3 * 1
// = 3
// Stride 0 = min(origStride_i, for all i in reassocation group and dim_i != 1)
// = min(origStride0)
// = min(2)
// = 2
//
// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1(
// CHECK-SAME: %[[ARG:.*]]: memref<3x1xf32, strided<[2, 1]>>,
//
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x1xf32, strided<[2, 1]>>
//
//
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [3], strides: [2]
func.func @simplify_collapse_with_dim_of_size1(%arg0: memref<3x1xf32, strided<[2,1]>>, %arg1: memref<3xf32>) {
%collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
memref<3x1xf32, strided<[2, 1]>> into memref<3xf32, strided<[2]>>
memref.copy %collapse_shape, %arg1 : memref<3xf32, strided<[2]>> to memref<3xf32>
return
}
// -----
// Check that we simplify collapse_shape with an edge case group of 1x1x...x1.
//
// The tricky bit here is also the resulting stride is meaningless, we still
// have to please the type system.
//
// In this case, we're collapsing two strides of respectively 2 and 1 and the
// resulting type wants a stride of 2.
//
// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_non_1_stride(
// CHECK-SAME: %[[ARG:.*]]: memref<1x1xi32, strided<[2, 1]
//
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<1x1xi32, strided<[2, 1], offset: ?>>
//
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [1], strides: [2]
func.func @simplify_collapse_with_dim_of_size1_and_non_1_stride
(%arg0: memref<1x1xi32, strided<[2, 1], offset: ?>>)
-> memref<1xi32, strided<[2], offset: ?>> {
%collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
memref<1x1xi32, strided<[2, 1], offset: ?>>
into memref<1xi32, strided<[2], offset: ?>>
return %collapse_shape : memref<1xi32, strided<[2], offset: ?>>
}
// -----
// Check that we simplify collapse_shape with an edge case group of 1x1x...x1.
// We also have a couple of collapsed dimensions before the 1x1x...x1 group
// to make sure we properly index into the dynamic strides based on the
// group ID.
//
// The tricky bit in this test is that the 1x1x...x1 group stride is dynamic
// so we have to propagate one of the dynamic dimension for this group.
//
// For this test we have:
// Size0 = origSize0 * origSize1
// = 2 * 3
// = 6
// Size1 = origSize2 * origSize3 * origSize4
// = 1 * 1 * 1
// = 1
//
// Stride0 = min(origStride0, origStride1)
// Stride1 = we actually don't know, this is dynamic but we don't know
// which one to pick.
// We just return the first dynamic one for this group.
//
//
// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1)>
// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride(
// CHECK-SAME: %[[ARG:.*]]: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2]
//
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:5, %[[STRIDES:.*]]:5 = memref.extract_strided_metadata %[[ARG]] : memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>
//
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0, %[[STRIDES]]#1]
//
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[DYN_STRIDE0]], %[[STRIDES]]#2]
func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
(%arg0: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>)
-> memref<6x1xi32, strided<[?, ?], offset: ?>> {
%collapse_shape = memref.collapse_shape %arg0 [[0, 1], [2, 3, 4]] :
memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>
into memref<6x1xi32, strided<[?, ?], offset: ?>>
return %collapse_shape : memref<6x1xi32, strided<[?, ?], offset: ?>>
}
// -----
// Check that we simplify extract_strided_metadata of collapse_shape.
//
// We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]
// Size 0 = origSize0
// Size 1 = origSize1 * origSize2 * origSize3
// = origSize1 * 4 * origSize3
// Size 2 = origSize4 * origSize5
// = 6 * 7
// = 42
// Stride 0 = origStride0
// Stride 1 = origStride3 (orig stride of the inner most dimension)
// = 42
// Stride 2 = origStride5
// = 1
//
// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0] -> (s0)>
// CHECK-LABEL: func @extract_strided_metadata_of_collapse(
// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
//
// CHECK-DAG: %[[C42:.*]] = arith.constant 42 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
//
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
//
// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[DYN_STRIDE0]], %[[C42]], %[[C1]]
func.func @extract_strided_metadata_of_collapse(%arg : memref<?x?x4x?x6x7xi32>)
-> (memref<i32>, index,
index, index, index,
index, index, index) {
%collapsed_view = memref.collapse_shape %arg [[0], [1, 2, 3], [4, 5]] :
memref<?x?x4x?x6x7xi32> into memref<?x?x42xi32>
%base, %offset, %sizes:3, %strides:3 =
memref.extract_strided_metadata %collapsed_view : memref<?x?x42xi32>
-> memref<i32>, index,
index, index, index,
index, index, index
return %base, %offset,
%sizes#0, %sizes#1, %sizes#2,
%strides#0, %strides#1, %strides#2 :
memref<i32>, index,
index, index, index,
index, index, index
}
// -----
// Check that we simplify extract_strided_metadata of collapse_shape to
// a 0-ranked shape.
// CHECK-LABEL: func @extract_strided_metadata_of_collapse_to_rank0(
// CHECK-SAME: %[[ARG:.*]]: memref<1x1x1x1x1x1xi32>)
//
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<1x1x1x1x1x1xi32>
//
// CHECK: return %[[BASE]], %[[C0]]
func.func @extract_strided_metadata_of_collapse_to_rank0(%arg : memref<1x1x1x1x1x1xi32>)
-> (memref<i32>, index) {
%collapsed_view = memref.collapse_shape %arg [] :
memref<1x1x1x1x1x1xi32> into memref<i32>
%base, %offset =
memref.extract_strided_metadata %collapsed_view : memref<i32>
-> memref<i32>, index
return %base, %offset :
memref<i32>, index
}
// -----
// Check that we simplify extract_strided_metadata of
// extract_strided_metadata.
//
// CHECK-LABEL: func @extract_strided_metadata_of_extract_strided_metadata(
// CHECK-SAME: %[[ARG:.*]]: memref<i32>)
//
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]]
//
// CHECK: return %[[BASE]], %[[C0]]
func.func @extract_strided_metadata_of_extract_strided_metadata(%arg : memref<i32>)
-> (memref<i32>, index) {
%base, %offset =
memref.extract_strided_metadata %arg:memref<i32>
-> memref<i32>, index
%base2, %offset2 =
memref.extract_strided_metadata %base:memref<i32>
-> memref<i32>, index
return %base2, %offset2 :
memref<i32>, index
}
// -----
// Check that we simplify extract_strided_metadata of reinterpret_cast
// when the source of the reinterpret_cast is compatible with what
// `extract_strided_metadata`s accept.
//
// When we apply the transformation the resulting offset, sizes and strides
// should come straight from the inputs of the reinterpret_cast.
//
// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast
// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index)
//
// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}:2, %{{.*}}:2 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]]
func.func @extract_strided_metadata_of_reinterpret_cast(
%arg : memref<?x?xi32, strided<[?, ?], offset:?>>,
%offset: index,
%size0 : index, %size1 : index,
%stride0 : index, %stride1 : index)
-> (memref<i32>, index,
index, index,
index, index) {
%cast =
memref.reinterpret_cast %arg to
offset: [%offset],
sizes: [%size0, %size1],
strides: [%stride0, %stride1] :
memref<?x?xi32, strided<[?, ?], offset: ?>> to
memref<?x?xi32, strided<[?, ?], offset: ?>>
%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
-> memref<i32>, index,
index, index,
index, index
return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}
// -----
// Check that we don't simplify extract_strided_metadata of
// reinterpret_cast when the source of the cast is unranked.
// Unranked memrefs cannot feed into extract_strided_metadata operations.
// Note: Technically we could still fold the sizes and strides.
//
// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast_unranked
// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index)
//
// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[DYN_OFFSET]]], sizes: [%[[DYN_SIZE0]], %[[DYN_SIZE1]]], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]]]
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
//
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
func.func @extract_strided_metadata_of_reinterpret_cast_unranked(
%arg : memref<*xi32>,
%offset: index,
%size0 : index, %size1 : index,
%stride0 : index, %stride1 : index)
-> (memref<i32>, index,
index, index,
index, index) {
%cast =
memref.reinterpret_cast %arg to
offset: [%offset],
sizes: [%size0, %size1],
strides: [%stride0, %stride1] :
memref<*xi32> to
memref<?x?xi32, strided<[?, ?], offset: ?>>
%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
-> memref<i32>, index,
index, index,
index, index
return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}
// -----
// Similar to @extract_strided_metadata_of_reinterpret_cast, just make sure
// we handle 0-D properly.
//
// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast_rank0
// CHECK-SAME: %[[ARG:.*]]: memref<i32, strided<[], offset: ?>>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index)
//
// CHECK: %[[BASE:.*]], %[[BASE_OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]]
//
// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]]
func.func @extract_strided_metadata_of_reinterpret_cast_rank0(
%arg : memref<i32, strided<[], offset:?>>,
%offset: index,
%size0 : index, %size1 : index,
%stride0 : index, %stride1 : index)
-> (memref<i32>, index,
index, index,
index, index) {
%cast =
memref.reinterpret_cast %arg to
offset: [%offset],
sizes: [%size0, %size1],
strides: [%stride0, %stride1] :
memref<i32, strided<[], offset: ?>> to
memref<?x?xi32, strided<[?, ?], offset: ?>>
%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
-> memref<i32>, index,
index, index,
index, index
return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}