Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add varint to MDS #574

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion streaming/base/format/mds/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from abc import ABC, abstractmethod
from decimal import Decimal
from io import BytesIO
from typing import Any, Optional, Set, Tuple
from typing import IO, Any, Optional, Set, Tuple

import numpy as np
from numpy import typing as npt
Expand Down Expand Up @@ -366,6 +366,77 @@ def __init__(self):
super().__init__(np.int64)


def _varuint_encode(obj: int) -> bytes:
if obj < 0:
raise ValueError(f'Expected non-negative integer, but got: {obj}.')
ret = []
while True:
byte = obj & 0x7F
obj >>= 7
if obj:
ret.append(0x80 | byte)
else:
ret.append(byte)
break
return bytes(ret)


def _varint_encode(obj: int) -> bytes:
if 0 <= obj:
obj = obj << 1
else:
obj = ((-obj) << 1) | 1
return _varuint_encode(obj)


def _varuint_decode(stream: IO[bytes]) -> int:
obj = 0
shift = 0
while True:
byte, = stream.read(1)
obj |= (byte & 0x7F) << shift
if 0x80 <= byte:
shift += 7
else:
break
return obj


def _varint_decode(stream: IO[bytes]) -> int:
obj = _varuint_decode(stream)
if obj & 1:
obj = -(obj >> 1)
else:
obj >>= 1
return obj


class VarUInt(Encoding):
"""Store an unsigned integer as a base-128 varint."""

@classmethod
def encode(cls, obj: int) -> bytes:
return _varuint_encode(obj)

@classmethod
def decode(cls, data: bytes) -> int:
stream = BytesIO(data)
return _varuint_decode(stream)


class VarInt(Encoding):
"""Store an integer as a base-128 varint."""

@classmethod
def encode(cls, obj: int) -> bytes:
return _varint_encode(obj)

@classmethod
def decode(cls, data: bytes) -> int:
stream = BytesIO(data)
return _varint_decode(stream)


class Float16(Scalar):
"""Store float16."""

Expand Down Expand Up @@ -531,6 +602,8 @@ def _is_valid(self, original: Any, converted: Any) -> None:
'int16': Int16,
'int32': Int32,
'int64': Int64,
'varuint': VarUInt,
'varint': VarInt,
'float16': Float16,
'float32': Float32,
'float64': Float64,
Expand Down
17 changes: 15 additions & 2 deletions tests/test_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,8 @@ def test_mds_StrDecimal(self, decoded: Decimal, encoded: bytes):
assert dec == decoded

def test_get_mds_encodings(self):
uints = {'uint8', 'uint16', 'uint32', 'uint64'}
ints = {'int8', 'int16', 'int32', 'int64', 'str_int'}
uints = {'uint8', 'uint16', 'uint32', 'uint64', 'varuint'}
ints = {'int8', 'int16', 'int32', 'int64', 'str_int', 'varint'}
floats = {'float16', 'float32', 'float64', 'str_float', 'str_decimal'}
scalars = uints | ints | floats
expected_encodings = {
Expand Down Expand Up @@ -488,6 +488,19 @@ def test_mds_scalar(self, encoding: str, decoded: Union[int, float], encoded: by
dec = mdsEnc.mds_decode(encoding, encoded)
assert dec == decoded

def test_varints(self):
from streaming.base.format.mds.encodings import mds_decode, mds_encode
for x in range(-700, 700, 7):
y = mds_encode('varint', x)
z = mds_decode('varint', y)
print(x, y, z)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need the print statements both here and below?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can we test ints and their expected lengths that we would encode through tokenization? For example, the vocab_size of many models customers train is between 50 and 150k. Testing ints in that range would be useful

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your spurious print statement life debt is paid

assert x == z
for x in range(0, 700, 7):
y = mds_encode('varuint', x)
z = mds_decode('varuint', y)
print(x, y, z)
assert x == z

@pytest.mark.parametrize(('enc_name', 'data'), [('bytes', b'9'), ('int', 27),
('str', 'mosaicml')])
def test_mds_encode(self, enc_name: str, data: Any):
Expand Down