Skip to content

Commit

Permalink
Make classes/functions ### heading
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Jun 17, 2024
1 parent 1630a55 commit e331a12
Show file tree
Hide file tree
Showing 6 changed files with 843 additions and 785 deletions.
96 changes: 75 additions & 21 deletions docs/chains/doc_gen/generate_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,88 @@
]


def _list_imported_symbols(module: object) -> list[tuple[str, str]]:
imported_symbols = [
(
f"truss_chains.{name}",
SECTION_CHAINLET = (
"Chainlets",
"APIs for creating user-defined Chainlets.",
[
"truss_chains.ChainletBase",
"truss_chains.depends",
"truss_chains.depends_context",
"truss_chains.DeploymentContext",
"truss_chains.RPCOptions",
"truss_chains.mark_entrypoint",
],
)
SECTION_CONFIG = (
"Remote Configuration",
(
"These data structures specify for each chainlet how it gets deployed "
"remotely, e.g. dependencies and compute resources."
),
[
"truss_chains.RemoteConfig",
"truss_chains.DockerImage",
"truss_chains.Compute",
"truss_chains.Assets",
],
)
SECTION_UTILITIES = (
"Core",
"General framework and helper functions.",
[
"truss_chains.deploy_remotely",
"truss_chains.deploy.ChainService",
"truss_chains.make_abs_path_here",
"truss_chains.run_local",
"truss_chains.ServiceDescriptor",
"truss_chains.StubBase",
"truss_chains.RemoteErrorDetail",
# "truss_chains.ChainsRuntimeError",
],
)

SECTIONS = [SECTION_CHAINLET, SECTION_CONFIG, SECTION_UTILITIES]


def _list_imported_symbols(module: object) -> dict[str, str]:
imported_symbols = {
f"truss_chains.{name}": (
"autoclass"
if inspect.isclass(obj)
else "autofunction"
if inspect.isfunction(obj)
else "autodata",
else "autodata"
)
for name, obj in inspect.getmembers(module)
if not name.startswith("_") and not inspect.ismodule(obj)
]
}
# Extra classes that are not really exported as public API, but are still relevant.
imported_symbols.extend((sym, "autoclass") for sym in NON_PUBLIC_SYMBOLS)
# print(imported_symbols)
return sorted(imported_symbols, key=lambda x: x[0].split(".")[-1].lower())
imported_symbols.update({sym: "autoclass" for sym in NON_PUBLIC_SYMBOLS})
return imported_symbols


def _make_rst_structure(chains):
exported_symbols = _list_imported_symbols(chains)
rst_parts = ["API Reference"]
rst_parts.append("=" * len(rst_parts[-1]) + "\n")

for name, descr, symbols in SECTIONS:
rst_parts.append(name)
rst_parts.append("=" * len(rst_parts[-1]) + "\n")
rst_parts.append(descr)
rst_parts.append("\n")

for symbol in symbols:
kind = exported_symbols.pop(symbol)
rst_parts.append(f".. {kind}:: {symbol}")
rst_parts.append("\n")

if exported_symbols:
raise ValueError(
"All symbols must be mapped to a section. Left over:"
f"{list(exported_symbols.keys())}."
)
return "\n".join(rst_parts)


def _clean_build_directory(build_dir: Path) -> None:
Expand Down Expand Up @@ -118,18 +183,7 @@ def generate_sphinx_docs(
docs_dir.mkdir(parents=True, exist_ok=True)
(docs_dir / "conf.py").write_text(config_file.read_text())
(docs_dir / "index.rst").write_text(DUMMY_INDEX_RST)

exported_symbols = _list_imported_symbols(chains)
rst_parts = ["API Reference\n============="]
for symbol, kind in exported_symbols:
rst_parts.append(
f"""
.. {kind}:: {symbol}
"""
)

(docs_dir / "modules.rst").write_text("\n".join(rst_parts))
(docs_dir / "modules.rst").write_text(_make_rst_structure(chains))

app = application.Sphinx(
srcdir=str(docs_dir),
Expand Down
Loading

0 comments on commit e331a12

Please sign in to comment.