-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace second order argsort with permutation inverse #223
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please benchmark and ship it if indeed faster.
Huge difference on the CPU, a bit faster on the GPU, though some results are surprising, for 10k elements Source# Inverse permutation bench
```elixir
Mix.install([
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
{:exla, github: "elixir-nx/nx", sparse: "exla", override: true},
{:benchee, "~> 1.2"}
])
```
## Section
```elixir
defmodule Bench do
import Nx.Defn
defn inverse_argsort(indices) do
Nx.argsort(indices, type: :u32)
end
defn inverse_indexed_put(indices) do
shape = Nx.shape(indices)
type = Nx.type(indices)
Nx.indexed_put(
Nx.broadcast(Nx.tensor(0, type: type), shape),
Nx.new_axis(indices, -1),
Nx.iota(shape, type: type)
)
end
def run() do
Nx.global_default_backend({EXLA.Backend, client: :host})
key = Nx.Random.key(0)
inputs = %{
"100" =>
Nx.Random.shuffle(key, Nx.iota({100}, type: :u32))
|> elem(0)
|> Nx.backend_transfer(EXLA.Backend),
"1000" =>
Nx.Random.shuffle(key, Nx.iota({1_000}, type: :u32))
|> elem(0)
|> Nx.backend_transfer(EXLA.Backend),
"10000" =>
Nx.Random.shuffle(key, Nx.iota({10_000}, type: :u32))
|> elem(0)
|> Nx.backend_transfer(EXLA.Backend),
"100000" =>
Nx.Random.shuffle(key, Nx.iota({100_000}, type: :u32))
|> elem(0)
|> Nx.backend_transfer(EXLA.Backend),
"1000000" =>
Nx.Random.shuffle(key, Nx.iota({1_000_000}, type: :u32))
|> elem(0)
|> Nx.backend_transfer(EXLA.Backend)
}
Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA)
Benchee.run(
%{
"inverse_argsort" => fn x -> inverse_argsort(x) end,
"inverse_indexed_put" => fn x -> inverse_indexed_put(x) end
},
inputs: inputs
)
end
end
```
```elixir
Bench.run()
``` CPU
GPU
|
@jonatanklosko A bit late to the party. This can be slightly improved by pulling |
The current version may be faster too. The whole purpose iota exists is so the compiler knows it is a monotonic increasing thing, and letting it see exactly where the monotonic increasing thing is may happen with optimizations. Your guess is as good as mine. Your refactoring will be faster for things like pytorch though. |
Intuitively I think keeping atomic operations like iota or broadcast close to where they are used gives the compiler a better hint for fusion and other optimisations. Purely a guess, but in this case I can imagine that if we have iota higher in the graph, so that it is clearly shared, then it may be materialized in GPU memory, while otherwise the compiler could have threads generate subsets of iota as they write the result. Actually, if having it materialized in the memory and shared was clearly best, then the compiler could easily find all iotas of the same shape in the graph and share them (across the whole graph!). Either way, I wouldn't expect it to have a measurable effect. |
indices
is a permutation, applyingargsort
to a permutation is the same as permutation inverse and we can compute it in linear time.