Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting a splat #44

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/src/basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,23 @@ true
```

This one can also be done as `reinterpret(reshape, Tri{Int64}, M)`.
But what would be smarter in the general case is to do one splat, not many:

```julia-repl
julia> Tri.(eachrow(M)...)
4-element Vector{Tri{Int64}}:
Tri{Int64}(1, 2, 3)
Tri{Int64}(4, 5, 6)
Tri{Int64}(7, 8, 9)
Tri{Int64}(10, 11, 12)

julia> @btime Base.splat(tuple).(eachcol(m)) setup=(m=rand(4,100));
38.041 μs (1411 allocations: 48.33 KiB)

julia> @btime tuple.(eachrow(m)...) setup=(m=rand(4,100));
824.256 ns (12 allocations: 4.06 KiB)
```

## Arrays of functions

Besides arrays of numbers (and arrays of arrays) you can also broadcast an array of functions,
Expand Down
45 changes: 35 additions & 10 deletions src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -520,13 +520,6 @@ function readycast(ex, target, store::NamedTuple, call::CallInfo)
# and arrays of functions, using apply:
@capture(ex, funs_[ijk__](args__) ) &&
return :( Core._apply($funs[$(ijk...)], $(args...) ) )
# splats
@capture(ex, fun_(pre__, arg_...)) && containsindexing(arg) && begin
@gensym splat ys
xs = [gensym(Symbol(:x, i)) for i in 1:length(pre)]
push!(store.main, :( local $splat($(xs...), $ys) = $fun($(xs...), $ys...) ))
return :( $splat($(pre...), $arg) )
end

# Apart from those, readycast acts only on lone tensors:
@capture(ex, A_[ijk__]) || return ex
Expand Down Expand Up @@ -627,11 +620,13 @@ end
"""
recursemacro(@reduce sum(i) A[i,j]) -> G[j]

Walks itself over RHS to look for `@reduce ...`, and replace with result,
Walks itself over RHS, originally to look for `@reduce ...`, and replace with result,
pushing calculation steps into store.

Also a convenient place to tidy all indices, including e.g. `fun(M[:,j],N[j]).same[i']`.
And to handle naked indices, `i` => `axes(M,1)[i]` but not exactly like that.
Starts from the outside and works in, which makes it useful for other things:
* Handle naked indices, `i` => `axes(M,1)[i]` but not exactly like that, stopping before this sees `A[i]`.
* Catch splats so that `f(M[:,c]...)` can become `f.(eachrow(M)...)` not `(splat(f)).(eachcol(M))`.
* Tidy all indices, including e.g. `fun(M[:,j], N[j]).same[i']`.
"""
function recursemacro(ex::Expr, canon, store::NamedTuple, call::CallInfo)

Expand All @@ -658,6 +653,36 @@ function recursemacro(ex::Expr, canon, store::NamedTuple, call::CallInfo)
ex = scalar ? :($name) : :($name[$(ind...)])
end

# Handle splatted slices -- walking from inside outwards would slice the wrong way.
if @capture(ex, fun_(args__)) && any(a -> @capture(a, (A_[ijk__]...)), args) && any(iscolon, ijk)
newargs = map(args) do arg
if @capture(arg, (A_[ijk__]...)) && any(iscolon, ijk)
indpost = filter(!iscolon, ijk)
if indexin(indpost, canon) == 1:length(indpost)
Aperm = A
revcode = map(i -> iscolon(i) ? :* : :(:), ijk)
else
perm = indexin(canon, ijk)
while isnothing(last(perm)) # trim nothings off end
pop!(perm)
end
indpost = canon[1:length(perm)]
revcode = vcat(map(_ -> :*, perm), fill(:(:), count(iscolon, ijk)))
for (d,i) in enumerate(ijk) # append positions of colons
iscolon(i) && push!(perm, d)
end
Aperm = :( TensorCast.transmute($A, $(Tuple(perm))) )
end
sliced = :( TensorCast.sliceview($Aperm, ($(revcode...),)) )
sym = maybepush(sliced, store)
:(($sym[$(indpost...)])...)
else
recursemacro(arg, canon, store, call)
end
end
return :( $fun($(newargs...)) )
end

# Tidy up indices, A[i,j][k] will be hit on different rounds...
if @capture(ex, A_[ijk__])
if !(A isa Symbol) # this check allows some tests which have c[c] etc.
Expand Down