From 7102922f2cfe3f82f3b9f989f02fd7ed5d8abaa5 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 15:53:27 -0800 Subject: [PATCH 01/19] Replace tutorial with fitting sine function with third order polynomial --- ...net_autograd.py => polynomial_autograd.py} | 50 ++++----- .../polynomial_custom_function.py | 104 ++++++++++++++++++ .../examples_autograd/tf_two_layer_net.py | 79 ------------- .../two_layer_net_custom_function.py | 97 ---------------- .../examples_tensor/polynomial_numpy.py | 53 +++++++++ ...yer_net_tensor.py => polynomial_tensor.py} | 43 ++++---- .../examples_tensor/two_layer_net_numpy.py | 51 --------- beginner_source/pytorch_with_examples.rst | 31 +++--- 8 files changed, 221 insertions(+), 287 deletions(-) rename beginner_source/examples_autograd/{two_layer_net_autograd.py => polynomial_autograd.py} (63%) create mode 100755 beginner_source/examples_autograd/polynomial_custom_function.py delete mode 100755 beginner_source/examples_autograd/tf_two_layer_net.py delete mode 100755 beginner_source/examples_autograd/two_layer_net_custom_function.py create mode 100755 beginner_source/examples_tensor/polynomial_numpy.py rename beginner_source/examples_tensor/{two_layer_net_tensor.py => polynomial_tensor.py} (56%) delete mode 100755 beginner_source/examples_tensor/two_layer_net_numpy.py diff --git a/beginner_source/examples_autograd/two_layer_net_autograd.py b/beginner_source/examples_autograd/polynomial_autograd.py similarity index 63% rename from beginner_source/examples_autograd/two_layer_net_autograd.py rename to beginner_source/examples_autograd/polynomial_autograd.py index ebbc98b2bb8..bd423ae6244 100755 --- a/beginner_source/examples_autograd/two_layer_net_autograd.py +++ b/beginner_source/examples_autograd/polynomial_autograd.py @@ -3,8 +3,8 @@ PyTorch: Tensors and autograd ------------------------------- -A fully-connected ReLU network with one hidden layer and no biases, trained to -predict y from x by minimizing squared Euclidean distance. +A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi` +to :math:`pi` by minimizing squared Euclidean distance. This implementation computes the forward pass using operations on PyTorch Tensors, and uses PyTorch autograd to compute gradients. @@ -15,42 +15,34 @@ holding the gradient of ``x`` with respect to some scalar value. """ import torch +import math dtype = torch.float device = torch.device("cpu") # device = torch.device("cuda:0") # Uncomment this to run on GPU -# torch.backends.cuda.matmul.allow_tf32 = False # Uncomment this to run on GPU -# The above line disables TensorFloat32. This a feature that allows -# networks to run at a much faster speed while sacrificing precision. -# Although TensorFloat32 works well on most real models, for our toy model -# in this tutorial, the sacrificed precision causes convergence issue. -# For more information, see: -# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - -# N is batch size; D_in is input dimension; -# H is hidden dimension; D_out is output dimension. -N, D_in, H, D_out = 64, 1000, 100, 10 - -# Create random Tensors to hold input and outputs. +# Create Tensors to hold input and outputs. # Setting requires_grad=False indicates that we do not need to compute gradients # with respect to these Tensors during the backward pass. -x = torch.randn(N, D_in, device=device, dtype=dtype) -y = torch.randn(N, D_out, device=device, dtype=dtype) +x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) +y = torch.sin(x) -# Create random Tensors for weights. +# Create random Tensors for weights. For a third order polynomial, we need +# 4 weights: y = a + b x + c x^2 + d x^3 # Setting requires_grad=True indicates that we want to compute gradients with # respect to these Tensors during the backward pass. -w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True) -w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True) +a = torch.randn((), device=device, dtype=dtype, requires_grad=True) +b = torch.randn((), device=device, dtype=dtype, requires_grad=True) +c = torch.randn((), device=device, dtype=dtype, requires_grad=True) +d = torch.randn((), device=device, dtype=dtype, requires_grad=True) learning_rate = 1e-6 -for t in range(500): +for t in range(2000): # Forward pass: compute predicted y using operations on Tensors; these # are exactly the same operations we used to compute the forward pass using # Tensors, but we do not need to keep references to intermediate values since # we are not implementing the backward pass by hand. - y_pred = x.mm(w1).clamp(min=0).mm(w2) + y_pred = a + b * x + c * x ** 2 + d * x ** 3 # Compute and print loss using operations on Tensors. # Now loss is a Tensor of shape (1,) @@ -73,9 +65,15 @@ # tensor, but doesn't track history. # You can also use torch.optim.SGD to achieve this. with torch.no_grad(): - w1 -= learning_rate * w1.grad - w2 -= learning_rate * w2.grad + a -= learning_rate * a.grad + b -= learning_rate * b.grad + c -= learning_rate * c.grad + d -= learning_rate * d.grad # Manually zero the gradients after updating weights - w1.grad.zero_() - w2.grad.zero_() + a.grad = None + b.grad = None + c.grad = None + d.grad = None + +print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') diff --git a/beginner_source/examples_autograd/polynomial_custom_function.py b/beginner_source/examples_autograd/polynomial_custom_function.py new file mode 100755 index 00000000000..894a8c6fb05 --- /dev/null +++ b/beginner_source/examples_autograd/polynomial_custom_function.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +""" +PyTorch: Defining New autograd Functions +---------------------------------------- + +A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi` +to :math:`pi` by minimizing squared Euclidean distance. Instead of writing the +polynomial as :math:`y=a+bx+cx^2+dx^3`, we write the polynomial as +:math:`y=a+b*P_3(c+dx)` where :math:`P_3(x)=\frac{1/2}\left(5x^3-3x\right)` is +the `Legendre polynomial`_ of degree three. + +.. _Legendre polynomial: + https://en.wikipedia.org/wiki/Legendre_polynomials + +This implementation computes the forward pass using operations on PyTorch +Tensors, and uses PyTorch autograd to compute gradients. + +In this implementation we implement our own custom autograd function to perform +:math:`P_3'(x)`. By mathematics, :math:`P_3'(x)=\frac{3/2}\left(5x^2-1\right)` +""" +import torch +import math + + +class LegendrePolynomial3(torch.autograd.Function): + """ + We can implement our own custom autograd Functions by subclassing + torch.autograd.Function and implementing the forward and backward passes + which operate on Tensors. + """ + + @staticmethod + def forward(ctx, input): + """ + In the forward pass we receive a Tensor containing the input and return + a Tensor containing the output. ctx is a context object that can be used + to stash information for backward computation. You can cache arbitrary + objects for use in the backward pass using the ctx.save_for_backward method. + """ + ctx.save_for_backward(input) + return 0.5 * (5 * input ** 3 - 3 * input) + + @staticmethod + def backward(ctx, grad_output): + """ + In the backward pass we receive a Tensor containing the gradient of the loss + with respect to the output, and we need to compute the gradient of the loss + with respect to the input. + """ + input, = ctx.saved_tensors + return grad_output * 1.5 * (5 * input ** 2 - 1) + + +dtype = torch.float +device = torch.device("cpu") +# device = torch.device("cuda:0") # Uncomment this to run on GPU + +# Create Tensors to hold input and outputs. +# Setting requires_grad=False indicates that we do not need to compute gradients +# with respect to these Tensors during the backward pass. +x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) +y = torch.sin(x) + +# Create random Tensors for weights. For this example, we need +# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized +# not too far from the correct result to ensure convergence. +# Setting requires_grad=True indicates that we want to compute gradients with +# respect to these Tensors during the backward pass. +a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) +b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True) +c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) +d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True) + +learning_rate = 5e-6 +for t in range(2000): + # To apply our Function, we use Function.apply method. We alias this as 'P3'. + P3 = LegendrePolynomial3.apply + + # Forward pass: compute predicted y using operations; we compute + # P3 using our custom autograd operation. + y_pred = a + b * P3(c + d * x) + + # Compute and print loss + loss = (y_pred - y).pow(2).sum() + if t % 100 == 99: + print(t, loss.item()) + + # Use autograd to compute the backward pass. + loss.backward() + + # Update weights using gradient descent + with torch.no_grad(): + a -= learning_rate * a.grad + b -= learning_rate * b.grad + c -= learning_rate * c.grad + d -= learning_rate * d.grad + + # Manually zero the gradients after updating weights + a.grad = None + b.grad = None + c.grad = None + d.grad = None + +print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)') diff --git a/beginner_source/examples_autograd/tf_two_layer_net.py b/beginner_source/examples_autograd/tf_two_layer_net.py deleted file mode 100755 index 1caf36e89f4..00000000000 --- a/beginner_source/examples_autograd/tf_two_layer_net.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -""" -TensorFlow: Static Graphs -------------------------- - -A fully-connected ReLU network with one hidden layer and no biases, trained to -predict y from x by minimizing squared Euclidean distance. - -This implementation uses basic TensorFlow operations to set up a computational -graph, then executes the graph many times to actually train the network. - -One of the main differences between TensorFlow and PyTorch is that TensorFlow -uses static computational graphs while PyTorch uses dynamic computational -graphs. - -In TensorFlow we first set up the computational graph, then execute the same -graph many times. -""" -import tensorflow as tf -import numpy as np - -# First we set up the computational graph: - -# N is batch size; D_in is input dimension; -# H is hidden dimension; D_out is output dimension. -N, D_in, H, D_out = 64, 1000, 100, 10 - -# Create placeholders for the input and target data; these will be filled -# with real data when we execute the graph. -x = tf.placeholder(tf.float32, shape=(None, D_in)) -y = tf.placeholder(tf.float32, shape=(None, D_out)) - -# Create Variables for the weights and initialize them with random data. -# A TensorFlow Variable persists its value across executions of the graph. -w1 = tf.Variable(tf.random_normal((D_in, H))) -w2 = tf.Variable(tf.random_normal((H, D_out))) - -# Forward pass: Compute the predicted y using operations on TensorFlow Tensors. -# Note that this code does not actually perform any numeric operations; it -# merely sets up the computational graph that we will later execute. -h = tf.matmul(x, w1) -h_relu = tf.maximum(h, tf.zeros(1)) -y_pred = tf.matmul(h_relu, w2) - -# Compute loss using operations on TensorFlow Tensors -loss = tf.reduce_sum((y - y_pred) ** 2.0) - -# Compute gradient of the loss with respect to w1 and w2. -grad_w1, grad_w2 = tf.gradients(loss, [w1, w2]) - -# Update the weights using gradient descent. To actually update the weights -# we need to evaluate new_w1 and new_w2 when executing the graph. Note that -# in TensorFlow the the act of updating the value of the weights is part of -# the computational graph; in PyTorch this happens outside the computational -# graph. -learning_rate = 1e-6 -new_w1 = w1.assign(w1 - learning_rate * grad_w1) -new_w2 = w2.assign(w2 - learning_rate * grad_w2) - -# Now we have built our computational graph, so we enter a TensorFlow session to -# actually execute the graph. -with tf.Session() as sess: - # Run the graph once to initialize the Variables w1 and w2. - sess.run(tf.global_variables_initializer()) - - # Create numpy arrays holding the actual data for the inputs x and targets - # y - x_value = np.random.randn(N, D_in) - y_value = np.random.randn(N, D_out) - for t in range(500): - # Execute the graph many times. Each time it executes we want to bind - # x_value to x and y_value to y, specified with the feed_dict argument. - # Each time we execute the graph we want to compute the values for loss, - # new_w1, and new_w2; the values of these Tensors are returned as numpy - # arrays. - loss_value, _, _ = sess.run([loss, new_w1, new_w2], - feed_dict={x: x_value, y: y_value}) - if t % 100 == 99: - print(t, loss_value) diff --git a/beginner_source/examples_autograd/two_layer_net_custom_function.py b/beginner_source/examples_autograd/two_layer_net_custom_function.py deleted file mode 100755 index 2d2a0875669..00000000000 --- a/beginner_source/examples_autograd/two_layer_net_custom_function.py +++ /dev/null @@ -1,97 +0,0 @@ -# -*- coding: utf-8 -*- -""" -PyTorch: Defining New autograd Functions ----------------------------------------- - -A fully-connected ReLU network with one hidden layer and no biases, trained to -predict y from x by minimizing squared Euclidean distance. - -This implementation computes the forward pass using operations on PyTorch -Variables, and uses PyTorch autograd to compute gradients. - -In this implementation we implement our own custom autograd function to perform -the ReLU function. -""" -import torch - - -class MyReLU(torch.autograd.Function): - """ - We can implement our own custom autograd Functions by subclassing - torch.autograd.Function and implementing the forward and backward passes - which operate on Tensors. - """ - - @staticmethod - def forward(ctx, input): - """ - In the forward pass we receive a Tensor containing the input and return - a Tensor containing the output. ctx is a context object that can be used - to stash information for backward computation. You can cache arbitrary - objects for use in the backward pass using the ctx.save_for_backward method. - """ - ctx.save_for_backward(input) - return input.clamp(min=0) - - @staticmethod - def backward(ctx, grad_output): - """ - In the backward pass we receive a Tensor containing the gradient of the loss - with respect to the output, and we need to compute the gradient of the loss - with respect to the input. - """ - input, = ctx.saved_tensors - grad_input = grad_output.clone() - grad_input[input < 0] = 0 - return grad_input - - -dtype = torch.float -device = torch.device("cpu") -# device = torch.device("cuda:0") # Uncomment this to run on GPU -# torch.backends.cuda.matmul.allow_tf32 = False # Uncomment this to run on GPU - -# The above line disables TensorFloat32. This a feature that allows -# networks to run at a much faster speed while sacrificing precision. -# Although TensorFloat32 works well on most real models, for our toy model -# in this tutorial, the sacrificed precision causes convergence issue. -# For more information, see: -# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - -# N is batch size; D_in is input dimension; -# H is hidden dimension; D_out is output dimension. -N, D_in, H, D_out = 64, 1000, 100, 10 - -# Create random Tensors to hold input and outputs. -x = torch.randn(N, D_in, device=device, dtype=dtype) -y = torch.randn(N, D_out, device=device, dtype=dtype) - -# Create random Tensors for weights. -w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True) -w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True) - -learning_rate = 1e-6 -for t in range(500): - # To apply our Function, we use Function.apply method. We alias this as 'relu'. - relu = MyReLU.apply - - # Forward pass: compute predicted y using operations; we compute - # ReLU using our custom autograd operation. - y_pred = relu(x.mm(w1)).mm(w2) - - # Compute and print loss - loss = (y_pred - y).pow(2).sum() - if t % 100 == 99: - print(t, loss.item()) - - # Use autograd to compute the backward pass. - loss.backward() - - # Update weights using gradient descent - with torch.no_grad(): - w1 -= learning_rate * w1.grad - w2 -= learning_rate * w2.grad - - # Manually zero the gradients after updating weights - w1.grad.zero_() - w2.grad.zero_() diff --git a/beginner_source/examples_tensor/polynomial_numpy.py b/beginner_source/examples_tensor/polynomial_numpy.py new file mode 100755 index 00000000000..8fe6cadac73 --- /dev/null +++ b/beginner_source/examples_tensor/polynomial_numpy.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +""" +Warm-up: numpy +-------------- + +A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi` +to :math:`pi` by minimizing squared Euclidean distance. + +This implementation uses numpy to manually compute the forward pass, loss, and +backward pass. + +A numpy array is a generic n-dimensional array; it does not know anything about +deep learning or gradients or computational graphs, and is just a way to perform +generic numeric computations. +""" +import numpy as np +import math + +# Create random input and output data +x = np.linspace(-math.pi, math.pi, 2000) +y = np.sin(x) + +# Randomly initialize weights +a = np.random.randn() +b = np.random.randn() +c = np.random.randn() +d = np.random.randn() + +learning_rate = 1e-6 +for t in range(2000): + # Forward pass: compute predicted y + # y = a + b x + c x^2 + d x^3 + y_pred = a + b * x + c * x ** 2 + d * x ** 3 + + # Compute and print loss + loss = np.square(y_pred - y).sum() + if t % 100 == 99: + print(t, loss) + + # Backprop to compute gradients of w1 and w2 with respect to loss + grad_y_pred = 2.0 * (y_pred - y) + grad_a = grad_y_pred.sum() + grad_b = (grad_y_pred * x).sum() + grad_c = (grad_y_pred * x ** 2).sum() + grad_d = (grad_y_pred * x ** 3).sum() + + # Update weights + a -= learning_rate * grad_a + b -= learning_rate * grad_b + c -= learning_rate * grad_c + d -= learning_rate * grad_d + +print(f'Result: y = {a} + {b} x + {c} x^2 + {d} x^3') diff --git a/beginner_source/examples_tensor/two_layer_net_tensor.py b/beginner_source/examples_tensor/polynomial_tensor.py similarity index 56% rename from beginner_source/examples_tensor/two_layer_net_tensor.py rename to beginner_source/examples_tensor/polynomial_tensor.py index 3eacae42702..3dade5b1b3e 100755 --- a/beginner_source/examples_tensor/two_layer_net_tensor.py +++ b/beginner_source/examples_tensor/polynomial_tensor.py @@ -3,8 +3,8 @@ PyTorch: Tensors ---------------- -A fully-connected ReLU network with one hidden layer and no biases, trained to -predict y from x by minimizing squared Euclidean distance. +A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi` +to :math:`pi` by minimizing squared Euclidean distance. This implementation uses PyTorch tensors to manually compute the forward pass, loss, and backward pass. @@ -19,30 +19,27 @@ """ import torch +import math dtype = torch.float device = torch.device("cpu") # device = torch.device("cuda:0") # Uncomment this to run on GPU -# N is batch size; D_in is input dimension; -# H is hidden dimension; D_out is output dimension. -N, D_in, H, D_out = 64, 1000, 100, 10 - # Create random input and output data -x = torch.randn(N, D_in, device=device, dtype=dtype) -y = torch.randn(N, D_out, device=device, dtype=dtype) +x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) +y = torch.sin(x) # Randomly initialize weights -w1 = torch.randn(D_in, H, device=device, dtype=dtype) -w2 = torch.randn(H, D_out, device=device, dtype=dtype) +a = torch.randn((), device=device, dtype=dtype) +b = torch.randn((), device=device, dtype=dtype) +c = torch.randn((), device=device, dtype=dtype) +d = torch.randn((), device=device, dtype=dtype) learning_rate = 1e-6 -for t in range(500): +for t in range(2000): # Forward pass: compute predicted y - h = x.mm(w1) - h_relu = h.clamp(min=0) - y_pred = h_relu.mm(w2) + y_pred = a + b * x + c * x ** 2 + d * x ** 3 # Compute and print loss loss = (y_pred - y).pow(2).sum().item() @@ -51,12 +48,16 @@ # Backprop to compute gradients of w1 and w2 with respect to loss grad_y_pred = 2.0 * (y_pred - y) - grad_w2 = h_relu.t().mm(grad_y_pred) - grad_h_relu = grad_y_pred.mm(w2.t()) - grad_h = grad_h_relu.clone() - grad_h[h < 0] = 0 - grad_w1 = x.t().mm(grad_h) + grad_a = grad_y_pred.sum() + grad_b = (grad_y_pred * x).sum() + grad_c = (grad_y_pred * x ** 2).sum() + grad_d = (grad_y_pred * x ** 3).sum() # Update weights using gradient descent - w1 -= learning_rate * grad_w1 - w2 -= learning_rate * grad_w2 + a -= learning_rate * grad_a + b -= learning_rate * grad_b + c -= learning_rate * grad_c + d -= learning_rate * grad_d + + +print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') diff --git a/beginner_source/examples_tensor/two_layer_net_numpy.py b/beginner_source/examples_tensor/two_layer_net_numpy.py deleted file mode 100755 index f003d0f002b..00000000000 --- a/beginner_source/examples_tensor/two_layer_net_numpy.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Warm-up: numpy --------------- - -A fully-connected ReLU network with one hidden layer and no biases, trained to -predict y from x using Euclidean error. - -This implementation uses numpy to manually compute the forward pass, loss, and -backward pass. - -A numpy array is a generic n-dimensional array; it does not know anything about -deep learning or gradients or computational graphs, and is just a way to perform -generic numeric computations. -""" -import numpy as np - -# N is batch size; D_in is input dimension; -# H is hidden dimension; D_out is output dimension. -N, D_in, H, D_out = 64, 1000, 100, 10 - -# Create random input and output data -x = np.random.randn(N, D_in) -y = np.random.randn(N, D_out) - -# Randomly initialize weights -w1 = np.random.randn(D_in, H) -w2 = np.random.randn(H, D_out) - -learning_rate = 1e-6 -for t in range(500): - # Forward pass: compute predicted y - h = x.dot(w1) - h_relu = np.maximum(h, 0) - y_pred = h_relu.dot(w2) - - # Compute and print loss - loss = np.square(y_pred - y).sum() - print(t, loss) - - # Backprop to compute gradients of w1 and w2 with respect to loss - grad_y_pred = 2.0 * (y_pred - y) - grad_w2 = h_relu.T.dot(grad_y_pred) - grad_h_relu = grad_y_pred.dot(w2.T) - grad_h = grad_h_relu.copy() - grad_h[h < 0] = 0 - grad_w1 = x.T.dot(grad_h) - - # Update weights - w1 -= learning_rate * grad_w1 - w2 -= learning_rate * grad_w2 diff --git a/beginner_source/pytorch_with_examples.rst b/beginner_source/pytorch_with_examples.rst index a9f56268b25..d10aa350147 100644 --- a/beginner_source/pytorch_with_examples.rst +++ b/beginner_source/pytorch_with_examples.rst @@ -11,8 +11,8 @@ At its core, PyTorch provides two main features: - An n-dimensional Tensor, similar to numpy but can run on GPUs - Automatic differentiation for building and training neural networks -We will use a fully-connected ReLU network as our running example. The -network will have a single hidden layer, and will be trained with +We will use a problem of fitting :math:`y=\sin(x)` with a third order polynomial +as our running example. The network will have four parameters, and will be trained with gradient descent to fit random data by minimizing the Euclidean distance between the network output and the true output. @@ -39,7 +39,7 @@ learning, or gradients. However we can easily use numpy to fit a two-layer network to random data by manually implementing the forward and backward passes through the network using numpy operations: -.. includenodoc:: /beginner/examples_tensor/two_layer_net_numpy.py +.. includenodoc:: /beginner/examples_tensor/polynomial_numpy.py PyTorch: Tensors @@ -62,11 +62,11 @@ Also unlike numpy, PyTorch Tensors can utilize GPUs to accelerate their numeric computations. To run a PyTorch Tensor on GPU, you simply need to cast it to a new datatype. -Here we use PyTorch Tensors to fit a two-layer network to random data. +Here we use PyTorch Tensors to fit a third order polynomial to sine function. Like the numpy example above we need to manually implement the forward and backward passes through the network: -.. includenodoc:: /beginner/examples_tensor/two_layer_net_tensor.py +.. includenodoc:: /beginner/examples_tensor/polynomial_tensor.py Autograd @@ -95,11 +95,11 @@ represents a node in a computational graph. If ``x`` is a Tensor that has ``x.requires_grad=True`` then ``x.grad`` is another Tensor holding the gradient of ``x`` with respect to some scalar value. -Here we use PyTorch Tensors and autograd to implement our two-layer -network; now we no longer need to manually implement the backward pass -through the network: +Here we use PyTorch Tensors and autograd to implement our fitting sine wave +with third order polynomial example; now we no longer need to manually +implement the backward pass through the network: -.. includenodoc:: /beginner/examples_autograd/two_layer_net_autograd.py +.. includenodoc:: /beginner/examples_autograd/polynomial_autograd.py PyTorch: Defining new autograd functions ---------------------------------------- @@ -117,11 +117,16 @@ and ``backward`` functions. We can then use our new autograd operator by constructing an instance and calling it like a function, passing Tensors containing input data. -In this example we define our own custom autograd function for -performing the ReLU nonlinearity, and use it to implement our two-layer -network: +In this example we define our model as :math:`y=a+b*P_3(c+dx)` instead of +:math:`y=a+bx+cx^2+dx^3`, where :math:`P_3(x)=\frac{1/2}\left(5x^3-3x\right)` +is the `Legendre polynomial`_ of degree three. We write our own custom autograd +function for computing forward and backward of P3, and use it to implement our +model: + +.. _Legendre polynomial: + https://en.wikipedia.org/wiki/Legendre_polynomials -.. includenodoc:: /beginner/examples_autograd/two_layer_net_custom_function.py +.. includenodoc:: /beginner/examples_autograd/polynomial_custom_function.py `nn` module =========== From f89d4a0dd47c11a8c562e98dacc955448cbf1002 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 16:39:11 -0800 Subject: [PATCH 02/19] more --- ...yer_net_module.py => polynomial_module.py} | 45 +++++++++-------- .../{two_layer_net_nn.py => polynomial_nn.py} | 50 +++++++++++-------- ...layer_net_optim.py => polynomial_optim.py} | 38 +++++++------- beginner_source/pytorch_with_examples.rst | 41 +++++++-------- 4 files changed, 93 insertions(+), 81 deletions(-) rename beginner_source/examples_nn/{two_layer_net_module.py => polynomial_module.py} (54%) rename beginner_source/examples_nn/{two_layer_net_nn.py => polynomial_nn.py} (59%) rename beginner_source/examples_nn/{two_layer_net_optim.py => polynomial_optim.py} (66%) diff --git a/beginner_source/examples_nn/two_layer_net_module.py b/beginner_source/examples_nn/polynomial_module.py similarity index 54% rename from beginner_source/examples_nn/two_layer_net_module.py rename to beginner_source/examples_nn/polynomial_module.py index 29d27274d25..3e1ea195e0e 100755 --- a/beginner_source/examples_nn/two_layer_net_module.py +++ b/beginner_source/examples_nn/polynomial_module.py @@ -3,25 +3,28 @@ PyTorch: Custom nn Modules -------------------------- -A fully-connected ReLU network with one hidden layer, trained to predict y from x -by minimizing squared Euclidean distance. +A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi` +to :math:`pi` by minimizing squared Euclidean distance. This implementation defines the model as a custom Module subclass. Whenever you want a model more complex than a simple sequence of existing Modules you will need to define your model this way. """ import torch +import math -class TwoLayerNet(torch.nn.Module): - def __init__(self, D_in, H, D_out): +class Polynomial3(torch.nn.Module): + def __init__(self,): """ - In the constructor we instantiate two nn.Linear modules and assign them as + In the constructor we instantiate four parameters and assign them as member variables. """ - super(TwoLayerNet, self).__init__() - self.linear1 = torch.nn.Linear(D_in, H) - self.linear2 = torch.nn.Linear(H, D_out) + super().__init__() + self.a = torch.nn.Parameter(torch.randn(())) + self.b = torch.nn.Parameter(torch.randn(())) + self.c = torch.nn.Parameter(torch.randn(())) + self.d = torch.nn.Parameter(torch.randn(())) def forward(self, x): """ @@ -29,28 +32,28 @@ def forward(self, x): a Tensor of output data. We can use Modules defined in the constructor as well as arbitrary operators on Tensors. """ - h_relu = self.linear1(x).clamp(min=0) - y_pred = self.linear2(h_relu) - return y_pred + return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 + def string(self): + """ + Just like any class in Python, you can also define custom method on PyTorch modules + """ + return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3' -# N is batch size; D_in is input dimension; -# H is hidden dimension; D_out is output dimension. -N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold inputs and outputs -x = torch.randn(N, D_in) -y = torch.randn(N, D_out) +# Create Tensors to hold input and outputs. +x = torch.linspace(-math.pi, math.pi, 2000) +y = torch.sin(x) # Construct our model by instantiating the class defined above -model = TwoLayerNet(D_in, H, D_out) +model = Polynomial3() # Construct our loss function and an Optimizer. The call to model.parameters() # in the SGD constructor will contain the learnable parameters of the two # nn.Linear modules which are members of the model. criterion = torch.nn.MSELoss(reduction='sum') -optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) -for t in range(500): +optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) +for t in range(2000): # Forward pass: Compute predicted y by passing x to the model y_pred = model(x) @@ -63,3 +66,5 @@ def forward(self, x): optimizer.zero_grad() loss.backward() optimizer.step() + +print(f'Result: {model.string()}') diff --git a/beginner_source/examples_nn/two_layer_net_nn.py b/beginner_source/examples_nn/polynomial_nn.py similarity index 59% rename from beginner_source/examples_nn/two_layer_net_nn.py rename to beginner_source/examples_nn/polynomial_nn.py index 0c1925878e8..3e609be9d65 100755 --- a/beginner_source/examples_nn/two_layer_net_nn.py +++ b/beginner_source/examples_nn/polynomial_nn.py @@ -3,8 +3,8 @@ PyTorch: nn ----------- -A fully-connected ReLU network with one hidden layer, trained to predict y from x -by minimizing squared Euclidean distance. +A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi` +to :math:`pi` by minimizing squared Euclidean distance. This implementation uses the nn package from PyTorch to build the network. PyTorch autograd makes it easy to define computational graphs and take gradients, @@ -14,41 +14,47 @@ input and may have some trainable weights. """ import torch +import math -# N is batch size; D_in is input dimension; -# H is hidden dimension; D_out is output dimension. -N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold inputs and outputs -x = torch.randn(N, D_in) -y = torch.randn(N, D_out) +# Create Tensors to hold input and outputs. +x = torch.linspace(-math.pi, math.pi, 2000) +y = torch.sin(x) -# Use the nn package to define our model as a sequence of layers. nn.Sequential -# is a Module which contains other Modules, and applies them in sequence to -# produce its output. Each Linear Module computes output from input using a -# linear function, and holds internal Tensors for its weight and bias. -model = torch.nn.Sequential( - torch.nn.Linear(D_in, H), - torch.nn.ReLU(), - torch.nn.Linear(H, D_out), -) +# Use the nn package to define our model as a single layer or a sequence of layers. +# For this example, the output y is a linear function of (x, x^2, x^3), so +# we can consider it as a single linear layer neural network. +model = torch.nn.Linear(3, 1) + +# If your model has multiple layers, you can use :class:`torch.nn.Sequential` them in +# sequence to produce its output. Something like: +# model = torch.nn.Sequential( +# torch.nn.Linear(D_in, H), +# torch.nn.ReLU(), +# torch.nn.Linear(H, D_out), +# ) # The nn package also contains definitions of popular loss functions; in this # case we will use Mean Squared Error (MSE) as our loss function. loss_fn = torch.nn.MSELoss(reduction='sum') -learning_rate = 1e-4 -for t in range(500): +learning_rate = 1e-6 +for t in range(2000): + # In order to use :class:`torch.nn.Linear`, we need to prepare our + # input and output in a format of (batch, D_in) and (batch, D_out) + xx = x.unsqueeze(-1).pow(torch.tensor([1, 2, 3])) + yy = y.unsqueeze(-1) + # Forward pass: compute predicted y by passing x to the model. Module objects # override the __call__ operator so you can call them like functions. When # doing so you pass a Tensor of input data to the Module and it produces # a Tensor of output data. - y_pred = model(x) + y_pred = model(xx) # Compute and print loss. We pass Tensors containing the predicted and true # values of y, and the loss function returns a Tensor containing the # loss. - loss = loss_fn(y_pred, y) + loss = loss_fn(y_pred, yy) if t % 100 == 99: print(t, loss.item()) @@ -66,3 +72,5 @@ with torch.no_grad(): for param in model.parameters(): param -= learning_rate * param.grad + +print(f'Result: y = {model.bias.item()} + {model.weight[:, 0].item()} x + {model.weight[:, 1].item()} x^2 + {model.weight[:, 2].item()} x^3') diff --git a/beginner_source/examples_nn/two_layer_net_optim.py b/beginner_source/examples_nn/polynomial_optim.py similarity index 66% rename from beginner_source/examples_nn/two_layer_net_optim.py rename to beginner_source/examples_nn/polynomial_optim.py index 82b67dcc1b0..30d5ca34d1e 100755 --- a/beginner_source/examples_nn/two_layer_net_optim.py +++ b/beginner_source/examples_nn/polynomial_optim.py @@ -3,8 +3,8 @@ PyTorch: optim -------------- -A fully-connected ReLU network with one hidden layer, trained to predict y from x -by minimizing squared Euclidean distance. +A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi` +to :math:`pi` by minimizing squared Euclidean distance. This implementation uses the nn package from PyTorch to build the network. @@ -14,35 +14,34 @@ used for deep learning, including SGD+momentum, RMSProp, Adam, etc. """ import torch +import math -# N is batch size; D_in is input dimension; -# H is hidden dimension; D_out is output dimension. -N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold inputs and outputs -x = torch.randn(N, D_in) -y = torch.randn(N, D_out) +# Create Tensors to hold input and outputs. +x = torch.linspace(-math.pi, math.pi, 2000) +y = torch.sin(x) # Use the nn package to define our model and loss function. -model = torch.nn.Sequential( - torch.nn.Linear(D_in, H), - torch.nn.ReLU(), - torch.nn.Linear(H, D_out), -) +model = torch.nn.Linear(3, 1) loss_fn = torch.nn.MSELoss(reduction='sum') # Use the optim package to define an Optimizer that will update the weights of # the model for us. Here we will use Adam; the optim package contains many other # optimization algorithms. The first argument to the Adam constructor tells the # optimizer which Tensors it should update. -learning_rate = 1e-4 -optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) -for t in range(500): +learning_rate = 1e-6 +optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) +for t in range(2000): + # In order to use :class:`torch.nn.Linear`, we need to prepare our + # input and output in a format of (batch, D_in) and (batch, D_out) + xx = x.unsqueeze(-1).pow(torch.tensor([1, 2, 3])) + yy = y.unsqueeze(-1) + # Forward pass: compute predicted y by passing x to the model. - y_pred = model(x) + y_pred = model(xx) # Compute and print loss. - loss = loss_fn(y_pred, y) + loss = loss_fn(y_pred, yy) if t % 100 == 99: print(t, loss.item()) @@ -60,3 +59,6 @@ # Calling the step function on an Optimizer makes an update to its # parameters optimizer.step() + + +print(f'Result: y = {model.bias.item()} + {model.weight[:, 0].item()} x + {model.weight[:, 1].item()} x^2 + {model.weight[:, 2].item()} x^3') diff --git a/beginner_source/pytorch_with_examples.rst b/beginner_source/pytorch_with_examples.rst index d10aa350147..9ed18341f77 100644 --- a/beginner_source/pytorch_with_examples.rst +++ b/beginner_source/pytorch_with_examples.rst @@ -157,10 +157,10 @@ containing learnable parameters. The ``nn`` package also defines a set of useful loss functions that are commonly used when training neural networks. -In this example we use the ``nn`` package to implement our two-layer +In this example we use the ``nn`` package to implement our polynomial model network: -.. includenodoc:: /beginner/examples_nn/two_layer_net_nn.py +.. includenodoc:: /beginner/examples_nn/polynomial_nn.py PyTorch: optim -------------- @@ -177,10 +177,10 @@ algorithm and provides implementations of commonly used optimization algorithms. In this example we will use the ``nn`` package to define our model as -before, but we will optimize the model using the Adam algorithm provided +before, but we will optimize the model using the SGD algorithm provided by the ``optim`` package: -.. includenodoc:: /beginner/examples_nn/two_layer_net_optim.py +.. includenodoc:: /beginner/examples_nn/polynomial_optim.py PyTorch: Custom nn Modules -------------------------- @@ -194,7 +194,7 @@ modules or other autograd operations on Tensors. In this example we implement our two-layer network as a custom Module subclass: -.. includenodoc:: /beginner/examples_nn/two_layer_net_module.py +.. includenodoc:: /beginner/examples_nn/polynomial_module.py PyTorch: Control Flow + Weight Sharing -------------------------------------- @@ -228,12 +228,12 @@ Tensors :maxdepth: 2 :hidden: - /beginner/examples_tensor/two_layer_net_numpy - /beginner/examples_tensor/two_layer_net_tensor + /beginner/examples_tensor/polynomial_numpy + /beginner/examples_tensor/polynomial_tensor -.. galleryitem:: /beginner/examples_tensor/two_layer_net_numpy.py +.. galleryitem:: /beginner/examples_tensor/polynomial_numpy.py -.. galleryitem:: /beginner/examples_tensor/two_layer_net_tensor.py +.. galleryitem:: /beginner/examples_tensor/polynomial_tensor.py .. raw:: html @@ -246,16 +246,13 @@ Autograd :maxdepth: 2 :hidden: - /beginner/examples_autograd/two_layer_net_autograd - /beginner/examples_autograd/two_layer_net_custom_function - /beginner/examples_autograd/tf_two_layer_net + /beginner/examples_autograd/polynomial_autograd + /beginner/examples_autograd/polynomial_custom_function -.. galleryitem:: /beginner/examples_autograd/two_layer_net_autograd.py +.. galleryitem:: /beginner/examples_autograd/polynomial_autograd.py -.. galleryitem:: /beginner/examples_autograd/two_layer_net_custom_function.py - -.. galleryitem:: /beginner/examples_autograd/tf_two_layer_net.py +.. galleryitem:: /beginner/examples_autograd/polynomial_custom_function.py .. raw:: html @@ -268,17 +265,17 @@ Autograd :maxdepth: 2 :hidden: - /beginner/examples_nn/two_layer_net_nn - /beginner/examples_nn/two_layer_net_optim - /beginner/examples_nn/two_layer_net_module + /beginner/examples_nn/polynomial_nn + /beginner/examples_nn/polynomial_optim + /beginner/examples_nn/polynomial_module /beginner/examples_nn/dynamic_net -.. galleryitem:: /beginner/examples_nn/two_layer_net_nn.py +.. galleryitem:: /beginner/examples_nn/polynomial_nn.py -.. galleryitem:: /beginner/examples_nn/two_layer_net_optim.py +.. galleryitem:: /beginner/examples_nn/polynomial_optim.py -.. galleryitem:: /beginner/examples_nn/two_layer_net_module.py +.. galleryitem:: /beginner/examples_nn/polynomial_module.py .. galleryitem:: /beginner/examples_nn/dynamic_net.py From 02d7694195db7543ff85fd72177d1f1adca7784a Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 16:55:49 -0800 Subject: [PATCH 03/19] Save --- beginner_source/examples_nn/polynomial_nn.py | 43 ++++++++++++------- .../examples_nn/polynomial_optim.py | 19 ++++---- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/beginner_source/examples_nn/polynomial_nn.py b/beginner_source/examples_nn/polynomial_nn.py index 3e609be9d65..80a39fb3bc7 100755 --- a/beginner_source/examples_nn/polynomial_nn.py +++ b/beginner_source/examples_nn/polynomial_nn.py @@ -21,18 +21,29 @@ x = torch.linspace(-math.pi, math.pi, 2000) y = torch.sin(x) -# Use the nn package to define our model as a single layer or a sequence of layers. # For this example, the output y is a linear function of (x, x^2, x^3), so -# we can consider it as a single linear layer neural network. -model = torch.nn.Linear(3, 1) +# we can consider it as a linear layer neural network. Let's prepare the +# tensor (x, x^2, x^3). +p = torch.tensor([1, 2, 3]) +xx = x.unsqueeze(-1).pow(p) -# If your model has multiple layers, you can use :class:`torch.nn.Sequential` them in -# sequence to produce its output. Something like: -# model = torch.nn.Sequential( -# torch.nn.Linear(D_in, H), -# torch.nn.ReLU(), -# torch.nn.Linear(H, D_out), -# ) +# In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape +# (3,), for this case, broadcasting semantics will apply to obtain a tensor +# of shape (2000, 3) + +# Use the nn package to define our model as a sequence of layers. nn.Sequential +# is a Module which contains other Modules, and applies them in sequence to +# produce its output. The Linear Module computes output from input using a +# linear function, and holds internal Tensors for its weight and bias. +# The Flatten layer flatens the output of the linear layer to a 1D tensor, +# to match the shape of `y`. +model = torch.nn.Sequential( + torch.nn.Linear(3, 1), + torch.nn.Flatten(0, 1) +) + +# In order to use :class:`torch.nn.Linear`, we need to prepare our +# input and output in a format of (batch, D_in) and (batch, D_out) # The nn package also contains definitions of popular loss functions; in this # case we will use Mean Squared Error (MSE) as our loss function. @@ -40,10 +51,6 @@ learning_rate = 1e-6 for t in range(2000): - # In order to use :class:`torch.nn.Linear`, we need to prepare our - # input and output in a format of (batch, D_in) and (batch, D_out) - xx = x.unsqueeze(-1).pow(torch.tensor([1, 2, 3])) - yy = y.unsqueeze(-1) # Forward pass: compute predicted y by passing x to the model. Module objects # override the __call__ operator so you can call them like functions. When @@ -54,7 +61,7 @@ # Compute and print loss. We pass Tensors containing the predicted and true # values of y, and the loss function returns a Tensor containing the # loss. - loss = loss_fn(y_pred, yy) + loss = loss_fn(y_pred, y) if t % 100 == 99: print(t, loss.item()) @@ -73,4 +80,8 @@ for param in model.parameters(): param -= learning_rate * param.grad -print(f'Result: y = {model.bias.item()} + {model.weight[:, 0].item()} x + {model.weight[:, 1].item()} x^2 + {model.weight[:, 2].item()} x^3') +# You can access the first layer of `model` like accessing the first item of a list +linear_layer = model[0] + +# For linear layer, its parameters are stored as `weight` and `bias`. +print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3') diff --git a/beginner_source/examples_nn/polynomial_optim.py b/beginner_source/examples_nn/polynomial_optim.py index 30d5ca34d1e..8c24e4f8153 100755 --- a/beginner_source/examples_nn/polynomial_optim.py +++ b/beginner_source/examples_nn/polynomial_optim.py @@ -21,8 +21,15 @@ x = torch.linspace(-math.pi, math.pi, 2000) y = torch.sin(x) +# Prepare the input tensor (x, x^2, x^3). +p = torch.tensor([1, 2, 3]) +xx = x.unsqueeze(-1).pow(p) + # Use the nn package to define our model and loss function. -model = torch.nn.Linear(3, 1) +model = torch.nn.Sequential( + torch.nn.Linear(3, 1), + torch.nn.Flatten(0, 1) +) loss_fn = torch.nn.MSELoss(reduction='sum') # Use the optim package to define an Optimizer that will update the weights of @@ -32,16 +39,11 @@ learning_rate = 1e-6 optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) for t in range(2000): - # In order to use :class:`torch.nn.Linear`, we need to prepare our - # input and output in a format of (batch, D_in) and (batch, D_out) - xx = x.unsqueeze(-1).pow(torch.tensor([1, 2, 3])) - yy = y.unsqueeze(-1) - # Forward pass: compute predicted y by passing x to the model. y_pred = model(xx) # Compute and print loss. - loss = loss_fn(y_pred, yy) + loss = loss_fn(y_pred, y) if t % 100 == 99: print(t, loss.item()) @@ -61,4 +63,5 @@ optimizer.step() -print(f'Result: y = {model.bias.item()} + {model.weight[:, 0].item()} x + {model.weight[:, 1].item()} x^2 + {model.weight[:, 2].item()} x^3') +linear_layer = model[0] +print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3') From 8d6514601e612b363ddff53ed980206450e2dbf4 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 17:33:19 -0800 Subject: [PATCH 04/19] save --- beginner_source/examples_nn/dynamic_net.py | 61 ++++++++++++---------- beginner_source/pytorch_with_examples.rst | 7 ++- 2 files changed, 36 insertions(+), 32 deletions(-) diff --git a/beginner_source/examples_nn/dynamic_net.py b/beginner_source/examples_nn/dynamic_net.py index 0e56e39dfbd..0c9ad66c894 100755 --- a/beginner_source/examples_nn/dynamic_net.py +++ b/beginner_source/examples_nn/dynamic_net.py @@ -4,30 +4,32 @@ -------------------------------------- To showcase the power of PyTorch dynamic graphs, we will implement a very strange -model: a fully-connected ReLU network that on each forward pass randomly chooses -a number between 1 and 4 and has that many hidden layers, reusing the same -weights multiple times to compute the innermost hidden layers. +model: a third-fifth order polynomial that on each forward pass +chooses a random number between 3 and 5 and uses that many orders, reusing +the same weights multiple times to compute the fourth and fifth order. """ import random import torch +import math class DynamicNet(torch.nn.Module): - def __init__(self, D_in, H, D_out): + def __init__(self): """ - In the constructor we construct three nn.Linear instances that we will use - in the forward pass. + In the constructor we instantiate four parameters and assign them as + member variables. """ - super(DynamicNet, self).__init__() - self.input_linear = torch.nn.Linear(D_in, H) - self.middle_linear = torch.nn.Linear(H, H) - self.output_linear = torch.nn.Linear(H, D_out) + super().__init__() + self.a = torch.nn.Parameter(torch.randn(())) + self.b = torch.nn.Parameter(torch.randn(())) + self.c = torch.nn.Parameter(torch.randn(())) + self.d = torch.nn.Parameter(torch.randn(())) + self.e = torch.nn.Parameter(torch.randn(())) def forward(self, x): """ - For the forward pass of the model, we randomly choose either 0, 1, 2, or 3 - and reuse the middle_linear Module that many times to compute hidden layer - representations. + For the forward pass of the model, we randomly choose either 4, 5 + and reuse the e parameter to compute the contribution of these orders. Since each forward pass builds a dynamic computation graph, we can use normal Python control-flow operators like loops or conditional statements when @@ -37,38 +39,41 @@ def forward(self, x): times when defining a computational graph. This is a big improvement from Lua Torch, where each Module could be used only once. """ - h_relu = self.input_linear(x).clamp(min=0) - for _ in range(random.randint(0, 3)): - h_relu = self.middle_linear(h_relu).clamp(min=0) - y_pred = self.output_linear(h_relu) - return y_pred + y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 + for exp in range(4, random.randint(4, 6)): + y = y + self.e * x ** exp + return y + def string(self): + """ + Just like any class in Python, you can also define custom method on PyTorch modules + """ + return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?' -# N is batch size; D_in is input dimension; -# H is hidden dimension; D_out is output dimension. -N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold inputs and outputs -x = torch.randn(N, D_in) -y = torch.randn(N, D_out) +# Create Tensors to hold input and outputs. +x = torch.linspace(-math.pi, math.pi, 2000) +y = torch.sin(x) # Construct our model by instantiating the class defined above -model = DynamicNet(D_in, H, D_out) +model = DynamicNet() # Construct our loss function and an Optimizer. Training this strange model with # vanilla stochastic gradient descent is tough, so we use momentum criterion = torch.nn.MSELoss(reduction='sum') -optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) -for t in range(500): +optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9) +for t in range(30000): # Forward pass: Compute predicted y by passing x to the model y_pred = model(x) # Compute and print loss loss = criterion(y_pred, y) - if t % 100 == 99: + if t % 2000 == 1999: print(t, loss.item()) # Zero gradients, perform a backward pass, and update the weights. optimizer.zero_grad() loss.backward() optimizer.step() + +print(f'Result: {model.string()}') diff --git a/beginner_source/pytorch_with_examples.rst b/beginner_source/pytorch_with_examples.rst index 9ed18341f77..a0c9e9ede72 100644 --- a/beginner_source/pytorch_with_examples.rst +++ b/beginner_source/pytorch_with_examples.rst @@ -200,10 +200,9 @@ PyTorch: Control Flow + Weight Sharing -------------------------------------- As an example of dynamic graphs and weight sharing, we implement a very -strange model: a fully-connected ReLU network that on each forward pass -chooses a random number between 1 and 4 and uses that many hidden -layers, reusing the same weights multiple times to compute the innermost -hidden layers. +strange model: a third-fifth order polynomial that on each forward pass +chooses a random number between 3 and 5 and uses that many orders, reusing +the same weights multiple times to compute the fourth and fifth order. For this model we can use normal Python flow control to implement the loop, and we can implement weight sharing among the innermost layers by simply From 02ed6c5cf613bbf2b631704eee62033b8bd746ef Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 17:42:10 -0800 Subject: [PATCH 05/19] save --- .../examples_autograd/polynomial_custom_function.py | 4 ++-- beginner_source/pytorch_with_examples.rst | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/beginner_source/examples_autograd/polynomial_custom_function.py b/beginner_source/examples_autograd/polynomial_custom_function.py index 894a8c6fb05..d88e1676e46 100755 --- a/beginner_source/examples_autograd/polynomial_custom_function.py +++ b/beginner_source/examples_autograd/polynomial_custom_function.py @@ -6,7 +6,7 @@ A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi` to :math:`pi` by minimizing squared Euclidean distance. Instead of writing the polynomial as :math:`y=a+bx+cx^2+dx^3`, we write the polynomial as -:math:`y=a+b*P_3(c+dx)` where :math:`P_3(x)=\frac{1/2}\left(5x^3-3x\right)` is +:math:`y=a+b*P_3(c+dx)` where :math:`P_3(x)=\frac{1}{2}\left(5x^3-3x\right)` is the `Legendre polynomial`_ of degree three. .. _Legendre polynomial: @@ -16,7 +16,7 @@ Tensors, and uses PyTorch autograd to compute gradients. In this implementation we implement our own custom autograd function to perform -:math:`P_3'(x)`. By mathematics, :math:`P_3'(x)=\frac{3/2}\left(5x^2-1\right)` +:math:`P_3'(x)`. By mathematics, :math:`P_3'(x)=\frac{3}{2}\left(5x^2-1\right)` """ import torch import math diff --git a/beginner_source/pytorch_with_examples.rst b/beginner_source/pytorch_with_examples.rst index a0c9e9ede72..e7da1dfab73 100644 --- a/beginner_source/pytorch_with_examples.rst +++ b/beginner_source/pytorch_with_examples.rst @@ -118,7 +118,7 @@ constructing an instance and calling it like a function, passing Tensors containing input data. In this example we define our model as :math:`y=a+b*P_3(c+dx)` instead of -:math:`y=a+bx+cx^2+dx^3`, where :math:`P_3(x)=\frac{1/2}\left(5x^3-3x\right)` +:math:`y=a+bx+cx^2+dx^3`, where :math:`P_3(x)=\frac{1}{2}\left(5x^3-3x\right)` is the `Legendre polynomial`_ of degree three. We write our own custom autograd function for computing forward and backward of P3, and use it to implement our model: From 4c6aa6c3a777ce07bd5263c0c18b2d6ab9b35447 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 17:58:15 -0800 Subject: [PATCH 06/19] fix --- beginner_source/pytorch_with_examples.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beginner_source/pytorch_with_examples.rst b/beginner_source/pytorch_with_examples.rst index e7da1dfab73..2f2d4b8b530 100644 --- a/beginner_source/pytorch_with_examples.rst +++ b/beginner_source/pytorch_with_examples.rst @@ -36,7 +36,7 @@ Numpy provides an n-dimensional array object, and many functions for manipulating these arrays. Numpy is a generic framework for scientific computing; it does not know anything about computation graphs, or deep learning, or gradients. However we can easily use numpy to fit a -two-layer network to random data by manually implementing the forward +third order polynomial to sine function by manually implementing the forward and backward passes through the network using numpy operations: .. includenodoc:: /beginner/examples_tensor/polynomial_numpy.py From 472f3d89641c97428f24fb73c92127410aee5a34 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 18:49:21 -0800 Subject: [PATCH 07/19] fix --- beginner_source/examples_tensor/polynomial_numpy.py | 2 +- beginner_source/examples_tensor/polynomial_tensor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/beginner_source/examples_tensor/polynomial_numpy.py b/beginner_source/examples_tensor/polynomial_numpy.py index 8fe6cadac73..a1a378e50ed 100755 --- a/beginner_source/examples_tensor/polynomial_numpy.py +++ b/beginner_source/examples_tensor/polynomial_numpy.py @@ -37,7 +37,7 @@ if t % 100 == 99: print(t, loss) - # Backprop to compute gradients of w1 and w2 with respect to loss + # Backprop to compute gradients of a, b, c, d with respect to loss grad_y_pred = 2.0 * (y_pred - y) grad_a = grad_y_pred.sum() grad_b = (grad_y_pred * x).sum() diff --git a/beginner_source/examples_tensor/polynomial_tensor.py b/beginner_source/examples_tensor/polynomial_tensor.py index 3dade5b1b3e..1e35b0f24bd 100755 --- a/beginner_source/examples_tensor/polynomial_tensor.py +++ b/beginner_source/examples_tensor/polynomial_tensor.py @@ -46,7 +46,7 @@ if t % 100 == 99: print(t, loss) - # Backprop to compute gradients of w1 and w2 with respect to loss + # Backprop to compute gradients of a, b, c, d with respect to loss grad_y_pred = 2.0 * (y_pred - y) grad_a = grad_y_pred.sum() grad_b = (grad_y_pred * x).sum() From c080d7cbd1c96dde992ac7486dad9d9d0704d158 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 19:00:33 -0800 Subject: [PATCH 08/19] fix --- beginner_source/pytorch_with_examples.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beginner_source/pytorch_with_examples.rst b/beginner_source/pytorch_with_examples.rst index 2f2d4b8b530..59fe44a034a 100644 --- a/beginner_source/pytorch_with_examples.rst +++ b/beginner_source/pytorch_with_examples.rst @@ -60,7 +60,7 @@ generic tool for scientific computing. Also unlike numpy, PyTorch Tensors can utilize GPUs to accelerate their numeric computations. To run a PyTorch Tensor on GPU, you simply -need to cast it to a new datatype. +need to specify the correct device. Here we use PyTorch Tensors to fit a third order polynomial to sine function. Like the numpy example above we need to manually implement the forward From 6982d7a416f3ebf7e267a9ec1ac549ee861474ca Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 19:36:11 -0800 Subject: [PATCH 09/19] fix --- .../examples_autograd/polynomial_autograd.py | 13 +++++-------- .../examples_autograd/polynomial_custom_function.py | 4 ++-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/beginner_source/examples_autograd/polynomial_autograd.py b/beginner_source/examples_autograd/polynomial_autograd.py index bd423ae6244..8c1da7f38fc 100755 --- a/beginner_source/examples_autograd/polynomial_autograd.py +++ b/beginner_source/examples_autograd/polynomial_autograd.py @@ -22,8 +22,8 @@ # device = torch.device("cuda:0") # Uncomment this to run on GPU # Create Tensors to hold input and outputs. -# Setting requires_grad=False indicates that we do not need to compute gradients -# with respect to these Tensors during the backward pass. +# By default, requires_grad=False, which indicates that we do not need to +# compute gradients with respect to these Tensors during the backward pass. x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) y = torch.sin(x) @@ -38,10 +38,7 @@ learning_rate = 1e-6 for t in range(2000): - # Forward pass: compute predicted y using operations on Tensors; these - # are exactly the same operations we used to compute the forward pass using - # Tensors, but we do not need to keep references to intermediate values since - # we are not implementing the backward pass by hand. + # Forward pass: compute predicted y using operations on Tensors. y_pred = a + b * x + c * x ** 2 + d * x ** 3 # Compute and print loss using operations on Tensors. @@ -53,8 +50,8 @@ # Use autograd to compute the backward pass. This call will compute the # gradient of loss with respect to all Tensors with requires_grad=True. - # After this call w1.grad and w2.grad will be Tensors holding the gradient - # of the loss with respect to w1 and w2 respectively. + # After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding + # the gradient of the loss with respect to a, b, c, d respectively. loss.backward() # Manually update weights using gradient descent. Wrap in torch.no_grad() diff --git a/beginner_source/examples_autograd/polynomial_custom_function.py b/beginner_source/examples_autograd/polynomial_custom_function.py index d88e1676e46..bbe2e576c40 100755 --- a/beginner_source/examples_autograd/polynomial_custom_function.py +++ b/beginner_source/examples_autograd/polynomial_custom_function.py @@ -56,8 +56,8 @@ def backward(ctx, grad_output): # device = torch.device("cuda:0") # Uncomment this to run on GPU # Create Tensors to hold input and outputs. -# Setting requires_grad=False indicates that we do not need to compute gradients -# with respect to these Tensors during the backward pass. +# By default, requires_grad=False, which indicates that we do not need to +# compute gradients with respect to these Tensors during the backward pass. x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) y = torch.sin(x) From f441cc2cb9f7dcdbd410d9501425c0f77cc56fb2 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 19:50:11 -0800 Subject: [PATCH 10/19] no tensor.data --- beginner_source/examples_autograd/polynomial_autograd.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/beginner_source/examples_autograd/polynomial_autograd.py b/beginner_source/examples_autograd/polynomial_autograd.py index 8c1da7f38fc..65ab5892d9e 100755 --- a/beginner_source/examples_autograd/polynomial_autograd.py +++ b/beginner_source/examples_autograd/polynomial_autograd.py @@ -57,10 +57,6 @@ # Manually update weights using gradient descent. Wrap in torch.no_grad() # because weights have requires_grad=True, but we don't need to track this # in autograd. - # An alternative way is to operate on weight.data and weight.grad.data. - # Recall that tensor.data gives a tensor that shares the storage with - # tensor, but doesn't track history. - # You can also use torch.optim.SGD to achieve this. with torch.no_grad(): a -= learning_rate * a.grad b -= learning_rate * b.grad From bda95a90432acbc86fa07c42193f940a3210b45c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 20:41:50 -0800 Subject: [PATCH 11/19] fix --- beginner_source/examples_autograd/polynomial_custom_function.py | 2 +- beginner_source/pytorch_with_examples.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/beginner_source/examples_autograd/polynomial_custom_function.py b/beginner_source/examples_autograd/polynomial_custom_function.py index bbe2e576c40..33fc1a24688 100755 --- a/beginner_source/examples_autograd/polynomial_custom_function.py +++ b/beginner_source/examples_autograd/polynomial_custom_function.py @@ -6,7 +6,7 @@ A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi` to :math:`pi` by minimizing squared Euclidean distance. Instead of writing the polynomial as :math:`y=a+bx+cx^2+dx^3`, we write the polynomial as -:math:`y=a+b*P_3(c+dx)` where :math:`P_3(x)=\frac{1}{2}\left(5x^3-3x\right)` is +:math:`y=a+b P_3(c+dx)` where :math:`P_3(x)=\frac{1}{2}\left(5x^3-3x\right)` is the `Legendre polynomial`_ of degree three. .. _Legendre polynomial: diff --git a/beginner_source/pytorch_with_examples.rst b/beginner_source/pytorch_with_examples.rst index 59fe44a034a..5bcd49e79a6 100644 --- a/beginner_source/pytorch_with_examples.rst +++ b/beginner_source/pytorch_with_examples.rst @@ -117,7 +117,7 @@ and ``backward`` functions. We can then use our new autograd operator by constructing an instance and calling it like a function, passing Tensors containing input data. -In this example we define our model as :math:`y=a+b*P_3(c+dx)` instead of +In this example we define our model as :math:`y=a+b P_3(c+dx)` instead of :math:`y=a+bx+cx^2+dx^3`, where :math:`P_3(x)=\frac{1}{2}\left(5x^3-3x\right)` is the `Legendre polynomial`_ of degree three. We write our own custom autograd function for computing forward and backward of P3, and use it to implement our From 466c07b48cd2d5c5917c635fd761de6c9065af4a Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 21:00:09 -0800 Subject: [PATCH 12/19] P3 --- beginner_source/pytorch_with_examples.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/beginner_source/pytorch_with_examples.rst b/beginner_source/pytorch_with_examples.rst index 5bcd49e79a6..39d0ecb5717 100644 --- a/beginner_source/pytorch_with_examples.rst +++ b/beginner_source/pytorch_with_examples.rst @@ -120,8 +120,8 @@ Tensors containing input data. In this example we define our model as :math:`y=a+b P_3(c+dx)` instead of :math:`y=a+bx+cx^2+dx^3`, where :math:`P_3(x)=\frac{1}{2}\left(5x^3-3x\right)` is the `Legendre polynomial`_ of degree three. We write our own custom autograd -function for computing forward and backward of P3, and use it to implement our -model: +function for computing forward and backward of :math:`P_3`, and use it to implement +our model: .. _Legendre polynomial: https://en.wikipedia.org/wiki/Legendre_polynomials From 9eaf5d84bbe0ae7556aae2023c342f277ebf915d Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 21:18:42 -0800 Subject: [PATCH 13/19] save --- beginner_source/examples_nn/polynomial_nn.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/beginner_source/examples_nn/polynomial_nn.py b/beginner_source/examples_nn/polynomial_nn.py index 80a39fb3bc7..9d5aca0534e 100755 --- a/beginner_source/examples_nn/polynomial_nn.py +++ b/beginner_source/examples_nn/polynomial_nn.py @@ -42,9 +42,6 @@ torch.nn.Flatten(0, 1) ) -# In order to use :class:`torch.nn.Linear`, we need to prepare our -# input and output in a format of (batch, D_in) and (batch, D_out) - # The nn package also contains definitions of popular loss functions; in this # case we will use Mean Squared Error (MSE) as our loss function. loss_fn = torch.nn.MSELoss(reduction='sum') From 929cf87089be72337f7c0348c01b4df513963b2a Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 21:32:35 -0800 Subject: [PATCH 14/19] save --- beginner_source/examples_nn/polynomial_optim.py | 8 ++++---- beginner_source/pytorch_with_examples.rst | 11 +++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/beginner_source/examples_nn/polynomial_optim.py b/beginner_source/examples_nn/polynomial_optim.py index 8c24e4f8153..434fb6624b3 100755 --- a/beginner_source/examples_nn/polynomial_optim.py +++ b/beginner_source/examples_nn/polynomial_optim.py @@ -33,11 +33,11 @@ loss_fn = torch.nn.MSELoss(reduction='sum') # Use the optim package to define an Optimizer that will update the weights of -# the model for us. Here we will use Adam; the optim package contains many other -# optimization algorithms. The first argument to the Adam constructor tells the +# the model for us. Here we will use RMSprop; the optim package contains many other +# optimization algorithms. The first argument to the RMSprop constructor tells the # optimizer which Tensors it should update. -learning_rate = 1e-6 -optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) +learning_rate = 1e-3 +optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate) for t in range(2000): # Forward pass: compute predicted y by passing x to the model. y_pred = model(xx) diff --git a/beginner_source/pytorch_with_examples.rst b/beginner_source/pytorch_with_examples.rst index 39d0ecb5717..a6836d55365 100644 --- a/beginner_source/pytorch_with_examples.rst +++ b/beginner_source/pytorch_with_examples.rst @@ -166,18 +166,17 @@ PyTorch: optim -------------- Up to this point we have updated the weights of our models by manually -mutating the Tensors holding learnable parameters (with ``torch.no_grad()`` -or ``.data`` to avoid tracking history in autograd). This is not a huge -burden for simple optimization algorithms like stochastic gradient descent, -but in practice we often train neural networks using more sophisticated -optimizers like AdaGrad, RMSProp, Adam, etc. +mutating the Tensors holding learnable parameters with ``torch.no_grad()``. +This is not a huge burden for simple optimization algorithms like stochastic +gradient descent, but in practice we often train neural networks using more +sophisticated optimizers like AdaGrad, RMSProp, Adam, etc. The ``optim`` package in PyTorch abstracts the idea of an optimization algorithm and provides implementations of commonly used optimization algorithms. In this example we will use the ``nn`` package to define our model as -before, but we will optimize the model using the SGD algorithm provided +before, but we will optimize the model using the RMSprop algorithm provided by the ``optim`` package: .. includenodoc:: /beginner/examples_nn/polynomial_optim.py From 1dd199803e6759639b68cbf81d2d93bc3d859966 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 21:41:17 -0800 Subject: [PATCH 15/19] save --- beginner_source/pytorch_with_examples.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beginner_source/pytorch_with_examples.rst b/beginner_source/pytorch_with_examples.rst index a6836d55365..7ab37dff5ed 100644 --- a/beginner_source/pytorch_with_examples.rst +++ b/beginner_source/pytorch_with_examples.rst @@ -190,7 +190,7 @@ Modules by subclassing ``nn.Module`` and defining a ``forward`` which receives input Tensors and produces output Tensors using other modules or other autograd operations on Tensors. -In this example we implement our two-layer network as a custom Module +In this example we implement our third order polynomial as a custom Module subclass: .. includenodoc:: /beginner/examples_nn/polynomial_module.py From eaed12b8477b84db57c6ba5a19dfa806caf35a10 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 21:48:20 -0800 Subject: [PATCH 16/19] save --- beginner_source/examples_nn/polynomial_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/beginner_source/examples_nn/polynomial_module.py b/beginner_source/examples_nn/polynomial_module.py index 3e1ea195e0e..d958e7add81 100755 --- a/beginner_source/examples_nn/polynomial_module.py +++ b/beginner_source/examples_nn/polynomial_module.py @@ -15,10 +15,10 @@ class Polynomial3(torch.nn.Module): - def __init__(self,): + def __init__(self): """ In the constructor we instantiate four parameters and assign them as - member variables. + member parameters. """ super().__init__() self.a = torch.nn.Parameter(torch.randn(())) From 99b29046648acb1da03cfad65338da8651853c8e Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 21:54:20 -0800 Subject: [PATCH 17/19] fix --- beginner_source/examples_nn/polynomial_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/beginner_source/examples_nn/polynomial_module.py b/beginner_source/examples_nn/polynomial_module.py index d958e7add81..7b20a5523be 100755 --- a/beginner_source/examples_nn/polynomial_module.py +++ b/beginner_source/examples_nn/polynomial_module.py @@ -49,8 +49,8 @@ def string(self): model = Polynomial3() # Construct our loss function and an Optimizer. The call to model.parameters() -# in the SGD constructor will contain the learnable parameters of the two -# nn.Linear modules which are members of the model. +# in the SGD constructor will contain the learnable parameters of the nn.Linear +# module which is members of the model. criterion = torch.nn.MSELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) for t in range(2000): From 5e5a2b725b7fccc9ea54f28f84f95063c0287266 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 21:57:17 -0800 Subject: [PATCH 18/19] fix --- beginner_source/pytorch_with_examples.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/beginner_source/pytorch_with_examples.rst b/beginner_source/pytorch_with_examples.rst index 7ab37dff5ed..c0a2b665a56 100644 --- a/beginner_source/pytorch_with_examples.rst +++ b/beginner_source/pytorch_with_examples.rst @@ -204,8 +204,8 @@ chooses a random number between 3 and 5 and uses that many orders, reusing the same weights multiple times to compute the fourth and fifth order. For this model we can use normal Python flow control to implement the loop, -and we can implement weight sharing among the innermost layers by simply -reusing the same Module multiple times when defining the forward pass. +and we can implement weight sharing by simply reusing the same parameter multiple +times when defining the forward pass. We can easily implement this model as a Module subclass: From 28f28d88f7fd0abda5d389f2c48a754c8e5b8fd1 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 2 Dec 2020 21:59:42 -0800 Subject: [PATCH 19/19] fix --- beginner_source/examples_nn/dynamic_net.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/beginner_source/examples_nn/dynamic_net.py b/beginner_source/examples_nn/dynamic_net.py index 0c9ad66c894..31fa40f3e56 100755 --- a/beginner_source/examples_nn/dynamic_net.py +++ b/beginner_source/examples_nn/dynamic_net.py @@ -16,8 +16,7 @@ class DynamicNet(torch.nn.Module): def __init__(self): """ - In the constructor we instantiate four parameters and assign them as - member variables. + In the constructor we instantiate five parameters and assign them as members. """ super().__init__() self.a = torch.nn.Parameter(torch.randn(())) @@ -35,9 +34,8 @@ def forward(self, x): Python control-flow operators like loops or conditional statements when defining the forward pass of the model. - Here we also see that it is perfectly safe to reuse the same Module many - times when defining a computational graph. This is a big improvement from Lua - Torch, where each Module could be used only once. + Here we also see that it is perfectly safe to reuse the same parameter many + times when defining a computational graph. """ y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 for exp in range(4, random.randint(4, 6)):