From a3fd43833ff01be4e10e4db4fa3878678fea6f19 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 24 Apr 2024 15:12:59 -0400 Subject: [PATCH] don't rely on TensorCast being in caller's scope --- Project.toml | 1 - src/macro.jl | 128 +++++++++++++++++++++++++-------------------------- 2 files changed, 64 insertions(+), 65 deletions(-) diff --git a/Project.toml b/Project.toml index 3c229f9..ad7f094 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,6 @@ LazyArrays = "0.21, 0.22, 1" LazyStack = "0.1.0" MacroTools = "0.5" StaticArrays = "1.3" -Strided = "1.1, 2" TransmuteDims = "0.1.13" julia = "1.6" diff --git a/src/macro.jl b/src/macro.jl index 9b0606a..838b23f 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -261,7 +261,7 @@ function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false) else # A = :( TensorCast.rview($A, $(ijcolon...)) ) perm = filter(!isnothing, ntuple(d -> ijcolon[d]==(:) ? d : nothing, length(ijcolon))) - A = :( TensorCast.transmute($A, Base.Val($perm)) ) + A = :( $transmute($A, $Val($perm)) ) end ijk = filter(!isconstant, ijk) # remove actual constants from list, end @@ -271,11 +271,11 @@ function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false) code = Tuple(map(i -> iscolon(i) ? (:) : (*), ijk)) if static sizeorcode = maybestaticsizes(ijk, code, call) - A = :( TensorCast.static_slice($A, $sizeorcode) ) + A = :( $static_slice($A, $sizeorcode) ) elseif (:lazy_0 in call.flags) || (:collect in call.flags) - A = :( TensorCast.slicecopy($A, $code) ) + A = :( $slicecopy($A, $code) ) else - A = :( TensorCast.sliceview($A, $code) ) + A = :( $sliceview($A, $code) ) end ijk = filter(!iscolon, ijk) elseif static @@ -290,9 +290,9 @@ function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false) # Diagonal extraction A[i,i] if length(ijk)==2 && ijk[1]==ijk[2] if (:lazy_0 in call.flags) && !LHS # don't do this for in-place output - A = :( TensorCast.diag($A) ) # LinearAlgebra really + A = :( $LinearAlgebra.diag($A) ) # LinearAlgebra really else - A = :( TensorCast.diagview($A) ) + A = :( $diagview($A) ) end pop!(ijk) end @@ -304,7 +304,7 @@ function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false) # Combined indices A[i,(j,k)] if any(istensor, ijk) flatsize = map(axwrap, flat) - A = :( Base.reshape($A, ($(flatsize...),)) ) + A = :( $reshape($A, ($(flatsize...),)) ) append!(store.need, flat) end @@ -312,7 +312,7 @@ function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false) if !isempty(reversed) A = maybepush(A, store, :prereverse) rind = map(1:length(flat)) do d - flat[d] in reversed ? :($reverse(Base.axes($A,$d))) : (:) + flat[d] in reversed ? :($reverse($axes($A,$d))) : (:) end rdims = Tuple(indexin(reversed, flat)) if (:lazy_0 in call.flags) && !LHS @@ -333,7 +333,7 @@ function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false) if !isempty(shuffled) A = maybepush(A, store, :preshuffle) sind = map(1:length(flat)) do d - flat[d] in shuffled ? :($shuffle(Base.axes($A,$d))) : (:) + flat[d] in shuffled ? :($shuffle($axes($A,$d))) : (:) end if (:lazy_0 in call.flags) && !LHS if length(flat) == 1 @@ -408,7 +408,7 @@ function standardglue(ex, target, store::NamedTuple, call::CallInfo) else # ex = :( $Bsym = @__dot__ TensorCast.rview($B, $(ijcolon...)) ) perm = filter(!isnothing, ntuple(d -> ijcolon[d]==(:) ? d : nothing, length(ijcolon))) - ex = :( $Bsym = TensorCast.transmute.($B, Base.Val($perm)) ) + ex = :( $Bsym = $transmute.($B, $Val($perm)) ) end push!(store.main, ex) B = Bsym @@ -424,13 +424,13 @@ function standardglue(ex, target, store::NamedTuple, call::CallInfo) ijk = vcat(inner, outer) if static - AB = :( TensorCast.static_glue($B) ) + AB = :( $static_glue($B) ) pop!(call.flags, :collected, :ok) elseif :lazy_0 in call.flags - AB = :( TensorCast.eagerstack($B) ) # really from Base/Compat + AB = :( $eagerstack($B) ) # really from Base/Compat push!(call.flags, :collected) else # if :lazy in call.flags - AB = :( TensorCast.lazystack($B) ) # really from LazyStack + AB = :( $lazystack($B) ) # really from LazyStack pop!(call.flags, :collected, :ok) end @@ -494,7 +494,7 @@ function readycast(ex, target, store::NamedTuple, call::CallInfo) end # Some things ought to apply elementwise: conjugation, - @capture(ex, A_[ijk__]') && return :( Base.adjoint($A[$(ijk...)]) ) + @capture(ex, A_[ijk__]') && return :( $adjoint($A[$(ijk...)]) ) # .fields... only one deep for now, @capture(ex, A_[ijk__].field_ ) && return :( getproperty($A[$(ijk...)], $(QuoteNode(field))) ) @@ -537,10 +537,10 @@ function readycast(ex, target, store::NamedTuple, call::CallInfo) perm = ntuple(d -> findfirst(isequal(d), dims), maximum(dims)) if perm != ntuple(identity, maximum(dims)) if :lazy_0 in call.flags - A = :( TensorCast.transmutedims($A, $perm) ) + A = :( $transmutedims($A, $perm) ) push!(call.flags, :collected) else - A = :( TensorCast.transmute($A, Base.Val($perm)) ) + A = :( $transmute($A, $Val($perm)) ) if ! increasing_or_zero(perm) # thus not just a reshape pop!(call.flags, :collected, :ok) end @@ -852,7 +852,7 @@ function reduceparse(ex1, ex2, store::NamedTuple, call::CallInfo) # Then parse redlist, decoding ranges like sum(i:10,j) which specify sizes reduced = [] for item in tensorprimetidy(redlist) # normalise θ' - i = @capture(item, j_:s_) ? saveonesize(j, :(Base.OneTo($s)), store) : item + i = @capture(item, j_:s_) ? saveonesize(j, :($Base.OneTo($s)), store) : item push!(reduced, i) end checknorepeats(reduced, call, " in the reduction") # catches sum(i,i) B[i,j,k] @@ -904,7 +904,7 @@ function indexparse(A, ijk::Vector, store=nothing, call=nothing; save=false) if iscolon(i) if i isa QuoteNode && A != :_ str = "fixed size in $A[" * join(ijk, ", ") * "]" # DimensionMismatch("fixed size in M[i, \$(QuoteNode(5))]: size(M, 2) == 5") TODO print more nicely - pushboundscheck!(store.mustassert, :( Base.size($A,$d)==$(i.value) || throw(DimensionMismatch($str))) ) + pushboundscheck!(store.mustassert, :( $size($A,$d)==$(i.value) || throw(DimensionMismatch($str))) ) end continue end @@ -913,7 +913,7 @@ function indexparse(A, ijk::Vector, store=nothing, call=nothing; save=false) push!(outaxes, Base.OneTo(1)) if i == :_ && A != :_ && save str = "underscore in $A[" * join(ijk, ", ") * "]" - pushboundscheck!(store.mustassert, :( Base.size($A,$d)==1 || throw(DimensionMismatch($str))) ) + pushboundscheck!(store.mustassert, :( $size($A,$d)==1 || throw(DimensionMismatch($str))) ) end continue end @@ -922,7 +922,7 @@ function indexparse(A, ijk::Vector, store=nothing, call=nothing; save=false) stripminustilde!(ii, reversed, shuffled) append!(flat, ii) push!(outaxes, axwrap(ii)) - save && A != :_ && saveonesize(ii, :(Base.axes($A, $d)), store) + save && A != :_ && saveonesize(ii, :($axes($A, $d)), store) elseif @capture(i, B_[klm__]) innerparse(B, klm, store, call) # called just for error on tensor/colon/constant @@ -932,7 +932,7 @@ function indexparse(A, ijk::Vector, store=nothing, call=nothing; save=false) elseif i isa Symbol push!(flat, i) push!(outaxes, axwrap(i)) - save && A != :_ && saveonesize(i, :(Base.axes($A, $d)), store) + save && A != :_ && saveonesize(i, :($axes($A, $d)), store) else throw(MacroError("don't understand index $i", call)) end @@ -942,10 +942,10 @@ function indexparse(A, ijk::Vector, store=nothing, call=nothing; save=false) N = length(ijk) if N == 1 str = "expected a vector or tuple $A[" * join(ijk, ", ") * "]" - pushboundscheck!(store.assert, :( $A isa Tuple || Base.ndims($A)==$N || Base.throw(ArgumentError($str))) ) + pushboundscheck!(store.assert, :( $A isa Tuple || $ndims($A)==$N || $throw(ArgumentError($str))) ) else str = "expected a $N-tensor $A[" * join(ijk, ", ") * "]" - pushboundscheck!(store.assert, :( Base.ndims($A)==$N || Base.throw(ArgumentError($str))) ) + pushboundscheck!(store.assert, :( $ndims($A)==$N || $throw(ArgumentError($str))) ) end end @@ -997,13 +997,13 @@ function innerparse(firstA, ijk, store::NamedTuple, call::CallInfo; save=false) if @capture(i, j_:s_) push!(innerflat, j) - saveonesize(j, :(Base.OneTo($s)), store) # save=true on LHS only for in-place, save this anyway + saveonesize(j, :( $Base.OneTo($s)), store) # save=true on LHS only for in-place, save this anyway elseif isconstant(i) - i == :_ && save && pushboundscheck!(store.mustassert, :( size($firstA, $d)==1 || throw(DimensionMismatch("inner underscore"))) ) + i == :_ && save && pushboundscheck!(store.mustassert, :( $size($firstA, $d)==1 || $throw(DimensionMismatch("inner underscore"))) ) else push!(innerflat, i) end - save && saveonesize(i, :(Base.axes($firstA, $d)), store) + save && saveonesize(i, :($axes($firstA, $d)), store) end checknorepeats(innerflat, call) @@ -1021,21 +1021,21 @@ function optionparse(opt, store::NamedTuple, call::CallInfo) if @capture(opt, i_ in ax_) || @capture(opt, i_ ∈ ax_) if @capture(ax, 1:s_) - saveonesize(tensorprimetidy(i), :(Base.OneTo($s)), store) + saveonesize(tensorprimetidy(i), :($Base.OneTo($s)), store) elseif ax isa Number @warn "did you mean `$i in 1:$ax`, not `$i in $ax`?" - saveonesize(tensorprimetidy(i), :(Base.OneTo($ax)), store) + saveonesize(tensorprimetidy(i), :($Base.OneTo($ax)), store) else ax1 = maybepushtop(ax, store, :axis) - push!(store.top, :($ax1 isa AbstractUnitRange || Base.throw(DimensionMismatch("index ranges must have step 1")))) + push!(store.top, :($ax1 isa AbstractUnitRange || $throw(DimensionMismatch("index ranges must have step 1")))) if isdefined(call.mod, :OffsetArrays) ax2, off = gensym(:axis), gensym(:offset) push!(store.top, :(local $off = first($ax1)-1)) - push!(store.top, :(local $ax2 = $ax1 isa Base.OneTo ? $ax1 : OffsetArrays.IdOffsetRange($ax1 .- $off, $off))) + push!(store.top, :(local $ax2 = $ax1 isa $Base.OneTo ? $ax1 : OffsetArrays.IdOffsetRange($ax1 .- $off, $off))) saveonesize(tensorprimetidy(i), ax2, store) else - pushboundscheck!(store.top, :(first($ax1)==1 || Base.throw(ArgumentError("you must load OffsetArrays to allow index ranges not starting at 1")))) - saveonesize(tensorprimetidy(i), :(Base.OneTo($ax1)), store) + pushboundscheck!(store.top, :(first($ax1)==1 || $throw(ArgumentError("you must load OffsetArrays to allow index ranges not starting at 1")))) + saveonesize(tensorprimetidy(i), :($Base.OneTo($ax1)), store) end end push!(call.flags, :assert) @@ -1043,7 +1043,7 @@ function optionparse(opt, store::NamedTuple, call::CallInfo) push!(call.flags, Symbol(:lazy_, Int(val))) elseif @capture(opt, i_:s_) @warn "please replace index ranges like `i:3` with `i in 1:3` or `i ∈ 1:3`" call.string maxlog=3 - saveonesize(tensorprimetidy(i), :(Base.OneTo($s)), store) + saveonesize(tensorprimetidy(i), :($Base.OneTo($s)), store) push!(call.flags, :assert) elseif opt in (:strided, :avx) @warn "postfix option $opt is deprecated, please write @cast @$opt A[i] := ..." call.string maxlog=3 @@ -1078,7 +1078,7 @@ function saveonesize(ind, ax, store::NamedTuple) elseif store.dict[ind] != ax # no need to save identical expressions if isa(ind, Symbol) str = "range of index $ind must agree" - pushboundscheck!(store.assert, :( $(store.dict[ind]) == $ax || Base.throw(DimensionMismatch($str))) ) + pushboundscheck!(store.assert, :( $(store.dict[ind]) == $ax || $throw(DimensionMismatch($str))) ) end end ind @@ -1122,18 +1122,18 @@ function sizeinfer(store::NamedTuple, call::CallInfo) if length(denfacts) > 1 # den = :( prod(length, ($(denfacts...),)) ) longs = map(takelength, denfacts) - den = :( Base.:*($(longs...)) ) + den = :( $(*)($(longs...)) ) else den = takelength(denfacts[1]) end - rat = :( Base.OneTo($num ÷ $den) ) + rat = :( $Base.OneTo($num ÷ $den) ) i = pair.first[.!known][1] d = findfirst(isequal(i), store.need) d != nothing && (sizes[d] = rat) str = "expected integer multiples, when calculating range of $i from range of $(join(pair.first, " ⊗ "))" - pushboundscheck!(store.mustassert, :( ($num % $den)==0 || Base.throw(ArgumentError($str))) ) + pushboundscheck!(store.mustassert, :( ($num % $den)==0 || $throw(ArgumentError($str))) ) end end end @@ -1154,9 +1154,9 @@ function takelength(ex) ex.args[2] elseif Meta.isexpr(ex, :call) && ex.args[1] in (:axes, axes, :(Base.axes)) @assert length(ex.args) == 3 - :(Base.size($(ex.args[2:end]...))) + :($size($(ex.args[2:end]...))) else - :(Base.length($ex)) + :($length($ex)) end end @@ -1171,7 +1171,7 @@ function maybestaticsizes(ijk::Vector, code::Tuple, call::CallInfo) length(ijk) == length(code) || error("wrong length of code!") staticsize = Any[ i.value for i in ijk if i isa QuoteNode ] if length(staticsize) == count(iscolon, ijk) - return :( TensorCast.Size($(staticsize...)) ) # really StaticArrays. + return :( $Size($(staticsize...)) ) # really StaticArrays. else return code end @@ -1194,7 +1194,7 @@ function maybestaticsizes(ijk::Vector, code::Tuple, store::NamedTuple, call::Cal return code end end - return :( TensorCast.Size($(staticsize...)) ) + return :( $Size($(staticsize...)) ) end """ @@ -1264,7 +1264,7 @@ axwrap(i::Symbol) = Symbol(:ax_,i) function axwrap(ijk::Vector) length(ijk) == 0 && return nothing length(ijk) == 1 && return axwrap(first(ijk)) - return :( TensorCast.star($(map(axwrap, ijk)...)) ) + return :( $star($(map(axwrap, ijk)...)) ) end isconstant(n::Int) = true @@ -1358,7 +1358,7 @@ function checkallseen(ex, canon, store, call) ex else fake = map(i -> recursemacro(i, canon, store, call), left) - :( TensorCast.onlyfirst($ex, $(fake...)) ) + :( $onlyfirst($ex, $(fake...)) ) end end @@ -1400,16 +1400,16 @@ function matrixshape(ex, left::Vector, right::Vector, store::NamedTuple, call::C # Deal with simple matrix, and with empty right of V in M*V length(left) == 1 && length(right) <= 1 && return ex # and empty right because it's M * vec(T) - isempty(right) && return :( Base.reshape($ex, :) ) + isempty(right) && return :( $reshape($ex, :) ) # Deal with empty left of V in V'*M, or perhaps V=vec(T) first if isempty(left) if length(right) == 1 # return :( TensorCast.PermuteDims($ex) ) - return :( Base.transpose($ex) ) + return :( $transpose($ex) ) else # return :( TensorCast.PermuteDims(reshape($ex, :)) ) - return :( Base.transpose(Base.reshape($ex, :)) ) + return :( $transpose($reshape($ex, :)) ) end end @@ -1419,7 +1419,7 @@ function matrixshape(ex, left::Vector, right::Vector, store::NamedTuple, call::C append!(store.need, left) append!(store.need, right) # push!(call.flags, :reshaped) - return :( Base.reshape($ex, ($left_sz,$right_sz)) ) + return :( $reshape($ex, ($left_sz,$right_sz)) ) end function unmatrixshape(ex, left::Vector, right::Vector, store::NamedTuple, call::CallInfo) @@ -1429,16 +1429,16 @@ function unmatrixshape(ex, left::Vector, right::Vector, store::NamedTuple, call: # For V' * V, did we want a scalar or not? What we will get is unknown to the macro: if length(left) == 0 && length(right) == 0 if :scalar in call.flags - return :( Base.first($ex) ) + return :( $first($ex) ) # If you had arrays of arrays, then PermuteDims would have permutedims-ed, and * would make an array, so this is still OK. else - return :( Base.fill($ex) ) # zero-dim array + return :( $fill($ex) ) # zero-dim array end end # For V'*M, you may get a Transpose row-vector, for which this is .parent: if length(left) == 0 - ex = :( TensorCast.transmute($ex, Base.Val((2,))) ) + ex = :( $transmute($ex, $Val((2,))) ) length(right) == 1 && return ex # literally V'*M done, # but for V'*T we should reshape this .parent end @@ -1448,7 +1448,7 @@ function unmatrixshape(ex, left::Vector, right::Vector, store::NamedTuple, call: append!(store.need, left) append!(store.need, right) # push!(call.flags, :reshaped) - return :( Base.reshape($ex, ($(sizes...),)) ) + return :( $reshape($ex, ($(sizes...),)) ) end function increasing_or_zero(tup::Tuple, prev=0) # strictly increasing, allows nothing or 0 @@ -1509,7 +1509,7 @@ function newoutput(ex, canon, parsed, store::NamedTuple, call::CallInfo) dims = length(parsed.rdims)>1 ? Tuple(parsed.rdims) : parsed.rdims[1] # perm = Tuple(filter(d -> !(d in parsed.rdims), 1:length(canon))) # ex = :( TensorCast.transmute($(parsed.redfun)($ex, dims=$dims), $perm) ) - ex = :( Base.dropdims($(parsed.redfun)($ex, dims=$dims), dims=$dims) ) + ex = :( $dropdims($(parsed.redfun)($ex, dims=$dims), dims=$dims) ) if :strided in call.flags pop!(call.flags, :collected, :ok) # makes stridedview(... end @@ -1525,12 +1525,12 @@ function newoutput(ex, canon, parsed, store::NamedTuple, call::CallInfo) code = Tuple(Any[ i in parsed.innerflat ? (:) : (*) for i in canon ]) if parsed.static sizeorcode = maybestaticsizes(canon, code, store, call) - ex = :( TensorCast.static_slice($ex, $sizeorcode) ) + ex = :( $static_slice($ex, $sizeorcode) ) elseif :collect in call.flags - ex = :( TensorCast.slicecopy($ex, $code) ) + ex = :( $slicecopy($ex, $code) ) push!(call.flags, :collected) else - ex = :( TensorCast.sliceview($ex, $code) ) + ex = :( $sliceview($ex, $code) ) end # Now I allow fixing output indices if any(isconstant, parsed.inner) @@ -1541,16 +1541,16 @@ function newoutput(ex, canon, parsed, store::NamedTuple, call::CallInfo) perm = Tuple(map(i -> isconstant(i) ? nothing : (_d+=1), parsed.inner)) # ex = :(TensorCast.orient.($Asafe, Ref($code)) ) # @. would need a dollar # refperm = maybepush(:( Ref() ), store, :zzz) - ex = :(TensorCast.transmute.($Asafe, Base.Val($perm)) ) + ex = :( $transmute.($Asafe, $Val($perm)) ) end end # Must we collect? Do this now, as reshape(TransmutedDimsArray(...)) is awful. if :collect in call.flags if :strided in call.flags - ex = :( Base.collect($ex) ) + ex = :( $collect($ex) ) elseif !(:collected in call.flags) - ex = :( Base.identity.($ex) ) + ex = :( $identity.($ex) ) end end @@ -1558,21 +1558,21 @@ function newoutput(ex, canon, parsed, store::NamedTuple, call::CallInfo) if any(i -> istensor(i) || isconstant(i), parsed.outer) any(i -> isconstant(i) && !(i == :_ || i == 1), parsed.outer) && throw(MacroError("can't fix output index to $i, only to 1", call)) if any(istensor, parsed.outer) - ex = :( Base.reshape($ex, ($(parsed.outaxes...),)) ) + ex = :( $reshape($ex, ($(parsed.outaxes...),)) ) append!(store.need, parsed.flat) else _d = 0 perm = Tuple(map(i -> isconstant(i) ? nothing : (_d+=1), parsed.outer)) - ex = :( TensorCast.transmute($ex, Base.Val($perm)) ) + ex = :( $transmute($ex, $Val($perm)) ) end end # Is the result Diagonal or friends? Doesn't allow Z[i,i,1] or Z[i,-i] but that's OK if length(parsed.outer)==2 && parsed.outer[1]==parsed.outer[2] if :lazy_0 in call.flags - ex = :( TensorCast.diagm(0 => $ex) ) + ex = :( $diagm(0 => $ex) ) else - ex = :( TensorCast.Diagonal($ex) ) + ex = :( $Diagonal($ex) ) end end @@ -1595,7 +1595,7 @@ function inplaceoutput(ex, canon, parsed, store::NamedTuple, call::CallInfo) zed isa Symbol || @capture(zed, ZZ_.field_) || error("wtf") newleft = parsed.left str = "expected a 0-tensor $zed[]" - pushboundscheck!(store.mustassert, :( Base.ndims($zed)==0 || Base.throw(ArgumentError($str))) ) + pushboundscheck!(store.mustassert, :( $ndims($zed)==0 || $throw(ArgumentError($str))) ) else newleft = standardise(parsed.left, store, call) @capture(newleft, zed_[ijk__]) || throw(MacroError("failed to parse LHS correctly, $(parsed.left) -> $newleft")) @@ -1620,7 +1620,7 @@ function inplaceoutput(ex, canon, parsed, store::NamedTuple, call::CallInfo) zmul = matrixshape(zmul, ex[3], ex[4], store, call) zmul isa Symbol || push!(call.flags, :showfinal) # zed = maybepush(zed, out, :mul!) - push!(out, :( TensorCast.mul!($zmul, $(ex[1]), $(ex[2])) ) ) + push!(out, :( $mul!($zmul, $(ex[1]), $(ex[2])) ) ) else push!(out, :( $zed .= $ex ) ) end