Files
PyTorch-Tutorial/tutorial-contents-notebooks/402_RNN.ipynb
morvanzhou 4e5122c373 u
2020-10-29 21:41:31 +08:00

319 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 402 RNN\n",
"\n",
"View more, visit my tutorial page: https://mofanpy.com/tutorials/\n",
"My Youtube Channel: https://www.youtube.com/user/MorvanZhou\n",
"\n",
"Dependencies:\n",
"* torch: 0.1.11\n",
"* matplotlib\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.autograd import Variable\n",
"import torchvision.datasets as dsets\n",
"import torchvision.transforms as transforms\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f5864059930>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.manual_seed(1) # reproducible"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Hyper Parameters\n",
"EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch\n",
"BATCH_SIZE = 64\n",
"TIME_STEP = 28 # rnn time step / image height\n",
"INPUT_SIZE = 28 # rnn input size / image width\n",
"LR = 0.01 # learning rate\n",
"DOWNLOAD_MNIST = True # set to True if haven't download the data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Mnist digital dataset\n",
"train_data = dsets.MNIST(\n",
" root='./mnist/',\n",
" train=True, # this is training data\n",
" transform=transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to\n",
" # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]\n",
" download=DOWNLOAD_MNIST, # download it if you don't have it\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([60000, 28, 28])\n",
"torch.Size([60000])\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADrpJREFUeJzt3X2sVHV+x/HPp6hpxAekpkhYLYsxGDWWbRAbQ1aNYX2I\nRlFjltSERiP7hyRu0pAa+sdqWqypD81SzQY26kKzdd1EjehufKiobGtCvCIq4qKu0SzkCjWIAj5Q\nuN/+cYftXb3zm8vMmTnD/b5fyeTOnO+cOd+c8OE8zvwcEQKQz5/U3QCAehB+ICnCDyRF+IGkCD+Q\nFOEHkiL8QFKEH6Oy/aLtL23vaTy21N0TqkX4UbI4Io5pPGbW3QyqRfiBpAg/Sv7Z9se2/9v2BXU3\ng2qZe/sxGtvnStosaZ+k70u6T9KsiPhdrY2hMoQfY2L7aUm/ioh/q7sXVIPdfoxVSHLdTaA6hB/f\nYHuS7Ytt/6ntI2z/jaTvSnq67t5QnSPqbgB96UhJ/yTpdEkHJP1W0lUR8U6tXaFSHPMDSbHbDyRF\n+IGkCD+QFOEHkurp2X7bnF0EuiwixnQ/RkdbftuX2N5i+z3bt3byWQB6q+1LfbYnSHpH0jxJWyW9\nImlBRGwuzMOWH+iyXmz550h6LyLej4h9kn4h6coOPg9AD3US/mmSfj/i9dbGtD9ie5HtAdsDHSwL\nQMW6fsIvIlZKWimx2w/0k062/NsknTzi9bca0wAcBjoJ/yuSTrP9bdtHafgHH9ZU0xaAbmt7tz8i\n9tteLOkZSRMkPRgRb1XWGYCu6um3+jjmB7qvJzf5ADh8EX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQf\nSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKE\nH0iK8ANJEX4gKcIPJEX4gaQIP5BU20N04/AwYcKEYv3444/v6vIXL17ctHb00UcX5505c2axfvPN\nNxfrd999d9PaggULivN++eWXxfqdd95ZrN9+++3Fej/oKPy2P5C0W9IBSfsjYnYVTQHoviq2/BdG\nxMcVfA6AHuKYH0iq0/CHpGdtv2p70WhvsL3I9oDtgQ6XBaBCne72z42Ibbb/XNJztn8bEetGviEi\nVkpaKUm2o8PlAahIR1v+iNjW+LtD0uOS5lTRFIDuazv8tifaPvbgc0nfk7SpqsYAdFcnu/1TJD1u\n++Dn/EdEPF1JV+PMKaecUqwfddRRxfp5551XrM+dO7dpbdKkScV5r7nmmmK9Tlu3bi3Wly9fXqzP\nnz+/aW337t3FeV9//fVi/aWXXirWDwdthz8i3pf0lxX2AqCHuNQHJEX4gaQIP5AU4QeSIvxAUo7o\n3U134/UOv1mzZhXra9euLda7/bXafjU0NFSs33DDDcX6nj172l724OBgsf7JJ58U61u2bGl72d0W\nER7L+9jyA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBSXOevwOTJk4v19evXF+szZsyosp1Ktep9165d\nxfqFF17YtLZv377ivFnvf+gU1/kBFBF+ICnCDyRF+IGkCD+QFOEHkiL8QFIM0V2BnTt3FutLliwp\n1i+//PJi/bXXXivWW/2EdcnGjRuL9Xnz5hXre/fuLdbPPPPMprVbbrmlOC+6iy0/kBThB5Ii/EBS\nhB9IivADSRF+ICnCDyTF9/n7wHHHHVestxpOesWKFU1rN954Y3He66+/vlh/+OGHi3X0n8q+z2/7\nQds7bG8aMW2y7edsv9v4e0InzQLovbHs9v9M0iVfm3arpOcj4jRJzzdeAziMtAx/RKyT9PX7V6+U\ntKrxfJWkqyruC0CXtXtv/5SIODjY2UeSpjR7o+1Fkha1uRwAXdLxF3siIkon8iJipaSVEif8gH7S\n7qW+7banSlLj747qWgLQC+2Gf42khY3nCyU9UU07AHql5W6/7YclXSDpRNtbJf1I0p2Sfmn7Rkkf\nSrqum02Od5999llH83/66adtz3vTTTcV64888kixPjQ01PayUa+W4Y+IBU1KF1XcC4Ae4vZeICnC\nDyRF+IGkCD+QFOEHkuIrvePAxIkTm9aefPLJ4rznn39+sX7ppZcW688++2yxjt5jiG4ARYQfSIrw\nA0kRfiApwg8kRfiBpAg/kBTX+ce5U089tVjfsGFDsb5r165i/YUXXijWBwYGmtbuv//+4ry9/Lc5\nnnCdH0AR4QeSIvxAUoQfSIrwA0kRfiApwg8kxXX+5ObPn1+sP/TQQ8X6scce2/ayly5dWqyvXr26\nWB8cHCzWs+I6P4Aiwg8kRfiBpAg/kBThB5Ii/EBShB9Iiuv8KDrrrLOK9XvvvbdYv+ii9gdzXrFi\nRbG+bNmyYn3btm1tL/twVtl1ftsP2t5he9OIabfZ3mZ7Y+NxWSfNAui9sez2/0zSJaNM/9eImNV4\n/LratgB0W8vwR8Q6STt70AuAHurkhN9i2280DgtOaPYm24tsD9hu/mNuAHqu3fD/RNKpkmZJGpR0\nT7M3RsTKiJgdEbPbXBaALmgr/BGxPSIORMSQpJ9KmlNtWwC6ra3w25464uV8SZuavRdAf2p5nd/2\nw5IukHSipO2SftR4PUtSSPpA0g8iouWXq7nOP/5MmjSpWL/iiiua1lr9VoBdvly9du3aYn3evHnF\n+ng11uv8R4zhgxaMMvmBQ+4IQF/h9l4gKcIPJEX4gaQIP5AU4QeS4iu9qM1XX31VrB9xRPli1P79\n+4v1iy++uGntxRdfLM57OOOnuwEUEX4gKcIPJEX4gaQIP5AU4QeSIvxAUi2/1Yfczj777GL92muv\nLdbPOeecprVW1/Fb2bx5c7G+bt26jj5/vGPLDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJcZ1/nJs5\nc2axvnjx4mL96quvLtZPOumkQ+5prA4cOFCsDw6Wfy1+aGioynbGHbb8QFKEH0iK8ANJEX4gKcIP\nJEX4gaQIP5BUy+v8tk+WtFrSFA0Pyb0yIn5se7KkRyRN1/Aw3ddFxCfdazWvVtfSFywYbSDlYa2u\n40+fPr2dlioxMDBQrC9btqxYX7NmTZXtpDOWLf9+SX8XEWdI+mtJN9s+Q9Ktkp6PiNMkPd94DeAw\n0TL8ETEYERsaz3dLelvSNElXSlrVeNsqSVd1q0kA1TukY37b0yV9R9J6SVMi4uD9lR9p+LAAwGFi\nzPf22z5G0qOSfhgRn9n/PxxYRESzcfhsL5K0qNNGAVRrTFt+20dqOPg/j4jHGpO3257aqE+VtGO0\neSNiZUTMjojZVTQMoBotw+/hTfwDkt6OiHtHlNZIWth4vlDSE9W3B6BbWg7RbXuupN9IelPSwe9I\nLtXwcf8vJZ0i6UMNX+rb2eKzUg7RPWVK+XTIGWecUazfd999xfrpp59+yD1VZf369cX6XXfd1bT2\nxBPl7QVfyW3PWIfobnnMHxH/JanZh110KE0B6B/c4QckRfiBpAg/kBThB5Ii/EBShB9Iip/uHqPJ\nkyc3ra1YsaI476xZs4r1GTNmtNVTFV5++eVi/Z577inWn3nmmWL9iy++OOSe0Bts+YGkCD+QFOEH\nkiL8QFKEH0iK8ANJEX4gqTTX+c8999xifcmSJcX6nDlzmtamTZvWVk9V+fzzz5vWli9fXpz3jjvu\nKNb37t3bVk/of2z5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCpNNf558+f31G9E5s3by7Wn3rqqWJ9\n//79xXrpO/e7du0qzou82PIDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKOiPIb7JMlrZY0RVJIWhkR\nP7Z9m6SbJP1P461LI+LXLT6rvDAAHYsIj+V9Ywn/VElTI2KD7WMlvSrpKknXSdoTEXePtSnCD3Tf\nWMPf8g6/iBiUNNh4vtv225Lq/ekaAB07pGN+29MlfUfS+sakxbbfsP2g7ROazLPI9oDtgY46BVCp\nlrv9f3ijfYyklyQti4jHbE+R9LGGzwP8o4YPDW5o8Rns9gNdVtkxvyTZPlLSU5KeiYh7R6lPl/RU\nRJzV4nMIP9BlYw1/y91+25b0gKS3Rwa/cSLwoPmSNh1qkwDqM5az/XMl/UbSm5KGGpOXSlogaZaG\nd/s/kPSDxsnB0mex5Qe6rNLd/qoQfqD7KtvtBzA+EX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrw\nA0kRfiApwg8kRfiBpAg/kBThB5Lq9RDdH0v6cMTrExvT+lG/9tavfUn01q4qe/uLsb6xp9/n/8bC\n7YGImF1bAwX92lu/9iXRW7vq6o3dfiApwg8kVXf4V9a8/JJ+7a1f+5LorV219FbrMT+A+tS95QdQ\nE8IPJFVL+G1fYnuL7fds31pHD83Y/sD2m7Y31j2+YGMMxB22N42YNtn2c7bfbfwddYzEmnq7zfa2\nxrrbaPuymno72fYLtjfbfsv2LY3pta67Ql+1rLeeH/PbniDpHUnzJG2V9IqkBRGxuaeNNGH7A0mz\nI6L2G0Jsf1fSHkmrDw6FZvtfJO2MiDsb/3GeEBF/3ye93aZDHLa9S701G1b+b1XjuqtyuPsq1LHl\nnyPpvYh4PyL2SfqFpCtr6KPvRcQ6STu/NvlKSasaz1dp+B9PzzXprS9ExGBEbGg83y3p4LDyta67\nQl+1qCP80yT9fsTrrapxBYwiJD1r+1Xbi+puZhRTRgyL9pGkKXU2M4qWw7b30teGle+bddfOcPdV\n44TfN82NiL+SdKmkmxu7t30pho/Z+ula7U8knarhMRwHJd1TZzONYeUflfTDiPhsZK3OdTdKX7Ws\ntzrCv03SySNef6sxrS9ExLbG3x2SHtfwYUo/2X5whOTG3x019/MHEbE9Ig5ExJCkn6rGddcYVv5R\nST+PiMcak2tfd6P1Vdd6qyP8r0g6zfa3bR8l6fuS1tTQxzfYntg4ESPbEyV9T/039PgaSQsbzxdK\neqLGXv5Ivwzb3mxYedW87vpuuPuI6PlD0mUaPuP/O0n/UEcPTfqaIen1xuOtunuT9LCGdwP/V8Pn\nRm6U9GeSnpf0rqT/lDS5j3r7dw0P5f6GhoM2tabe5mp4l/4NSRsbj8vqXneFvmpZb9zeCyTFCT8g\nKcIPJEX4gaQIP5AU4QeSIvxAUoQfSOr/AH6evjIXWuv8AAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f580f9626a0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# plot one example\n",
"print(train_data.train_data.size()) # (60000, 28, 28)\n",
"print(train_data.train_labels.size()) # (60000)\n",
"plt.imshow(train_data.train_data[0].numpy(), cmap='gray')\n",
"plt.title('%i' % train_data.train_labels[0])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Data Loader for easy mini-batch return in training\n",
"train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# convert test data into Variable, pick 2000 samples to speed up testing\n",
"test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())\n",
"test_x = Variable(test_data.test_data, volatile=True).type(torch.FloatTensor)[:2000]/255. # shape (2000, 28, 28) value in range(0,1)\n",
"test_y = test_data.test_labels.numpy().squeeze()[:2000] # covert to numpy array"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class RNN(nn.Module):\n",
" def __init__(self):\n",
" super(RNN, self).__init__()\n",
"\n",
" self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns\n",
" input_size=INPUT_SIZE,\n",
" hidden_size=64, # rnn hidden unit\n",
" num_layers=1, # number of rnn layer\n",
" batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)\n",
" )\n",
"\n",
" self.out = nn.Linear(64, 10)\n",
"\n",
" def forward(self, x):\n",
" # x shape (batch, time_step, input_size)\n",
" # r_out shape (batch, time_step, output_size)\n",
" # h_n shape (n_layers, batch, hidden_size)\n",
" # h_c shape (n_layers, batch, hidden_size)\n",
" r_out, (h_n, h_c) = self.rnn(x, None) # None represents zero initial hidden state\n",
"\n",
" # choose r_out at the last time step\n",
" out = self.out(r_out[:, -1, :])\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RNN (\n",
" (rnn): LSTM(28, 64, batch_first=True)\n",
" (out): Linear (64 -> 10)\n",
")\n"
]
}
],
"source": [
"rnn = RNN()\n",
"print(rnn)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parameters\n",
"loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0 | train loss: 2.3127 | test accuracy: 0.14\n",
"Epoch: 0 | train loss: 1.1182 | test accuracy: 0.56\n",
"Epoch: 0 | train loss: 0.9374 | test accuracy: 0.68\n",
"Epoch: 0 | train loss: 0.6279 | test accuracy: 0.75\n",
"Epoch: 0 | train loss: 0.8004 | test accuracy: 0.75\n",
"Epoch: 0 | train loss: 0.3407 | test accuracy: 0.88\n",
"Epoch: 0 | train loss: 0.3342 | test accuracy: 0.89\n",
"Epoch: 0 | train loss: 0.3200 | test accuracy: 0.91\n",
"Epoch: 0 | train loss: 0.3430 | test accuracy: 0.92\n",
"Epoch: 0 | train loss: 0.1234 | test accuracy: 0.93\n",
"Epoch: 0 | train loss: 0.2714 | test accuracy: 0.92\n",
"Epoch: 0 | train loss: 0.1043 | test accuracy: 0.93\n",
"Epoch: 0 | train loss: 0.2893 | test accuracy: 0.93\n",
"Epoch: 0 | train loss: 0.0635 | test accuracy: 0.94\n",
"Epoch: 0 | train loss: 0.2412 | test accuracy: 0.94\n",
"Epoch: 0 | train loss: 0.1203 | test accuracy: 0.92\n",
"Epoch: 0 | train loss: 0.0807 | test accuracy: 0.94\n",
"Epoch: 0 | train loss: 0.1307 | test accuracy: 0.94\n",
"Epoch: 0 | train loss: 0.1166 | test accuracy: 0.95\n"
]
}
],
"source": [
"# training and testing\n",
"for epoch in range(EPOCH):\n",
" for step, (x, y) in enumerate(train_loader): # gives batch data\n",
" b_x = Variable(x.view(-1, 28, 28)) # reshape x to (batch, time_step, input_size)\n",
" b_y = Variable(y) # batch y\n",
"\n",
" output = rnn(b_x) # rnn output\n",
" loss = loss_func(output, b_y) # cross entropy loss\n",
" optimizer.zero_grad() # clear gradients for this training step\n",
" loss.backward() # backpropagation, compute gradients\n",
" optimizer.step() # apply gradients\n",
"\n",
" if step % 50 == 0:\n",
" test_output = rnn(test_x) # (samples, time_step, input_size)\n",
" pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()\n",
" accuracy = sum(pred_y == test_y) / float(test_y.size)\n",
" print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0], '| test accuracy: %.2f' % accuracy)\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[7 2 1 0 4 1 4 9 5 9] prediction number\n",
"[7 2 1 0 4 1 4 9 5 9] real number\n"
]
}
],
"source": [
"# print 10 predictions from test data\n",
"test_output = rnn(test_x[:10].view(-1, 28, 28))\n",
"pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()\n",
"print(pred_y, 'prediction number')\n",
"print(test_y[:10], 'real number')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}