Skip to content

Commit

Permalink
Update README
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax authored Nov 26, 2024
1 parent 230450c commit 7554e1f
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,13 @@ The usage of a heteroskedastic BNN is straightforward and follows the same patte

For fully Bayesian heteroskedastic NN:
```python3
# Define a NN architecture (note that we use a 2-head model, where the second head is for noise)
architecture = nb.FlaxMLP2Head(hidden_dims, target_dim)

# Initialize HeteroskedasticBNN model
model = nb.HeteroskedasticBNN(architecture)
# Train
model.fit(X_measured, y_measured, num_warmup=2000, num_samples=2000)
model.fit(X_measured, y_measured, num_warmup=1000, num_samples=1000)
# Make a prediction
posterior_mean, posterior_var = model.predict(X_domain)
```
Expand All @@ -90,6 +93,25 @@ model.fit(X_measured, y_measured, sgd_epochs=5000, sgd_lr=5e-3, num_warmup=1000,
posterior_mean, posterior_var = model.predict(X_domain)
```

Sometimes in scientific and engineering applciations, domain experts may possess prior knowledge about how noise level varies with inputs. NeuroBayes enables the integration of such knowledge through a noise model-based heteroskedastic Bayesian Neural Network (BNN).

```python3
# Define a noise model based on prior knowledge:
def noise_model_fn(x, a, b):
return a * jnp.exp(b*x)

# Convert it to a format that neurobayes can work with
noise_model = nb.utils.set_fn(noise_model_fn)
noise_model_prior = nb.priors.auto_normal_priors(noise_model_fn)

# Deine architecture (note that here we use a single-head MLP)
architecture = nb.FlaxMLP(hidden_dims, target_dim)

# Initialize and train a noise model-based Heteroskedastic BNN
model = nb.VarianceModelHeteroskedasticBNN(architecture, noise_model, noise_model_prior)
model.fit(X_measured, y_measured, num_warmup=1000, num_samples=1000)
```

![hsk](https://github.com/user-attachments/assets/5a619361-74c0-4d03-9b1a-4aa995f1c540)

See example [here](https://github.com/ziatdinovmax/NeuroBayes/blob/main/examples/heteroskedastic.ipynb).
Expand Down

0 comments on commit 7554e1f

Please sign in to comment.