diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 62fa622ae512..1c121f5fdac4 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -52,6 +52,7 @@ #include "absl/types/span.h" #include "llvm/include/llvm/ADT/APInt.h" #include "llvm/include/llvm/Support/LogicalResult.h" +#include "llvm/include/llvm/Support/raw_ostream.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/include/mlir/Dialect/Math/IR/Math.h" @@ -5496,6 +5497,142 @@ void rotateLanes(OpBuilder &builder, xla::Array &vregs, rotateVregs(builder, vregs, amount, 1); } +FailureOr> doRowShiftRelayout( + OpBuilder &builder, const Location loc, const ArrayRef shape, + xla::Array src_vregs, const VectorLayout &src_layout, + const int64_t dst_row_offset, const std::array target_shape) { + constexpr int32_t kNativeBitwidth = 32; + const std::array tiled_ishape = + src_layout.getImplicitTiledDims(shape, 1); + const std::array tiling = src_layout.tiling(); + const int64_t sublanes_per_tile = src_layout.sublanesPerTile(target_shape); + const int64_t tiles_per_vreg = src_layout.tilesPerVreg(target_shape); + const LayoutOffsets &src_offsets = src_layout.offsets(); + CHECK(src_offsets[0].has_value()); + CHECK_EQ(tiling[0] % sublanes_per_tile, 0); + const int64_t rows_per_sublane = tiling[0] / sublanes_per_tile; + const int64_t bits_per_row = kNativeBitwidth / rows_per_sublane; + const int64_t row_shift_amount = + dst_row_offset - *src_offsets[0] + + (dst_row_offset > *src_offsets[0] ? 0 + : target_shape[0] * rows_per_sublane); + // How many whole sublanes to shift the original low bits: + const int64_t shift_sublanes = row_shift_amount / rows_per_sublane; + const int64_t shift_in_sublane = row_shift_amount % rows_per_sublane; + const int32_t bitshift = shift_in_sublane * bits_per_row; + const VectorType vreg_ty = cast(src_vregs.begin()->getType()); + const VectorType i32_vreg_ty = + VectorType::get(target_shape, builder.getI32Type()); + if (shift_in_sublane == 0) { + rotateSublanes(builder, src_vregs, shift_sublanes); + } else { + auto left_bitshift_const = builder.create( + loc, i32_vreg_ty, DenseElementsAttr::get(i32_vreg_ty, bitshift)); + auto right_bitshift_const = builder.create( + loc, i32_vreg_ty, + DenseElementsAttr::get(i32_vreg_ty, kNativeBitwidth - bitshift)); + // Note 1: Below, after shifting, the rotate+blend can be done in two ways: + // 1. Rotate low and high bits to their final positions, then OR them + // 2. Rotate low bits by 1, OR them, and then rotate low and high together + // to their final positions. + // Before optimizations, the second way performs more rotates in older TPU + // gens, but has a longer critical path. The first way is chosen and we rely + // on later optimizations to convert it to the second for older gens. + // Note 2: When rotating by less than 1 sublane, the part that rolls over + // into the next vreg is only src_high_shifted, so we *could* use only the + // src_high_shifted part instead of the fully ORed vreg. + src_vregs.Each([&](absl::Span /*idxs*/, Value *v) { + // Low bits in src, high bits in dst: + Value v_i32 = builder.create(loc, i32_vreg_ty, *v); + Value src_low_shifted = + builder.create(loc, v_i32, left_bitshift_const); + src_low_shifted = builder.create( + loc, src_low_shifted, shift_sublanes, 0, nullptr, nullptr); + // It is important that the shifted-in bits are 0, so use ShRUI + // High bits in src, low bits in dst: + Value src_high_shifted = + builder.create(loc, v_i32, right_bitshift_const); + src_high_shifted = builder.create( + loc, src_high_shifted, shift_sublanes + 1, 0, nullptr, nullptr); + // DO NOT SUBMIT: We may not need to OR for few rows + *v = builder.create(loc, src_low_shifted, src_high_shifted); + }); + } + // We've shifted and rotated so that the original low part (final high part) + // of tiles are in the right place. If there are more than 1 tiles per vreg, + // we will also need the original high part (final low part) to be in the + // right place. + // TODO(tlongeri): We could avoid allocating an extra array when there is only + // one tile per vreg. + // DO NOT SUBMIT: Name of this variable + xla::Array high_part = src_vregs; + // This is a no-op when tiles_per_vreg is 1: + rotateSublanes(builder, high_part, (tiles_per_vreg - 1) * sublanes_per_tile); + + // The mask selects the first row_shift_amount full/half/quarter/etc-sublanes + // of each tile that contains data. + Value mask = nullptr; + const VectorType vmask_ty = + getNativeVregOrVmaskType(builder.getI1Type(), bits_per_row, target_shape); + const VectorType select_vreg_ty = + getNativeVregType(builder.getIntegerType(bits_per_row), target_shape); + const int64_t shift_in_tile = row_shift_amount % tiling[0]; + for (int64_t i = 0; i < tiles_per_vreg; ++i) { + const int64_t tile_start = i * tiling[0]; + // Skip tiles that contain no data + if (*src_layout.offsets()[0] < tile_start + tiling[0] && + tile_start < *src_layout.offsets()[0] + tiled_ishape[0]) { + Value tile_mask = builder.create( + loc, vmask_ty, tile_start, tile_start + shift_in_tile); + if (mask == nullptr) { + mask = tile_mask; + } else { + mask = builder.create(loc, mask, tile_mask); + } + } + } + + xla::Array res_vregs( + VectorLayout(src_layout.bitwidth(), {dst_row_offset, src_offsets[1]}, + src_layout.tiling(), src_layout.implicit_dim()) + .tileArrayImplicitShape(shape, target_shape)); + int64_t res_low_idx_delta = -1; + int64_t res_high_idx_delta = 0; + if (dst_row_offset < *src_offsets[0]) { + ++res_low_idx_delta; + ++res_high_idx_delta; + } + res_vregs.Each([&](absl::Span idxs, Value *v) { + Value low, high; + // idxs for the result low part + SmallVector low_idxs(toArrayRef(idxs)); + *(low_idxs.end() - 2) += res_low_idx_delta; + if (0 <= *(low_idxs.end() - 2)) { + low = high_part(low_idxs); + } + // idxs for the result high part + SmallVector high_idxs(toArrayRef(idxs)); + *(high_idxs.end() - 2) += res_high_idx_delta; + if (*(high_idxs.end() - 2) < *(src_vregs.dimensions().end() - 2)) { + high = src_vregs(high_idxs); + } + if (low != nullptr && high != nullptr) { + low = builder.create(loc, select_vreg_ty, low); + high = builder.create(loc, select_vreg_ty, high); + *v = builder.create(loc, mask, low, high); + } else if (low != nullptr) { + *v = low; + } else { + DCHECK(high != nullptr); + *v = high; + } + + *v = builder.create(loc, vreg_ty, *v); + }); + + return res_vregs; +} + // Relayout src_vregs from layout src to layout dst, where dst is the same as // src except that the column offset is dst_col_offset. FailureOr> doColumnShiftRelayout( @@ -5753,8 +5890,6 @@ FailureOr>> changeOffsets( const auto &target_shape = ctx.target_shape; const VectorLayout dst(src.bitwidth(), dst_offsets, src.tiling(), src.implicit_dim()); - const int packing = src.packing(); - const int8_t bitwidth = src.bitwidth(); int row_diff; if (!src.offsets()[0].has_value()) { @@ -5780,56 +5915,9 @@ FailureOr>> changeOffsets( } const SmallVector implicit_shape = src.implicitShape(vty.getShape()); - if (implicit_shape[implicit_shape.size() - 2] != 1) { - // Multi row shift - // TODO(mvoz): This should take the vregs array, not the value. - FAILUREOR_ASSIGN_OR_RETURN( - vregs, tpu_rotate_with_overflow( - builder, target_shape, loc, vty, std::move(vregs), - /*dim*/ implicit_shape.size() - 2, src, dst_offsets)); - } else { - // Single row case - // TODO(mvoz): The single row case has a broader set of supported - // operations: non-native tiling, packed types, implicit dim. We should - // support these cases in tpu_rotate_with_overflow and remove this - // branch. - const int64_t src_sublane = *src.offsets()[0] / packing; - const int64_t dst_sublane = *dst_offsets[0] / packing; - if (int64_t sublane_diff = dst_sublane - src_sublane) { - if (sublane_diff < 0) { - sublane_diff += target_shape[0]; - } - rotateSublanes(builder, vregs, sublane_diff); - } - const int src_subelem = *src.offsets()[0] % packing; - const int dst_subelem = *dst.offsets()[0] % packing; - if (src_subelem != dst_subelem) { - const int subelem_diff = dst_subelem - src_subelem; - const int shift_bits = bitwidth * std::abs(subelem_diff); - VectorType bits_vreg_ty = - VectorType::get(target_shape, builder.getI32Type()); - auto shift_vreg = builder.create( - loc, bits_vreg_ty, - DenseElementsAttr::get(bits_vreg_ty, shift_bits)); - vregs.Each([&](absl::Span /*idx*/, Value *tile) { - auto bit_tile = - builder.create(loc, bits_vreg_ty, *tile); - Operation *shift_tile; - if (subelem_diff > 0) { - shift_tile = - builder.create(loc, bit_tile, shift_vreg); - } else { // subelem_diff < 0 - CHECK_LT(subelem_diff, 0); - shift_tile = - builder.create(loc, bit_tile, shift_vreg); - } - *tile = builder - .create(loc, tile->getType(), - shift_tile->getResult(0)) - .getResult(); - }); - } - } + FAILUREOR_ASSIGN_OR_RETURN( + vregs, doRowShiftRelayout(builder, loc, vty.getShape(), vregs, src, + *dst_offsets[0], ctx.target_shape)); } // Rows are now correctly aligned. Time to offset columns.