Skip to content

Commit

Permalink
Merge pull request #15 from karenl7/main
Browse files Browse the repository at this point in the history
incorporate many changes
  • Loading branch information
karenl7 authored Nov 27, 2024
2 parents 14e858a + 698c8c5 commit ff13b31
Show file tree
Hide file tree
Showing 8 changed files with 2,211 additions and 2,606 deletions.
29 changes: 16 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,27 @@ A toolbox to compute the robustness of STL formulas using computation graphs. Th

Requires Python 3.10+

Clone the repo.
Install the repo:

Make a venv and activate it

`python3 -m venv stljax_venv`

`source stljax_venv/bin/activate`
```
pip install git+https://github.com/UW-CTRL/stljax.git
```

Go into the `stljax` folder. Then to install:
Alternatively, if you like to install the package in editable mode,

`pip install -e .`
```
git clone https://github.com/UW-CTRL/stljax.git
cd stljax
pip install -e .
```
(Best to use a virtual environment.)


## Usage
`demo.ipynb` is an IPython jupyter notebook that showcases the basic functionality of the toolbox:
* Setting up signals for the formulas, including the use of Expressions and Predicates
* Defining STL formulas and visualizing them
* Evaluating STL robustness, and robustness trace
* Gradient descent on STL parameters and signal parameters.


## (New) Features
Expand All @@ -35,13 +37,13 @@ stljax leverages to benefits of jax and automatic differentiation!
Aside from using jax as the backend, stljax is more recent and tidier implementation of stlcg which was originally implemented in PyTorch back ~2019.

- Removed the `distributed_mean` hack from original stlcg implementation. jax keeps track of multiple max/min values and will distribute the gradients across all max/min values!
- Incorporation of the smooth max/min presented in [Optimization with Temporal and Logical Specifications via Generalized Mean-based Smooth Robustness Measures](https://arxiv.org/abs/2405.10996) by Samet Uzun, Purnanand Elango, Pierre-Loic Garoche, Behcet Acikmese
- Use `approx_method="gmsr"` and `temperature=(eps, p)`


## Tags

| Tags 🏷️ | Description |
| --------- | ----------- |
| v.1.1.0 | General code improvements. Included recurrent implementation and example notebooks. |
| v.1.0.0 | Removed awkward expected signal dimension & leverage vmap for batched inputs. Masking for temporal operations & remove need to reverse signals. |
| v0.0.0 | A transfer from the 2019 PyTorch implementation to Jax + some tidying + adding Predicates + reversing signal automatically. |

Expand Down Expand Up @@ -107,11 +109,12 @@ We can use `jax.vmap` to handle multiple signals at once.
`jax.vmap(formula)(signals) # signals is shape [bs, time_dim,...]`



NOTE: Need to take care for formulas defined with Expressions and need multiple inputs. Need a wrapper since `jax.vmap` doesn't like tuples in a single argument.



## TODOs
- re-implement stlcg (PyTorch) with the latest version of PyTorch.
- manage reversing of signals internally for recurrent cases.


## Publications
Expand Down
2,459 changes: 384 additions & 2,075 deletions demo.ipynb

Large diffs are not rendered by default.

287 changes: 287 additions & 0 deletions examples/parametric_time_interval.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit ff13b31

Please sign in to comment.