Skip to content

Commit

Permalink
[spirv] Add spv.ReturnValue
Browse files Browse the repository at this point in the history
This CL adds the spv.ReturnValue op and its tests. Also adds a
InFunctionScope trait to make sure that the op stays inside
a function. To be consistent, ModuleOnly trait is changed to
InModuleScope.

PiperOrigin-RevId: 264193081
  • Loading branch information
antiagainst authored and tensorflower-gardener committed Aug 19, 2019
1 parent a9f4ae1 commit 636c414
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 20 deletions.
16 changes: 10 additions & 6 deletions include/mlir/Dialect/SPIRV/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>;
def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>;
def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>;
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;

def SPV_OpcodeAttr :
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
Expand All @@ -146,7 +147,7 @@ def SPV_OpcodeAttr :
SPV_OC_OpFMod, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn
SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn, SPV_OC_OpReturnValue
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
Expand Down Expand Up @@ -778,12 +779,15 @@ def SPV_SamplerUseAttr:
// SPIR-V OpTrait definitions
//===----------------------------------------------------------------------===//

// Check that an op can only be used with SPIR-V ModuleOp
def IsModuleOnlyPred :
CPred<"llvm::isa_and_nonnull<spirv::ModuleOp>($_op.getParentOp())">;
// Check that an op can only be used within the scope of a FuncOp.
def InFunctionScope : PredOpTrait<
"op must appear in a 'func' block",
CPred<"llvm::isa_and_nonnull<FuncOp>($_op.getParentOp())">>;

def ModuleOnly :
PredOpTrait<"op can only be used in a 'spv.module' block", IsModuleOnlyPred>;
// Check that an op can only be used within the scope of a SPIR-V ModuleOp.
def InModuleScope : PredOpTrait<
"op must appear in a 'spv.module' block",
CPred<"llvm::isa_and_nonnull<spirv::ModuleOp>($_op.getParentOp())">>;

//===----------------------------------------------------------------------===//
// SPIR-V op definitions
Expand Down
36 changes: 34 additions & 2 deletions include/mlir/Dialect/SPIRV/SPIRVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {

// -----

def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> {
def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> {
let summary = "Declare an execution mode for an entry point.";

let description = [{
Expand Down Expand Up @@ -599,7 +599,7 @@ def SPV_LoadOp : SPV_Op<"Load", []> {

// -----

def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> {
let summary = "Return with no value from a function with void return type.";

let description = [{
Expand All @@ -624,6 +624,38 @@ def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {

// -----

def SPV_ReturnValueOp : SPV_Op<"ReturnValue", [InFunctionScope, Terminator]> {
let summary = "Return a value from a function.";

let description = [{
Value is the value returned, by copy, and must match the Return Type
operand of the OpTypeFunction type of the OpFunction body this return
instruction is in.

This instruction must be the last instruction in a block.

### Custom assembly form

``` {.ebnf}
return-value-op ::= `spv.ReturnValue` ssa-use `:` spirv-type
```

For example:

```
spv.ReturnValue %0 : f32
```
}];

let arguments = (ins
SPV_Type:$value
);

let results = (outs);
}

// -----

def SPV_SDivOp : SPV_ArithmeticOp<"SDiv", SPV_Integer> {
let summary = "Signed-integer division of Operand 1 divided by Operand 2.";

Expand Down
8 changes: 4 additions & 4 deletions include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
include "mlir/SPIRV/SPIRVBase.td"
#endif // SPIRV_BASE

def SPV_AddressOfOp : SPV_Op<"_address_of", [NoSideEffect]> {
def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
let summary = "Get the address of a global variable.";

let description = [{
Expand Down Expand Up @@ -66,7 +66,7 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [NoSideEffect]> {
let hasOpcode = 0;
}

def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> {
let summary = [{
Declare an entry point, its execution model, and its interface.
}];
Expand Down Expand Up @@ -122,7 +122,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
}


def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [ModuleOnly]> {
def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope]> {
let summary = [{
Allocate an object in memory at module scope. The object is
referenced using a symbol name.
Expand Down Expand Up @@ -264,7 +264,7 @@ def SPV_ModuleOp : SPV_Op<"module",
}];
}

def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator, ModuleOnly]> {
def SPV_ModuleEndOp : SPV_Op<"_module_end", [InModuleScope, Terminator]> {
let summary = "The pseudo op that ends a SPIR-V module";

let description = [{
Expand Down
42 changes: 38 additions & 4 deletions lib/Dialect/SPIRV/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1042,10 +1042,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
//===----------------------------------------------------------------------===//

static LogicalResult verifyReturn(spirv::ReturnOp returnOp) {
auto funcOp = llvm::dyn_cast<FuncOp>(returnOp.getOperation()->getParentOp());
if (!funcOp)
return returnOp.emitOpError("must appear in a 'func' op");

auto funcOp = llvm::cast<FuncOp>(returnOp.getParentOp());
auto numOutputs = funcOp.getType().getNumResults();
if (numOutputs != 0)
return returnOp.emitOpError("cannot be used in functions returning value")
Expand All @@ -1054,6 +1051,43 @@ static LogicalResult verifyReturn(spirv::ReturnOp returnOp) {
return success();
}

//===----------------------------------------------------------------------===//
// spv.ReturnValue
//===----------------------------------------------------------------------===//

static ParseResult parseReturnValueOp(OpAsmParser *parser,
OperationState *state) {
OpAsmParser::OperandType retValInfo;
Type retValType;
return failure(
parser->parseOperand(retValInfo) || parser->parseColonType(retValType) ||
parser->resolveOperand(retValInfo, retValType, state->operands));
}

static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter *printer) {
*printer << spirv::ReturnValueOp::getOperationName() << ' ';
printer->printOperand(retValOp.value());
*printer << " : " << retValOp.value()->getType();
}

static LogicalResult verify(spirv::ReturnValueOp retValOp) {
auto funcOp = llvm::cast<FuncOp>(retValOp.getParentOp());
auto numFnResults = funcOp.getType().getNumResults();
if (numFnResults != 1)
return retValOp.emitOpError(
"returns 1 value but enclosing function requires ")
<< numFnResults << " results";

auto operandType = retValOp.value()->getType();
auto fnResultType = funcOp.getType().getResult(0);
if (operandType != fnResultType)
return retValOp.emitOpError(" return value's type (")
<< operandType << ") mismatch with function's result type ("
<< fnResultType << ")";

return success();
}

//===----------------------------------------------------------------------===//
// spv.StoreOp
//===----------------------------------------------------------------------===//
Expand Down
21 changes: 21 additions & 0 deletions test/Dialect/SPIRV/Serialization/terminator.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s

func @spirv_terminator() -> () {
spv.module "Logical" "GLSL450" {
// CHECK-LABEL: @ret
func @ret() -> () {
// CHECK: spv.Return
spv.Return
}

// CHECK-LABEL: @ret_val
func @ret_val() -> (i32) {
%0 = spv.Variable : !spv.ptr<i32, Function>
%1 = spv.Load "Function" %0 : i32
// CHECK: spv.ReturnValue {{.*}} : i32
spv.ReturnValue %1 : i32
}
}
return
}

41 changes: 38 additions & 3 deletions test/Dialect/SPIRV/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ spv.module "Logical" "VulkanKHR" {

spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
// expected-error @+1 {{'spv.EntryPoint' op failed to verify that op can only be used in a 'spv.module' block}}
// expected-error @+1 {{op must appear in a 'spv.module' block}}
spv.EntryPoint "GLCompute" @do_something
}
}
Expand Down Expand Up @@ -451,7 +451,7 @@ spv.module "Logical" "VulkanKHR" {

spv.module "Logical" "VulkanKHR" {
func @foo() {
// expected-error @+1 {{op failed to verify that op can only be used in a 'spv.module' block}}
// expected-error @+1 {{op must appear in a 'spv.module' block}}
spv.globalVariable !spv.ptr<f32, Input> @var0
spv.Return
}
Expand Down Expand Up @@ -767,7 +767,7 @@ spv.module "Logical" "VulkanKHR" {
//===----------------------------------------------------------------------===//

"foo.function"() ({
// expected-error @+1 {{must appear in a 'func' op}}
// expected-error @+1 {{op must appear in a 'func' block}}
spv.Return
}) : () -> ()

Expand All @@ -783,6 +783,41 @@ spv.module "Logical" "VulkanKHR" {

// -----

//===----------------------------------------------------------------------===//
// spv.ReturnValue
//===----------------------------------------------------------------------===//

func @ret_val() -> (i32) {
%0 = spv.constant 42 : i32
spv.ReturnValue %0 : i32
}

// -----

"foo.function"() ({
%0 = spv.constant true
// expected-error @+1 {{op must appear in a 'func' block}}
spv.ReturnValue %0 : i1
}) : () -> ()

// -----

func @value_count_mismatch() -> () {
%0 = spv.constant 42 : i32
// expected-error @+1 {{op returns 1 value but enclosing function requires 0 results}}
spv.ReturnValue %0 : i32
}

// -----

func @value_type_mismatch() -> (f32) {
%0 = spv.constant 42 : i32
// expected-error @+1 {{return value's type ('i32') mismatch with function's result type ('f32')}}
spv.ReturnValue %0 : i32
}

// -----

//===----------------------------------------------------------------------===//
// spv.SDiv
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 13 additions & 1 deletion test/Dialect/SPIRV/structure-ops.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s

//===----------------------------------------------------------------------===//
// spv._address_of
//===----------------------------------------------------------------------===//

spv.module "Logical" "GLSL450" {
spv.globalVariable !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input> @var
// expected-error @+1 {{op must appear in a 'func' block}}
%1 = spv._address_of @var : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
}

// -----

//===----------------------------------------------------------------------===//
// spv.constant
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -171,6 +183,6 @@ spv.module "Logical" "VulkanKHR" {
//===----------------------------------------------------------------------===//

func @module_end_not_in_module() -> () {
// expected-error @+1 {{can only be used in a 'spv.module' block}}
// expected-error @+1 {{op must appear in a 'spv.module' block}}
spv._module_end
}

0 comments on commit 636c414

Please sign in to comment.