-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
92 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from os import linesep | ||
from typing import List | ||
from functools import partial | ||
from pathlib import Path | ||
|
||
from google.protobuf.reflection import GeneratedProtocolMessageType | ||
from google.protobuf.descriptor import Descriptor, FieldDescriptor | ||
from google.protobuf.struct_pb2 import Struct | ||
from jinja2 import Environment, PackageLoader | ||
|
||
tab = " " * 4 | ||
one_line, two_lines = linesep * 2, linesep * 3 | ||
type_mapping = { | ||
FieldDescriptor.TYPE_DOUBLE: float, | ||
FieldDescriptor.TYPE_FLOAT: float, | ||
FieldDescriptor.TYPE_INT64: int, | ||
FieldDescriptor.TYPE_UINT64: int, | ||
FieldDescriptor.TYPE_INT32: int, | ||
FieldDescriptor.TYPE_FIXED64: float, | ||
FieldDescriptor.TYPE_FIXED32: float, | ||
FieldDescriptor.TYPE_BOOL: bool, | ||
FieldDescriptor.TYPE_STRING: str, | ||
FieldDescriptor.TYPE_BYTES: str, | ||
FieldDescriptor.TYPE_UINT32: int, | ||
FieldDescriptor.TYPE_SFIXED32: float, | ||
FieldDescriptor.TYPE_SFIXED64: float, | ||
FieldDescriptor.TYPE_SINT32: int, | ||
FieldDescriptor.TYPE_SINT64: int, | ||
} | ||
|
||
|
||
def convert_field(level: int, field: FieldDescriptor) -> str: | ||
level += 1 | ||
field_type = field.type | ||
extra = None | ||
|
||
if field_type == FieldDescriptor.TYPE_ENUM: | ||
type_statement = field.name | ||
class_statement = f"{tab * level}class {field.name}(IntEnum):" | ||
field_statements = map( | ||
lambda value: f"{tab * (level + 1)}{value.name} = {value.index}", | ||
field.enum_type.values, | ||
) | ||
extra = linesep.join([class_statement, *field_statements]) | ||
elif field_type == FieldDescriptor.TYPE_MESSAGE: | ||
type_statement = field.message_type.name | ||
if type_statement == Struct.__name__: | ||
type_statement = "Dict" | ||
else: | ||
extra = msg2pydantic(level, field.message_type) | ||
else: | ||
type_statement = type_mapping[field_type].__name__ | ||
|
||
if field.label == FieldDescriptor.LABEL_REPEATED: | ||
type_statement = f"List[{type_statement}]" | ||
|
||
field_statement = f"{tab * level}{field.name}: {type_statement}" | ||
if not extra: | ||
return field_statement | ||
return linesep + extra + one_line + field_statement | ||
|
||
|
||
def msg2pydantic(level: int, msg: Descriptor) -> str: | ||
class_statement = f"{tab * level}class {msg.name}(BaseModel):" | ||
field_statements = map(partial(convert_field, level), msg.fields) | ||
return linesep.join([class_statement, *field_statements]) | ||
|
||
|
||
def pb2_to_pydantic(module) -> str: | ||
pydantic_models: List[str] = [] | ||
for i in dir(module): | ||
obj = getattr(module, i) | ||
if not isinstance(obj, GeneratedProtocolMessageType): | ||
continue | ||
model_string = msg2pydantic(0, obj.DESCRIPTOR) | ||
pydantic_models.append(model_string) | ||
|
||
env = Environment(loader=PackageLoader(Path.cwd().name)) | ||
template = env.get_template("pydantic.jinja2") | ||
return template.render(models=pydantic_models) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,80 +1,14 @@ | ||
from os import linesep | ||
from typing import List | ||
from functools import partial | ||
from pathlib import Path | ||
from importlib import import_module | ||
|
||
from google.protobuf.reflection import GeneratedProtocolMessageType | ||
from google.protobuf.descriptor import Descriptor, FieldDescriptor | ||
from google.protobuf.struct_pb2 import Struct | ||
from jinja2 import Environment, PackageLoader | ||
from typer import Typer | ||
|
||
tab = " " * 4 | ||
one_line, two_lines = linesep * 2, linesep * 3 | ||
type_mapping = { | ||
FieldDescriptor.TYPE_DOUBLE: float, | ||
FieldDescriptor.TYPE_FLOAT: float, | ||
FieldDescriptor.TYPE_INT64: int, | ||
FieldDescriptor.TYPE_UINT64: int, | ||
FieldDescriptor.TYPE_INT32: int, | ||
FieldDescriptor.TYPE_FIXED64: float, | ||
FieldDescriptor.TYPE_FIXED32: float, | ||
FieldDescriptor.TYPE_BOOL: bool, | ||
FieldDescriptor.TYPE_STRING: str, | ||
FieldDescriptor.TYPE_BYTES: str, | ||
FieldDescriptor.TYPE_UINT32: int, | ||
FieldDescriptor.TYPE_SFIXED32: float, | ||
FieldDescriptor.TYPE_SFIXED64: float, | ||
FieldDescriptor.TYPE_SINT32: int, | ||
FieldDescriptor.TYPE_SINT64: int, | ||
} | ||
from protobuf2pydantic import biz | ||
|
||
app = Typer() | ||
|
||
def convert_field(level: int, field: FieldDescriptor) -> str: | ||
level += 1 | ||
field_type = field.type | ||
extra = None | ||
|
||
if field_type == FieldDescriptor.TYPE_ENUM: | ||
type_statement = field.name | ||
class_statement = f"{tab * level}class {field.name}(IntEnum):" | ||
field_statements = map( | ||
lambda value: f"{tab * (level + 1)}{value.name} = {value.index}", | ||
field.enum_type.values, | ||
) | ||
extra = linesep.join([class_statement, *field_statements]) | ||
elif field_type == FieldDescriptor.TYPE_MESSAGE: | ||
type_statement = field.message_type.name | ||
if type_statement == Struct.__name__: | ||
type_statement = "Dict" | ||
else: | ||
extra = msg2pydantic(level, field.message_type) | ||
else: | ||
type_statement = type_mapping[field_type].__name__ | ||
|
||
if field.label == FieldDescriptor.LABEL_REPEATED: | ||
type_statement = f"List[{type_statement}]" | ||
|
||
field_statement = f"{tab * level}{field.name}: {type_statement}" | ||
if not extra: | ||
return field_statement | ||
return linesep + extra + one_line + field_statement | ||
|
||
|
||
def msg2pydantic(level: int, msg: Descriptor) -> str: | ||
class_statement = f"{tab * level}class {msg.name}(BaseModel):" | ||
field_statements = map(partial(convert_field, level), msg.fields) | ||
return linesep.join([class_statement, *field_statements]) | ||
|
||
|
||
def pb2_to_pydantic(module) -> str: | ||
pydantic_models: List[str] = [] | ||
for i in dir(module): | ||
obj = getattr(module, i) | ||
if not isinstance(obj, GeneratedProtocolMessageType): | ||
continue | ||
model_string = msg2pydantic(0, obj.DESCRIPTOR) | ||
pydantic_models.append(model_string) | ||
|
||
env = Environment(loader=PackageLoader(Path.cwd().name)) | ||
template = env.get_template("pydantic.jinja2") | ||
return template.render(models=pydantic_models) | ||
# fixme | ||
@app.command() | ||
def pydantic(pb2_filename: str): | ||
module = import_module(pb2_filename) | ||
biz.pb2_to_pydantic(module) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
grpcio>=1.32.0 | ||
grpcio-tools>=1.32.0 | ||
jinja2>=2.11.2 | ||
typer>=0.3.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters