Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU lowering #290

Merged
merged 6 commits into from
Jan 29, 2025
Merged

CPU lowering #290

merged 6 commits into from
Jan 29, 2025

Conversation

wsmoses
Copy link
Member

@wsmoses wsmoses commented Jan 28, 2025

No description provided.

@wsmoses
Copy link
Member Author

wsmoses commented Jan 28, 2025

I'm clearly missing some basic conversion thing here:

If anyone has any idea/can take a look, it would be apprecaited

(base) wmoses@hydra:~/git/Enzyme-JaX$ ./bazel-bin/enzymexlamlir-opt --pass-pipeline="builtin.module(lower-kernel{backend=cpu})" ./test/lit_tests/lowering/cpu.mlir 
 lowered submod: module @cpuoffload1 {
  llvm.func internal unnamed_addr fastcc @throw_boundserror_2676() attributes {dso_local, no_inline, sym_visibility = "private"} {
    llvm.unreachable
  }
  llvm.func @entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
    %0 = llvm.mlir.constant(63 : i32) : i32
    %1 = llvm.mlir.constant(0 : index) : i64
    %2 = llvm.mlir.constant(1 : index) : i64
    %3 = llvm.mlir.constant(40 : index) : i64
    %4 = llvm.load %arg1 : !llvm.ptr -> !llvm.ptr<1>
    omp.parallel {
      omp.wsloop {
        omp.loop_nest (%arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : i64 = (%1, %1, %1, %1, %1, %1) to (%2, %2, %2, %2, %2, %3) step (%2, %2, %2, %2, %2, %2) {
          %5 = llvm.intr.stacksave : !llvm.ptr
          llvm.br ^bb1
        ^bb1:  // pred: ^bb0
          llvm.br ^bb2
        ^bb2:  // pred: ^bb1
          %6 = llvm.trunc %arg6 : i64 to i32
          %7 = llvm.icmp "ugt" %6, %0 : i32
          llvm.cond_br %7, ^bb4, ^bb3
        ^bb3:  // pred: ^bb2
          %8 = llvm.zext %6 : i32 to i64
          %9 = llvm.getelementptr inbounds %4[%8] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i64
          %10 = llvm.load %9 {alignment = 1 : i64} : !llvm.ptr<1> -> i64
          %11 = llvm.mul %10, %10 : i64
          llvm.store %11, %9 {alignment = 1 : i64} : i64, !llvm.ptr<1>
          llvm.br ^bb5
        ^bb4:  // pred: ^bb2
          llvm.call fastcc @throw_boundserror_2676() : () -> ()
          llvm.br ^bb5
        ^bb5:  // 2 preds: ^bb3, ^bb4
          llvm.intr.stackrestore %5 : !llvm.ptr
          llvm.br ^bb6
        ^bb6:  // pred: ^bb5
          omp.yield
        }
      }
      omp.terminator
    }
    llvm.return
  }
}
./test/lit_tests/lowering/cpu.mlir:27:10: error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: omp.parallel
    %0 = enzymexla.kernel_call @kern blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c40) shmem=%c0 (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
         ^
./test/lit_tests/lowering/cpu.mlir:27:10: note: see current operation: 
"omp.parallel"() <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 0>}> ({
  "omp.wsloop"() <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 0, 0>}> ({
    "omp.loop_nest"(%1, %1, %1, %1, %1, %1, %2, %2, %2, %2, %2, %3, %2, %2, %2, %2, %2, %2) ({
    ^bb0(%arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: i64, %arg8: i64):
      %5 = "llvm.intr.stacksave"() : () -> !llvm.ptr
      "llvm.br"()[^bb1] : () -> ()
    ^bb1:  // pred: ^bb0
      "llvm.br"()[^bb2] : () -> ()
    ^bb2:  // pred: ^bb1
      %6 = "llvm.trunc"(%arg6) <{overflowFlags = #llvm.overflow<none>}> : (i64) -> i32
      %7 = "llvm.icmp"(%6, %0) <{predicate = 8 : i64}> : (i32, i32) -> i1
      "llvm.cond_br"(%7)[^bb4, ^bb3] <{operandSegmentSizes = array<i32: 1, 0, 0>}> : (i1) -> ()
    ^bb3:  // pred: ^bb2
      %8 = "llvm.zext"(%6) : (i32) -> i64
      %9 = "llvm.getelementptr"(%4, %8) <{elem_type = i64, inbounds, rawConstantIndices = array<i32: -2147483648>}> : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>
      %10 = "llvm.load"(%9) <{alignment = 1 : i64, ordering = 0 : i64}> : (!llvm.ptr<1>) -> i64
      %11 = "llvm.mul"(%10, %10) <{overflowFlags = #llvm.overflow<none>}> : (i64, i64) -> i64
      "llvm.store"(%11, %9) <{alignment = 1 : i64, ordering = 0 : i64}> : (i64, !llvm.ptr<1>) -> ()
      "llvm.br"()[^bb5] : () -> ()
    ^bb4:  // pred: ^bb2
      "llvm.call"() <{CConv = #llvm.cconv<fastcc>, TailCallKind = #llvm.tailcallkind<none>, callee = @throw_boundserror_2676, fastmathFlags = #llvm.fastmath<none>, op_bundle_sizes = array<i32>, operandSegmentSizes = array<i32: 0, 0>}> : () -> ()
      "llvm.br"()[^bb5] : () -> ()
    ^bb5:  // 2 preds: ^bb3, ^bb4
      "llvm.intr.stackrestore"(%5) : (!llvm.ptr) -> ()
      "llvm.br"()[^bb6] : () -> ()
    ^bb6:  // pred: ^bb5
      "omp.yield"() : () -> ()
    }) : (i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64) -> ()
  }) : () -> ()
  "omp.terminator"() : () -> ()
}) : () -> ()
modOp: module @cpuoffload1 {
  llvm.func internal unnamed_addr fastcc @throw_boundserror_2676() attributes {dso_local, no_inline, sym_visibility = "private"} {
    llvm.unreachable
  }
  llvm.func @entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
    %0 = llvm.mlir.constant(63 : i32) : i32
    %1 = llvm.mlir.constant(0 : index) : i64
    %2 = llvm.mlir.constant(1 : index) : i64
    %3 = llvm.mlir.constant(40 : index) : i64
    %4 = llvm.load %arg1 : !llvm.ptr -> !llvm.ptr<1>
    omp.parallel {
      omp.wsloop {
        omp.loop_nest (%arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : i64 = (%1, %1, %1, %1, %1, %1) to (%2, %2, %2, %2, %2, %3) step (%2, %2, %2, %2, %2, %2) {
          %5 = llvm.intr.stacksave : !llvm.ptr
          llvm.br ^bb1
        ^bb1:  // pred: ^bb0
          llvm.br ^bb2
        ^bb2:  // pred: ^bb1
          %6 = llvm.trunc %arg6 : i64 to i32
          %7 = llvm.icmp "ugt" %6, %0 : i32
          llvm.cond_br %7, ^bb4, ^bb3
        ^bb3:  // pred: ^bb2
          %8 = llvm.zext %6 : i32 to i64
          %9 = llvm.getelementptr inbounds %4[%8] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i64
          %10 = llvm.load %9 {alignment = 1 : i64} : !llvm.ptr<1> -> i64
          %11 = llvm.mul %10, %10 : i64
          llvm.store %11, %9 {alignment = 1 : i64} : i64, !llvm.ptr<1>
          llvm.br ^bb5
        ^bb4:  // pred: ^bb2
          llvm.call fastcc @throw_boundserror_2676() : () -> ()
          llvm.br ^bb5
        ^bb5:  // 2 preds: ^bb3, ^bb4
          llvm.intr.stackrestore %5 : !llvm.ptr
          llvm.br ^bb6
        ^bb6:  // pred: ^bb5
          omp.yield
        }
      }
      omp.terminator
    }
    llvm.return
  }
}
could not convert to LLVM IR
module {
  llvm.func internal unnamed_addr fastcc @throw_boundserror_2676() attributes {dso_local, no_inline, sym_visibility = "private"} {
    llvm.unreachable
  }
  func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> {
    %c = stablehlo.constant dense<0> : tensor<i64>
    %c_0 = stablehlo.constant dense<1> : tensor<i64>
    %c_1 = stablehlo.constant dense<40> : tensor<i64>
    %0 = stablehlo.custom_call @enzymexla_compile_gpu(%arg0) {api_version = 4 : i32, backend_config = {attr = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00"}, output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
    return %0 : tensor<64xi64>
  }
}

@ftynse
Copy link
Collaborator

ftynse commented Jan 28, 2025

Open MP Translation registration was missing

@wsmoses wsmoses merged commit 264115e into main Jan 29, 2025
4 of 8 checks passed
@wsmoses wsmoses deleted the cpu branch January 29, 2025 00:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants