Skip to content

Commit

Permalink
Merge pull request #60 from chriselrod/matmulunrolloffsizeoft
Browse files Browse the repository at this point in the history
Unroll based off of sizeof(eltype(C))
  • Loading branch information
ChrisRackauckas authored Apr 12, 2021
2 parents 6c45d5d + 46add07 commit 8f93399
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 42 deletions.
104 changes: 64 additions & 40 deletions src/exp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,59 +120,83 @@ function naivemul!(C, A, B)
mstep = step(Maxis)
# I don't want to deal with axes having non-unit step
if nstep == mstep == 1
naivemul!(C, A, B, Maxis, Naxis)
if sizeof(eltype(C)) > 256
naivemul!(C, A, B, Maxis, Naxis, Val(1), Val(1))
elseif sizeof(eltype(C)) > 128
naivemul!(C, A, B, Maxis, Naxis, Val(2), Val(1))
elseif sizeof(eltype(C)) > 96
naivemul!(C, A, B, Maxis, Naxis, Val(4), Val(1))
elseif sizeof(eltype(C)) > 64
naivemul!(C, A, B, Maxis, Naxis, Val(4), Val(2))
else
naivemul!(C, A, B, Maxis, Naxis, Val(4), Val(3))
end
else
mul!(C,A,B)
end
end
_const(A) = A
_const(A::Array) = Base.Experimental.Const(A)
# Separated to make it easier to test.
function naivemul!(C::AbstractMatrix{T}, A, B, Maxis, Naxis) where {T}
N = last(Naxis)
M = last(Maxis)
Kaxis = axes(B,1)
Base.Experimental.@aliasscope begin
n = first(Naxis)-1
@inbounds begin
while n < N - 1
m = first(Maxis)-1
while m < M - 3
Base.Cartesian.@nexprs 2 j -> Base.Cartesian.@nexprs 4 i -> Cmn_i_j = zero(T)
for k Kaxis
Base.Cartesian.@nexprs 2 j -> Base.Cartesian.@nexprs 4 i -> Cmn_i_j = muladd(_const(A)[m+i,k],_const(B)[k,n+j],Cmn_i_j)
end
Base.Cartesian.@nexprs 2 j -> Base.Cartesian.@nexprs 4 i -> C[m+i,n+j] = Cmn_i_j
m += 4
end
for mm 1+m:M
Base.Cartesian.@nexprs 2 j -> Cmn_j = zero(T)
for k Kaxis
Base.Cartesian.@nexprs 2 j -> Cmn_j = muladd(_const(A)[mm,k],_const(B)[k,n+j],Cmn_j)
end
Base.Cartesian.@nexprs 2 j -> C[mm,n+j] = Cmn_j
end
n += 2
@generated function naivemul!(C::AbstractMatrix{T}, A, B, Maxis, Naxis, ::Val{MU}, ::Val{NU}) where {T,MU,NU}
nrem_body = quote
m = first(Maxis)-1
while m < M - $(MU-1)
Base.Cartesian.@nexprs $MU i -> Cmn_i = zero(T)
for k Kaxis
Base.Cartesian.@nexprs $MU i -> Cmn_i = muladd(_const(A)[m+i,k],_const(B)[k,nn],Cmn_i)
end
m = first(Maxis)-1
while m < M - 3
Base.Cartesian.@nexprs 4 i -> Cmn_i = zero(T)
for k Kaxis
Base.Cartesian.@nexprs 4 i -> Cmn_i = muladd(_const(A)[m+i,k],_const(B)[k,N],Cmn_i)
end
Base.Cartesian.@nexprs 4 i -> C[m+i,N] = Cmn_i
m += 4
Base.Cartesian.@nexprs $MU i -> C[m+i,nn] = Cmn_i
m += $MU
end
for mm 1+m:M
Cmn = zero(T)
for k Kaxis
Cmn = muladd(_const(A)[mm,k], _const(B)[k,nn], Cmn)
end
for mm 1+m:M
Cmn = zero(T)
for k Kaxis
Cmn = muladd(_const(A)[mm,k], _const(B)[k,N], Cmn)
C[mm,nn] = Cmn
end
end
nrem_quote = if NU > 2
:(for nn 1+n:N; $nrem_body; end)
else
:(let nn = N; $nrem_body; end)
end
quote
N = last(Naxis)
M = last(Maxis)
Kaxis = axes(B,1)
Base.Experimental.@aliasscope begin
n = first(Naxis)-1
@inbounds begin
while n < N - $(NU-1)
m = first(Maxis)-1
while m < M - $(MU-1)
Base.Cartesian.@nexprs $NU j -> Base.Cartesian.@nexprs $MU i -> Cmn_i_j = zero(T)
for k Kaxis
Base.Cartesian.@nexprs $MU i -> Ak_i = _const(A)[m+i,k]
Base.Cartesian.@nexprs $NU j -> begin
Bk_j = _const(B)[k,n+j]
Base.Cartesian.@nexprs $MU i -> Cmn_i_j = muladd(Ak_i, Bk_j, Cmn_i_j)
end
end
Base.Cartesian.@nexprs $NU j -> Base.Cartesian.@nexprs $MU i -> C[m+i,n+j] = Cmn_i_j
m += $MU
end
for mm 1+m:M
Base.Cartesian.@nexprs $NU j -> Cmn_j = zero(T)
for k Kaxis
Base.Cartesian.@nexprs $NU j -> Cmn_j = muladd(_const(A)[mm,k],_const(B)[k,n+j],Cmn_j)
end
Base.Cartesian.@nexprs $NU j -> C[mm,n+j] = Cmn_j
end
n += $NU
end
C[mm,N] = Cmn
$(NU > 1 ? nrem_quote : nothing)
end
end
C
end
C
end

"""
Expand Down
9 changes: 7 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,17 @@ end
A = rand(n,n);
B = rand(n,n);
C = similar(A);
@test ExponentialUtilities.naivemul!(C, A, B, axes(C)...) A*B
AB = A*B;
@test ExponentialUtilities.naivemul!(C, A, B, axes(C,1), axes(C,2), Val(2), Val(1)) AB
@test ExponentialUtilities.naivemul!(C, A, B, axes(C,1), axes(C,2), Val(4), Val(2)) AB
@test ExponentialUtilities.naivemul!(C, A, B, axes(C,1), axes(C,2), Val(4), Val(3)) AB
if n 16
Am = MMatrix{n,n}(A)
Bm = MMatrix{n,n}(B)
Cm = MMatrix{n,n}(A)
@test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm)...) A*B
@test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(Cm,2), Val(2), Val(2)) AB
@test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(Cm,2), Val(4), Val(2)) AB
@test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(Cm,2), Val(4), Val(3)) AB
end
end
A = @SMatrix rand(7,7);
Expand Down

0 comments on commit 8f93399

Please sign in to comment.