From 2ef1cc8ee616433d8679f30a8fbb2f1350bc082a Mon Sep 17 00:00:00 2001 From: NorthblueM Date: Sun, 28 Jan 2024 14:46:22 +0800 Subject: [PATCH] update: ch08.ipynb --- notebooks/ch08/ch08.ipynb | 31487 ++++++++++++++++++------------------ 1 file changed, 15953 insertions(+), 15534 deletions(-) diff --git a/notebooks/ch08/ch08.ipynb b/notebooks/ch08/ch08.ipynb index 1972608d..05895bc7 100644 --- a/notebooks/ch08/ch08.ipynb +++ b/notebooks/ch08/ch08.ipynb @@ -64,7 +64,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 9, "metadata": { "pycharm": { "name": "#%%\n" @@ -77,12 +77,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-06-20T23:40:29.170098\n", + " 2024-01-13T11:40:24.048397\n", " image/svg+xml\n", " \n", " \n", @@ -97,8 +97,8 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -569,18 +569,18 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -590,18 +590,18 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -610,18 +610,18 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -630,18 +630,18 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -650,18 +650,18 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -670,7 +670,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -1681,106 +1676,23 @@ "from torch import nn\n", "from d2l import torch as d2l\n", "\n", - "T = 1000 # 总共产生1000个点\n", - "time = torch.arange(1, T + 1, dtype=torch.float32)\n", - "x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))\n", + "T = 1000 # 总共产生1000个点的时间范围\n", + "time = torch.arange(1, T + 1, dtype=torch.float32) # 根据点生成时间序列\n", + "x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,)) # 按正弦函数生成序列x,并为其添加正态分布噪声\n", "d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))" ] }, { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "tau = 4\n", - "features = torch.zeros((T - tau, tau))\n", - "for i in range(tau):\n", - " features[:, i] = x[i: T - tau + i]\n", - "labels = x[tau:].reshape((-1, 1))\n", - "\n", - "batch_size, n_train = 16, 600\n", - "# 只有前n_train个样本用于训练\n", - "train_iter = d2l.load_array((features[:n_train], labels[:n_train]),\n", - " batch_size, is_train=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "# 初始化网络权重的函数\n", - "def init_weights(m):\n", - " if type(m) == nn.Linear:\n", - " nn.init.xavier_uniform_(m.weight)\n", - "\n", - "# 一个简单的多层感知机\n", - "def get_net():\n", - " net = nn.Sequential(nn.Linear(tau, 10),\n", - " nn.ReLU(),\n", - " nn.Linear(10, 1))\n", - " net.apply(init_weights)\n", - " return net\n", - "\n", - "# 平方损失。注意:MSELoss计算平方误差时不带系数1/2\n", - "loss = nn.MSELoss(reduction='none')" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 1, loss: 0.054359\n", - "epoch 2, loss: 0.050225\n", - "epoch 3, loss: 0.049887\n", - "epoch 4, loss: 0.048766\n", - "epoch 5, loss: 0.048702\n" - ] - } - ], + "cell_type": "markdown", + "metadata": {}, "source": [ - "def train(net, train_iter, loss, epochs, lr):\n", - " trainer = torch.optim.Adam(net.parameters(), lr)\n", - " for epoch in range(epochs):\n", - " for X, y in train_iter:\n", - " trainer.zero_grad()\n", - " l = loss(net(X), y)\n", - " l.sum().backward()\n", - " trainer.step()\n", - " print(f'epoch {epoch + 1}, '\n", - " f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')\n", - "\n", - "net = get_net()\n", - "train(net, train_iter, loss, 5, 0.01)" + "  $x = torch.sin(0.01 * time)$ 中乘 $0.01$ 的目的在于对时间序列进行缩放,以便更好的进行观察。未缩放的结果如下图所示。" ] }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "execution_count": 2, + "metadata": {}, "outputs": [ { "data": { @@ -1788,12 +1700,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-06-20T23:40:59.981869\n", + " 2024-01-13T11:19:35.977536\n", " image/svg+xml\n", " \n", " \n", @@ -1808,8 +1720,8 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2280,18 +2192,18 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2301,18 +2213,18 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2321,18 +2233,18 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2341,18 +2253,18 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2363,11 +2275,11 @@ " \n", " \n", + "\" clip-path=\"url(#p36ba876075)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2381,7 +2293,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -4513,55 +3169,124 @@ } ], "source": [ - "onestep_preds = net(features)\n", - "d2l.plot([time, time[tau:]],\n", - " [x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n", - " 'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n", - " figsize=(6, 3))" + "import torch\n", + "from torch import nn\n", + "from d2l import torch as d2l\n", + "\n", + "T = 1000 # 总共产生1000个点的时间范围\n", + "time = torch.arange(1, T + 1, dtype=torch.float32) # 根据点生成时间序列\n", + "x = torch.sin(time) + torch.normal(0, 0.2, (T,)) # 按正弦函数生成序列x,并为其添加正态分布噪声\n", + "d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 10, "metadata": { "pycharm": { - "name": "#%% md\n" + "name": "#%%\n" } }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([995, 1])\n" + ] + } + ], "source": [ - "  在代码中,模型通过一个形状为$(T-tau, tau)$ 的特征张量存储过去观测结果,此时 $T$ 为总观测数,$tau$为观测结果的历史长度,将$tau$设置为$4$后,模型考虑了过去$4$个观测结果。 \n", - "  此外由于预测是基于单变量的时序预测,即根据过去观测结果预测下一时间步的值,因此在代码中,真实值存储在$(T-tau, 1)$的张量中,生成$T-tau$个训练样本,又因为$n_train$中规定了训练样本数量为600,因此真实值数量为600个" + "# 定义滞后步数\n", + "tau = 5\n", + "# 初始化特征矩阵,形状为(T - tau, tau),用于存储滞后特征\n", + "features = torch.zeros((T - tau, tau))\n", + "for i in range(tau):\n", + " # 将每个滞后步数对应的原始序列存储到特征矩阵中\n", + " features[:, i] = x[i: T - tau + i]\n", + "\n", + "# 生成标签,并对生成形状进行设置 \n", + "labels = x[tau:].reshape((-1, 1))\n", + "print(labels.shape)\n", + "\n", + "# 设置训练批次与训练样本数\n", + "batch_size, n_train = 16, 600\n", + "# 只有前n_train个样本用于训练\n", + "train_iter = d2l.load_array((features[:n_train], labels[:n_train]),\n", + " batch_size, is_train=True)" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 11, "metadata": { "pycharm": { - "name": "#%% md\n" + "name": "#%%\n" } }, + "outputs": [], "source": [ - "2. 如果没有噪音,需要多少个过去的观测结果?提示:把sin和cos写成微分方程" + "# 初始化网络权重的函数\n", + "def init_weights(m):\n", + " if type(m) == nn.Linear:\n", + " nn.init.xavier_uniform_(m.weight)\n", + "\n", + "# 一个简单的多层感知机\n", + "def get_net():\n", + " net = nn.Sequential(nn.Linear(tau, 20),\n", + " nn.ReLU(),\n", + " nn.Linear(20, 1))\n", + " net.apply(init_weights)\n", + " return net\n", + "\n", + "# 平方损失。注意:MSELoss计算平方误差时不带系数1/2\n", + "loss = nn.MSELoss(reduction='none')" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 12, "metadata": { "pycharm": { - "name": "#%% md\n" + "name": "#%%\n" } }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 1, loss: 0.059790\n", + "epoch 2, loss: 0.053361\n", + "epoch 3, loss: 0.056654\n", + "epoch 4, loss: 0.054220\n", + "epoch 5, loss: 0.049205\n" + ] + } + ], "source": [ - "  由于\n", - "$$\\frac{d}{dx}sin(x)=cos(x)$$\n", - "$$\\frac{d}{dx}cos(x)=-sin(x)$$\n", - "  因此$sin$与$cos$的导数可以相互表示,即可通过以下微分方程描述$sin$与$cos$二者间关系:\n", - "$$\\frac{dy}{dx}=cos(x)$$\n", - "  因此在没有噪声的情况下,由于可以对$sin$值和$cos$值通过微分互相计算,因此只需要一个过去观测结果就可通过微分方程来恢复未观测的值,得到完整信息。" + "def train(net, train_iter, loss, epochs, lr):\n", + " trainer = torch.optim.Adam(net.parameters(), lr) # 设置Adam优化器\n", + " for epoch in range(epochs):\n", + " for X, y in train_iter:\n", + " # 梯度清零\n", + " trainer.zero_grad()\n", + " # 计算MSELoss\n", + " l = loss(net(X), y)\n", + " # 进行求和并反向传播\n", + " l.sum().backward()\n", + " # 通过优化器进行参数更新\n", + " trainer.step()\n", + " print(f'epoch {epoch + 1}, '\n", + " f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')\n", + "\n", + "net = get_net() # 加载定义好的网络模型\n", + "train(net, train_iter, loss, 5, 0.01) # 调用train函数进行训练,训练5轮,学习率为0.01" ] }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 13, "metadata": { "pycharm": { "name": "#%%\n" @@ -4579,7 +3304,7 @@ " \n", " \n", " \n", - " 2023-06-20T23:04:51.495471\n", + " 2024-01-13T11:40:35.125582\n", " image/svg+xml\n", " \n", " \n", @@ -4615,16 +3340,16 @@ " \n", " \n", + "\" clip-path=\"url(#p7cf952e537)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4687,11 +3412,11 @@ " \n", " \n", + "\" clip-path=\"url(#p7cf952e537)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4728,11 +3453,11 @@ " \n", " \n", + "\" clip-path=\"url(#p7cf952e537)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4780,11 +3505,11 @@ " \n", " \n", + "\" clip-path=\"url(#p7cf952e537)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4841,11 +3566,11 @@ " \n", " \n", + "\" clip-path=\"url(#p7cf952e537)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4978,23 +3703,23 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", @@ -5150,119 +3914,1868 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5302,74 +6013,57 @@ } ], "source": [ - "T = 1000 # 总共产生1000个点\n", - "time = torch.arange(1, T + 1, dtype=torch.float32)\n", - "x = torch.sin(0.01 * time)\n", - "d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))" + "# 调用训练好的模型进行预测\n", + "onestep_preds = net(features)\n", + "# 纵轴为序列值x,横轴为时间,data为原始数据,1-step(单步预测):根据上一个过去值预测下一个值\n", + "d2l.plot([time, time[tau:]],\n", + " [x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n", + " 'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n", + " figsize=(6, 3))" ] }, { - "cell_type": "code", - "execution_count": 43, + "cell_type": "markdown", "metadata": { "pycharm": { - "name": "#%%\n" + "name": "#%% md\n" } }, - "outputs": [], "source": [ - "tau = 1\n", - "features = torch.zeros((T - tau, tau))\n", - "for i in range(tau):\n", - " features[:, i] = x[i: T - tau + i]\n", - "labels = x[tau:].reshape((-1, 1))\n", - "\n", - "batch_size, n_train = 16, 600\n", - "# 只有前n_train个样本用于训练\n", - "train_iter = d2l.load_array((features[:n_train], labels[:n_train]),\n", - " batch_size, is_train=True)" + "  在代码中,模型通过一个形状为$(T-tau, tau)$ 的特征张量存储过去观测结果,此时 $T$ 为总观测数,$tau$ 为观测结果的历史长度,将 $tau$ 设置为 $5$ 后,模型考虑了过去 $5$ 个观测结果作为滞后特征,即在当前时间点,模型在预测时考虑了过去5个时间点的观测结果。 \n", + "  由于预测是基于单变量的时序预测,即根据过去观测结果预测下一时间步的值,因此在代码中,真实值存储在$(T-tau, 1)$的张量中,且在代码中使用 lables=x[tau:].reshape((-1, 1)) 生成标签,因此真实值数量为总样本数减去滞后步数,即共有 $1000-5=995$ 个真实值" ] }, { - "cell_type": "code", - "execution_count": 44, + "cell_type": "markdown", "metadata": { "pycharm": { - "name": "#%%\n" + "name": "#%% md\n" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 1, loss: 0.033229\n", - "epoch 2, loss: 0.005169\n", - "epoch 3, loss: 0.002071\n", - "epoch 4, loss: 0.001557\n", - "epoch 5, loss: 0.001237\n" - ] + "source": [ + "2. 如果没有噪音,需要多少个过去的观测结果?提示:把sin和cos写成微分方程" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" } - ], + }, "source": [ - "def train(net, train_iter, loss, epochs, lr):\n", - " trainer = torch.optim.Adam(net.parameters(), lr)\n", - " for epoch in range(epochs):\n", - " for X, y in train_iter:\n", - " trainer.zero_grad()\n", - " l = loss(net(X), y)\n", - " l.sum().backward()\n", - " trainer.step()\n", - " print(f'epoch {epoch + 1}, '\n", - " f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')\n", - "\n", - "net = get_net()\n", - "train(net, train_iter, loss, 5, 0.01)" + "  根据正余弦函数导数,存在\n", + "$$\\displaystyle{\\frac{d}{dx}sin(x)=cos(x)}$$\n", + "$$\\displaystyle{\\frac{d}{dx}cos(x)=-sin(x)}$$\n", + "  因此$sin$与$cos$的导数可以相互表示,即可通过以下微分方程描述$sin$与$cos$二者间关系:\n", + "$$\\displaystyle{\\frac{dy}{dx}=cos(x)}$$\n", + "  因此在没有噪声的情况下,由于可以对$sin$值和$cos$值通过微分互相计算,因此只需要一个过去观测结果就可通过微分方程来恢复未观测的值,得到完整信息。" ] }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 42, "metadata": { "pycharm": { "name": "#%%\n" @@ -5387,7 +6081,7 @@ " \n", " \n", " \n", - " 2023-06-20T23:04:59.093951\n", + " 2023-06-20T23:04:51.495471\n", " image/svg+xml\n", " \n", " \n", @@ -5423,16 +6117,16 @@ " \n", " \n", + "\" clip-path=\"url(#paf89f37639)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5495,11 +6189,11 @@ " \n", " \n", + "\" clip-path=\"url(#paf89f37639)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5536,11 +6230,11 @@ " \n", " \n", + "\" clip-path=\"url(#paf89f37639)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5588,11 +6282,11 @@ " \n", " \n", + "\" clip-path=\"url(#paf89f37639)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5649,11 +6343,11 @@ " \n", " \n", + "\" clip-path=\"url(#paf89f37639)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5786,23 +6480,23 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5896,18 +6590,18 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5916,18 +6610,18 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5959,240 +6653,118 @@ " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + "L 211.350127 165.874546 \n", + "L 213.025802 165.600349 \n", + "L 215.036613 165.02316 \n", + "L 217.047424 164.177267 \n", + "L 219.058235 163.065711 \n", + "L 221.069046 161.692494 \n", + "L 223.079856 160.062559 \n", + "L 225.425802 157.844373 \n", + "L 227.771748 155.295645 \n", + "L 230.117694 152.428884 \n", + "L 232.798775 148.781181 \n", + "L 235.479856 144.7598 \n", + "L 238.496073 139.821198 \n", + "L 241.847424 133.866101 \n", + "L 245.53391 126.815889 \n", + "L 249.555532 118.62521 \n", + "L 254.582559 107.829627 \n", + "L 261.620397 92.112771 \n", + "L 271.674451 69.699119 \n", + "L 276.701478 59.063659 \n", + "L 280.7231 51.050379 \n", + "L 284.409586 44.198858 \n", + "L 287.760938 38.452368 \n", + "L 290.777154 33.722444 \n", + "L 293.458235 29.901806 \n", + "L 296.139316 26.467914 \n", + "L 298.485262 23.797831 \n", + "L 300.831208 21.453758 \n", + "L 303.177154 19.447188 \n", + "L 305.187965 18.003396 \n", + "L 307.198775 16.820008 \n", + "L 309.209586 15.901292 \n", + "L 311.220397 15.250541 \n", + "L 313.231208 14.870102 \n", + "L 314.906883 14.760577 \n", + "L 316.582559 14.840011 \n", + "L 318.258235 15.108206 \n", + "L 320.269046 15.678234 \n", + "L 322.279856 16.517021 \n", + "L 324.290667 17.621583 \n", + "L 326.301478 18.987902 \n", + "L 328.312289 20.611109 \n", + "L 330.658235 22.821626 \n", + "L 333.004181 25.36295 \n", + "L 335.350127 28.222664 \n", + "L 338.031208 31.862661 \n", + "L 340.712289 35.876846 \n", + "L 343.728505 40.808022 \n", + "L 347.079856 46.755643 \n", + "L 350.766343 53.798893 \n", + "L 354.787965 61.983338 \n", + "L 359.814992 72.773434 \n", + "L 366.852829 88.487043 \n", + "L 376.906883 110.905218 \n", + "L 381.93391 121.546849 \n", + "L 385.955532 129.566877 \n", + "L 386.960938 131.488159 \n", + "L 386.960938 131.488159 \n", + "\" clip-path=\"url(#paf89f37639)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -6430,27 +6804,74 @@ } ], "source": [ - "onestep_preds = net(features)\n", - "d2l.plot([time, time[tau:]],\n", - " [x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n", - " 'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n", - " figsize=(6, 3))" + "T = 1000 # 总共产生1000个点\n", + "time = torch.arange(1, T + 1, dtype=torch.float32)\n", + "x = torch.sin(0.01 * time)\n", + "d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 43, "metadata": { "pycharm": { - "name": "#%% md\n" + "name": "#%%\n" } }, + "outputs": [], "source": [ - "3. 可以在保持特征总数不变的情况下合并旧的观察结果吗?这能提高正确度吗?为什么?" + "tau = 1\n", + "features = torch.zeros((T - tau, tau))\n", + "for i in range(tau):\n", + " features[:, i] = x[i: T - tau + i]\n", + "labels = x[tau:].reshape((-1, 1))\n", + "\n", + "batch_size, n_train = 16, 600\n", + "# 只有前n_train个样本用于训练\n", + "train_iter = d2l.load_array((features[:n_train], labels[:n_train]),\n", + " batch_size, is_train=True)" ] }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 44, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 1, loss: 0.033229\n", + "epoch 2, loss: 0.005169\n", + "epoch 3, loss: 0.002071\n", + "epoch 4, loss: 0.001557\n", + "epoch 5, loss: 0.001237\n" + ] + } + ], + "source": [ + "def train(net, train_iter, loss, epochs, lr):\n", + " trainer = torch.optim.Adam(net.parameters(), lr)\n", + " for epoch in range(epochs):\n", + " for X, y in train_iter:\n", + " trainer.zero_grad()\n", + " l = loss(net(X), y)\n", + " l.sum().backward()\n", + " trainer.step()\n", + " print(f'epoch {epoch + 1}, '\n", + " f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')\n", + "\n", + "net = get_net()\n", + "train(net, train_iter, loss, 5, 0.01)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, "metadata": { "pycharm": { "name": "#%%\n" @@ -6468,7 +6889,7 @@ " \n", " \n", " \n", - " 2023-06-20T23:27:46.112018\n", + " 2023-06-20T23:04:59.093951\n", " image/svg+xml\n", " \n", " \n", @@ -6504,16 +6925,16 @@ " \n", " \n", + "\" clip-path=\"url(#p83de8faeb8)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -6576,11 +6997,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83de8faeb8)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -6617,11 +7038,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83de8faeb8)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -6669,11 +7090,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83de8faeb8)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -6730,11 +7151,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83de8faeb8)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -6867,23 +7288,23 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -7078,8909 +7460,6539 @@ " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "T = 1000 # 总共产生1000个点\n", - "time = torch.arange(1, T + 1, dtype=torch.float32)\n", - "x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))\n", - "d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "tau = 4\n", - "features = torch.zeros((T - tau, tau))\n", - "for i in range(T - tau):\n", - " features[i] = x[i:i+tau]\n", - "labels = x[tau:].reshape((-1, 1))\n", - "\n", - "batch_size, n_train = 16, 600\n", - "# 只有前n_train个样本用于训练\n", - "train_iter = d2l.load_array((features[:n_train], labels[:n_train]),\n", - " batch_size, is_train=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 1, loss: 0.064830\n", - "epoch 2, loss: 0.060723\n", - "epoch 3, loss: 0.057784\n", - "epoch 4, loss: 0.057198\n", - "epoch 5, loss: 0.054965\n" - ] - } - ], - "source": [ - "def train(net, train_iter, loss, epochs, lr):\n", - " trainer = torch.optim.Adam(net.parameters(), lr)\n", - " for epoch in range(epochs):\n", - " for X, y in train_iter:\n", - " trainer.zero_grad()\n", - " l = loss(net(X), y)\n", - " l.sum().backward()\n", - " trainer.step()\n", - " print(f'epoch {epoch + 1}, '\n", - " f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')\n", - "\n", - "net = get_net()\n", - "train(net, train_iter, loss, 5, 0.01)" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2023-06-20T23:28:58.280965\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.5.1, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "onestep_preds = net(features)\n", - "d2l.plot([time, time[tau:]],\n", - " [x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n", - " 'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n", - " figsize=(6, 3))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  可以采用滑动窗口的方式生成特征矩阵**feature**,这种方式在原有特征矩阵中移动窗口的起始位置,并没有增加新的特征维度。通过利用过去的观察结果作为输入特征,通过合并旧的观察结果,可以提供更丰富的特征信息,使得模型可以捕捉到序列数据中的时间相关性,从而有效提高预测准确度。" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "4. 改变神经网络架构并评估其性能" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2023-06-20T23:46:32.659972\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.5.1, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import torch\n", - "from torch import nn\n", - "from d2l import torch as d2l\n", - "import matplotlib.pyplot as plt\n", - "\n", - "T = 1000 # 总共产生1000个点\n", - "time = torch.arange(1, T + 1, dtype=torch.float32)\n", - "x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))\n", - "d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "tau = 4\n", - "features = torch.zeros((T - tau, tau))\n", - "for i in range(tau):\n", - " features[:, i] = x[i: T - tau + i]\n", - "labels = x[tau:].reshape((-1, 1))\n", - "\n", - "batch_size, n_train = 16, 600\n", - "# 只有前n_train个样本用于训练\n", - "train_iter = d2l.load_array((features[:n_train], labels[:n_train]),\n", - " batch_size, is_train=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "# 初始化网络权重的函数\n", - "def init_weights(m):\n", - " if type(m) == nn.Linear:\n", - " nn.init.xavier_uniform_(m.weight)\n", - "\n", - "# 一个简单的多层感知机\n", - "def get_net():\n", - " net = nn.Sequential(nn.Linear(4, 10),\n", - " nn.ReLU(),\n", - " nn.Linear(10, 1))\n", - " net.apply(init_weights)\n", - " return net\n", - "\n", - "# 平方损失。注意:MSELoss计算平方误差时不带系数1/2\n", - "loss = nn.MSELoss(reduction='none')" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "def new_get_net():\n", - " net = nn.Sequential(nn.Linear(tau, 50),\n", - " nn.ReLU(),\n", - " nn.Linear(50, 10),\n", - " nn.ReLU(),\n", - " nn.Linear(10, 1))\n", - " net.apply(init_weights)\n", - " return net" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "def train_draw(net, train_iter, loss, epochs, lr):\n", - " trainer = torch.optim.Adam(net.parameters(), lr)\n", - " losses = [] # 用于记录每个 epoch 的损失值\n", - " for epoch in range(epochs):\n", - " epoch_loss = 0.0\n", - " for X, y in train_iter:\n", - " trainer.zero_grad()\n", - " l = loss(net(X), y)\n", - " l.sum().backward()\n", - " trainer.step()\n", - " epoch_loss += l.mean().item()\n", - " epoch_loss /= len(train_iter)\n", - " losses.append(epoch_loss)\n", - " print(f'epoch {epoch + 1}, loss: {epoch_loss:f}')\n", - "\n", - " # 绘制损失曲线\n", - " plt.plot(losses)\n", - " plt.xlabel('Epoch')\n", - " plt.ylabel('Loss')\n", - " plt.title('Training Loss')\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 1, loss: 0.215900\n", - "epoch 2, loss: 0.065764\n", - "epoch 3, loss: 0.059391\n", - "epoch 4, loss: 0.056455\n", - "epoch 5, loss: 0.055573\n" - ] - }, - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2023-06-20T23:48:04.226450\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.5.1, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "net_1 = get_net()\n", - "train_draw(net_1, train_iter, loss, 5, 0.01)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2023-06-20T23:48:10.930074\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.5.1, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "onestep_preds = net(features)\n", + "d2l.plot([time, time[tau:]],\n", + " [x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n", + " 'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n", + " figsize=(6, 3))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "3. 可以在保持特征总数不变的情况下合并旧的观察结果吗?这能提高正确度吗?为什么?" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-06-20T23:27:46.112018\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.5.1, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "T = 1000 # 总共产生1000个点\n", + "time = torch.arange(1, T + 1, dtype=torch.float32)\n", + "x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))\n", + "d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "tau = 4\n", + "features = torch.zeros((T - tau, tau))\n", + "for i in range(T - tau):\n", + " features[i] = x[i:i+tau]\n", + "labels = x[tau:].reshape((-1, 1))\n", + "\n", + "batch_size, n_train = 16, 600\n", + "# 只有前n_train个样本用于训练\n", + "train_iter = d2l.load_array((features[:n_train], labels[:n_train]),\n", + " batch_size, is_train=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 1, loss: 0.064830\n", + "epoch 2, loss: 0.060723\n", + "epoch 3, loss: 0.057784\n", + "epoch 4, loss: 0.057198\n", + "epoch 5, loss: 0.054965\n" + ] + } + ], + "source": [ + "def train(net, train_iter, loss, epochs, lr):\n", + " trainer = torch.optim.Adam(net.parameters(), lr)\n", + " for epoch in range(epochs):\n", + " for X, y in train_iter:\n", + " trainer.zero_grad()\n", + " l = loss(net(X), y)\n", + " l.sum().backward()\n", + " trainer.step()\n", + " print(f'epoch {epoch + 1}, '\n", + " f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')\n", + "\n", + "net = get_net()\n", + "train(net, train_iter, loss, 5, 0.01)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-06-20T23:28:58.280965\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.5.1, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "onestep_preds = net(features)\n", + "d2l.plot([time, time[tau:]],\n", + " [x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n", + " 'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n", + " figsize=(6, 3))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  可以采用滑动窗口的方式生成特征矩阵**feature**,这种方式在原有特征矩阵中移动窗口的起始位置,并没有增加新的特征维度。通过利用过去的观察结果作为输入特征,通过合并旧的观察结果,可以提供更丰富的特征信息,使得模型可以捕捉到序列数据中的时间相关性,从而有效提高预测准确度。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "4. 改变神经网络架构并评估其性能" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-06-20T23:46:32.659972\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.5.1, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -16218,16 +14032,122 @@ } ], "source": [ - "onestep_preds = net_1(features)\n", - "d2l.plot([time, time[tau:]],\n", - " [x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n", - " 'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n", - " figsize=(6, 3))" + "import torch\n", + "from torch import nn\n", + "from d2l import torch as d2l\n", + "import matplotlib.pyplot as plt\n", + "\n", + "T = 1000 # 总共产生1000个点\n", + "time = torch.arange(1, T + 1, dtype=torch.float32)\n", + "x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))\n", + "d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 10, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "tau = 4\n", + "features = torch.zeros((T - tau, tau))\n", + "for i in range(tau):\n", + " features[:, i] = x[i: T - tau + i]\n", + "labels = x[tau:].reshape((-1, 1))\n", + "\n", + "batch_size, n_train = 16, 600\n", + "# 只有前n_train个样本用于训练\n", + "train_iter = d2l.load_array((features[:n_train], labels[:n_train]),\n", + " batch_size, is_train=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# 初始化网络权重的函数\n", + "def init_weights(m):\n", + " if type(m) == nn.Linear:\n", + " nn.init.xavier_uniform_(m.weight)\n", + "\n", + "# 一个简单的多层感知机\n", + "def get_net():\n", + " net = nn.Sequential(nn.Linear(4, 10),\n", + " nn.ReLU(),\n", + " nn.Linear(10, 1))\n", + " net.apply(init_weights)\n", + " return net\n", + "\n", + "# 平方损失。注意:MSELoss计算平方误差时不带系数1/2\n", + "loss = nn.MSELoss(reduction='none')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "def new_get_net():\n", + " net = nn.Sequential(nn.Linear(tau, 50),\n", + " nn.ReLU(),\n", + " nn.Linear(50, 10),\n", + " nn.ReLU(),\n", + " nn.Linear(10, 1))\n", + " net.apply(init_weights)\n", + " return net" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "def train_draw(net, train_iter, loss, epochs, lr):\n", + " trainer = torch.optim.Adam(net.parameters(), lr)\n", + " losses = [] # 用于记录每个 epoch 的损失值\n", + " for epoch in range(epochs):\n", + " epoch_loss = 0.0\n", + " for X, y in train_iter:\n", + " trainer.zero_grad()\n", + " l = loss(net(X), y)\n", + " l.sum().backward()\n", + " trainer.step()\n", + " epoch_loss += l.mean().item()\n", + " epoch_loss /= len(train_iter)\n", + " losses.append(epoch_loss)\n", + " print(f'epoch {epoch + 1}, loss: {epoch_loss:f}')\n", + "\n", + " # 绘制损失曲线\n", + " plt.plot(losses)\n", + " plt.xlabel('Epoch')\n", + " plt.ylabel('Loss')\n", + " plt.title('Training Loss')\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, "metadata": { "pycharm": { "name": "#%%\n" @@ -16238,11 +14158,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "epoch 1, loss: 0.125826\n", - "epoch 2, loss: 0.055088\n", - "epoch 3, loss: 0.059558\n", - "epoch 4, loss: 0.058091\n", - "epoch 5, loss: 0.053354\n" + "epoch 1, loss: 0.215900\n", + "epoch 2, loss: 0.065764\n", + "epoch 3, loss: 0.059391\n", + "epoch 4, loss: 0.056455\n", + "epoch 5, loss: 0.055573\n" ] }, { @@ -16251,12 +14171,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-06-20T23:48:42.399788\n", + " 2023-06-20T23:48:04.226450\n", " image/svg+xml\n", " \n", " \n", @@ -16272,18 +14192,18 @@ " \n", " \n", " \n", " \n", " \n", " \n", - " \n", " \n", @@ -16291,17 +14211,17 @@ " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -16429,12 +14349,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -16485,12 +14405,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -16549,12 +14469,12 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -17000,37 +14806,37 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -17194,13 +15000,13 @@ } ], "source": [ - "net_2 = new_get_net()\n", - "train_draw(net_2, train_iter, loss, 5, 0.01)" + "net_1 = get_net()\n", + "train_draw(net_1, train_iter, loss, 5, 0.01)" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 14, "metadata": { "pycharm": { "name": "#%%\n" @@ -17218,7 +15024,7 @@ " \n", " \n", " \n", - " 2023-06-20T23:49:14.899483\n", + " 2023-06-20T23:48:10.930074\n", " image/svg+xml\n", " \n", " \n", @@ -17254,16 +15060,16 @@ " \n", " \n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -17326,11 +15132,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -17367,11 +15173,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -17419,11 +15225,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -17480,11 +15286,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -17619,16 +15425,16 @@ " \n", " \n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -17686,11 +15492,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -17707,11 +15513,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -17728,11 +15534,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -17748,11 +15554,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -17768,11 +15574,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -17788,11 +15594,11 @@ " \n", " \n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -18769,7 +16575,7 @@ "L 386.625802 108.320214 \n", "L 386.960938 112.595419 \n", "L 386.960938 112.595419 \n", - "\" clip-path=\"url(#pdd482f00e2)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", " \n", + "\" clip-path=\"url(#p83ec42125c)\" style=\"fill: none; stroke-dasharray: 5.55,2.4; stroke-dashoffset: 0; stroke: #bf00bf; stroke-width: 1.5\"/>\n", " \n", " \n", " \n", + "\" style=\"fill: #ffffff; opacity: 0.8; stroke: #cccccc; stroke-linejoin: miter\"/>\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "onestep_preds = net_1(features)\n", + "d2l.plot([time, time[tau:]],\n", + " [x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n", + " 'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n", + " figsize=(6, 3))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 1, loss: 0.125826\n", + "epoch 2, loss: 0.055088\n", + "epoch 3, loss: 0.059558\n", + "epoch 4, loss: 0.058091\n", + "epoch 5, loss: 0.053354\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-06-20T23:48:42.399788\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.5.1, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "onestep_preds = net_1(features)\n", - "d2l.plot([time, time[tau:]],\n", - " [x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n", - " 'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n", - " figsize=(6, 3))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  通过增加多层感知机中隐藏层的大小和层数,可以发现loss相较于初始的多层感知机出现降低,说明复杂结构在某种程度上提升了模型性能,但是可以发现复杂结构的loss在第2个epoch出现增高,导致曲线振荡。" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "### 练习 8.1.2\n", - "\n", - "一位投资者想要找到一种好的证券来购买。他查看过去的回报,以决定哪一种可能是表现良好的。这一策略可能会出什么问题呢?" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "**解答:**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  该投资者仅采用过去的回报作为参考信息来预测未来的证券表现,属于外推法。 \n", - "  外推法在预测未来数据时,存在假设局限性,即数据必须满足其过去历史可以用来预测未来的趋势。如果假设不成立,那么预测结果很可能会不准确。此外,外推法不能预测突发事件的发生,并且预测结果非常依赖数据质量。如果历史数据不准确,预测结果也会收到影响。 \n", - "  对于金融证券来说,决定证券回报的因素不仅仅包括过去的回报曲线,还包括政治事件、行业趋势和金融危机等外部因素,而这些因素往往是决定证券表现的重要因素。因此在使用外推法预测证券未来回报时,很难考虑到随时变化的市场环境以及突发的外部因素,从而带来产生巨大误差。" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "### 练习 8.1.3\n", - "\n", - "时间是向前推进的因果模型在多大程度上适用于文本呢?" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "**解答:**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  时间是向前推进的因果模型是一种基于时间序列和因果关系的模型,该模型假设未来的结果受过去因素和时间序列影响,并且无法反过来影响过去因素。 \n", - "  但是在文本数据中,时间序列并非影响未来文本结果产生的唯一因素,并且某种程度上来讲,时序与文本内容的产生并无显著关联。因此采用时间前向推进的因果模型并不能很好的适用于文本场景。需要通过建立更加复杂的模型来考虑影响文本内容的多种因素。" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "### 练习 8.1.4\n", - "\n", - "举例说明什么时候可能需要隐变量自回归模型来捕捉数据的动力学模型。" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "**解答:**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  隐变量自回归模型是一种用于建模时间序列数据中的动态变化的自回归模型,该模型通过引入一些隐变量来捕捉数据的动态性质。 \n", - "  动力学模型则是一种描述系统状态随时间变化而变化的数学模型,在该模型中,系统的状态可以用一组状态变量来表示,并且这些状态变量随着时间的推移而发生变化。 \n", - "  因此,在实际场景中,股票价格变化通常受时间序列和市场因素的影响,可通过隐变量自回归模型自回归的建模股票在时间上的依赖关系,并通过隐变量来捕捉未观测到的市场信息,最终建立起股票价格的动力学模型。" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "## 8.2 文本预处理 " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "### 练习 8.2.1\n", - "\n", - "词元化是一个关键的预处理步骤,它因语言而异。尝试找到另外三种常用的词元化文本的方法。" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "**解答:**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  三种常用的词元化文本的方法如下:\n", - "1. BPE(Byte-Pair Encoding):字节对编码,该方法本质是一种贪心算法,具体流程如下:\n", - " - 在语料库中单词末尾添加
\\,并统计该单词出现的次数\n", - " - 将单词切分为单个字符作为子词,并以此构建出初始子词词表\n", - " - 在语料库中,统计单词内相邻子词对的频数\n", - " - 将频数最高的子词对合并为新的子词,并加入到子词词表中\n", - " - 重复上述两步,直到达到合并次数或子词词表中所有子词的频数均为1。\n", - " 通过该方法,对语料库实现了数据压缩,实现通过最少的token数表示一个corpus\n", - "\n", - "2. WordPiece:WordPiece与BPE方法类似,本质也是一种贪心,但是不同于BPE选择出现频率最高的两个子词合并为新的子词,WordPiece选择具有较强关联性的子词进行合并。具体流程如下:\n", - " - 将语料库中单词切分为单个字符作为初始化的子词词表,假设每个子词独立,此时语言模型似然值等价于子词概率积\n", - " - 两两拼接子词,并统计新子词加入词表后对语言模型似然值的提升程度\n", - " - 最终选择对语言模型似然度提升最大的字符加入到词表中\n", - " - 重复上述两步,直到词表大小达到指定大小。\n", - "\n", - "3. SentencePiece:不同于BPE和WordPiece中采用空格区分不同单词的方式,SentencePiece将空格也进行编码,不考虑句子中单独的单词,而是将整个句子作为整体进行拆分。再通过BPE或Unigram的方式构造词表。" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "### 练习 8.2.2\n", - "\n", - "在本节的实验中,将文本词元为单词和更改`Vocab`实例的`min_freq`参数。这对词表大小有何影响?" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "**解答:**" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "import collections\n", - "import re\n", - "\n", - "import numpy as np\n", - "from d2l import torch as d2l" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  词元设置为Char时,输出的语料库和词表大小为 $(170580, 28)$,结果如下:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(170580, 28)" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def load_corpus_time_machine(max_tokens=-1): #@save\n", - " \"\"\"返回时光机器数据集的词元索引列表和词表\"\"\"\n", - " lines = d2l.read_time_machine()\n", - " tokens = d2l.tokenize(lines, 'char')\n", - " vocab = d2l.Vocab(tokens)\n", - " # 因为时光机器数据集中的每个文本行不一定是一个句子或一个段落,\n", - " # 所以将所有文本行展平到一个列表中\n", - " corpus = [vocab[token] for line in tokens for token in line]\n", - " if max_tokens > 0:\n", - " corpus = corpus[:max_tokens]\n", - " return corpus, vocab\n", - "\n", - "corpus, vocab = load_corpus_time_machine()\n", - "len(corpus), len(vocab)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  而当词元设置为Word时,输出的语料库和词表大小为 $(32775, 4580)$,结果如下:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "data": { + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], "text/plain": [ - "(32775, 4580)" + "
" ] }, - "execution_count": 3, "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def load_corpus_time_machine(max_tokens=-1): #@save\n", - " \"\"\"返回时光机器数据集的词元索引列表和词表\"\"\"\n", - " lines = d2l.read_time_machine()\n", - " tokens = d2l.tokenize(lines, 'word')\n", - " vocab = d2l.Vocab(tokens)\n", - " # 因为时光机器数据集中的每个文本行不一定是一个句子或一个段落,\n", - " # 所以将所有文本行展平到一个列表中\n", - " corpus = [vocab[token] for line in tokens for token in line]\n", - " if max_tokens > 0:\n", - " corpus = corpus[:max_tokens]\n", - " return corpus, vocab\n", - "\n", - "corpus, vocab = load_corpus_time_machine()\n", - "len(corpus), len(vocab)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  在本节实验中,Vocab 实例的 **min_freq** 参数主要用来实现对低频词的过滤,默认 **min_freq=0**,结果如下:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "文本: ['the', 'time', 'machine', 'by', 'h', 'g', 'wells']\n", - "索引: [1, 19, 50, 40, 2183, 2184, 400]\n", - "文本: ['twinkled', 'and', 'his', 'usually', 'pale', 'face', 'was', 'flushed', 'and', 'animated', 'the']\n", - "索引: [2186, 3, 25, 1044, 362, 113, 7, 1421, 3, 1045, 1]\n" - ] - } - ], - "source": [ - "lines = d2l.read_time_machine()\n", - "tokens = d2l.tokenize(lines, 'word')\n", - "vocab=d2l.Vocab(tokens, min_freq=0)\n", - "\n", - "for i in [0, 10]:\n", - " print('文本:', tokens[i])\n", - " print('索引:', vocab[tokens[i]])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  设置 **min_freq=10**,过滤掉低频词,所得对词表的影响结果如下:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "文本: ['the', 'time', 'machine', 'by', 'h', 'g', 'wells']\n", - "索引: [1, 19, 50, 40, 0, 0, 0]\n", - "文本: ['twinkled', 'and', 'his', 'usually', 'pale', 'face', 'was', 'flushed', 'and', 'animated', 'the']\n", - "索引: [0, 3, 25, 0, 362, 113, 7, 0, 3, 0, 1]\n" - ] - } - ], - "source": [ - "vocab=d2l.Vocab(tokens, min_freq=10)\n", - "\n", - "for i in [0, 10]:\n", - " print('文本:', tokens[i])\n", - " print('索引:', vocab[tokens[i]])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "## 8.3 语言模型和数据集" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "### 练习 8.3.1\n", - "\n", - "假设训练数据集中有$100,000$个单词。一个四元语法需要存储多少个词频和相邻多词频率?" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "**解答:**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  对于$100000$个单词,四元组需要存储每个单词的词频,即$P(W_i), i\\in (1,100000)$,总计$100000$个。 \n", - "  计算相邻多词频率时,由于二元组、三元组以及四元组构成的词不同,因此需要将每个元组构成的词进行存储,例如“The War of Worlds”,若想知道$P(\\text{The War of Worlds})$的概率,需要知道$P(\\text{The}),P(\\text{War|The}),P(\\text{of|The War}),P(\\text{Worlds|The War of})$的概率。 \n", - "  因此对于$100000$个单词,其二相邻频率为 \n", - "$$(1 \\times 10^5)^2$$\n", - "  三相邻频率为 \n", - "$$(1 \\times 10^5)^3$$\n", - "  四相邻频率为 \n", - "$$(1 \\times 10^5)^4$$\n", - "  共需要存储$1 \\times 10^{10}+1 \\times 10^{15} + 1 \\times 10^{20}$个相邻多词频率。 " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "### 练习 8.3.2\n", - "\n", - "我们如何对一系列对话建模?" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "**解答:**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  N-Gram语法将文本看作一系列的词序列,假设当前词的出现概率与前$N-1$个词相关,根据出现的频率来估计下一个词出现的概率。因此对于对话建模问题,理论上可以将对话中的每个句子看作一个独立的文本序列,并对每个文本序列计算出对应的词频和N-Gram概率,在预测下一个词时,将前$N-1$个词视作上下文,选择出现概率最高的词作为下一个词。重复上述过程,直到生成整个对话。 \n", - "  但是N-Gram对于长距离的依赖关系并不能处理的很好,因此将对话语句处理为独立文本序列后,还可以采用RNN或LSTM进行建模。" - ] - }, - { - "cell_type": "code", - "execution_count": 113, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Hi, you you you you you you you you you you\n" - ] + "output_type": "display_data" } ], "source": [ - "from nltk import word_tokenize\n", - "from nltk.util import ngrams\n", - "from collections import Counter, defaultdict\n", - "\n", - "# 读取对话数据\n", - "dialogues = [\n", - " \"Hi,how are you?\",\n", - " \"I'm fine, thank you. And you?\",\n", - " \"I'm good too.\",\n", - " \"That's great!\",\n", - " \"Thanks.\",\n", - "]\n", - "\n", - "# 将对话数据转换为词序列\n", - "tokens = []\n", - "for dialogue in dialogues:\n", - " tokens += word_tokenize(dialogue)\n", - "\n", - "# 统计词频\n", - "word_freq = Counter(tokens)\n", - "\n", - "# 将word_freq字典的键的默认值设置为1\n", - "default_freq = 1\n", - "word_freq = defaultdict(int, {k: v for k, v in word_freq.items()})\n", - "word_freq.default_factory = lambda: default_freq\n", - "\n", - "# 计算N-gram概率\n", - "n = 3 # 选择3-gram模型\n", - "ngrams_freq = Counter(ngrams(tokens, n))\n", - "ngrams_prob = {}\n", - "for ngram in ngrams_freq:\n", - " context = ' '.join(ngram[:-1])\n", - " if context not in ngrams_prob:\n", - " ngrams_prob[context] = {}\n", - " ngrams_prob[context][ngram[-1]] = ngrams_freq[ngram] / word_freq[context]\n", - "\n", - "# 使用N-gram模型生成对话\n", - "start = 'Hi,'\n", - "dialogue = [start]\n", - "for i in range(10):\n", - " context = ' '.join(dialogue[-n+1:])\n", - " if context in ngrams_prob:\n", - " next_word = max(ngrams_prob[context], key=ngrams_prob[context].get)\n", - " else:\n", - " next_word = max(word_freq, key=word_freq.get)\n", - " dialogue.append(next_word)\n", - "print(' '.join(dialogue))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "### 练习 8.3.3\n", - "\n", - "一元语法、二元语法和三元语法的齐普夫定律的指数是不一样的,能设法估计么?" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "**解答:**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  齐普夫定律认为,排名与频率之间存在的负相关的关系。即对于大规模文本,将其中每个词的词频进行统计,并由高到低排序标号,则这些单词的频数$F$和这些单词的序号$R$之间存在一个常数$C$,满足$F \\times R=C$,以\"The Time Machine\" 这篇文章为例,其各语法的齐普夫定律指数如下:" + "net_2 = new_get_net()\n", + "train_draw(net_2, train_iter, loss, 5, 0.01)" ] }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 16, "metadata": { "pycharm": { "name": "#%%\n" @@ -20527,12 +18715,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-04-14T21:02:05.473377\n", + " 2023-06-20T23:49:14.899483\n", " image/svg+xml\n", " \n", " \n", @@ -20547,587 +18735,2607 @@ " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\" transform=\"scale(0.015625)\"/>\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" ], "text/plain": [ - "
" + "
" ] }, "metadata": {}, "output_type": "display_data" } - ], + ], + "source": [ + "onestep_preds = net_1(features)\n", + "d2l.plot([time, time[tau:]],\n", + " [x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n", + " 'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n", + " figsize=(6, 3))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  通过增加多层感知机中隐藏层的大小和层数,可以发现loss相较于初始的多层感知机出现降低,说明复杂结构在某种程度上提升了模型性能,但是可以发现复杂结构的loss在第2个epoch出现增高,导致曲线振荡。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### 练习 8.1.2\n", + "\n", + "一位投资者想要找到一种好的证券来购买。他查看过去的回报,以决定哪一种可能是表现良好的。这一策略可能会出什么问题呢?" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "**解答:**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  该投资者仅采用过去的回报作为参考信息来预测未来的证券表现,属于外推法。 \n", + "  外推法在预测未来数据时,存在假设局限性,即数据必须满足其过去历史可以用来预测未来的趋势。如果假设不成立,那么预测结果很可能会不准确。且外推法不能预测突发事件的发生,并且预测结果非常依赖数据质量。如果历史数据不准确,预测结果也会受到影响。 \n", + "  对于金融证券来说,决定证券回报的因素不仅仅包括过去的回报曲线,还包括政治事件、行业趋势和金融危机等外部因素,而这些因素往往是决定证券表现的重要因素。因此在使用外推法预测证券未来回报时,很难考虑到随时变化的市场环境以及突发的外部因素,从而带来产生巨大误差。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### 练习 8.1.3\n", + "\n", + "时间是向前推进的因果模型在多大程度上适用于文本呢?" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "**解答:**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  时间是向前推进的因果模型是一种基于时间序列和因果关系的模型,该模型假设未来的结果受过去因素和时间序列影响,并且无法反过来影响过去因素。 \n", + "  在文本数据中,对于按时间顺序排列的文本数据,如新闻报道、历史日志等,这些内容在理解和预测发展时,时间因素至关重要,可以更好的帮助模型去学习。因此在舆情分析、时序预测等与时间因素强相关的任务中,时间是向前推进的因果模型可以起到很大作用。 \n", + "  但时间序列并非影响未来文本结果产生的唯一因素,在一些静态文本分类或情感分析任务中,文本本身可能足以进行准确的预测,而不需要考虑时间信息,此时时序与文本内容的产生并无显著关联。在这种情况下采用时间前向推进的因果模型并不能很好的适用于文本场景。需要通过建立更加复杂的模型来考虑影响文本内容的多种因素。 \n", + "  因此在处理文本数据时,需要根据具体任务细节来评估时间因素对于任务的重要性,以此来选择合适的模型结构和特征工程方法。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### 练习 8.1.4\n", + "\n", + "举例说明什么时候可能需要隐变量自回归模型来捕捉数据的动力学模型。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "**解答:**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  隐变量自回归模型是一种用于建模时间序列数据中的动态变化的自回归模型,该模型通过引入一些隐变量来捕捉数据的动态性质。 \n", + "  动力学模型则是一种描述系统状态随时间变化而变化的数学模型,在该模型中,系统的状态可以用一组状态变量来表示,并且这些状态变量随着时间的推移而发生变化。 \n", + "  因此,在实际场景中,股票价格变化通常受时间序列和市场因素的影响,可通过隐变量自回归模型自回归的建模股票在时间上的依赖关系,并通过隐变量来捕捉未观测到的市场信息,最终建立起股票价格的动力学模型。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## 8.2 文本预处理 " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### 练习 8.2.1\n", + "\n", + "词元化是一个关键的预处理步骤,它因语言而异。尝试找到另外三种常用的词元化文本的方法。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "**解答:**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  三种常用的词元化文本的方法如下:\n", + "1. BPE(Byte-Pair Encoding):字节对编码,该方法本质是一种贪心算法,具体流程如下:\n", + " - 在语料库中单词末尾添加
\\,并统计该单词出现的次数\n", + " - 将单词切分为单个字符作为子词,并以此构建出初始子词词表\n", + " - 在语料库中,统计单词内相邻子词对的频数\n", + " - 将频数最高的子词对合并为新的子词,并加入到子词词表中\n", + " - 重复上述两步,直到达到合并次数或子词词表中所有子词的频数均为1。\n", + " 通过该方法,对语料库实现了数据压缩,实现通过最少的token数表示一个corpus\n", + "\n", + "2. WordPiece:WordPiece与BPE方法类似,本质也是一种贪心,但是不同于BPE选择出现频率最高的两个子词合并为新的子词,WordPiece选择具有较强关联性的子词进行合并。具体流程如下:\n", + " - 将语料库中单词切分为单个字符作为初始化的子词词表,假设每个子词独立,此时语言模型似然值等价于子词概率积\n", + " - 两两拼接子词,并统计新子词加入词表后对语言模型似然值的提升程度\n", + " - 最终选择对语言模型似然度提升最大的字符加入到词表中\n", + " - 重复上述两步,直到词表大小达到指定大小。\n", + "\n", + "3. SentencePiece:不同于BPE和WordPiece中采用空格区分不同单词的方式,SentencePiece将空格也进行编码,不考虑句子中单独的单词,而是将整个句子作为整体进行拆分。再通过BPE或Unigram的方式构造词表。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### 练习 8.2.2\n", + "\n", + "在本节的实验中,将文本词元为单词和更改`Vocab`实例的`min_freq`参数。这对词表大小有何影响?" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "**解答:**" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import collections\n", + "import re\n", + "\n", + "import numpy as np\n", + "from d2l import torch as d2l" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  词元设置为Char时,输出的语料库和词表大小为 $(170580, 28)$,结果如下:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "def load_corpus_time_machine(max_tokens=-1): #@save\n", + " \"\"\"返回时光机器数据集的词元索引列表和词表\"\"\"\n", + " lines = d2l.read_time_machine() # 获取时光机器数据集文本行的数据\n", + " tokens = d2l.tokenize(lines, 'char') # 对文本进行分词,将字符级转换为词元\n", + " vocab = d2l.Vocab(tokens) # 创建词汇表对象,将词元映射到索引\n", + " \n", + " # 因为时光机器数据集中的每个文本行不一定是一个句子或一个段落,\n", + " # 所以将所有文本行展平到一个列表中\n", + " corpus = [vocab[token] for line in tokens for token in line]\n", + " \n", + " # max_tokens主要用于限制corpus长度\n", + " if max_tokens > 0:\n", + " corpus = corpus[:max_tokens]\n", + " \n", + " # 返回列表和词汇表\n", + " return corpus, vocab\n", + "\n", + "corpus, vocab = load_corpus_time_machine()\n", + "# 打印列表和词汇表大小\n", + "len(corpus), len(vocab)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "print(vocab.idx_to_token)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  而当词元设置为Word时,输出的语料库和词表大小为 $(32775, 4580)$,结果如下:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "def load_corpus_time_machine(max_tokens=-1): #@save\n", + " \"\"\"返回时光机器数据集的词元索引列表和词表\"\"\"\n", + " lines = d2l.read_time_machine()\n", + " tokens = d2l.tokenize(lines, 'word')\n", + " vocab = d2l.Vocab(tokens)\n", + " # 因为时光机器数据集中的每个文本行不一定是一个句子或一个段落,\n", + " # 所以将所有文本行展平到一个列表中\n", + " corpus = [vocab[token] for line in tokens for token in line]\n", + " if max_tokens > 0:\n", + " corpus = corpus[:max_tokens]\n", + " return corpus, vocab\n", + "\n", + "corpus, vocab = load_corpus_time_machine()\n", + "len(corpus), len(vocab)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  在本节实验中,Vocab 实例的 **min_freq** 参数主要用来实现对低频词的过滤,默认 **min_freq=0**,结果如下:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "lines = d2l.read_time_machine()\n", + "tokens = d2l.tokenize(lines, 'word')\n", + "vocab=d2l.Vocab(tokens, min_freq=0)\n", + "\n", + "for i in [0, 10]:\n", + " print('文本:', tokens[i])\n", + " print('索引:', vocab[tokens[i]])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  设置 **min_freq=10**,过滤掉低频词,所得对词表的影响结果如下:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "vocab=d2l.Vocab(tokens, min_freq=10)\n", + "\n", + "for i in [0, 10]:\n", + " print('文本:', tokens[i])\n", + " print('索引:', vocab[tokens[i]])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  可以看到,将文本词元从字符(Char)改为单词(Word)后,词表大小由28个字符扩大到4580个单词,语料内容也从根据字符划分变更为根据单词划分。 \n", + "  min_frep 参数则是主要用于对低频词进行过滤,通过进一步过滤掉低频词减少词表大小,降低噪音影响,使得模型更加高效,有助于提高模型的学习效果。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## 8.3 语言模型和数据集" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### 练习 8.3.1\n", + "\n", + "假设训练数据集中有$100,000$个单词。一个四元语法需要存储多少个词频和相邻多词频率?" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "**解答:**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  对于$100000$个单词,四元组需要存储每个单词的词频,即$P(W_i), i\\in (1,100000)$,总计$100000$个。 \n", + "  计算相邻多词频率时,由于二元组、三元组以及四元组构成的词不同,因此需要将每个元组构成的词进行存储,例如“The War of Worlds”,若想知道$P(\\text{The War of Worlds})$的概率,需要知道$P(\\text{The}),P(\\text{War|The}),P(\\text{of|The War}),P(\\text{Worlds|The War of})$的概率。 \n", + "  因此对于$100000$个单词,其二相邻频率为 \n", + "$$(1 \\times 10^5)^2$$\n", + "  三相邻频率为 \n", + "$$(1 \\times 10^5)^3$$\n", + "  四相邻频率为 \n", + "$$(1 \\times 10^5)^4$$\n", + "  共需要存储$1 \\times 10^{10}+1 \\times 10^{15} + 1 \\times 10^{20}$个相邻多词频率。 " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### 练习 8.3.2\n", + "\n", + "我们如何对一系列对话建模?" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "**解答:**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  N-Gram语法将文本看作一系列的词序列,假设当前词的出现概率与前$N-1$个词相关,根据出现的频率来估计下一个词出现的概率。因此对于对话建模问题,理论上可以将对话中的每个句子看作一个独立的文本序列,并对每个文本序列计算出对应的词频和N-Gram概率,在预测下一个词时,将前$N-1$个词视作上下文,选择出现概率最高的词作为下一个词。重复上述过程,直到生成整个对话。 \n", + "  但是N-Gram对于长距离的依赖关系并不能处理的很好,因此将对话语句处理为独立文本序列后,还可以采用RNN或LSTM进行建模。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "from nltk import word_tokenize\n", + "from nltk.util import ngrams\n", + "from collections import Counter, defaultdict\n", + "\n", + "# 读取对话数据\n", + "dialogues = [\n", + " \"Hi,how are you?\",\n", + " \"I'm fine, thank you. And you?\",\n", + " \"I'm good too.\",\n", + " \"That's great!\",\n", + " \"Thanks.\",\n", + "]\n", + "\n", + "# 将对话数据转换为词序列\n", + "tokens = []\n", + "for dialogue in dialogues:\n", + " tokens += word_tokenize(dialogue)\n", + "\n", + "# 统计词频\n", + "word_freq = Counter(tokens)\n", + "\n", + "# 将word_freq字典的键的默认值设置为1\n", + "default_freq = 1\n", + "word_freq = defaultdict(int, {k: v for k, v in word_freq.items()}) # 创建默认字典\n", + "word_freq.default_factory = lambda: default_freq # 若访问不存在的键时,返回default_frep以避免KeyError\n", + "\n", + "# 计算N-gram概率\n", + "n = 3 # 选择3-gram模型\n", + "ngrams_freq = Counter(ngrams(tokens, n))\n", + "ngrams_prob = {}\n", + "for ngram in ngrams_freq:\n", + " context = ' '.join(ngram[:-1]) # 将前n-1个词连接,组成context\n", + " if context not in ngrams_prob:\n", + " # 若上下文内容不存在于字典中,则添加空字典作为条目\n", + " ngrams_prob[context] = {}\n", + " \n", + " # 计算当前N_gram的条件概率 \n", + " ngrams_prob[context][ngram[-1]] = ngrams_freq[ngram] / word_freq[context]\n", + "\n", + "# 使用N-gram模型生成对话\n", + "start = 'Hi,'\n", + "dialogue = [start]\n", + "for i in range(10):\n", + " context = ' '.join(dialogue[-n+1:])\n", + " if context in ngrams_prob:\n", + " next_word = max(ngrams_prob[context], key=ngrams_prob[context].get)\n", + " else:\n", + " next_word = max(word_freq, key=word_freq.get)\n", + " dialogue.append(next_word)\n", + "print(' '.join(dialogue))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### 练习 8.3.3\n", + "\n", + "一元语法、二元语法和三元语法的齐普夫定律的指数是不一样的,能设法估计么?" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "**解答:**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  齐普夫定律认为,排名与频率之间存在的负相关的关系。即对于大规模文本,将其中每个词的词频进行统计,并由高到低排序标号,则这些单词的频数$F$和这些单词的序号$R$之间存在一个常数$C$,满足$F \\times R=C$,以\"The Time Machine\" 这篇文章为例,其各语法的齐普夫定律指数如下:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], "source": [ "tokens = d2l.tokenize(d2l.read_time_machine())\n", "# 因为每个文本行不一定是一个句子或一个段落,因此我们把所有文本行拼接到一起\n", @@ -21422,6 +22062,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21435,6 +22076,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21446,6 +22088,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21459,6 +22102,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21475,6 +22119,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21486,6 +22131,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21497,6 +22143,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21512,150 +22159,14 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": null, "metadata": { + "collapsed": false, "pycharm": { "name": "#%%\n" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "i= 0\n", - "X: tensor([[0, 1, 2, 3, 4],\n", - " [5, 6, 7, 8, 9]]) \n", - "Y: tensor([[ 1, 2, 3, 4, 5],\n", - " [ 6, 7, 8, 9, 10]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [15, 16, 17, 18, 19]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [16, 17, 18, 19, 20]])\n", - "X: tensor([[20, 21, 22, 23, 24],\n", - " [25, 26, 27, 28, 29]]) \n", - "Y: tensor([[21, 22, 23, 24, 25],\n", - " [26, 27, 28, 29, 30]])\n", - "i= 1\n", - "X: tensor([[0, 1, 2, 3, 4],\n", - " [5, 6, 7, 8, 9]]) \n", - "Y: tensor([[ 1, 2, 3, 4, 5],\n", - " [ 6, 7, 8, 9, 10]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [15, 16, 17, 18, 19]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [16, 17, 18, 19, 20]])\n", - "X: tensor([[20, 21, 22, 23, 24],\n", - " [25, 26, 27, 28, 29]]) \n", - "Y: tensor([[21, 22, 23, 24, 25],\n", - " [26, 27, 28, 29, 30]])\n", - "i= 2\n", - "X: tensor([[0, 1, 2, 3, 4],\n", - " [5, 6, 7, 8, 9]]) \n", - "Y: tensor([[ 1, 2, 3, 4, 5],\n", - " [ 6, 7, 8, 9, 10]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [15, 16, 17, 18, 19]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [16, 17, 18, 19, 20]])\n", - "X: tensor([[20, 21, 22, 23, 24],\n", - " [25, 26, 27, 28, 29]]) \n", - "Y: tensor([[21, 22, 23, 24, 25],\n", - " [26, 27, 28, 29, 30]])\n", - "i= 3\n", - "X: tensor([[0, 1, 2, 3, 4],\n", - " [5, 6, 7, 8, 9]]) \n", - "Y: tensor([[ 1, 2, 3, 4, 5],\n", - " [ 6, 7, 8, 9, 10]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [15, 16, 17, 18, 19]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [16, 17, 18, 19, 20]])\n", - "X: tensor([[20, 21, 22, 23, 24],\n", - " [25, 26, 27, 28, 29]]) \n", - "Y: tensor([[21, 22, 23, 24, 25],\n", - " [26, 27, 28, 29, 30]])\n", - "i= 4\n", - "X: tensor([[0, 1, 2, 3, 4],\n", - " [5, 6, 7, 8, 9]]) \n", - "Y: tensor([[ 1, 2, 3, 4, 5],\n", - " [ 6, 7, 8, 9, 10]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [15, 16, 17, 18, 19]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [16, 17, 18, 19, 20]])\n", - "X: tensor([[20, 21, 22, 23, 24],\n", - " [25, 26, 27, 28, 29]]) \n", - "Y: tensor([[21, 22, 23, 24, 25],\n", - " [26, 27, 28, 29, 30]])\n", - "i= 5\n", - "X: tensor([[0, 1, 2, 3, 4],\n", - " [5, 6, 7, 8, 9]]) \n", - "Y: tensor([[ 1, 2, 3, 4, 5],\n", - " [ 6, 7, 8, 9, 10]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [15, 16, 17, 18, 19]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [16, 17, 18, 19, 20]])\n", - "X: tensor([[20, 21, 22, 23, 24],\n", - " [25, 26, 27, 28, 29]]) \n", - "Y: tensor([[21, 22, 23, 24, 25],\n", - " [26, 27, 28, 29, 30]])\n", - "i= 6\n", - "X: tensor([[0, 1, 2, 3, 4],\n", - " [5, 6, 7, 8, 9]]) \n", - "Y: tensor([[ 1, 2, 3, 4, 5],\n", - " [ 6, 7, 8, 9, 10]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [15, 16, 17, 18, 19]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [16, 17, 18, 19, 20]])\n", - "X: tensor([[20, 21, 22, 23, 24],\n", - " [25, 26, 27, 28, 29]]) \n", - "Y: tensor([[21, 22, 23, 24, 25],\n", - " [26, 27, 28, 29, 30]])\n", - "i= 7\n", - "X: tensor([[0, 1, 2, 3, 4],\n", - " [5, 6, 7, 8, 9]]) \n", - "Y: tensor([[ 1, 2, 3, 4, 5],\n", - " [ 6, 7, 8, 9, 10]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [15, 16, 17, 18, 19]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [16, 17, 18, 19, 20]])\n", - "X: tensor([[20, 21, 22, 23, 24],\n", - " [25, 26, 27, 28, 29]]) \n", - "Y: tensor([[21, 22, 23, 24, 25],\n", - " [26, 27, 28, 29, 30]])\n", - "i= 8\n", - "X: tensor([[0, 1, 2, 3, 4],\n", - " [5, 6, 7, 8, 9]]) \n", - "Y: tensor([[ 1, 2, 3, 4, 5],\n", - " [ 6, 7, 8, 9, 10]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [15, 16, 17, 18, 19]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [16, 17, 18, 19, 20]])\n", - "X: tensor([[20, 21, 22, 23, 24],\n", - " [25, 26, 27, 28, 29]]) \n", - "Y: tensor([[21, 22, 23, 24, 25],\n", - " [26, 27, 28, 29, 30]])\n", - "i= 9\n", - "X: tensor([[0, 1, 2, 3, 4],\n", - " [5, 6, 7, 8, 9]]) \n", - "Y: tensor([[ 1, 2, 3, 4, 5],\n", - " [ 6, 7, 8, 9, 10]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [15, 16, 17, 18, 19]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [16, 17, 18, 19, 20]])\n", - "X: tensor([[20, 21, 22, 23, 24],\n", - " [25, 26, 27, 28, 29]]) \n", - "Y: tensor([[21, 22, 23, 24, 25],\n", - " [26, 27, 28, 29, 30]])\n" - ] - } - ], + "outputs": [], "source": [ "my_seq = list(range(35))\n", "\n", @@ -21687,6 +22198,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21697,138 +22209,14 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": null, "metadata": { + "collapsed": false, "pycharm": { "name": "#%%\n" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "i= 0\n", - "X: tensor([[ 5, 6, 7, 8, 9],\n", - " [19, 20, 21, 22, 23]]) \n", - "Y: tensor([[ 6, 7, 8, 9, 10],\n", - " [20, 21, 22, 23, 24]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [24, 25, 26, 27, 28]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [25, 26, 27, 28, 29]])\n", - "i= 1\n", - "X: tensor([[ 0, 1, 2, 3, 4],\n", - " [17, 18, 19, 20, 21]]) \n", - "Y: tensor([[ 1, 2, 3, 4, 5],\n", - " [18, 19, 20, 21, 22]])\n", - "X: tensor([[ 5, 6, 7, 8, 9],\n", - " [22, 23, 24, 25, 26]]) \n", - "Y: tensor([[ 6, 7, 8, 9, 10],\n", - " [23, 24, 25, 26, 27]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [27, 28, 29, 30, 31]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [28, 29, 30, 31, 32]])\n", - "i= 2\n", - "X: tensor([[ 4, 5, 6, 7, 8],\n", - " [19, 20, 21, 22, 23]]) \n", - "Y: tensor([[ 5, 6, 7, 8, 9],\n", - " [20, 21, 22, 23, 24]])\n", - "X: tensor([[ 9, 10, 11, 12, 13],\n", - " [24, 25, 26, 27, 28]]) \n", - "Y: tensor([[10, 11, 12, 13, 14],\n", - " [25, 26, 27, 28, 29]])\n", - "X: tensor([[14, 15, 16, 17, 18],\n", - " [29, 30, 31, 32, 33]]) \n", - "Y: tensor([[15, 16, 17, 18, 19],\n", - " [30, 31, 32, 33, 34]])\n", - "i= 3\n", - "X: tensor([[ 5, 6, 7, 8, 9],\n", - " [19, 20, 21, 22, 23]]) \n", - "Y: tensor([[ 6, 7, 8, 9, 10],\n", - " [20, 21, 22, 23, 24]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [24, 25, 26, 27, 28]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [25, 26, 27, 28, 29]])\n", - "i= 4\n", - "X: tensor([[ 1, 2, 3, 4, 5],\n", - " [17, 18, 19, 20, 21]]) \n", - "Y: tensor([[ 2, 3, 4, 5, 6],\n", - " [18, 19, 20, 21, 22]])\n", - "X: tensor([[ 6, 7, 8, 9, 10],\n", - " [22, 23, 24, 25, 26]]) \n", - "Y: tensor([[ 7, 8, 9, 10, 11],\n", - " [23, 24, 25, 26, 27]])\n", - "X: tensor([[11, 12, 13, 14, 15],\n", - " [27, 28, 29, 30, 31]]) \n", - "Y: tensor([[12, 13, 14, 15, 16],\n", - " [28, 29, 30, 31, 32]])\n", - "i= 5\n", - "X: tensor([[ 1, 2, 3, 4, 5],\n", - " [17, 18, 19, 20, 21]]) \n", - "Y: tensor([[ 2, 3, 4, 5, 6],\n", - " [18, 19, 20, 21, 22]])\n", - "X: tensor([[ 6, 7, 8, 9, 10],\n", - " [22, 23, 24, 25, 26]]) \n", - "Y: tensor([[ 7, 8, 9, 10, 11],\n", - " [23, 24, 25, 26, 27]])\n", - "X: tensor([[11, 12, 13, 14, 15],\n", - " [27, 28, 29, 30, 31]]) \n", - "Y: tensor([[12, 13, 14, 15, 16],\n", - " [28, 29, 30, 31, 32]])\n", - "i= 6\n", - "X: tensor([[ 2, 3, 4, 5, 6],\n", - " [18, 19, 20, 21, 22]]) \n", - "Y: tensor([[ 3, 4, 5, 6, 7],\n", - " [19, 20, 21, 22, 23]])\n", - "X: tensor([[ 7, 8, 9, 10, 11],\n", - " [23, 24, 25, 26, 27]]) \n", - "Y: tensor([[ 8, 9, 10, 11, 12],\n", - " [24, 25, 26, 27, 28]])\n", - "X: tensor([[12, 13, 14, 15, 16],\n", - " [28, 29, 30, 31, 32]]) \n", - "Y: tensor([[13, 14, 15, 16, 17],\n", - " [29, 30, 31, 32, 33]])\n", - "i= 7\n", - "X: tensor([[ 4, 5, 6, 7, 8],\n", - " [19, 20, 21, 22, 23]]) \n", - "Y: tensor([[ 5, 6, 7, 8, 9],\n", - " [20, 21, 22, 23, 24]])\n", - "X: tensor([[ 9, 10, 11, 12, 13],\n", - " [24, 25, 26, 27, 28]]) \n", - "Y: tensor([[10, 11, 12, 13, 14],\n", - " [25, 26, 27, 28, 29]])\n", - "X: tensor([[14, 15, 16, 17, 18],\n", - " [29, 30, 31, 32, 33]]) \n", - "Y: tensor([[15, 16, 17, 18, 19],\n", - " [30, 31, 32, 33, 34]])\n", - "i= 8\n", - "X: tensor([[ 5, 6, 7, 8, 9],\n", - " [19, 20, 21, 22, 23]]) \n", - "Y: tensor([[ 6, 7, 8, 9, 10],\n", - " [20, 21, 22, 23, 24]])\n", - "X: tensor([[10, 11, 12, 13, 14],\n", - " [24, 25, 26, 27, 28]]) \n", - "Y: tensor([[11, 12, 13, 14, 15],\n", - " [25, 26, 27, 28, 29]])\n", - "i= 9\n", - "X: tensor([[ 2, 3, 4, 5, 6],\n", - " [18, 19, 20, 21, 22]]) \n", - "Y: tensor([[ 3, 4, 5, 6, 7],\n", - " [19, 20, 21, 22, 23]])\n", - "X: tensor([[ 7, 8, 9, 10, 11],\n", - " [23, 24, 25, 26, 27]]) \n", - "Y: tensor([[ 8, 9, 10, 11, 12],\n", - " [24, 25, 26, 27, 28]])\n", - "X: tensor([[12, 13, 14, 15, 16],\n", - " [28, 29, 30, 31, 32]]) \n", - "Y: tensor([[13, 14, 15, 16, 17],\n", - " [29, 30, 31, 32, 33]])\n" - ] - } - ], + "outputs": [], "source": [ "import random\n", "\n", @@ -21865,6 +22253,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21876,6 +22265,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21887,6 +22277,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21898,6 +22289,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21909,6 +22301,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21922,6 +22315,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21933,6 +22327,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21945,6 +22340,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21956,6 +22352,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21969,6 +22366,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -21980,6 +22378,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -22003,6 +22402,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -22013,24 +22413,14 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { + "collapsed": false, "pycharm": { "name": "#%%\n" } }, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([5, 2, 28])" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", @@ -22047,8 +22437,9 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": { + "collapsed": false, "pycharm": { "name": "#%%\n" } @@ -22113,23 +22504,14 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": null, "metadata": { + "collapsed": false, "pycharm": { "name": "#%%\n" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "X.shape: torch.Size([5, 2, 28])\n", - "Y_list.shape: torch.Size([5, 2, 28])\n", - "state.shape: torch.Size([2, 256])\n" - ] - } - ], + "outputs": [], "source": [ "# 打印输出形状\n", "num_hiddens = 256\n", @@ -22144,6 +22526,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -22157,6 +22540,7 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -22168,17 +22552,21 @@ { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ - "  因为语言模型的目标是根据过去和当前的词元来预测出下一个词元,在循环神经网络中,通过引入隐状态的概念,使得RNN模型可以通过隐状态保留了直到当前时间步的历史信息,根据这些过去历史信息,循环神经网络可以预测出当前词元的生成。" + "  RNN实际上是一种递归循环神经网络,即在每个时间步接收输入和前一时间步的隐藏状态,并输出当前时间步的隐藏状态。这种结构使得RNN能够捕捉序列数据中的时间依赖关系,令先前的信息得以传递到后续的时间步。 \n", + "  因为RNN的隐藏状态中包含了模型对过去所有输入信息的表示,因此在每个时间步进行更新时,RNN的隐藏状态也会随之更新,使得当前输入与上一步的隐藏状态进行结合,形成新的隐藏状态。因此RNN在处理文本序列时,可以根据隐藏状态内的历史上下文信息来增强对当前词元的理解。 \n", + "  通过不断更新隐藏状态并传递历史信息,RNN可以在某个时间步对当前词元进行建模,并输出当前词元的条件概率这使得RNN能够更好地理解文本序列中词与词之间的长期依赖关系,有助于更准确地预测或生成序列中的下一个词。" ] }, { "cell_type": "markdown", "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } @@ -22189,6 +22577,30 @@ "如果基于一个长序列进行反向传播,梯度会发生什么状况?" ] }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "**解答:**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "  基于长序列进行反向传播时,假设序列长度为$\\text{T}$,则迭代计算$\\text{T}$个时间步的梯度时,会产生长度为$O(T)$的矩阵乘法链,造成类似多层感知机中层数为$\\text{T}$的情况,因此当序列长度过长时,会使得梯度值在反向传播进行梯度累乘时放大或缩小梯度值,使得梯度值出现指数级的放大或衰减,导致数值不稳定,出现梯度爆炸或者梯度消失等现象。" + ] + }, { "cell_type": "markdown", "metadata": { @@ -22208,7 +22620,7 @@ } }, "source": [ - "  基于长序列进行反向传播时,假设序列长度为$\\text{T}$,则迭代计算$\\text{T}$个时间步的梯度时,会产生长度为$O(T)$的矩阵乘法链,造成类似多层感知机中层数为$\\text{T}$的情况,因此当序列长度过长时,会使得梯度值在反向传播进行梯度累乘时放大或缩小梯度值,导致数值不稳定,出现梯度爆炸或者梯度消失等现象。" + "  基于长序列进行反向传播时,假设序列长度为$\\text{T}$,则迭代计算$\\text{T}$个时间步的梯度时,会产生长度为$O(T)$的矩阵乘法链,造成类似多层感知机中层数为$\\text{T}$的情况,因此当序列长度过长时,会使得梯度值在反向传播进行梯度累乘时放大或缩小梯度值,使得梯度值出现指数级的放大或衰减,导致数值不稳定,出现梯度爆炸或者梯度消失等现象。" ] }, { @@ -22292,8 +22704,9 @@ } }, "source": [ - "  one-hot编码可以将词表中的每个单词编码为一个独立符号,使得任何词在与词表同样大小的one-hot向量中都具有一个独一无二的维度,实现对不同词的不同嵌入表示。 \n", - "  而嵌入表示建立起一个低维稠密向量空间,也是一种将对象转为低维连续向量的方法,可以根据向量来目标对象间的相似性和差异。因此可将二者视为等价情况,即对于一个N维嵌入向量对表示的N个词,可以将其one-hot编码等价视为嵌入表示中每个词仅有唯一一个元素非零的特殊情况。" + "  独热(one-hot)编码是对每个对象都用一个长度等于对象总数的向量来表示。每个向量的所有元素都是零,只有与对象对应的索引处的元素是1,表示该对象的存在。 \n", + "  嵌入表示(embedding)是用固定大小的稠密向量表示对象,每个对象都与唯一的嵌入向量关联。 \n", + "  如果我们考虑一组对象的独热编码,每个对象将有一个唯一的二进制向量。类似地,如果我们考虑同一组对象的嵌入表示,每个类别将有一个唯一的稠密向量。当假设嵌入表示向量和独热编码向量维度相同时,独热编码可以等价于为每个对象选择了不同的嵌入表示。" ] }, { @@ -22344,7 +22757,7 @@ "source": [ "  困惑度常用来对语言模型好坏进行评估,其公式如下:\n", "$$\\text{perplexity} = exp\\bigg(- \\displaystyle{\\frac{1}{n} \\sum_{t=1}^n \\log P(x_t|x_{t-1},...,x-1)} \\bigg)$$\n", - "  在最好的情况下,模型可以完美估计标签词元的概率为1,即仅有一个词可供模型选择,此时困惑度为1;而模型正确预测次元的概率为0时,这种最坏情况下的困惑度为无穷大。" + "  在最好的情况下,模型可以完美估计标签词元的概率为1,此时困惑度为1;而模型正确预测标签词元的概率为0时,这种最坏情况下的困惑度为无穷大。" ] }, { @@ -22360,7 +22773,18 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from d2l import torch as d2l\n", + "import torch\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": 2, "metadata": { "pycharm": { "name": "#%%\n" @@ -22378,7 +22802,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 3, "metadata": { "pycharm": { "name": "#%%\n" @@ -22441,7 +22865,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 4, "metadata": { "pycharm": { "name": "#%%\n" @@ -22454,7 +22878,7 @@ "text": [ "hyper[0]:1.1\n", "hyper[1]:1.0\n", - "hyper[2]:3.1\n", + "hyper[2]:2.9\n", "hyper[3]:8.6\n" ] } @@ -24332,7 +24756,7 @@ } }, "source": [ - "  相较于时间机器,更换数据集后,文本数据量明显增加,最终预测结果更好。" + "  相较于“时间机器”数据集,更换数据集后,文本数据量明显增加,最终预测结果更好。" ] }, { @@ -25340,6 +25764,24 @@ "1. 会发生什么?" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "  改为采样方式后,每次生成文本时都以一定概率选择不同字符,导致文本在生成时的随机性大大增加,提高了多样性,但是也导致生成文本更加不稳定,出现困惑度爆炸的情况。" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "from d2l import torch as d2l\n", + "import torch\n", + "from torch import nn" + ] + }, { "cell_type": "code", "execution_count": 40, @@ -26287,18 +26729,14 @@ } }, "source": [ - "  可以看到,改为采样方式后,每次生成文本时逗号以一定概率选择不同字符,导致文本在生成时的随机性大大增加,提高了多样性,但是也导致生成文本更加不稳定,出现困惑度爆炸的情况。" + "2. 调整模型使之偏向更可能的输出,例如,当$\\alpha > 1$,从$q(x_t \\mid x_{t-1}, \\ldots, x_1) \\propto P(x_t \\mid x_{t-1}, \\ldots, x_1)^\\alpha$中采样。" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ - "2. 调整模型使之偏向更可能的输出,例如,当$\\alpha > 1$,从$q(x_t \\mid x_{t-1}, \\ldots, x_1) \\propto P(x_t \\mid x_{t-1}, \\ldots, x_1)^\\alpha$中采样。" + "  从$q(x_t \\mid x_{t-1}, \\ldots, x_1) \\propto P(x_t \\mid x_{t-1}, \\ldots, x_1)^\\alpha$中进行采样,较大的$\\alpha$会使得概率较高的字符更容易被选择,以此选择概率较高的字符作为输出。这种方法降低了生成文本的多样性,但同时也使得输出结果更加稳定,困惑度更低。" ] }, { @@ -27266,17 +27704,6 @@ "train_ch8_alpha(net, train_iter, vocab, lr, num_epochs, device)" ] }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  通过设置$\\alpha$来进一步调整分布形状,较大的$\\alpha$会从$q(x_t \\mid x_{t-1}, \\ldots, x_1) \\propto P(x_t \\mid x_{t-1}, \\ldots, x_1)^\\alpha$中进行采样,通过增加概率指数来使得概率较高的字符更容易被选择,以此选择概率较高的字符作为输出。这种方法降低了生成文本的多样性,但同时也使得输出结果更加稳定,困惑度更低。" - ] - }, { "cell_type": "markdown", "metadata": { @@ -27309,10 +27736,10 @@ } }, "source": [ - "  原文内容中指出:\n", + "  根据原文内容的说明:\n", "> 对于长度为$\\text{T}$的序列,我们在迭代中计算这$\\text{T}$个时间步上的梯度,将会在反向传播过程中产生长度为$O(T)$的矩阵乘法链。当$\\text{T}$较大时,它可能导致数值不稳定,例如可能导致梯度爆炸或梯度消失。因此,循环神经网络模型往往需要额外的方式来支持稳定训练。\n", "\n", - "  因此在不裁剪梯度的情况下,会出现数值不稳定的现象。使得相较于裁剪梯度,模型在训练时的困惑度曲线出现了波动,导致曲线出现不平滑的现象。" + "  因此在不裁剪梯度的情况下,会出现数值不稳定的现象。相较于裁剪梯度,模型在训练时的困惑度曲线出现了波动,导致曲线出现不平滑的现象,且难以收敛。" ] }, { @@ -29176,7 +29603,8 @@ } }, "source": [ - "  对顺序划分进行修改,改为先对输入和标签进行设备变换和形状变换,再进行前向计算和反向传播,将隐状态的分离操作放在更新之前,避免在更新中对隐状态进行计算,无需对隐状态进行修改,因此可以避免隐状态从计算图中分离的问题。" + "  先对输入和标签进行设备(device)变换和形状(reshape)变换,再进行前向计算和反向传播,将隐状态的分离操作放在更新之前,避免了更新中对隐状态进行计算,这样无需对隐状态进行修改,即可实现了不会从计算图中分离隐状态。 \n", + "  通过实验知,这样的修改加快了词元速度,使得运行时间得到了减少,但困惑度也随之增加。" ] }, { @@ -30838,17 +31266,6 @@ "train_ch8(net, train_iter, vocab, lr, num_epochs, device)" ] }, - { - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "  可以看到,这样的修改加快了词元速度,使得运行时间得到了减少,但困惑度也随之增加。" - ] - }, { "cell_type": "markdown", "metadata": { @@ -30881,8 +31298,9 @@ } }, "source": [ - "  使用ReLU激活函数后,无需进行梯度裁剪。因为原始RNN中采用tanh作为激活函数,而tanh会使得梯度在反向传播时不断缩小或放大,导致出现梯度爆炸或梯度消失的现象。 \n", - "  因此需要进行梯度裁剪,避免因过长的求导链出现数值不稳定的现象。而ReLU作为激活函数时,其梯度在输入为正时恒为1,输入为负时恒为0,因此在反向传播时通常不会出现数值不稳定现象,无需进行梯度裁剪。" + "  原始RNN中采用tanh作为激活函数,tanh在输入很大或很小时,导数几乎为0,累乘会导致梯度消失。 \n", + "  ReLU作为激活函数时,在输入为正时,梯度恒为1,在一定程度上ReLU缓解了梯度消失或爆炸,但由于梯度计算除了激活函数导数外,还包括权重连乘,并没有彻底解决梯度消失或爆炸问题。此外,ReLU在输入为负时,梯度恒为0,会导致一部分神经元未激活。 \n", + "  此处,ReLu作为激活函数,会缓解反向传播时出现数值不稳定的现象,可以不进行梯度裁剪。但一般仅ReLU作为激活函数这一方法,并替代不了梯度裁剪解决梯度消失或爆炸问题。" ] }, { @@ -31827,18 +32245,14 @@ } }, "source": [ - "  在高级API中,针对循环神经网络进行了一系列优化,例如nn.RNN源码中增加了可开启的Dropout模块,以及可选的ReLU激活函数,d2l.train_ch8函数中也默认开启了梯度裁剪,这些措施均可以很好的缓解模型在训练时的过拟合问题,使得模型不会轻易发生梯度爆炸。 \n", - "  但是当学习率设置过大时,例如由1增加到10,过大的学习率会使得模型在优化时无法找到最优值,反而螺旋上升,发生梯度爆炸。" + "  在高级API中,针对循环神经网络进行了一系列优化,例如nn.RNN源码中增加了可开启的dropout模块,这些措施均可以很好的缓解模型在训练时的过拟合问题。 \n", + "  但是设置不合理的参数,仍可使得模型过拟合,例如对于小的训练集,设置大的模型规模,如当隐藏单元数是1024,隐层数是3时,训练集上模型很快降到困惑度为1,但这并不意味着此处模型预测性能很好,而是发生了过拟合。" ] }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "execution_count": 24, + "metadata": {}, "outputs": [], "source": [ "import torch\n", @@ -31849,26 +32263,23 @@ "batch_size, num_steps = 32, 35\n", "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)\n", "\n", - "num_hiddens = 256\n", - "rnn_layer = nn.RNN(len(vocab), num_hiddens)" + "num_hiddens = 1024\n", + "num_layers = 3\n", + "rnn_layer = nn.RNN(len(vocab), num_hiddens, num_layers=num_layers)" ] }, { "cell_type": "code", - "execution_count": 6, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "execution_count": 25, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "perplexity 62334304447965890753852937863168.0, 372993.5 tokens/sec on cuda:0\n", - "time travellereeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee\n", - "travellereeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee\n" + "perplexity 1.0, 25271.1 tokens/sec on cuda:0\n", + "time travelleryou can show black is white by argument said filby\n", + "travelleryou can show black is white by argument said filby\n" ] }, { @@ -31877,12 +32288,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-04-17T14:50:12.918199\n", + " 2024-01-28T13:36:10.349639\n", " image/svg+xml\n", " \n", " \n", @@ -31897,42 +32308,42 @@ " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", " \n", - " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -32738,7 +33157,7 @@ "device = d2l.try_gpu()\n", "net = d2l.RNNModel(rnn_layer, vocab_size=len(vocab))\n", "net = net.to(device)\n", - "num_epochs, lr = 500, 10\n", + "num_epochs, lr = 500, 1\n", "d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)" ] }, @@ -32774,7 +33193,7 @@ } }, "source": [ - "  增加隐藏层的数量会使得模型变得更加复杂,一定程度上可以提高模型预测效果,但过于复杂的模型也会导致出现过拟合现象。例如将隐藏层由1增加到3,相较于单层隐藏层,模型的困惑度由1.2降低到1,但是token的生成速度也出现了明显下降。" + "  增加隐藏层的数量会使得模型变得更加复杂,一定程度上可以提高模型预测效果,但过于复杂的模型也会导致出现过拟合现象。当隐藏层更多时,模型训练难以收敛,发生欠拟合。例如将隐藏层由1增加到3,相较于单层隐藏层,模型的困惑度由1.2降低到1,而增加到20层时,模型困惑度训练完后反倒是17.4。同时,更多的隐藏层,token的生成速度也出现了明显下降。" ] }, { @@ -34414,7 +34833,7 @@ } }, "source": [ - "  增加隐藏层数到20时,训练速度出现明显下降,模型也发生过拟合。" + "  增加隐藏层数到20时,训练速度出现明显下降,且模型无法收敛,发生欠拟合。" ] }, { @@ -34427,7 +34846,7 @@ }, "outputs": [], "source": [ - "# 隐藏层为100\n", + "# 隐藏层为20\n", "num_hiddens = 256\n", "num_layers = 20\n", "rnn_layer = nn.RNN(len(vocab), num_hiddens, num_layers=num_layers)" @@ -39953,20 +40372,19 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ - "  对于对称矩阵$M$,其特征值为$\\lambda_i$,且拥有特征向量$v_i$,因此该对称矩阵满足:\n", - "$$M \\cdot v_i=\\lambda_i \\cdot v_i$$\n", - "  将上述公式两边同时乘上对称矩阵$M$后,公式满足:\n", - "$$M \\cdot M \\cdot v_i = M \\cdot \\lambda_i \\cdot v_i$$\n", - "$$M^2 \\cdot v_i = \\lambda_i \\cdot (M \\cdot v_i) = {\\lambda_i}^2 \\cdot v_i$$\n", - "  重复上述步骤,对于$M^k$,存在:\n", - "$$M^k \\cdot v_i = {\\lambda_i}^k \\cdot v_i$$\n", - "  证得矩阵$M^k$具有特征值${\\lambda_i}^k$" + "  首先,由于$M$是对称矩阵,它可以被对角化为:\n", + "$$M = PDP^{-1}$$\n", + "  其中,$P$ 是包含$M$的特征向量的矩阵,而$D$是对角矩阵,其对角线上的元素是$M$的特征值。 \n", + "  现在,考虑$M^k$,我们有:\n", + "$$M^k = (PDP^{-1})^k = PD^kP^{-1}$$\n", + "  $D^k$是对角矩阵,其对角线上的元素是$M$的特征值$\\lambda$的$k$次方$\\lambda^k$。因此,我们可以得到:\n", + "$$D^k = \\begin{bmatrix} \\lambda_1^k & 0 & \\cdots & 0 \\\\ 0 & \\lambda_2^k & \\cdots & 0 \\\\ \\vdots & \\vdots & \\ddots & \\vdots \\\\ 0 & 0 & \\cdots & \\lambda_n^k \\end{bmatrix}$$\n", + "  其中,$\\lambda_i$是$M$的第$i$个特征值。 \n", + "  最后,我们得到$M^k$的对角化形式:\n", + "$$M^k = PD^kP^{-1}$$\n", + "  这意味着矩阵$M^k$与矩阵$M$具有相同的特征向量,而其特征值是$M$的特征值$\\lambda$的$k$次方$\\lambda^k$。因此,我们证明了如果$\\lambda$是$M$的特征值,那么$\\lambda^k$就是$M^k$的特征值。" ] }, { @@ -39989,7 +40407,7 @@ }, "source": [ "  对于随机向量$x \\in R^n$,可将其分解到特征向量$V$所在的向量空间中,即对于分解系数$\\alpha_i$,存在分解式:\n", - "$$x = \\alpha_1 v1+\\alpha_2 v2+ ... + \\alpha_n v_n = \\sum_{i=1}^n \\alpha_i v_i$$\n", + "$$x = \\alpha_1 v_1+\\alpha_2 v_2+ ... + \\alpha_n v_n = \\sum_{i=1}^n \\alpha_i v_i$$\n", "  将上述公式两边同乘$M^k$,存在:\n", "$$\\begin{align}\n", "M^k \\cdot x\n", @@ -39997,7 +40415,7 @@ "&= \\lambda_i^k \\cdot \\sum_{i=1}^n \\alpha_i v_i \\\\\n", "&= \\sum_{i=1}^n \\lambda_i^k \\alpha_i v_i\n", "\\end{align}$$\n", - "  又因为$M$的特征值$\\lambda_i$满足$|\\lambda_i| \\geq |\\lambda_{i+1}|$,因此$lambda_1^k >> lambda_i$,即$\\lambda_1^k$的权重最大。 \n", + "  又因为$M$的特征值$\\lambda_i$满足$|\\lambda_i| \\geq |\\lambda_{i+1}|$,因此$\\lambda_1^k$指数级最大,即$v_1$的权重最大。 \n", "  因此$M^k \\cdot x \\approx \\lambda_1^k \\alpha_1 v_1$,即存在较高概率与特征向量$v_1$在一条直线上。" ] }, @@ -40061,7 +40479,8 @@ "  除了进行梯度截断,还可采用下述方法解决循环神经网络中的梯度爆炸:\n", "- 长短期记忆网络:可以通过引入门控机制或注意力机制的方法,减少梯度异常发生的概率,对循环神经网络进行改进。\n", "- 对权重进行正则化,通过添加L1或L2正则项减少梯度爆炸发生的概率\n", - "- 修改激活函数,将激活函数更换为ReLU,可以有效避免梯度爆炸问题" + "- 修改激活函数,将激活函数更换为ReLU,可以有效减缓梯度爆炸问题\n", + "- 使用批次归一化(Batch Normalization)" ] } ], @@ -40081,9 +40500,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 1 -} \ No newline at end of file +}