Skip to content

Commit

Permalink
Made conv2dk1_i8 more parameterizable for strix (#2046)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackl-xilinx authored Feb 18, 2025
1 parent 98bec8f commit ab9760e
Showing 1 changed file with 74 additions and 54 deletions.
128 changes: 74 additions & 54 deletions aie_kernels/aie2/conv2dk1_i8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,16 @@ void conv2dk1_i8_vector(int8_t *input, int8_t *kernels, int8_t *output,
const int32_t output_channels, const int scale) {
event0();

using MMUL4x8x8 = aie::mmul<4, 8, 8, int8, int8>;
constexpr int NUM_ACC = 8; // Number of accumulators
constexpr int MMUL_M = 4; // Matrix A M size in MxK (Input width)
constexpr int MMUL_K = 8;
constexpr int MMUL_N = 8;
constexpr int CHANNEL_FACTOR = MMUL_K;
constexpr int MMUL_MK = MMUL_M * MMUL_K;
constexpr int MMUL_KN = MMUL_K * MMUL_N;
constexpr int MMUL_MN = MMUL_M * MMUL_N;

using MMUL8x8x8 = aie::mmul<MMUL_M, MMUL_K, MMUL_N, int8, int8>;
::aie::set_saturation(
aie::saturation_mode::saturate); // Needed to saturate properly to uint8
::aie::set_rounding(aie::rounding_mode::symmetric_inf); // Needed to saturate
Expand All @@ -97,93 +106,104 @@ void conv2dk1_i8_vector(int8_t *input, int8_t *kernels, int8_t *output,

const int scaleT = scale;

MMUL4x8x8 acc_tmp[8];
for (int x = 0; x < 8; x++) {
acc_tmp[x] = aie::zeros<acc32, 32>();
MMUL8x8x8 acc_tmp[NUM_ACC];
for (int x = 0; x < NUM_ACC; x++) {
acc_tmp[x] = aie::zeros<acc32, MMUL_MN>();
}

// TODO Keeping this variable gives a wrong behavior and bad schedule!
const int iw = input_width;
const int iw_32 = (input_width / 4) / 8;
const int iw_partial = (input_width / MMUL_M) / NUM_ACC;

// const int iw_32_rem = (input_width / 4) % 8;
// const int iw_32_rem = (32 / 4) % 8;
assert((input_width / 4) % 8 == 0);
const int iw_32_rem = 0; // TODO - See restriction
// const int iw_partial_rem = (input_width / MMUL_M) % NUM_ACC;
// const int iw_partial_rem = (32 / MMUL_M) % NUM_ACC;
assert((input_width / MMUL_M) % NUM_ACC == 0);
const int iw_partial_rem = 0; // TODO - See restriction

assert((input_channels / 8) > 2); // Assume IC >= 16
assert((input_channels / CHANNEL_FACTOR) > 2); // Assume IC >= 16

if (iw_32 > 0) {
int8_t *input_begin_ptr = input;
int8_t *input_rem_begin_ptr =
input + iw_partial * MMUL_M * NUM_ACC * CHANNEL_FACTOR;

for (int oc = 0; oc < (output_channels / 8); oc++) {
for (int iw_32c = 0; iw_32c < iw_32; iw_32c++) {
for (int ic = 0; ic < (input_channels / 8); ic++)
if (iw_partial > 0) {

for (int oc = 0; oc < (output_channels / CHANNEL_FACTOR); oc++) {
for (int iw_partialc = 0; iw_partialc < iw_partial; iw_partialc++) {
for (int ic = 0; ic < (input_channels / CHANNEL_FACTOR); ic++)
chess_prepare_for_pipelining chess_loop_range(2, ) {
aie::vector<int8, 64> in_b = aie::load_v<64>(kernels);
kernels += 64; // wts ic0..7(oc0..7)
aie::vector<int8, MMUL_KN> in_b = aie::load_v<MMUL_KN>(kernels);
kernels += MMUL_KN; // wts ic0..7(oc0..7)

for (int x = 0; x < 8; x++) {
aie::vector<int8, 32> in_a = aie::load_v<32>(input);
input += 32; // act oc0..3(ic0..7)
for (int x = 0; x < NUM_ACC; x++) { // 4 acc
aie::vector<int8, MMUL_MK> in_a = aie::load_v<MMUL_MK>(input);
input += MMUL_MK; // act oc0..3(ic0..7)
acc_tmp[x].mac(in_a, in_b);
}
input += (iw * 8) - 256; // Move to next ic/8 position
// Move to next ic/8 position but in the same input range
input += (iw * CHANNEL_FACTOR) - MMUL_MK * NUM_ACC;
}
// input ptr just moves to next section
for (int xx = 0; xx < 8; xx++) {
aie::vector<int8, 32> o1 = acc_tmp[xx].to_vector<int8>(scaleT);

for (int xx = 0; xx < NUM_ACC; xx++) {
aie::vector<int8, MMUL_MN> o1 = acc_tmp[xx].to_vector<int8>(scaleT);
aie::store_v(out_ptr, o1);
out_ptr += 32;
acc_tmp[xx] = aie::zeros<acc32, 32>();
out_ptr += MMUL_MN;
acc_tmp[xx] = aie::zeros<acc32, MMUL_MN>();
}
input -= ((input_channels / 8) * iw * 8) -
256; // reset to next input_width/32 block
kernels -=
(input_channels / 8) * 64; // reset kernel back to beginning of ic/8
// reset to next set of 64*NUM_ACC inputs
input -= (input_channels * iw) - MMUL_MK * NUM_ACC;
// reset kernel back to beginning of ic/8
kernels -= (input_channels / CHANNEL_FACTOR) * MMUL_KN;
}
input -= (iw_32) * 256; // 8*32, reset beginning of input ptr
kernels += (input_channels / 8) * 64; // move to next oc/8 weights
out_ptr += (iw_32_rem *
32); // move to next oc/8 (skip remainder section if present)
input = input_begin_ptr; // reset beginning of input ptr
kernels += (input_channels / CHANNEL_FACTOR) *
MMUL_KN; // move to next oc/8 weights
out_ptr +=
(iw_partial_rem *
MMUL_MN); // move to next oc/8 (skip remainder section if present)
}

} // if(iw_32 > 0) {
} // if(iw_partial > 0) {

if (iw_32_rem > 0) {
if (iw_partial_rem > 0) {

const int ocs = output_channels;
const int ics = input_channels;

for (int oc = 0; oc < (ocs / 8); oc++) {
for (int ic = 0; ic < (ics / 8); ic++)
for (int oc = 0; oc < (ocs / CHANNEL_FACTOR); oc++) {
for (int ic = 0; ic < (ics / CHANNEL_FACTOR); ic++)
chess_prepare_for_pipelining chess_loop_range(2, ) {
aie::vector<int8, 64> in_b = aie::load_v<64>(kernels);
kernels += 64; // wts ic0..7(oc0..7)
aie::vector<int8, MMUL_KN> in_b = aie::load_v<MMUL_KN>(kernels);
kernels += MMUL_KN; // wts ic0..7(oc0..7)

for (int x = 0; x < iw_32_rem; x++) {
aie::vector<int8, 32> in_a = aie::load_v<32>(input);
input += 32; // act oc0..3(ic0..7)
for (int x = 0; x < iw_partial_rem; x++) {
aie::vector<int8, MMUL_MK> in_a = aie::load_v<MMUL_MK>(input);
input += MMUL_MK; // act oc0..3(ic0..7)
acc_tmp[x].mac(in_a, in_b);
}
input += (iw * 8) - (iw_32_rem * 32); // Move to next ic/8 position
input += (iw * CHANNEL_FACTOR) -
(MMUL_MK * iw_partial_rem); // Move to next ic/8 position
}
// input ptr just moves to next section
for (int xx = 0; xx < iw_32_rem; xx++) {
aie::vector<int8, 32> o1 = acc_tmp[xx].to_vector<int8>(scaleT);
for (int xx = 0; xx < iw_partial_rem; xx++) {
aie::vector<int8, MMUL_MN> o1 = acc_tmp[xx].to_vector<int8>(scaleT);
aie::store_v(out_ptr, o1);
out_ptr += 32;
acc_tmp[xx] = aie::zeros<acc32, 32>();
out_ptr += MMUL_MN;
acc_tmp[xx] = aie::zeros<acc32, MMUL_MN>();
}
// input -= ((ics-1)/8)*(iw*8)+(iw_32_rem*32); // reset to beginning of
// input ptr for remainder
input -= 448; // reset to beginning of input ptr for remainder
// input -= ((ics-1)/8)*(iw*8)+(iw_partial_rem*32); // reset to
// beginning of input ptr for remainder reset to beginning of input ptr
// for remainder
input = input_rem_begin_ptr;
// kernel ptr already at next oc/8
out_ptr += (iw * 8) -
(iw_32_rem *
32); // move to next oc/8 (skip remainder section if present)
out_ptr +=
(iw * CHANNEL_FACTOR) -
(iw_partial_rem *
MMUL_MN); // move to next oc/8 (skip remainder section if present)
}

} // if(iw_32_rem > 0)
} // if(iw_partial_rem > 0)

event1();
}
Expand Down Expand Up @@ -218,4 +238,4 @@ void conv2dk1_i8(int8_t *input, int8_t *kernels, int8_t *output,
}
#endif // INT8_ACT
#endif // Vector
} // extern "C"
} // extern "C"

0 comments on commit ab9760e

Please sign in to comment.