Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][ABI][Lowering] Fixes calls with union type #1119

Merged
merged 5 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ struct MissingFeatures {
static bool X86TypeClassification() { return false; }

static bool ABIClangTypeKind() { return false; }
static bool ABIEnterStructForCoercedAccess() { return false; }
static bool ABIFuncPtr() { return false; }
static bool ABIInRegAttribute() { return false; }
static bool ABINestedRecordLayout() { return false; }
Expand Down
13 changes: 6 additions & 7 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,12 @@ void StructType::computeSizeAndAlignment(

// Found a nested union: recurse into it to fetch its largest member.
auto structMember = mlir::dyn_cast<StructType>(ty);
if (structMember && structMember.isUnion()) {
auto candidate = structMember.getLargestMember(dataLayout);
if (dataLayout.getTypeSize(candidate) > largestMemberSize) {
largestMember = candidate;
largestMemberSize = dataLayout.getTypeSize(largestMember);
}
} else if (dataLayout.getTypeSize(ty) > largestMemberSize) {
if (!largestMember ||
dataLayout.getTypeABIAlignment(ty) >
dataLayout.getTypeABIAlignment(largestMember) ||
(dataLayout.getTypeABIAlignment(ty) ==
dataLayout.getTypeABIAlignment(largestMember) &&
dataLayout.getTypeSize(ty) > largestMemberSize)) {
largestMember = ty;
largestMemberSize = dataLayout.getTypeSize(largestMember);
}
Expand Down
41 changes: 30 additions & 11 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ mlir::Value createCoercedBitcast(mlir::Value Src, mlir::Type DestTy,
CastKind::bitcast, Src);
}

// FIXME(cir): Create a custom rewriter class to abstract this away.
mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
return LF.getRewriter().create<CastOp>(Src.getLoc(), Ty, CastKind::bitcast,
Src);
}

/// Given a struct pointer that we are accessing some number of bytes out of it,
/// try to gep into the struct to get at its inner goodness. Dive as deep as
/// possible without entering an element with an in-memory size smaller than
Expand All @@ -67,6 +73,9 @@ mlir::Value enterStructPointerForCoercedAccess(mlir::Value SrcPtr,

mlir::Type FirstElt = SrcSTy.getMembers()[0];

if (SrcSTy.isUnion())
FirstElt = SrcSTy.getLargestMember(CGF.LM.getDataLayout().layout);

// If the first elt is at least as large as what we're looking for, or if the
// first element is the same size as the whole struct, we can enter it. The
// comparison must be made on the store size and not the alloca size. Using
Expand All @@ -76,10 +85,26 @@ mlir::Value enterStructPointerForCoercedAccess(mlir::Value SrcPtr,
FirstEltSize < CGF.LM.getDataLayout().getTypeStoreSize(SrcSTy))
return SrcPtr;

cir_cconv_assert_or_abort(
!cir::MissingFeatures::ABIEnterStructForCoercedAccess(), "NYI");
return SrcPtr; // FIXME: This is a temporary workaround for the assertion
// above.
auto &rw = CGF.getRewriter();
auto *ctxt = rw.getContext();
auto ptrTy = PointerType::get(ctxt, FirstElt);
if (mlir::isa<StructType>(SrcPtr.getType())) {
auto addr = SrcPtr;
if (auto load = mlir::dyn_cast<LoadOp>(SrcPtr.getDefiningOp()))
addr = load.getAddr();
cir_cconv_assert(mlir::isa<PointerType>(addr.getType()));
// we can not use getMemberOp here since we need a pointer to the first
// element. And in the case of unions we pick a type of the largest elt,
// that may or may not be the first one. Thus, getMemberOp verification
// may fail.
auto cast = createBitcast(addr, ptrTy, CGF);
SrcPtr = rw.create<LoadOp>(SrcPtr.getLoc(), cast);
}

if (auto sty = mlir::dyn_cast<StructType>(SrcPtr.getType()))
return enterStructPointerForCoercedAccess(SrcPtr, sty, DstSize, CGF);

return SrcPtr;
}

/// Convert a value Val to the specific Ty where both
Expand Down Expand Up @@ -141,12 +166,6 @@ static mlir::Value coerceIntOrPtrToIntOrPtr(mlir::Value val, mlir::Type typ,
return val;
}

// FIXME(cir): Create a custom rewriter class to abstract this away.
mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
return LF.getRewriter().create<CastOp>(Src.getLoc(), Ty, CastKind::bitcast,
Src);
}

AllocaOp createTmpAlloca(LowerFunction &LF, mlir::Location loc, mlir::Type ty) {
auto &rw = LF.getRewriter();
auto *ctxt = rw.getContext();
Expand Down Expand Up @@ -302,7 +321,7 @@ mlir::Value createCoercedValue(mlir::Value Src, mlir::Type Ty,
// extension or truncation to the desired type.
if ((mlir::isa<IntType>(Ty) || mlir::isa<PointerType>(Ty)) &&
(mlir::isa<IntType>(SrcTy) || mlir::isa<PointerType>(SrcTy))) {
cir_cconv_unreachable("NYI");
return coerceIntOrPtrToIntOrPtr(Src, Ty, CGF);
}

// If load is legal, just bitcast the src pointer.
Expand Down
32 changes: 31 additions & 1 deletion clang/test/CIR/CallConvLowering/AArch64/union.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,34 @@ void foo(U u) {}
U init() {
U u;
return u;
}
}

typedef union {

struct {
short a;
char b;
char c;
};

int x;
} A;

void passA(A x) {}

// CIR: cir.func {{.*@callA}}()
// CIR: %[[#V0:]] = cir.alloca !ty_A, !cir.ptr<!ty_A>, ["x"] {alignment = 4 : i64}
// CIR: %[[#V1:]] = cir.cast(bitcast, %[[#V0:]] : !cir.ptr<!ty_A>), !cir.ptr<!s32i>
// CIR: %[[#V2:]] = cir.load %[[#V1]] : !cir.ptr<!s32i>, !s32i
// CIR: %[[#V3:]] = cir.cast(integral, %[[#V2]] : !s32i), !u64i
// CIR: cir.call @passA(%[[#V3]]) : (!u64i) -> ()

// LLVM: void @callA()
// LLVM: %[[#V0:]] = alloca %union.A, i64 1, align 4
// LLVM: %[[#V1:]] = load i32, ptr %[[#V0]], align 4
// LLVM: %[[#V2:]] = sext i32 %[[#V1]] to i64
// LLVM: call void @passA(i64 %[[#V2]])
void callA() {
A x;
passA(x);
}
2 changes: 1 addition & 1 deletion clang/test/CIR/Lowering/unions.cir
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module {
cir.global external @u2 = #cir.zero : !ty_U2_
cir.global external @u3 = #cir.zero : !ty_U3_
// CHECK: llvm.mlir.global external @u2() {addr_space = 0 : i32} : !llvm.struct<"union.U2", (f64)>
// CHECK: llvm.mlir.global external @u3() {addr_space = 0 : i32} : !llvm.struct<"union.U3", (i32)>
// CHECK: llvm.mlir.global external @u3() {addr_space = 0 : i32} : !llvm.struct<"union.U3", (struct<"union.U1", (i32)>)>

// CHECK: llvm.func @test
cir.func @test(%arg0: !cir.ptr<!ty_U1_>) {
Expand Down