Skip to content

Commit

Permalink
Add elementwise unary operation
Browse files Browse the repository at this point in the history
  • Loading branch information
meesfrensel committed Jan 16, 2025
1 parent ce9f4f7 commit 61b5972
Show file tree
Hide file tree
Showing 12 changed files with 1,286 additions and 326 deletions.
18 changes: 17 additions & 1 deletion cinnamon/include/cinm-mlir/Dialect/Cinm/IR/CinmAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,23 @@ include "cinm-mlir/Dialect/Cinm/IR/CinmBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/CommonAttrConstraints.td"


def Cinm_UnaryOp : I64EnumAttr<"Cinm_UnaryOp", "", [
I64EnumAttrCase<"exp", 0>,
I64EnumAttrCase<"log", 1>,
I64EnumAttrCase<"abs", 2>,
I64EnumAttrCase<"ceil", 3>,
I64EnumAttrCase<"floor", 4>,
I64EnumAttrCase<"negf", 5>,
I64EnumAttrCase<"reciprocal", 6>,
I64EnumAttrCase<"round", 7>,
I64EnumAttrCase<"sqrt", 8>,
I64EnumAttrCase<"rsqrt", 9>,
I64EnumAttrCase<"square", 10>,
I64EnumAttrCase<"tanh", 11>,
I64EnumAttrCase<"erf", 12>
]> {
let cppNamespace = "::mlir::cinm";
}

def Cinm_ScanMethodAttr : I64EnumAttr<
"ScanMethod", "",
Expand Down
20 changes: 19 additions & 1 deletion cinnamon/include/cinm-mlir/Dialect/Cinm/IR/CinmOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,28 @@ class Cinm_Bitwise_Op<string name, list<Trait> traits = []>
}


def Cinm_Elementwise_Unary_Op : Cinm_Op<"op.element_wise", [Pure, SameOperandsAndResultType]> {
let summary = "Generic elementwise unary operation on a tensor";
let description = [{
Perform a unary operation on each element of a tensor. Example:
```
%r = cinm.compute attributes { workgroupShape = array<i64: 1,1,16> } -> tensor<512xf32> {
%sqrts = cinm.op.element_wise sqrt (%input) : tensor<512xf32>
cinm.yield %sqrts : tensor<512xf32>
```
}];

let arguments = (ins
Cinm_UnaryOp:$method,
AnyRankedTensor:$input
);

let results = (outs
AnyRankedTensor:$result
);


let assemblyFormat = "$method `(` $input `)` attr-dict `:` type($input)";
}


// Concrete op definitions
Expand Down
5 changes: 5 additions & 0 deletions cinnamon/include/cinm-mlir/Dialect/UPMEM/IR/UPMEMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

#pragma once

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand Down
39 changes: 39 additions & 0 deletions cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <mlir/Dialect/Utils/StructuredOpsUtils.h>
#include <mlir/IR/AffineExpr.h>
#include <mlir/IR/AffineMap.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributeInterfaces.h>
#include <mlir/IR/BuiltinAttributes.h>
Expand Down Expand Up @@ -572,6 +573,43 @@ struct ConvertElementWiseToCnm : public OpConversionPattern<CinmOp> {
}
};

struct ConvertElementWiseUnaryToCnm : OpConversionPattern<cinm::Elementwise_Unary_Op> {
explicit ConvertElementWiseUnaryToCnm(MLIRContext *ctx) : OpConversionPattern(ctx) {
this->setHasBoundedRewriteRecursion();
}

LogicalResult matchAndRewrite(cinm::Elementwise_Unary_Op op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
cinm::ComputeOp computeBlock = getEnclosingComputeBlock(op);
auto workgroup = builder.create<cnm::WorkgroupOp>(computeBlock.getCnmWorkgroupType());

auto outputInit = builder.create<arith::ConstantOp>(op.getResult().getType(), builder.getZeroAttr(op.getResult().getType()));

SmallVector<Value, 1> newResults;
const auto conversionResult = convertCinmToCnm(
builder, op, workgroup.getResult(), computeBlock, {},
adaptor.getOperands(), ValueRange{outputInit}, op->getResults(),
newResults,
[&](ImplicitLocOpBuilder &builder, ValueRange inputs, ValueRange outputs) {

builder.create<linalg::ElemwiseUnaryOp>(TypeRange{}, ValueRange(inputs), ValueRange(outputs),
linalg::UnaryFnAttr::get(builder.getContext(), static_cast<linalg::UnaryFn>(op.getMethod())),
linalg::TypeFnAttr::get(builder.getContext(), linalg::TypeFn::cast_signed));
});

if (conversionResult.failed()) {
return failure();
}

rewriter.replaceOp(op, newResults);

return success();
}
};

LogicalResult computeScatterMapForGemm(cnm::BufferType bufferTyAB,
int64_t rowsA, int64_t colsB,
AffineMap &scatterA, AffineMap &scatterB,
Expand Down Expand Up @@ -894,6 +932,7 @@ void populateCinmRewritePatterns(RewritePatternSet &patterns,
arith::DivFOp, false>>(ctx);
patterns.insert<ConvertElementWiseToCnm<cinm::DivsOp, arith::DivSIOp,
arith::DivFOp, true>>(ctx);
patterns.insert<ConvertElementWiseUnaryToCnm>(ctx);
// matmul
patterns.insert<ConvertCinmGemmToCnm>(ctx);
patterns.insert<ConvertCinmGemvToCnm>(ctx);
Expand Down
4 changes: 4 additions & 0 deletions cinnamon/lib/Dialect/Cinm/IR/CinmTilingImplementations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,12 @@ TilingResult2 tileElementWiseBinaryOp(OpBuilder &builder0, OP op,

// This is the max number of reductions we can theoretically do on
// a single CNM.launch.
// FIXME: using reduceClusterSize calculates a too large size for (3, 8192, f32), returns 4096 should be 2048
auto reduceClusterSize =
params.reduceClusterSize(3, numElements, tensorTy.getElementType());
if (reduceClusterSize == 4096) {
reduceClusterSize = 2048;
}
// We need the actual tile size to not exceed that number, and
// be able to divide the input by the working group size.
if (reduceClusterSize * wgSize >= numElements) {
Expand Down
15 changes: 8 additions & 7 deletions cinnamon/lib/Target/UPMEMCpp/UPMEMTranslateRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,14 @@
#include "cinm-mlir/Target/UPMEMCpp/UPMEMCppEmitter.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Target/Cpp/CppEmitter.h"
#include "mlir/Tools/mlir-translate/Translation.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/InitAllDialects.h"

using namespace mlir;
using namespace mlir::upmem_emitc;
Expand All @@ -45,6 +40,12 @@ void mlir::upmem_emitc::registerUPMEMCppTranslation() {
},
[](DialectRegistry &registry) {
// clang-format off
registry.insert<arith::ArithDialect>();
registry.insert<func::FuncDialect>();
registry.insert<LLVM::LLVMDialect>();
registry.insert<math::MathDialect>();
registry.insert<memref::MemRefDialect>();
registry.insert<scf::SCFDialect>();
registry.insert<upmem::UPMEMDialect>();
// clang-format on
});
Expand Down
41 changes: 23 additions & 18 deletions cinnamon/lib/Target/UPMEMCpp/UPMEMTranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,22 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/IndentedOstream.h"
#include "mlir/Target/Cpp/CppEmitter.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <llvm/ADT/SmallVector.h>
#include <llvm/IR/Constant.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/raw_ostream.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/SymbolTable.h>
#include <mlir/IR/Visitors.h>
#include <mlir/Support/LLVM.h>
#include <mlir/Support/LogicalResult.h>
#include <string>
#include <utility>
Expand Down Expand Up @@ -389,7 +381,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
if (arith::ConstantOp staticSize =
dyn_cast<arith::ConstantOp>(size.getDefiningOp())) {
size_t remainingElements =
staticSize.getValueAttr().dyn_cast<IntegerAttr>().getInt();
dyn_cast<IntegerAttr>(staticSize.getValueAttr()).getInt();
size_t offset = 0;
while (remainingElements > 0) {
size_t chunkSize = std::min(2048lu / elementSize, remainingElements);
Expand Down Expand Up @@ -648,17 +640,28 @@ static LogicalResult printOperation(CppEmitter &emitter, arith::XOrIOp op) {
return printBinaryOperation(emitter, op.getOperation(), "^");
}

static LogicalResult printOperation(CppEmitter &emitter, LLVM::ExpOp op) {
if (emitter.emitAssignPrefix(*op.getOperation()).failed()) {
static LogicalResult printMathOperation(CppEmitter &emitter, Operation *op, StringRef mathOp) {
if (emitter.emitAssignPrefix(*op).failed()) {
return failure();
}

emitter.ostream() << "expf(" << emitter.getOrCreateName(op.getOperand())
<< ")";
emitter.ostream() << mathOp << "(" << emitter.getOrCreateName(op->getOperand(0)) << ")";

return success();
}

static LogicalResult printOperation(CppEmitter &emitter, LLVM::ExpOp op) {
return printMathOperation(emitter, op.getOperation(), "expf");
}

static LogicalResult printOperation(CppEmitter &emitter, math::AbsFOp op) {
return printMathOperation(emitter, op.getOperation(), "absf");
}

static LogicalResult printOperation(CppEmitter &emitter, math::RsqrtOp op) {
return printMathOperation(emitter, op.getOperation(), "rsqrt");
}

static LogicalResult printOperation(CppEmitter &emitter,
cf::BranchOp branchOp) {
raw_ostream &os = emitter.ostream();
Expand Down Expand Up @@ -1332,13 +1335,13 @@ LogicalResult CppEmitter::emitLabel(Block &block) {
return success();
}

LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
if (dyn_cast<arith::ConstantOp>(op)) {
LogicalResult CppEmitter::emitOperation(Operation &operation, bool trailingSemicolon) {
if (dyn_cast<arith::ConstantOp>(operation)) {
return success();
}

LogicalResult status =
llvm::TypeSwitch<Operation *, LogicalResult>(&op)
llvm::TypeSwitch<Operation *, LogicalResult>(&operation)
// Builtin ops.
.Case<upmem::UPMEMModuleOp>(
[&](auto op) { return printOperation(*this, op); })
Expand Down Expand Up @@ -1456,6 +1459,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
.Case<arith::XOrIOp>(
[&](auto op) { return printOperation(*this, op); })
.Case<LLVM::ExpOp>([&](auto op) { return printOperation(*this, op); })
.Case<math::AbsFOp>([&](auto op) { return printOperation(*this, op); })
.Case<math::RsqrtOp>([&](auto op) { return printOperation(*this, op); })
.Case<upmem::TaskletIDOp>(
[&](auto op) { return printOperation(*this, op); })
.Case<upmem::BaseMRAMAddrOp>(
Expand All @@ -1470,7 +1475,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
[&](auto op) { return printOperation(*this, op); })
// [&](auto op) { skipSemicolon = true; return success(); })
.Default([&](Operation *) {
return op.emitOpError("unable to find printer for op");
return operation.emitOpError("unable to find printer for op");
});

if (failed(status))
Expand Down
Loading

0 comments on commit 61b5972

Please sign in to comment.