Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Behavior on vcat of scalars #172

Open
gdalle opened this issue Jan 29, 2025 · 2 comments
Open

Behavior on vcat of scalars #172

gdalle opened this issue Jan 29, 2025 · 2 comments
Labels
bug Something isn't working

Comments

@gdalle
Copy link

gdalle commented Jan 29, 2025

Here's an inconsistency I discovered based on two ways of constructing vectors from floats:

julia> using Tracker

julia> f1(x) = [x, x];

julia> f2(x) = vcat(x, x);

julia> x = 1.0;

julia> dy = ones(2);

julia> y1, pb1 = Tracker.forward(f1, x);

julia> y2, pb2 = Tracker.forward(f2, x);

julia> y1 == y2
true

julia> pb1(dy) == pb2(dy)
false

julia> pb1(dy)  # returns a scalar
(2.0,)

julia> pb2(dy)  # returns an array
([2.0],)
@mcabbott mcabbott added the bug Something isn't working label Jan 29, 2025
@mcabbott
Copy link
Member

mcabbott commented Jan 29, 2025

Shorter example:

julia> Tracker.gradient(sumvcat, [1.0, 2.0], [3.0])  # fine
([1.0, 1.0] (tracked), [1.0] (tracked))

julia> Tracker.gradient(sumvcat, 1.0, 2.0)  # creates a vector where it wants a scalar
ERROR: MethodError: Cannot `convert` an object of type Vector{Float64} to an object of type Float64

Stacktrace:
  [1] setproperty!(x::Tracker.Tracked{Float64}, f::Symbol, v::Vector{Float64})
    @ Base ./Base.jl:52
  [2] back(x::Tracker.Tracked{Float64}, Δ::Vector{Float64}, once::Bool)
    @ Tracker ~/.julia/packages/Tracker/6rnwO/src/back.jl:48
  [3] (::Tracker.var"#707#708"{Bool})(x::Tracker.Tracked{Float64}, d::Vector{Float64})
    @ Tracker ~/.julia/packages/Tracker/6rnwO/src/back.jl:38

julia> Tracker.gradient(sumvcat, [1.0, 2.0], 3.0)  # same error with mix of vector & scalar args
ERROR: MethodError: Cannot `convert` an object of type Vector{Float64} to an object of type Float64

Note that I wouldn't call this "inconsistent", it seems to always go wrong. The other function f1(x) = [x, x] is Base.vect and not vcat, see e.g. Meta.@lower [1, 2]. The syntax for vcat has a semicolon, [1; 2].

@mcabbott mcabbott changed the title Inconsistent behavior on vcat of scalars Behavior on vcat of scalars Jan 29, 2025
@ToucheSir
Copy link
Member

Tracker.jl/src/lib/array.jl

Lines 214 to 220 in ce231f9

Δs = [begin
x = map(_ -> :, size(xsi))
i = isempty(x) ? x : Base.tail(x)
d = Δ[start+1:start+size(xsi,1), i...]
start += size(xsi, 1)
d
end for xsi in xs]
looks to be tailored for arrays only. hcat has a similar code path and thus a similar problem. Base.vect does not appear to have a rule, so Tracker traces through it transparently and the individual tracked scalars remain intact.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants