Skip to content

Commit

Permalink
support repeat and required
Browse files Browse the repository at this point in the history
  • Loading branch information
Ed-XCF committed Jan 16, 2021
1 parent e9c1bca commit 4bceef8
Show file tree
Hide file tree
Showing 4 changed files with 431 additions and 15 deletions.
3 changes: 3 additions & 0 deletions protobuf2pydantic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from enum import IntEnum # noqa

from google.protobuf.reflection import GeneratedProtocolMessageType
from google.protobuf.struct_pb2 import Struct # noqa
from google.protobuf.timestamp_pb2 import Timestamp # noqa
from google.protobuf.duration_pb2 import Duration # noqa
from pydantic import BaseModel, Field # noqa

from protobuf2pydantic.biz import msg2pydantic
Expand Down
53 changes: 38 additions & 15 deletions protobuf2pydantic/biz.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from os import linesep
from typing import List
from functools import partial
from enum import IntEnum

from pydantic import BaseModel, Field
from google.protobuf.reflection import GeneratedProtocolMessageType
from google.protobuf.descriptor import Descriptor, FieldDescriptor, EnumDescriptor
from google.protobuf.struct_pb2 import Struct
from google.protobuf.timestamp_pb2 import Timestamp
from google.protobuf.duration_pb2 import Duration

GOOGLE_MESSAGE = [Struct, Timestamp, Duration]
GOOGLE_MESSAGE_STR = [i.__name__ for i in GOOGLE_MESSAGE]
tab = " " * 4
one_line, two_lines = linesep * 2, linesep * 3
type_mapping = {
Expand All @@ -29,48 +31,66 @@
}


def m(field: FieldDescriptor) -> str:
return type_mapping[field.type].__name__


def convert_field(level: int, field: FieldDescriptor) -> str:
level += 1
field_type = field.type
field_label = field.label
extra = None

if field_type == FieldDescriptor.TYPE_ENUM:
enum_type: EnumDescriptor = field.enum_type
type_statement = enum_type.name
class_statement = f"{tab * level}class {enum_type.name}({IntEnum.__name__}):"
class_statement = f"{tab * level}class {enum_type.name}(IntEnum):"
field_statements = map(
lambda value: f"{tab * (level + 1)}{value.name} = {value.index}",
enum_type.values,
)
extra = linesep.join([class_statement, *field_statements])
factory = int.__name__
factory = "int"
elif field_type == FieldDescriptor.TYPE_MESSAGE:
type_statement = field.message_type.name
if type_statement == Struct.__name__:
type_statement = "Dict"
factory = dict.__name__
type_statement: str = field.message_type.name
if type_statement in GOOGLE_MESSAGE_STR:
factory = type_statement
elif type_statement.endswith("Entry"):
key, value = field.message_type.fields # type: FieldDescriptor
type_statement = f"Dict[{m(key)}, {m(value)}]"
factory = "dict"
else:
extra = msg2pydantic(level, field.message_type)
factory = type_statement
else:
type_statement = type_mapping[field_type].__name__
type_statement = m(field)
factory = type_statement

if field.label == FieldDescriptor.LABEL_REPEATED:
if field_label == FieldDescriptor.LABEL_REPEATED:
type_statement = f"List[{type_statement}]"
factory = list.__name__
factory = "list"

default_statement = f"{Field.__name__}(default_factory={factory})"
field_statement = f"{tab * level}{field.name}: {type_statement} = {default_statement}"
default_statement = f" = Field(default_factory={factory})"
if field_label == FieldDescriptor.LABEL_REQUIRED:
default_statement = ""

field_statement = f"{tab * level}{field.name}: {type_statement}{default_statement}"
if not extra:
return field_statement
return linesep + extra + one_line + field_statement


def msg2pydantic(level: int, msg: Descriptor) -> str:
class_statement = f"{tab * level}class {msg.name}({BaseModel.__name__}):"
class_statement = f"{tab * level}class {msg.name}(BaseModel):"
field_statements = map(partial(convert_field, level), msg.fields)
return linesep.join([class_statement, *field_statements])
return linesep.join([class_statement, *field_statements]) + linesep + get_config(level)


def get_config(level: int):
level += 1
class_statement = f"{tab * level}class Config:"
attribute_statement = f"{tab * (level + 1)}arbitrary_types_allowed = True"
return linesep + class_statement + linesep + attribute_statement


def pb2_to_pydantic(module) -> str:
Expand All @@ -86,6 +106,9 @@ def pb2_to_pydantic(module) -> str:
from enum import IntEnum
from pydantic import BaseModel, Field
from google.protobuf.struct_pb2 import Struct
from google.protobuf.timestamp_pb2 import Timestamp
from google.protobuf.duration_pb2 import Duration
"""
Expand Down
4 changes: 4 additions & 0 deletions tests/test_biz.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ def test_with_celery_task(self):
def test_with_test(self):
from . import test_pb2
pb2_to_pydantic(test_pb2)

def test_with_map(self):
from . import test_map_pb2
pb2_to_pydantic(test_map_pb2)
Loading

0 comments on commit 4bceef8

Please sign in to comment.