Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic:TPU][Relayout] Row shifts for packed types and non-native tilings #26754

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 140 additions & 52 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "absl/types/span.h"
#include "llvm/include/llvm/ADT/APInt.h"
#include "llvm/include/llvm/Support/LogicalResult.h"
#include "llvm/include/llvm/Support/raw_ostream.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/Dialect/Math/IR/Math.h"
Expand Down Expand Up @@ -5496,6 +5497,142 @@ void rotateLanes(OpBuilder &builder, xla::Array<Value> &vregs,
rotateVregs(builder, vregs, amount, 1);
}

FailureOr<xla::Array<Value>> doRowShiftRelayout(
OpBuilder &builder, const Location loc, const ArrayRef<int64_t> shape,
xla::Array<Value> src_vregs, const VectorLayout &src_layout,
const int64_t dst_row_offset, const std::array<int64_t, 2> target_shape) {
constexpr int32_t kNativeBitwidth = 32;
const std::array<int64_t, 2> tiled_ishape =
src_layout.getImplicitTiledDims(shape, 1);
const std::array<int64_t, 2> tiling = src_layout.tiling();
const int64_t sublanes_per_tile = src_layout.sublanesPerTile(target_shape);
const int64_t tiles_per_vreg = src_layout.tilesPerVreg(target_shape);
const LayoutOffsets &src_offsets = src_layout.offsets();
CHECK(src_offsets[0].has_value());
CHECK_EQ(tiling[0] % sublanes_per_tile, 0);
const int64_t rows_per_sublane = tiling[0] / sublanes_per_tile;
const int64_t bits_per_row = kNativeBitwidth / rows_per_sublane;
const int64_t row_shift_amount =
dst_row_offset - *src_offsets[0] +
(dst_row_offset > *src_offsets[0] ? 0
: target_shape[0] * rows_per_sublane);
// How many whole sublanes to shift the original low bits:
const int64_t shift_sublanes = row_shift_amount / rows_per_sublane;
const int64_t shift_in_sublane = row_shift_amount % rows_per_sublane;
const int32_t bitshift = shift_in_sublane * bits_per_row;
const VectorType vreg_ty = cast<VectorType>(src_vregs.begin()->getType());
const VectorType i32_vreg_ty =
VectorType::get(target_shape, builder.getI32Type());
if (shift_in_sublane == 0) {
rotateSublanes(builder, src_vregs, shift_sublanes);
} else {
auto left_bitshift_const = builder.create<arith::ConstantOp>(
loc, i32_vreg_ty, DenseElementsAttr::get(i32_vreg_ty, bitshift));
auto right_bitshift_const = builder.create<arith::ConstantOp>(
loc, i32_vreg_ty,
DenseElementsAttr::get(i32_vreg_ty, kNativeBitwidth - bitshift));
// Note 1: Below, after shifting, the rotate+blend can be done in two ways:
// 1. Rotate low and high bits to their final positions, then OR them
// 2. Rotate low bits by 1, OR them, and then rotate low and high together
// to their final positions.
// Before optimizations, the second way performs more rotates in older TPU
// gens, but has a longer critical path. The first way is chosen and we rely
// on later optimizations to convert it to the second for older gens.
// Note 2: When rotating by less than 1 sublane, the part that rolls over
// into the next vreg is only src_high_shifted, so we *could* use only the
// src_high_shifted part instead of the fully ORed vreg.
src_vregs.Each([&](absl::Span<const int64_t> /*idxs*/, Value *v) {
// Low bits in src, high bits in dst:
Value v_i32 = builder.create<tpu::BitcastVregOp>(loc, i32_vreg_ty, *v);
Value src_low_shifted =
builder.create<arith::ShLIOp>(loc, v_i32, left_bitshift_const);
src_low_shifted = builder.create<tpu::RotateOp>(
loc, src_low_shifted, shift_sublanes, 0, nullptr, nullptr);
// It is important that the shifted-in bits are 0, so use ShRUI
// High bits in src, low bits in dst:
Value src_high_shifted =
builder.create<arith::ShRUIOp>(loc, v_i32, right_bitshift_const);
src_high_shifted = builder.create<tpu::RotateOp>(
loc, src_high_shifted, shift_sublanes + 1, 0, nullptr, nullptr);
// DO NOT SUBMIT: We may not need to OR for few rows
*v = builder.create<arith::OrIOp>(loc, src_low_shifted, src_high_shifted);
});
}
// We've shifted and rotated so that the original low part (final high part)
// of tiles are in the right place. If there are more than 1 tiles per vreg,
// we will also need the original high part (final low part) to be in the
// right place.
// TODO(tlongeri): We could avoid allocating an extra array when there is only
// one tile per vreg.
// DO NOT SUBMIT: Name of this variable
xla::Array<Value> high_part = src_vregs;
// This is a no-op when tiles_per_vreg is 1:
rotateSublanes(builder, high_part, (tiles_per_vreg - 1) * sublanes_per_tile);

// The mask selects the first row_shift_amount full/half/quarter/etc-sublanes
// of each tile that contains data.
Value mask = nullptr;
const VectorType vmask_ty =
getNativeVregOrVmaskType(builder.getI1Type(), bits_per_row, target_shape);
const VectorType select_vreg_ty =
getNativeVregType(builder.getIntegerType(bits_per_row), target_shape);
const int64_t shift_in_tile = row_shift_amount % tiling[0];
for (int64_t i = 0; i < tiles_per_vreg; ++i) {
const int64_t tile_start = i * tiling[0];
// Skip tiles that contain no data
if (*src_layout.offsets()[0] < tile_start + tiling[0] &&
tile_start < *src_layout.offsets()[0] + tiled_ishape[0]) {
Value tile_mask = builder.create<tpu::CreateSubelementMaskOp>(
loc, vmask_ty, tile_start, tile_start + shift_in_tile);
if (mask == nullptr) {
mask = tile_mask;
} else {
mask = builder.create<arith::OrIOp>(loc, mask, tile_mask);
}
}
}

xla::Array<Value> res_vregs(
VectorLayout(src_layout.bitwidth(), {dst_row_offset, src_offsets[1]},
src_layout.tiling(), src_layout.implicit_dim())
.tileArrayImplicitShape(shape, target_shape));
int64_t res_low_idx_delta = -1;
int64_t res_high_idx_delta = 0;
if (dst_row_offset < *src_offsets[0]) {
++res_low_idx_delta;
++res_high_idx_delta;
}
res_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
Value low, high;
// idxs for the result low part
SmallVector<int64_t> low_idxs(toArrayRef(idxs));
*(low_idxs.end() - 2) += res_low_idx_delta;
if (0 <= *(low_idxs.end() - 2)) {
low = high_part(low_idxs);
}
// idxs for the result high part
SmallVector<int64_t> high_idxs(toArrayRef(idxs));
*(high_idxs.end() - 2) += res_high_idx_delta;
if (*(high_idxs.end() - 2) < *(src_vregs.dimensions().end() - 2)) {
high = src_vregs(high_idxs);
}
if (low != nullptr && high != nullptr) {
low = builder.create<tpu::BitcastVregOp>(loc, select_vreg_ty, low);
high = builder.create<tpu::BitcastVregOp>(loc, select_vreg_ty, high);
*v = builder.create<arith::SelectOp>(loc, mask, low, high);
} else if (low != nullptr) {
*v = low;
} else {
DCHECK(high != nullptr);
*v = high;
}

*v = builder.create<tpu::BitcastVregOp>(loc, vreg_ty, *v);
});

return res_vregs;
}

// Relayout src_vregs from layout src to layout dst, where dst is the same as
// src except that the column offset is dst_col_offset.
FailureOr<xla::Array<Value>> doColumnShiftRelayout(
Expand Down Expand Up @@ -5753,8 +5890,6 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeOffsets(
const auto &target_shape = ctx.target_shape;
const VectorLayout dst(src.bitwidth(), dst_offsets, src.tiling(),
src.implicit_dim());
const int packing = src.packing();
const int8_t bitwidth = src.bitwidth();

int row_diff;
if (!src.offsets()[0].has_value()) {
Expand All @@ -5780,56 +5915,9 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeOffsets(
}
const SmallVector<int64_t> implicit_shape =
src.implicitShape(vty.getShape());
if (implicit_shape[implicit_shape.size() - 2] != 1) {
// Multi row shift
// TODO(mvoz): This should take the vregs array, not the value.
FAILUREOR_ASSIGN_OR_RETURN(
vregs, tpu_rotate_with_overflow(
builder, target_shape, loc, vty, std::move(vregs),
/*dim*/ implicit_shape.size() - 2, src, dst_offsets));
} else {
// Single row case
// TODO(mvoz): The single row case has a broader set of supported
// operations: non-native tiling, packed types, implicit dim. We should
// support these cases in tpu_rotate_with_overflow and remove this
// branch.
const int64_t src_sublane = *src.offsets()[0] / packing;
const int64_t dst_sublane = *dst_offsets[0] / packing;
if (int64_t sublane_diff = dst_sublane - src_sublane) {
if (sublane_diff < 0) {
sublane_diff += target_shape[0];
}
rotateSublanes(builder, vregs, sublane_diff);
}
const int src_subelem = *src.offsets()[0] % packing;
const int dst_subelem = *dst.offsets()[0] % packing;
if (src_subelem != dst_subelem) {
const int subelem_diff = dst_subelem - src_subelem;
const int shift_bits = bitwidth * std::abs(subelem_diff);
VectorType bits_vreg_ty =
VectorType::get(target_shape, builder.getI32Type());
auto shift_vreg = builder.create<arith::ConstantOp>(
loc, bits_vreg_ty,
DenseElementsAttr::get(bits_vreg_ty, shift_bits));
vregs.Each([&](absl::Span<const int64_t> /*idx*/, Value *tile) {
auto bit_tile =
builder.create<tpu::BitcastVregOp>(loc, bits_vreg_ty, *tile);
Operation *shift_tile;
if (subelem_diff > 0) {
shift_tile =
builder.create<arith::ShLIOp>(loc, bit_tile, shift_vreg);
} else { // subelem_diff < 0
CHECK_LT(subelem_diff, 0);
shift_tile =
builder.create<arith::ShRUIOp>(loc, bit_tile, shift_vreg);
}
*tile = builder
.create<tpu::BitcastVregOp>(loc, tile->getType(),
shift_tile->getResult(0))
.getResult();
});
}
}
FAILUREOR_ASSIGN_OR_RETURN(
vregs, doRowShiftRelayout(builder, loc, vty.getShape(), vregs, src,
*dst_offsets[0], ctx.target_shape));
}

// Rows are now correctly aligned. Time to offset columns.
Expand Down
Loading