Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/unknown enumeration branches #134

Merged
merged 18 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
lint:
uv run ruff format --check .
uv run ruff check .
uv run mypy .
uv run pyright-python .
uv run deptry .
uv run ruff format --check src tests
uv run ruff check src tests
uv run mypy src tests
uv run pyright-python src tests
uv run deptry src tests

format:
uv run ruff format .
uv run ruff check . --fix
uv run ruff format src tests
uv run ruff check src tests --fix
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@ ignore_missing_imports = True
[mypy-pyd.*]
ignore_missing_imports = True

[mypy-pytest_snapshot.*]
ignore_missing_imports = True

[mypy-tyd.*]
ignore_missing_imports = True
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ dependencies = [
"aiochannel>=1.2.1",
"grpcio-tools>=1.59.3",
"grpcio>=1.59.3",
"msgpack-types>=0.3.0",
"msgpack>=1.0.7",
"nanoid>=2.0.0",
"protobuf>=5.28.3",
Expand All @@ -37,6 +36,7 @@ dependencies = [
[tool.uv]
dev-dependencies = [
"deptry>=0.14.0",
"msgpack-types>=0.3.0",
"mypy>=1.4.0",
"mypy-protobuf>=3.5.0",
"pytest>=7.4.0",
Expand All @@ -48,11 +48,12 @@ dev-dependencies = [
"types-protobuf>=4.24.0.20240311",
"types-nanoid>=2.0.0.20240601",
"pyright>=1.1.389",
"pytest-snapshot>=0.9.0",
]

[tool.ruff]
lint.select = ["F", "E", "W", "I001"]
exclude = ["*/generated/*"]
exclude = ["*/generated/*", "*/snapshots/*"]

# Should be kept in sync with mypy.ini in the project root.
# The VSCode mypy extension can only read /mypy.ini.
Expand Down Expand Up @@ -91,3 +92,6 @@ ignore_missing_imports = true
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
packages = ["src/replit_river"]
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from textwrap import dedent
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
OrderedDict,
Sequence,
Set,
TextIO,
Tuple,
Union,
cast,
Expand All @@ -32,6 +34,7 @@
TypeExpression,
TypeName,
UnionTypeExpr,
UnknownTypeExpr,
ensure_literal_type,
extract_inner_type,
render_type_expr,
Expand Down Expand Up @@ -80,6 +83,7 @@
Literal,
Optional,
Mapping,
NewType,
NotRequired,
Union,
Tuple,
Expand Down Expand Up @@ -160,6 +164,7 @@ def encode_type(
prefix: TypeName,
base_model: str,
in_module: list[ModuleName],
permit_unknown_members: bool,
) -> Tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]:
encoder_name: Optional[str] = None # defining this up here to placate mypy
chunks: List[FileContents] = []
Expand Down Expand Up @@ -256,6 +261,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
TypeName(f"{pfx}{i}"),
base_model,
in_module,
permit_unknown_members=permit_unknown_members,
)
one_of.append(type_name)
chunks.extend(contents)
Expand Down Expand Up @@ -283,7 +289,11 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
else:
oneof_t = oneof_ts[0]
type_name, _, contents, _ = encode_type(
oneof_t, TypeName(pfx), base_model, in_module
oneof_t,
TypeName(pfx),
base_model,
in_module,
permit_unknown_members=permit_unknown_members,
)
one_of.append(type_name)
chunks.extend(contents)
Expand All @@ -301,6 +311,14 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
else
""",
)
if permit_unknown_members:
unknown_name = TypeName(f"{prefix}AnyOf__Unknown")
chunks.append(
FileContents(
f"{unknown_name} = NewType({repr(unknown_name)}, object)"
)
)
one_of.append(UnknownTypeExpr(unknown_name))
chunks.append(
FileContents(
f"{prefix} = {render_type_expr(UnionTypeExpr(one_of))}"
Expand Down Expand Up @@ -336,7 +354,11 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
typeddict_encoder = []
for i, t in enumerate(type.anyOf):
type_name, _, contents, _ = encode_type(
t, TypeName(f"{prefix}AnyOf_{i}"), base_model, in_module
t,
TypeName(f"{prefix}AnyOf_{i}"),
base_model,
in_module,
permit_unknown_members=permit_unknown_members,
)
any_of.append(type_name)
chunks.extend(contents)
Expand All @@ -363,6 +385,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
typeddict_encoder.append(
f"encode_{ensure_literal_type(other)}(x)"
)
if permit_unknown_members:
unknown_name = TypeName(f"{prefix}AnyOf__Unknown")
chunks.append(
FileContents(f"{unknown_name} = NewType({repr(unknown_name)}, object)")
)
any_of.append(UnknownTypeExpr(unknown_name))
if is_literal(type):
typeddict_encoder = ["x"]
chunks.append(
Expand Down Expand Up @@ -404,6 +432,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
prefix,
base_model,
in_module,
permit_unknown_members=permit_unknown_members,
)
elif isinstance(type, RiverConcreteType):
typeddict_encoder = list[str]()
Expand Down Expand Up @@ -446,7 +475,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
return (TypeName("datetime.datetime"), [], [], set())
elif type.type == "array" and type.items:
type_name, module_info, type_chunks, encoder_names = encode_type(
type.items, prefix, base_model, in_module
type.items,
prefix,
base_model,
in_module,
permit_unknown_members=permit_unknown_members,
)
typeddict_encoder.append("TODO: dstewart")
return (ListTypeExpr(type_name), module_info, type_chunks, encoder_names)
Expand All @@ -460,6 +493,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
prefix,
base_model,
in_module,
permit_unknown_members=permit_unknown_members,
)
# TODO(dstewart): This structure changed since we were incorrectly leaking
# ListTypeExprs into codegen. This generated code is
Expand Down Expand Up @@ -494,7 +528,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
) in sorted(list(type.properties.items()), key=lambda xs: xs[0]):
typeddict_encoder.append(f"{repr(name)}:")
type_name, _, contents, _ = encode_type(
prop, TypeName(prefix + name.title()), base_model, in_module
prop,
TypeName(prefix + name.title()),
base_model,
in_module,
permit_unknown_members=permit_unknown_members,
)
encoder_name = None
chunks.extend(contents)
Expand Down Expand Up @@ -685,7 +723,7 @@ def generate_common_client(
chunks.extend(
[
f"from .{model_name} import {class_name}"
for model_name, class_name in modules
for model_name, class_name in sorted(modules, key=lambda kv: kv[1])
]
)
chunks.extend(handshake_chunks)
Expand Down Expand Up @@ -732,6 +770,7 @@ def __init__(self, client: river.Client[Any]):
TypeName(f"{name.title()}Init"),
input_base_class,
module_names,
permit_unknown_members=False,
)
serdes.append(
(
Expand All @@ -745,6 +784,7 @@ def __init__(self, client: river.Client[Any]):
TypeName(f"{name.title()}Input"),
input_base_class,
module_names,
permit_unknown_members=False,
)
serdes.append(
(
Expand All @@ -758,6 +798,7 @@ def __init__(self, client: river.Client[Any]):
TypeName(f"{name.title()}Output"),
"BaseModel",
module_names,
permit_unknown_members=True,
)
serdes.append(
(
Expand All @@ -772,6 +813,7 @@ def __init__(self, client: river.Client[Any]):
TypeName(f"{name.title()}Errors"),
"RiverError",
module_names,
permit_unknown_members=True,
)
if error_type == "None":
error_type = TypeName("RiverError")
Expand Down Expand Up @@ -822,9 +864,9 @@ def __init__(self, client: river.Client[Any]):
.validate_python
"""

assert (
init_type is None or render_init_method
), f"Unable to derive the init encoder from: {input_type}"
assert init_type is None or render_init_method, (
f"Unable to derive the init encoder from: {input_type}"
)

# Input renderer
render_input_method: Optional[str] = None
Expand Down Expand Up @@ -862,9 +904,9 @@ def __init__(self, client: river.Client[Any]):
):
render_input_method = "lambda x: x"

assert (
render_input_method
), f"Unable to derive the input encoder from: {input_type}"
assert render_input_method, (
f"Unable to derive the input encoder from: {input_type}"
)

if output_type == "None":
parse_output_method = "lambda x: None"
Expand Down Expand Up @@ -1038,7 +1080,7 @@ async def {name}(
emitted_files[file_path] = FileContents("\n".join([existing] + contents))

rendered_imports = [
f"from .{dotted_modules} import {', '.join(names)}"
f"from .{dotted_modules} import {', '.join(sorted(names))}"
for dotted_modules, names in imports.items()
]

Expand All @@ -1063,7 +1105,11 @@ def generate_river_client_module(
handshake_chunks: list[str] = []
if schema_root.handshakeSchema is not None:
_handshake_type, _, contents, _ = encode_type(
schema_root.handshakeSchema, TypeName("HandshakeSchema"), "BaseModel", []
schema_root.handshakeSchema,
TypeName("HandshakeSchema"),
"BaseModel",
[],
permit_unknown_members=False,
)
handshake_chunks.extend(contents)
handshake_type = HandshakeType(render_type_expr(_handshake_type))
Expand All @@ -1090,25 +1136,29 @@ def generate_river_client_module(


def schema_to_river_client_codegen(
schema_path: str,
read_schema: Callable[[], TextIO],
target_path: str,
client_name: str,
typed_dict_inputs: bool,
file_opener: Callable[[Path], TextIO],
) -> None:
"""Generates the lines of a River module."""
with open(schema_path) as f:
with read_schema() as f:
schemas = RiverSchemaFile(json.load(f))
for subpath, contents in generate_river_client_module(
client_name, schemas.root, typed_dict_inputs
).items():
module_path = Path(target_path).joinpath(subpath)
module_path.parent.mkdir(mode=0o755, parents=True, exist_ok=True)
with open(module_path, "w") as f:
with file_opener(module_path) as f:
try:
popen = subprocess.Popen(
["ruff", "format", "-"], stdin=subprocess.PIPE, stdout=f
["ruff", "format", "-"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
)
popen.communicate(contents.encode())
stdout, _ = popen.communicate(contents.encode())
f.write(stdout.decode("utf-8"))
except:
f.write(contents)
raise
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
import os.path
from pathlib import Path
from typing import TextIO

from .client import schema_to_river_client_codegen
from .schema import proto_to_river_schema_codegen
Expand Down Expand Up @@ -50,8 +52,16 @@ def main() -> None:
elif args.command == "client":
schema_path = os.path.abspath(args.schema)
target_path = os.path.abspath(args.output)

def file_opener(path: Path) -> TextIO:
return open(path, "w")

schema_to_river_client_codegen(
schema_path, target_path, args.client_name, args.typed_dict_inputs
lambda: open(schema_path),
target_path,
args.client_name,
args.typed_dict_inputs,
file_opener,
)
else:
raise NotImplementedError(f"Unknown command {args.command}")
File renamed without changes.
File renamed without changes.
Loading
Loading