pyzag is a library for efficiently training generic models defined with a recursive nonlinear function. Full documentation is available here.
The library is available as open source code with an MIT license.
A nonlinear recursive function has the form
with
While this form seems abstract, it actually describes a large number of interesting and useful models. For example, consider the ordinary differential equation defined by
We can convert this into a nonlinear recursive equation by applying a numerical time integration scheme, for example the backward Euler method:
This algebraic equation has our standard form for a nonlinear recursive model:
However, defining our time series with an algebraic equation, rather than a differential equation, provides access to a range of models that cannot be expressed as ODEs, for example difference equations.
The goal of training is basically to find the parameters
pyzag provides a few building block methods for efficiently generating sequences and their derivatives:
- pyzag can vectorize simulating the sequences both for independent instantiations of the same model (i.e. batch vectorization) but also by vectorizing over some number of steps
$i$ . This paper describes the basic idea, but pyzag extends the concept to general nonlinear recursive models. The advantage of the approach is that it can increase the calculation bandwith if batch parallelism alone is not enough to fully utilize the compute device. - pyzag implements the parameter gradient calculation with the adjoint method. For long sequences this approach is much more memory efficient compared to automatic differentiation and is also generally more computationally efficient.
- pyzag also provides several methods for solving the resulting batched, time-chunked nonlinear and linear equations and predictors for starting the nonlinear solves based on previously simulated pieces of the sequence.
pyzag is built on top of PyTorch, integrating the adjoint calculation into PyTorch AD. Users can seemlessly define and train deterministic models using PyTorch primitives.
The library also provides helper classes to convert a deterministic model, defined as a nonlinear recursive relation implemented with a PyTorch model, into a statistical model using the pyro library. Specifically, pyzag provides methods for automatically converting the deterministic model to a stochastic model by replacing determinsitc parameters with prior distributions as well as methods for converting models into a hierarchical statistical format to provide dependence across multiple sequences.