diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index a455ae5e4221..7c72ec7e4996 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -3865,20 +3865,20 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) { case Intrinsic::smin: case Intrinsic::umax: case Intrinsic::umin: - BaseType type; - if (isa(I.getOperand(0)->getType()) && isa(I.getOperand(1)->getType()) ) { - type = BaseType::Pointer; - } else { - type = BaseType::Integer; + if (direction & UP) { + auto returnType = getAnalysis(&I)[{-1}]; + if (returnType == BaseType::Integer || returnType == BaseType::Pointer) { + updateAnalysis(I.getOperand(0), TypeTree(returnType).Only(-1, &I), &I); + updateAnalysis(I.getOperand(1), TypeTree(returnType).Only(-1, &I), &I); + } + } + if (direction & DOWN) { + auto opType0 = getAnalysis(I.getOperand(0))[{-1}]; + auto opType1 = getAnalysis(I.getOperand(1))[{-1}]; + if (opType0 == opType1 && (opType0 == BaseType::Integer || opType0 == BaseType::Pointer)){ + updateAnalysis(&I, TypeTree(opType0).Only(-1, &I), &I); + } } - // No direction check as always valid - updateAnalysis(&I, TypeTree(type).Only(-1, &I), &I); - // No direction check as always valid - updateAnalysis(I.getOperand(0), TypeTree(type).Only(-1, &I), - &I); - // No direction check as always valid - updateAnalysis(I.getOperand(1), TypeTree(type).Only(-1, &I), - &I); return; #endif case Intrinsic::umul_with_overflow: