Skip to content

Commit

Permalink
Work on pytensor backend
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Nov 24, 2023
1 parent bbf4387 commit e405cb2
Show file tree
Hide file tree
Showing 26 changed files with 14,264 additions and 1,420 deletions.
19 changes: 10 additions & 9 deletions REQUIREMENTS.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
numpy~=1.24.4
scipy~=1.11.2
sympy~=1.12
pandas~=2.1.1
numba~=0.57.1
IPython~=8.15.0
latextable~=1.0.0
pytest~=7.4.2
pytest~=7.4.3
setuptools~=68.2.2
texttable~=1.6.7
numba~=0.58.1
numpy~=1.25.2
sympy~=1.12
scipy~=1.11.3
pandas~=2.1.3
ipython~=8.17.2
pytensor~=2.18
latextable~=1.0.1
texttable~=1.7.0
99 changes: 69 additions & 30 deletions cge_modeling/base/cge.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def __init__(
parameters: Optional[Union[list[Parameter], dict[str, Parameter]]] = None,
equations: Optional[Union[list[Equation], dict[str, Equation]]] = None,
numeraire: Optional[Variable] = None,
parse_equations_to_sympy: bool = True,
):
self.numeraire = None
self.coords = {} if coords is None else coords
self.parse_equations_to_sympy = parse_equations_to_sympy
self._symbolic_coords = {sp.Idx(k): v for k, v in self.coords.items()}

self._variables = {}
Expand All @@ -76,7 +78,8 @@ def __init__(
self._initialize_group(parameters, "parameters")
self._initialize_group(equations, "equations")

self._simplify_unpacked_sympy_representation()
if self.parse_equations_to_sympy:
self._simplify_unpacked_sympy_representation()

if numeraire:
del self._variables[numeraire]
Expand Down Expand Up @@ -228,8 +231,13 @@ def _unpack_objects(self, objs: list[Variable | Parameter | Equation]):
list[Variable | Parameter | Equation]
A list of objects with coordinate indices.
"""

expanded_objects = []
for obj in objs:
if isinstance(obj, Equation) and not self.parse_equations_to_sympy:
expanded_objects.append(obj)
continue

expanded_objs = expand_obj_by_indices(
obj, self.coords, dims=None, on_unused_dim="ignore"
)
Expand Down Expand Up @@ -288,32 +296,50 @@ def add_equation(self, equation: Equation, overwrite: bool = False):

local_dict = var_dict | param_dict | str_dim_to_symbol
fancy_dict = fancy_var_dict | fancy_param_dict | str_dim_to_symbol
if self.parse_equations_to_sympy:
# TODO: Should i call substitute_reduce_ops here to remove the sum/product over dummy indices in the equation
# lists? Downside: it will make very long expressions if the dim labels are long.
try:
sympy_eq = sp.parse_expr(
equation.equation, local_dict=local_dict, transformations="all"
)
fancy_eq = sp.parse_expr(
equation.equation, local_dict=fancy_dict, transformations="all"
)
except Exception as e:
raise ValueError(
f"""Could not parse equation "{equation.name}":\n{equation.equation}\n\nEncountered the """
f"following error:\n{e}"
)

# TODO: Should i call substitute_reduce_ops here to remove the sum/product over dummy indices in the equation
# lists? Downside: it will make very long expressions if the dim labels are long.
sympy_eq = sp.parse_expr(equation.equation, local_dict=local_dict, transformations="all")
fancy_eq = sp.parse_expr(equation.equation, local_dict=fancy_dict, transformations="all")

if self.numeraire:
x = self.numeraire.to_sympy()
sympy_eq = sympy_eq.subs({x: 1})

# Standardize equation
standard_eq = substitute_reduce_ops(sympy_eq.lhs - sympy_eq.rhs, self.coords)

eq_id = equation.eq_id
if eq_id is None:
eq_id = len(self.equations) + 1

new_eq = _SympyEquation(
name=equation.name,
equation=equation.equation,
symbolic_eq=sympy_eq,
_eq=standard_eq,
_fancy_eq=fancy_eq,
dims=find_equation_dims(standard_eq, list(str_dim_to_symbol.values())),
eq_id=eq_id,
)
if self.numeraire:
x = self.numeraire.to_sympy()
sympy_eq = sympy_eq.subs({x: 1})

# Standardize equation
try:
standard_eq = substitute_reduce_ops(sympy_eq.lhs - sympy_eq.rhs, self.coords)
except Exception as e:
raise ValueError(
f"""Could not standardize equation "{equation.name}":\n{sympy_eq}\n\nEncountered the """
f"following error:\n{e}"
)

eq_id = equation.eq_id
if eq_id is None:
eq_id = len(self.equations) + 1

new_eq = _SympyEquation(
name=equation.name,
equation=equation.equation,
symbolic_eq=sympy_eq,
_eq=standard_eq,
_fancy_eq=fancy_eq,
dims=find_equation_dims(standard_eq, list(str_dim_to_symbol.values())),
eq_id=eq_id,
)
else:
new_eq = equation

self._add_object(new_eq, "equations", overwrite)

Expand Down Expand Up @@ -567,6 +593,13 @@ def simulate(

return result

def print_residuals(self, res):
n_vars = len(self.unpacked_variable_names)
endog, exog = res.fitted_values[:n_vars], res.fitted_values[n_vars:]
errors = self.f_system(endog, exog)
for eq, val in zip(self.unpacked_equation_names, errors):
print(f"{eq:<75}: {val:<10.3f}")


def recursive_solve_symbolic(equations, known_values=None, max_iter=100):
"""
Expand Down Expand Up @@ -600,10 +633,16 @@ def recursive_solve_symbolic(equations, known_values=None, max_iter=100):
if len(unknowns) == 1:
unknown = unknowns[0]
solution = sp.solve(eq, unknown)
if solution:
known_values[unknown] = solution[0].subs(known_values).evalf()
new_solution_found = True
remove.append(eq)

if isinstance(solution, list):
if len(solution) == 0:
solution = sp.core.numbers.Zero()
else:
solution = solution[0]

known_values[unknown] = solution.subs(known_values).evalf()
new_solution_found = True
remove.append(eq)
elif len(unknowns) == 0:
remove.append(eq)
for eq in remove:
Expand Down
Empty file.
Loading

0 comments on commit e405cb2

Please sign in to comment.