From f9746ac8bc20aeda76bd0e676c6aa4b28dff3bbe Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 25 Dec 2024 12:53:55 +0100 Subject: [PATCH] unstructured control flow test --- .../test/MLIR/ForwardMode/batched_branch.mlir | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 enzyme/test/MLIR/ForwardMode/batched_branch.mlir diff --git a/enzyme/test/MLIR/ForwardMode/batched_branch.mlir b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir new file mode 100644 index 00000000000..f20989aa424 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %y : f64) -> f64 { + %c = arith.cmpf ult, %x, %y : f64 + cf.cond_br %c, ^blk2(%x : f64), ^blk2(%y : f64) + + ^blk2(%r : f64): + return %r : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>, %y : f64, %dy : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx, %y, %dy) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffesquare(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = arith.cmpf ult, %[[arg0]], %[[arg2]] : f64 +// CHECK-NEXT: cf.cond_br %[[i0]], ^bb1(%[[arg0]], %[[arg1]] : f64, tensor<2xf64>), ^bb1(%[[arg2]], %[[arg3]] : f64, tensor<2xf64>) +// CHECK-NEXT: ^bb1(%[[i1:.+]]: f64, %[[i2:.+]]: tensor<2xf64>): // 2 preds: ^bb0, ^bb0 +// CHECK-NEXT: return %[[i2]] : tensor<2xf64> +// CHECK-NEXT: }