Skip to content

Commit

Permalink
Merge pull request #2392 from banda-larga:patch-1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 467673436
  • Loading branch information
Flax Authors committed Aug 15, 2022
2 parents 99294fb + 43c3c29 commit 47418b3
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/guides/flax_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@
"### Gradient descent\n",
"\n",
"If you jumped here directly without going through the JAX part, here is the linear regression formulation we're going to use: from a set of data points $\\{(x_i,y_i), i\\in \\{1,\\ldots, k\\}, x_i\\in\\mathbb{R}^n,y_i\\in\\mathbb{R}^m\\}$, we try to find a set of parameters $W\\in \\mathcal{M}_{m,n}(\\mathbb{R}), b\\in\\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error:\n",
"\n",
"$$\\mathcal{L}(W,b)\\rightarrow\\frac{1}{k}\\sum_{i=1}^{k} \\frac{1}{2}\\|y_i-f_{W,b}(x_i)\\|^2_2$$\n",
"\n",
"Here, we see that the tuple $(W,b)$ matches the parameters of the Dense layer. We'll perform gradient descent using those. Let's first generate the fake data we'll use. The data is exactly the same as in the JAX part's linear regression pytree example."
Expand Down
1 change: 1 addition & 0 deletions docs/guides/flax_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ model.apply(params, x)
### Gradient descent

If you jumped here directly without going through the JAX part, here is the linear regression formulation we're going to use: from a set of data points $\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}$, we try to find a set of parameters $W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error:

$$\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}(x_i)\|^2_2$$

Here, we see that the tuple $(W,b)$ matches the parameters of the Dense layer. We'll perform gradient descent using those. Let's first generate the fake data we'll use. The data is exactly the same as in the JAX part's linear regression pytree example.
Expand Down
6 changes: 6 additions & 0 deletions docs/guides/jax_for_the_impatient.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@
"$$f(x) = \\frac{1}{2} x^T x$$\n",
"\n",
"with the (known) gradient:\n",
"\n",
"$$\\nabla f(x) = x$$"
]
},
Expand Down Expand Up @@ -560,11 +561,13 @@
"### Jacobian-Vector product\n",
"\n",
"Let's consider a map $f:\\mathbb{R}^n\\rightarrow\\mathbb{R}^m$. As a reminder, the differential of f is the map $df:\\mathbb{R}^n \\rightarrow \\mathcal{L}(\\mathbb{R}^n,\\mathbb{R}^m)$ where $\\mathcal{L}(\\mathbb{R}^n,\\mathbb{R}^m)$ is the space of linear maps from $\\mathbb{R}^n$ to $\\mathbb{R}^m$ (hence $df(x)$ is often represented as a Jacobian matrix). The linear approximation of f at point $x$ reads:\n",
"\n",
"$$f(x+v) = f(x) + df(x)\\bullet v + o(v)$$\n",
"\n",
"The $\\bullet$ operator means you are applying the linear map $df(x)$ to the vector v.\n",
"\n",
"Even though you are rarely interested in computing the full Jacobian matrix representing the linear map $df(x)$ in a standard basis, you are often interested in the quantity $df(x)\\bullet v$. This is exactly what `jax.jvp` is for, and `jax.jvp(f, (x,), (v,))` returns the tuple:\n",
"\n",
"$$(f(x), df(x)\\bullet v)$$"
]
},
Expand Down Expand Up @@ -621,11 +624,13 @@
"source": [
"### Vector-Jacobian product\n",
"Keeping our $f:\\mathbb{R}^n\\rightarrow\\mathbb{R}^m$ it's often the case (for example, when you are working with a scalar loss function) that you are interested in the composition $x\\rightarrow\\phi\\circ f(x)$ where $\\phi :\\mathbb{R}^m\\rightarrow\\mathbb{R}$. In that case, the gradient reads:\n",
"\n",
"$$\\nabla(\\phi\\circ f)(x) = J_f(x)^T\\nabla\\phi(f(x))$$\n",
"\n",
"Where $J_f(x)$ is the Jacobian matrix of f evaluated at x, meaning that $df(x)\\bullet v = J_f(x)v$.\n",
"\n",
"`jax.vjp(f,x)` returns the tuple:\n",
"\n",
"$$(f(x),v\\rightarrow v^TJ_f(x))$$\n",
"\n",
"Keeping the same example as previously, using $v=(1,\\ldots,1)$, applying the VJP function returned by JAX should return the $x$ value:"
Expand Down Expand Up @@ -795,6 +800,7 @@
"## Full example: linear regression\n",
"\n",
"Let's implement one of the simplest models using everything we have seen so far: a linear regression. From a set of data points $\\{(x_i,y_i), i\\in \\{1,\\ldots, k\\}, x_i\\in\\mathbb{R}^n,y_i\\in\\mathbb{R}^m\\}$, we try to find a set of parameters $W\\in \\mathcal{M}_{m,n}(\\mathbb{R}), b\\in\\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error:\n",
"\n",
"$$\\mathcal{L}(W,b)\\rightarrow\\frac{1}{k}\\sum_{i=1}^{k} \\frac{1}{2}\\|y_i-f_{W,b}(x_i)\\|^2_2$$\n",
"\n",
"(Note: depending on how you cast the regression problem you might end up with different setups. Theoretically we should be minimizing the expectation of the loss wrt to the data distribution, however for the sake of simplicity here we consider only the sampled loss)."
Expand Down
6 changes: 6 additions & 0 deletions docs/guides/jax_for_the_impatient.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ JAX provides first-class support for gradients and automatic differentiation in
$$f(x) = \frac{1}{2} x^T x$$

with the (known) gradient:

$$\nabla f(x) = x$$

```{code-cell}
Expand Down Expand Up @@ -270,11 +271,13 @@ As previously mentioned, `jax.grad` only works for scalar-valued functions. JAX
### Jacobian-Vector product

Let's consider a map $f:\mathbb{R}^n\rightarrow\mathbb{R}^m$. As a reminder, the differential of f is the map $df:\mathbb{R}^n \rightarrow \mathcal{L}(\mathbb{R}^n,\mathbb{R}^m)$ where $\mathcal{L}(\mathbb{R}^n,\mathbb{R}^m)$ is the space of linear maps from $\mathbb{R}^n$ to $\mathbb{R}^m$ (hence $df(x)$ is often represented as a Jacobian matrix). The linear approximation of f at point $x$ reads:

$$f(x+v) = f(x) + df(x)\bullet v + o(v)$$

The $\bullet$ operator means you are applying the linear map $df(x)$ to the vector v.

Even though you are rarely interested in computing the full Jacobian matrix representing the linear map $df(x)$ in a standard basis, you are often interested in the quantity $df(x)\bullet v$. This is exactly what `jax.jvp` is for, and `jax.jvp(f, (x,), (v,))` returns the tuple:

$$(f(x), df(x)\bullet v)$$

+++ {"id": "F5nI_gbeqj2y"}
Expand Down Expand Up @@ -303,11 +306,13 @@ print(jax.jvp(f, (x,),(v,)))

### Vector-Jacobian product
Keeping our $f:\mathbb{R}^n\rightarrow\mathbb{R}^m$ it's often the case (for example, when you are working with a scalar loss function) that you are interested in the composition $x\rightarrow\phi\circ f(x)$ where $\phi :\mathbb{R}^m\rightarrow\mathbb{R}$. In that case, the gradient reads:

$$\nabla(\phi\circ f)(x) = J_f(x)^T\nabla\phi(f(x))$$

Where $J_f(x)$ is the Jacobian matrix of f evaluated at x, meaning that $df(x)\bullet v = J_f(x)v$.

`jax.vjp(f,x)` returns the tuple:

$$(f(x),v\rightarrow v^TJ_f(x))$$

Keeping the same example as previously, using $v=(1,\ldots,1)$, applying the VJP function returned by JAX should return the $x$ value:
Expand Down Expand Up @@ -398,6 +403,7 @@ print("Batched example shape: ", jax.vmap(apply_matrix)(batched_x).shape)
## Full example: linear regression

Let's implement one of the simplest models using everything we have seen so far: a linear regression. From a set of data points $\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}$, we try to find a set of parameters $W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error:

$$\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}(x_i)\|^2_2$$

(Note: depending on how you cast the regression problem you might end up with different setups. Theoretically we should be minimizing the expectation of the loss wrt to the data distribution, however for the sake of simplicity here we consider only the sampled loss).
Expand Down

0 comments on commit 47418b3

Please sign in to comment.