diff --git a/compiler/circle-quantizer/src/QuantizeWeightsLLM.cpp b/compiler/circle-quantizer/src/QuantizeWeightsLLM.cpp index df3fb547925..493a13c0d4d 100644 --- a/compiler/circle-quantizer/src/QuantizeWeightsLLM.cpp +++ b/compiler/circle-quantizer/src/QuantizeWeightsLLM.cpp @@ -130,6 +130,12 @@ void QuantizeWeightsLLM::visit(luci::CircleFullyConnected *node) void QuantizeWeightsLLM::visit(luci::CircleGather *node) { + if (dynamic_cast(node->params()) == nullptr) + return; + + if (dynamic_cast(node->indices()) != nullptr) + return; + auto input = loco::must_cast(node->arg(0)); if (elementsize(input) < _skip_length) return; @@ -139,6 +145,11 @@ void QuantizeWeightsLLM::visit(luci::CircleGather *node) auto new_weights = _quant_type == Type::Q4_0 ? quantize_q4_block(input) : quantize_q8_block(input); node->params(new_weights); + + // Workaround: indices to INT32 type + auto indices = loco::must_cast(node->indices()); + if (indices->dtype() == loco::DataType::S64) + indices->dtype(loco::DataType::S32); } }