Skip to content

Commit

Permalink
Support an array of XY data plots (#632)
Browse files Browse the repository at this point in the history
* Add serialization strategy for array of DoubleXYData

* Generate new stubs for test.proto with an array of XYData field

* Get serializer tests passing with DoubleXYData array

* Fix serialization and deserialization of array with multiple xy data message values

* Add test_service tests for DoubleXYDataArray1D

* Fix Black errors

* Fix lint error

* Fix mypy errors

* Revert unintentional changes

* Fix tests
  • Loading branch information
dixonjoel authored Mar 7, 2024
1 parent c3fa33d commit 6c70e0f
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 49 deletions.
5 changes: 5 additions & 0 deletions ni_measurementlink_service/_datatypeinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,9 @@ def get_type_info(data_type: DataType) -> DataTypeInfo:
DataType.PinArray1D: DataTypeInfo(type_pb2.Field.TYPE_STRING, True, TypeSpecialization.Pin),
DataType.PathArray1D: DataTypeInfo(type_pb2.Field.TYPE_STRING, True, TypeSpecialization.Path),
DataType.EnumArray1D: DataTypeInfo(type_pb2.Field.TYPE_ENUM, True, TypeSpecialization.Enum),
DataType.DoubleXYDataArray1D: DataTypeInfo(
type_pb2.Field.TYPE_MESSAGE,
True,
message_type=xydata_pb2.DoubleXYData.DESCRIPTOR.full_name,
),
}
113 changes: 86 additions & 27 deletions ni_measurementlink_service/_internal/parameter/_message.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import struct
from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)

from google.protobuf.internal import encoder, wire_format
from google.protobuf.message import Message
Expand All @@ -14,7 +25,7 @@

def _message_encoder_constructor(
field_index: int, is_repeated: bool, is_packed: bool
) -> Callable[[WriteFunction, Message, bool], int]:
) -> Callable[[WriteFunction, Union[Message, List[Message]], bool], int]:
"""Mimics google.protobuf.internal.MessageEncoder.
This function was forked in order to call SerializeToString instead of _InternalSerialize.
Expand All @@ -26,13 +37,31 @@ def _message_encoder_constructor(
tag = encoder.TagBytes(field_index, wire_format.WIRETYPE_LENGTH_DELIMITED)
encode_varint = _varint_encoder()

def _encode_message(write: WriteFunction, value: Message, deterministic: bool) -> int:
write(tag)
bytes = value.SerializeToString()
encode_varint(write, len(bytes), deterministic)
return write(bytes)
if is_repeated:

def _encode_repeated_message(
write: WriteFunction, value: Union[Message, List[Message]], deterministic: bool
) -> int:
bytes_written = 0
for element in cast(List[Message], value):
write(tag)
bytes = element.SerializeToString()
encode_varint(write, len(bytes), deterministic)
bytes_written += write(bytes)
return bytes_written

return _encode_repeated_message
else:

return _encode_message
def _encode_message(
write: WriteFunction, value: Union[Message, List[Message]], deterministic: bool
) -> int:
write(tag)
bytes = cast(Message, value).SerializeToString()
encode_varint(write, len(bytes), deterministic)
return write(bytes)

return _encode_message


def _varint_encoder() -> Callable[[WriteFunction, int, Optional[bool]], int]:
Expand Down Expand Up @@ -67,25 +96,55 @@ def _message_decoder_constructor(
(like DoubleXYData) are defined in .proto files, so they use whichever protobuf implementation
that google.protobuf.internal.api_implementation chooses (usually upb).
"""

def _decode_message(
buffer: memoryview, pos: int, end: int, message: Message, field_dict: Dict[Key, Any]
) -> int:
decode_varint = _varint_decoder(mask=(1 << 64) - 1, result_type=int)
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
# Read length.
(size, pos) = decode_varint(buffer, pos)
new_pos = pos + size
if new_pos > end:
raise ValueError("Error decoding a message. Message is truncated.")
parsed_bytes = value.ParseFromString(buffer[pos:new_pos])
if parsed_bytes != size:
raise ValueError("Parsed incorrect number of bytes.")
return new_pos

return _decode_message
if is_repeated:
tag_bytes = encoder.TagBytes(field_index, wire_format.WIRETYPE_LENGTH_DELIMITED)
tag_len = len(tag_bytes)

def _decode_repeated_message(
buffer: memoryview, pos: int, end: int, message: Message, field_dict: Dict[Key, Any]
) -> int:
decode_varint = _varint_decoder(mask=(1 << 64) - 1, result_type=int)
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, [])
while 1:
parsed_value = new_default(message)
# Read length.
(size, pos) = decode_varint(buffer, pos)
new_pos = pos + size
if new_pos > end:
raise ValueError("Error decoding a message. Message is truncated.")
parsed_bytes = parsed_value.ParseFromString(buffer[pos:new_pos])
if parsed_bytes != size:
raise ValueError("Parsed incorrect number of bytes.")
value.append(parsed_value)
# Predict that the next tag is another copy of the same repeated field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
# Prediction failed. Return.
return new_pos

return _decode_repeated_message
else:

def _decode_message(
buffer: memoryview, pos: int, end: int, message: Message, field_dict: Dict[Key, Any]
) -> int:
decode_varint = _varint_decoder(mask=(1 << 64) - 1, result_type=int)
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
# Read length.
(size, pos) = decode_varint(buffer, pos)
new_pos = pos + size
if new_pos > end:
raise ValueError("Error decoding a message. Message is truncated.")
parsed_bytes = value.ParseFromString(buffer[pos:new_pos])
if parsed_bytes != size:
raise ValueError("Parsed incorrect number of bytes.")
return new_pos

return _decode_message


T = TypeVar("T", bound="int")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ def vector_encoder(field_index: int) -> Encoder:
return vector_encoder


def _unsupported_encoder(field_index: int, is_repeated: bool, is_packed: bool) -> Encoder:
raise NotImplementedError(f"Unsupported data type for field {field_index}")


def _scalar_decoder(decoder: DecoderConstructor) -> PartialDecoderConstructor:
"""Constructs a scalar decoder constructor.
Expand Down Expand Up @@ -103,7 +99,9 @@ def vector_decoder(field_index: int, key: Key) -> Decoder:
return vector_decoder


def _double_xy_data_decoder(decoder: DecoderConstructor) -> PartialDecoderConstructor:
def _double_xy_data_decoder(
decoder: DecoderConstructor, is_repeated: bool
) -> PartialDecoderConstructor:
"""Constructs a DoubleXYData decoder constructor.
Takes a field index and a key and returns a Decoder for DoubleXYData.
Expand All @@ -113,7 +111,6 @@ def _new_default(unused_message: Optional[Message] = None) -> Any:
return xydata_pb2.DoubleXYData()

def message_decoder(field_index: int, key: Key) -> Decoder:
is_repeated = True
is_packed = True
return decoder(field_index, is_repeated, is_packed, key, _new_default)

Expand All @@ -136,7 +133,9 @@ def message_decoder(field_index: int, key: Key) -> Decoder:
UIntArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.UInt32Encoder))
BoolArrayEncoder = _vector_encoder(encoder.BoolEncoder)
StringArrayEncoder = _vector_encoder(encoder.StringEncoder, is_packed=False)
UnsupportedMessageArrayEncoder = _vector_encoder(_unsupported_encoder)
MessageArrayEncoder = _vector_encoder(
cast(EncoderConstructor, _message._message_encoder_constructor)
)

# Cast works around this issue in typeshed
# https://github.com/python/typeshed/issues/10697
Expand All @@ -148,7 +147,7 @@ def message_decoder(field_index: int, key: Key) -> Decoder:
UInt64Decoder = _scalar_decoder(cast(DecoderConstructor, decoder.UInt64Decoder))
BoolDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.BoolDecoder))
StringDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.StringDecoder))
XYDataDecoder = _double_xy_data_decoder(_message._message_decoder_constructor)
XYDataDecoder = _double_xy_data_decoder(_message._message_decoder_constructor, is_repeated=False)

FloatArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.FloatDecoder))
DoubleArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.DoubleDecoder))
Expand All @@ -160,6 +159,9 @@ def message_decoder(field_index: int, key: Key) -> Decoder:
StringArrayDecoder = _vector_decoder(
cast(DecoderConstructor, decoder.StringDecoder), is_packed=False
)
XYDataArrayDecoder = _double_xy_data_decoder(
_message._message_decoder_constructor, is_repeated=True
)


_FIELD_TYPE_TO_ENCODER_MAPPING = {
Expand All @@ -172,7 +174,7 @@ def message_decoder(field_index: int, key: Key) -> Decoder:
type_pb2.Field.TYPE_BOOL: (BoolEncoder, BoolArrayEncoder),
type_pb2.Field.TYPE_STRING: (StringEncoder, StringArrayEncoder),
type_pb2.Field.TYPE_ENUM: (IntEncoder, IntArrayEncoder),
type_pb2.Field.TYPE_MESSAGE: (MessageEncoder, UnsupportedMessageArrayEncoder),
type_pb2.Field.TYPE_MESSAGE: (MessageEncoder, MessageArrayEncoder),
}

_FIELD_TYPE_TO_DECODER_MAPPING = {
Expand Down Expand Up @@ -203,6 +205,10 @@ def message_decoder(field_index: int, key: Key) -> Decoder:
xydata_pb2.DoubleXYData.DESCRIPTOR.full_name: XYDataDecoder,
}

_ARRAY_MESSAGE_TYPE_TO_DECODER = {
xydata_pb2.DoubleXYData.DESCRIPTOR.full_name: XYDataArrayDecoder,
}


def get_encoder(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> PartialEncoderConstructor:
"""Get the appropriate partial encoder constructor for the specified type.
Expand All @@ -227,8 +233,9 @@ def get_decoder(
return array_decoder if repeated else scalar_decoder
elif type == type_pb2.Field.Kind.TYPE_MESSAGE:
if repeated:
raise ValueError(f"Repeated message types are not supported '{message_type}'")
decoder = _MESSAGE_TYPE_TO_DECODER.get(message_type)
decoder = _ARRAY_MESSAGE_TYPE_TO_DECODER.get(message_type)
else:
decoder = _MESSAGE_TYPE_TO_DECODER.get(message_type)
if decoder is None:
raise ValueError(f"Unknown message type '{message_type}'")
return decoder
Expand Down
1 change: 1 addition & 0 deletions ni_measurementlink_service/measurement/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,4 @@ class DataType(enum.Enum):
PinArray1D = 108
PathArray1D = 109
EnumArray1D = 110
DoubleXYDataArray1D = 111
8 changes: 8 additions & 0 deletions tests/unit/test_serialization_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
(type_pb2.Field.TYPE_STRING, False, serialization_strategy.StringEncoder),
(type_pb2.Field.TYPE_ENUM, False, serialization_strategy.IntEncoder),
(type_pb2.Field.TYPE_MESSAGE, False, serialization_strategy.MessageEncoder),
(type_pb2.Field.TYPE_MESSAGE, True, serialization_strategy.MessageArrayEncoder),
],
)
def test___serialization_strategy___get_encoder___returns_expected_encoder(
Expand Down Expand Up @@ -48,6 +49,12 @@ def test___serialization_strategy___get_encoder___returns_expected_encoder(
xydata_pb2.DoubleXYData.DESCRIPTOR.full_name,
serialization_strategy.XYDataDecoder,
),
(
type_pb2.Field.TYPE_MESSAGE,
True,
xydata_pb2.DoubleXYData.DESCRIPTOR.full_name,
serialization_strategy.XYDataArrayDecoder,
),
],
)
def test___serialization_strategy___get_decoder___returns_expected_decoder(
Expand All @@ -71,6 +78,7 @@ def test___serialization_strategy___get_decoder___returns_expected_decoder(
(type_pb2.Field.TYPE_STRING, False, ""),
(type_pb2.Field.TYPE_ENUM, False, 0),
(type_pb2.Field.TYPE_MESSAGE, False, None),
(type_pb2.Field.TYPE_MESSAGE, True, []),
],
)
def test___serialization_strategy___get_default_value___returns_type_defaults(
Expand Down
22 changes: 21 additions & 1 deletion tests/unit/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
from google.protobuf import any_pb2, type_pb2


from ni_measurementlink_service._annotations import (
ENUM_VALUES_KEY,
TYPE_SPECIALIZATION_KEY,
Expand Down Expand Up @@ -43,6 +42,12 @@ class Countries(IntEnum):
double_xy_data.x_data.append(4)
double_xy_data.y_data.append(6)

double_xy_data2 = xydata_pb2.DoubleXYData()
double_xy_data2.x_data.append(8)
double_xy_data2.y_data.append(10)

double_xy_data_array = [double_xy_data, double_xy_data2]

# This should match the number of fields in bigmessage.proto.
BIG_MESSAGE_SIZE = 100

Expand Down Expand Up @@ -72,6 +77,7 @@ class Countries(IntEnum):
Countries.AUSTRALIA,
[Countries.AUSTRALIA, Countries.CANADA],
double_xy_data,
double_xy_data_array,
],
[
-0.9999,
Expand All @@ -95,6 +101,7 @@ class Countries(IntEnum):
Countries.AUSTRALIA,
[Countries.AUSTRALIA, Countries.CANADA],
double_xy_data,
double_xy_data_array,
],
],
)
Expand Down Expand Up @@ -133,6 +140,7 @@ def test___serializer___serialize_parameter___successful_serialization(test_valu
Countries.AUSTRALIA,
[Countries.AUSTRALIA, Countries.CANADA],
double_xy_data,
double_xy_data_array,
],
[
-0.9999,
Expand All @@ -156,6 +164,7 @@ def test___serializer___serialize_parameter___successful_serialization(test_valu
Countries.AUSTRALIA,
[Countries.AUSTRALIA, Countries.CANADA],
double_xy_data,
double_xy_data_array,
],
],
)
Expand Down Expand Up @@ -193,6 +202,7 @@ def test___serializer___serialize_default_parameter___successful_serialization(d
Countries.AUSTRALIA,
[Countries.AUSTRALIA, Countries.CANADA],
double_xy_data,
double_xy_data_array,
]
],
)
Expand Down Expand Up @@ -230,6 +240,7 @@ def test___empty_buffer___deserialize_parameters___returns_zero_or_empty():
Countries.AUSTRALIA,
[Countries.AUSTRALIA, Countries.CANADA],
double_xy_data,
double_xy_data_array,
]
parameter = _get_test_parameter_by_id(nonzero_defaults)
parameter_value_by_id = serializer.deserialize_parameters(parameter, bytes())
Expand Down Expand Up @@ -449,6 +460,14 @@ def _get_test_parameter_by_id(default_values):
annotations={},
message_type=xydata_pb2.DoubleXYData.DESCRIPTOR.full_name,
),
22: ParameterMetadata(
display_name="xy_data_array",
type=type_pb2.Field.TYPE_MESSAGE,
repeated=True,
default_value=default_values[21],
annotations={},
message_type=xydata_pb2.DoubleXYData.DESCRIPTOR.full_name,
),
}
return parameter_by_id

Expand Down Expand Up @@ -477,6 +496,7 @@ def _get_test_grpc_message(test_values):
parameter.int_enum_array_data.extend(list(map(lambda x: x.value, test_values[19])))
parameter.xy_data.x_data.append(test_values[20].x_data[0])
parameter.xy_data.y_data.append(test_values[20].y_data[0])
parameter.xy_data_array.extend(test_values[21])
return parameter


Expand Down
Loading

0 comments on commit 6c70e0f

Please sign in to comment.