Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: don't unroll Recurrence #1209

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

feat: don't unroll Recurrence #1209

wants to merge 2 commits into from

Conversation

avik-pal
Copy link
Member

needs EnzymeAD/Reactant.jl#565.

AD doesn't seem to work (enzyme.init doesn't work with XLA) cc @wsmoses

module {
  func.func private @"diffeConst{typeof(sumabs2)}(Main.sumabs2)_autodiff"(%arg0: tensor<6x2x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<3x3xf32>, %arg3: tensor<3xf32>, %arg4: tensor<3xf32>, %arg5: tensor<2xui64>, %arg6: tensor<f32>, %arg7: tensor<2xui64>, %arg8: tensor<3x3xf32>, %arg9: tensor<3x3xf32>, %arg10: tensor<3xf32>, %arg11: tensor<3xf32>) -> (tensor<6x2x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>) {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<3x2xf32>
    %c = stablehlo.constant dense<0> : tensor<i64>
    %c_0 = stablehlo.constant dense<6> : tensor<i64>
    %c_1 = stablehlo.constant dense<2> : tensor<i64>
    %c_2 = stablehlo.constant dense<1> : tensor<i64>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<3x2xf32>
    %0 = "enzyme.init"() : () -> !enzyme.Cache<tensor<3x2xf32>>
    %1 = "enzyme.init"() : () -> !enzyme.Cache<tensor<3x3xf32>>
    %2 = "enzyme.init"() : () -> !enzyme.Cache<tensor<2x3xf32>>
    %3 = "enzyme.init"() : () -> !enzyme.Cache<tensor<3x2xf32>>
    %4 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<6x2x3xf32>) -> tensor<3x2x6xf32>
    %5 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %6 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %7 = stablehlo.slice %4 [0:3, 0:2, 0:1] : (tensor<3x2x6xf32>) -> tensor<3x2x1xf32>
    %8 = stablehlo.transpose %7, dims = [2, 1, 0] : (tensor<3x2x1xf32>) -> tensor<1x2x3xf32>
    %9 = stablehlo.reshape %8 : (tensor<1x2x3xf32>) -> tensor<2x3xf32>
    %10 = stablehlo.broadcast_in_dim %arg4, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32>
    %11 = stablehlo.dot_general %arg1, %9, contracting_dims = [0] x [1] : (tensor<3x3xf32>, tensor<2x3xf32>) -> tensor<3x2xf32>
    %12 = stablehlo.broadcast_in_dim %arg3, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32>
    %13 = stablehlo.add %11, %12 : tensor<3x2xf32>
    %14 = stablehlo.add %10, %13 : tensor<3x2xf32>
    %15 = stablehlo.tanh %14 : tensor<3x2xf32>
    %16:10 = stablehlo.while(%iterArg = %c, %iterArg_5 = %5, %iterArg_6 = %6, %iterArg_7 = %arg3, %iterArg_8 = %arg4, %iterArg_9 = %arg5, %iterArg_10 = %c_0, %iterArg_11 = %15, %iterArg_12 = %4, %iterArg_13 = %c) : tensor<i64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<i64>, tensor<3x2xf32>, tensor<3x2x6xf32>, tensor<i64>
     cond {
      %41 = stablehlo.subtract %iterArg_10, %c_1 : tensor<i64>
      %42 = stablehlo.divide %41, %c_2 : tensor<i64>
      %43 = stablehlo.add %42, %c_2 : tensor<i64>
      %44 = stablehlo.compare  LT, %iterArg, %43 : (tensor<i64>, tensor<i64>) -> tensor<i1>
      stablehlo.return %44 : tensor<i1>
    } do {
      %41 = stablehlo.add %iterArg_13, %c_2 : tensor<i64>
      %42 = stablehlo.add %c_1, %iterArg : tensor<i64>
      %43 = stablehlo.subtract %42, %c_2 : tensor<i64>
      %44 = stablehlo.dynamic_slice %iterArg_12, %c, %c, %43, sizes = [3, 2, 1] : (tensor<3x2x6xf32>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<3x2x1xf32>
      %45 = stablehlo.transpose %44, dims = [2, 1, 0] : (tensor<3x2x1xf32>) -> tensor<1x2x3xf32>
      %46 = stablehlo.reshape %45 : (tensor<1x2x3xf32>) -> tensor<2x3xf32>
      "enzyme.push"(%1, %iterArg_6) : (!enzyme.Cache<tensor<3x3xf32>>, tensor<3x3xf32>) -> ()
      "enzyme.push"(%0, %iterArg_11) : (!enzyme.Cache<tensor<3x2xf32>>, tensor<3x2xf32>) -> ()
      %47 = stablehlo.dot_general %iterArg_6, %iterArg_11, contracting_dims = [1] x [0] : (tensor<3x3xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
      %48 = stablehlo.broadcast_in_dim %iterArg_8, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32>
      %49 = stablehlo.add %47, %48 : tensor<3x2xf32>
      "enzyme.push"(%2, %46) : (!enzyme.Cache<tensor<2x3xf32>>, tensor<2x3xf32>) -> ()
      %50 = stablehlo.dot_general %iterArg_5, %46, contracting_dims = [1] x [1] : (tensor<3x3xf32>, tensor<2x3xf32>) -> tensor<3x2xf32>
      %51 = stablehlo.broadcast_in_dim %iterArg_7, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32>
      %52 = stablehlo.add %50, %51 : tensor<3x2xf32>
      %53 = stablehlo.add %49, %52 : tensor<3x2xf32>
      "enzyme.push"(%3, %53) : (!enzyme.Cache<tensor<3x2xf32>>, tensor<3x2xf32>) -> ()
      %54 = stablehlo.tanh %53 : tensor<3x2xf32>
      %55 = stablehlo.add %iterArg, %c_2 : tensor<i64>
      stablehlo.return %55, %iterArg_5, %iterArg_6, %iterArg_7, %iterArg_8, %iterArg_9, %iterArg_10, %54, %iterArg_12, %41 : tensor<i64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<i64>, tensor<3x2xf32>, tensor<3x2x6xf32>, tensor<i64>
    }
    %17 = stablehlo.abs %16#7 : tensor<3x2xf32>
    %18 = stablehlo.transpose %16#8, dims = [2, 1, 0] : (tensor<3x2x6xf32>) -> tensor<6x2x3xf32>
    %19 = stablehlo.transpose %16#1, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %20 = stablehlo.transpose %16#2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %21 = stablehlo.transpose %arg9, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %22 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %23 = stablehlo.broadcast_in_dim %arg6, dims = [] : (tensor<f32>) -> tensor<3x2xf32>
    %24 = stablehlo.multiply %23, %17 : tensor<3x2xf32>
    %25 = stablehlo.add %24, %24 : tensor<3x2xf32>
    %26 = stablehlo.compare  GE, %16#7, %cst_4 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xi1>
    %27 = stablehlo.negate %25 : tensor<3x2xf32>
    %28 = stablehlo.select %26, %25, %27 : tensor<3x2xi1>, tensor<3x2xf32>
    %29:6 = stablehlo.while(%iterArg = %c, %iterArg_5 = %22, %iterArg_6 = %21, %iterArg_7 = %arg10, %iterArg_8 = %arg11, %iterArg_9 = %28) : tensor<i64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3x2xf32>
     cond {
      %41 = stablehlo.compare  LT, %iterArg, %16#9 : (tensor<i64>, tensor<i64>) -> tensor<i1>
      stablehlo.return %41 : tensor<i1>
    } do {
      %41 = stablehlo.add %iterArg, %c_2 : tensor<i64>
      %42 = "enzyme.pop"(%3) : (!enzyme.Cache<tensor<3x2xf32>>) -> tensor<3x2xf32>
      %43 = stablehlo.tanh %42 : tensor<3x2xf32>
      %44 = stablehlo.multiply %43, %43 : tensor<3x2xf32>
      %45 = stablehlo.subtract %cst, %44 : tensor<3x2xf32>
      %46 = stablehlo.multiply %iterArg_9, %45 : tensor<3x2xf32>
      %47 = stablehlo.reduce(%46 init: %cst_3) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor<f32>) -> tensor<3xf32>
      %48 = stablehlo.add %iterArg_7, %47 : tensor<3xf32>
      %49 = "enzyme.pop"(%2) : (!enzyme.Cache<tensor<2x3xf32>>) -> tensor<2x3xf32>
      %50 = stablehlo.dot_general %46, %49, contracting_dims = [1] x [0] : (tensor<3x2xf32>, tensor<2x3xf32>) -> tensor<3x3xf32>
      %51 = stablehlo.add %iterArg_5, %50 : tensor<3x3xf32>
      %52 = stablehlo.add %iterArg_8, %47 : tensor<3xf32>
      %53 = "enzyme.pop"(%1) : (!enzyme.Cache<tensor<3x3xf32>>) -> tensor<3x3xf32>
      %54 = "enzyme.pop"(%0) : (!enzyme.Cache<tensor<3x2xf32>>) -> tensor<3x2xf32>
      %55 = stablehlo.dot_general %46, %54, contracting_dims = [1] x [1] : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x3xf32>
      %56 = stablehlo.add %iterArg_6, %55 : tensor<3x3xf32>
      %57 = stablehlo.dot_general %53, %46, contracting_dims = [0] x [0] : (tensor<3x3xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
      stablehlo.return %41, %51, %56, %48, %52, %57 : tensor<i64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3x2xf32>
    }
    %30 = stablehlo.multiply %15, %15 : tensor<3x2xf32>
    %31 = stablehlo.subtract %cst, %30 : tensor<3x2xf32>
    %32 = stablehlo.multiply %29#5, %31 : tensor<3x2xf32>
    %33 = stablehlo.reduce(%32 init: %cst_3) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor<f32>) -> tensor<3xf32>
    %34 = stablehlo.add %29#3, %33 : tensor<3xf32>
    %35 = stablehlo.dot_general %32, %9, contracting_dims = [1] x [0] : (tensor<3x2xf32>, tensor<2x3xf32>) -> tensor<3x3xf32>
    %36 = stablehlo.reduce(%32 init: %cst_3) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor<f32>) -> tensor<3xf32>
    %37 = stablehlo.add %29#4, %36 : tensor<3xf32>
    %38 = stablehlo.transpose %29#2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %39 = stablehlo.add %35, %29#1 : tensor<3x3xf32>
    %40 = stablehlo.transpose %39, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    return %18, %19, %20, %16#3, %16#4, %arg5, %40, %38, %34, %37 : tensor<6x2x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>
  }
  func.func @main(%arg0: tensor<6x2x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<3x3xf32>, %arg3: tensor<3xf32>, %arg4: tensor<3xf32>, %arg5: tensor<2xui64>) -> (tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<6x2x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>) {
    %c = stablehlo.constant dense<1> : tensor<2xui64>
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<3x3xf32>
    %0:10 = call @"diffeConst{typeof(sumabs2)}(Main.sumabs2)_autodiff"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %cst, %c, %cst_1, %cst_1, %cst_0, %cst_0) : (tensor<6x2x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<f32>, tensor<2xui64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<6x2x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>)
    return %0#6, %0#7, %0#8, %0#9, %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<6x2x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>
  }
}

@avik-pal avik-pal force-pushed the ap/loop_rnn_reactant branch from 8fdb7ef to 4ae54ef Compare January 18, 2025 04:38
Copy link
Contributor

github-actions bot commented Jan 18, 2025

Benchmark Results (ASV)

main 0bbc7d9... main/0bbc7d9f004b86...
basics/overhead 0.127 ± 0.0014 μs 0.126 ± 0.0011 μs 1.01
time_to_load 0.904 ± 0.012 s 0.903 ± 0.0065 s 1

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@avik-pal avik-pal marked this pull request as draft January 18, 2025 05:01
@wsmoses
Copy link
Contributor

wsmoses commented Jan 18, 2025

cc @Pangoraw if an ir test case is helpful!

@Pangoraw
Copy link

Pangoraw commented Jan 18, 2025

Thank you for the sample MLIR. Our current loop analysis cannot figure out the static number of iterations from this condition:

    %16:10 = stablehlo.while(%iterArg = %c, %iterArg_5 = %5, %iterArg_6 = %6, %iterArg_7 = %arg3, %iterArg_8 = %arg4, %iterArg_9 = %arg5, %iterArg_10 = %c_0, %iterArg_11 = %15, %iterArg_12 = %4, %iterArg_13 = %c) : tensor<i64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<i64>, tensor<3x2xf32>, tensor<3x2x6xf32>, tensor<i64>
     cond {
      %41 = stablehlo.subtract %iterArg_10, %c_1 : tensor<i64>
      %42 = stablehlo.divide %41, %c_2 : tensor<i64>
      %43 = stablehlo.add %42, %c_2 : tensor<i64>
      %44 = stablehlo.compare  LT, %iterArg, %43 : (tensor<i64>, tensor<i64>) -> tensor<i1>
      stablehlo.return %44 : tensor<i1>
    } do {

So we should probably update the codegen from @trace for as well or add something like EnzymeAD/Enzyme-JAX#173

Comment on lines 12 to 13
T = size(x, ndims(x))
@trace for i in 2:T
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Pangoraw this is how I am doing the iterations. It is statically inferable from the size of the array. Should I rewrite this in some way?

@wsmoses
Copy link
Contributor

wsmoses commented Jan 18, 2025

EnzymeAD/Enzyme-JAX#173

I'm currently fixing fires on weird execution stuff. @Pangoraw if you have cycles to take/finish up the while dead code limination PR, be my guest! It would be super helpful (especially for differentiation)


function (r::Lux.Recurrence{True})(x::AnyTracedRArray, ps, st::NamedTuple)
if r.ordering isa Lux.TimeLastIndex ||
(r.ordering isa Lux.BatchLastIndex && ndims(x) == 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
(r.ordering isa Lux.BatchLastIndex && ndims(x) == 2)
(r.ordering isa Lux.BatchLastIndex && ndims(x) == 2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants