diff --git a/ext/BFloat16sExt.jl b/ext/BFloat16sExt.jl index cec5e9f0..0bb4c5e6 100644 --- a/ext/BFloat16sExt.jl +++ b/ext/BFloat16sExt.jl @@ -1,7 +1,7 @@ module BFloat16sExt using LLVM -using LLVM: API +using LLVM: API, BFloatType using BFloat16s @@ -12,7 +12,7 @@ LLVM.ConstantFP(val::BFloat16) = ConstantFP(BFloatType(), val) Base.convert(::Type{BFloat16}, val::ConstantFP) = convert(BFloat16, API.LLVMConstRealGetDouble(val, Ref{API.LLVMBool}())) -ConstantDataArray(data::AbstractVector{BFloat16}) = +LLVM.ConstantDataArray(data::AbstractVector{BFloat16}) = ConstantDataArray(BFloatType(), data) end diff --git a/test/core_tests.jl b/test/core_tests.jl index 1d7ddd1b..002bf619 100644 --- a/test/core_tests.jl +++ b/test/core_tests.jl @@ -421,6 +421,8 @@ end typ = LLVM.BFloatType() c = ConstantFP(typ, BFloat16(1.1)) @test convert(BFloat16, c) == BFloat16(1.1) + d = ConstantFP(BFloat16(1.1)) + @test convert(BFloat16, d) == BFloat16(1.1) end let typ = LLVM.X86FP80Type() @@ -551,7 +553,7 @@ end @test size(vec) == size(cda) @test collect(cda) == ConstantInt.(vec) end - for T in [Float32, Float64] + for T in [Float32, Float64, BFloat16] vec = T[1,2,3,4] cda = ConstantDataArray(vec) @test cda isa ConstantDataArray