Skip to content

Commit

Permalink
Merge pull request #566 from robertknight/conv-integer
Browse files Browse the repository at this point in the history
Implement `ConvInteger` operator
  • Loading branch information
robertknight authored Feb 3, 2025
2 parents 352c986 + ad266bd commit e48df48
Show file tree
Hide file tree
Showing 15 changed files with 759 additions and 125 deletions.
2 changes: 1 addition & 1 deletion rten-convert/rten_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def op_node_from_onnx_operator(
attrs.valueType = scalar_type
attrs.value = scalar

case "Conv":
case "Conv" | "ConvInteger":
attrs = sg.ConvAttrsT()
attrs.dilations = read_dilations(op_reader)
attrs.groups = op_reader.get_attr("group", "int", 1)
Expand Down
105 changes: 53 additions & 52 deletions rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class OperatorType(object):
DynamicQuantizeLinear = 107
MatMulInteger = 108
DepthToSpace = 109
ConvInteger = 110


class RNNDirection(object):
Expand Down Expand Up @@ -205,91 +206,91 @@ def OperatorAttrsCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == OperatorAttrs().ArgMaxAttrs:
if unionType == OperatorAttrs.ArgMaxAttrs:
return ArgMaxAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().AveragePoolAttrs:
if unionType == OperatorAttrs.AveragePoolAttrs:
return AveragePoolAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().BatchNormalizationAttrs:
if unionType == OperatorAttrs.BatchNormalizationAttrs:
return BatchNormalizationAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().CastAttrs:
if unionType == OperatorAttrs.CastAttrs:
return CastAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ConcatAttrs:
if unionType == OperatorAttrs.ConcatAttrs:
return ConcatAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ConstantOfShapeAttrs:
if unionType == OperatorAttrs.ConstantOfShapeAttrs:
return ConstantOfShapeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ConvAttrs:
if unionType == OperatorAttrs.ConvAttrs:
return ConvAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ConvTransposeAttrs:
if unionType == OperatorAttrs.ConvTransposeAttrs:
return ConvTransposeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().FlattenAttrs:
if unionType == OperatorAttrs.FlattenAttrs:
return FlattenAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GatherAttrs:
if unionType == OperatorAttrs.GatherAttrs:
return GatherAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GemmAttrs:
if unionType == OperatorAttrs.GemmAttrs:
return GemmAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GRUAttrs:
if unionType == OperatorAttrs.GRUAttrs:
return GRUAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().LeakyReluAttrs:
if unionType == OperatorAttrs.LeakyReluAttrs:
return LeakyReluAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().LSTMAttrs:
if unionType == OperatorAttrs.LSTMAttrs:
return LSTMAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().MaxPoolAttrs:
if unionType == OperatorAttrs.MaxPoolAttrs:
return MaxPoolAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ReduceMeanAttrs:
if unionType == OperatorAttrs.ReduceMeanAttrs:
return ReduceMeanAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ReshapeAttrs:
if unionType == OperatorAttrs.ReshapeAttrs:
return ReshapeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ResizeAttrs:
if unionType == OperatorAttrs.ResizeAttrs:
return ResizeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().SplitAttrs:
if unionType == OperatorAttrs.SplitAttrs:
return SplitAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().SoftmaxAttrs:
if unionType == OperatorAttrs.SoftmaxAttrs:
return SoftmaxAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().TransposeAttrs:
if unionType == OperatorAttrs.TransposeAttrs:
return TransposeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ModAttrs:
if unionType == OperatorAttrs.ModAttrs:
return ModAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ScatterElementsAttrs:
if unionType == OperatorAttrs.ScatterElementsAttrs:
return ScatterElementsAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().OneHotAttrs:
if unionType == OperatorAttrs.OneHotAttrs:
return OneHotAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().TopKAttrs:
if unionType == OperatorAttrs.TopKAttrs:
return TopKAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().HardSigmoidAttrs:
if unionType == OperatorAttrs.HardSigmoidAttrs:
return HardSigmoidAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().TriluAttrs:
if unionType == OperatorAttrs.TriluAttrs:
return TriluAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ScatterNDAttrs:
if unionType == OperatorAttrs.ScatterNDAttrs:
return ScatterNDAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().NonMaxSuppressionAttrs:
if unionType == OperatorAttrs.NonMaxSuppressionAttrs:
return NonMaxSuppressionAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().LayerNormalizationAttrs:
if unionType == OperatorAttrs.LayerNormalizationAttrs:
return LayerNormalizationAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().RandomUniformAttrs:
if unionType == OperatorAttrs.RandomUniformAttrs:
return RandomUniformAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().EluAttrs:
if unionType == OperatorAttrs.EluAttrs:
return EluAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().RandomUniformLikeAttrs:
if unionType == OperatorAttrs.RandomUniformLikeAttrs:
return RandomUniformLikeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().RandomNormalAttrs:
if unionType == OperatorAttrs.RandomNormalAttrs:
return RandomNormalAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().RandomNormalLikeAttrs:
if unionType == OperatorAttrs.RandomNormalLikeAttrs:
return RandomNormalLikeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GatherNDAttrs:
if unionType == OperatorAttrs.GatherNDAttrs:
return GatherNDAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GeluAttrs:
if unionType == OperatorAttrs.GeluAttrs:
return GeluAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().EinsumAttrs:
if unionType == OperatorAttrs.EinsumAttrs:
return EinsumAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().IfAttrs:
if unionType == OperatorAttrs.IfAttrs:
return IfAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().PadAttrs:
if unionType == OperatorAttrs.PadAttrs:
return PadAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().DequantizeLinearAttrs:
if unionType == OperatorAttrs.DequantizeLinearAttrs:
return DequantizeLinearAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().QuantizeLinearAttrs:
if unionType == OperatorAttrs.QuantizeLinearAttrs:
return QuantizeLinearAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().DepthToSpaceAttrs:
if unionType == OperatorAttrs.DepthToSpaceAttrs:
return DepthToSpaceAttrsT.InitFromBuf(table.Bytes, table.Pos)
return None

Expand All @@ -308,9 +309,9 @@ def ScalarCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == Scalar().IntScalar:
if unionType == Scalar.IntScalar:
return IntScalarT.InitFromBuf(table.Bytes, table.Pos)
if unionType == Scalar().FloatScalar:
if unionType == Scalar.FloatScalar:
return FloatScalarT.InitFromBuf(table.Bytes, table.Pos)
return None

Expand Down Expand Up @@ -343,11 +344,11 @@ def NodeKindCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == NodeKind().OperatorNode:
if unionType == NodeKind.OperatorNode:
return OperatorNodeT.InitFromBuf(table.Bytes, table.Pos)
if unionType == NodeKind().ConstantNode:
if unionType == NodeKind.ConstantNode:
return ConstantNodeT.InitFromBuf(table.Bytes, table.Pos)
if unionType == NodeKind().ValueNode:
if unionType == NodeKind.ValueNode:
return ValueNodeT.InitFromBuf(table.Bytes, table.Pos)
return None

Expand All @@ -363,13 +364,13 @@ def ConstantDataCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == ConstantData().FloatData:
if unionType == ConstantData.FloatData:
return FloatDataT.InitFromBuf(table.Bytes, table.Pos)
if unionType == ConstantData().Int32Data:
if unionType == ConstantData.Int32Data:
return Int32DataT.InitFromBuf(table.Bytes, table.Pos)
if unionType == ConstantData().Int8Data:
if unionType == ConstantData.Int8Data:
return Int8DataT.InitFromBuf(table.Bytes, table.Pos)
if unionType == ConstantData().UInt8Data:
if unionType == ConstantData.UInt8Data:
return UInt8DataT.InitFromBuf(table.Bytes, table.Pos)
return None

Expand Down
7 changes: 3 additions & 4 deletions src/gemm/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,14 +302,13 @@ impl Im2Col<'_, i8> {
let out_ptr = out_ptr.add(out_offset + idx * K_TILE + i);
let src_elem = *img_ptr.add(offsets_array[idx] as usize);

// This should be compiled to a conditional move.
let elem = if pad_mask_array[idx] { src_elem } else { 0 };

if CAST_B_U8 {
let elem = shift_cast_i8_u8(elem);
let src_elem = shift_cast_i8_u8(src_elem);
let elem = if pad_mask_array[idx] { src_elem } else { 0 };
col_sums[idx] += elem as i32;
out_ptr.write(MaybeUninit::new(elem as i8));
} else {
let elem = if pad_mask_array[idx] { src_elem } else { 0 };
col_sums[idx] += elem as i32;
out_ptr.write(MaybeUninit::new(elem));
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ mod model_metadata;
mod number;
mod op_registry;
mod optimize;
mod shift_cast;
mod slice_cast;
mod slice_reductions;
mod tensor_pool;
Expand Down
16 changes: 15 additions & 1 deletion src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,7 @@ mod tests {
let input_node = graph_builder.add_value("input", None, None);
let input_2d = graph_builder.add_value("input.2d", None, None);
let input_bool = graph_builder.add_value("input.bool", None, None);
let input_2d_u8 = graph_builder.add_value("input.2d.u8", None, None);

// 4D shape used as the primary input to test most operators (eg. NCHW image). A few
// require a different shape.
Expand All @@ -1257,6 +1258,9 @@ mod tests {
let kernel_val = Tensor::from_data(&[1, 1, 1, 1], vec![0.5]);
let kernel = graph_builder.add_constant(kernel_val.view());

let kernel_val_i8 = Tensor::from_data(&[1, 1, 1, 1], vec![0i8]);
let kernel_i8 = graph_builder.add_constant(kernel_val_i8.view());

// Names of all operator output nodes.
let mut op_outputs = Vec::new();

Expand Down Expand Up @@ -1340,7 +1344,12 @@ mod tests {
padding: [1, 1, 1, 1].into(),
strides: vec![1, 1],
});

add_operator!(ConvInteger, [input_2d_u8, kernel_i8], {
dilations: vec![1, 1],
groups: 1,
padding: [1, 1, 1, 1].into(),
strides: vec![1, 1],
});
add_operator!(ConvTranspose, [input_node, kernel], {
strides: vec![2, 2],
padding: [0, 0, 0, 0].into(),
Expand Down Expand Up @@ -1663,11 +1672,14 @@ mod tests {
// Most ops are tested with one of several standard inputs:
//
// - 4D float tensor (like an NCHW image)
// - Int8 NCHW tensor
// - Bool-ish int tensor
//
// A few require different shapes are tested separately.
let input = Tensor::from_data(&input_shape, vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
let input_bool_data: Tensor<i32> = Tensor::from([0, 1, 1]);
let input_u8_data = input.map(|&x| x as u8);

for output in op_outputs {
if [
"Gemm_out",
Expand Down Expand Up @@ -1695,6 +1707,7 @@ mod tests {
vec![
(input_node, input.view().into()),
(input_bool, input_bool_data.view().into()),
(input_2d_u8, input_u8_data.view().into()),
],
&[output_id],
None,
Expand All @@ -1712,6 +1725,7 @@ mod tests {
vec![
(input_node, input.clone().into()),
(input_bool, input_bool_data.clone().into()),
(input_2d_u8, input_u8_data.clone().into()),
],
&[output_id],
None,
Expand Down
27 changes: 21 additions & 6 deletions src/model_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ use crate::header::Header;
use crate::number::LeBytes;
use crate::ops::{
ArgMax, ArgMin, AveragePool, BatchNormalization, BoxOrder, Cast, Concat, ConstantOfShape, Conv,
ConvTranspose, CoordTransformMode, DataType, DepthToSpace, DepthToSpaceMode, DequantizeLinear,
Einsum, Elu, Flatten, Gather, GatherElements, GatherND, Gelu, Gemm, HardSigmoid,
InstanceNormalization, LayerNormalization, LeakyRelu, LogSoftmax, MaxPool, Mod, NearestMode,
NonMaxSuppression, OneHot, Padding, QuantizeLinear, ReduceMax, ReduceMean, ReduceMin,
ReduceProd, ReduceSum, ReduceSumSquare, Reshape, Resize, ResizeMode, Scalar, ScatterElements,
ScatterReduction, Softmax, Split, TopK, Transpose, Trilu,
ConvInteger, ConvTranspose, CoordTransformMode, DataType, DepthToSpace, DepthToSpaceMode,
DequantizeLinear, Einsum, Elu, Flatten, Gather, GatherElements, GatherND, Gelu, Gemm,
HardSigmoid, InstanceNormalization, LayerNormalization, LeakyRelu, LogSoftmax, MaxPool, Mod,
NearestMode, NonMaxSuppression, OneHot, Padding, QuantizeLinear, ReduceMax, ReduceMean,
ReduceMin, ReduceProd, ReduceSum, ReduceSumSquare, Reshape, Resize, ResizeMode, Scalar,
ScatterElements, ScatterReduction, Softmax, Split, TopK, Transpose, Trilu,
};
use crate::schema_generated as sg;

Expand Down Expand Up @@ -44,6 +44,7 @@ pub enum OpType<'a> {
Concat(Concat),
ConstantOfShape(ConstantOfShape),
Conv(Conv),
ConvInteger(ConvInteger),
ConvTranspose(ConvTranspose),
Cos,
DequantizeLinear(DequantizeLinear),
Expand Down Expand Up @@ -497,6 +498,20 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
strides,
}
}),
OpType::ConvInteger(args) => op_with_attrs!(ConvInteger, ConvAttrs, {
let pad_args = pad_args_from_padding(args.padding);
let pads = self.create_vec(pad_args.pads, |pad| pad as u32);
let dilations = self.create_vec(Some(args.dilations), |d| d as u32);
let strides = self.create_vec(Some(args.strides), |s| s as u32);

sg::ConvAttrsArgs {
dilations,
groups: args.groups as u32,
auto_pad: pad_args.auto_pad,
pads,
strides,
}
}),
OpType::ConvTranspose(args) => op_with_attrs!(ConvTranspose, ConvTransposeAttrs, {
let pad_args = pad_args_from_padding(args.padding);
let pads = self.create_vec(pad_args.pads, |pad| pad as u32);
Expand Down
13 changes: 13 additions & 0 deletions src/op_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ impl OpRegistry {
register_op!(Clip);
register_op!(Concat);
register_op!(Conv);
register_op!(ConvInteger);
register_op!(ConstantOfShape);
register_op!(ConvTranspose);
register_op!(Cos);
Expand Down Expand Up @@ -452,6 +453,18 @@ impl_read_op!(Conv, attrs_as_conv_attrs, |attrs: sg::ConvAttrs| {
dilations,
})
});
impl_read_op!(ConvInteger, attrs_as_conv_attrs, |attrs: sg::ConvAttrs| {
let groups = attrs.groups() as usize;
let padding = padding_from_attrs(attrs.auto_pad(), attrs.pads());
let strides = vec_from_attr(attrs.strides(), &[1, 1]);
let dilations = vec_from_attr(attrs.dilations(), &[1, 1]);
Ok(ops::ConvInteger {
groups,
padding,
strides,
dilations,
})
});
impl_read_op!(
ConstantOfShape,
attrs_as_constant_of_shape_attrs,
Expand Down
Loading

0 comments on commit e48df48

Please sign in to comment.