llvm-project/mlir/test/Transforms/test-commutativity-utils.mlir
srishti-cb b508c5649f [MLIR] Add a utility to sort the operands of commutative ops
Added a commutativity utility pattern and a function to populate it. The pattern sorts the operands of an op in ascending order of the "key" associated with each operand iff the op is commutative. This sorting is stable.

The function is intended to be used inside passes to simplify the matching of commutative operations. After the application of the above-mentioned pattern, since the commutative operands now have a deterministic order in which they occur in an op, the matching of large DAGs becomes much simpler, i.e., requires much less number of checks to be written by a user in her/his pattern matching function.

The "key" associated with an operand is the list of the "AncestorKeys" associated with the ancestors of this operand, in a breadth-first order.

The operand of any op is produced by a set of ops and block arguments. Each of these ops and block arguments is called an "ancestor" of this operand.

Now, the "AncestorKey" associated with:
1. A block argument is `{type: BLOCK_ARGUMENT, opName: ""}`.
2. A non-constant-like op, for example, `arith.addi`, is `{type: NON_CONSTANT_OP, opName: "arith.addi"}`.
3. A constant-like op, for example, `arith.constant`, is `{type: CONSTANT_OP, opName: "arith.constant"}`.

So, if an operand, say `A`, was produced as follows:

```
`<block argument>`  `<block argument>`
             \          /
              \        /
              `arith.subi`           `arith.constant`
                         \            /
                         `arith.addi`
                                |
                           returns `A`
```

Then, the block arguments and operations present in the backward slice of `A`, in the breadth-first order are:
`arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and `<block argument>`.

Thus, the "key" associated with operand `A` is:
```
{
 {type: NON_CONSTANT_OP, opName: "arith.addi"},
 {type: NON_CONSTANT_OP, opName: "arith.subi"},
 {type: CONSTANT_OP, opName: "arith.constant"},
 {type: BLOCK_ARGUMENT, opName: ""},
 {type: BLOCK_ARGUMENT, opName: ""}
}
```

Now, if "keyA" is the key associated with operand `A` and "keyB" is the key associated with operand `B`, then:
"keyA" < "keyB" iff:
1. In the first unequal pair of corresponding AncestorKeys, the AncestorKey in operand `A` is smaller, or,
2. Both the AncestorKeys in every pair are the same and the size of operand `A`'s "key" is smaller.

AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller ones are the ones with smaller op names (lexicographically).

---

Some examples of such a sorting:

Assume that the sorting is being applied to `foo.commutative`, which is a commutative op.

Example 1:

> %1 = foo.const 0
> %2 = foo.mul <block argument>, <block argument>
> %3 = foo.commutative %1, %2

Here,
1. The key associated with %1 is:
```
    {
     {CONSTANT_OP, "foo.const"}
    }
```
2. The key associated with %2 is:
```
    {
     {NON_CONSTANT_OP, "foo.mul"},
     {BLOCK_ARGUMENT, ""},
     {BLOCK_ARGUMENT, ""}
    }
```

The key of %2 < the key of %1
Thus, the sorted `foo.commutative` is:
> %3 = foo.commutative %2, %1

Example 2:

> %1 = foo.const 0
> %2 = foo.mul <block argument>, <block argument>
> %3 = foo.mul %2, %1
> %4 = foo.add %2, %1
> %5 = foo.commutative %1, %2, %3, %4

Here,
1. The key associated with %1 is:
```
    {
     {CONSTANT_OP, "foo.const"}
    }
```
2. The key associated with %2 is:
```
    {
     {NON_CONSTANT_OP, "foo.mul"},
     {BLOCK_ARGUMENT, ""}
    }
```
3. The key associated with %3 is:
```
    {
     {NON_CONSTANT_OP, "foo.mul"},
     {NON_CONSTANT_OP, "foo.mul"},
     {CONSTANT_OP, "foo.const"},
     {BLOCK_ARGUMENT, ""},
     {BLOCK_ARGUMENT, ""}
    }
```
4. The key associated with %4 is:
```
    {
     {NON_CONSTANT_OP, "foo.add"},
     {NON_CONSTANT_OP, "foo.mul"},
     {CONSTANT_OP, "foo.const"},
     {BLOCK_ARGUMENT, ""},
     {BLOCK_ARGUMENT, ""}
    }
```

Thus, the sorted `foo.commutative` is:
> %5 = foo.commutative %4, %3, %2, %1

Signed-off-by: Srishti Srivastava <srishti.srivastava@polymagelabs.com>

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D124750
2022-07-30 19:25:18 -04:00

117 lines
3.5 KiB
MLIR

// RUN: mlir-opt %s -test-commutativity-utils | FileCheck %s
// CHECK-LABEL: @test_small_pattern_1
func.func @test_small_pattern_1(%arg0 : i32) -> i32 {
// CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant
%0 = arith.constant 45 : i32
// CHECK-NEXT: %[[TEST_ADD:.*]] = "test.addi"
%1 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
// CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi
%2 = arith.addi %arg0, %arg0 : i32
// CHECK-NEXT: %[[ARITH_MUL:.*]] = arith.muli
%3 = arith.muli %arg0, %arg0 : i32
// CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARITH_MUL]], %[[TEST_ADD]], %[[ARITH_CONST]])
%result = "test.op_commutative"(%0, %1, %2, %3): (i32, i32, i32, i32) -> i32
// CHECK-NEXT: return %[[RESULT]]
return %result : i32
}
// CHECK-LABEL: @test_small_pattern_2
// CHECK-SAME: (%[[ARG0:.*]]: i32
func.func @test_small_pattern_2(%arg0 : i32) -> i32 {
// CHECK-NEXT: %[[TEST_CONST:.*]] = "test.constant"
%0 = "test.constant"() {value = 0 : i32} : () -> i32
// CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant
%1 = arith.constant 0 : i32
// CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi
%2 = arith.addi %arg0, %arg0 : i32
// CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARG0]], %[[ARITH_ADD]], %[[ARITH_CONST]], %[[TEST_CONST]])
%result = "test.op_commutative"(%0, %1, %2, %arg0): (i32, i32, i32, i32) -> i32
// CHECK-NEXT: return %[[RESULT]]
return %result : i32
}
// CHECK-LABEL: @test_large_pattern
func.func @test_large_pattern(%arg0 : i32, %arg1 : i32) -> i32 {
// CHECK-NEXT: arith.divsi
%0 = arith.divsi %arg0, %arg1 : i32
// CHECK-NEXT: arith.divsi
%1 = arith.divsi %0, %arg0 : i32
// CHECK-NEXT: arith.divsi
%2 = arith.divsi %1, %arg1 : i32
// CHECK-NEXT: arith.addi
%3 = arith.addi %1, %arg1 : i32
// CHECK-NEXT: arith.subi
%4 = arith.subi %2, %3 : i32
// CHECK-NEXT: "test.addi"
%5 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
// CHECK-NEXT: %[[VAL6:.*]] = arith.divsi
%6 = arith.divsi %4, %5 : i32
// CHECK-NEXT: arith.divsi
%7 = arith.divsi %1, %arg1 : i32
// CHECK-NEXT: %[[VAL8:.*]] = arith.muli
%8 = arith.muli %1, %arg1 : i32
// CHECK-NEXT: %[[VAL9:.*]] = arith.subi
%9 = arith.subi %7, %8 : i32
// CHECK-NEXT: "test.addi"
%10 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
// CHECK-NEXT: %[[VAL11:.*]] = arith.divsi
%11 = arith.divsi %9, %10 : i32
// CHECK-NEXT: %[[VAL12:.*]] = arith.divsi
%12 = arith.divsi %6, %arg1 : i32
// CHECK-NEXT: arith.subi
%13 = arith.subi %arg1, %arg0 : i32
// CHECK-NEXT: "test.op_commutative"(%[[VAL12]], %[[VAL12]], %[[VAL8]], %[[VAL9]])
%14 = "test.op_commutative"(%12, %9, %12, %8): (i32, i32, i32, i32) -> i32
// CHECK-NEXT: %[[VAL15:.*]] = arith.divsi
%15 = arith.divsi %13, %14 : i32
// CHECK-NEXT: %[[VAL16:.*]] = arith.addi
%16 = arith.addi %2, %15 : i32
// CHECK-NEXT: arith.subi
%17 = arith.subi %16, %arg1 : i32
// CHECK-NEXT: "test.addi"
%18 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
// CHECK-NEXT: %[[VAL19:.*]] = arith.divsi
%19 = arith.divsi %17, %18 : i32
// CHECK-NEXT: "test.addi"
%20 = "test.addi"(%arg0, %16): (i32, i32) -> i32
// CHECK-NEXT: %[[VAL21:.*]] = arith.divsi
%21 = arith.divsi %17, %20 : i32
// CHECK-NEXT: %[[RESULT:.*]] = "test.op_large_commutative"(%[[VAL16]], %[[VAL19]], %[[VAL19]], %[[VAL21]], %[[VAL6]], %[[VAL11]], %[[VAL15]])
%result = "test.op_large_commutative"(%16, %6, %11, %15, %19, %21, %19): (i32, i32, i32, i32, i32, i32, i32) -> i32
// CHECK-NEXT: return %[[RESULT]]
return %result : i32
}