-
Notifications
You must be signed in to change notification settings - Fork 393
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Zibing Zhang
committed
Dec 10, 2022
1 parent
05383ca
commit 4ada4b5
Showing
4 changed files
with
170 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from typing import Any, Callable | ||
|
||
from latexify import frontend | ||
|
||
|
||
def check_algorithm( | ||
fn: Callable[..., Any], | ||
latex: str, | ||
**kwargs, | ||
) -> None: | ||
"""Helper to check if the obtained function has the expected LaTeX form. | ||
Args: | ||
fn: Function to check. | ||
latex: LaTeX form of `fn`. | ||
**kwargs: Arguments passed to `frontend.get_latex`. | ||
""" | ||
# Checks the syntax: | ||
# def fn(...): | ||
# ... | ||
# latexified = get_latex(fn, style=ALGORITHM, **kwargs) | ||
latexified = frontend.get_latex(fn, style=frontend.Style.ALGORITHMIC, **kwargs) | ||
assert latexified == latex | ||
|
||
|
||
def test_factorial() -> None: | ||
def fact(n): | ||
if n == 0: | ||
return 1 | ||
else: | ||
return n * fact(n - 1) | ||
|
||
latex = ( | ||
r"\begin{algorithmic} " | ||
r"\If{$n = 0$} " | ||
r"\State \Return $1$ " | ||
r"\Else " | ||
r"\State \Return $n \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right)$ " | ||
r"\EndIf " | ||
r"\end{algorithmic}" | ||
) | ||
check_algorithm(fact, latex) | ||
|
||
|
||
def test_collatz() -> None: | ||
def collatz(n): | ||
iterations = 0 | ||
while n > 1: | ||
if n % 2 == 0: | ||
n = n / 2 | ||
else: | ||
n = 3 * n + 1 | ||
iterations = iterations + 1 | ||
return iterations | ||
|
||
latex = ( | ||
r"\begin{algorithmic} " | ||
r"\State $\mathrm{iterations} \gets 0$ " | ||
r"\While{$n > 1$} " | ||
r"\If{$n \mathbin{\%} 2 = 0$} " | ||
r"\State $n \gets \frac{n}{2}$ " | ||
r"\Else \State $n \gets 3 n + 1$ " | ||
r"\EndIf " | ||
r"\State $\mathrm{iterations} \gets \mathrm{iterations} + 1$ " | ||
r"\EndWhile " | ||
r"\State \Return $\mathrm{iterations}$ " | ||
r"\end{algorithmic}" | ||
) | ||
|
||
check_algorithm(collatz, latex) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
"""Package latexify.codegen.""" | ||
|
||
from latexify.codegen import expression_codegen, function_codegen | ||
from latexify.codegen import algorithmic_codegen, expression_codegen, function_codegen | ||
|
||
AlgorithmicCodegen = algorithmic_codegen.AlgorithmicCodegen | ||
ExpressionCodegen = expression_codegen.ExpressionCodegen | ||
FunctionCodegen = function_codegen.FunctionCodegen |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
"""Codegen for single algorithms.""" | ||
import ast | ||
|
||
from latexify import exceptions | ||
from latexify.codegen import codegen_utils, expression_codegen | ||
|
||
|
||
class AlgorithmicCodegen(ast.NodeVisitor): | ||
"""Codegen for single algorithms.""" | ||
|
||
def __init__( | ||
self, *, use_math_symbols: bool = False, use_set_symbols: bool = False | ||
) -> None: | ||
"""Initializer. | ||
Args: | ||
use_math_symbols: Whether to convert identifiers with a math symbol surface | ||
(e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). | ||
use_set_symbols: Whether to use set symbols or not. | ||
""" | ||
self._expression_codegen = expression_codegen.ExpressionCodegen( | ||
use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols | ||
) | ||
|
||
def generic_visit(self, node: ast.AST) -> str: | ||
raise exceptions.LatexifyNotSupportedError( | ||
f"Unsupported AST: {type(node).__name__}" | ||
) | ||
|
||
def visit_Assign(self, node: ast.Assign) -> str: | ||
operands: list[str] = [ | ||
self._expression_codegen.visit(target) for target in node.targets | ||
] | ||
operands.append(self._expression_codegen.visit(node.value)) | ||
operands_latex = r" \gets ".join(operands) | ||
return rf"\State ${operands_latex}$" | ||
|
||
def visit_FunctionDef(self, node: ast.FunctionDef) -> str: | ||
body_strs: list[str] = [self.visit(stmt) for stmt in node.body] | ||
return rf"\begin{{algorithmic}} {' '.join(body_strs)} \end{{algorithmic}}" | ||
|
||
def visit_If(self, node: ast.If) -> str: | ||
cond_latex = self._expression_codegen.visit(node.test) | ||
body_latex = " ".join(self.visit(stmt) for stmt in node.body) | ||
|
||
latex = rf"\If{{${cond_latex}$}} {body_latex}" | ||
|
||
if node.orelse: | ||
latex += r" \Else " | ||
latex += " ".join(self.visit(stmt) for stmt in node.orelse) | ||
|
||
return latex + r" \EndIf" | ||
|
||
def visit_Module(self, node: ast.Module) -> str: | ||
return self.visit(node.body[0]) | ||
|
||
def visit_Return(self, node: ast.Return) -> str: | ||
return ( | ||
rf"\State \Return ${self._expression_codegen.visit(node.value)}$" | ||
if node.value is not None | ||
else codegen_utils.convert_constant(None) | ||
) | ||
|
||
def visit_While(self, node: ast.While) -> str: | ||
cond_latex = self._expression_codegen.visit(node.test) | ||
body_latex = " ".join(self.visit(stmt) for stmt in node.body) | ||
|
||
latex = rf"\While{{${cond_latex}$}} {body_latex}" | ||
|
||
if node.orelse: | ||
latex += r" \Else " | ||
latex += " ".join(self.visit(stmt) for stmt in node.orelse) | ||
|
||
return latex + r" \EndWhile" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters