diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 243077c1376..7d19e15ee11 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -3312,6 +3312,9 @@ bool AdjointGenerator::handleKnownCallDerivatives( } #endif Value *replacement = B.CreateAlloca(elTy, Size); + for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type", "enzymejl_allocart"}) + if (auto M = call.getMetadata(MD)) + cast(replacement)->setMetadata(MD, M); if (I) replacement->takeName(I); else diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 3cd0203c08a..065f9f4f81e 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -508,7 +508,7 @@ UpgradeAllocasToMallocs(Function *NewF, DerivativeMode mode, {ConstantAsMetadata::get(ConstantInt::get( IntegerType::get(AI->getContext(), 64), align))})); - for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type"}) + for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type", "enzymejl_allocart"}) if (auto M = AI->getMetadata(MD)) CI->setMetadata(MD, M); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 7c012506e37..6496a041336 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3280,6 +3280,9 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { auto replacement = NB.CreateAlloca( Type::getInt8Ty(I.getContext()), lookupM(getNewFromOriginal(I.getOperand(0)), NB, available)); + for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type", "enzymejl_allocart"}) + if (auto M = I.getMetadata(MD)) + replacement->setMetadata(MD, M); auto Alignment = cast( cast(MD->getOperand(0))->getValue()) @@ -3524,6 +3527,9 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { auto rule = [&](Value *anti) { AllocaInst *replacement = NB.CreateAlloca( Type::getInt8Ty(orig->getContext()), args[0]); + for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type", "enzymejl_allocart"}) + if (auto M = I.getMetadata(MD)) + replacement->setMetadata(MD, M); replacement->takeName(anti); auto Alignment = cast(cast( MD->getOperand(0)) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp index 11a0fb3180f..1e505fe7b90 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchPass.cpp @@ -30,7 +30,11 @@ namespace { static mlir::TensorType applyBatchSizes(mlir::Type Ty, llvm::ArrayRef batchSizes) { - auto T = cast(Ty); + auto T = dyn_cast(Ty); + if (!T) { + return RankedTensorType::get(batchSizes, Ty); + } + SmallVector shape(batchSizes.begin(), batchSizes.end()); shape.append(T.getShape().begin(), T.getShape().end()); auto T2 = T.clone(shape); diff --git a/enzyme/test/MLIR/Batch/batched_scalar.mlir b/enzyme/test/MLIR/Batch/batched_scalar.mlir new file mode 100644 index 00000000000..16cd543873e --- /dev/null +++ b/enzyme/test/MLIR/Batch/batched_scalar.mlir @@ -0,0 +1,21 @@ +// RUN: %eopt --enzyme-batch %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64 { + %y = math.sin %x : f64 + return %y : f64 + } + func.func @dsq(%x : tensor<10x2xf64>) -> tensor<10x2xf64> { + %r = enzyme.batch @square(%x) { batch_shape=array } : (tensor<10x2xf64>) -> (tensor<10x2xf64>) + return %r : tensor<10x2xf64> + } +} + +// CHECK: func.func @dsq(%arg0: tensor<10x2xf64>) -> tensor<10x2xf64> { +// CHECK-NEXT: %0 = call @batched_square(%arg0) : (tensor<10x2xf64>) -> tensor<10x2xf64> +// CHECK-NEXT: return %0 : tensor<10x2xf64> +// CHECK-NEXT: } +// CHECK: func.func private @batched_square(%arg0: tensor<10x2xf64>) -> tensor<10x2xf64> { +// CHECK-NEXT: %0 = math.sin %arg0 : tensor<10x2xf64> +// CHECK-NEXT: return %0 : tensor<10x2xf64> +// CHECK-NEXT: } \ No newline at end of file