Skip to content

Commit

Permalink
init alg
Browse files Browse the repository at this point in the history
  • Loading branch information
Zibing Zhang committed Dec 10, 2022
1 parent 05383ca commit 4ada4b5
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 10 deletions.
69 changes: 69 additions & 0 deletions src/integration_tests/algorithmic_style_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Any, Callable

from latexify import frontend


def check_algorithm(
fn: Callable[..., Any],
latex: str,
**kwargs,
) -> None:
"""Helper to check if the obtained function has the expected LaTeX form.
Args:
fn: Function to check.
latex: LaTeX form of `fn`.
**kwargs: Arguments passed to `frontend.get_latex`.
"""
# Checks the syntax:
# def fn(...):
# ...
# latexified = get_latex(fn, style=ALGORITHM, **kwargs)
latexified = frontend.get_latex(fn, style=frontend.Style.ALGORITHMIC, **kwargs)
assert latexified == latex


def test_factorial() -> None:
def fact(n):
if n == 0:
return 1
else:
return n * fact(n - 1)

latex = (
r"\begin{algorithmic} "
r"\If{$n = 0$} "
r"\State \Return $1$ "
r"\Else "
r"\State \Return $n \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right)$ "
r"\EndIf "
r"\end{algorithmic}"
)
check_algorithm(fact, latex)


def test_collatz() -> None:
def collatz(n):
iterations = 0
while n > 1:
if n % 2 == 0:
n = n / 2
else:
n = 3 * n + 1
iterations = iterations + 1
return iterations

latex = (
r"\begin{algorithmic} "
r"\State $\mathrm{iterations} \gets 0$ "
r"\While{$n > 1$} "
r"\If{$n \mathbin{\%} 2 = 0$} "
r"\State $n \gets \frac{n}{2}$ "
r"\Else \State $n \gets 3 n + 1$ "
r"\EndIf "
r"\State $\mathrm{iterations} \gets \mathrm{iterations} + 1$ "
r"\EndWhile "
r"\State \Return $\mathrm{iterations}$ "
r"\end{algorithmic}"
)

check_algorithm(collatz, latex)
3 changes: 2 additions & 1 deletion src/latexify/codegen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Package latexify.codegen."""

from latexify.codegen import expression_codegen, function_codegen
from latexify.codegen import algorithmic_codegen, expression_codegen, function_codegen

AlgorithmicCodegen = algorithmic_codegen.AlgorithmicCodegen
ExpressionCodegen = expression_codegen.ExpressionCodegen
FunctionCodegen = function_codegen.FunctionCodegen
73 changes: 73 additions & 0 deletions src/latexify/codegen/algorithmic_codegen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Codegen for single algorithms."""
import ast

from latexify import exceptions
from latexify.codegen import codegen_utils, expression_codegen


class AlgorithmicCodegen(ast.NodeVisitor):
"""Codegen for single algorithms."""

def __init__(
self, *, use_math_symbols: bool = False, use_set_symbols: bool = False
) -> None:
"""Initializer.
Args:
use_math_symbols: Whether to convert identifiers with a math symbol surface
(e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
use_set_symbols: Whether to use set symbols or not.
"""
self._expression_codegen = expression_codegen.ExpressionCodegen(
use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols
)

def generic_visit(self, node: ast.AST) -> str:
raise exceptions.LatexifyNotSupportedError(
f"Unsupported AST: {type(node).__name__}"
)

def visit_Assign(self, node: ast.Assign) -> str:
operands: list[str] = [
self._expression_codegen.visit(target) for target in node.targets
]
operands.append(self._expression_codegen.visit(node.value))
operands_latex = r" \gets ".join(operands)
return rf"\State ${operands_latex}$"

def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
body_strs: list[str] = [self.visit(stmt) for stmt in node.body]
return rf"\begin{{algorithmic}} {' '.join(body_strs)} \end{{algorithmic}}"

def visit_If(self, node: ast.If) -> str:
cond_latex = self._expression_codegen.visit(node.test)
body_latex = " ".join(self.visit(stmt) for stmt in node.body)

latex = rf"\If{{${cond_latex}$}} {body_latex}"

if node.orelse:
latex += r" \Else "
latex += " ".join(self.visit(stmt) for stmt in node.orelse)

return latex + r" \EndIf"

def visit_Module(self, node: ast.Module) -> str:
return self.visit(node.body[0])

def visit_Return(self, node: ast.Return) -> str:
return (
rf"\State \Return ${self._expression_codegen.visit(node.value)}$"
if node.value is not None
else codegen_utils.convert_constant(None)
)

def visit_While(self, node: ast.While) -> str:
cond_latex = self._expression_codegen.visit(node.test)
body_latex = " ".join(self.visit(stmt) for stmt in node.body)

latex = rf"\While{{${cond_latex}$}} {body_latex}"

if node.orelse:
latex += r" \Else "
latex += " ".join(self.visit(stmt) for stmt in node.orelse)

return latex + r" \EndWhile"
35 changes: 26 additions & 9 deletions src/latexify/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import enum
from collections.abc import Callable
from typing import Any, overload

Expand All @@ -16,20 +17,27 @@
_COMMON_PREFIXES = {"math", "numpy", "np"}


# TODO(odashi): move expand_functions to Config.
class Style(str, enum.Enum):
EXPRESSION = "expression"
FUNCTION = "function"
ALGORITHMIC = "algorithmic"


def get_latex(
fn: Callable[..., Any],
*,
style: Style = Style.FUNCTION,
config: cfg.Config | None = None,
**kwargs,
) -> str:
"""Obtains LaTeX description from the function's source.
Args:
fn: Reference to a function to analyze.
config: use defined Config object, if it is None, it will be automatic assigned
style: Style of the LaTeX description, the default is FUNCTION.
config: Use defined Config object, if it is None, it will be automatic assigned
with default value.
**kwargs: dict of Config field values that could be defined individually
**kwargs: Dict of Config field values that could be defined individually
by users.
Returns:
Expand All @@ -38,6 +46,9 @@ def get_latex(
Raises:
latexify.exceptions.LatexifyError: Something went wrong during conversion.
"""
if style == Style.EXPRESSION:
kwargs["use_signature"] = kwargs.get("use_signature", False)

merged_config = cfg.Config.defaults().merge(config=config, **kwargs)

# Obtains the source AST.
Expand All @@ -56,11 +67,17 @@ def get_latex(
tree = transformers.FunctionExpander(merged_config.expand_functions).visit(tree)

# Generates LaTeX.
return codegen.FunctionCodegen(
use_math_symbols=merged_config.use_math_symbols,
use_signature=merged_config.use_signature,
use_set_symbols=merged_config.use_set_symbols,
).visit(tree)
if style == Style.ALGORITHMIC:
return codegen.AlgorithmicCodegen(
use_math_symbols=merged_config.use_math_symbols,
use_set_symbols=merged_config.use_set_symbols,
).visit(tree)
else:
return codegen.FunctionCodegen(
use_math_symbols=merged_config.use_math_symbols,
use_signature=merged_config.use_signature,
use_set_symbols=merged_config.use_set_symbols,
).visit(tree)


class LatexifiedFunction:
Expand Down Expand Up @@ -173,7 +190,7 @@ def expression(
This function is a shortcut for `latexify.function` with the default parameter
`use_signature=False`.
"""
kwargs["use_signature"] = kwargs.get("use_signature", False)
kwargs["style"] = Style.EXPRESSION
if fn is not None:
return function(fn, **kwargs)
else:
Expand Down

0 comments on commit 4ada4b5

Please sign in to comment.