Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace second order argsort with permutation inverse #223

Merged
merged 2 commits into from
Dec 18, 2023

Conversation

jonatanklosko
Copy link
Member

@jonatanklosko jonatanklosko commented Dec 18, 2023

indices is a permutation, applying argsort to a permutation is the same as permutation inverse and we can compute it in linear time.

Copy link
Contributor

@josevalim josevalim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please benchmark and ship it if indeed faster.

@jonatanklosko
Copy link
Member Author

jonatanklosko commented Dec 18, 2023

Huge difference on the CPU, a bit faster on the GPU, though some results are surprising, for 10k elements indexed_put is 2.5x faster, but for for 1M it's suddenly a tiny bit slower (1.1x, but 1.01x on another run, so pretty much the same). Overall an improvement for sure.

Source
# Inverse permutation bench

```elixir
Mix.install([
  {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
  {:exla, github: "elixir-nx/nx", sparse: "exla", override: true},
  {:benchee, "~> 1.2"}
])
```

## Section

```elixir
defmodule Bench do
  import Nx.Defn

  defn inverse_argsort(indices) do
    Nx.argsort(indices, type: :u32)
  end

  defn inverse_indexed_put(indices) do
    shape = Nx.shape(indices)
    type = Nx.type(indices)

    Nx.indexed_put(
      Nx.broadcast(Nx.tensor(0, type: type), shape),
      Nx.new_axis(indices, -1),
      Nx.iota(shape, type: type)
    )
  end

  def run() do
    Nx.global_default_backend({EXLA.Backend, client: :host})

    key = Nx.Random.key(0)

    inputs = %{
      "100" =>
        Nx.Random.shuffle(key, Nx.iota({100}, type: :u32))
        |> elem(0)
        |> Nx.backend_transfer(EXLA.Backend),
      "1000" =>
        Nx.Random.shuffle(key, Nx.iota({1_000}, type: :u32))
        |> elem(0)
        |> Nx.backend_transfer(EXLA.Backend),
      "10000" =>
        Nx.Random.shuffle(key, Nx.iota({10_000}, type: :u32))
        |> elem(0)
        |> Nx.backend_transfer(EXLA.Backend),
      "100000" =>
        Nx.Random.shuffle(key, Nx.iota({100_000}, type: :u32))
        |> elem(0)
        |> Nx.backend_transfer(EXLA.Backend),
      "1000000" =>
        Nx.Random.shuffle(key, Nx.iota({1_000_000}, type: :u32))
        |> elem(0)
        |> Nx.backend_transfer(EXLA.Backend)
    }

    Nx.global_default_backend(EXLA.Backend)
    Nx.Defn.global_default_options(compiler: EXLA)

    Benchee.run(
      %{
        "inverse_argsort" => fn x -> inverse_argsort(x) end,
        "inverse_indexed_put" => fn x -> inverse_indexed_put(x) end
      },
      inputs: inputs
    )
  end
end
```

```elixir
Bench.run()
```
CPU
Operating System: macOS
CPU Information: Apple M1 Pro
Number of Available Cores: 10
Available memory: 32 GB
Elixir 1.15.2
Erlang 26.0.2

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 5 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: 100, 1000, 10000, 100000, 1000000
Estimated total run time: 1.17 min

Benchmarking inverse_argsort with input 100 ...
Benchmarking inverse_argsort with input 1000 ...
Benchmarking inverse_argsort with input 10000 ...
Benchmarking inverse_argsort with input 100000 ...
Benchmarking inverse_argsort with input 1000000 ...
Benchmarking inverse_indexed_put with input 100 ...
Benchmarking inverse_indexed_put with input 1000 ...
Benchmarking inverse_indexed_put with input 10000 ...
Benchmarking inverse_indexed_put with input 100000 ...
Benchmarking inverse_indexed_put with input 1000000 ...

##### With input 100 #####
Name                          ips        average  deviation         median         99th %
inverse_indexed_put       62.37 K       16.03 μs    ±73.45%       13.67 μs       47.08 μs
inverse_argsort           40.77 K       24.53 μs    ±34.16%       22.17 μs       57.46 μs

Comparison: 
inverse_indexed_put       62.37 K
inverse_argsort           40.77 K - 1.53x slower +8.50 μs

##### With input 1000 #####
Name                          ips        average  deviation         median         99th %
inverse_indexed_put       54.67 K       18.29 μs    ±47.13%       17.21 μs       51.92 μs
inverse_argsort            7.91 K      126.42 μs    ±14.12%      121.92 μs      177.19 μs

Comparison: 
inverse_indexed_put       54.67 K
inverse_argsort            7.91 K - 6.91x slower +108.13 μs

##### With input 10000 #####
Name                          ips        average  deviation         median         99th %
inverse_indexed_put       25.04 K      0.0399 ms   ±253.10%      0.0258 ms        0.46 ms
inverse_argsort            0.51 K        1.94 ms    ±16.82%        1.84 ms        2.91 ms

Comparison: 
inverse_indexed_put       25.04 K
inverse_argsort            0.51 K - 48.65x slower +1.90 ms

##### With input 100000 #####
Name                          ips        average  deviation         median         99th %
inverse_indexed_put        5.57 K       0.180 ms    ±87.46%       0.145 ms        1.00 ms
inverse_argsort          0.0452 K       22.10 ms     ±2.80%       22.01 ms       23.77 ms

Comparison: 
inverse_indexed_put        5.57 K
inverse_argsort          0.0452 K - 123.01x slower +21.92 ms

##### With input 1000000 #####
Name                          ips        average  deviation         median         99th %
inverse_indexed_put        520.74        1.92 ms    ±34.22%        1.65 ms        4.33 ms
inverse_argsort              3.51      285.27 ms     ±8.63%      277.20 ms      365.77 ms

Comparison: 
inverse_indexed_put        520.74
inverse_argsort              3.51 - 148.55x slower +283.35 ms
GPU
Operating System: Linux
CPU Information: AMD EPYC 7713 64-Core Processor
Number of Available Cores: 8
Available memory: 29.39 GB
Elixir 1.15.7
Erlang 26.1.2

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 5 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: 100, 1000, 10000, 100000, 1000000
Estimated total run time: 1.17 min

Benchmarking inverse_argsort with input 100 ...
Benchmarking inverse_argsort with input 1000 ...
Benchmarking inverse_argsort with input 10000 ...
Benchmarking inverse_argsort with input 100000 ...
Benchmarking inverse_argsort with input 1000000 ...
Benchmarking inverse_indexed_put with input 100 ...
Benchmarking inverse_indexed_put with input 1000 ...
Benchmarking inverse_indexed_put with input 10000 ...
Benchmarking inverse_indexed_put with input 100000 ...
Benchmarking inverse_indexed_put with input 1000000 ...

##### With input 100 #####
Name                          ips        average  deviation         median         99th %
inverse_indexed_put        7.76 K      128.89 μs    ±17.85%      123.06 μs      193.89 μs
inverse_argsort            6.38 K      156.82 μs    ±26.62%      145.02 μs      264.14 μs

Comparison: 
inverse_indexed_put        7.76 K
inverse_argsort            6.38 K - 1.22x slower +27.92 μs

##### With input 1000 #####
Name                          ips        average  deviation         median         99th %
inverse_indexed_put        7.36 K      135.82 μs    ±15.64%      130.02 μs      204.83 μs
inverse_argsort            6.29 K      159.02 μs    ±24.38%      142.54 μs      257.40 μs

Comparison: 
inverse_indexed_put        7.36 K
inverse_argsort            6.29 K - 1.17x slower +23.20 μs

##### With input 10000 #####
Name                          ips        average  deviation         median         99th %
inverse_indexed_put        7.60 K      131.57 μs    ±17.51%      125.22 μs      197.56 μs
inverse_argsort            2.94 K      339.59 μs    ±11.16%      331.50 μs      445.78 μs

Comparison: 
inverse_indexed_put        7.60 K
inverse_argsort            2.94 K - 2.58x slower +208.03 μs

##### With input 100000 #####
Name                          ips        average  deviation         median         99th %
inverse_indexed_put        6.34 K      157.84 μs    ±23.32%      145.42 μs      240.85 μs
inverse_argsort            5.39 K      185.62 μs    ±13.12%      178.81 μs      268.58 μs

Comparison: 
inverse_indexed_put        6.34 K
inverse_argsort            5.39 K - 1.18x slower +27.78 μs

##### With input 1000000 #####
Name                          ips        average  deviation         median         99th %
inverse_argsort            5.54 K      180.48 μs    ±12.74%      173.94 μs      257.51 μs
inverse_indexed_put        4.86 K      205.66 μs    ±22.16%      189.76 μs      317.96 μs

Comparison: 
inverse_argsort            5.54 K
inverse_indexed_put        4.86 K - 1.14x slower +25.18 μs

@jonatanklosko jonatanklosko merged commit 6c60968 into main Dec 18, 2023
2 checks passed
@jonatanklosko jonatanklosko deleted the jk-second-order-argsort branch December 18, 2023 20:21
@krstopro
Copy link
Member

@jonatanklosko A bit late to the party. This can be slightly improved by pulling Nx.iota out of the function, because the size is the same in every iteration. Is it worth though?

@josevalim
Copy link
Contributor

The current version may be faster too. The whole purpose iota exists is so the compiler knows it is a monotonic increasing thing, and letting it see exactly where the monotonic increasing thing is may happen with optimizations. Your guess is as good as mine. Your refactoring will be faster for things like pytorch though.

@jonatanklosko
Copy link
Member Author

jonatanklosko commented Dec 21, 2023

Intuitively I think keeping atomic operations like iota or broadcast close to where they are used gives the compiler a better hint for fusion and other optimisations. Purely a guess, but in this case I can imagine that if we have iota higher in the graph, so that it is clearly shared, then it may be materialized in GPU memory, while otherwise the compiler could have threads generate subsets of iota as they write the result. Actually, if having it materialized in the memory and shared was clearly best, then the compiler could easily find all iotas of the same shape in the graph and share them (across the whole graph!). Either way, I wouldn't expect it to have a measurable effect.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants