Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reference Doc (semi-manual). #978

Merged
merged 3 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/chains/doc_gen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
This generation process of the documentation is *extremely* scrappy and just
an interim solution. It requires significant manual oversight and the code
quality in this directory is non-existent.

The general process is:
1. Document as much as possible in the code, including usage examples, links
etc.
2. Auto-generate `generated-API-reference.mdx` with `poetry run python
docs/chains/doc_gen/generate_reference.py`. This applies the patch file and
launches meld to resolve conflicts.
4. Proofread `docs/snippets/chains/API-reference.mdx`.
5. If proofreading leads to edits or the upstream docstrings changed lot,
update the patch file: `diff -u \
docs/chains/doc_gen/generated-reference.mdx \
docs/snippets/chains/API-reference.mdx > \
docs/chains/doc_gen/reference.patch`

For questions, please reach out to @marius-baseten.
215 changes: 215 additions & 0 deletions docs/chains/doc_gen/generate_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# type: ignore # This tool is only for Marius.
"""Script to auot-generate the API reference for Truss Chains."""
import inspect
import pathlib
import shutil
import subprocess
import tempfile
from pathlib import Path

import truss_chains as chains
from sphinx import application

DUMMY_INDEX_RST = """
.. Dummy

Welcome to Truss Chains's documentation!
========================================

.. toctree::
:maxdepth: 2
:caption: Contents:

modules
"""


BUILDER = "mdx_adapter" # "html"
NON_PUBLIC_SYMBOLS = [
# "truss_chains.definitions.AssetSpec",
# "truss_chains.definitions.ComputeSpec",
"truss_chains.deploy.ChainService",
]


SECTION_CHAINLET = (
"Chainlets",
"APIs for creating user-defined Chainlets.",
[
"truss_chains.ChainletBase",
"truss_chains.depends",
"truss_chains.depends_context",
"truss_chains.DeploymentContext",
"truss_chains.RPCOptions",
"truss_chains.mark_entrypoint",
],
)
SECTION_CONFIG = (
"Remote Configuration",
(
"These data structures specify for each chainlet how it gets deployed "
"remotely, e.g. dependencies and compute resources."
),
[
"truss_chains.RemoteConfig",
"truss_chains.DockerImage",
"truss_chains.Compute",
"truss_chains.Assets",
],
)
SECTION_UTILITIES = (
"Core",
"General framework and helper functions.",
[
"truss_chains.deploy_remotely",
"truss_chains.deploy.ChainService",
"truss_chains.make_abs_path_here",
"truss_chains.run_local",
"truss_chains.ServiceDescriptor",
"truss_chains.StubBase",
"truss_chains.RemoteErrorDetail",
# "truss_chains.ChainsRuntimeError",
],
)

SECTIONS = [SECTION_CHAINLET, SECTION_CONFIG, SECTION_UTILITIES]


def _list_imported_symbols(module: object) -> dict[str, str]:
imported_symbols = {
f"truss_chains.{name}": (
"autoclass"
if inspect.isclass(obj)
else "autofunction"
if inspect.isfunction(obj)
else "autodata"
)
for name, obj in inspect.getmembers(module)
if not name.startswith("_") and not inspect.ismodule(obj)
}
# Extra classes that are not really exported as public API, but are still relevant.
imported_symbols.update({sym: "autoclass" for sym in NON_PUBLIC_SYMBOLS})
return imported_symbols


def _make_rst_structure(chains):
exported_symbols = _list_imported_symbols(chains)
rst_parts = ["API Reference"]
rst_parts.append("=" * len(rst_parts[-1]) + "\n")

for name, descr, symbols in SECTIONS:
rst_parts.append(name)
rst_parts.append("=" * len(rst_parts[-1]) + "\n")
rst_parts.append(descr)
rst_parts.append("\n")

for symbol in symbols:
kind = exported_symbols.pop(symbol)
rst_parts.append(f".. {kind}:: {symbol}")
rst_parts.append("\n")

if exported_symbols:
raise ValueError(
"All symbols must be mapped to a section. Left over:"
f"{list(exported_symbols.keys())}."
)
return "\n".join(rst_parts)


def _clean_build_directory(build_dir: Path) -> None:
if build_dir.exists() and build_dir.is_dir():
shutil.rmtree(build_dir)
build_dir.mkdir(parents=True, exist_ok=True)


def _apply_patch(
original_file_path: str, patch_file_path: str, output_file_path: str
) -> None:
original_file = Path(original_file_path)
patch_file = Path(patch_file_path)
output_file = Path(output_file_path)

with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_output_file_path = Path(temp_file.name)

try:
subprocess.run(
[
"patch",
str(original_file),
"-o",
str(temp_output_file_path),
str(patch_file),
],
check=True,
capture_output=True,
text=True,
)

# Copy temp file to final output if no errors
shutil.copy(temp_output_file_path, output_file)

except subprocess.CalledProcessError as e:
reject_file = temp_output_file_path.with_suffix(".rej")
if reject_file.exists():
print(f"Conflicts found, saved to {reject_file}")
subprocess.run(
[
"meld",
str(original_file_path),
str(output_file),
str(temp_output_file_path),
],
check=True,
)
else:
print(f"Patch failed: {e.stderr}")

finally:
if temp_output_file_path.exists():
temp_output_file_path.unlink()


def generate_sphinx_docs(
output_dir: pathlib.Path,
snippets_dir: pathlib.Path,
) -> None:
_clean_build_directory(output_dir)
config_file = pathlib.Path(__file__).parent / "sphinx_config.py"
docs_dir = output_dir / "docs"
conf_dir = docs_dir
doctree_dir = docs_dir / "doctrees"

docs_dir.mkdir(parents=True, exist_ok=True)
(docs_dir / "conf.py").write_text(config_file.read_text())
(docs_dir / "index.rst").write_text(DUMMY_INDEX_RST)
(docs_dir / "modules.rst").write_text(_make_rst_structure(chains))

app = application.Sphinx(
srcdir=str(docs_dir),
confdir=str(conf_dir),
outdir=str(Path(output_dir).resolve()),
doctreedir=str(doctree_dir),
buildername=BUILDER,
)
app.build()
if BUILDER == "mdx_adapter":
dog_gen_dir = pathlib.Path(__file__).parent.absolute()
generated_reference_path = dog_gen_dir / "generated-reference.mdx"
shutil.copy(output_dir / "modules.mdx", generated_reference_path)
patch_file_path = dog_gen_dir / "reference.patch"
# Apply patch to generated_reference_path
snippets_reference_path = snippets_dir / "chains/API-reference.mdx"
_apply_patch(
str(generated_reference_path),
str(patch_file_path),
str(snippets_reference_path),
)


if __name__ == "__main__":
snippets_dir = pathlib.Path(__file__).parent.parent.parent.absolute() / "snippets"
generate_sphinx_docs(
output_dir=pathlib.Path("/tmp/doc_gen"),
snippets_dir=snippets_dir,
)
Loading
Loading