[mlir][spirv] Convert math.ctlz to spv.GLSL.FindUMsb
Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D127582
This commit is contained in:
parent
f1c84d0ff0
commit
cc020a2236
|
@ -1221,4 +1221,20 @@ def SPV_GLSLFMixOp :
|
|||
let hasVerifier = 0;
|
||||
}
|
||||
|
||||
def SPV_GLSLFindUMsbOp : SPV_GLSLUnaryArithmeticOp<"FindUMsb", 75, SPV_Int32> {
|
||||
let summary = "Unsigned-integer most-significant bit";
|
||||
|
||||
let description = [{
|
||||
Results in the bit number of the most-significant 1-bit in the binary
|
||||
representation of Value. If Value is 0, the result is -1.
|
||||
|
||||
Result Type and the type of Value must both be integer scalar or
|
||||
integer vector types. Result Type and operand types must have the
|
||||
same number of components with the same component width. Results are
|
||||
computed per component.
|
||||
|
||||
This instruction is currently limited to 32-bit width components.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_SPIRV_IR_GLSL_OPS
|
||||
|
|
|
@ -16,12 +16,35 @@
|
|||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "math-to-spirv-pattern"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
|
||||
/// given type is not a 32-bit scalar/vector type.
|
||||
static Value getScalarOrVectorI32Constant(Type type, int value,
|
||||
OpBuilder &builder, Location loc) {
|
||||
if (auto vectorType = type.dyn_cast<VectorType>()) {
|
||||
if (!vectorType.getElementType().isInteger(32))
|
||||
return nullptr;
|
||||
SmallVector<int> values(vectorType.getNumElements(), value);
|
||||
return builder.create<spirv::ConstantOp>(loc, type,
|
||||
builder.getI32VectorAttr(values));
|
||||
}
|
||||
if (type.isInteger(32))
|
||||
return builder.create<spirv::ConstantOp>(loc, type,
|
||||
builder.getI32IntegerAttr(value));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operation conversion
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -92,6 +115,42 @@ class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Converts math.ctlz to SPIR-V ops.
|
||||
///
|
||||
/// SPIR-V does not have a direct operations for counting leading zeros. If
|
||||
/// Shader capability is supported, we can leverage GLSL FindUMsb to calculate
|
||||
/// it.
|
||||
class CountLeadingZerosPattern final
|
||||
: public OpConversionPattern<math::CountLeadingZerosOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto type = getTypeConverter()->convertType(countOp.getType());
|
||||
if (!type)
|
||||
return failure();
|
||||
|
||||
// We can only support 32-bit integer types for now.
|
||||
unsigned bitwidth = 0;
|
||||
if (type.isa<IntegerType>())
|
||||
bitwidth = type.getIntOrFloatBitWidth();
|
||||
if (auto vectorType = type.dyn_cast<VectorType>())
|
||||
bitwidth = vectorType.getElementTypeBitWidth();
|
||||
if (bitwidth != 32)
|
||||
return failure();
|
||||
|
||||
Location loc = countOp.getLoc();
|
||||
Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
|
||||
Value msb =
|
||||
rewriter.create<spirv::GLSLFindUMsbOp>(loc, adaptor.getOperand());
|
||||
// We need to subtract from 31 given that the index is from the least
|
||||
// significant bit.
|
||||
rewriter.replaceOpWithNewOp<spirv::ISubOp>(countOp, val31, msb);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts math.expm1 to SPIR-V ops.
|
||||
///
|
||||
/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
|
||||
|
@ -148,7 +207,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|||
|
||||
// GLSL patterns
|
||||
patterns
|
||||
.add<Log1pOpPattern<spirv::GLSLLogOp>, ExpM1OpPattern<spirv::GLSLExpOp>,
|
||||
.add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>,
|
||||
ExpM1OpPattern<spirv::GLSLExpOp>,
|
||||
spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
|
||||
spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
|
||||
spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
|
||||
|
|
|
@ -36,6 +36,17 @@ void ConvertMathToSPIRVPass::runOnOperation() {
|
|||
|
||||
SPIRVTypeConverter typeConverter(targetAttr);
|
||||
|
||||
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
|
||||
// in patterns for other dialects.
|
||||
auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
|
||||
Location loc) {
|
||||
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
|
||||
return Optional<Value>(cast.getResult(0));
|
||||
};
|
||||
typeConverter.addSourceMaterialization(addUnrealizedCast);
|
||||
typeConverter.addTargetMaterialization(addUnrealizedCast);
|
||||
target->addLegalOp<UnrealizedConversionCastOp>();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
populateMathToSPIRVPatterns(typeConverter, patterns);
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, #spv.resource_limits<>> } {
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, #spv.resource_limits<>>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: @float32_unary_scalar
|
||||
func.func @float32_unary_scalar(%arg0: f32) {
|
||||
|
@ -91,4 +93,56 @@ func.func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ctlz_scalar
|
||||
// CHECK-SAME: (%[[VAL:.+]]: i32)
|
||||
func.func @ctlz_scalar(%val: i32) -> i32 {
|
||||
// CHECK: %[[V31:.+]] = spv.Constant 31 : i32
|
||||
// CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : i32
|
||||
// CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32
|
||||
// CHECK: return %[[SUB]]
|
||||
%0 = math.ctlz %val : i32
|
||||
return %0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ctlz_vector1
|
||||
func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> {
|
||||
// CHECK: spv.GLSL.FindUMsb
|
||||
// CHECK: spv.ISub
|
||||
%0 = math.ctlz %val : vector<1xi32>
|
||||
return %0 : vector<1xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ctlz_vector2
|
||||
// CHECK-SAME: (%[[VAL:.+]]: vector<2xi32>)
|
||||
func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
|
||||
// CHECK-DAG: %[[V31:.+]] = spv.Constant dense<31> : vector<2xi32>
|
||||
// CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : vector<2xi32>
|
||||
// CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32>
|
||||
// CHECK: return %[[SUB]]
|
||||
%0 = math.ctlz %val : vector<2xi32>
|
||||
return %0 : vector<2xi32>
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader, Int64, Int16], []>, #spv.resource_limits<>>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: @ctlz_scalar
|
||||
func.func @ctlz_scalar(%val: i64) -> i64 {
|
||||
// CHECK: math.ctlz
|
||||
%0 = math.ctlz %val : i64
|
||||
return %0 : i64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ctlz_vector2
|
||||
func.func @ctlz_vector2(%val: vector<2xi16>) -> vector<2xi16> {
|
||||
// CHECK: math.ctlz
|
||||
%0 = math.ctlz %val : vector<2xi16>
|
||||
return %0 : vector<2xi16>
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
|
|
@ -494,10 +494,34 @@ func.func @fmix(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> () {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @fmix_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>, %arg2 : vector<3xf32>) -> () {
|
||||
// CHECK: {{%.*}} = spv.GLSL.FMix {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xf32> -> vector<3xf32>
|
||||
%0 = spv.GLSL.FMix %arg0 : vector<3xf32>, %arg1 : vector<3xf32>, %arg2 : vector<3xf32> -> vector<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.GLSL.Exp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func.func @findumsb(%arg0 : i32) -> () {
|
||||
// CHECK: spv.GLSL.FindUMsb {{%.*}} : i32
|
||||
%2 = spv.GLSL.FindUMsb %arg0 : i32
|
||||
return
|
||||
}
|
||||
|
||||
func.func @findumsb_vector(%arg0 : vector<3xi32>) -> () {
|
||||
// CHECK: spv.GLSL.FindUMsb {{%.*}} : vector<3xi32>
|
||||
%2 = spv.GLSL.FindUMsb %arg0 : vector<3xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @findumsb(%arg0 : i64) -> () {
|
||||
// expected-error @+1 {{operand #0 must be Int32 or vector of Int32}}
|
||||
%2 = spv.GLSL.FindUMsb %arg0 : i64
|
||||
return
|
||||
}
|
||||
|
|
|
@ -75,4 +75,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
|||
%13 = spv.GLSL.Fma %arg0, %arg1, %arg2 : f32
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.func @findumsb(%arg0 : i32) "None" {
|
||||
// CHECK: spv.GLSL.FindUMsb {{%.*}} : i32
|
||||
%2 = spv.GLSL.FindUMsb %arg0 : i32
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user