diff --git a/README.md b/README.md
index 25c34ce..2bc3db3 100644
--- a/README.md
+++ b/README.md
@@ -8,17 +8,20 @@ A toolbox to compute the robustness of STL formulas using computation graphs. Th
Requires Python 3.10+
-Clone the repo.
+Install the repo:
- Make a venv and activate it
-
-`python3 -m venv stljax_venv`
-
-`source stljax_venv/bin/activate`
+```
+pip install git+https://github.com/UW-CTRL/stljax.git
+```
-Go into the `stljax` folder. Then to install:
+Alternatively, if you like to install the package in editable mode,
-`pip install -e .`
+```
+git clone https://github.com/UW-CTRL/stljax.git
+cd stljax
+pip install -e .
+```
+(Best to use a virtual environment.)
## Usage
@@ -26,7 +29,6 @@ Go into the `stljax` folder. Then to install:
* Setting up signals for the formulas, including the use of Expressions and Predicates
* Defining STL formulas and visualizing them
* Evaluating STL robustness, and robustness trace
-* Gradient descent on STL parameters and signal parameters.
## (New) Features
@@ -35,13 +37,13 @@ stljax leverages to benefits of jax and automatic differentiation!
Aside from using jax as the backend, stljax is more recent and tidier implementation of stlcg which was originally implemented in PyTorch back ~2019.
- Removed the `distributed_mean` hack from original stlcg implementation. jax keeps track of multiple max/min values and will distribute the gradients across all max/min values!
-- Incorporation of the smooth max/min presented in [Optimization with Temporal and Logical Specifications via Generalized Mean-based Smooth Robustness Measures](https://arxiv.org/abs/2405.10996) by Samet Uzun, Purnanand Elango, Pierre-Loic Garoche, Behcet Acikmese
- - Use `approx_method="gmsr"` and `temperature=(eps, p)`
+
## Tags
| Tags 🏷️ | Description |
| --------- | ----------- |
+| v.1.1.0 | General code improvements. Included recurrent implementation and example notebooks. |
| v.1.0.0 | Removed awkward expected signal dimension & leverage vmap for batched inputs. Masking for temporal operations & remove need to reverse signals. |
| v0.0.0 | A transfer from the 2019 PyTorch implementation to Jax + some tidying + adding Predicates + reversing signal automatically. |
@@ -107,11 +109,12 @@ We can use `jax.vmap` to handle multiple signals at once.
`jax.vmap(formula)(signals) # signals is shape [bs, time_dim,...]`
-
NOTE: Need to take care for formulas defined with Expressions and need multiple inputs. Need a wrapper since `jax.vmap` doesn't like tuples in a single argument.
+
+
## TODOs
-- re-implement stlcg (PyTorch) with the latest version of PyTorch.
+- manage reversing of signals internally for recurrent cases.
## Publications
diff --git a/demo.ipynb b/demo.ipynb
index aa2b98d..f265524 100644
--- a/demo.ipynb
+++ b/demo.ipynb
@@ -1,2080 +1,389 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "view-in-github"
- },
- "source": [
- "
"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Ym9wI1PDEH34"
- },
- "source": [
- "If running in colab, in a cell, run the following command:\n",
- "\n",
- "`!pip install --upgrade git+https://github.com/UW-CTRL/stljax.git`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "bP3ZmMocCRef",
- "outputId": "e8ce448d-f09a-4ac2-fb0b-75a9e927aafe"
- },
- "outputs": [],
- "source": [
- "import jax\n",
- "import jax.numpy as jnp\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "from stljax.formula import *\n",
- "from stljax.viz import *\n",
- "\n",
- "import functools"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "nSqJZK4OCReg"
- },
- "source": [
- "## NOTE\n",
- "If using Expressions to define formulas, `stljax` expects input signals to be of size `[time_dim]`.\n",
- "If using Predicates to define formulas, `stljax` expects input signals to be of size `[time_dim, state_dim]` where `state_dim` is the expected input size of your predicate function.\n",
- "\n",
- "Note: With the `mask` version, we do not need to worry about reversing the signal."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "id": "8zoZeNnFCReg"
- },
- "outputs": [],
- "source": [
- "# some helper functions\n",
- "\n",
- "@jax.jit\n",
- "def dynamics_discrete_step(state, control, dt=0.1):\n",
- " '''Single integrator 2d dynamics'''\n",
- " return state + control * dt\n",
- "\n",
- "@jax.jit\n",
- "def simulate_dynamics(controls, state0, dt):\n",
- " T = controls.shape[0]\n",
- " _states = [state0]\n",
- " for t in range(T):\n",
- " _states.append(dynamics_discrete_step(_states[-1], controls[t,:], dt))\n",
- " return jnp.concatenate(_states, 0)\n",
- "\n",
- "@jax.jit\n",
- "def compute_distance_to_origin(states):\n",
- " return jnp.linalg.norm(states[...,:2], axis=-1, keepdims=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 499
- },
- "id": "enRG3dDrCReh",
- "outputId": "af9bcb65-ca49-48bb-b268-9bb5dc04d8f6"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(np.float64(-2.8336225271224977),\n",
- " np.float64(0.658743929862976),\n",
- " np.float64(-2.5266865134239196),\n",
- " np.float64(0.6441279292106629))"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "",
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "# In this example, using a control sequence to generate a state trajectory\n",
- "T = 25 # time horizon\n",
- "dt = 0.1\n",
- "np.random.seed(123)\n",
- "controls = jnp.array(np.random.randn(T,2)) # generate control sequence\n",
- "state0 = jnp.array(np.random.randn(1,2)) - 1.0 # initial state\n",
- "states = simulate_dynamics(controls, state0, dt) # simulate state trajectory\n",
- "\n",
- "# plotting the trajectory (should look noisy/random)\n",
- "fig, ax = plt.subplots()\n",
- "ax.plot(*states.T)\n",
- "ax.scatter(states[0,:1], states[0,1:], label=\"start\")\n",
- "ax.scatter(states[-1,:1], states[-1,1:], label=\"end\")\n",
- "circle1 = plt.Circle((0, 0), 0.5, color='C2', alpha=0.4)\n",
- "ax.add_patch(circle1)\n",
- "\n",
- "plt.xlim([-5,1])\n",
- "plt.ylim([-4,1])\n",
- "plt.legend()\n",
- "plt.grid()\n",
- "plt.axis(\"equal\")\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "oYow-oBfCReh"
- },
- "source": [
- "## Using Expressions"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "pgkKR6e4CReh",
- "outputId": "8d27aa1b-022e-4689-d001-3b0af39b8ae5"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Robustness trace: [-2.2557175 -2.2557175 -2.2557175 -2.2557175 -2.2557175 -2.2557175\n",
- " -2.2557175 -2.2557175 -2.2557175 -2.2557175 -2.2557175 -2.2557175\n",
- " -2.256407 -2.3951554 -2.4112952 -2.477519 -2.5186005 -2.5186005\n",
- " -2.5186005 -2.5186005 -2.5186005 -2.5186005 -2.5186005 -2.5186005\n",
- " -2.5186005 -2.5186005]\n",
- "Robustness value: -2.2557175\n"
- ]
- }
- ],
- "source": [
- "# first define Expression (with None value)\n",
- "# NOTE: Expressions are used for setting up predicates.\n",
- "# The values associated with an Expression is more for convenience.\n",
- "# You can use a jnp.array directly when evaluating a formula, rather than using an Expression with values populated.\n",
- "distance_to_origin_exp = Expression(\"magnitude\", value=None)\n",
- "\n",
- "# formula Eventually distance to origin is less than 0.5\n",
- "formula_exp = Eventually(distance_to_origin_exp < 0.5)\n",
- "\n",
- "\n",
- "# this will throw error since the Expression value is None.\n",
- "# (commented out for convenience in running the notebook)\n",
- "# formula(distance_to_origin_exp)\n",
- "\n",
- "# setting value for\n",
- "states_norm = compute_distance_to_origin(states) # compute distance to origin, size [1, 26, 1]\n",
- "\n",
- "distance_to_origin_exp.set_value(states_norm) # set value for Expression\n",
- "\n",
- "\n",
- "# compute robustness trace and value\n",
- "# inputs are Expression objects\n",
- "# since reverse=False, stljax will automatically time reverse the input signal, and warn user about it.\n",
- "\n",
- "# robustness trace\n",
- "print(\"Robustness trace: \", formula_exp(distance_to_origin_exp).squeeze())\n",
- "\n",
- "# robustness value\n",
- "print(\"Robustness value: \", formula_exp.robustness(distance_to_origin_exp).squeeze())\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "id": "S3xuqsYvCRei"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "5.36 μs ± 186 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
- "8.62 μs ± 274 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
- ]
- }
- ],
- "source": [
- "# we can jit the robustness function and its gradient.\n",
- "f = jax.jit(formula_exp.robustness)\n",
- "g = jax.jit(jax.grad(f))\n",
- "\n",
- "# measure the time it takes to compute a forward and backward pass\n",
- "# very fast!\n",
- "%timeit f(distance_to_origin_exp.value)\n",
- "%timeit g(distance_to_origin_exp.value)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(Array([0.4814521 , 0.49946868, 0.45444584, 0.44981503], dtype=float32),\n",
- " Array([[-0., -0., -0., -0., -0., -0., -0., -0., -0., -1., -0., -0., -0.,\n",
- " -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],\n",
- " [-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,\n",
- " -0., -0., -0., -0., -0., -0., -0., -0., -1., -0., -0., -0.],\n",
- " [-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,\n",
- " -0., -0., -0., -0., -0., -0., -0., -0., -1., -0., -0., -0.],\n",
- " [-0., -0., -0., -0., -0., -0., -0., -1., -0., -0., -0., -0., -0.,\n",
- " -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.]], dtype=float32))"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# we can batch the computation\n",
- "bs = 4\n",
- "key = jax.random.PRNGKey(123) # Random seed is explicit in JAX\n",
- "signals = jax.random.uniform(key, shape=(bs, T)) \n",
- "\n",
- "jax.vmap(f)(signals), jax.vmap(g)(signals)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Array([[-0.02582416, -0.0514056 , -0.01276034, -0.04622918, -0.03386938,\n",
- " -0.06998569, -0.05163663, -0.06342276, -0.03815436, -0.07147011,\n",
- " -0.01940775, -0.05362437, -0.04669334, -0.03519493, -0.0166717 ,\n",
- " -0.02356261, -0.02027061, -0.00810847, -0.04404066, -0.06902726,\n",
- " -0.04640958, -0.04664499, -0.03303727, -0.03973469, -0.03281348],\n",
- " [-0.06602133, -0.0323073 , -0.01701606, -0.02895677, -0.01675591,\n",
- " -0.03107198, -0.02742128, -0.04578061, -0.01740972, -0.07078583,\n",
- " -0.04380497, -0.06857068, -0.01038632, -0.00757802, -0.04489319,\n",
- " -0.0553256 , -0.04845608, -0.0576579 , -0.01687001, -0.04463371,\n",
- " -0.02515552, -0.07106859, -0.06927916, -0.00785302, -0.07494054],\n",
- " [-0.0698043 , -0.01172489, -0.03137346, -0.07852152, -0.03564867,\n",
- " -0.06827941, -0.01655721, -0.03844461, -0.01342334, -0.03133292,\n",
- " -0.0134404 , -0.02768762, -0.03235616, -0.03376853, -0.02919685,\n",
- " -0.06583954, -0.05004779, -0.02036106, -0.01206894, -0.01921945,\n",
- " -0.07753026, -0.0851762 , -0.05762396, -0.03871985, -0.04185297],\n",
- " [-0.03141921, -0.02370912, -0.01929593, -0.04155137, -0.01883044,\n",
- " -0.05980545, -0.03337944, -0.09266003, -0.03492882, -0.03324058,\n",
- " -0.05988078, -0.06018983, -0.02499974, -0.01668277, -0.02550981,\n",
- " -0.0155309 , -0.07851798, -0.01412536, -0.06788316, -0.04530248,\n",
- " -0.02928593, -0.02955182, -0.01892633, -0.03209555, -0.09269717]], dtype=float32)"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# We can apply different max/min approximations\n",
- "bs = 4\n",
- "key = jax.random.PRNGKey(123) # Random seed is explicit in JAX\n",
- "signals = jax.random.uniform(key, shape=(bs, T)) # shape [bs, T]\n",
- "\n",
- "# Default: approx_method=\"true\" -- taking gradient using the true min/max function\n",
- "jax.vmap(jax.grad(formula_exp.robustness))(signals)\n",
- "\n",
- "# taking gradient using a min/max approximation method (specify method and temperature), and the gradients should be \"spread\" to other indices as well\n",
- "\n",
- "# logsumexp approximation\n",
- "foo = functools.partial(formula_exp.robustness, approx_method=\"logsumexp\", temperature=1.) \n",
- "jax.vmap(jax.grad(foo))(signals)\n",
- "\n",
- "# # softmax approximation\n",
- "foo = functools.partial(formula_exp.robustness, approx_method=\"softmax\", temperature=1.) \n",
- "jax.vmap(jax.grad(foo))(signals)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "bUfGQOrGCRei"
- },
- "source": [
- "## Using Predicates\n",
- "\n",
- "Alternatively, we can define the predicate function of an STL formula where predicate function is μ: Rⁿ → R and the input is an n-dimensional signal."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "zdhRJXV1CRei",
- "outputId": "82fdef6f-4f62-4f20-d447-65ba69843e25"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Robustness trace: [-2.2557175 -2.2557175 -2.2557175 -2.2557175 -2.2557175 -2.2557175\n",
- " -2.2557175 -2.2557175 -2.2557175 -2.2557175 -2.2557175 -2.2557175\n",
- " -2.256407 -2.3951554 -2.4112952 -2.477519 -2.5186005 -2.5186005\n",
- " -2.5186005 -2.5186005 -2.5186005 -2.5186005 -2.5186005 -2.5186005\n",
- " -2.5186005 -2.5186005]\n",
- "Robustness value: -2.2557175\n"
- ]
- }
- ],
- "source": [
- "distance_to_origin_pred = Predicate(\"magnitude\", predicate_function=compute_distance_to_origin)\n",
- "formula_pred = Eventually(distance_to_origin_pred < 0.5)\n",
- "\n",
- "formula_pred(states).squeeze()\n",
- "\n",
- "# compute robustness trace and value\n",
- "# inputs are jnp.arrays\n",
- "# robustness trace\n",
- "print(\"Robustness trace: \", formula_pred(states).squeeze() )\n",
- "\n",
- "# robustness value\n",
- "print(\"Robustness value: \", formula_pred.robustness(states).squeeze())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "id": "CHpFtaVnCRei"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "6.39 μs ± 163 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
- "9.53 μs ± 44.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
- ]
- }
- ],
- "source": [
- "# we can jit the robustness function and its gradient.\n",
- "f = jax.jit(formula_pred.robustness)\n",
- "g = jax.jit(jax.grad(f))\n",
- "\n",
- "# measure the time it takes to compute a forward and backward pass\n",
- "# very fast!\n",
- "%timeit f(states)\n",
- "%timeit g(states)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "M170nntGCRei",
- "outputId": "6e2a4aea-6f7f-47f8-c319-d05cadb5a2cb"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(Array([0.4808055 , 0.16095048, 0.35008496, 0.2274302 ], dtype=float32),\n",
- " Array([[[-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0.5642942 , -0.82557374],\n",
- " [-0. , -0. ]],\n",
- " \n",
- " [[-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0.46162838, -0.8870734 ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ]],\n",
- " \n",
- " [[-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0.9965015 , -0.08357489],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ]],\n",
- " \n",
- " [[-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0. , -0. ],\n",
- " [-0.5923696 , -0.8056663 ],\n",
- " [-0. , -0. ]]], dtype=float32))"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# we can batch the computation\n",
- "bs = 4\n",
- "key = jax.random.PRNGKey(123) # Random seed is explicit in JAX\n",
- "signals = jax.random.uniform(key, shape=(bs, T, 2)) \n",
- "\n",
- "jax.vmap(f)(signals), jax.vmap(g)(signals)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Array([[[-0.02352086, -0.02653968],\n",
- " [-0.01435679, -0.00693523],\n",
- " [-0.03308038, -0.01781628],\n",
- " [-0.01722912, -0.02483274],\n",
- " [-0.04262575, -0.06899171],\n",
- " [-0.0105945 , -0.01726518],\n",
- " [-0.02208826, -0.01126937],\n",
- " [-0.02735491, -0.02297586],\n",
- " [-0.01796402, -0.00646728],\n",
- " [-0.02368197, -0.00658443],\n",
- " [-0.01846283, -0.01260185],\n",
- " [-0.05555077, -0.02745588],\n",
- " [-0.00655061, -0.0230543 ],\n",
- " [-0.07459132, -0.04124878],\n",
- " [-0.01915025, -0.02961909],\n",
- " [-0.01502226, -0.00939996],\n",
- " [-0.01998683, -0.01625802],\n",
- " [-0.00708358, -0.00567479],\n",
- " [-0.01812643, -0.01658712],\n",
- " [-0.01440577, -0.00744342],\n",
- " [-0.05441307, -0.07690613],\n",
- " [-0.00256752, -0.01885604],\n",
- " [-0.08005733, -0.02827522],\n",
- " [-0.07011323, -0.10257705],\n",
- " [-0.02251863, -0.02088918]],\n",
- "\n",
- " [[-0.0288817 , -0.04554031],\n",
- " [-0.02986922, -0.03051575],\n",
- " [-0.06544525, -0.0120479 ],\n",
- " [-0.05150944, -0.03064507],\n",
- " [-0.01373286, -0.01273497],\n",
- " [-0.01047864, -0.02845001],\n",
- " [-0.0125201 , -0.03490231],\n",
- " [-0.02127206, -0.01206193],\n",
- " [-0.00927833, -0.02261019],\n",
- " [-0.02714585, -0.05306787],\n",
- " [-0.01446127, -0.00606696],\n",
- " [-0.00537809, -0.02302884],\n",
- " [-0.0235438 , -0.00923518],\n",
- " [-0.0158476 , -0.00633811],\n",
- " [-0.03734772, -0.01784552],\n",
- " [-0.04300226, -0.04666027],\n",
- " [-0.03669007, -0.02190299],\n",
- " [-0.05762413, -0.00766001],\n",
- " [-0.02510785, -0.01035139],\n",
- " [-0.01190622, -0.00796054],\n",
- " [-0.00665629, -0.01547243],\n",
- " [-0.01100077, -0.01839606],\n",
- " [-0.03262926, -0.06270097],\n",
- " [-0.01855177, -0.05273074],\n",
- " [-0.07838853, -0.04113434]],\n",
- "\n",
- " [[-0.01167339, -0.02134102],\n",
- " [-0.05936034, -0.03223828],\n",
- " [-0.01480538, -0.02571858],\n",
- " [-0.00940478, -0.02955442],\n",
- " [-0.02773137, -0.00883445],\n",
- " [-0.00646909, -0.05561024],\n",
- " [-0.04886941, -0.02358606],\n",
- " [-0.02293241, -0.01577896],\n",
- " [-0.01866045, -0.02070966],\n",
- " [-0.01848573, -0.02131906],\n",
- " [-0.01669237, -0.01865314],\n",
- " [-0.03718995, -0.01123359],\n",
- " [-0.01240873, -0.00944648],\n",
- " [-0.0190175 , -0.03663538],\n",
- " [-0.02860874, -0.00258437],\n",
- " [-0.01791863, -0.00819362],\n",
- " [-0.01327199, -0.01249256],\n",
- " [-0.05429133, -0.08259395],\n",
- " [-0.0474679 , -0.00997441],\n",
- " [-0.00964289, -0.01495654],\n",
- " [-0.01428238, -0.0178033 ],\n",
- " [-0.10850457, -0.00910009],\n",
- " [-0.02777518, -0.05727986],\n",
- " [-0.01082244, -0.01753074],\n",
- " [-0.04805903, -0.02802218]],\n",
- "\n",
- " [[-0.01940999, -0.03582613],\n",
- " [-0.02017951, -0.01958448],\n",
- " [-0.04156577, -0.02017255],\n",
- " [-0.02964688, -0.00185195],\n",
- " [-0.00766283, -0.07021606],\n",
- " [-0.00951366, -0.01608 ],\n",
- " [-0.00285715, -0.02878935],\n",
- " [-0.02630167, -0.0143526 ],\n",
- " [-0.01493212, -0.01433359],\n",
- " [-0.00281592, -0.02805223],\n",
- " [-0.01263587, -0.02154564],\n",
- " [-0.00839119, -0.01035679],\n",
- " [-0.00894236, -0.01246331],\n",
- " [-0.00970442, -0.02231573],\n",
- " [-0.03043569, -0.04590179],\n",
- " [-0.0218413 , -0.0091681 ],\n",
- " [-0.0386564 , -0.08243714],\n",
- " [-0.0380757 , -0.00910007],\n",
- " [-0.02410734, -0.04115387],\n",
- " [-0.00632994, -0.06267686],\n",
- " [-0.02284043, -0.02331491],\n",
- " [-0.0064186 , -0.03530146],\n",
- " [-0.02465762, -0.02741359],\n",
- " [-0.06220399, -0.08460201],\n",
- " [-0.04395318, -0.02402879]]], dtype=float32)"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# We can apply different max/min approximations\n",
- "bs = 4\n",
- "key = jax.random.PRNGKey(123) # Random seed is explicit in JAX\n",
- "signals = jax.random.uniform(key, shape=(bs, T, 2)) # shape [bs, T, state_dim]\n",
- "\n",
- "# Default: approx_method=\"true\" -- taking gradient using the true min/max function\n",
- "jax.vmap(jax.grad(formula_pred.robustness))(signals)\n",
- "\n",
- "# taking gradient using a min/max approximation method (specify method and temperature), and the gradients should be \"spread\" to other indices as well\n",
- "\n",
- "# logsumexp approximation\n",
- "foo = functools.partial(formula_pred.robustness, approx_method=\"logsumexp\", temperature=1.) \n",
- "jax.vmap(jax.grad(foo))(signals)\n",
- "\n",
- "# # softmax approximation\n",
- "foo = functools.partial(formula_pred.robustness, approx_method=\"softmax\", temperature=1.) \n",
- "jax.vmap(jax.grad(foo))(signals)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "pMmyk6x2CRej"
- },
- "source": [
- "## Gradient descent to optimize control inputs\n",
- "\n",
- "Now, we can perform gradient descent on the control inputs to make progress towards the formula being true."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "pLreo31uCRej"
- },
- "source": [
- "### Using `Expression`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "id": "CGRcCX-kCRej"
- },
- "outputs": [],
- "source": [
- "# set random initial state and control\n",
- "np.random.seed(123)\n",
- "T = 51 # time horizon\n",
- "dt = 0.1 # time step size\n",
- "ts = jnp.array([t * dt for t in range(T)])\n",
- "umax = 1.0 # max control limit\n",
- "\n",
- "controls = jnp.array(np.random.randn(1,T,2))\n",
- "state0 = jnp.ones(2).reshape([1,2]) * 3.\n",
- "obstacle_center = jnp.ones([1,2]) * 2.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 340
- },
- "id": "kSKcXrjWCRej",
- "outputId": "99541049-5385-4a64-e9f3-3e71bf7e2007"
- },
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# defining formula\n",
- "distance_to_origin = Expression(\"magnitude\", None)\n",
- "distance_to_obstacle = Expression(\"distance_to_obs\", None)\n",
- "reach = Eventually(distance_to_origin < 0.5)\n",
- "avoid = Always(distance_to_obstacle > 0.5)\n",
- "formula = reach & avoid\n",
- "\n",
- "make_stl_graph(formula)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "h9D-hcY8CRej",
- "outputId": "de267035-5188-41a2-de4a-fc5e29e78c00"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(Array(6.845453, dtype=float32), Array(6.995544, dtype=float32))"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "@functools.partial(jax.jit, static_argnames=(\"approx_method\"))\n",
- "def loss(controls, state0, umax, dt, coeffs=[1., 0.1, 5.], approx_method=\"true\", temperature=None):\n",
- " # generate trajectory from control sequence \n",
- " traj = simulate_dynamics(controls, state0, dt)\n",
- " # compute distance_to_origin and distance_to_obstacle\n",
- " distance_to_origin_signal = jnp.linalg.norm(traj, axis=-1)\n",
- " distance_to_obstacle_signal = jnp.linalg.norm(traj - obstacle_center, axis=-1)\n",
- " # loss functions\n",
- " input_signal = (distance_to_origin_signal, distance_to_obstacle_signal)\n",
- " loss_robustness = jax.nn.relu(-formula.robustness(input_signal, approx_method=approx_method, temperature=temperature))\n",
- " loss_control_smoothness = (jnp.diff(controls, axis=1)**2).sum(-1).mean() # make controls smoother\n",
- " loss_control_limits = jax.nn.relu(jnp.linalg.norm(controls, axis=-1) - umax).mean() # penalize control limit violation\n",
- " return coeffs[0] * loss_robustness + coeffs[1] * loss_control_smoothness + coeffs[2] * loss_control_limits\n",
- "\n",
- "# @jax.jit\n",
- "def true_robustness(controls, state0, dt):\n",
- " # generate trajectory from control sequence and reverse along time dimension\n",
- " # user has to manually reverse it since we are inputing jnp.array to compute robustness, instead of an Expression\n",
- " traj = simulate_dynamics(controls, state0, dt)\n",
- " # compute distance_to_origin and distance_to_obstacle\n",
- " distance_to_origin_signal = jnp.linalg.norm(traj, axis=-1)\n",
- " distance_to_obstacle_signal = jnp.linalg.norm(traj - obstacle_center, axis=-1)\n",
- " # loss functions\n",
- " input_signal = (distance_to_origin_signal, distance_to_obstacle_signal)\n",
- " return formula.robustness(input_signal)\n",
- "\n",
- "def temperature_schedule(i, i_max, start_temp, end_temp, scale=5):\n",
- " i_ = i\n",
- " center = i_max / 2\n",
- " return jax.nn.sigmoid((i_ - center) / scale) * (end_temp - start_temp) + start_temp\n",
- "\n",
- "# compare true value with max/min approximation\n",
- "loss(controls, state0, umax, dt), loss(controls, state0, umax, dt, approx_method=\"softmax\", temperature=5)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {
- "id": "clBy6pVHCRej"
- },
- "outputs": [],
- "source": [
- "controls = jnp.array(np.random.randn(T,2))\n",
- "\n",
- "states_ = [simulate_dynamics(controls, state0, dt)]\n",
- "lr = 0.5 # learning rate\n",
- "approx_method = \"true\"\n",
- "n_steps = 1000 # number of gradient steps\n",
- "n_steps_extra = 10\n",
- "coeffs = [1., 0.1, 5.]\n",
- "\n",
- "# jit the gradient function to speed things up (by A LOT).\n",
- "grad_jit = jax.jit(jax.grad(loss, 0), static_argnames=\"approx_method\")\n",
- "# temperature schedule parameters\n",
- "start_temp = 50\n",
- "end_temp = 500\n",
- "scale = 5\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "knWf1b-6CRej",
- "outputId": "ceff89d5-88ce-4951-c202-dcf375e2f00b"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " 0 -- true robustness: -3.68 smoothness: 1.81 control limits: 0.34\n",
- " 50 -- true robustness: -0.04 smoothness: 0.36 control limits: 0.03\n",
- "100 -- true robustness: 0.07 smoothness: 0.26 control limits: 0.00\n",
- "150 -- true robustness: 0.05 smoothness: 0.20 control limits: 0.00\n",
- "200 -- true robustness: 0.07 smoothness: 0.16 control limits: 0.00\n",
- "250 -- true robustness: 0.06 smoothness: 0.13 control limits: 0.00\n",
- "300 -- true robustness: 0.04 smoothness: 0.11 control limits: 0.00\n",
- "350 -- true robustness: 0.04 smoothness: 0.09 control limits: 0.00\n",
- "400 -- true robustness: 0.04 smoothness: 0.08 control limits: 0.00\n",
- "450 -- true robustness: 0.03 smoothness: 0.07 control limits: 0.00\n",
- "500 -- true robustness: 0.03 smoothness: 0.07 control limits: 0.00\n",
- "550 -- true robustness: 0.03 smoothness: 0.06 control limits: 0.00\n",
- "600 -- true robustness: 0.02 smoothness: 0.06 control limits: 0.00\n",
- "650 -- true robustness: 0.03 smoothness: 0.06 control limits: 0.00\n",
- "700 -- true robustness: 0.02 smoothness: 0.06 control limits: 0.00\n",
- "750 -- true robustness: 0.01 smoothness: 0.06 control limits: 0.00\n",
- "800 -- true robustness: 0.01 smoothness: 0.05 control limits: 0.00\n",
- "850 -- true robustness: 0.01 smoothness: 0.06 control limits: 0.00\n",
- "900 -- true robustness: 0.03 smoothness: 0.07 control limits: 0.00\n",
- "950 -- true robustness: 0.07 smoothness: 0.08 control limits: 0.00\n"
- ]
- }
- ],
- "source": [
- "temperatures = temperature_schedule(jnp.arange(n_steps), n_steps, start_temp, end_temp, scale)\n",
- "for i in range(n_steps):\n",
- " temperature = temperature_schedule(i, n_steps, start_temp, end_temp, scale)\n",
- " g = grad_jit(controls, state0, umax, dt, coeffs, approx_method, temperatures[i]) # take gradient\n",
- " controls -= g * lr\n",
- " states_.append(simulate_dynamics(controls, state0, dt))\n",
- " if (i % 50) == 0:\n",
- " print(\"%3i -- true robustness: %.2f smoothness: %.2f control limits: %.2f\"%(i, true_robustness(controls, state0, dt), loss(controls, state0, umax, dt, coeffs=[0., 1., 0.]), loss(controls, state0, umax, dt, coeffs=[0., 0., 1.])))\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 530
- },
- "id": "krMVpNDNCRej",
- "outputId": "123b59bc-a063-45d5-b576-d420ea714a8e"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Text(0, 0.5, 'Controls')"
- ]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "",
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "fig, axs = plt.subplots(1,3, figsize=(15,4)) # note we must use plt.subplots, not plt.subplot\n",
- "\n",
- "ax = axs[0]\n",
- "circle1 = plt.Circle((0, 0), 0.5, color='C2', alpha=0.4)\n",
- "circle2 = plt.Circle(obstacle_center[0], 0.5, color='C3', alpha=0.4)\n",
- "\n",
- "ax.add_patch(circle1)\n",
- "ax.add_patch(circle2)\n",
- "\n",
- "N = 100\n",
- "[ax.plot(*s.T, color=\"k\", alpha=0.2) for s in states_[::N]]\n",
- "[ax.plot(*s.T, color=\"blue\", label=\"Initial traj\") for s in states_[:1]]\n",
- "[ax.plot(*s.T, color=\"r\", label=\"Final traj\") for s in states_[-1:]]\n",
- "\n",
- "ax.scatter(states_[-1][0,:1], states_[-1][0,1:], zorder=10, label=\"start\", color=\"red\")\n",
- "ax.scatter(states_[-1][-1,:1], states_[-1][-1,1:], zorder=10, label=\"end\", color=\"green\")\n",
- "\n",
- "ax.set_xlabel(\"x position\")\n",
- "ax.set_ylabel(\"y position\")\n",
- "ax.grid()\n",
- "ax.legend()\n",
- "ax.axis(\"equal\")\n",
- "\n",
- "# plot x, y\n",
- "ax = axs[1]\n",
- "ax.plot(ts, states_[-1][:-1,:1], label=\"x\")\n",
- "ax.plot(ts, states_[-1][:-1,1:], label=\"y\")\n",
- "ax.grid()\n",
- "ax.axis(\"equal\")\n",
- "ax.legend()\n",
- "ax.set_xlabel(\"Time (s)\")\n",
- "ax.set_ylabel(\"Position\")\n",
- "\n",
- "\n",
- "# plot control signal\n",
- "ax = axs[2]\n",
- "ax.plot(ts, controls[:,:1], label=\"x control\")\n",
- "ax.plot(ts, controls[:,1:], label=\"y control\")\n",
- "ax.plot(ts, jnp.linalg.norm(controls, axis=-1).squeeze(), label=\"control norm\")\n",
- "ax.grid()\n",
- "ax.axis(\"equal\")\n",
- "ax.legend()\n",
- "ax.set_xlabel(\"Time (s)\")\n",
- "ax.set_ylabel(\"Controls\")\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Scv5PRbdCRej"
- },
- "source": [
- "### Using Predicates"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 102,
- "metadata": {
- "id": "7aMIbtVjCRej"
- },
- "outputs": [],
- "source": [
- "# set random initial state and control\n",
- "np.random.seed(123)\n",
- "T = 51 # time horizon\n",
- "dt = 0.1 # time step size\n",
- "ts = jnp.array([t * dt for t in range(T)])\n",
- "umax = 1.0 # max control limit\n",
- "\n",
- "controls = jnp.array(np.random.randn(T,2))\n",
- "state0 = jnp.ones(2).reshape([1,2]) * 3.\n",
- "obstacle_center = jnp.ones([1,2]) * 2.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 103,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 340
- },
- "id": "JVViQbSrCRej",
- "outputId": "146e982c-2fba-4ff7-d3a5-2f4a03a0eb54"
- },
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 103,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# defining formula\n",
- "\n",
- "def compute_distance_to_point(states, point):\n",
- " return jnp.linalg.norm(states[...,:2] - point, axis=-1, keepdims=True)\n",
- "\n",
- "def compute_distance_to_origin(states):\n",
- " return compute_distance_to_point(states, jnp.zeros(2))\n",
- "\n",
- "\n",
- "\n",
- "distance_to_origin = Predicate(\"magnitude\", compute_distance_to_origin)\n",
- "distance_to_obstacle = Predicate(\"distance_to_obs\", lambda x: compute_distance_to_point(x, obstacle_center))\n",
- "reach = Eventually(distance_to_origin < 0.2)\n",
- "avoid = Always(distance_to_obstacle > 0.5)\n",
- "formula = reach & avoid\n",
- "formula = Until(distance_to_obstacle > 0.5, Always(distance_to_origin < 0.5), interval=[40,45])\n",
- "\n",
- "make_stl_graph(formula)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 104,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "Mo_cXd9pCRej",
- "outputId": "c581c88a-8e59-44ae-fa60-7803eea32827"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(Array(7.334007, dtype=float32),\n",
- " Array(7.138606, dtype=float32),\n",
- " Array(-4.1455383, dtype=float32))"
- ]
- },
- "execution_count": 104,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "@functools.partial(jax.jit, static_argnames=(\"approx_method\"))\n",
- "def loss(controls, state0, umax, dt, coeffs=[1., 0.1, 5.], approx_method=\"true\", temperature=None):\n",
- " # generate trajectory from control sequence and reverse along time dimension\n",
- " traj = simulate_dynamics(controls, state0, dt)\n",
- " # loss functions\n",
- " loss_robustness = jax.nn.relu(-formula.robustness(traj,approx_method=approx_method, temperature=temperature))\n",
- " loss_control_smoothness = jnp.abs(jnp.diff(controls, axis=1)).sum(-1).mean() + (controls**2).sum(-1).mean() # make controls smoother\n",
- " loss_control_limits = jax.nn.relu(jnp.linalg.norm(controls, axis=-1) - umax).mean() # penalize control limit violation\n",
- " return coeffs[0] * loss_robustness + coeffs[1] * loss_control_smoothness + coeffs[2] * loss_control_limits\n",
- "\n",
- "@jax.jit\n",
- "def true_robustness(controls, state0, dt):\n",
- " # generate trajectory from control sequence and reverse along time dimension\n",
- " traj = simulate_dynamics(controls, state0, dt)\n",
- " # loss functions\n",
- " return formula.robustness(traj).mean()\n",
- "\n",
- "def temperature_schedule(i, i_max, start_temp, end_temp, scale=5):\n",
- " i_ = i\n",
- " center = i_max / 2\n",
- " return jax.nn.sigmoid((i_ - center) / scale) * (end_temp - start_temp) + start_temp\n",
- "\n",
- "# compare true value with max/min approximation\n",
- "loss(controls, state0, umax, dt), loss(controls, state0, umax, dt, approx_method=\"softmax\", temperature=5), true_robustness(controls, state0, dt)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 106,
- "metadata": {
- "id": "MInn9AycCRek"
- },
- "outputs": [],
- "source": [
- "states_ = [simulate_dynamics(controls, state0, dt)]\n",
- "lr = 1E-1 # learning rate\n",
- "approx_method = \"logsumexp\"\n",
- "n_steps = 1000 # number of gradient steps\n",
- "n_steps_extra = 10\n",
- "coeffs = [1., 0.5, 5.]\n",
- "\n",
- "# jit the gradient function to speed things up (by A LOT).\n",
- "grad_jit = jax.jit(jax.grad(loss, 0), static_argnames=\"approx_method\")\n",
- "grad_jit(controls, state0, umax, dt, coeffs, approx_method, 0.2) \n",
- "# temperature schedule parameters\n",
- "start_temp = 50\n",
- "end_temp = 500\n",
- "scale = 5"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 110,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "_W4TYjG2CRek",
- "outputId": "0e2b8e16-34ce-4512-adff-dfe6f98bd3f6"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " 0 -- true robustness: -0.01 smoothness: 0.89 control limits: 0.00\n",
- " 50 -- true robustness: 0.01 smoothness: 0.90 control limits: 0.00\n",
- "100 -- true robustness: -0.00 smoothness: 0.89 control limits: 0.00\n",
- "150 -- true robustness: 0.00 smoothness: 0.89 control limits: 0.00\n",
- "200 -- true robustness: -0.00 smoothness: 0.90 control limits: 0.00\n",
- "250 -- true robustness: 0.00 smoothness: 0.89 control limits: 0.00\n",
- "300 -- true robustness: 0.01 smoothness: 0.89 control limits: 0.00\n",
- "350 -- true robustness: 0.01 smoothness: 0.90 control limits: 0.00\n",
- "400 -- true robustness: 0.01 smoothness: 0.89 control limits: 0.00\n",
- "450 -- true robustness: -0.01 smoothness: 0.89 control limits: 0.00\n",
- "500 -- true robustness: -0.02 smoothness: 0.88 control limits: 0.00\n",
- "550 -- true robustness: -0.00 smoothness: 0.88 control limits: 0.00\n",
- "600 -- true robustness: -0.00 smoothness: 0.88 control limits: 0.00\n",
- "650 -- true robustness: -0.01 smoothness: 0.87 control limits: 0.00\n",
- "700 -- true robustness: -0.00 smoothness: 0.88 control limits: 0.00\n",
- "750 -- true robustness: 0.00 smoothness: 0.87 control limits: 0.00\n",
- "800 -- true robustness: -0.02 smoothness: 0.87 control limits: 0.00\n",
- "850 -- true robustness: 0.01 smoothness: 0.88 control limits: 0.00\n",
- "900 -- true robustness: -0.00 smoothness: 0.87 control limits: 0.00\n",
- "950 -- true robustness: 0.01 smoothness: 0.89 control limits: 0.00\n"
- ]
- }
- ],
- "source": [
- "temperatures = temperature_schedule(jnp.arange(n_steps), n_steps, start_temp, end_temp, scale)\n",
- "for i in range(n_steps):\n",
- " g = grad_jit(controls, state0, umax, dt, coeffs, approx_method, temperatures[i]) # take gradient\n",
- " # g = jax.grad(loss, 0)(controls, state0, umax, approx_method, temperature) # not jitting\n",
- " controls -= g * lr\n",
- " states_.append(simulate_dynamics(controls, state0, dt))\n",
- " if (i % 50) == 0:\n",
- " print(\"%3i -- true robustness: %.2f smoothness: %.2f control limits: %.2f\"%(i, true_robustness(controls, state0, dt), loss(controls, state0, umax, dt, coeffs=[0., 1., 0.]), loss(controls, state0, umax, dt, coeffs=[0., 0., 1.])))\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 111,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 408
- },
- "id": "mYWVasKKCRek",
- "outputId": "78a2194c-17e3-431b-9535-828058aba881"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Text(0, 0.5, 'Controls')"
- ]
- },
- "execution_count": 111,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "",
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "fig, axs = plt.subplots(1,3, figsize=(15,4)) # note we must use plt.subplots, not plt.subplot\n",
- "\n",
- "ax = axs[0]\n",
- "circle1 = plt.Circle((0, 0), 0.5, color='C2', alpha=0.4)\n",
- "circle2 = plt.Circle(obstacle_center[0], 0.5, color='C3', alpha=0.4)\n",
- "\n",
- "ax.add_patch(circle1)\n",
- "ax.add_patch(circle2)\n",
- "\n",
- "N = 100\n",
- "[ax.plot(*s.T, color=\"k\", alpha=0.2) for s in states_[::N]]\n",
- "[ax.plot(*s.T, color=\"blue\", label=\"Initial traj\") for s in states_[:1]]\n",
- "[ax.plot(*s.T, color=\"r\", label=\"Final traj\") for s in states_[-1:]]\n",
- "\n",
- "ax.scatter(states_[-1][0,:1], states_[-1][0,1:], zorder=10, label=\"start\", color=\"red\")\n",
- "ax.scatter(states_[-1][-1,:1], states_[-1][-1,1:], zorder=10, label=\"end\", color=\"green\")\n",
- "\n",
- "ax.set_xlabel(\"x position\")\n",
- "ax.set_ylabel(\"y position\")\n",
- "ax.grid()\n",
- "ax.legend()\n",
- "ax.axis(\"equal\")\n",
- "\n",
- "# plot x, y\n",
- "ax = axs[1]\n",
- "ax.plot(ts, states_[-1][:-1,:1], label=\"x\")\n",
- "ax.plot(ts, states_[-1][:-1,1:], label=\"y\")\n",
- "ax.plot(ts, distance_to_origin.predicate_function(states_[-1][1:]).squeeze())\n",
- "ax.grid()\n",
- "ax.axis(\"equal\")\n",
- "ax.legend()\n",
- "ax.set_xlabel(\"Time (s)\")\n",
- "ax.set_ylabel(\"Position\")\n",
- "\n",
- "\n",
- "# plot control signal\n",
- "ax = axs[2]\n",
- "ax.plot(ts, controls[:,:1], label=\"x control\")\n",
- "ax.plot(ts, controls[:,1:], label=\"y control\")\n",
- "ax.plot(ts, jnp.linalg.norm(controls, axis=-1).squeeze(), label=\"control norm\")\n",
- "ax.grid()\n",
- "ax.axis(\"equal\")\n",
- "ax.legend()\n",
- "ax.set_xlabel(\"Time (s)\")\n",
- "ax.set_ylabel(\"Controls\")\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "OlK7HyGfCRek"
- },
- "source": [
- "## Parametric STL\n",
- "We can take gradients with respect to the RHS constant in the <, >, == STL formulas\n",
- "\n",
- "As a simple example, we using the final trajectory in the example above to find the value of `c` such that the formula `Always(distance_to_obstacle > c)` is as tight as possible."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "vCF9DvcnCRek",
- "outputId": "0aa83f60-507b-4dfa-aeae-2f05b4fc38f3"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " 0 -- true robustness: 1.7771 \n",
- " 10 -- true robustness: 0.2160 \n",
- " 20 -- true robustness: 0.0263 \n",
- " 30 -- true robustness: 0.0032 \n",
- " 40 -- true robustness: 0.0004 \n",
- " 50 -- true robustness: 0.0000 \n",
- " 60 -- true robustness: 0.0000 \n",
- " 70 -- true robustness: 0.0000 \n",
- " 80 -- true robustness: 0.0000 \n",
- " 90 -- true robustness: 0.0000 \n"
- ]
- }
- ],
- "source": [
- "def parametric_stl_robustness(c, signal):\n",
- " return Always(distance_to_obstacle > c).robustness(signal)**2\n",
- "\n",
- "traj = states_[-1]\n",
- "grad_jit = jax.jit(jax.grad(lambda c: parametric_stl_robustness(c, traj)))\n",
- "\n",
- "c = 2.\n",
- "cs = [c]\n",
- "lr = 0.05\n",
- "for i in range(100):\n",
- " g = grad_jit(c)\n",
- " c -= g * lr\n",
- " cs.append(c)\n",
- " if (i % 10) == 0:\n",
- " print(\"%3i -- true robustness: %.4f \"%(i, parametric_stl_robustness(c, traj)))\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 449
- },
- "id": "JEZh75FCCRek",
- "outputId": "32c64842-65dd-4726-9dae-5c8493b8ef5f"
- },
- "outputs": [
- {
- "data": {
- "image/png": "",
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "\n",
- "plt.plot(ts, jnp.linalg.norm(traj - obstacle_center, axis=-1, keepdims=True).squeeze()[:-1], linewidth=3, label=\"distance to obstacle of trajectory\", zorder=5)\n",
- "plt.hlines(cs[0], ts[0], ts[-1], linewidth=2, color=\"red\", label=\"initial c\", zorder=4)\n",
- "plt.hlines(cs[-1], ts[0], ts[-1], linewidth=2, color=\"green\", label=\"final c\", zorder=4)\n",
- "\n",
- "plt.hlines(cs, ts[0], ts[-1], alpha=0.2, color=\"black\")\n",
- "plt.xlabel(\"Time [s]\")\n",
- "plt.ylabel(\"Distance to obstacle\")\n",
- "plt.legend()\n",
- "\n",
- "plt.grid()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "7t97PeSZCRek"
- },
- "source": [
- "## Examples of other formulas\n",
- "\n",
- "Below are examples of how to apply different STL operations given the reach and avoid predicates.\n",
- "(Don't read too much into the meaning behind each formula with the read and avoid predicates. Just treat them as placeholders)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {
- "id": "Tm2SVG3vCReq"
- },
- "outputs": [],
- "source": [
- "# using values from above\n",
- "traj = states_[-1]\n",
- "\n",
- "# compute distance_to_origin and distance_to_obstacle\n",
- "distance_to_origin_signal = jnp.linalg.norm(traj, axis=-1)\n",
- "distance_to_obstacle_signal = jnp.linalg.norm(traj - obstacle_center, axis=-1)\n",
- "\n",
- "distance_to_origin = Expression(\"magnitude\", distance_to_origin_signal)\n",
- "distance_to_obstacle = Expression(\"distance_to_obs\", distance_to_obstacle_signal)\n",
- "reach = Eventually(distance_to_origin < 0.5)\n",
- "avoid = Always(distance_to_obstacle > 0.5)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "KeoUSbAnCRer"
- },
- "source": [
- "### Eventually Always"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 340
- },
- "id": "D2UKqsfDCRer",
- "outputId": "43e320bf-7eb4-48bd-af55-ebaf2bf9fa57"
- },
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 27,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# nested temporal operators\n",
- "# NOTE: temporal operatos pad signal with the value at the last time step\n",
- "ϕ = Eventually(Always(distance_to_origin < 0.5, interval=[0, 5]))\n",
- "\n",
- "\n",
- "ϕ(distance_to_origin);\n",
- "make_stl_graph(ϕ)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "V1bpm2LVCRer"
- },
- "source": [
- "### Until"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {
- "id": "AbQB7WW0CRer"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Array([0.01881605, 0.01881605, 0.01881605, 0.01881605, 0.01881605,\n",
- " 0.01881605, 0.01881605, 0.01881605, 0.01881605, 0.01881605,\n",
- " 0.01881605, 0.01881605, 0.01881605, 0.01881605, 0.01881605,\n",
- " 0.01881605, 0.01889771, 0.02217305, 0.0412991 , 0.07640123,\n",
- " 0.08110067, 0.08110067, 0.08110067, 0.08110067, 0.08110067,\n",
- " 0.08110067, 0.08110067, 0.08110067, 0.08110067, 0.08110067,\n",
- " 0.08110067, 0.08110067, 0.08110067, 0.08110067, 0.08110067,\n",
- " 0.08110067, 0.08110067, 0.08110067, 0.08110067, 0.08110067,\n",
- " 0.08110067, 0.08110067, 0.08110067, 0.08110067, 0.08110067,\n",
- " 0.08110067, 0.08110067, 0.08110067, 0.08110067, 0.08110067,\n",
- " 0.08110067, 0.08110067], dtype=float32)"
- ]
- },
- "execution_count": 28,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "ϕ = Until(avoid, reach, interval=None)\n",
- "\n",
- "ϕ((distance_to_obstacle, distance_to_origin))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "l-bt2v27CRer"
- },
- "source": [
- "### Multiple And (Or)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 29,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 340
- },
- "id": "DlNhLq9JCRer",
- "outputId": "193392e8-c698-43a3-9a94-8b48e866d620"
- },
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 29,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "ψ = distance_to_obstacle > 0.5\n",
- "ϕ = ψ & ψ & ψ # equivalent to (ψ & ψ) & ψ this formula is redundant, but just demonstrating functionality\n",
- "ϕ(((distance_to_obstacle,distance_to_obstacle), distance_to_obstacle));\n",
- "make_stl_graph(ϕ)\n",
- "\n",
- "# similarly, you can do this with Or --> ψ | ψ | ψ\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "laUYW_xQCRer"
- },
- "source": [
- "### Implies"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 340
- },
- "id": "n60vd2tmCRer",
- "outputId": "7262cfa1-5044-4d3e-a79a-297c4f194275"
- },
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 30,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "ϕ = Implies(avoid, reach)\n",
- "ϕ((distance_to_obstacle,distance_to_obstacle));\n",
- "make_stl_graph(ϕ)\n"
- ]
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "from stljax.formula import *\n",
+ "from stljax.viz import *\n",
+ "\n",
+ "import functools"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## NOTE\n",
+ "If using Expressions to define formulas, `stljax` expects input signals to be of size `[time_dim]`.\n",
+ "If using Predicates to define formulas, `stljax` expects input signals to be of size `[time_dim, state_dim]` where `state_dim` is the expected input size of your predicate function.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def compute_distance_to_origin(states):\n",
+ " return jnp.linalg.norm(states[...,:2], axis=-1, keepdims=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Array([1.4142135, 1.4142135, 1.4142135, 1.4142135, 1.4142135, 1.4142135,\n",
+ " 1.4142135, 1.4142135, 1.4142135, 1.4142135], dtype=float32)"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "T = 10\n",
+ "compute_distance_to_origin(jnp.ones([T, 2]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Using Expressions\n",
+ "Expressions are placeholders for input signals. Specifically, it is assuming the signal is already a 1D array, such as the output of a predicate function. \n",
+ "\n",
+ "This is useful if you have signals from predicates computed already. \n",
+ "\n",
+ "In general, this is useful for readability and visualization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "AssertionError",
+ "evalue": "Input Expression does not have numerical values",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[4], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m distance_to_origin_exp \u001b[38;5;241m=\u001b[39m Expression(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmagnitude\u001b[39m\u001b[38;5;124m\"\u001b[39m, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# can define an Expression without setting values for the expression right now\u001b[39;00m\n\u001b[1;32m 2\u001b[0m formula_exp \u001b[38;5;241m=\u001b[39m Eventually(distance_to_origin_exp \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0.5\u001b[39m) \u001b[38;5;66;03m# can define an STL formula given an expression, again, the value of the expression does not need to be set yet\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m \u001b[43mformula_exp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdistance_to_origin_exp\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# <---- this will throw an error since the expression does not have values set yet\u001b[39;00m\n",
+ "File \u001b[0;32m~/repos/stljax/stljax/formula.py:79\u001b[0m, in \u001b[0;36mSTL_Formula.__call__\u001b[0;34m(self, signal, **kwargs)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, signal, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 73\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;124;03m Evaluates the robustness_trace given the input. The input is converted to the numerical value first.\u001b[39;00m\n\u001b[1;32m 75\u001b[0m \n\u001b[1;32m 76\u001b[0m \u001b[38;5;124;03m See STL_Formula.robustness_trace\u001b[39;00m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 79\u001b[0m inputs \u001b[38;5;241m=\u001b[39m \u001b[43mconvert_to_input_values\u001b[49m\u001b[43m(\u001b[49m\u001b[43msignal\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrobustness_trace(inputs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
+ "File \u001b[0;32m~/repos/stljax/stljax/formula.py:866\u001b[0m, in \u001b[0;36mconvert_to_input_values\u001b[0;34m(inputs)\u001b[0m\n\u001b[1;32m 864\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(inputs, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 865\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(inputs, Expression):\n\u001b[0;32m--> 866\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m inputs\u001b[38;5;241m.\u001b[39mvalue \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInput Expression does not have numerical values\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 867\u001b[0m \u001b[38;5;66;03m# if Expression is not time reversed\u001b[39;00m\n\u001b[1;32m 868\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m inputs\u001b[38;5;241m.\u001b[39mvalue\n",
+ "\u001b[0;31mAssertionError\u001b[0m: Input Expression does not have numerical values"
+ ]
+ }
+ ],
+ "source": [
+ "distance_to_origin_exp = Expression(\"magnitude\", value=None) # can define an Expression without setting values for the expression right now\n",
+ "formula_exp = Eventually(distance_to_origin_exp < 0.5) # can define an STL formula given an expression, again, the value of the expression does not need to be set yet\n",
+ "\n",
+ "\n",
+ "formula_exp(distance_to_origin_exp) # <---- this will throw an error since the expression does not have values set yet\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Array([ 0.24982172, 0.24982172, 0.24982172, 0.24982172, -0.9981775 ], dtype=float32)"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# so let's go ahead and set a value for the expression\n",
+ "T = 5\n",
+ "states = jnp.array(np.random.randn(T, 2))\n",
+ "states_norm = compute_distance_to_origin(states) # compute distance to origin\n",
+ "\n",
+ "distance_to_origin_exp.set_value(states_norm) # set value for Expression\n",
+ "\n",
+ "# compute robustness trace\n",
+ "formula_exp(distance_to_origin_exp) # <---- this will no longer throw an error since the expression has a value set\n",
+ "\n",
+ "# alternatively, we can directly plug any jnp.array and evaluate the robustness without \n",
+ "states2 = jnp.array(np.random.randn(T, 2))\n",
+ "states_norm2 = compute_distance_to_origin(states2) # compute distance to origin\n",
+ "formula_exp(states_norm2) \n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can compute the robustness value (instead of trace) and take the derivative"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Robustness value: -0.041\n",
+ "\n",
+ "Gradient of robustness value w.r.t. input:\n",
+ " [-0. -0. -1. -0. -0.]\n"
+ ]
}
- ],
- "metadata": {
- "colab": {
- "include_colab_link": true,
- "provenance": []
- },
- "kernelspec": {
- "display_name": "stljax",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.4"
+ ],
+ "source": [
+ "robustness = formula_exp.robustness(states_norm) \n",
+ "print(f\"Robustness value: {robustness:.3f}\\n\")\n",
+ "\n",
+ "gradient = jax.grad(formula_exp.robustness)(states_norm) \n",
+ "print(f\"Gradient of robustness value w.r.t. input:\\n {gradient}\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can apply a smooth max/min approximation by selecting a `approx_method` and `temperature`.\n",
+ "The default `approx_method` is `true`.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Robustness value: 1.030\n",
+ "\n",
+ "Gradient of robustness value w.r.t. input:\n",
+ " [-0.09700805 -0.2143781 -0.34265578 -0.23554076 -0.1104174 ]\n"
+ ]
+ }
+ ],
+ "source": [
+ "approx_method = \"logsumexp\" # or \"softmax\"\n",
+ "temperature = 1. # needs to be > 0\n",
+ "\n",
+ "robustness = formula_exp.robustness(states_norm, approx_method=approx_method, temperature=temperature) \n",
+ "print(f\"Robustness value: {robustness:.3f}\\n\")\n",
+ "\n",
+ "gradient = jax.grad(formula_exp.robustness)(states_norm, approx_method=approx_method, temperature=temperature) \n",
+ "print(f\"Gradient of robustness value w.r.t. input:\\n {gradient}\") # <----- gradients are spread across different values"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For formulas that are defined with two different Expressions, we need to be careful about the signals we are feeding in."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Array([-1.8956004 , -1.1781825 , -0.04106313, -0.41590554, -1.1735218 ], dtype=float32)"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# if both subformulas use the same signal, then we can do this\n",
+ "phi = (distance_to_origin_exp > 0) & (distance_to_origin_exp < 0.5) \n",
+ "phi(states_norm)\n",
+ "\n",
+ "\n",
+ "# if the formula depends on two different signals, then we need to provide the two signals as tuple\n",
+ "distance_to_origin_exp = Expression(\"magnitude\", value=None)\n",
+ "speed_exp = Expression(\"speed\", value=None)\n",
+ "\n",
+ "phi = (distance_to_origin_exp > 0) & (speed_exp < 0.5) \n",
+ "\n",
+ "phi(states_norm) # <--- Will give WRONG ANSWER\n",
+ "\n",
+ "\n",
+ "speed = jnp.array(np.random.randn(T))\n",
+ "input_correct_order = (states_norm, speed)\n",
+ "input_wrong_order = (speed, states_norm)\n",
+ "phi(input_correct_order) # <--- Will give desired answer\n",
+ "phi(input_wrong_order) # <--- Will give WRONG ANSWER since the ordering of the input does not correspond to how phi is defined\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Using Predicates\n",
+ "Predicates are the functions that an N-D signal is passed through and its outputs are then passed through each operation of the STL formula.\n",
+ "We can construct an STL formula by specifying the predicate functions and the connectives and temporal operations.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Array([ True, True, True, True, True], dtype=bool)"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "distance_to_origin_pred = Predicate(\"magnitude\", predicate_function=compute_distance_to_origin) # define a predicate function with a name and the function\n",
+ "formula_pred = Eventually(distance_to_origin_pred < 0.5) # define the STL formula\n",
+ "\n",
+ "# so let's go ahead and set a value for the input N-D array which will be the input into the predicate function.\n",
+ "T = 5\n",
+ "states = jnp.array(np.random.randn(T, 2)) # 2D signal\n",
+ "output_from_using_predicate = formula_pred(states) # compute distance to origin INSIDE \n",
+ "\n",
+ "\n",
+ "# NOTE: this is equivalent to the following with expressions\n",
+ "states_norm = compute_distance_to_origin(states) # computes distance to origin OUTSIDE \n",
+ "output_from_using_expression = formula_exp(states_norm) \n",
+ "\n",
+ "\n",
+ "# check if we get the same answer\n",
+ "jnp.isclose(output_from_using_predicate, output_from_using_expression)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Similarly, we can compute the robustness value (instead of trace) and take the derivative. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Robustness value: 0.951\n",
+ "\n",
+ "Gradient of robustness value w.r.t. input:\n",
+ " [[-0.05231676 0.04973718]\n",
+ " [-0.02268155 -0.20005907]\n",
+ " [ 0.23647855 0.29368952]\n",
+ " [ 0.21631472 -0.10630771]\n",
+ " [-0.0618272 0.08902159]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "approx_method = \"logsumexp\" # or \"softmax\"\n",
+ "temperature = 1. # needs to be > 0\n",
+ "\n",
+ "robustness = formula_pred.robustness(states, approx_method=approx_method, temperature=temperature) \n",
+ "print(f\"Robustness value: {robustness:.3f}\\n\")\n",
+ "\n",
+ "gradient = jax.grad(formula_pred.robustness)(states, approx_method=approx_method, temperature=temperature) \n",
+ "print(f\"Gradient of robustness value w.r.t. input:\\n {gradient}\") # <----- gradients are spread across different values"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Note that when taken gradients with formulas defined with predicates, the input is the N-D signal which is passed into the predicate function and other robustness formulas. That is to say, the gradient will be influenced by the choice of the predicate. \n",
+ "\n",
+ "To get the same gradient output when using Expressions, we need to do the following:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Array([[-0.05231676, 0.04973718],\n",
+ " [-0.02268155, -0.20005907],\n",
+ " [ 0.23647855, 0.29368952],\n",
+ " [ 0.21631472, -0.10630771],\n",
+ " [-0.0618272 , 0.08902159]], dtype=float32)"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
}
+ ],
+ "source": [
+ "def foo(states):\n",
+ " states_norm = compute_distance_to_origin(states) # compute distance to origin\n",
+ " return formula_exp.robustness(states_norm, approx_method=approx_method, temperature=temperature) \n",
+ "\n",
+ "jax.grad(foo)(states)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "test",
+ "language": "python",
+ "name": "python3"
},
- "nbformat": 4,
- "nbformat_minor": 0
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
}
diff --git a/examples/parametric_time_interval.ipynb b/examples/parametric_time_interval.ipynb
new file mode 100644
index 0000000..63ce89b
--- /dev/null
+++ b/examples/parametric_time_interval.ipynb
@@ -0,0 +1,287 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "from ipywidgets import interact\n",
+ "import ipywidgets as widgets\n",
+ "\n",
+ "from stljax.formula import *\n",
+ "from stljax.viz import *\n",
+ "from stljax.utils import anneal\n",
+ "\n",
+ "from matplotlib import rc\n",
+ "rc('font',**{'family':'serif','serif':['Palatino']})\n",
+ "rc('text', usetex=True)\n",
+ "\n",
+ "jax.config.update(\"jax_enable_x64\", True)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Generate data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "T = 20\n",
+ "fontsize = 14\n",
+ "true_t_start = 0.21\n",
+ "true_t_end = 0.59\n",
+ "bs = 32\n",
+ "\n",
+ "\n",
+ "key = jax.random.key(1701)\n",
+ "signal_data = jax.vmap(smooth_mask, [None, 0, None, None])(T, true_t_start + 0.02 * jax.random.normal(key, shape=(bs,)), true_t_end, 3.) + jax.random.normal(key, shape=(bs, T,)) * 0.1 - 0.5\n",
+ "plt.figure(figsize=(5,2))\n",
+ "plt.plot(jnp.linspace(0,1,T), signal_data.T, color=\"black\", alpha=0.2)\n",
+ "plt.xlabel(\"Normalized time\", fontsize=fontsize, labelpad=-1)\n",
+ "plt.ylabel(\"Signal\", fontsize=fontsize, labelpad=-1)\n",
+ "plt.grid()\n",
+ "plt.tight_layout()\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Define the STL formula with differentiable time intervals and loss function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pred = Predicate('x', lambda x: x)\n",
+ "phi = DifferentiableAlways(pred > 0.)\n",
+ "phi_robustness_jit = jax.jit(phi.robustness, static_argnames=(\"approx_method\"))\n",
+ "\n",
+ "\n",
+ "@functools.partial(jax.jit, static_argnames=(\"approx_method\"))\n",
+ "def loss(signal_data, t_start, t_end, scale, approx_method, temperature, coeff):\n",
+ " rob_partial = functools.partial(phi_robustness_jit, t_start=t_start, t_end=t_end, scale=scale, approx_method=approx_method, temperature=temperature)\n",
+ " robustness_ = jax.vmap(rob_partial, [0])(signal_data)\n",
+ " robustness = jax.nn.relu(-jnp.where(t_start < (t_end - 0.05), robustness_, jnp.nan)).mean()\n",
+ " return robustness + coeff * (t_start - t_end)\n",
+ "\n",
+ "\n",
+ "grad_loss = jax.jit(jax.grad(loss, [1,2]), static_argnames=(\"approx_method\"))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set up gradient descent routine. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "approx_method = \"logsumexp\"\n",
+ "temperature = 1.\n",
+ "scale = 1.\n",
+ "a = -2.\n",
+ "b = 2.\n",
+ "lr = 1E-2\n",
+ "max_steps = 5000\n",
+ "scale_start = 0.1\n",
+ "scale_end = 20\n",
+ "temperature_start = 0.1\n",
+ "temperature_end = 20\n",
+ "a_list = [a]\n",
+ "b_list = [b]\n",
+ "coeff_start = 0.1\n",
+ "coeff_end = 0.\n",
+ "\n",
+ "# Gradient descent!\n",
+ "for i in range(max_steps):\n",
+ " j = (i / max_steps)\n",
+ " s = (1 - j) * scale_start + j * scale_end\n",
+ " t = (1 - j) * temperature_start + j * temperature_end\n",
+ " c = (1 - j) * coeff_start + j * coeff_end\n",
+ " a_ = jax.nn.sigmoid(a)\n",
+ " b_ = jax.nn.sigmoid(b)\n",
+ " g = grad_loss(signal_data, a_, b_, s, approx_method, t, c)\n",
+ " a -= lr * g[0] * a_ * (1 - a_)\n",
+ " b -= lr * g[1] * b_ * (1 - b_)\n",
+ " a_list.append(a)\n",
+ " b_list.append(b)\n",
+ " # print(a,b)\n",
+ "a_list = jnp.stack(a_list)\n",
+ "b_list = jnp.stack(b_list)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "step_size = 50\n",
+ "coeff = 0.1\n",
+ "\n",
+ "def visualize_loss_landscape(loss_func, signal, scale, approx_method, temperature, coeff):\n",
+ " N = 100\n",
+ " fontsize = 14\n",
+ " levels = 10\n",
+ " T = signal.shape[0]\n",
+ " starts, ends = jnp.meshgrid(jnp.linspace(0,1, N), jnp.linspace(0,1, N))\n",
+ " losses = jax.vmap(loss_func, [None, 0, 0, None, None, None, None])(signal, starts.reshape([-1,1]), ends.reshape([-1,1]), scale, approx_method, temperature, coeff).reshape([N,N])\n",
+ " \n",
+ " plt.contourf(starts, ends, losses, levels=levels, cmap=\"jet\", alpha=0.4)\n",
+ " plt.colorbar()\n",
+ " if approx_method != \"true\":\n",
+ " match approx_method:\n",
+ " case \"logsumexp\":\n",
+ " app = \"LSE\"\n",
+ " case \"softmax\":\n",
+ " app = \"soft\"\n",
+ " plt.title(\"Loss landscape \\n (c = %.2f, $\\\\tau_\\\\mathrm{%s}$ = %.2f)\"%(scale, app, temperature))\n",
+ " else:\n",
+ " plt.title(\"Loss landscape \\n (c = %.2f)\"%(scale))\n",
+ " plt.xlabel(\"$a$\", fontsize=fontsize, labelpad=-3)\n",
+ " plt.ylabel(\"$b$\", fontsize=fontsize, labelpad=-3)\n",
+ " plt.grid(zorder=-5, alpha=0.2)\n",
+ " \n",
+ " \n",
+ "def visualize_results(i, approx_method):\n",
+ " a, b = jax.nn.sigmoid(a_list[i]), jax.nn.sigmoid(b_list[i])\n",
+ " j = (i / max_steps)\n",
+ " s = (1 - j) * scale_start + j * scale_end\n",
+ " t = (1 - j) * temperature_start + j * temperature_end\n",
+ " ell = loss(signal_data, a, b, s, approx_method, t, coeff)\n",
+ " \n",
+ " # plt.figure(figsize=(4,3))\n",
+ " \n",
+ " visualize_loss_landscape(loss, signal_data, s, approx_method, t, coeff)\n",
+ " plt.plot(jax.nn.sigmoid(a_list[::step_size]), jax.nn.sigmoid(b_list[::step_size]), \"o-\", markersize=3, color=\"black\", linewidth=1)\n",
+ " current_loss = loss(signal_data, a, b, s, approx_method, t, coeff)\n",
+ " gt_loss = loss(signal_data, true_t_start, true_t_end, s, approx_method, t, coeff)\n",
+ " plt.scatter([a], [b], marker=\"o\", s=40, color=\"magenta\", label=\"Current: %.3f\"%current_loss, edgecolor=\"black\", zorder=4)\n",
+ " plt.scatter([true_t_start], [true_t_end], marker=\"*\", s=100, color=\"orange\", label=\"Ground truth: %.3f\"%gt_loss, edgecolor=\"black\", zorder=4)\n",
+ " \n",
+ "\n",
+ " plt.vlines(a, 0, 1, color=\"red\", linestyle='--')\n",
+ " plt.hlines(b, 0, 1, color=\"blue\", linestyle='--')\n",
+ " if approx_method != \"true\":\n",
+ " match approx_method:\n",
+ " case \"logsumexp\":\n",
+ " app = \"LSE\"\n",
+ " case \"softmax\":\n",
+ " app = \"soft\"\n",
+ " plt.title(\"Loss = %.5f (c = %.2f, $\\\\tau_\\\\mathrm{%s}$ = %.2f)\"%(ell, s, app, t))\n",
+ " else:\n",
+ " plt.title(\"Loss = %.5f (c = %.2f)\"%(ell, s))\n",
+ " plt.legend(loc=\"lower right\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c307d3fbc6e5413bbbc0562a41b129b4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "interactive(children=(IntSlider(value=0, description='i: ', max=4999, step=50), Output()), _dom_classes=('widg…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "index_slider = widgets.IntSlider(value=0, min=0, max=max_steps-1, step=step_size, description='i: ')\n",
+ "interact(visualize_results, i=index_slider, approx_method=widgets.fixed(approx_method))\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "test",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/unconstrained_trajectory_optimization.ipynb b/examples/unconstrained_trajectory_optimization.ipynb
new file mode 100644
index 0000000..beadad7
--- /dev/null
+++ b/examples/unconstrained_trajectory_optimization.ipynb
@@ -0,0 +1,733 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "from stljax.formula import *\n",
+ "from stljax.viz import *\n",
+ "from stljax.utils import anneal\n",
+ "\n",
+ "from matplotlib import rc\n",
+ "rc('font',**{'family':'serif','serif':['Palatino']})\n",
+ "rc('text', usetex=True)\n",
+ "\n",
+ "jax.config.update(\"jax_enable_x64\", True)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Some helper functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "@jax.jit\n",
+ "def dynamics_discrete_step(state, control, dt=0.1):\n",
+ " '''Single integrator dynamics'''\n",
+ " return state + control * dt\n",
+ "\n",
+ "@jax.jit\n",
+ "def simulate_dynamics(controls, state0, dt):\n",
+ " T = controls.shape[0]\n",
+ " _states = [state0]\n",
+ " for t in range(T):\n",
+ " _states.append(dynamics_discrete_step(_states[-1], controls[t,:], dt))\n",
+ " return jnp.concatenate(_states, 0)\n",
+ "\n",
+ "@jax.jit\n",
+ "def compute_distance_to_point(states, point):\n",
+ " return jnp.linalg.norm(states[...,:2] - point, axis=-1, keepdims=True)\n",
+ "\n",
+ "@jax.jit\n",
+ "def compute_distance_to_origin(states):\n",
+ " return compute_distance_to_point(states, jnp.zeros(2))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Start setting STL formulas"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# environment parameters\n",
+ "obstacle_center = jnp.array([[0,2]]) # obstacle location\n",
+ "reach_radius = 0.1\n",
+ "obstacle_radius = 0.5\n",
+ "\n",
+ "\n",
+ "distance_to_origin = Predicate(\"magnitude\", compute_distance_to_origin)\n",
+ "distance_to_obstacle = Predicate(\"distance_to_obs\", lambda x: compute_distance_to_point(x, obstacle_center))\n",
+ "\n",
+ "reach = Eventually(distance_to_origin < reach_radius)\n",
+ "avoid = Always(distance_to_obstacle > obstacle_radius)\n",
+ "# stay = Eventually(Always(distance_to_obstacle < 0.5, interval=[0, 7]), interval=[0,20]) # if you don't want to have differentiable time intervals\n",
+ "stay = DifferentiableAlways(distance_to_obstacle < 0.5)\n",
+ "\n",
+ "formula = reach & stay\n",
+ "# formula = Until(distance_to_obstacle > 0.5, Always(distance_to_origin < 0.5), interval=[40,45])\n",
+ "\n",
+ "make_stl_graph(formula)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Defining cost function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "def exponenial_penalty(x):\n",
+ " return jnp.exp(x)\n",
+ "\n",
+ "@functools.partial(jax.jit, static_argnames=(\"approx_method\"))\n",
+ "def loss(controls, t_start, t_end, scale, state0, umax, dt, coeffs=[1., 0.1, 5., 0.], approx_method=\"true\", temperature=None):\n",
+ " # see paper for more details on loss function\n",
+ " # generate trajectory from control sequence and reverse along time dimension\n",
+ " traj = simulate_dynamics(controls, state0, dt)\n",
+ " # loss functions\n",
+ " loss_robustness = jax.nn.relu(-formula.robustness(traj, t_start=t_start, t_end=t_end, scale=scale, approx_method=approx_method, temperature=temperature))\n",
+ " loss_control_smoothness = 0 * jnp.abs(jnp.diff(controls, axis=1)).sum(-1).mean() + (controls**2).sum(-1).mean() # make controls smoother\n",
+ " loss_control_limits = jax.nn.relu(jnp.linalg.norm(controls, axis=-1) - umax).mean() # penalize control limit violation\n",
+ " min_interval = 0.2\n",
+ " interval_difference = min_interval - (t_end - t_start) # negative is good\n",
+ " cost_array = jnp.array([\n",
+ " loss_robustness,\n",
+ " loss_control_smoothness,\n",
+ " loss_control_limits,\n",
+ " exponenial_penalty(2 * interval_difference)\n",
+ " ])\n",
+ " return jnp.dot(jnp.array(coeffs), cost_array)\n",
+ " # return coeffs[0] * loss_robustness + coeffs[1] * loss_control_smoothness + coeffs[2] * loss_control_limits + coeffs[3] * exponenial_penalty(2 * interval_difference)\n",
+ " \n",
+ "grad_jit = jax.jit(jax.grad(loss, [0,1,2]), static_argnames=\"approx_method\")\n",
+ "\n",
+ "@jax.jit\n",
+ "def true_robustness(controls, t_start, t_end, scale, state0, dt):\n",
+ " traj = simulate_dynamics(controls, state0, dt)\n",
+ " return formula.robustness(traj, t_start=t_start, t_end=t_end, scale=scale).mean()\n",
+ "\n",
+ "@jax.jit\n",
+ "def schedule(i, i_max, start, end):\n",
+ " j = (i / i_max)\n",
+ " temp = anneal(j)\n",
+ " return temp * (end - start) + start\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Setting up parameters to begin the gradient descent routine"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "np.random.seed(123)\n",
+ "T = 51 # time horizon\n",
+ "dt = 0.1 # time step size\n",
+ "ts = jnp.array([t * dt for t in range(T)]) # time step array\n",
+ "umax = 2. # max control limit\n",
+ "\n",
+ "controls = jnp.array(np.random.randn(T,2)) # initial random control sequence\n",
+ "state0 = jnp.ones(2).reshape([1,2]) * 3. # initial state\n",
+ "states_ = [simulate_dynamics(controls, state0, dt)] # list to collect all the state trajectories at each gradient descent step\n",
+ "\n",
+ "# initial values for time interval (before passing through softmax)\n",
+ "t_start = -1.8\n",
+ "t_end = 1.5\n",
+ "\n",
+ "lr = 1E-2 # learning rate\n",
+ "approx_method = \"logsumexp\"\n",
+ "n_steps = 10000 # number of gradient steps\n",
+ "\n",
+ "# start and end values for annealing temperature and scale\n",
+ "start_temp = 1\n",
+ "end_temp = 100\n",
+ "\n",
+ "start_scale = 10\n",
+ "end_scale = 100\n",
+ "\n",
+ "coeffs = [1.1, 0.5, 2., 0.05] # coefficients for loss function"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Run the functions to test them out\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "scale = 0.1\n",
+ "loss(controls, t_start, t_end, scale, state0, umax, dt)\n",
+ "loss(controls, t_start, t_end, scale, state0, umax, dt, approx_method=\"softmax\", temperature=5)\n",
+ "true_robustness(controls, t_start, t_end, scale, state0, dt)\n",
+ "grad_jit(controls, t_start, t_end, scale, state0, umax, dt, coeffs, approx_method, 0.2);\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Run optimization loop!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " 0 -- true robustness: -3.73 smoothness: 2.58 control limits: 0.07 interval: 0.39 t_start: 0.14 t_end: 0.82\n",
+ " 50 -- true robustness: -3.64 smoothness: 2.51 control limits: 0.07 interval: 0.40 t_start: 0.14 t_end: 0.80\n",
+ "100 -- true robustness: -3.56 smoothness: 2.45 control limits: 0.06 interval: 0.40 t_start: 0.14 t_end: 0.80\n",
+ "150 -- true robustness: -3.48 smoothness: 2.39 control limits: 0.05 interval: 0.40 t_start: 0.14 t_end: 0.80\n",
+ "200 -- true robustness: -3.40 smoothness: 2.34 control limits: 0.05 interval: 0.41 t_start: 0.15 t_end: 0.80\n",
+ "250 -- true robustness: -3.32 smoothness: 2.29 control limits: 0.04 interval: 0.43 t_start: 0.16 t_end: 0.78\n",
+ "300 -- true robustness: -3.25 smoothness: 2.24 control limits: 0.04 interval: 0.43 t_start: 0.16 t_end: 0.78\n",
+ "350 -- true robustness: -3.17 smoothness: 2.20 control limits: 0.03 interval: 0.43 t_start: 0.16 t_end: 0.78\n",
+ "400 -- true robustness: -3.10 smoothness: 2.16 control limits: 0.03 interval: 0.44 t_start: 0.16 t_end: 0.78\n",
+ "450 -- true robustness: -3.03 smoothness: 2.12 control limits: 0.03 interval: 0.46 t_start: 0.18 t_end: 0.77\n",
+ "500 -- true robustness: -2.96 smoothness: 2.09 control limits: 0.02 interval: 0.47 t_start: 0.18 t_end: 0.76\n",
+ "550 -- true robustness: -2.89 smoothness: 2.06 control limits: 0.02 interval: 0.47 t_start: 0.18 t_end: 0.76\n",
+ "600 -- true robustness: -2.82 smoothness: 2.03 control limits: 0.02 interval: 0.49 t_start: 0.20 t_end: 0.76\n",
+ "650 -- true robustness: -2.75 smoothness: 2.01 control limits: 0.02 interval: 0.50 t_start: 0.20 t_end: 0.75\n",
+ "700 -- true robustness: -2.69 smoothness: 1.99 control limits: 0.02 interval: 0.52 t_start: 0.22 t_end: 0.74\n",
+ "750 -- true robustness: -2.61 smoothness: 1.97 control limits: 0.02 interval: 0.55 t_start: 0.24 t_end: 0.74\n",
+ "800 -- true robustness: -2.52 smoothness: 1.96 control limits: 0.02 interval: 0.57 t_start: 0.26 t_end: 0.74\n",
+ "850 -- true robustness: -2.42 smoothness: 1.95 control limits: 0.02 interval: 0.60 t_start: 0.27 t_end: 0.73\n",
+ "900 -- true robustness: -2.33 smoothness: 1.94 control limits: 0.02 interval: 0.63 t_start: 0.29 t_end: 0.72\n",
+ "950 -- true robustness: -2.24 smoothness: 1.93 control limits: 0.02 interval: 0.64 t_start: 0.30 t_end: 0.72\n",
+ "1000 -- true robustness: -2.15 smoothness: 1.92 control limits: 0.02 interval: 0.69 t_start: 0.32 t_end: 0.71\n",
+ "1050 -- true robustness: -2.07 smoothness: 1.92 control limits: 0.02 interval: 0.72 t_start: 0.34 t_end: 0.70\n",
+ "1100 -- true robustness: -1.99 smoothness: 1.91 control limits: 0.02 interval: 0.74 t_start: 0.35 t_end: 0.70\n",
+ "1150 -- true robustness: -1.91 smoothness: 1.91 control limits: 0.02 interval: 0.77 t_start: 0.37 t_end: 0.70\n",
+ "1200 -- true robustness: -1.83 smoothness: 1.91 control limits: 0.02 interval: 0.78 t_start: 0.38 t_end: 0.70\n",
+ "1250 -- true robustness: -1.76 smoothness: 1.91 control limits: 0.02 interval: 0.81 t_start: 0.39 t_end: 0.70\n",
+ "1300 -- true robustness: -1.69 smoothness: 1.91 control limits: 0.02 interval: 0.81 t_start: 0.39 t_end: 0.70\n",
+ "1350 -- true robustness: -1.63 smoothness: 1.91 control limits: 0.02 interval: 0.80 t_start: 0.39 t_end: 0.70\n",
+ "1400 -- true robustness: -1.58 smoothness: 1.92 control limits: 0.02 interval: 0.80 t_start: 0.39 t_end: 0.71\n",
+ "1450 -- true robustness: -1.52 smoothness: 1.92 control limits: 0.02 interval: 0.78 t_start: 0.39 t_end: 0.72\n",
+ "1500 -- true robustness: -1.48 smoothness: 1.92 control limits: 0.02 interval: 0.77 t_start: 0.39 t_end: 0.72\n",
+ "1550 -- true robustness: -1.44 smoothness: 1.92 control limits: 0.02 interval: 0.74 t_start: 0.39 t_end: 0.74\n",
+ "1600 -- true robustness: -1.42 smoothness: 1.92 control limits: 0.02 interval: 0.73 t_start: 0.39 t_end: 0.75\n",
+ "1650 -- true robustness: -1.41 smoothness: 1.90 control limits: 0.02 interval: 0.71 t_start: 0.39 t_end: 0.76\n",
+ "1700 -- true robustness: -1.40 smoothness: 1.88 control limits: 0.02 interval: 0.70 t_start: 0.39 t_end: 0.76\n",
+ "1750 -- true robustness: -1.41 smoothness: 1.86 control limits: 0.02 interval: 0.68 t_start: 0.38 t_end: 0.78\n",
+ "1800 -- true robustness: -1.41 smoothness: 1.83 control limits: 0.02 interval: 0.67 t_start: 0.38 t_end: 0.78\n",
+ "1850 -- true robustness: -1.41 smoothness: 1.81 control limits: 0.01 interval: 0.66 t_start: 0.38 t_end: 0.78\n",
+ "1900 -- true robustness: -1.42 smoothness: 1.79 control limits: 0.01 interval: 0.65 t_start: 0.38 t_end: 0.79\n",
+ "1950 -- true robustness: -1.43 smoothness: 1.77 control limits: 0.01 interval: 0.64 t_start: 0.37 t_end: 0.80\n",
+ "2000 -- true robustness: -1.43 smoothness: 1.74 control limits: 0.01 interval: 0.64 t_start: 0.37 t_end: 0.80\n",
+ "2050 -- true robustness: -1.44 smoothness: 1.72 control limits: 0.01 interval: 0.63 t_start: 0.37 t_end: 0.81\n",
+ "2100 -- true robustness: -1.44 smoothness: 1.71 control limits: 0.01 interval: 0.62 t_start: 0.37 t_end: 0.81\n",
+ "2150 -- true robustness: -1.44 smoothness: 1.69 control limits: 0.01 interval: 0.61 t_start: 0.37 t_end: 0.82\n",
+ "2200 -- true robustness: -1.44 smoothness: 1.67 control limits: 0.01 interval: 0.61 t_start: 0.37 t_end: 0.82\n",
+ "2250 -- true robustness: -1.44 smoothness: 1.65 control limits: 0.01 interval: 0.61 t_start: 0.37 t_end: 0.82\n",
+ "2300 -- true robustness: -1.43 smoothness: 1.64 control limits: 0.01 interval: 0.60 t_start: 0.37 t_end: 0.83\n",
+ "2350 -- true robustness: -1.43 smoothness: 1.62 control limits: 0.01 interval: 0.59 t_start: 0.37 t_end: 0.83\n",
+ "2400 -- true robustness: -1.41 smoothness: 1.61 control limits: 0.00 interval: 0.59 t_start: 0.37 t_end: 0.83\n",
+ "2450 -- true robustness: -1.39 smoothness: 1.59 control limits: 0.00 interval: 0.59 t_start: 0.37 t_end: 0.84\n",
+ "2500 -- true robustness: -1.36 smoothness: 1.58 control limits: 0.00 interval: 0.59 t_start: 0.37 t_end: 0.84\n",
+ "2550 -- true robustness: -1.33 smoothness: 1.56 control limits: 0.00 interval: 0.59 t_start: 0.37 t_end: 0.84\n",
+ "2600 -- true robustness: -1.28 smoothness: 1.55 control limits: 0.00 interval: 0.59 t_start: 0.37 t_end: 0.84\n",
+ "2650 -- true robustness: -1.23 smoothness: 1.54 control limits: 0.00 interval: 0.59 t_start: 0.37 t_end: 0.84\n",
+ "2700 -- true robustness: -1.18 smoothness: 1.52 control limits: 0.00 interval: 0.59 t_start: 0.37 t_end: 0.84\n",
+ "2750 -- true robustness: -1.12 smoothness: 1.51 control limits: 0.00 interval: 0.59 t_start: 0.37 t_end: 0.84\n",
+ "2800 -- true robustness: -1.06 smoothness: 1.50 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "2850 -- true robustness: -1.01 smoothness: 1.49 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "2900 -- true robustness: -0.96 smoothness: 1.48 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "2950 -- true robustness: -0.91 smoothness: 1.47 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3000 -- true robustness: -0.86 smoothness: 1.46 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3050 -- true robustness: -0.82 smoothness: 1.45 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3100 -- true robustness: -0.78 smoothness: 1.44 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3150 -- true robustness: -0.74 smoothness: 1.43 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3200 -- true robustness: -0.71 smoothness: 1.42 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3250 -- true robustness: -0.68 smoothness: 1.41 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3300 -- true robustness: -0.65 smoothness: 1.40 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3350 -- true robustness: -0.63 smoothness: 1.39 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3400 -- true robustness: -0.60 smoothness: 1.38 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3450 -- true robustness: -0.58 smoothness: 1.37 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3500 -- true robustness: -0.56 smoothness: 1.36 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3550 -- true robustness: -0.54 smoothness: 1.35 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3600 -- true robustness: -0.52 smoothness: 1.34 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3650 -- true robustness: -0.51 smoothness: 1.33 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3700 -- true robustness: -0.49 smoothness: 1.32 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3750 -- true robustness: -0.48 smoothness: 1.31 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3800 -- true robustness: -0.46 smoothness: 1.30 control limits: 0.00 interval: 0.58 t_start: 0.37 t_end: 0.84\n",
+ "3850 -- true robustness: -0.45 smoothness: 1.29 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "3900 -- true robustness: -0.44 smoothness: 1.28 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "3950 -- true robustness: -0.43 smoothness: 1.28 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "4000 -- true robustness: -0.41 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "4050 -- true robustness: -0.40 smoothness: 1.26 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "4100 -- true robustness: -0.39 smoothness: 1.25 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "4150 -- true robustness: -0.38 smoothness: 1.24 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "4200 -- true robustness: -0.37 smoothness: 1.24 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4250 -- true robustness: -0.37 smoothness: 1.23 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4300 -- true robustness: -0.36 smoothness: 1.22 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4350 -- true robustness: -0.35 smoothness: 1.22 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4400 -- true robustness: -0.34 smoothness: 1.21 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4450 -- true robustness: -0.33 smoothness: 1.21 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4500 -- true robustness: -0.32 smoothness: 1.20 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4550 -- true robustness: -0.32 smoothness: 1.20 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4600 -- true robustness: -0.31 smoothness: 1.19 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4650 -- true robustness: -0.30 smoothness: 1.19 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4700 -- true robustness: -0.30 smoothness: 1.19 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4750 -- true robustness: -0.29 smoothness: 1.19 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4800 -- true robustness: -0.28 smoothness: 1.19 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4850 -- true robustness: -0.28 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4900 -- true robustness: -0.27 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "4950 -- true robustness: -0.27 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5000 -- true robustness: -0.26 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5050 -- true robustness: -0.26 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5100 -- true robustness: -0.25 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5150 -- true robustness: -0.25 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5200 -- true robustness: -0.24 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5250 -- true robustness: -0.24 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5300 -- true robustness: -0.23 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5350 -- true robustness: -0.23 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5400 -- true robustness: -0.22 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5450 -- true robustness: -0.22 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5500 -- true robustness: -0.22 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5550 -- true robustness: -0.21 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5600 -- true robustness: -0.21 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5650 -- true robustness: -0.20 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5700 -- true robustness: -0.20 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5750 -- true robustness: -0.19 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5800 -- true robustness: -0.19 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5850 -- true robustness: -0.19 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5900 -- true robustness: -0.18 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "5950 -- true robustness: -0.18 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6000 -- true robustness: -0.18 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6050 -- true robustness: -0.17 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6100 -- true robustness: -0.17 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6150 -- true robustness: -0.16 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6200 -- true robustness: -0.16 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6250 -- true robustness: -0.16 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6300 -- true robustness: -0.15 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6350 -- true robustness: -0.15 smoothness: 1.18 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6400 -- true robustness: -0.15 smoothness: 1.19 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6450 -- true robustness: -0.14 smoothness: 1.19 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6500 -- true robustness: -0.14 smoothness: 1.19 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6550 -- true robustness: -0.14 smoothness: 1.19 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6600 -- true robustness: -0.13 smoothness: 1.19 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6650 -- true robustness: -0.13 smoothness: 1.19 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6700 -- true robustness: -0.13 smoothness: 1.19 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6750 -- true robustness: -0.12 smoothness: 1.19 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6800 -- true robustness: -0.12 smoothness: 1.20 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6850 -- true robustness: -0.12 smoothness: 1.20 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6900 -- true robustness: -0.11 smoothness: 1.20 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "6950 -- true robustness: -0.11 smoothness: 1.20 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7000 -- true robustness: -0.11 smoothness: 1.20 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7050 -- true robustness: -0.10 smoothness: 1.20 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7100 -- true robustness: -0.10 smoothness: 1.21 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7150 -- true robustness: -0.10 smoothness: 1.21 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7200 -- true robustness: -0.10 smoothness: 1.21 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7250 -- true robustness: -0.09 smoothness: 1.21 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7300 -- true robustness: -0.09 smoothness: 1.21 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7350 -- true robustness: -0.09 smoothness: 1.21 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7400 -- true robustness: -0.08 smoothness: 1.22 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7450 -- true robustness: -0.08 smoothness: 1.22 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7500 -- true robustness: -0.08 smoothness: 1.22 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7550 -- true robustness: -0.08 smoothness: 1.22 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7600 -- true robustness: -0.07 smoothness: 1.22 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7650 -- true robustness: -0.07 smoothness: 1.23 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7700 -- true robustness: -0.07 smoothness: 1.23 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7750 -- true robustness: -0.06 smoothness: 1.23 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7800 -- true robustness: -0.06 smoothness: 1.23 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7850 -- true robustness: -0.06 smoothness: 1.23 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7900 -- true robustness: -0.06 smoothness: 1.24 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "7950 -- true robustness: -0.05 smoothness: 1.24 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "8000 -- true robustness: -0.05 smoothness: 1.24 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "8050 -- true robustness: -0.05 smoothness: 1.24 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "8100 -- true robustness: -0.05 smoothness: 1.24 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "8150 -- true robustness: -0.04 smoothness: 1.25 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "8200 -- true robustness: -0.04 smoothness: 1.25 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "8250 -- true robustness: -0.04 smoothness: 1.25 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "8300 -- true robustness: -0.04 smoothness: 1.25 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "8350 -- true robustness: -0.04 smoothness: 1.25 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "8400 -- true robustness: -0.03 smoothness: 1.25 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "8450 -- true robustness: -0.03 smoothness: 1.26 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.85\n",
+ "8500 -- true robustness: -0.03 smoothness: 1.26 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "8550 -- true robustness: -0.03 smoothness: 1.26 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "8600 -- true robustness: -0.03 smoothness: 1.26 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "8650 -- true robustness: -0.02 smoothness: 1.26 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "8700 -- true robustness: -0.02 smoothness: 1.26 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "8750 -- true robustness: -0.02 smoothness: 1.26 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "8800 -- true robustness: -0.02 smoothness: 1.26 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "8850 -- true robustness: -0.02 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "8900 -- true robustness: -0.02 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "8950 -- true robustness: -0.02 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9000 -- true robustness: -0.02 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9050 -- true robustness: -0.01 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9100 -- true robustness: -0.01 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9150 -- true robustness: -0.01 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9200 -- true robustness: -0.01 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9250 -- true robustness: -0.01 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9300 -- true robustness: -0.01 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9350 -- true robustness: -0.01 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9400 -- true robustness: -0.01 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9450 -- true robustness: -0.01 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9500 -- true robustness: -0.01 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9550 -- true robustness: -0.00 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9600 -- true robustness: -0.00 smoothness: 1.27 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9650 -- true robustness: -0.00 smoothness: 1.28 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9700 -- true robustness: -0.00 smoothness: 1.28 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9750 -- true robustness: -0.00 smoothness: 1.28 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9800 -- true robustness: 0.00 smoothness: 1.28 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9850 -- true robustness: 0.00 smoothness: 1.28 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9900 -- true robustness: 0.00 smoothness: 1.28 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n",
+ "9950 -- true robustness: 0.00 smoothness: 1.28 control limits: 0.00 interval: 0.57 t_start: 0.37 t_end: 0.84\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "for i in range(n_steps):\n",
+ " temperature = schedule(i, n_steps, start_temp, end_temp)\n",
+ " scale = schedule(i, n_steps, start_scale, end_scale)\n",
+ " \n",
+ " t_start_ = jax.nn.sigmoid(t_start)\n",
+ " t_end_ = jax.nn.sigmoid(t_end)\n",
+ " g1, g2, g3 = grad_jit(controls, t_start_, t_end_, scale, state0, umax, dt, coeffs, approx_method, temperature) # take gradient\n",
+ " if ((jnp.linalg.norm(g1)/ T / 2) < 5E-6) or (jnp.isnan(g1).sum() > 0):\n",
+ " break\n",
+ " # g = jax.grad(loss, 0)(controls, state0, umax, approx_method, temperature) # not jitting\n",
+ " controls -= g1 * lr\n",
+ " t_start -= g2 * lr * t_start_ * (1 - t_start_) \n",
+ " t_end -= g3 * lr * t_end_ * (1 - t_end_)\n",
+ " # print(g2, g3, g3 * lr * t_end_ * (1 - t_end_))\n",
+ " \n",
+ " states_.append(simulate_dynamics(controls, state0, dt))\n",
+ " if (i % 50) == 0:\n",
+ " t_start_ = jax.nn.sigmoid(t_start)\n",
+ " t_end_ = jax.nn.sigmoid(t_end)\n",
+ " print(\"%3i -- true robustness: %.2f smoothness: %.2f control limits: %.2f interval: %.2f t_start: %.2f t_end: %.2f\"%(i, true_robustness(controls, t_start_, t_end_, 1000., state0, dt), loss(controls, t_start_, t_end_, 1000., state0, umax, dt, coeffs=[0., 1., 0., 0.]), loss(controls, t_start_, t_end_, 1000., state0, umax, dt, coeffs=[0., 0., 1., 0.]), loss(controls, t_start_, t_end_, 1000., state0, umax, dt, coeffs=[0., 0., 0., 1.]), t_start_, t_end_))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, axs = plt.subplots(1,3, figsize=(15,4)) \n",
+ "\n",
+ "ax = axs[0]\n",
+ "circle1 = plt.Circle((0, 0), 0.2, color='C2', alpha=0.4)\n",
+ "circle2 = plt.Circle(obstacle_center[0], 0.5, color='C1', alpha=0.4)\n",
+ "\n",
+ "ax.add_patch(circle1)\n",
+ "ax.add_patch(circle2)\n",
+ "\n",
+ "N = 250\n",
+ "[ax.plot(*s.T, color=\"k\", alpha=0.2) for s in states_[::N]]\n",
+ "[ax.plot(*s.T, color=\"blue\", label=\"Initial traj\") for s in states_[:1]]\n",
+ "[ax.plot(*s.T, '.-', color=\"r\", markersize=10, label=\"Final traj\") for s in states_[-1:]]\n",
+ "\n",
+ "ax.scatter(states_[-1][0,:1], states_[-1][0,1:], marker=\"^\", c='yellow', edgecolor=\"k\", s=100, label=\"start\", zorder=4)\n",
+ "ax.scatter(states_[-1][-1,:1], states_[-1][-1,1:], marker=\"*\", c='yellow', edgecolor=\"k\", s=200, label=\"end\", zorder=4)\n",
+ "\n",
+ "ax.set_xlabel(\"x position\")\n",
+ "ax.set_ylabel(\"y position\")\n",
+ "ax.grid()\n",
+ "ax.legend()\n",
+ "ax.axis(\"equal\")\n",
+ "ax.set_title(\"Trajectory\")\n",
+ "\n",
+ "# plot x, y\n",
+ "ax = axs[1]\n",
+ "ax.plot(ts, states_[-1][:-1,:1], label=\"x\")\n",
+ "ax.plot(ts, states_[-1][:-1,1:], label=\"y\")\n",
+ "ax.plot(ts, distance_to_origin.predicate_function(states_[-1][1:]).squeeze(), label=\"distance to origin\")\n",
+ "ax.grid()\n",
+ "ax.axis(\"equal\")\n",
+ "ax.legend()\n",
+ "ax.set_xlabel(\"Time (s)\")\n",
+ "ax.set_ylabel(\"Position\")\n",
+ "ax.set_title(\"Position over time\")\n",
+ "\n",
+ "\n",
+ "\n",
+ "# plot control signal\n",
+ "ax = axs[2]\n",
+ "ax.plot(ts, controls[:,:1], label=\"x control\")\n",
+ "ax.plot(ts, controls[:,1:], label=\"y control\")\n",
+ "ax.plot(ts, jnp.linalg.norm(controls, axis=-1).squeeze(), label=\"control norm\")\n",
+ "ax.grid()\n",
+ "ax.axis(\"equal\")\n",
+ "ax.legend(ncols=3)\n",
+ "ax.set_xlabel(\"Time (s)\")\n",
+ "ax.set_ylabel(\"Controls\")\n",
+ "ax.set_title(\"Control sequence\")\n",
+ "\n",
+ "\n",
+ "\n",
+ "plt.tight_layout()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fontsize = 14\n",
+ "fig, ax = plt.subplots(figsize=(5, 5))\n",
+ "\n",
+ "circle1 = plt.Circle((0, 0), 0.2, color='C2', alpha=0.7)\n",
+ "circle2 = plt.Circle(obstacle_center[0], 0.5, color='C1', alpha=0.7)\n",
+ "\n",
+ "N = 250 # show intermediate solutions at everge N iterations\n",
+ "\n",
+ "[ax.plot(*s.T, color=\"blue\", label=\"Initial traj\", alpha=0.6) for s in states_[:1]]\n",
+ "[ax.plot(*s.T, '.-', color=\"r\", label=\"Final traj\", zorder=10, linewidth=2, markersize=8) for s in states_[-1:]]\n",
+ "ax.scatter(states_[-1][0,:1], states_[-1][0,1:], zorder=10, label=\"start\", color=\"yellow\", edgecolor=\"black\", marker=\"^\", s=100)\n",
+ "ax.scatter(states_[-1][-1,:1], states_[-1][-1,1:], zorder=10, label=\"end\", color=\"yellow\", edgecolor=\"black\", marker=\"*\", s=200)\n",
+ "[ax.plot(*s.T, color=\"k\", alpha=0.2, label=\"Iterations\", zorder=0) for s in states_[::N]]\n",
+ "\n",
+ "ax.add_patch(circle1)\n",
+ "ax.add_patch(circle2)\n",
+ "\n",
+ "ax.annotate(\"Goal\", (-0.4, -0.4), fontsize=fontsize-2)\n",
+ "ax.annotate(\"Target\", (-0.4, 2.7), fontsize=fontsize-2)\n",
+ "\n",
+ "ax.set_xlabel(\"$x$ position [m]\", fontsize=fontsize, labelpad=-2)\n",
+ "ax.set_ylabel(\"$y$ position [m]\", fontsize=fontsize)\n",
+ "ax.set_title(\"Robustness $\\\\rho$ = %.2f\"%formula.robustness(states_[-1], t_start=jax.nn.sigmoid(t_start), t_end=jax.nn.sigmoid(t_end), scale=1000.), fontsize=fontsize)\n",
+ "ax.grid(zorder=-6, alpha=0.5)\n",
+ "ax.legend([\"Initial guess\", \"Final trajectory\", \"Start\", \"End\"], ncol=1, fontsize=fontsize-3)\n",
+ "ax.axis(\"equal\")\n",
+ "plt.tight_layout()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "test",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/stljax/formula.py b/stljax/formula.py
index c793414..da0bf71 100644
--- a/stljax/formula.py
+++ b/stljax/formula.py
@@ -5,186 +5,7 @@
import functools
warnings.simplefilter("default")
-
-
-@jax.jit
-def bar_plus(signal, p=2):
- '''max(0,signal)**p'''
- return jax.nn.relu(signal) ** p
-
-
-@jax.jit
-def bar_minus(signal, p=2):
- '''min(0,signal)**p'''
- return (-jax.nn.relu(-signal)) ** p
-
-@functools.partial(jax.jit, static_argnames=("axis", "keepdims"))
-def M0(signal, eps, weights=None, axis=1, keepdims=True):
- '''Used in gsmr approx method, Eq 4(a) in https://arxiv.org/abs/2405.10996'''
- if weights is None:
- weights = jnp.ones_like(signal)
- sum_w = weights.sum(axis, keepdims=keepdims)
- return (
- eps**sum_w + jnp.prod(signal**weights, axis=axis, keepdims=keepdims)
- ) ** (1 / sum_w)
-
-@functools.partial(jax.jit, static_argnames=("axis", "keepdims"))
-def Mp(signal, eps, p, weights=None, axis=1, keepdims=True):
- '''Used in gsmr approx method, Eq 4(b) in https://arxiv.org/abs/2405.10996'''
- if weights is None:
- weights = jnp.ones_like(signal)
- sum_w = weights.sum(axis, keepdims=keepdims)
- return (
- eps**p + 1 / sum_w * jnp.sum(weights * signal**p, axis=axis, keepdims=keepdims)
- ) ** (1 / p)
-
-@functools.partial(jax.jit, static_argnames=("axis", "keepdims"))
-def gmsr_min(signal, eps, p, weights=None, axis=1, keepdims=True):
- '''Used in gsmr approx method, Eq 3 in https://arxiv.org/abs/2405.10996'''
-
- return (
- M0(bar_plus(signal, 2), eps, weights=weights, axis=axis, keepdims=keepdims)
- ** 0.5
- - Mp(
- bar_minus(signal, 2), eps, p, weights=weights, axis=axis, keepdims=keepdims
- )
- ** 0.5
- )
-
-@functools.partial(jax.jit, static_argnames=("axis", "keepdims"))
-def gmsr_max(signal, eps, p, weights=None, axis=1, keepdims=True):
- '''Used in gsmr approx method, Eq 4(a) but for max in https://arxiv.org/abs/2405.10996'''
-
- return -gmsr_min(-signal, eps, p, weights=weights, axis=axis, keepdims=keepdims)
-
-@functools.partial(jax.jit, static_argnames=("axis", "keepdims"))
-def gmsr_min_turbo(signal, eps, p, weights=None, axis=1, keepdims=True):
- # TODO: (norrisg) make actually turbo (faster than normal `gmsr_min`)
- pos_idx = signal > 0.0
- neg_idx = ~pos_idx
-
- return jnp.where(
- neg_idx.sum(axis, keepdims=keepdims) > 0,
- eps**0.5
- - Mp(
- bar_minus(signal, 2),
- eps,
- p,
- weights=weights,
- axis=axis,
- keepdims=keepdims,
- )
- ** 0.5,
- M0(bar_plus(signal, 2), eps, weights=weights, axis=axis, keepdims=keepdims)
- ** 0.5
- - eps**0.5,
- )
-
-@functools.partial(jax.jit, static_argnames=("axis", "keepdims"))
-def gmsr_max_turbo(signal, eps, p, weights=None, axis=1, keepdims=True):
- return -gmsr_min_turbo(
- -signal, eps, p, weights=weights, axis=axis, keepdims=keepdims
- )
-
-
-def gmsr_min_fast(signal, eps, p):
- # TODO: (norrisg) allow `axis` specification
-
- # Split indices into positive and non-positive values
- pos_idx = signal > 0.0
- neg_idx = ~pos_idx
-
- weights = jnp.ones_like(signal)
-
- # Sum of all weights
- sum_w = weights.sum()
-
- # If there exists a negative element
- if signal[neg_idx].size > 0:
- sums = 0.0
- sums = jnp.sum(weights[neg_idx] * (signal[neg_idx] ** (2 * p)))
- Mp = (eps**p + (sums / sum_w)) ** (1 / p)
- h_min = eps**0.5 - Mp**0.5
-
- # If all values are positive
- else:
- mult = 1.0
- mult = jnp.prod(signal[pos_idx] ** (2 * weights[pos_idx]))
- M0 = (eps**sum_w + mult) ** (1 / sum_w)
- h_min = M0**0.5 - eps**0.5
-
- return jnp.reshape(h_min, (1, 1, 1))
-
-
-def gmsr_max_fast(signal, eps, p):
- return -gmsr_min_fast(-signal, eps, p)
-
-@functools.partial(jax.jit, static_argnames=("axis", "keepdims", "approx_method", "padding"))
-def maxish(signal, axis, keepdims=True, approx_method="true", temperature=None, **kwargs):
- """
- Function to compute max(ish) along an axis.
-
- Args:
- signal: A jnp.array or an Expression
- axis: (Int) axis along to compute max(ish)
- keepdims: (Bool) whether to keep the original array size. Defaults to True
- approx_method: (String) argument to choose the type of max(ish) approximation. possible choices are "true", "logsumexp", "softmax", "gmsr" (https://arxiv.org/abs/2405.10996).
- temperature: Optional, required for approx_method not True.
-
- Returns:
- jnp.array corresponding to the maxish
-
- Raises:
- If Expression does not have a value, or invalid approx_method
-
- """
-
- if isinstance(signal, Expression):
- assert (
- signal.value is not None
- ), "Input Expression does not have numerical values"
- signal = signal.value
-
- match approx_method:
- case "true":
- """jax keeps track of multiple max values and will distribute the gradients across all max values!
- e.g., jax.grad(jnp.max)(jnp.array([0.01, 0.0, 0.01])) # --> Array([0.5, 0. , 0.5], dtype=float32)
- """
- return jnp.max(signal, axis, keepdims=keepdims)
-
- case "logsumexp":
- """https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.special.logsumexp.html"""
- assert temperature is not None, "need a temperature value"
- return (
- jax.scipy.special.logsumexp(
- temperature * signal, axis=axis, keepdims=keepdims
- )
- / temperature
- )
-
- case "softmax":
- assert temperature is not None, "need a temperature value"
- return (jax.nn.softmax(temperature * signal, axis) * signal).sum(
- axis, keepdims=keepdims
- )
-
- case "gmsr":
- assert (
- temperature is not None
- ), "temperature tuple containing (eps, p) is required"
- (eps, p) = temperature
- return gmsr_max(signal, eps, p, axis=axis, keepdims=keepdims)
-
- case _:
- raise ValueError("Invalid approx_method")
-
-@functools.partial(jax.jit, static_argnames=("axis", "keepdims", "approx_method", "padding"))
-def minish(signal, axis, keepdims=True, approx_method="true", temperature=None, **kwargs):
- '''
- Same as maxish
- '''
- return -maxish(-signal, axis, keepdims, approx_method, temperature, **kwargs)
-
+from .utils import *
class STL_Formula:
'''
@@ -559,10 +380,6 @@ def separate_or(formula, input_, **kwargs):
formula: STL_formula
input_: input of STL_formula
"""
- # if isinstance(input_, tuple):
- # return jnp.concatenate([Or.separate_or(formula.subformula1, input_[0], **kwargs), Or.separate_or(formula.subformula2, input_[1], **kwargs)], axis=-1)
- # else:
- # return jnp.concatenate([Or.separate_or(formula.subformula1, input_, **kwargs), Or.separate_or(formula.subformula2, input_, **kwargs)], axis=-1)
if formula.__class__.__name__ != "Or":
return jnp.expand_dims(formula(input_, **kwargs), -1)
@@ -635,212 +452,117 @@ def _next_function(self):
def __str__(self):
return "(" + str(self.subformula1) + ") ⇒ (" + str(self.subformula2) + ")"
-class Temporal_Operator(STL_Formula):
- """
- Class to compute Eventually and Always. This builds a recurrent cell to perform dynamic programming
+class TemporalOperator(STL_Formula):
- Args:
- subformula: The subformula that the temporal operator is applied to.
- interval: The time interval that the temporal operator operates on. Default: None which means [0, jnp.inf]. Other options car be: [a, b] (b < jnp.inf), [a, jnp.inf] (a > 0)
-
- NOTE: Assume that the interval is describing the INDICES of the desired time interval. The user is responsible for converting the time interval (in time units) into indices (integers) using knowledge of the time step size.
- """
def __init__(self, subformula, interval=None):
super().__init__()
self.subformula = subformula
self.interval = interval
- self._interval = [0, jnp.inf] if self.interval is None else self.interval
- self.hidden_dim = 1 if not self.interval else self.interval[-1] # hidden_dim=1 if interval is [0, ∞) otherwise hidden_dim=end of interval
- if self.hidden_dim == jnp.inf:
- self.hidden_dim = self.interval[0]
- self.steps = 1 if not self.interval else self.interval[-1] - self.interval[0] + 1 # steps=1 if interval is [0, ∞) otherwise steps=length of interval
+
+ if self.interval is None:
+ self.hidden_dim = None
+ self._interval = None
+ elif interval[1] == jnp.inf:
+ self.hidden_dim = None
+ self._interval = [interval[0], interval[1]]
+ else:
+ self.hidden_dim = interval[1] + 1
+ self._interval = [interval[0], interval[1]]
+
+
+ self.LARGE_NUMBER = 1E9
self.operation = None
- # Matrices that shift a vector and add a new entry at the end.
- self.M = jnp.diag(jnp.ones(self.hidden_dim-1), k=1)
- self.b = jnp.zeros(self.hidden_dim)
- self.b = self.b.at[-1].set(1)
+ def _get_interval_indices(self):
+ start_idx = -self.hidden_dim
+ end_idx = -self._interval[0]
- def _initialize_hidden_state(self, signal):
- """
- Compute the initial hidden state.
+ return start_idx, (None if end_idx == 0 else end_idx)
- Args:
- signal: the input signal. Expected size [time_dim,]
+ def _run_cell(self, signal, padding=None, **kwargs):
- Returns:
- h0: initial hidden state is [hidden_dim,]
+ hidden_state = self._initialize_hidden_state(signal, padding=padding) # [hidden_dim]
+ def f_(hidden, state):
+ hidden, o = self._cell(state, hidden, **kwargs)
+ return hidden, o
- Notes:
- Initializing the hidden state requires padding on the signal. Currently, the default is to extend the last value.
- TODO: have option on this padding
+ _, outputs_stack = jax.lax.scan(f_, hidden_state, signal)
+ return outputs_stack
- """
- # Case 1, 2, 4
- # TODO: make this less hard-coded. Assumes signal is [bs, time_dim, signal_dim], and already reversed
- # pads with the signal value at the last time step.
- y = jax.lax.stop_gradient(signal[:1])
- h0 = jnp.ones([self.hidden_dim, *signal.shape[1:]])*y
-
- # Case 3: if self.interval is [a, jnp.inf), then the hidden state is a tuple (like in an LSTM)
- if (self._interval[1] == jnp.inf) & (self._interval[0] > 0):
- c0 = signal[:1]
- return (c0, h0)
- return h0
+ def _initialize_hidden_state(self, signal, padding=None):
+ if padding == "last":
+ pad_value = jax.lax.stop_gradient(signal)[0]
+ elif padding == "mean":
+ pad_value = jax.lax.stop_gradient(signal).mean(0)
+ else:
+ pad_value = -self.LARGE_NUMBER
- def _cell(self, x, hidden_state, **kwargs):
- """
- This function describes the operation that takes place at each recurrent step.
- Args:
- x: the input state at time t [batch_size, 1, ...]
- hidden_state: the hidden state. It is either a tensor, or a tuple of tensors, depending on the interval chosen and other arguments. Generally, the hidden state is of size [batch_size, hidden_dim,...]
+ n_time_steps = signal.shape[0]
- Return:
- output and next hidden_state
- """
- raise NotImplementedError("_cell is not implemented")
+ # compute hidden dim if signal length was needed
+ if (self.interval is None) or (self.interval[1] == jnp.inf):
+ self.hidden_dim = n_time_steps
+ if self.interval is None:
+ self._interval = [0, n_time_steps - 1]
+ elif self.interval[1] == jnp.inf:
+ self._interval[1] = n_time_steps - 1
- def _run_cell(self, signal, **kwargs):
- """
- Function to run a signal through a cell T times, where T is the length of the signal in the time dimension
+ self.M = jnp.diag(jnp.ones(self.hidden_dim-1), k=1)
+ self.b = jnp.zeros(self.hidden_dim)
+ self.b = self.b.at[-1].set(1)
- Args:
- signal: input signal, size = [time_dim,]
- time_dim: axis corresponding to time_dim. Default: 0
- kwargs: Other arguments including time_dim, approx_method, temperature
+ if (self.interval is None) or (self.interval[1] == jnp.inf):
+ pad_value = jnp.concatenate([jnp.ones(self._interval[0] + 1) * pad_value, jnp.ones(self.hidden_dim - self._interval[0] - 1) * self.sign * pad_value])
- Return:
- outputs: list of outputs
- states: list of hidden_states
- """
- time_dim = 0 # assuming signal is [time_dim,...]
- outputs = []
- states = []
- hidden_state = self._initialize_hidden_state(signal) # [hidden_dim]
- signal_split = jnp.split(signal, signal.shape[time_dim], time_dim) # list of x at each time step
- for i in range(signal.shape[time_dim]):
- o, hidden_state = self._cell(signal_split[i], hidden_state, **kwargs)
- outputs.append(o)
- states.append(hidden_state)
- return outputs, states
+ h0 = jnp.ones(self.hidden_dim) * pad_value
+ return h0
- def robustness_trace(self, signal, **kwargs):
- """
- Function to compute robustness trace of a temporal STL formula
- First, compute the robustness trace of the subformula, and use that as the input for the recurrent computation
+ def _cell(self, state, hidden, **kwargs):
- Args:
- signal: input signal, size = [bs, time_dim, ...]
- time_dim: axis corresponding to time_dim. Default: 1
- kwargs: Other arguments including time_dim, approx_method, temperature
+ h_new = self.M @ hidden + self.b * state
+ start_idx, end_idx = self._get_interval_indices()
+ output = self.operation(h_new[start_idx:end_idx], axis=0, keepdims=False, **kwargs)
+
+ return h_new, output
+
+
+ def robustness_trace(self, signal, padding=None, **kwargs):
- Returns:
- robustness_trace: jnp.array. Same size as signal.
- """
- time_dim = 0 # assuming signal is [time_dim,...]
trace = self.subformula(signal, **kwargs)
- outputs, _ = self._run_cell(trace, **kwargs)
- return jnp.concatenate(outputs, axis=time_dim) # [time_dim, ]
+ outputs = self._run_cell(trace, padding, **kwargs)
+ return outputs
+
+ def robustness(self, signal, **kwargs):
+ return self.__call__(signal, **kwargs)[-1]
+
def _next_function(self):
- """ next function is the input subformula. For visualization purposes """
return [self.subformula]
-class AlwaysRecurrent(Temporal_Operator):
- """
- The Always STL formula □_[a,b] subformula
- The robustness value is the minimum value of the input trace over a prespecified time interval
+class AlwaysRecurrent(TemporalOperator):
- Args:
- subformula: subformula that the Always operation is applied on
- interval: time interval [a,b] where a, b are indices along the time dimension. It is up to the user to keep track of what the timestep size is.
- """
def __init__(self, subformula, interval=None):
super().__init__(subformula=subformula, interval=interval)
-
- def _cell(self, x, hidden_state, **kwargs):
- """
- see Temporal_Operator._cell
- """
- time_dim = 0 # assuming signal is [time_dim,...]
- # Case 1, interval = [0, inf]
- if self.interval is None:
- input_ = jnp.concatenate([hidden_state, x], axis=time_dim) # [rnn_dim+1,]
- output = minish(input_, time_dim, keepdims=True, **kwargs) # [1,]
- return output, output
-
- # Case 3: self.interval is [a, np.inf)
- if (self._interval[1] == jnp.inf) & (self._interval[0] > 0):
- c, h = hidden_state
- ch = jnp.concatenate([c, h[:1]], axis=time_dim) # [2,]
- output = minish(ch, time_dim, keepdims=True, **kwargs) # [1,]
- hidden_state_ = (output, self.M @ h + self.b * x)
-
- # Case 2 and 4: self.interval is [a, b]
- else:
- hidden_state_ = self.M @ hidden_state + self.b * x
- hx = jnp.concatenate([hidden_state, x], axis=time_dim) # [rnn_dim+1,]
- input_ = hx[:self.steps] # [self.steps,]
- output = minish(input_, time_dim, **kwargs) # [1,]
- return output, hidden_state_
+ self.operation = minish
+ self.sign = -1.
def __str__(self):
return "◻ " + str(self._interval) + "( " + str(self.subformula) + " )"
+class EventuallyRecurrent(TemporalOperator):
-class EventuallyRecurrent(Temporal_Operator):
- """
- The Eventually STL formula ♢_[a,b] subformula
- The robustness value is the minimum value of the input trace over a prespecified time interval
-
- Args:
- subformula: subformula that the Eventually operation is applied on
- interval: time interval [a,b] where a, b are indices along the time dimension. It is up to the user to keep track of what the timestep size is.
- """
def __init__(self, subformula, interval=None):
super().__init__(subformula=subformula, interval=interval)
-
- def _cell(self, x, hidden_state, **kwargs):
- """
- see Temporal_Operator._cell
- """
- time_dim = 0 # assuming signal is [time_dim,...]
- # Case 1, interval = [0, inf]
- if self.interval is None:
- input_ = jnp.concatenate([hidden_state, x], axis=time_dim) # [rnn_dim+1, ]
- output = maxish(input_, time_dim, keepdims=True, **kwargs) # [1, ]
- return output, output
-
- # Case 3: self.interval is [a, np.inf)
- if (self._interval[1] == jnp.inf) & (self._interval[0] > 0):
- c, h = hidden_state
- ch = jnp.concatenate([c, h[:1]], axis=time_dim) # [2, ]
- output = maxish(ch, time_dim, keepdims=True, **kwargs) # [1, ]
- hidden_state_ = (output, self.M @ h + self.b * x)
-
- # Case 2 and 4: self.interval is [a, b]
- else:
- hidden_state_ = self.M @ hidden_state + self.b * x
- hx = jnp.concatenate([hidden_state, x], axis=time_dim) # [rnn_dim+1, ]
- input_ = hx[:self.steps] # [self.steps, ]
- output = maxish(input_, time_dim, **kwargs) # [1, ]
- return output, hidden_state_
+ self.operation = maxish
+ self.sign = 1.
def __str__(self):
return "♢ " + str(self._interval) + "( " + str(self.subformula) + " )"
class UntilRecurrent(STL_Formula):
- """
- The Until STL operator U. Subformula1 U_[a,b] subformula2
- Arg:
- subformula1: subformula for lhs of the Until operation
- subformula2: subformula for rhs of the Until operation
- interval: time interval [a,b] where a, b are indices along the time dimension. It is up to the user to keep track of what the timestep is.
- overlap: If overlap=True, then the last time step that ϕ is true, ψ starts being true. That is, sₜ ⊧ ϕ and sₜ ⊧ ψ at a common time t. If overlap=False, when ϕ stops being true, ψ starts being true. That is sₜ ⊧ ϕ and sₜ+₁ ⊧ ψ, but sₜ ¬⊧ ψ
- """
def __init__(self, subformula1, subformula2, interval=None, overlap=True):
super().__init__()
@@ -849,64 +571,115 @@ def __init__(self, subformula1, subformula2, interval=None, overlap=True):
self.interval = interval
if overlap == False:
self.subformula2 = Eventually(subformula=subformula2, interval=[0,1])
- self.LARGE_NUMBER = 1E6
+ self.LARGE_NUMBER = 1E9
+ # self.Alw = AlwaysRecurrent(subformula=Identity(name=str(self.subformula1))
+ self.Alw = AlwaysRecurrent(Predicate('x', lambda x: x) > 0.)
- def robustness_trace(self, signal, **kwargs):
- """
- Computing robustness trace of subformula1 U subformula2 (see paper)
-
- Args:
- signal: input signal for the formula. If using Expressions to define the formula, then inputs a tuple of signals corresponding to each subformula. If using Predicates to define the formula, then inputs is just a single jnp.array. Not need for different signals for each subformula. Expected signal is size [batch_size, time_dim, x_dim]
- time_dim: axis for time_dim. Default: 1
- kwargs: Other arguments including time_dim, approx_method, temperature
-
- Returns:
- robustness_trace: jnp.array. Same size as signal.
- """
+ if self.interval is None:
+ self.hidden_dim = None
+ elif interval[1] == jnp.inf:
+ self.hidden_dim = None
+ else:
+ self.hidden_dim = interval[1] + 1
- # TODO (karenl7) this really assumes axis=1 is the time dimension. Can this be generalized?
+ def _initialize_hidden_state(self, signal, padding=None, **kwargs):
time_dim = 0 # assuming signal is [time_dim,...]
- LARGE_NUMBER = self.LARGE_NUMBER
if isinstance(signal, tuple):
# for formula defined using Expression
assert signal[0].shape[time_dim] == signal[1].shape[time_dim]
trace1 = self.subformula1(signal[0], **kwargs)
trace2 = self.subformula2(signal[1], **kwargs)
- n_time_steps = signal[0].shape[time_dim] # TODO: WIP
+ n_time_steps = signal[0].shape[time_dim]
else:
# for formula defined using Predicate
trace1 = self.subformula1(signal, **kwargs)
trace2 = self.subformula2(signal, **kwargs)
- n_time_steps = signal.shape[time_dim] # TODO: WIP
-
- Alw = Always(subformula=Identity(name=str(self.subformula1)))
- LHS = jnp.permute_dims(jnp.repeat(jnp.expand_dims(trace2, -1), n_time_steps, axis=-1), [1,0]) # [sub_signal, t_prime]
- RHS = jnp.ones_like(LHS) * -LARGE_NUMBER # [sub_signal, t_prime]
-
- # Case 1, interval = [0, inf]
- if self.interval == None:
- for i in range(n_time_steps):
- RHS = RHS.at[i:,i].set(Alw(trace1[i:]))
-
- # Case 2 and 4: self.interval is [a, b], a ≥ 0, b < ∞
- elif self.interval[1] < jnp.inf:
- a = self.interval[0]
- b = self.interval[1]
- for i in range(n_time_steps):
- end = i+b+1
- RHS = RHS.at[i+a:end,i].set(Alw(trace1[i:end])[a:])
-
- # Case 3: self.interval is [a, np.inf), a ≂̸ 0
+ n_time_steps = signal.shape[time_dim]
+
+ # compute hidden dim if signal length was needed
+ if self.hidden_dim is None:
+ self.hidden_dim = n_time_steps
+ if self.interval is None:
+ self.interval = [0, n_time_steps - 1]
+ elif self.interval[1] == jnp.inf:
+ self.interval[1] = n_time_steps - 1
+
+ self.ones_array = jnp.ones(self.hidden_dim)
+
+ # set shift operations given hidden_dim
+ self.M = jnp.diag(jnp.ones(self.hidden_dim-1), k=1)
+ self.b = jnp.zeros(self.hidden_dim)
+ self.b = self.b.at[-1].set(1)
+
+ if self.hidden_dim == n_time_steps:
+ pad_value = self.LARGE_NUMBER
else:
- a = self.interval[0]
- for i in range(n_time_steps):
- RHS = RHS.at[i+a:,i].set(Alw(trace1[i:])[a:])
+ pad_value = -self.LARGE_NUMBER
+
+ h1 = pad_value * self.ones_array
+ h2 = -self.LARGE_NUMBER * self.ones_array
+ return (h1, h2), trace1, trace2
+
+ def _get_interval_indices(self):
+ start_idx = -self.hidden_dim
+ end_idx = -self.interval[0]
+
+ return start_idx, (None if end_idx == 0 else end_idx)
+
+ def _cell(self, state, hidden, **kwargs):
+ x1, x2 = state
+ h1, h2 = hidden
+ h1_new = self.M @ h1 + self.b * x1
+ h1_min = jnp.flip(self.Alw(jnp.flip(h1_new), **kwargs))
+ h2_new = self.M @ h2 + self.b * x2
+ start_idx, end_idx = self._get_interval_indices()
+ z = minish(jnp.stack([h1_min, h2_new]), axis=0, keepdims=False, **kwargs)[start_idx:end_idx]
+
+ def g_(carry, x):
+ carry = maxish(jnp.array([carry, x]), axis=0, keepdims=False, **kwargs)
+ return carry, carry
+
+ output, _ = jax.lax.scan(g_, -self.LARGE_NUMBER, z)
- return maxish(minish(jnp.stack([LHS, RHS], axis=-1), axis=-1, keepdims=False, **kwargs), axis=-1, keepdims=False, **kwargs)
+ return output, (h1_new, h2_new)
+
+ def robustness_trace(self, signal, padding=None, **kwargs):
+ """
+ Function to run a signal through a cell T times, where T is the length of the signal in the time dimension
+
+ Args:
+ signal: input signal, size = [time_dim,]
+ time_dim: axis corresponding to time_dim. Default: 0
+ kwargs: Other arguments including time_dim, approx_method, temperature
+
+ Return:
+ outputs: list of outputs
+ states: list of hidden_states
+ """
+ hidden_state, trace1, trace2 = self._initialize_hidden_state(signal, padding=padding, **kwargs)
+ def f_(hidden, state):
+ o, hidden = self._cell(state, hidden, **kwargs)
+ return hidden, o
+ _, outputs_stack = jax.lax.scan(f_, hidden_state, jnp.stack([trace1, trace2], axis=1))
+ return outputs_stack
+
+
+ def robustness(self, signal, **kwargs):
+ """
+ Computes the robustness value. Extracts the last entry along time_dim of robustness trace.
+
+ Args:
+ signal: jnp.array or Expression. Expected size [bs, time_dim, state_dim]
+ kwargs: Other arguments including time_dim, approx_method, temperature
+
+ Return: jnp.array, same as input with the time_dim removed.
+ """
+ return self.__call__(signal, **kwargs)[-1]
+ # return jnp.rollaxis(self.__call__(signal, **kwargs), time_dim)[-1]
def _next_function(self):
""" next function is the input subformulas. For visualization purposes """
return [self.subformula1, self.subformula2]
@@ -916,7 +689,6 @@ def __str__(self):
-
class Expression:
name: str
value: jnp.array
@@ -1110,13 +882,17 @@ def __init__(self, subformula, interval=None):
self.subformula = subformula
self._interval = [0, jnp.inf] if self.interval is None else self.interval
- def robustness_trace(self, signal, padding="last", large_number=1E9, **kwargs):
+ def robustness_trace(self, signal, padding=None, large_number=1E9, **kwargs):
time_dim = 0 # assuming signal is [time_dim,...]
- signal = self.subformula(signal, time_dim=time_dim, padding=padding, large_number=large_number, **kwargs)
+ signal = self.subformula(signal, padding=padding, large_number=large_number, **kwargs)
T = signal.shape[time_dim]
mask_value = -large_number
+ offset = 0
if self.interval is None:
interval = [0,T-1]
+ elif self.interval[1] == jnp.inf:
+ interval = [self.interval[0], T-1]
+ offset = self.interval[0]
else:
interval = self.interval
signal_matrix = signal.reshape([T,1]) @ jnp.ones([1,T])
@@ -1125,11 +901,11 @@ def robustness_trace(self, signal, padding="last", large_number=1E9, **kwargs):
elif padding == "mean":
pad_value = signal.mean(time_dim)
else:
- pad_value = padding
+ pad_value = mask_value
signal_pad = jnp.ones([interval[1]+1, T]) * pad_value
signal_padded = jnp.concatenate([signal_matrix, signal_pad], axis=time_dim)
subsignal_mask = jnp.tril(jnp.ones([T + interval[1]+1,T]))
- time_interval_mask = jnp.triu(jnp.ones([T + interval[1]+1,T]), -interval[-1]) * jnp.tril(jnp.ones([T + interval[1]+1,T]), -interval[0])
+ time_interval_mask = jnp.triu(jnp.ones([T + interval[1]+1,T]), -interval[-1]-offset) * jnp.tril(jnp.ones([T + interval[1]+1,T]), -interval[0])
masked_signal_matrix = jnp.where(time_interval_mask * subsignal_mask, signal_padded, mask_value)
return maxish(masked_signal_matrix, axis=time_dim, keepdims=False, **kwargs)
@@ -1149,26 +925,40 @@ def __init__(self, subformula, interval=None):
self.subformula = subformula
self._interval = [0, jnp.inf] if self.interval is None else self.interval
- def robustness_trace(self, signal, padding="last", large_number=1E9, **kwargs):
+ def robustness_trace(self, signal, padding=None, large_number=1E9, **kwargs):
time_dim = 0 # assuming signal is [time_dim,...]
signal = self.subformula(signal, padding=padding, large_number=large_number, **kwargs)
T = signal.shape[time_dim]
mask_value = large_number
- if self.interval is None:
- interval = [0,T-1]
- else:
- interval = self.interval
+ sign = 1.
+ offset = 0
+ # if self.interval is None:
+ # interval = [0,T-1]
+ # sign = -1.
+ def true_func(_interval, T):
+ return [_interval[0], T-1], -1., _interval[0]
+ def false_func(_interval, T):
+ return _interval, 1., 0
+ operands = (self._interval, T,)
+ interval, sign, offset = cond(self._interval[1] == jnp.inf, true_func, false_func, *operands)
+ # if self._interval[1] == jnp.inf:
+ # interval = [self.interval[0], T-1]
+ # sign = -1.
+ # offset = self.interval[0]
+ # else:
+ # interval = self.interval
signal_matrix = signal.reshape([T,1]) @ jnp.ones([1,T])
if padding == "last":
pad_value = signal[-1]
elif padding == "mean":
pad_value = signal.mean(time_dim)
else:
- pad_value = padding
- signal_pad = jnp.ones([interval[1]+1, T]) * pad_value
+ pad_value = -large_number
+ signal_pad = jnp.concatenate([jnp.ones([interval[1], T]) * sign * pad_value, jnp.ones([1, T]) * pad_value], axis=time_dim)
signal_padded = jnp.concatenate([signal_matrix, signal_pad], axis=time_dim)
subsignal_mask = jnp.tril(jnp.ones([T + interval[1]+1,T]))
- time_interval_mask = jnp.triu(jnp.ones([T + interval[1]+1,T]), -interval[-1]) * jnp.tril(jnp.ones([T + interval[1]+1,T]), -interval[0])
+
+ time_interval_mask = jnp.triu(jnp.ones([T + interval[1]+1,T]), -interval[-1]-offset) * jnp.tril(jnp.ones([T + interval[1]+1,T]), -interval[0])
masked_signal_matrix = jnp.where(time_interval_mask * subsignal_mask, signal_padded, mask_value)
return minish(masked_signal_matrix, axis=time_dim, keepdims=False, **kwargs)
@@ -1189,7 +979,7 @@ def __init__(self, subformula1, subformula2, interval=None):
self._interval = [0, jnp.inf] if self.interval is None else self.interval
- def robustness_trace(self, signal, padding="last", large_number=1E9, **kwargs):
+ def robustness_trace(self, signal, padding=None, large_number=1E9, **kwargs):
time_dim = 0 # assuming signal is [time_dim,...]
if isinstance(signal, tuple):
signal1, signal2 = signal
@@ -1205,6 +995,8 @@ def robustness_trace(self, signal, padding="last", large_number=1E9, **kwargs):
mask_value = large_number
if self.interval is None:
interval = [0,T-1]
+ elif self.interval[1] == jnp.inf:
+ interval = [self.interval[0], T-1]
else:
interval = self.interval
signal1_matrix = signal1.reshape([T,1]) @ jnp.ones([1,T])
@@ -1216,8 +1008,8 @@ def robustness_trace(self, signal, padding="last", large_number=1E9, **kwargs):
signal1_pad = jnp.ones([interval[1]+1, T]) * signal1.mean(time_dim)
signal2_pad = jnp.ones([interval[1]+1, T]) * signal2.mean(time_dim)
else:
- signal1_pad = jnp.ones([interval[1]+1, T]) * padding
- signal2_pad = jnp.ones([interval[1]+1, T]) * padding
+ signal1_pad = jnp.ones([interval[1]+1, T]) * -mask_value
+ signal2_pad = jnp.ones([interval[1]+1, T]) * -mask_value
signal1_padded = jnp.concatenate([signal1_matrix, signal1_pad], axis=time_dim)
signal2_padded = jnp.concatenate([signal2_matrix, signal2_pad], axis=time_dim)
@@ -1227,11 +1019,89 @@ def robustness_trace(self, signal, padding="last", large_number=1E9, **kwargs):
phi2_mask = jnp.stack([jnp.triu(jnp.ones([T + interval[1]+1,T]), -end_idx) * jnp.tril(jnp.ones([T + interval[1]+1,T]), -end_idx) for end_idx in range(interval[0], interval[-1]+1)], 0)
phi1_masked_signal = jnp.stack([jnp.where(m1, signal1_padded, mask_value) for m1 in phi1_mask], 0)
phi2_masked_signal = jnp.stack([jnp.where(m2, signal2_padded, mask_value) for m2 in phi2_mask], 0)
- return maxish(jnp.stack([minish(jnp.stack([minish(s1, axis=0, keepdims=False), minish(s2, axis=0, keepdims=False)], axis=0), axis=0, keepdims=False) for (s1, s2) in zip(phi1_masked_signal, phi2_masked_signal)], axis=0), axis=0, keepdims=False)
+ return maxish(jnp.stack([minish(jnp.stack([minish(s1, axis=0, keepdims=False, **kwargs), minish(s2, axis=0, keepdims=False, **kwargs)], axis=0), axis=0, keepdims=False, **kwargs) for (s1, s2) in zip(phi1_masked_signal, phi2_masked_signal)], axis=0), axis=0, keepdims=False, **kwargs)
+
+
def _next_function(self):
""" next function is the input subformula. For visualization purposes """
return [self.subformula1, self.subformula2]
def __str__(self):
- return "(" + str(self.subformula1) + ")" + " U " + str(self._interval) + "(" + str(self.subformula2) + ")"
\ No newline at end of file
+ return "(" + str(self.subformula1) + ")" + " U " + str(self._interval) + "(" + str(self.subformula2) + ")"
+
+
+class DifferentiableAlways(STL_Formula):
+ def __init__(self, subformula, interval=None):
+ super().__init__()
+
+ self.interval = interval
+ self.subformula = subformula
+ # self._interval = [0, jnp.inf] if self.interval is None else self.interval
+
+ def robustness_trace(self, signal, t_start, t_end, scale=1.0, padding=None, large_number=1E9, delta=1E-3, **kwargs):
+ time_dim = 0 # assuming signal is [time_dim,...]
+ signal = self.subformula(signal, padding=padding, large_number=large_number, **kwargs)
+ T = signal.shape[time_dim]
+ mask_value = large_number
+ signal_matrix = signal.reshape([T,1]) @ jnp.ones([1,T])
+ if padding == "last":
+ pad_value = signal[-1]
+ elif padding == "mean":
+ pad_value = signal.mean(time_dim)
+ else:
+ pad_value = -mask_value
+ signal_pad = jnp.ones([T, T]) * pad_value
+ signal_padded = jnp.concatenate([signal_matrix, signal_pad], axis=time_dim)
+ smooth_time_mask = smooth_mask(T, t_start, t_end, scale)
+ padded_smooth_time_mask = jnp.zeros([2 * T, T])
+ for t in range(T):
+ padded_smooth_time_mask = padded_smooth_time_mask.at[t:t+T,t].set(smooth_time_mask)
+
+ masked_signal_matrix = jnp.where(padded_smooth_time_mask > delta, signal_padded * padded_smooth_time_mask, mask_value)
+ return minish(masked_signal_matrix, axis=time_dim, keepdims=False, **kwargs)
+
+ def _next_function(self):
+ """ next function is the input subformula. For visualization purposes """
+ return [self.subformula]
+
+ def __str__(self):
+ return "◻ [a,b] ( " + str(self.subformula) + " )"
+
+
+class DifferentiableEventually(STL_Formula):
+ def __init__(self, subformula, interval=None):
+ super().__init__()
+
+ self.interval = interval
+ self.subformula = subformula
+ self._interval = [0, jnp.inf] if self.interval is None else self.interval
+
+ def robustness_trace(self, signal, t_start, t_end, scale=1.0, padding=None, large_number=1E9, delta=1E-3, **kwargs):
+ time_dim = 0 # assuming signal is [time_dim,...]
+ signal = self.subformula(signal, padding=padding, large_number=large_number, **kwargs)
+ T = signal.shape[time_dim]
+ mask_value = -large_number
+ signal_matrix = signal.reshape([T,1]) @ jnp.ones([1,T])
+ if padding == "last":
+ pad_value = signal[-1]
+ elif padding == "mean":
+ pad_value = signal.mean(time_dim)
+ else:
+ pad_value = mask_value
+ signal_pad = jnp.ones([T, T]) * pad_value
+ signal_padded = jnp.concatenate([signal_matrix, signal_pad], axis=time_dim)
+ smooth_time_mask = smooth_mask(T, t_start, t_end, scale)
+ padded_smooth_time_mask = jnp.zeros([2 * T, T])
+ for t in range(T):
+ padded_smooth_time_mask = padded_smooth_time_mask.at[t:t+T,t].set(smooth_time_mask)
+
+ masked_signal_matrix = jnp.where(padded_smooth_time_mask > delta, signal_padded * padded_smooth_time_mask, mask_value)
+ return maxish(masked_signal_matrix, axis=time_dim, keepdims=False, **kwargs)
+
+ def _next_function(self):
+ """ next function is the input subformula. For visualization purposes """
+ return [self.subformula]
+
+ def __str__(self):
+ return "♢ [a,b] ( " + str(self.subformula) + " )"
\ No newline at end of file
diff --git a/stljax/tests.py b/stljax/tests.py
new file mode 100644
index 0000000..69afcd9
--- /dev/null
+++ b/stljax/tests.py
@@ -0,0 +1,310 @@
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from stljax.formula import *
+from stljax.viz import *
+import functools
+
+jax.config.update("jax_enable_x64", True)
+
+def test_always(signal, interval, verbose=True, **kwargs):
+ def true_robustness_trace(signal, interval, **kwargs):
+ T = len(signal)
+ if (interval is None):
+ return jnp.stack([minish(signal[i:], axis=0, keepdims=False, **kwargs) for i in range(len(signal))])
+ else:
+ if interval[1] == jnp.inf:
+ large_number = 1E9
+ start = interval[0]
+ signal_padded = jnp.concat([signal, jnp.ones(T-1) * large_number, jnp.ones(start) * -large_number])
+ return jnp.stack([minish(signal_padded[i+start:i+start+T], axis=0, keepdims=False, **kwargs) for i in range(len(signal))])
+ else:
+ large_number = 1E9
+ signal_padded = jnp.concat([signal, jnp.ones(interval[1]+1) * -large_number])
+ return jnp.stack([minish(signal_padded[i:][interval[0]:interval[1]+1], axis=0, keepdims=False, **kwargs) for i in range(len(signal))])
+
+ def true_robustness(signal, interval, **kwargs):
+ return true_robustness_trace(signal, interval, **kwargs)[0]
+
+ pred = Predicate("identity", lambda x: x)
+ phi = Always(pred > 0., interval=interval)
+ rob = phi.robustness(signal, **kwargs)
+ rob_trace = phi(signal, **kwargs)
+ rob_grad = jax.grad(phi.robustness)(signal, **kwargs)
+
+ true_trace = true_robustness_trace(signal, interval, **kwargs)
+ true_rob = true_robustness(signal, interval, **kwargs)
+ true_grad = jax.grad(true_robustness)(signal, interval, **kwargs)
+
+ rob_correct = jnp.isclose(rob, true_rob)
+ trace_correct = jnp.all(jnp.isclose(rob_trace, true_trace))
+ grad_correct = jnp.all(jnp.isclose(rob_grad, true_grad))
+
+ pass_n = 0
+
+ if rob_correct:
+ if verbose: print("\u2713 Always robustness value match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Always robustness value does not match expected answer")
+
+ if trace_correct:
+ if verbose: print("\u2713 Always robustness trace match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Always robustness trace does not match expected answer")
+
+ if grad_correct:
+ if verbose: print("\u2713 Always robustness gradient match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Always robustness gradient does not match expected answer")
+
+ print("%i/3 test passed for Always formula\n"%pass_n)
+
+
+def test_eventually(signal, interval, verbose=True, **kwargs):
+ def true_robustness_trace(signal, interval, **kwargs):
+ T = len(signal)
+ if (interval is None):
+ return jnp.stack([maxish(signal[i:], axis=0, keepdims=False, **kwargs) for i in range(len(signal))])
+ else:
+ if interval[1] == jnp.inf:
+ large_number = 1E9
+ start = interval[0]
+ signal_padded = jnp.concat([signal, jnp.ones(T) * -large_number])
+ return jnp.stack([maxish(signal_padded[i+start:i+T], axis=0, keepdims=False, **kwargs) for i in range(len(signal))])
+ else:
+ large_number = 1E9
+ signal_padded = jnp.concat([signal, jnp.ones(interval[1]+1) * -large_number])
+ return jnp.stack([maxish(signal_padded[i:][interval[0]:interval[1]+1], axis=0, keepdims=False, **kwargs) for i in range(len(signal))])
+
+ def true_robustness(signal, interval, **kwargs):
+ return true_robustness_trace(signal, interval, **kwargs)[0]
+
+ pred = Predicate("identity", lambda x: x)
+ phi = Eventually(pred > 0., interval=interval)
+ rob = phi.robustness(signal, **kwargs)
+ rob_trace = phi(signal, **kwargs)
+ rob_grad = jax.grad(phi.robustness)(signal, **kwargs)
+
+ true_trace = true_robustness_trace(signal, interval, **kwargs)
+ true_rob = true_robustness(signal, interval, **kwargs)
+ true_grad = jax.grad(true_robustness)(signal, interval, **kwargs)
+
+ rob_correct = jnp.isclose(rob, true_rob)
+ trace_correct = jnp.all(jnp.isclose(rob_trace, true_trace))
+ grad_correct = jnp.all(jnp.isclose(rob_grad, true_grad))
+
+ pass_n = 0
+
+ if rob_correct:
+ if verbose: print("\u2713 Eventually robustness value match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Eventually robustness value does not match expected answer")
+
+ if trace_correct:
+ if verbose: print("\u2713 Eventually robustness trace match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Eventually robustness trace does not match expected answer")
+
+ if grad_correct:
+ if verbose: print("\u2713 Eventually robustness gradient match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Eventually robustness gradient does not match expected answer")
+
+ print("%i/3 test passed for Eventually formula\n"%pass_n)
+
+
+def test_until(signal, interval, verbose=True, **kwargs):
+ pred = Predicate("identity", lambda x: x)
+ phi = Until(pred > 0., pred < 0., interval=interval)
+
+ def true_robustness_trace(signal, interval, **kwargs):
+ signal1, signal2 = phi.subformula1(signal), phi.subformula2(signal)
+ T = len(signal1)
+ if interval is None:
+ interval = [0,T-1]
+ elif interval[1] == jnp.inf:
+ interval = [interval[0], T-1]
+ large_number = 1E9
+ signal1_padded = jnp.concat([signal1, jnp.ones_like(signal1) * -large_number])
+ signal2_padded = jnp.concat([signal2, jnp.ones_like(signal2) * -large_number])
+ return jnp.stack([maxish(jnp.stack([minish(jnp.stack([minish(signal1_padded[i:][:t+1], axis=0, keepdims=False, **kwargs), signal2_padded[i:][t]]), axis=0, keepdims=False, **kwargs) for t in range(interval[0],interval[-1]+1)]), axis=0, keepdims=False, **kwargs) for i in range(T)])
+
+ def true_robustness(signal, interval, **kwargs):
+ return true_robustness_trace(signal, interval, **kwargs)[0]
+
+
+ pred = Predicate("identity", lambda x: x)
+ phi = Until(pred > 0., pred < 0, interval=interval)
+ rob = phi.robustness(signal, **kwargs)
+ rob_trace = phi(signal, **kwargs)
+ rob_grad = jax.grad(phi.robustness)(signal, **kwargs)
+
+ true_trace = true_robustness_trace(signal, interval, **kwargs)
+ true_rob = true_robustness(signal, interval, **kwargs)
+ true_grad = jax.grad(true_robustness)(signal, interval, **kwargs)
+
+ rob_correct = jnp.isclose(rob, true_rob, atol=1E-5)
+ trace_correct = jnp.all(jnp.isclose(rob_trace, true_trace, atol=1E-5))
+ grad_diff = jnp.linalg.norm(rob_grad - true_grad)
+ grad_correct = jnp.all(jnp.isclose(rob_grad, true_grad, atol=1E-5))
+
+ pass_n = 0
+
+ if rob_correct:
+ if verbose: print("\u2713 Until robustness value match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Until robustness value does not match expected answer")
+
+ if trace_correct:
+ if verbose: print("\u2713 Until robustness trace match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Until robustness trace does not match expected answer")
+
+ if grad_correct:
+ if verbose: print("\u2713 Until robustness gradient match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Until robustness gradient does not match expected answer")
+ print(grad_diff)
+ print(rob_grad, true_grad)
+
+ print("%i/3 test passed for Until formula\n"%pass_n)
+
+
+def test_always_mask_recurrent(signal, interval, verbose=True, **kwargs):
+ signal_flip = jnp.flip(signal)
+ pred = Predicate("identity", lambda x: x)
+ mask = Always(pred > 0., interval)
+ rec = AlwaysRecurrent(pred > 0., interval)
+
+ rob_correct = jnp.isclose(mask.robustness(signal, **kwargs), rec.robustness(signal_flip, **kwargs))
+ trace_correct = jnp.all(jnp.isclose(mask(signal, **kwargs), jnp.flip(rec(signal_flip, **kwargs))))
+ grad_correct = jnp.all(jnp.isclose(jax.grad(mask.robustness)(signal, **kwargs), jnp.flip(jax.grad(rec.robustness)(signal_flip, **kwargs))))
+
+ pass_n = 0
+
+ if rob_correct:
+ if verbose: print("\u2713 Always mask vs recurrent robustness value match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Always mask vs recurrent robustness value does not match expected answer")
+
+ if trace_correct:
+ if verbose: print("\u2713 Always mask vs recurrent robustness trace match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Always mask vs recurrent robustness trace does not match expected answer")
+
+ if grad_correct:
+ if verbose: print("\u2713 Always mask vs recurrent robustness gradient match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Always mask vs recurrent robustness gradient does not match expected answer")
+
+
+ print("%i/3 test passed for Always mask vs recurrent formula\n"%pass_n)
+
+
+def test_eventually_mask_recurrent(signal, interval, verbose=True, **kwargs):
+ signal_flip = jnp.flip(signal)
+ pred = Predicate("identity", lambda x: x)
+ mask = Eventually(pred > 0., interval)
+ rec = EventuallyRecurrent(pred > 0., interval)
+
+ rob_correct = jnp.isclose(mask.robustness(signal, **kwargs), rec.robustness(signal_flip, **kwargs))
+ trace_correct = jnp.all(jnp.isclose(mask(signal, **kwargs), jnp.flip(rec(signal_flip, **kwargs))))
+ grad_correct = jnp.all(jnp.isclose(jax.grad(mask.robustness)(signal, **kwargs), jnp.flip(jax.grad(rec.robustness)(signal_flip, **kwargs))))
+
+ pass_n = 0
+
+ if rob_correct:
+ if verbose: print("\u2713 Eventually mask vs recurrent robustness value match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Eventually mask vs recurrent robustness value does not match expected answer")
+
+ if trace_correct:
+ if verbose: print("\u2713 Eventually mask vs recurrent robustness trace match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Eventually mask vs recurrent robustness trace does not match expected answer")
+
+ if grad_correct:
+ if verbose: print("\u2713 Eventually mask vs recurrent robustness gradient match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Eventually mask vs recurrent robustness gradient does not match expected answer")
+
+
+ print("%i/3 test passed for Eventually mask vs recurrent formula\n"%pass_n)
+
+
+def test_until_mask_recurrent(signal, interval, verbose=True, **kwargs):
+ signal_flip = jnp.flip(signal)
+ pred = Predicate("identity", lambda x: x)
+ mask = Until(pred > 0., pred < 0., interval)
+ rec = UntilRecurrent(pred > 0., pred < 0., interval)
+
+ rob_correct = jnp.isclose(mask.robustness(signal, **kwargs), rec.robustness(signal_flip, **kwargs))
+ trace_correct = jnp.all(jnp.isclose(mask(signal, **kwargs), jnp.flip(rec(signal_flip, **kwargs))))
+ grad_correct = jnp.all(jnp.isclose(jax.grad(mask.robustness)(signal, **kwargs), jnp.flip(jax.grad(rec.robustness)(signal_flip, **kwargs))))
+
+ pass_n = 0
+
+ if rob_correct:
+ if verbose: print("\u2713 Until mask vs recurrent robustness value match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Until mask vs recurrent robustness value does not match expected answer")
+
+ if trace_correct:
+ if verbose: print("\u2713 Until mask vs recurrent robustness trace match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Until mask vs recurrent robustness trace does not match expected answer")
+
+ if grad_correct:
+ if verbose: print("\u2713 Until mask vs recurrent robustness gradient match expected answer")
+ pass_n += 1
+ else:
+ print("\u274c Until mask vs recurrent robustness gradient does not match expected answer")
+
+
+ print("%i/3 test passed for Until mask vs recurrent formula\n"%pass_n)
+
+def test_all_settings(test_func, verbose=True, T=10):
+ assert T > 0, "Pick a T larger than 10"
+ signal = jnp.array(np.random.randn(T)) * 1.0
+ # None, [0,b], [a,b], [0, inf], [a, inf]
+ interval_list = [None, [0, T//3], [T//4, T//2], [0, jnp.inf], [T//4, jnp.inf]]
+ approx_method_list = ["true", "logsumexp"]
+ temperature_list = [1., 10., 20., 100.]
+
+ for interval in interval_list:
+ for approx_method in approx_method_list:
+ for temperature in temperature_list:
+ kwargs = {"approx_method": approx_method,
+ "temperature": temperature
+ }
+ print(f"int={interval}\t temp={temperature} \t approx={approx_method} ")
+ test_func(signal, interval, verbose, **kwargs)
+
+if __name__ == "__main__":
+
+ test_all_settings(test_always, verbose=False, T=10)
+ test_all_settings(test_eventually, verbose=False, T=10)
+ test_all_settings(test_until, verbose=False, T=10)
+ test_all_settings(test_always_mask_recurrent, verbose=False, T=10)
+ test_all_settings(test_eventually_mask_recurrent, verbose=False, T=10)
+ test_all_settings(test_until_mask_recurrent, verbose=False, T=10)
+
+
diff --git a/stljax/utils.py b/stljax/utils.py
new file mode 100644
index 0000000..1413399
--- /dev/null
+++ b/stljax/utils.py
@@ -0,0 +1,206 @@
+import jax
+import jax.numpy as jnp
+import functools
+
+def cond(pred, true_fun, false_fun, *operands):
+ if pred:
+ return true_fun(*operands)
+ else:
+ return false_fun(*operands)
+
+def scan(f, init, xs, length=None):
+ if xs is None:
+ xs = [None] * length
+ carry = init
+ ys = []
+ for x in xs:
+ carry, y = f(carry, x)
+ ys.append(y)
+ return carry, jnp.stack(ys)
+
+
+def smooth_mask(T, t_start, t_end, scale):
+ xs = jnp.arange(T) * 1.
+ return jax.nn.sigmoid(scale * (xs - t_start * T)) - jax.nn.sigmoid(scale * (xs - t_end * T))
+
+def anneal(i):
+ return jax.nn.sigmoid(15 * (i - 0.5))
+
+
+@jax.jit
+def bar_plus(signal, p=2):
+ '''max(0,signal)**p'''
+ return jax.nn.relu(signal) ** p
+
+
+@jax.jit
+def bar_minus(signal, p=2):
+ '''min(0,signal)**p'''
+ return (-jax.nn.relu(-signal)) ** p
+
+@functools.partial(jax.jit, static_argnames=("axis", "keepdims"))
+def M0(signal, eps, weights=None, axis=1, keepdims=True):
+ '''Used in gsmr approx method, Eq 4(a) in https://arxiv.org/abs/2405.10996'''
+ if weights is None:
+ weights = jnp.ones_like(signal)
+ sum_w = weights.sum(axis, keepdims=keepdims)
+ return (
+ eps**sum_w + jnp.prod(signal**weights, axis=axis, keepdims=keepdims)
+ ) ** (1 / sum_w)
+
+@functools.partial(jax.jit, static_argnames=("axis", "keepdims"))
+def Mp(signal, eps, p, weights=None, axis=1, keepdims=True):
+ '''Used in gsmr approx method, Eq 4(b) in https://arxiv.org/abs/2405.10996'''
+ if weights is None:
+ weights = jnp.ones_like(signal)
+ sum_w = weights.sum(axis, keepdims=keepdims)
+ return (
+ eps**p + 1 / sum_w * jnp.sum(weights * signal**p, axis=axis, keepdims=keepdims)
+ ) ** (1 / p)
+
+@functools.partial(jax.jit, static_argnames=("axis", "keepdims"))
+def gmsr_min(signal, eps, p, weights=None, axis=1, keepdims=True):
+ '''Used in gsmr approx method, Eq 3 in https://arxiv.org/abs/2405.10996'''
+
+ return (
+ M0(bar_plus(signal, 2), eps, weights=weights, axis=axis, keepdims=keepdims)
+ ** 0.5
+ - Mp(
+ bar_minus(signal, 2), eps, p, weights=weights, axis=axis, keepdims=keepdims
+ )
+ ** 0.5
+ )
+
+@functools.partial(jax.jit, static_argnames=("axis", "keepdims"))
+def gmsr_max(signal, eps, p, weights=None, axis=1, keepdims=True):
+ '''Used in gsmr approx method, Eq 4(a) but for max in https://arxiv.org/abs/2405.10996'''
+
+ return -gmsr_min(-signal, eps, p, weights=weights, axis=axis, keepdims=keepdims)
+
+@functools.partial(jax.jit, static_argnames=("axis", "keepdims"))
+def gmsr_min_turbo(signal, eps, p, weights=None, axis=1, keepdims=True):
+ # TODO: (norrisg) make actually turbo (faster than normal `gmsr_min`)
+ pos_idx = signal > 0.0
+ neg_idx = ~pos_idx
+
+ return jnp.where(
+ neg_idx.sum(axis, keepdims=keepdims) > 0,
+ eps**0.5
+ - Mp(
+ bar_minus(signal, 2),
+ eps,
+ p,
+ weights=weights,
+ axis=axis,
+ keepdims=keepdims,
+ )
+ ** 0.5,
+ M0(bar_plus(signal, 2), eps, weights=weights, axis=axis, keepdims=keepdims)
+ ** 0.5
+ - eps**0.5,
+ )
+
+@functools.partial(jax.jit, static_argnames=("axis", "keepdims"))
+def gmsr_max_turbo(signal, eps, p, weights=None, axis=1, keepdims=True):
+ return -gmsr_min_turbo(
+ -signal, eps, p, weights=weights, axis=axis, keepdims=keepdims
+ )
+
+
+def gmsr_min_fast(signal, eps, p):
+ # TODO: (norrisg) allow `axis` specification
+
+ # Split indices into positive and non-positive values
+ pos_idx = signal > 0.0
+ neg_idx = ~pos_idx
+
+ weights = jnp.ones_like(signal)
+
+ # Sum of all weights
+ sum_w = weights.sum()
+
+ # If there exists a negative element
+ if signal[neg_idx].size > 0:
+ sums = 0.0
+ sums = jnp.sum(weights[neg_idx] * (signal[neg_idx] ** (2 * p)))
+ Mp = (eps**p + (sums / sum_w)) ** (1 / p)
+ h_min = eps**0.5 - Mp**0.5
+
+ # If all values are positive
+ else:
+ mult = 1.0
+ mult = jnp.prod(signal[pos_idx] ** (2 * weights[pos_idx]))
+ M0 = (eps**sum_w + mult) ** (1 / sum_w)
+ h_min = M0**0.5 - eps**0.5
+
+ return jnp.reshape(h_min, (1, 1, 1))
+
+
+def gmsr_max_fast(signal, eps, p):
+ return -gmsr_min_fast(-signal, eps, p)
+
+@functools.partial(jax.jit, static_argnames=("axis", "keepdims", "approx_method", "padding"))
+def maxish(signal, axis, keepdims=True, approx_method="true", temperature=None, **kwargs):
+ """
+ Function to compute max(ish) along an axis.
+
+ Args:
+ signal: A jnp.array or an Expression
+ axis: (Int) axis along to compute max(ish)
+ keepdims: (Bool) whether to keep the original array size. Defaults to True
+ approx_method: (String) argument to choose the type of max(ish) approximation. possible choices are "true", "logsumexp", "softmax", "gmsr" (https://arxiv.org/abs/2405.10996).
+ temperature: Optional, required for approx_method not True.
+
+ Returns:
+ jnp.array corresponding to the maxish
+
+ Raises:
+ If Expression does not have a value, or invalid approx_method
+
+ """
+
+ # if isinstance(signal, Expression):
+ # assert (
+ # signal.value is not None
+ # ), "Input Expression does not have numerical values"
+ # signal = signal.value
+
+ match approx_method:
+ case "true":
+ """jax keeps track of multiple max values and will distribute the gradients across all max values!
+ e.g., jax.grad(jnp.max)(jnp.array([0.01, 0.0, 0.01])) # --> Array([0.5, 0. , 0.5], dtype=float32)
+ """
+ return jnp.max(signal, axis, keepdims=keepdims)
+
+ case "logsumexp":
+ """https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.special.logsumexp.html"""
+ assert temperature is not None, "need a temperature value"
+ return (
+ jax.scipy.special.logsumexp(
+ temperature * signal, axis=axis, keepdims=keepdims
+ )
+ / temperature
+ )
+
+ case "softmax":
+ assert temperature is not None, "need a temperature value"
+ return (jax.nn.softmax(temperature * signal, axis) * signal).sum(
+ axis, keepdims=keepdims
+ )
+
+ case "gmsr":
+ assert (
+ temperature is not None
+ ), "temperature tuple containing (eps, p) is required"
+ (eps, p) = temperature
+ return gmsr_max(signal, eps, p, axis=axis, keepdims=keepdims)
+
+ case _:
+ raise ValueError("Invalid approx_method")
+
+@functools.partial(jax.jit, static_argnames=("axis", "keepdims", "approx_method", "padding"))
+def minish(signal, axis, keepdims=True, approx_method="true", temperature=None, **kwargs):
+ '''
+ Same as maxish
+ '''
+ return -maxish(-signal, axis, keepdims, approx_method, temperature, **kwargs)
diff --git a/test_timing.py b/test_timing.py
deleted file mode 100644
index 4ebc53c..0000000
--- a/test_timing.py
+++ /dev/null
@@ -1,113 +0,0 @@
-import jax
-import jax.numpy as jnp
-import numpy as np
-from stljax.formula import *
-from stljax.viz import *
-import matplotlib.pyplot as plt
-import timeit
-import statistics
-import pickle
-import sys
-
-if __name__ == "__main__":
-
- args = sys.argv[1:]
- filename = args[0]
- max_T = int(args[1])
-
-
- axis = 0
- pred_rev = Predicate('x', lambda x: x)
- interval = [2, 5]
- mask = Always(pred_rev > 4, interval=interval)
- recurrent = AlwaysRecurrent(pred_rev > 4, interval=interval)
-
-
-
- def grad_mask(signal):
- return jax.vmap(jax.grad(lambda x: mask(x).mean()))(signal)
- def grad_recurrent(signal):
- return jax.vmap(jax.grad(lambda x: recurrent(x).mean()))(signal)
-
- def mask_(signal):
- return jax.vmap(lambda x: mask(x).mean())(signal)
- def recurrent_(signal):
- return jax.vmap(lambda x: recurrent(x).mean())(signal)
-
-
- @jax.jit
- def grad_mask_jit(signal):
- return jax.vmap(jax.grad(lambda x: mask(x).mean()))(signal)
- @jax.jit
- def grad_recurrent_jit(signal):
- return jax.vmap(jax.grad(lambda x: recurrent(x).mean()))(signal)
- @jax.jit
- def mask_jit(signal):
- return jax.vmap(lambda x: mask(x).mean())(signal)
- @jax.jit
- def recurrent_jit(signal):
- return jax.vmap(lambda x: recurrent(x).mean())(signal)
-
- # Number of loops per run
- loops = 100
- # Number of runs
- runs = 25
- T = 2
-
- bs = 256
- means = []
- stds = []
- data = {}
-
- functions = ["mask_", "recurrent_", "grad_mask", "grad_recurrent", "mask_jit", "recurrent_jit", "grad_mask_jit", "grad_recurrent_jit"]
- # functions = ["mask_jit", "recurrent_jit", "grad_mask_jit", "grad_recurrent_jit"]
- # functions = ["mask_jit", "grad_mask_jit"]
-
- Ts = []
- data["functions"] = functions
- data["runs"] = runs
- data["loops"] = loops
- while T <= max_T:
- Ts.append(T)
- data['Ts'] = Ts
- print("running ", T)
- signal = jnp.array(np.random.random([bs, T]))
- times_list = []
- data[str(T)] = {}
-
- for f in functions:
- print("timing ", f)
- timeit.repeat(f + "(signal)", globals=globals(), repeat=1, number=1)
- times = timeit.repeat(f + "(signal)", globals=globals(), repeat=runs, number=loops)
- times_list.append(times)
- print("timing: ", statistics.mean(times), statistics.stdev(times))
- data[str(T)][f] = times
- with open(filename + '.pkl', 'wb') as f:
- pickle.dump(data, f)
-
- T *= 2
-
-
- # means = []
- # stds = []
- # for k in loaded_dict.keys():
- # if k in ["Ts", "functions"]:
- # break
- # mus = []
- # sts = []
- # for f in loaded_dict[k].keys():
- # mus.append(statistics.mean(loaded_dict[k][f])/loaded_dict["loops"])
- # sts.append(statistics.stdev(loaded_dict[k][f])/loaded_dict["loops"])
-
- # means.append(mus)
- # stds.append(sts)
- # means = np.array(means)
- # stds = np.array(stds)
-
- # plt.plot(loaded_dict["Ts"], means * 1E3)
- # plt.yscale("log")
- # plt.legend(loaded_dict["functions"])
- # plt.grid()
- # plt.xlabel("signal length")
- # plt.ylabel("computation time [ms]")
-