From 890800b0d25b6bdd9633d1c0fac170d3b99aa796 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 16 Dec 2024 18:19:00 -0600 Subject: [PATCH] fixup --- src/enzyme_ad/jax/Implementations/HLODerivatives.td | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index c3b37bef6..85221c2c8 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -681,7 +681,7 @@ def FftLength : GlobalExprgetResult(0).getType().cast(); - auto lengths = op.getFftLength(); + auto lengths = op.getFftLengthAttr().getValues(); auto N = std::accumulate(lengths.begin(), lengths.end(), llvm::APInt(64, 1, true), std::multiplies{}).getSExtValue(); double value = N; @@ -707,15 +707,15 @@ def FftMultiplier : GlobalExpr(op.getLoc(), SplatElementsAttr::get( RT, FloatAttr::get(resTy.getElementType(), 0))); auto end_constant = builder.create(op.getLoc(), SplatElementsAttr::get( - RT, FloatAttr::get(resTy.getElementType(), lengths.back()-1))); + RT, FloatAttr::get(resTy.getElementType(), lengths[lengths.size()-1]-1))); auto RT64 = RankedTensorType::get({1}, builder.getIntegerType(64)); Value start[] = { - builder.create(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(0))) + builder.create(op.getLoc(), SplatElementsAttr::get(RT64, rewriter.getI64IntegerAttr(0))) }; Value end[] = { - builder.create(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(lengths.size()-1))) + builder.create(op.getLoc(), SplatElementsAttr::get(RT64, rewriter.getI64IntegerAttr(lengths.size()-1))) }; ret_constant = builder.create(op.getLoc(), resTy, ret_constant, zero_constant, start); ret_constant = builder.create(op.getLoc(), resTy, ret_constant, end_constant, end);