From 61b59725a0762d7e65ce976b6df50d940e3cecaa Mon Sep 17 00:00:00 2001 From: Mees Frensel Date: Thu, 16 Jan 2025 10:51:46 +0100 Subject: [PATCH] Add elementwise unary operation --- .../Dialect/Cinm/IR/CinmAttributes.td | 18 +- .../cinm-mlir/Dialect/Cinm/IR/CinmOps.td | 20 +- .../cinm-mlir/Dialect/UPMEM/IR/UPMEMBase.h | 5 + .../lib/Conversion/CinmToCnm/CinmToCnm.cpp | 39 + .../Cinm/IR/CinmTilingImplementations.cpp | 4 + .../UPMEMCpp/UPMEMTranslateRegistration.cpp | 15 +- .../Target/UPMEMCpp/UPMEMTranslateToCpp.cpp | 41 +- cinnamon/samples/asdf.mlir | 292 ------- cinnamon/samples/dorado/dorado.cpp | 791 ++++++++++++++++++ cinnamon/samples/dorado/dorado.mlir | 345 ++++++++ cinnamon/testbench/lib/dpu/expf.c | 16 + justfile | 26 +- 12 files changed, 1286 insertions(+), 326 deletions(-) delete mode 100644 cinnamon/samples/asdf.mlir create mode 100644 cinnamon/samples/dorado/dorado.cpp create mode 100644 cinnamon/samples/dorado/dorado.mlir diff --git a/cinnamon/include/cinm-mlir/Dialect/Cinm/IR/CinmAttributes.td b/cinnamon/include/cinm-mlir/Dialect/Cinm/IR/CinmAttributes.td index d4e2b78..479f19c 100644 --- a/cinnamon/include/cinm-mlir/Dialect/Cinm/IR/CinmAttributes.td +++ b/cinnamon/include/cinm-mlir/Dialect/Cinm/IR/CinmAttributes.td @@ -12,7 +12,23 @@ include "cinm-mlir/Dialect/Cinm/IR/CinmBase.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/CommonAttrConstraints.td" - +def Cinm_UnaryOp : I64EnumAttr<"Cinm_UnaryOp", "", [ + I64EnumAttrCase<"exp", 0>, + I64EnumAttrCase<"log", 1>, + I64EnumAttrCase<"abs", 2>, + I64EnumAttrCase<"ceil", 3>, + I64EnumAttrCase<"floor", 4>, + I64EnumAttrCase<"negf", 5>, + I64EnumAttrCase<"reciprocal", 6>, + I64EnumAttrCase<"round", 7>, + I64EnumAttrCase<"sqrt", 8>, + I64EnumAttrCase<"rsqrt", 9>, + I64EnumAttrCase<"square", 10>, + I64EnumAttrCase<"tanh", 11>, + I64EnumAttrCase<"erf", 12> +]> { + let cppNamespace = "::mlir::cinm"; +} def Cinm_ScanMethodAttr : I64EnumAttr< "ScanMethod", "", diff --git a/cinnamon/include/cinm-mlir/Dialect/Cinm/IR/CinmOps.td b/cinnamon/include/cinm-mlir/Dialect/Cinm/IR/CinmOps.td index d4428d7..f2291f5 100644 --- a/cinnamon/include/cinm-mlir/Dialect/Cinm/IR/CinmOps.td +++ b/cinnamon/include/cinm-mlir/Dialect/Cinm/IR/CinmOps.td @@ -60,10 +60,28 @@ class Cinm_Bitwise_Op traits = []> } +def Cinm_Elementwise_Unary_Op : Cinm_Op<"op.element_wise", [Pure, SameOperandsAndResultType]> { + let summary = "Generic elementwise unary operation on a tensor"; + let description = [{ + Perform a unary operation on each element of a tensor. Example: + ``` + %r = cinm.compute attributes { workgroupShape = array } -> tensor<512xf32> { + %sqrts = cinm.op.element_wise sqrt (%input) : tensor<512xf32> + cinm.yield %sqrts : tensor<512xf32> + ``` + }]; + let arguments = (ins + Cinm_UnaryOp:$method, + AnyRankedTensor:$input + ); + let results = (outs + AnyRankedTensor:$result + ); - + let assemblyFormat = "$method `(` $input `)` attr-dict `:` type($input)"; +} // Concrete op definitions diff --git a/cinnamon/include/cinm-mlir/Dialect/UPMEM/IR/UPMEMBase.h b/cinnamon/include/cinm-mlir/Dialect/UPMEM/IR/UPMEMBase.h index da8d193..c6067c2 100644 --- a/cinnamon/include/cinm-mlir/Dialect/UPMEM/IR/UPMEMBase.h +++ b/cinnamon/include/cinm-mlir/Dialect/UPMEM/IR/UPMEMBase.h @@ -4,6 +4,11 @@ #pragma once +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" diff --git a/cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp b/cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp index 3c27eda..04f40ce 100644 --- a/cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp +++ b/cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -572,6 +573,43 @@ struct ConvertElementWiseToCnm : public OpConversionPattern { } }; +struct ConvertElementWiseUnaryToCnm : OpConversionPattern { + explicit ConvertElementWiseUnaryToCnm(MLIRContext *ctx) : OpConversionPattern(ctx) { + this->setHasBoundedRewriteRecursion(); + } + + LogicalResult matchAndRewrite(cinm::Elementwise_Unary_Op op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + cinm::ComputeOp computeBlock = getEnclosingComputeBlock(op); + auto workgroup = builder.create(computeBlock.getCnmWorkgroupType()); + + auto outputInit = builder.create(op.getResult().getType(), builder.getZeroAttr(op.getResult().getType())); + + SmallVector newResults; + const auto conversionResult = convertCinmToCnm( + builder, op, workgroup.getResult(), computeBlock, {}, + adaptor.getOperands(), ValueRange{outputInit}, op->getResults(), + newResults, + [&](ImplicitLocOpBuilder &builder, ValueRange inputs, ValueRange outputs) { + + builder.create(TypeRange{}, ValueRange(inputs), ValueRange(outputs), + linalg::UnaryFnAttr::get(builder.getContext(), static_cast(op.getMethod())), + linalg::TypeFnAttr::get(builder.getContext(), linalg::TypeFn::cast_signed)); + }); + + if (conversionResult.failed()) { + return failure(); + } + + rewriter.replaceOp(op, newResults); + + return success(); + } +}; + LogicalResult computeScatterMapForGemm(cnm::BufferType bufferTyAB, int64_t rowsA, int64_t colsB, AffineMap &scatterA, AffineMap &scatterB, @@ -894,6 +932,7 @@ void populateCinmRewritePatterns(RewritePatternSet &patterns, arith::DivFOp, false>>(ctx); patterns.insert>(ctx); + patterns.insert(ctx); // matmul patterns.insert(ctx); patterns.insert(ctx); diff --git a/cinnamon/lib/Dialect/Cinm/IR/CinmTilingImplementations.cpp b/cinnamon/lib/Dialect/Cinm/IR/CinmTilingImplementations.cpp index 8fe01cd..de21a24 100644 --- a/cinnamon/lib/Dialect/Cinm/IR/CinmTilingImplementations.cpp +++ b/cinnamon/lib/Dialect/Cinm/IR/CinmTilingImplementations.cpp @@ -184,8 +184,12 @@ TilingResult2 tileElementWiseBinaryOp(OpBuilder &builder0, OP op, // This is the max number of reductions we can theoretically do on // a single CNM.launch. + // FIXME: using reduceClusterSize calculates a too large size for (3, 8192, f32), returns 4096 should be 2048 auto reduceClusterSize = params.reduceClusterSize(3, numElements, tensorTy.getElementType()); + if (reduceClusterSize == 4096) { + reduceClusterSize = 2048; + } // We need the actual tile size to not exceed that number, and // be able to divide the input by the working group size. if (reduceClusterSize * wgSize >= numElements) { diff --git a/cinnamon/lib/Target/UPMEMCpp/UPMEMTranslateRegistration.cpp b/cinnamon/lib/Target/UPMEMCpp/UPMEMTranslateRegistration.cpp index 661072c..b3c8fbf 100644 --- a/cinnamon/lib/Target/UPMEMCpp/UPMEMTranslateRegistration.cpp +++ b/cinnamon/lib/Target/UPMEMCpp/UPMEMTranslateRegistration.cpp @@ -9,19 +9,14 @@ #include "cinm-mlir/Target/UPMEMCpp/UPMEMCppEmitter.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Dialect.h" -#include "mlir/Target/Cpp/CppEmitter.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "llvm/Support/CommandLine.h" -#include "mlir/InitAllDialects.h" using namespace mlir; using namespace mlir::upmem_emitc; @@ -45,6 +40,12 @@ void mlir::upmem_emitc::registerUPMEMCppTranslation() { }, [](DialectRegistry ®istry) { // clang-format off + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); registry.insert(); // clang-format on }); diff --git a/cinnamon/lib/Target/UPMEMCpp/UPMEMTranslateToCpp.cpp b/cinnamon/lib/Target/UPMEMCpp/UPMEMTranslateToCpp.cpp index 3a099a9..69e9cad 100644 --- a/cinnamon/lib/Target/UPMEMCpp/UPMEMTranslateToCpp.cpp +++ b/cinnamon/lib/Target/UPMEMCpp/UPMEMTranslateToCpp.cpp @@ -14,30 +14,22 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/Support/IndentedOstream.h" #include "mlir/Target/Cpp/CppEmitter.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringMap.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include -#include #include -#include #include #include #include -#include #include -#include #include #include #include @@ -389,7 +381,7 @@ static LogicalResult printOperation(CppEmitter &emitter, if (arith::ConstantOp staticSize = dyn_cast(size.getDefiningOp())) { size_t remainingElements = - staticSize.getValueAttr().dyn_cast().getInt(); + dyn_cast(staticSize.getValueAttr()).getInt(); size_t offset = 0; while (remainingElements > 0) { size_t chunkSize = std::min(2048lu / elementSize, remainingElements); @@ -648,17 +640,28 @@ static LogicalResult printOperation(CppEmitter &emitter, arith::XOrIOp op) { return printBinaryOperation(emitter, op.getOperation(), "^"); } -static LogicalResult printOperation(CppEmitter &emitter, LLVM::ExpOp op) { - if (emitter.emitAssignPrefix(*op.getOperation()).failed()) { +static LogicalResult printMathOperation(CppEmitter &emitter, Operation *op, StringRef mathOp) { + if (emitter.emitAssignPrefix(*op).failed()) { return failure(); } - emitter.ostream() << "expf(" << emitter.getOrCreateName(op.getOperand()) - << ")"; + emitter.ostream() << mathOp << "(" << emitter.getOrCreateName(op->getOperand(0)) << ")"; return success(); } +static LogicalResult printOperation(CppEmitter &emitter, LLVM::ExpOp op) { + return printMathOperation(emitter, op.getOperation(), "expf"); +} + +static LogicalResult printOperation(CppEmitter &emitter, math::AbsFOp op) { + return printMathOperation(emitter, op.getOperation(), "absf"); +} + +static LogicalResult printOperation(CppEmitter &emitter, math::RsqrtOp op) { + return printMathOperation(emitter, op.getOperation(), "rsqrt"); +} + static LogicalResult printOperation(CppEmitter &emitter, cf::BranchOp branchOp) { raw_ostream &os = emitter.ostream(); @@ -1332,13 +1335,13 @@ LogicalResult CppEmitter::emitLabel(Block &block) { return success(); } -LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { - if (dyn_cast(op)) { +LogicalResult CppEmitter::emitOperation(Operation &operation, bool trailingSemicolon) { + if (dyn_cast(operation)) { return success(); } LogicalResult status = - llvm::TypeSwitch(&op) + llvm::TypeSwitch(&operation) // Builtin ops. .Case( [&](auto op) { return printOperation(*this, op); }) @@ -1456,6 +1459,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { .Case( [&](auto op) { return printOperation(*this, op); }) .Case([&](auto op) { return printOperation(*this, op); }) + .Case([&](auto op) { return printOperation(*this, op); }) + .Case([&](auto op) { return printOperation(*this, op); }) .Case( [&](auto op) { return printOperation(*this, op); }) .Case( @@ -1470,7 +1475,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { [&](auto op) { return printOperation(*this, op); }) // [&](auto op) { skipSemicolon = true; return success(); }) .Default([&](Operation *) { - return op.emitOpError("unable to find printer for op"); + return operation.emitOpError("unable to find printer for op"); }); if (failed(status)) diff --git a/cinnamon/samples/asdf.mlir b/cinnamon/samples/asdf.mlir deleted file mode 100644 index 8d4a7d8..0000000 --- a/cinnamon/samples/asdf.mlir +++ /dev/null @@ -1,292 +0,0 @@ -// These dims and sizes aren't correct for this model -// dim: 512 -// hidden_dim: 1536 -// kv_dim: 512 -// kv_mul: 1 -// n_blocks: 18 (layers) -// n_heads: 8 -// head_size: 64 -// vocab_size: 32000 > should be deleted? -// seq_len T: 512 - -func.func @forward(%token : index, %pos : index, - // state - %kc : memref<18x512x512xf32>, // blocks, T, C - %vc : memref<18x512x512xf32>, // blocks, T, C - // weights - %embedding_table : tensor<32000x512xf32>, - %rms_att_weights : tensor<18x512xf32>, - %wq : tensor<18x512x512xf32>, - %wk : tensor<18x512x512xf32>, - %wv : tensor<18x512x512xf32>, - %wo : tensor<18x512x512xf32>, - %w1 : tensor<18x1536x512xf32>, - %w2 : tensor<18x512x1536xf32>, - %w3 : tensor<18x1536x512xf32>, - %rms_ffn_weights : tensor<18x512xf32>, - %rms_final_weight : tensor<512xf32>, - %wcls : tensor<32000x512xf32> -) -> tensor<32000xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %n_blocks = arith.constant 18 : index - %head_size = arith.constant 64 : index - %head_size_f32 = arith.constant 64.0 : f32 - %dim = arith.constant 512 : index - %len = arith.constant 512 : index - %hidden_dim = arith.constant 1536 : index - - // Required by language limitations - %c0f = arith.constant 0.0 : f32 - %c1f = arith.constant 1.0 : f32 - %c10000f = arith.constant 10000.0 : f32 - - %content_row = tensor.extract_slice %embedding_table [%token, 0] [1, 512] [1, 1] : tensor<32000x512xf32> to tensor<512xf32> - - %x = scf.for %layer = %c0 to %n_blocks step %c1 iter_args(%x = %content_row) -> (tensor<512xf32>) { - %rms_att_weight = tensor.extract_slice %rms_att_weights [%layer, 0] [1, 512] [1, 1] : tensor<18x512xf32> to tensor<512xf32> - %xb = func.call @rmsnorm(%x, %rms_att_weight) : (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32> - - // qkv matmuls - %wqs = tensor.extract_slice %wq [%layer, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<18x512x512xf32> to tensor<512x512xf32> - %wks = tensor.extract_slice %wk [%layer, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<18x512x512xf32> to tensor<512x512xf32> - %wvs = tensor.extract_slice %wv [%layer, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<18x512x512xf32> to tensor<512x512xf32> - %q, %k, %v = cinm.compute attributes { workgroupShape = array } -> tensor<512xf32>, tensor<512xf32>, tensor<512xf32> { - %q = cinm.op.gemv %wqs, %xb : (tensor<512x512xf32>, tensor<512xf32>) -> tensor<512xf32> - %k = cinm.op.gemv %wks, %xb : (tensor<512x512xf32>, tensor<512xf32>) -> tensor<512xf32> - %v = cinm.op.gemv %wvs, %xb : (tensor<512x512xf32>, tensor<512xf32>) -> tensor<512xf32> - cinm.yield %q, %k, %v : tensor<512xf32>, tensor<512xf32>, tensor<512xf32> - } - - // RoPE relative positional encoding: complex-valued rotate q and k in each head - %posi = arith.index_cast %pos : index to i64 - %posf = arith.uitofp %posi : i64 to f32 - %q2, %k2 = scf.for %i = %c0 to %dim step %c2 iter_args(%qi = %q, %ki = %k) -> (tensor<512xf32>, tensor<512xf32>) { - %head_dim = arith.remui %i, %head_size : index - %head_dimi = arith.index_cast %head_dim : index to i64 - %head_dimf = arith.uitofp %head_dimi : i64 to f32 - %0 = arith.divf %head_dimf, %head_size_f32 : f32 - %1 = math.powf %c10000f, %0 : f32 - %freq = arith.divf %c1f, %1 : f32 - %val = arith.mulf %posf, %freq : f32 - %fcr = math.cos %val : f32 - %fci = math.sin %val : f32 - - %qr = func.call @rot(%qi, %i, %fcr, %fci) : (tensor<512xf32>, index, f32, f32) -> tensor<512xf32> - - %cond = arith.cmpi ult, %i, %dim : index - %kr = scf.if %cond -> (tensor<512xf32>) { - %kr = func.call @rot(%ki, %i, %fcr, %fci) : (tensor<512xf32>, index, f32, f32) -> tensor<512xf32> - scf.yield %kr : tensor<512xf32> - } else { - scf.yield %ki : tensor<512xf32> - } - - scf.yield %qr, %kr : tensor<512xf32>, tensor<512xf32> - } - - %kmr = bufferization.to_memref %k2 : memref<512xf32> - %vmr = bufferization.to_memref %v : memref<512xf32> - - %kcd = memref.subview %kc [%layer, %pos, 0] [1, 1, 512] [1, 1, 1] : memref<18x512x512xf32> to memref<512xf32, strided<[1], offset: ?>> // blocks, T, C - %vcd = memref.subview %vc [%layer, %pos, 0] [1, 1, 512] [1, 1, 1] : memref<18x512x512xf32> to memref<512xf32, strided<[1], offset: ?>> // blocks, T, C - - memref.copy %vmr, %vcd : memref<512xf32> to memref<512xf32, strided<[1], offset: ?>> - memref.copy %kmr, %kcd : memref<512xf32> to memref<512xf32, strided<[1], offset: ?>> - - // multi head attention - %lkc = memref.subview %kc [%layer, 0, 0] [1, 512, 512] [1, 1, 1] : memref<18x512x512xf32> to memref<512x512xf32, strided<[512, 1], offset: ?>> // blocks, T, C - %lkc2 = bufferization.to_tensor %lkc : memref<512x512xf32, strided<[512, 1], offset: ?>> - %lvc = memref.subview %vc [%layer, 0, 0] [1, 512, 512] [1, 1, 1] : memref<18x512x512xf32> to memref<512x512xf32, strided<[512, 1], offset: ?>> // blocks, T, C - %lvc2 = bufferization.to_tensor %lvc : memref<512x512xf32, strided<[512, 1], offset: ?>> - %xb2 = func.call @mha(%q2, %lkc2, %lvc2, %pos) : (tensor<512xf32>, tensor<512x512xf32>, tensor<512x512xf32>, index) -> tensor<512xf32> // T, C - - %wo_slice = tensor.extract_slice %wo [%layer, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<18x512x512xf32> to tensor<512x512xf32> - %xb4 = cinm.compute attributes { workgroupShape = array } -> tensor<512xf32> { - // final matmul to get the output of the attention - %xb3 = cinm.op.gemv %wo_slice, %xb2 : (tensor<512x512xf32>, tensor<512xf32>) -> tensor<512xf32> - - // residual connection back into x - %xb4 = cinm.op.add %x, %xb3 : tensor<512xf32> - cinm.yield %xb4 : tensor<512xf32> - } - - // ffn rmsnorm - %rms_ffn_weight = tensor.extract_slice %rms_ffn_weights [%layer, 0] [1, 512] [1, 1] : tensor<18x512xf32> to tensor<512xf32> - %xb5 = func.call @rmsnorm(%xb4, %rms_ffn_weight) : (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32> - - // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) - // first calculate self.w1(x) and self.w3(x) - %w1_slice = tensor.extract_slice %w1 [%layer, 0, 0] [1, 1536, 512] [1, 1, 1] : tensor<18x1536x512xf32> to tensor<1536x512xf32> - %w3_slice = tensor.extract_slice %w3 [%layer, 0, 0] [1, 1536, 512] [1, 1, 1] : tensor<18x1536x512xf32> to tensor<1536x512xf32> - %hb1, %hb2 = cinm.compute attributes { workgroupShape = array } -> tensor<1536xf32>, tensor<1536xf32> { - %hb1 = cinm.op.gemv %w1_slice, %xb5 : (tensor<1536x512xf32>, tensor<512xf32>) -> tensor<1536xf32> - %hb2 = cinm.op.gemv %w3_slice, %xb5 : (tensor<1536x512xf32>, tensor<512xf32>) -> tensor<1536xf32> - cinm.yield %hb1, %hb2 : tensor<1536xf32>, tensor<1536xf32> - } - - // SwiGLU non-linearity - %hb3 = scf.for %i = %c0 to %hidden_dim step %c1 iter_args(%hb = %hb1) -> (tensor<1536xf32>) { - %0 = tensor.extract %hb [%i] : tensor<1536xf32> - %1 = tensor.extract %hb2 [%i] : tensor<1536xf32> - %2 = math.exp %0 : f32 - %3 = arith.addf %c1f, %2 : f32 - %4 = arith.divf %c1f, %3 : f32 - %5 = arith.mulf %1, %4 : f32 - %hbr = tensor.insert %5 into %hb [%i] : tensor<1536xf32> - scf.yield %hbr : tensor<1536xf32> - } - - %w2_slice = tensor.extract_slice %w2 [%layer, 0, 0] [1, 512, 1536] [1, 1, 1] : tensor<18x512x1536xf32> to tensor<512x1536xf32> - %xb7 = cinm.compute attributes { workgroupShape = array } -> tensor<512xf32> { - // final matmul to get the output of the ffn - %xb6 = cinm.op.gemv %w2_slice, %hb3 : (tensor<512x1536xf32>, tensor<1536xf32>) -> tensor<512xf32> - - // residual connection - %xb7 = cinm.op.add %x, %xb6 : tensor<512xf32> - cinm.yield %xb7 : tensor<512xf32> - } - - scf.yield %xb7 : tensor<512xf32> - } - - %x2 = func.call @rmsnorm(%x, %rms_final_weight) : (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32> - %logits = cinm.compute attributes { workgroupShape = array } -> tensor<32000xf32> { - %wcls2 = tensor.pad %wcls low[0,0] high[%hidden_dim,0] { - ^bb0(%arg1: index, %arg2: index): - tensor.yield %c0f : f32 - } : tensor<32000x512xf32> to tensor<32768x512xf32> - %logits = cinm.op.gemv %wcls2, %x2 : (tensor<32768x512xf32>, tensor<512xf32>) -> tensor<32768xf32> - %logits2 = tensor.extract_slice %logits [0] [32000] [1] : tensor<32768xf32> to tensor<32000xf32> - cinm.yield %logits2 : tensor<32000xf32> - } - - return %logits : tensor<32000xf32> -} - -func.func @rot(%v: tensor<512xf32>, %i: index, %fcr : f32, %fci : f32) -> tensor<512xf32> { - %c1 = arith.constant 1 : index - %i2 = arith.addi %i, %c1 : index - %v0 = tensor.extract %v [%i] : tensor<512xf32> - %v1 = tensor.extract %v [%i2] : tensor<512xf32> - %0 = arith.mulf %v0, %fcr : f32 - %1 = arith.mulf %v1, %fci : f32 - %2 = arith.subf %0, %1 : f32 - %r0 = tensor.insert %2 into %v[%i] : tensor<512xf32> - %3 = arith.mulf %v0, %fci : f32 - %4 = arith.mulf %v1, %fcr : f32 - %5 = arith.addf %3, %4 : f32 - %r1 = tensor.insert %2 into %r0[%i] : tensor<512xf32> - return %r1 : tensor<512xf32> -} - - -// Q: features, KC: sequence length x features, VC: sequence length x features -func.func @mha(%q: tensor<512xf32>, %kc: tensor<512x512xf32>, %vc: tensor<512x512xf32>, %pos: index) -> tensor<512xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %nheads = arith.constant 8 : index - %head_dim = arith.constant 64 : index - %c0f = arith.constant 0.0 : f32 - %scale = arith.constant 8.0 : f32 // sqrt(head_dim) - %ninf = arith.constant 0xFF800000 : f32 - - %pos2 = arith.addi %pos, %c1 : index - - %attn_init = tensor.generate { - ^bb0(%arg1: index): - tensor.yield %ninf : f32 - } : tensor<512xf32> - - %xb_init = tensor.empty() : tensor<512xf32> - %xb = scf.for %head = %c0 to %nheads step %c1 iter_args(%xbi = %xb_init) -> (tensor<512xf32>) { - %hoff = arith.muli %head, %head_dim : index - - %attn = scf.for %i = %c0 to %pos2 step %c1 iter_args(%attn_i = %attn_init) -> (tensor<512xf32>) { - %qs = tensor.extract_slice %q [%hoff] [64] [1] : tensor<512xf32> to tensor<64xf32> - %k = tensor.extract_slice %kc [%i, %hoff] [1, 64] [1, 1] : tensor<512x512xf32> to tensor<64xf32> - %score = cinm.compute attributes { workgroupShape = array } -> f32 { - %0 = cinm.op.mul %qs, %k : tensor<64xf32> - %1 = cinm.op.reduce add (%0) : tensor<64xf32> - %2 = arith.divf %1, %scale : f32 - cinm.yield %2 : f32 - } - %attn_i2 = tensor.insert %score into %attn_i [%i] : tensor<512xf32> - scf.yield %attn_i2 : tensor<512xf32> - } - - %attn3 = func.call @softmax(%attn) : (tensor<512xf32>) -> tensor<512xf32> - - %xb_slice_init = tensor.generate { - ^bb0(%arg1: index): - tensor.yield %c0f : f32 - } : tensor<64xf32> - - %xb_slice = scf.for %i = %c0 to %pos2 step %c1 iter_args(%xb_slice_i = %xb_slice_init) -> (tensor<64xf32>) { - %v = tensor.extract_slice %vc [%i, %hoff] [1, 64] [1, 1] : tensor<512x512xf32> to tensor<64xf32> - %a = tensor.extract %attn3 [%i] : tensor<512xf32> - %xb_slice = cinm.compute attributes { workgroupShape = array } -> tensor<64xf32> { - %0 = cinm.op.muls %v, %a : tensor<64xf32> - %1 = cinm.op.add %xb_slice_i, %0 : tensor<64xf32> - cinm.yield %1 : tensor<64xf32> - } - scf.yield %xb_slice : tensor<64xf32> - } - - %xbr = tensor.insert_slice %xb_slice into %xbi [%hoff] [64] [1] : tensor<64xf32> into tensor<512xf32> - scf.yield %xbr : tensor<512xf32> - } - - return %xb : tensor<512xf32> -} - -func.func @rmsnorm(%v : tensor<512xf32>, %w : tensor<512xf32>) -> tensor<512xf32> { - %epsilon = arith.constant 1.0e-5 : f32 - %c1 = arith.constant 1.0 : f32 - %len = arith.constant 512.0 : f32 - - %r = cinm.compute attributes { workgroupShape = array } -> tensor<512xf32> { - %0 = cinm.op.mul %v, %v : tensor<512xf32> - %ss = cinm.op.reduce add (%0) : tensor<512xf32> - %s0 = arith.divf %ss, %len : f32 - %s1 = arith.addf %s0, %epsilon : f32 - %s = math.rsqrt %s1 : f32 - %x = cinm.op.muls %v, %s : tensor<512xf32> - %r = cinm.op.mul %x, %w : tensor<512xf32> - cinm.yield %r : tensor<512xf32> - } - return %r : tensor<512xf32> -} - -//func.func @rmsnorm(%v : tensor<262144xf32>, %w : tensor<262144xf32>) -> tensor<262144xf32> { -// %epsilon = arith.constant 1.0e-5 : f32 -// %c1 = arith.constant 1.0 : f32 -// %len = arith.constant 262144.0 : f32 -// -// %r = cinm.compute attributes { workgroupShape = array } -> tensor<262144xf32> { -// %0 = cinm.op.mul %v, %v : tensor<262144xf32> -// %ss = cinm.op.reduce add (%0) : tensor<262144xf32> -// %s0 = arith.divf %ss, %len : f32 -// %s1 = arith.addf %s0, %epsilon : f32 -// %s = math.rsqrt %s1 : f32 -// %x = cinm.op.muls %v, %s : tensor<262144xf32> -// %r = cinm.op.mul %x, %w : tensor<262144xf32> -// cinm.yield %r : tensor<262144xf32> -// } -// return %r : tensor<262144xf32> -//} - -func.func @softmax(%vec : tensor<512xf32>) -> tensor<512xf32> { - %r = cinm.compute attributes { workgroupShape = array } -> tensor<512xf32> { - %max = cinm.op.reduce max (%vec) : tensor<512xf32> - %t = cinm.op.subs %vec, %max : tensor<512xf32> - %shape = tensor.empty() : tensor<512xf32> - %e = linalg.exp ins(%t : tensor<512xf32>) outs(%shape : tensor<512xf32>) -> tensor<512xf32> - %s = cinm.op.reduce add (%e) : tensor<512xf32> - %r = cinm.op.divs %e, %s : tensor<512xf32> - cinm.yield %r : tensor<512xf32> - } - - return %r : tensor<512xf32> -} diff --git a/cinnamon/samples/dorado/dorado.cpp b/cinnamon/samples/dorado/dorado.cpp new file mode 100644 index 0000000..a00ffd0 --- /dev/null +++ b/cinnamon/samples/dorado/dorado.cpp @@ -0,0 +1,791 @@ +/* Inference for Llama-2 Transformer model in pure C, based on + * https://github.com/karpathy/llama2.c/ with modifications to run the inference + * using cinm */ + +/* + MIT License + + Copyright (c) 2023 Andrej + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined _WIN32 +#include "win.h" +#else +#include +#include +#endif + +// ---------------------------------------------------------------------------- +// Transformer model + +typedef struct { + uint32_t dim; // transformer dimension (512) + uint32_t hidden_dim; // for ffn layers (2048) + uint32_t n_layers; // number of layers (18) + uint32_t n_heads; // number of query heads (8) + uint32_t seq_len; // max sequence length (512) +} Config; + +// Struct contains all model weights +typedef struct { + // weights for rmsnorms + float *rms1; // (layer, dim) rmsnorm weights + float *rms2; // (layer, dim) + // weights for matmuls. note dim == n_heads * head_size + float *wqkv; // (layer, dim, n_heads * head_size * 3) + float *wo; // (layer, n_heads * head_size, dim) + // weights for ffn + float *w1; // (layer, hidden_dim, dim) + float *w2; // (layer, dim, hidden_dim) + float *upscale; // (dim, dim x2) upscaling layer, must reshape to len x2 + float *crf; // (dim x2, 4096) +} TransformerWeights; + +typedef struct { + // current wave of activations + float *x; // activation at current time stamp (dim, seq_len) + float *xb; // same, but inside a residual branch (dim, seq_len) + float *hb; // buffer for hidden dimension in the ffn (hidden_dim,) + float *qkv; // query, key, value (T, C*3,) + float *att; // buffer for scores/attention values (n_heads, seq_len) + float *cos_freqs; // Used by RoPE + float *sin_freqs; // Used by RoPE + float *upscaled; // buffer for upscaled output before crf (dim, seq_len*2) + float *crf; // buffer for crf output (4096, seq_len*2) +} RunState; + +typedef struct { + Config config; // the hyperparameters of the architecture (the blueprint) + TransformerWeights weights; // the weights of the model + RunState state; // buffers for the "wave" of activations in the forward pass + // some more state needed to properly clean up the memory mapping (sigh) + int fd; // file descriptor for memory mapping + float *data; // memory mapped data pointer + ssize_t file_size; // size of the checkpoint file in bytes +} Transformer; + +void malloc_run_state(RunState *s, Config *p) { + // we calloc instead of malloc to keep valgrind happy + s->x = (float *)calloc(p->seq_len * p->dim, sizeof(float)); + s->xb = (float *)calloc(p->seq_len * p->dim, sizeof(float)); + s->hb = (float *)calloc(p->seq_len * p->hidden_dim * 2, sizeof(float)); + s->qkv = (float *)calloc(p->seq_len * p->dim * 3, sizeof(float)); + s->att = (float *)calloc(p->n_heads * p->seq_len, sizeof(float)); + s->upscaled = (float *)calloc(p->dim * p->seq_len * 2, sizeof(float)); + s->crf = (float *)calloc(4096 * p->seq_len * 2, sizeof(float)); // TODO: check dim + // ensure all mallocs went fine + if (!s->x || !s->xb || !s->hb || !s->qkv || !s->att || !s->upscaled || !s->crf) { + fprintf(stderr, "malloc failed!\n"); + exit(EXIT_FAILURE); + } +} + +void free_run_state(RunState *s) { + free(s->x); + free(s->xb); + free(s->hb); + free(s->qkv); + free(s->att); + free(s->upscaled); + free(s->crf); +} + +void create_rope_freqs(RunState *s, Config *p) { + const float theta = 10000.0f; + const int64_t max_seq_len = 2048; + + double *vec = (double *)calloc(p->dim / 2, sizeof(double)); + if (!vec) {fprintf(stderr, "malloc failed!\n"); exit(EXIT_FAILURE);} + + for (int i = 0; i < p->dim / 2; i++) { + double a = i * 2; + vec[i] = 1.0 / std::pow(static_cast(theta), a / static_cast(p->dim)); + } + + // torch::arange(max_seq_len).reshape({max_seq_len, 1, 1, 1}) * inv_freq [< invfreq is vec] + // Causes broadcasting semantics resulting in {max_seq_len, 1, 1, dim/2} or in our case {max_seq_len, dim/2} + float *cos_freqs = (float *)calloc(max_seq_len * p->dim / 2, sizeof(float)); + float *sin_freqs = (float *)calloc(max_seq_len * p->dim / 2, sizeof(float)); + if (!cos_freqs || !sin_freqs) {fprintf(stderr, "malloc failed!\n"); exit(EXIT_FAILURE);} + for (int i = 0; i < max_seq_len; i++) { + for (int j = 0; j < p->dim / 2; j++) { + cos_freqs[i * p->dim / 2 + j] = (float) std::cos(i * vec[j]); + sin_freqs[i * p->dim / 2 + j] = (float) std::sin(i * vec[j]); + } + } + + s->cos_freqs = cos_freqs; + s->sin_freqs = sin_freqs; +} + +void memory_map_weights(TransformerWeights *w, Config *p, float *ptr, ssize_t file_size) { + float *ptr_orig = ptr; + // make sure the multiplications below are done in 64bit to fit the parameter + // counts of 13B+ models + uint64_t n_layers = p->n_layers; + + w->rms1 = ptr; + ptr += n_layers * p->dim; // pointer arithmetic automatically multiplies sizeof(float) + w->rms2 = ptr; + ptr += n_layers * p->dim; + + w->wqkv = ptr; + ptr += n_layers * p->dim * p->dim * 3; + w->wo = ptr; + ptr += n_layers * p->dim * p->dim; + + w->w1 = ptr; + ptr += n_layers * p->dim * p->hidden_dim * 2; + w->w2 = ptr; + ptr += n_layers * p->hidden_dim * p->dim; + + w->upscale = ptr; + ptr += p->dim * p->dim * 2; // looks like in>out is 512>1024 but it's reshaped to 512 with seq_len*2 + w->crf = ptr; + ptr += p->dim * 4096; + + if (ptr > ptr_orig + file_size / 4) { + fprintf(stderr, "Too many bytes read!\nFile size: %lu kb; bytes read: %lu kb\n", file_size / 4 / 1024, (ptr - ptr_orig) / 1024); + } else if (ptr < ptr_orig + file_size / 4) { + fprintf(stderr, "Too few bytes read!\nFile size: %lu kb; bytes read: %lu kb\n", file_size / 4 / 1024, (ptr - ptr_orig) / 1024); + } +} + +void read_checkpoint(char *checkpoint, Config *config, + TransformerWeights *weights, int *fd, float **data, + ssize_t *file_size) { + FILE *file = fopen(checkpoint, "rb"); + if (!file) { + fprintf(stderr, "Couldn't open file %s\n", checkpoint); + exit(EXIT_FAILURE); + } + uint32_t *magic_and_version = (uint32_t *) calloc(2, sizeof(uint32_t)); + if (fread(magic_and_version, sizeof(uint32_t), 2, file) != 2) { + fprintf(stderr, "Couldn't read magic number and version from file %s\n", checkpoint); + exit(EXIT_FAILURE); + } + if ((magic_and_version[0] != 0x616b3432) || (magic_and_version[1] != 1)) { + fprintf(stderr, "Wrong magic number or incompatible version detected in file %s.\nMagic: %u; version: %d", checkpoint, magic_and_version[0], magic_and_version[1]); + } + // read in the config header + if (fread(config, sizeof(Config), 1, file) != 1) { + exit(EXIT_FAILURE); + } + + printf("dim: %d\n", config->dim); + printf("hidden_dim: %d\n", config->hidden_dim); + printf("n_layers: %d\n", config->n_layers); + printf("n_heads: %d\n", config->n_heads); + printf("seq_len: %d\n", config->seq_len); + + // figure out the file size + fseek(file, 0, SEEK_END); // move file pointer to end of file + *file_size = ftell(file); // get the file size, in bytes + fclose(file); + // memory map the Transformer weights into the data pointer + *fd = open(checkpoint, O_RDONLY); // open in read only mode + if (*fd == -1) { + fprintf(stderr, "open failed!\n"); + exit(EXIT_FAILURE); + } + *data = (float *)mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0); + if (*data == MAP_FAILED) { + fprintf(stderr, "mmap failed!\n"); + fprintf(stderr, "%s\n", strerror(errno)); + exit(EXIT_FAILURE); + } + + float *weights_ptr = *data + 256; // Fixed 256 byte header+padding + memory_map_weights(weights, config, weights_ptr, *file_size - 256); +} + +void build_transformer(Transformer *t, char *checkpoint_path) { + // read in the Config and the Weights from the checkpoint + read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, + &t->file_size); + // allocate the RunState buffers + malloc_run_state(&t->state, &t->config); + create_rope_freqs(&t->state, &t->config); +} + +void free_transformer(Transformer *t) { + // close the memory mapping + if (t->data != MAP_FAILED) { + munmap(t->data, t->file_size); + } + if (t->fd != -1) { + close(t->fd); + } + // free the RunState buffers + free_run_state(&t->state); +} + +// ---------------------------------------------------------------------------- +extern "C" { +// actual forward() is implemented in mlir +/* func.func @forward(%token : index, %pos : index, + // state + %key_cache : tensor<6x256x288xf32>, + %value_cache : tensor<6x256x288xf32>, + // weights + %embedding_table : tensor<32000x288xf32>, + %rms_att_weights : tensor<6x288xf32>, + %wq : tensor<6x288x288xf32>, + %wk : tensor<6x288x288xf32>, + %wv : tensor<6x288x288xf32>, + %wo : tensor<6x288x288xf32>, + %w1 : tensor<6x768x288xf32>, + %w2 : tensor<6x288x768xf32>, + %w3 : tensor<6x768x288xf32>, + %rms_ffn_weights : tensor<6x288xf32>, + %rms_final_weight : tensor<288xf32>, + %wcls : tensor<32000x288xf32> +) -> (tensor<32000xf32>, tensor<6x256x288xf32>, tensor<6x256x288xf32>) { +*/ + +float *forward(uint64_t index, uint64_t pos, float *key_cache, float *value_cache, float *embedding_table, + const float *rms_att_weights, const float *wq, const float *wk, const float *wv, const float *wo, + const float *w1, const float *w2, const float *w3, const float *rms_ffn_weights, + const float *rms_final_weight, const float *wcls); + +float *mha(const float *q, const float *kc, const float *vc, int64_t pos); + +// v: tensor<512xf32>; w: tensor<512xf32>. +float *rmsnorm(const float *v, const float *w); +// v: tensor<262144xf32>; w: tensor<262144xf32>. +float *rmsnorm_large(const float *v, const float *w); +// v: tensor<512x512xf322> (seq_len x dim); w: tensor<512xf32> (dim). +float *rmsnorm_batched(const float *v, const float *w); +float *softmax(const float *x); +} + +// Matrix-vector multiplication: W @ x -> xout. +// xout: pointer to output data (d,); +// x: input data vector (n,); +// w: weight matrix (d, n); +void matmul(float *xout, float *x, float *w, int A_cols, int A_rows) { + // W (d,n) @ x (n,) -> xout (d,) + // by far the most amount of time is spent inside this little function + int i; +#pragma omp parallel for private(i) + for (i = 0; i < A_rows; i++) { + float val = 0.0f; + for (int j = 0; j < A_cols; j++) { + val += w[i * A_cols + j] * x[j]; + } + xout[i] = val; + } +} + +// Matrix-matrix multiplication: A @ B -> C +// A: (A_rows x A_cols); +// B: (A_cols x B_cols); +// C: (A_rows x B_cols). +static void gemm(float *A, float *B, float *C, const int64_t A_rows, const int64_t A_cols, const int64_t B_cols, + const int64_t A_rows_offset = 0, const int64_t A_cols_offset = 0, const int64_t B_cols_offset = 0, + int64_t A_rows_end = -1, int64_t A_cols_end = -1, int64_t B_cols_end = -1) { + //printf("GEMM with A_rows = %ld, A_cols = %ld, B_cols = %ld\n", A_rows, A_cols, B_cols); + if (A_rows_end == -1) A_rows_end = A_rows; + if (A_cols_end == -1) A_cols_end = A_cols; + if (B_cols_end == -1) B_cols_end = B_cols; + assert(A_rows_end <= A_rows); + assert(A_cols_end <= A_cols); + assert(B_cols_end <= B_cols); + + int64_t i, k, j; + float temp; +#pragma omp parallel for default(shared) private(i, k, j, temp) + for (i = A_rows_offset; i < A_rows_end; i++) { + for (k = A_cols_offset; k < A_cols_end; k++) { + temp = A[i * A_cols + k]; + for (j = B_cols_offset; j < B_cols_end; j++) { + C[i * B_cols + j] += temp * B[k * B_cols + j]; + } + } + } +} + +// Matrix-matrix multiplication transposing B: A @ B^T -> C. +// A: (A_rows x A_cols); +// B: (B_rows x A_cols); +// C: (A_rows x B_rows). +static void gemm_t(float *A, float *B, float *C, const int64_t A_rows, const int64_t A_cols, const int64_t B_rows, + const int64_t A_rows_offset = 0, const int64_t A_cols_offset = 0, const int64_t B_rows_offset = 0, + int64_t A_rows_end = -1, int64_t A_cols_end = -1, int64_t B_rows_end = -1) { + //printf("GEMM_T with A_rows = %ld, A_cols = %ld, B_rows = %ld\n", A_rows, A_cols, B_rows); + if (A_rows_end == -1) A_rows_end = A_rows; + if (A_cols_end == -1) A_cols_end = A_cols; + if (B_rows_end == -1) B_rows_end = B_rows; + assert(A_rows_end <= A_rows); + assert(A_cols_end <= A_cols); + assert(B_rows_end <= B_rows); + + int64_t i, k, j; + float temp; +#pragma omp parallel for default(shared) private(i, k, j, temp) + for (i = A_rows_offset; i < A_rows_end; i++) { + for (k = A_cols_offset; k < A_cols_end; k++) { + temp = A[i * A_cols + k]; + for (j = B_rows_offset; j < B_rows_end; j++) { + // For inserting into submatrix of original size matrix + C[i * B_rows + j] += temp * B[k + j * B_rows]; + // For outputting to perfectly sized smaller matrix +// C[(i - A_rows_offset) * (B_rows_end - B_rows_offset) + j - B_rows_offset] += temp * B[k + j * B_rows]; + } + } + } +} + +void rmsnorm_cpu(float *out, float *x, float *weight, int size) { + // calculate sum of squares + float ss = 0.0f; + for (int j = 0; j < size; j++) { + ss += x[j] * x[j]; + } + ss /= size; + ss += 1e-5f; + ss = 1.0f / sqrtf(ss); + // normalize and scale + for (int j = 0; j < size; j++) { + out[j] = weight[j] * (ss * x[j]); + } +} + +// TODO: add support for extra dim (sequence length) and ensure RMS is only taken and used along that dim. +void rmsnorm_upmem(float *out, float *x, float *weight, int size) { + float *r = rmsnorm(x, weight); + memcpy(out, r, size * sizeof(float)); + free(r); +} + +void rmsnorm_upmem_large(float *out, float *x, float *weight, int size, int extra_dim_size = 1) { + float *r = rmsnorm_large(x, weight); + memcpy(out, r, size * extra_dim_size * sizeof(float)); + free(r); +} + +void rmsnorm_upmem_batched(float *out, float *x, float *weight, int size, int extra_dim_size = 1) { + float *r = rmsnorm_batched(x, weight); + memcpy(out, r, size * extra_dim_size * sizeof(float)); + free(r); +} + +void softmax_cpu(float *x, int size, int extra_dim_size = 1) { + for (int dim = 0; dim < extra_dim_size; dim++) { + // find max value (for numerical stability) + float max_val = x[0 + dim * size]; + for (int i = 1; i < size; i++) { + if (x[i + dim * size] > max_val) { + max_val = x[i + dim * size]; + } + } + // exp and sum + float sum = 0.0f; + for (int i = 0; i < size; i++) { + x[i + dim * size] = expf(x[i + dim * size] - max_val); + sum += x[i + dim * size]; + } + // normalize + for (int i = 0; i < size; i++) { + x[i + dim * size] /= sum; + } + } +} + +// TODO: add support for extra dim in CINM kernel and ensure max is only taken and normalized along that dim. +void softmax_upmem(float *x, int size, int extra_dim_size = 1) { + for (int dim = 0; dim < extra_dim_size; dim++) { + float *r = softmax(x); + memcpy(x + dim * size, r, size * sizeof(float)); + free(r); + } +} + +// x: output; q/k/v: inputs; qb/qe: query begin/end indices; kvb/kve: key/value begin/end indices. +void scaled_dot_product_attention(float *x, float *q, float *k, float *v, int qb, int qe, int kvb, int kve, int seq_len, int heads, int head_size) { + // q/k/v is NHTD (batch size, heads, seq_len, head size) + float *matmul_qk = (float *) malloc(seq_len * seq_len * sizeof(float)); + + float sqrt_d_k = std::sqrt(head_size); + + for (int h = 0; h < heads; h++) { + memset(matmul_qk, 0, seq_len * seq_len * sizeof(float)); + gemm_t(q + h * head_size, k + h * head_size, matmul_qk, + seq_len, head_size, seq_len, + qb, 0, kvb, + qe, head_size, kve); + + for (int i = 0; i < seq_len * seq_len; i++) { + matmul_qk[i] /= sqrt_d_k; + } + + // Performs softmax along last dim only + softmax_cpu(matmul_qk, seq_len, seq_len); + + // x.slice(-2 (T), qb, qe) + gemm(matmul_qk, v + h * head_size, x + h * head_size + qb, + seq_len, seq_len, head_size, + qb, kvb, 0, + qe, kve, head_size); + } +} + +void mha_cpu(float *x, float *att, float *qkv, int seq_len, int num_heads, int head_size) { + int dim = num_heads * head_size; + float *q = qkv; + float *k = qkv + dim * seq_len; + float *v = qkv + dim * seq_len * 2; + + int win_upper = 128, win_lower = 127; // FIXME: get from config + int num_splits = 8; // FIXME: get from config + int elems_per_split = (seq_len + (num_splits - 1)) / num_splits; // round up. TODO: does not pad to 4 + + for (int i = 0; i < num_splits; i++) { + int qb = i * elems_per_split; + if (qb >= seq_len) { + break; + } + + int qe = std::min(seq_len, qb + elems_per_split); + int kvb = std::max(0, qb - win_lower); + int kve = std::min(seq_len, qe + win_upper); + //printf("qb %d, qe %d, kvb %d, kve %d\n", qb, qe, kvb, kve); + // qkv is now supposed to be 3NHTD ({Q|K|V}, batch size, heads, seq_len, head size) + + // x.slice(-2 {seq_len dim}, qb, qe) = + scaled_dot_product_attention(x, q, k, v, qb, qe, kvb, kve, seq_len, num_heads, head_size); + } +} + +void mha_upmem(float *x, float *att, float *q, float *kc, float *vc, + int64_t pos) { + float *r = mha(q, kc, vc, pos); + memcpy(x, r, 288 * sizeof(float)); + free(r); +} + +/////----------------------- +///// FORWARD +/////----------------------- +// +// Input dims: (T, D) -> (sequence length, dimensions) -> (512, 512) +float *forward2(Transformer *transformer, float *input) { + // a few convenience variables + Config *p = &transformer->config; + TransformerWeights *w = &transformer->weights; + RunState *s = &transformer->state; + int head_size = p->dim / p->n_heads; + float deepnorm_alpha = 2.4494897f; // TODO: from config + + // TODO: implement convs + + // forward all the layers + for (unsigned int l = 0; l < p->n_layers; l++) { + if (l == 0) { + for (int i = 0; i < p->dim * p->seq_len; i++) { + s->x[i] = input[i]; + } + } + printf("Layer %d of %d. x[0] = %f\n", l + 1, p->n_layers, s->x[0]); + + // QKV matmul + if (l == 0) printf(" QKV...\n"); + gemm(w->wqkv + l * p->dim * p->dim * 3, s->x, s->qkv, // A, B, C + p->dim * 3, p->dim, p->seq_len); // dims + + // RoPE relative positional encoding: complex-valued rotate q and k in each head (ROPE) + if (l == 0) printf(" ROPE... x[0] = %f\n", s->qkv[0]); + for (int i = 0; i < p->seq_len; i++) { + for (int j = 0; j < p->dim; j += 2) { + int head_dim = j % head_size; + float cos_factor = s->cos_freqs[i * p->dim / 2 + j]; + float sin_factor = s->sin_freqs[i * p->dim / 2 + j]; + + float v0 = s->qkv[i * p->dim * 3 + j]; + float v1 = s->qkv[i * p->dim * 3 + j + 1]; + s->qkv[i * p->dim * 3 + j] = v0 * cos_factor - v1 * sin_factor; + s->qkv[i * p->dim * 3 + j + 1] = v0 * sin_factor + v1 * cos_factor; + } + } + + // MHA + if (l == 0) printf(" MHA... x[0] = %f\n", s->qkv[0]); + mha_cpu(s->xb, s->att, s->qkv, p->seq_len, p->n_heads, head_size); + + // final matmul to get the output of the attention + if (l == 0) printf(" OUTPROJ... x[0] = %f\n", s->x[0]); + gemm(w->wo + l * p->dim * p->dim, s->xb, s->x, // A, B, C + p->dim, p->dim, p->seq_len); // dims + + // residual connection back into x + for (int i = 0; i < p->seq_len * p->dim; i++) { + s->x[i] += s->xb[i] * deepnorm_alpha; + } + + // Attention rmsnorm (NORM1) + if (l == 0) printf(" RES&NORM1... x[0] = %f\n", s->xb[0]); + for (int i = 0; i < p->seq_len; i++) { + rmsnorm_cpu(s->x + i * p->dim, s->xb + i * p->dim, w->rms1 + l * p->dim, p->dim); + } + + // Gated MLP + + // In PyTorch: + // y = fc1(x) + // y, gate = y.chunk(2 parts, dim=-1 (last dim)) + // y = swiglu(y, gate) + // y = fc2(y) + // return y + + if (l == 0) printf(" FC1... x[0] = %f\n", s->x[0]); + gemm(w->w1 + l * p->dim * p->hidden_dim * 2, s->x, s->hb, p->hidden_dim * 2, p->dim, p->seq_len); + + if (l == 0) printf(" SWIGLU... x[0] = %f\n", s->hb[0]); + for (int i = 0; i < p->seq_len; i++) { + // SwiGLU non-linearity + for (int j = 0; j < p->hidden_dim; j++) { + float val = s->hb[i * p->hidden_dim * 2 + j]; + // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid + val *= (1.0f / (1.0f + expf(-val))); + val *= s->hb[i * p->hidden_dim * 2 + p->hidden_dim + i]; + s->hb[i * p->hidden_dim * 2 + i] = val; + } + } + + if (l == 0) printf(" FC2... x[0] = %f\n", s->hb[0]); + for (int i = 0; i < p->seq_len; i++) { + // final matmul to get the output of the ffn + matmul(s->xb + i * p->dim, + s->hb + i * p->hidden_dim * 2, + w->w2 + l * p->dim * p->hidden_dim, + p->hidden_dim, p->dim); + } + + // residual connection + for (int i = 0; i < p->seq_len * p->dim; i++) { + s->xb[i] += s->x[i] * deepnorm_alpha; + } + + // FF rmsnorm (NORM2) + if (l == 0) printf(" RES&NORM2... x[0] = %f\n", s->xb[0]); + rmsnorm_upmem_batched(s->x, s->xb, w->rms2 + l * p->dim, p->dim, p->seq_len); + } + + // upscale layer and reshape (implicit) + printf("Upscale layer\n"); + gemm(w->upscale, s->x, s->upscaled, p->dim * 2, p->dim, p->seq_len); + + // TODO: check if indexing is correct after upscaling. + printf("CRF layer\n"); + gemm(w->crf, s->upscaled, s->crf, 4096, p->dim, p->seq_len * 2); + + return s->crf; +} + +unsigned int random_u32(unsigned long long *state) { + // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A + *state ^= *state >> 12; + *state ^= *state << 25; + *state ^= *state >> 27; + return (*state * 0x2545F4914F6CDD1Dull) >> 32; +} + +float random_f32(unsigned long long *state) { // random float32 in [0,1) + return (random_u32(state) >> 8) / 16777216.0f; +} + +// ---------------------------------------------------------------------------- +// utilities: time + +long time_in_ms() { + // return time in milliseconds, for benchmarking the model speed + struct timespec time; + clock_gettime(CLOCK_REALTIME, &time); + return time.tv_sec * 1000 + time.tv_nsec / 1000000; +} + +void read_stdin(const char *guide, char *buffer, size_t bufsize) { + // read a line from stdin, up to but not including \n + printf("%s", guide); + if (fgets(buffer, bufsize, stdin) != NULL) { + size_t len = strlen(buffer); + if (len > 0 && buffer[len - 1] == '\n') { + buffer[len - 1] = '\0'; // strip newline + } + } +} + +// ---------------------------------------------------------------------------- +// CLI, include only if not testing +#ifndef TESTING + +void error_usage() { + fprintf(stderr, "Usage: run \n"); + fprintf(stderr, "Example: run model.bin\n"); + exit(EXIT_FAILURE); +} + +#define benchmark(name, num_runs, body) \ + do { \ + {body}; \ + {body}; \ + clock_t duration = 0; \ + for (size_t i = 0; i < num_runs; i++) { \ + clock_t start = clock(); \ + {body}; \ + duration += clock() - start; \ + } \ + printf(name ": %fms\n", (double)(duration) * 1000.0 / \ + (double)CLOCKS_PER_SEC / (double)num_runs); \ + } while (false) + +extern "C" void *upmemrt_dpu_alloc(int32_t num_ranks, int32_t num_dpus, + const char *dpu_binary_path); + +extern "C" void upmemrt_dpu_free(void *dpu_set); + +void run_benchmarks() { + float *a = (float *)malloc(262144 * sizeof(float)); + float *b = (float *)malloc(262144 * sizeof(float)); + + float *q = (float *)malloc(32768 * sizeof(float)); + float *kc = (float *)malloc(1024 * 32768 * sizeof(float)); + float *vc = (float *)malloc(1024 * 32768 * sizeof(float)); + float *qkv = (float *)malloc(3 * 1024 * 32768 * sizeof(float)); + float *att = (float *)malloc(1024 * sizeof(float)); + + float *r_cpu = (float *)malloc(262144 * sizeof(float)); + float *r_upmem; + unsigned long long rng = 1234; + for (size_t i = 0; i < 262144; i++) { + a[i] = random_f32(&rng); + b[i] = random_f32(&rng); + } + + benchmark("vector add cpu", 32, { + for (size_t i = 0; i < 262144; i++) { + r_cpu[i] = a[i] + b[i]; + } + }); + + benchmark("softmax cpu", 32, { softmax_cpu(a, 262144); }); + benchmark("rmsnorm cpu", 32, { rmsnorm_cpu(r_cpu, a, b, 262144); }); + benchmark("mha cpu", 32, { mha_cpu(r_cpu, att, qkv, 1024, 8, 4096); }); +} +static void print_matrix(float *A, const int64_t A_rows, const int64_t A_cols) { + + int64_t i, j; + printf("["); + for (i = 0; i < A_rows; ++i) { + for (j = 0; j < A_cols; ++j) { + printf("%f, ", A[i * A_cols + j]); + } + printf("\n"); + } + printf("]\n"); +} +int main(int argc, char *argv[]) { + if (getenv("BENCHMARK")) { + run_benchmarks(); + return 0; + } + + { + size_t seq_len = 512; + size_t dim_size = 512; + float *sample_input = (float *)malloc(seq_len * dim_size * sizeof(float)); + float *sample_weights = (float *)malloc(dim_size * sizeof(float)); + float *sample_weights_large = (float *)malloc(seq_len * dim_size * sizeof(float)); + unsigned long long rng = time_in_ms(); + for (size_t i = 0; i < seq_len * dim_size; i++) { + sample_input[i] = random_f32(&rng); + } + for (size_t i = 0; i < dim_size; i++) { + sample_weights[i] = random_f32(&rng); + } + float *out1 = (float *)malloc(dim_size * sizeof(float)); + float *out2 = (float *)malloc(seq_len * dim_size * sizeof(float)); + printf("Comparing rmsnorm upmem (un)batched...\n"); + printf("Running rmsnorm upmem...\n"); + rmsnorm_upmem(out1, sample_input, sample_weights, dim_size); + printf("Running rmsnorm upmem large...\n"); + rmsnorm_upmem_large(out2, sample_input, sample_weights_large, dim_size, seq_len); + printf("Running rmsnorm upmem batched...\n"); + rmsnorm_upmem_batched(out2, sample_input, sample_weights, dim_size, seq_len); + printf("Comparing...\n"); + for (size_t i = 0; i < dim_size; i++) { + if (fabsf(out1[i] - out2[i]) > 0.00001f) { + printf("[ERROR] %f != %f\n", out1[i], out2[i]); + break; + } + } + free(sample_input); + free(sample_weights); + free(out1); + free(out2); + } + + // default parameters + char *checkpoint_path = NULL; // e.g. out/model.bin + const char *mode = "basecall"; + + // poor man's C argparse so we can override the defaults above from the + // command line + if (argc >= 2) { + checkpoint_path = argv[1]; + } else { + error_usage(); + } + + // build the Transformer via the model .bin file + Transformer transformer; + build_transformer(&transformer, checkpoint_path); + + float *sample_input = (float *)malloc(transformer.config.seq_len * transformer.config.dim * sizeof(float)); + unsigned long long rng = time_in_ms(); + for (size_t i = 0; i < transformer.config.seq_len * transformer.config.dim; i++) { + sample_input[i] = random_f32(&rng); + } + // run! + if (strcmp(mode, "basecall") == 0) { + printf("Running forward pass. x[0] = %f\n", sample_input[0]); + forward2(&transformer, sample_input); + } else { + fprintf(stderr, "unknown mode: %s\n", mode); + error_usage(); + } + + // memory and file handles cleanup + free_transformer(&transformer); + return 0; +} +#endif diff --git a/cinnamon/samples/dorado/dorado.mlir b/cinnamon/samples/dorado/dorado.mlir new file mode 100644 index 0000000..9c98faf --- /dev/null +++ b/cinnamon/samples/dorado/dorado.mlir @@ -0,0 +1,345 @@ +// dim: 512 +// hidden_dim: 1536 +// kv_dim: 512 +// kv_mul: 1 +// n_blocks: 18 (layers) +// n_heads: 8 +// head_size: 64 +// vocab_size: 32000 > should be deleted? +// seq_len T: 512 + +//func.func @forward(%token : index, %pos : index, +// // state +// %kc : memref<18x512x512xf32>, // blocks, T, C +// %vc : memref<18x512x512xf32>, // blocks, T, C +// // weights +// %embedding_table : tensor<32000x512xf32>, +// %rms_att_weights : tensor<18x512xf32>, +// %wq : tensor<18x512x512xf32>, +// %wk : tensor<18x512x512xf32>, +// %wv : tensor<18x512x512xf32>, +// %wo : tensor<18x512x512xf32>, +// %w1 : tensor<18x1536x512xf32>, +// %w2 : tensor<18x512x1536xf32>, +// %w3 : tensor<18x1536x512xf32>, +// %rms_ffn_weights : tensor<18x512xf32>, +// %rms_final_weight : tensor<512xf32>, +// %wcls : tensor<32000x512xf32> +//) -> tensor<32000xf32> { +// %c0 = arith.constant 0 : index +// %c1 = arith.constant 1 : index +// %c2 = arith.constant 2 : index +// %n_blocks = arith.constant 18 : index +// %head_size = arith.constant 64 : index +// %head_size_f32 = arith.constant 64.0 : f32 +// %dim = arith.constant 512 : index +// %len = arith.constant 512 : index +// %hidden_dim = arith.constant 1536 : index +// +// // Required by language limitations +// %c0f = arith.constant 0.0 : f32 +// %c1f = arith.constant 1.0 : f32 +// %c10000f = arith.constant 10000.0 : f32 +// +// %content_row = tensor.extract_slice %embedding_table [%token, 0] [1, 512] [1, 1] : tensor<32000x512xf32> to tensor<512xf32> +// +// // TODO: Need convs +// +// %x = scf.for %layer = %c0 to %n_blocks step %c1 iter_args(%x = %content_row) -> (tensor<512xf32>) { +// %rms_att_weight = tensor.extract_slice %rms_att_weights [%layer, 0] [1, 512] [1, 1] : tensor<18x512xf32> to tensor<512xf32> +// %xb = func.call @rmsnorm(%x, %rms_att_weight) : (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32> +// +// // qkv matmuls +// %wqs = tensor.extract_slice %wq [%layer, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<18x512x512xf32> to tensor<512x512xf32> +// %wks = tensor.extract_slice %wk [%layer, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<18x512x512xf32> to tensor<512x512xf32> +// %wvs = tensor.extract_slice %wv [%layer, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<18x512x512xf32> to tensor<512x512xf32> +// %q, %k, %v = cinm.compute attributes { workgroupShape = array } -> tensor<512xf32>, tensor<512xf32>, tensor<512xf32> { +// %q = cinm.op.gemv %wqs, %xb : (tensor<512x512xf32>, tensor<512xf32>) -> tensor<512xf32> +// %k = cinm.op.gemv %wks, %xb : (tensor<512x512xf32>, tensor<512xf32>) -> tensor<512xf32> +// %v = cinm.op.gemv %wvs, %xb : (tensor<512x512xf32>, tensor<512xf32>) -> tensor<512xf32> +// cinm.yield %q, %k, %v : tensor<512xf32>, tensor<512xf32>, tensor<512xf32> +// } +// +// // RoPE relative positional encoding: complex-valued rotate q and k in each head +// %posi = arith.index_cast %pos : index to i64 +// %posf = arith.uitofp %posi : i64 to f32 +// %q2, %k2 = scf.for %i = %c0 to %dim step %c2 iter_args(%qi = %q, %ki = %k) -> (tensor<512xf32>, tensor<512xf32>) { +// %head_dim = arith.remui %i, %head_size : index +// %head_dimi = arith.index_cast %head_dim : index to i64 +// %head_dimf = arith.uitofp %head_dimi : i64 to f32 +// %0 = arith.divf %head_dimf, %head_size_f32 : f32 +// %1 = math.powf %c10000f, %0 : f32 +// %freq = arith.divf %c1f, %1 : f32 +// %val = arith.mulf %posf, %freq : f32 +// %fcr = math.cos %val : f32 +// %fci = math.sin %val : f32 +// +// %qr = func.call @rot(%qi, %i, %fcr, %fci) : (tensor<512xf32>, index, f32, f32) -> tensor<512xf32> +// +// %cond = arith.cmpi ult, %i, %dim : index +// %kr = scf.if %cond -> (tensor<512xf32>) { +// %kr = func.call @rot(%ki, %i, %fcr, %fci) : (tensor<512xf32>, index, f32, f32) -> tensor<512xf32> +// scf.yield %kr : tensor<512xf32> +// } else { +// scf.yield %ki : tensor<512xf32> +// } +// +// scf.yield %qr, %kr : tensor<512xf32>, tensor<512xf32> +// } +// +// %kmr = bufferization.to_memref %k2 : memref<512xf32> +// %vmr = bufferization.to_memref %v : memref<512xf32> +// +// %kcd = memref.subview %kc [%layer, %pos, 0] [1, 1, 512] [1, 1, 1] : memref<18x512x512xf32> to memref<512xf32, strided<[1], offset: ?>> // blocks, T, C +// %vcd = memref.subview %vc [%layer, %pos, 0] [1, 1, 512] [1, 1, 1] : memref<18x512x512xf32> to memref<512xf32, strided<[1], offset: ?>> // blocks, T, C +// +// memref.copy %vmr, %vcd : memref<512xf32> to memref<512xf32, strided<[1], offset: ?>> +// memref.copy %kmr, %kcd : memref<512xf32> to memref<512xf32, strided<[1], offset: ?>> +// +// // multi head attention +// %lkc = memref.subview %kc [%layer, 0, 0] [1, 512, 512] [1, 1, 1] : memref<18x512x512xf32> to memref<512x512xf32, strided<[512, 1], offset: ?>> // blocks, T, C +// %lkc2 = bufferization.to_tensor %lkc : memref<512x512xf32, strided<[512, 1], offset: ?>> +// %lvc = memref.subview %vc [%layer, 0, 0] [1, 512, 512] [1, 1, 1] : memref<18x512x512xf32> to memref<512x512xf32, strided<[512, 1], offset: ?>> // blocks, T, C +// %lvc2 = bufferization.to_tensor %lvc : memref<512x512xf32, strided<[512, 1], offset: ?>> +// %xb2 = func.call @mha(%q2, %lkc2, %lvc2, %pos) : (tensor<512xf32>, tensor<512x512xf32>, tensor<512x512xf32>, index) -> tensor<512xf32> // T, C +// +// %wo_slice = tensor.extract_slice %wo [%layer, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<18x512x512xf32> to tensor<512x512xf32> +// %xb4 = cinm.compute attributes { workgroupShape = array } -> tensor<512xf32> { +// // final matmul to get the output of the attention +// %xb3 = cinm.op.gemv %wo_slice, %xb2 : (tensor<512x512xf32>, tensor<512xf32>) -> tensor<512xf32> +// +// // residual connection back into x +// %xb4 = cinm.op.add %x, %xb3 : tensor<512xf32> +// cinm.yield %xb4 : tensor<512xf32> +// } +// +// // ffn rmsnorm +// %rms_ffn_weight = tensor.extract_slice %rms_ffn_weights [%layer, 0] [1, 512] [1, 1] : tensor<18x512xf32> to tensor<512xf32> +// %xb5 = func.call @rmsnorm(%xb4, %rms_ffn_weight) : (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32> +// +// // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) +// // first calculate self.w1(x) and self.w3(x) +// %w1_slice = tensor.extract_slice %w1 [%layer, 0, 0] [1, 1536, 512] [1, 1, 1] : tensor<18x1536x512xf32> to tensor<1536x512xf32> +// %w3_slice = tensor.extract_slice %w3 [%layer, 0, 0] [1, 1536, 512] [1, 1, 1] : tensor<18x1536x512xf32> to tensor<1536x512xf32> +// %hb1, %hb2 = cinm.compute attributes { workgroupShape = array } -> tensor<1536xf32>, tensor<1536xf32> { +// %hb1 = cinm.op.gemv %w1_slice, %xb5 : (tensor<1536x512xf32>, tensor<512xf32>) -> tensor<1536xf32> +// %hb2 = cinm.op.gemv %w3_slice, %xb5 : (tensor<1536x512xf32>, tensor<512xf32>) -> tensor<1536xf32> +// cinm.yield %hb1, %hb2 : tensor<1536xf32>, tensor<1536xf32> +// } +// +// // SwiGLU non-linearity +// %hb3 = scf.for %i = %c0 to %hidden_dim step %c1 iter_args(%hb = %hb1) -> (tensor<1536xf32>) { +// %0 = tensor.extract %hb [%i] : tensor<1536xf32> +// %1 = tensor.extract %hb2 [%i] : tensor<1536xf32> +// %2 = math.exp %0 : f32 +// %3 = arith.addf %c1f, %2 : f32 +// %4 = arith.divf %c1f, %3 : f32 +// %5 = arith.mulf %1, %4 : f32 +// %hbr = tensor.insert %5 into %hb [%i] : tensor<1536xf32> +// scf.yield %hbr : tensor<1536xf32> +// } +// +// %w2_slice = tensor.extract_slice %w2 [%layer, 0, 0] [1, 512, 1536] [1, 1, 1] : tensor<18x512x1536xf32> to tensor<512x1536xf32> +// %xb7 = cinm.compute attributes { workgroupShape = array } -> tensor<512xf32> { +// // final matmul to get the output of the ffn +// %xb6 = cinm.op.gemv %w2_slice, %hb3 : (tensor<512x1536xf32>, tensor<1536xf32>) -> tensor<512xf32> +// +// // residual connection +// %xb7 = cinm.op.add %x, %xb6 : tensor<512xf32> +// cinm.yield %xb7 : tensor<512xf32> +// } +// +// scf.yield %xb7 : tensor<512xf32> +// } +// +// // TODO: need transformer decoder (linear upsampling layer) +// // TODO: need CRF layer (linear layer) +// +// %x2 = func.call @rmsnorm(%x, %rms_final_weight) : (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32> +// %logits = cinm.compute attributes { workgroupShape = array } -> tensor<32000xf32> { +// %wcls2 = tensor.pad %wcls low[0,0] high[%hidden_dim,0] { +// ^bb0(%arg1: index, %arg2: index): +// tensor.yield %c0f : f32 +// } : tensor<32000x512xf32> to tensor<32768x512xf32> +// %logits = cinm.op.gemv %wcls2, %x2 : (tensor<32768x512xf32>, tensor<512xf32>) -> tensor<32768xf32> +// %logits2 = tensor.extract_slice %logits [0] [32000] [1] : tensor<32768xf32> to tensor<32000xf32> +// cinm.yield %logits2 : tensor<32000xf32> +// } +// +// return %logits : tensor<32000xf32> +//} + +func.func @rot(%v: tensor<512xf32>, %i: index, %fcr : f32, %fci : f32) -> tensor<512xf32> { + %c1 = arith.constant 1 : index + %i2 = arith.addi %i, %c1 : index + %v0 = tensor.extract %v [%i] : tensor<512xf32> + %v1 = tensor.extract %v [%i2] : tensor<512xf32> + %0 = arith.mulf %v0, %fcr : f32 + %1 = arith.mulf %v1, %fci : f32 + %2 = arith.subf %0, %1 : f32 + %r0 = tensor.insert %2 into %v[%i] : tensor<512xf32> + %3 = arith.mulf %v0, %fci : f32 + %4 = arith.mulf %v1, %fcr : f32 + %5 = arith.addf %3, %4 : f32 + %r1 = tensor.insert %2 into %r0[%i] : tensor<512xf32> + return %r1 : tensor<512xf32> +} + + +// Q: features, KC: sequence length x features, VC: sequence length x features +func.func @mha(%q: tensor<512xf32>, %kc: tensor<512x512xf32>, %vc: tensor<512x512xf32>, %pos: index) -> tensor<512xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %nheads = arith.constant 8 : index + %head_dim = arith.constant 64 : index + %c0f = arith.constant 0.0 : f32 + %scale = arith.constant 8.0 : f32 // sqrt(head_dim) + %ninf = arith.constant 0xFF800000 : f32 + + %pos2 = arith.addi %pos, %c1 : index + + %attn_init = tensor.generate { + ^bb0(%arg1: index): + tensor.yield %ninf : f32 + } : tensor<512xf32> + + %xb_init = tensor.empty() : tensor<512xf32> + %xb = scf.for %head = %c0 to %nheads step %c1 iter_args(%xbi = %xb_init) -> (tensor<512xf32>) { + %hoff = arith.muli %head, %head_dim : index + + %attn = scf.for %i = %c0 to %pos2 step %c1 iter_args(%attn_i = %attn_init) -> (tensor<512xf32>) { + %qs = tensor.extract_slice %q [%hoff] [64] [1] : tensor<512xf32> to tensor<64xf32> + %k = tensor.extract_slice %kc [%i, %hoff] [1, 64] [1, 1] : tensor<512x512xf32> to tensor<64xf32> + %score = cinm.compute attributes { workgroupShape = array } -> f32 { + %0 = cinm.op.mul %qs, %k : tensor<64xf32> + //%1 = cinm.op.reduce add (%0) { dimensions = array } : tensor<64xf32> -> f32 + %1 = arith.constant 1.0 : f32 + %2 = arith.divf %1, %scale : f32 + cinm.yield %2 : f32 + } + %attn_i2 = tensor.insert %score into %attn_i [%i] : tensor<512xf32> + scf.yield %attn_i2 : tensor<512xf32> + } + + %attn3 = func.call @softmax(%attn) : (tensor<512xf32>) -> tensor<512xf32> + + %xb_slice_init = tensor.generate { + ^bb0(%arg1: index): + tensor.yield %c0f : f32 + } : tensor<64xf32> + + %xb_slice = scf.for %i = %c0 to %pos2 step %c1 iter_args(%xb_slice_i = %xb_slice_init) -> (tensor<64xf32>) { + %v = tensor.extract_slice %vc [%i, %hoff] [1, 64] [1, 1] : tensor<512x512xf32> to tensor<64xf32> + %a = tensor.extract %attn3 [%i] : tensor<512xf32> + %xb_slice = cinm.compute attributes { workgroupShape = array } -> tensor<64xf32> { + %0 = cinm.op.muls %v, %a : tensor<64xf32> + %1 = cinm.op.add %xb_slice_i, %0 : tensor<64xf32> + cinm.yield %1 : tensor<64xf32> + } + scf.yield %xb_slice : tensor<64xf32> + } + + %xbr = tensor.insert_slice %xb_slice into %xbi [%hoff] [64] [1] : tensor<64xf32> into tensor<512xf32> + scf.yield %xbr : tensor<512xf32> + } + + return %xb : tensor<512xf32> +} + +func.func @rmsnorm(%v : tensor<512xf32>, %w : tensor<512xf32>) -> tensor<512xf32> { + %epsilon = arith.constant 1.0e-5 : f32 + %c1 = arith.constant 1.0 : f32 + %len = arith.constant 512.0 : f32 + + %r = cinm.compute attributes { workgroupShape = array } -> tensor<512xf32> { + %0 = cinm.op.mul %v, %v : tensor<512xf32> + %ss = cinm.op.reduce add (%0) { dimensions = array } : tensor<512xf32> -> f32 + %s0 = arith.divf %ss, %len : f32 + %s1 = arith.addf %s0, %epsilon : f32 + %s = math.rsqrt %s1 : f32 + %x = cinm.op.muls %v, %s : tensor<512xf32> + %r = cinm.op.mul %x, %w : tensor<512xf32> + cinm.yield %r : tensor<512xf32> + } + return %r : tensor<512xf32> +} + +func.func @rmsnorm_large(%v : tensor<262144xf32>, %w : tensor<262144xf32>) -> tensor<262144xf32> { + %epsilon = arith.constant 1.0e-5 : f32 + %c1 = arith.constant 1.0 : f32 + %len = arith.constant 512.0 : f32 + + %r = cinm.compute attributes { workgroupShape = array } -> tensor<262144xf32> { + %0 = cinm.op.mul %v, %v : tensor<262144xf32> + //%ss = cinm.op.reduce add (%0) { dimensions = array } : tensor<262144xf32> -> f32 + %ss = arith.constant 1.0 : f32 + %s0 = arith.divf %ss, %len : f32 + %s1 = arith.addf %s0, %epsilon : f32 + %s = math.rsqrt %s1 : f32 + %x = cinm.op.muls %v, %s : tensor<262144xf32> + %r = cinm.op.mul %x, %w : tensor<262144xf32> + cinm.yield %r : tensor<262144xf32> + } + return %r : tensor<262144xf32> +} + +// input: (batch size/seq len, input size) +func.func @rmsnorm_batched(%v : tensor<512x512xf32>, %w : tensor<512xf32>) -> tensor<512x512xf32> { + %epsilon = arith.constant 1.0e-5 : f32 + %c1 = arith.constant 1.0 : f32 + %len = arith.constant 512.0 : f32 + + %r = cinm.compute attributes { workgroupShape = array } -> tensor<512x512xf32> { + %0 = cinm.op.mul %v, %v : tensor<512x512xf32> + //%ss0 = cinm.op.reduce add (%0) { dimensions = array } : tensor<512x512xf32> -> tensor<512xf32> + %ss0 = tensor.splat %c1 : tensor<512xf32> + %ss1 = cinm.op.divs %ss0, %len : tensor<512xf32> + %ss2 = cinm.op.adds %ss1, %epsilon : tensor<512xf32> + %s = cinm.op.element_wise rsqrt (%ss2) : tensor<512xf32> + + %shape = tensor.empty() : tensor<512x512xf32> + %s_broadcasted = linalg.broadcast ins(%s : tensor<512xf32>) outs(%shape : tensor<512x512xf32>) dimensions = [1] + %x = cinm.op.mul %v, %s_broadcasted : tensor<512x512xf32> + + %w_broadcasted = linalg.broadcast ins(%w : tensor<512xf32>) outs(%shape : tensor<512x512xf32>) dimensions = [0] + %r = cinm.op.mul %x, %w_broadcasted : tensor<512x512xf32> + + cinm.yield %0 : tensor<512x512xf32> + } + return %r : tensor<512x512xf32> +} + +func.func @softmax(%vec : tensor<512xf32>) -> tensor<512xf32> { + %r = cinm.compute attributes { workgroupShape = array } -> tensor<512xf32> { + %max = cinm.op.reduce max (%vec) { dimensions = array } : tensor<512xf32> -> f32 + %t = cinm.op.subs %vec, %max : tensor<512xf32> + %shape = tensor.empty() : tensor<512xf32> + %e = linalg.exp ins(%t : tensor<512xf32>) outs(%shape : tensor<512xf32>) -> tensor<512xf32> + %s = cinm.op.reduce add (%e) { dimensions = array } : tensor<512xf32> -> f32 + %r = cinm.op.divs %e, %s : tensor<512xf32> + cinm.yield %r : tensor<512xf32> + } + + return %r : tensor<512xf32> +} + +func.func @softmax_batched(%x: tensor<512x512xf32>) -> tensor<512x512xf32> { + %c1 = arith.constant 1.0 : f32 + %all_ones = tensor.splat %c1 : tensor<512x512xf32> + + %r = cinm.compute attributes { workgroupShape = array } -> tensor<512x512xf32> { + %shape = tensor.empty() : tensor<512x512xf32> + // FIXME: do max only along feature dim, not seq_len dim + %maxs = cinm.op.reduce max (%x) { dimensions = array } : tensor<512x512xf32> -> tensor<512xf32> + %maxs_bc = linalg.broadcast ins(%maxs : tensor<512xf32>) outs(%shape : tensor<512x512xf32>) dimensions = [1] + %tmp = cinm.op.sub %x, %maxs_bc : tensor<512x512xf32> + %e = linalg.exp ins(%tmp : tensor<512x512xf32>) outs(%shape : tensor<512x512xf32>) -> tensor<512x512xf32> + // Ideal case: reduce along the batch size dimension only. Currently not possible. + %summed = cinm.op.reduce add (%e) { dimensions = array } : tensor<512x512xf32> -> tensor<512xf32> + %summed_bc = linalg.broadcast ins(%summed : tensor<512xf32>) outs(%shape : tensor<512x512xf32>) dimensions = [1] + %res = cinm.op.div %e, %summed_bc : tensor<512x512xf32> + cinm.yield %res : tensor<512x512xf32> + } + + return %r : tensor<512x512xf32> +} diff --git a/cinnamon/testbench/lib/dpu/expf.c b/cinnamon/testbench/lib/dpu/expf.c index 9be18c7..b652248 100644 --- a/cinnamon/testbench/lib/dpu/expf.c +++ b/cinnamon/testbench/lib/dpu/expf.c @@ -66,3 +66,19 @@ float expf(float a) { r = s * s; return r; } + +float rsqrt(float number) +{ + union { + float f; + unsigned int i; + } conv = { .f = number }; + conv.i = 0x5f37599e - (conv.i >> 1); + conv.f *= 1.5F - (number * 0.5F * conv.f * conv.f); + conv.f *= 1.5F - (number * 0.5F * conv.f * conv.f); + return conv.f; +} + +float absf(float a) { + return a < 0 ? -a : a; +} diff --git a/justfile b/justfile index 12a77d7..516ae97 100644 --- a/justfile +++ b/justfile @@ -12,7 +12,7 @@ build_type := env_var_or_default("LLVM_BUILD_TYPE", "RelWithDebInfo") linker := env_var_or_default("CMAKE_LINKER_TYPE", "DEFAULT") upmem_dir := env_var_or_default("UPMEM_HOME", "") build_dir := "cinnamon/build" -python37_dir := env_var("PYTHON_37_DIR") +python310_dir := env_var("PYTHON_310_DIR") # Do a full build as if in CI. Only needed the first time you build the project. # Parameters: no-upmem enable-gpu enable-cuda enable-roc no-torch-mlir no-python-venv @@ -75,7 +75,6 @@ cnm-to-upmem FILE *ARGS: ( cinm-opt FILE "--convert-cnm-to-upmem" "--cse" - "--convert-math-to-llvm" "--upmem-outline-kernel" "--upmem-dedup-kernels" "--cse" @@ -85,6 +84,7 @@ cnm-to-upmem FILE *ARGS: ( upmem-to-llvm FILE *ARGS: ( cinm-opt FILE "--mlir-print-debuginfo" + "--convert-math-to-llvm" "--convert-scf-to-cf" "--convert-cf-to-llvm" "--fold-memref-alias-ops" @@ -119,19 +119,19 @@ translate-upmem-kernel-to-cpp FILE *ARGS: ( ) compile-upmem-kernels FILE OUTDIR: - bash "testbench/lib/compile_dpu.sh" {{FILE}} {{OUTDIR}} + bash "cinnamon/testbench/lib/compile_dpu.sh" {{FILE}} {{OUTDIR}} compile-upmem-runner *ARGS: - clang++ -g -c {{ARGS}} + /usr/bin/clang++ -g -c {{ARGS}} link-upmem-runner *ARGS: - clang++ -g {{ARGS}} -lUpmemDialectRuntime -fPIE -ldpu -ldpuverbose -L{{upmem_dir}}/lib -L{{build_dir}}/lib -I{{upmem_dir}}/include/dpu -rpath {{python37_dir}} + /usr/bin/clang++ -g {{ARGS}} -lUpmemDialectRuntime -fPIE -ldpu -ldpuverbose -L{{upmem_dir}}/lib -L{{build_dir}}/lib -I{{upmem_dir}}/include/dpu -rpath {{python310_dir}} remove-memref-alignment FILE: sed -i 's/{alignment = 64 : i64} //' {{FILE}} build-transformer: \ - (cinm-to-cnm "cinnamon/samples/asdf.mlir" "-o" "./transformer.cnm.mlir") \ + (cinm-to-cnm "cinnamon/samples/transformer.mlir" "-o" "./transformer.cnm.mlir") \ (build-transformer-from-cnm "./transformer.cnm.mlir") build-transformer-from-cnm FILE: \ @@ -142,9 +142,21 @@ build-transformer-from-cnm FILE: \ (translate-upmem-kernel-to-cpp "./transformer.upmem.mlir" "-o" "./transformer.upmem.c") \ (compile-upmem-kernels "./transformer.upmem.c" "cinnamon/build/samples") \ (compile-upmem-runner "./transformer.ll" "-o" "cinnamon/build/samples/transformer.o") \ - (compile-upmem-runner "./llama2.cpp" "-o" "cinnamon/build/samples/llama2.o") \ + (compile-upmem-runner "cinnamon/samples/llama2.cpp" "-o" "cinnamon/build/samples/llama2.o") \ (link-upmem-runner "cinnamon/build/samples/transformer.o" "cinnamon/build/samples/llama2.o" "-o" "cinnamon/build/samples/transformer") +build-dorado: \ + (cinm-to-cnm "cinnamon/samples/dorado/dorado.mlir" "-o" "./dorado.cnm.mlir") \ + (cnm-to-upmem "./dorado.cnm.mlir" "-o" "./dorado.upmem.mlir") \ + (remove-memref-alignment "./dorado.upmem.mlir") \ + (upmem-to-llvm "./dorado.upmem.mlir" "-o" "./dorado.llvm.mlir") \ + (translate-mlir-to-llvmir "./dorado.llvm.mlir" "-o" "./dorado.ll") \ + (translate-upmem-kernel-to-cpp "./dorado.upmem.mlir" "-o" "./dorado.upmem.c") \ + (compile-upmem-kernels "./dorado.upmem.c" "cinnamon/build/samples") \ + (compile-upmem-runner "./dorado.ll" "-o" "cinnamon/build/samples/dorado.o") \ + (compile-upmem-runner "cinnamon/samples/dorado/dorado.cpp" "-o" "cinnamon/build/samples/dorado_host.o" "-fopenmp") \ + (link-upmem-runner "cinnamon/build/samples/dorado.o" "cinnamon/build/samples/dorado_host.o" "-o" "cinnamon/build/samples/dorado" "-fopenmp") + cinm-vulkan-runner FILE *ARGS: {{build_dir}}/bin/cinm-vulkan-runner {{FILE}} \ --shared-libs={{llvm_prefix}}/lib/libvulkan-runtime-wrappers.so,{{llvm_prefix}}/lib/libmlir_runner_utils.so \