Skip to content

[Doc] Update pendulum and rnn tutos #1691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ sphinxcontrib-htmlhelp
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
myst-parser
docutils
sphinx_design

torchvision
dm_control
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"sphinx_gallery.gen_gallery",
"sphinxcontrib.aafig",
"myst_parser",
"sphinx_design",
]

intersphinx_mapping = {
Expand Down
163 changes: 93 additions & 70 deletions tutorials/sphinx-tutorials/dqn_with_rnn.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,66 @@
# -*- coding: utf-8 -*-

"""
Recurrent DQN: Training recurrent policies
==========================================

**Author**: `Vincent Moens <https://github.com/vmoens>`_

Memory-based policies are crucial not only when the observations are partially
observable but also when the time dimension must be taken into account to
make informed decisions.

Recurrent neural network have long been a popular tool for memory-based
policies. The idea is to keep a recurrent state in memory between two
consecutive steps, and use this as an input to the policy along with the
current observation.

This tutorial shows how to incorporate an RNN in a policy.

Key learnings:

- Incorporating an RNN in an actor in TorchRL;
- Using that memory-based policy with a replay buffer and a loss module.
.. grid:: 2

The core idea of using RNNs in TorchRL is to use TensorDict as a data carrier
for the hidden states from one step to another. We'll build a policy that
reads the previous recurrent state from the current tensordict, and writes the
current recurrent states in the tensordict of the next state:
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn

.. figure:: /_static/img/rollout_recurrent.png
:alt: Data collection with a recurrent policy
* How to incorporating an RNN in an actor in TorchRL
* How to use that memory-based policy with a replay buffer and a loss module

As this figure shows, our env populates the tensordict with zeroed recurrent
states which are read by the policy together with the observation to produce an
action, and recurrent states that will be used for the next step.
When the :func:`torchrl.envs.step_mdp` function is called, the recurrent states
from the next state are brought to the current tensordict. Let's see how this
is implemented in practice.
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites

* PyTorch v2.0.0
* gym[mujoco]
* tqdm
"""

#########################################################################
# Overview
# --------
#
# Memory-based policies are crucial not only when the observations are partially
# observable but also when the time dimension must be taken into account to
# make informed decisions.
#
# Recurrent neural network have long been a popular tool for memory-based
# policies. The idea is to keep a recurrent state in memory between two
# consecutive steps, and use this as an input to the policy along with the
# current observation.
#
# This tutorial shows how to incorporate an RNN in a policy using TorchRL.
#
# Key learnings:
#
# - Incorporating an RNN in an actor in TorchRL;
# - Using that memory-based policy with a replay buffer and a loss module.
#
# The core idea of using RNNs in TorchRL is to use TensorDict as a data carrier
# for the hidden states from one step to another. We'll build a policy that
# reads the previous recurrent state from the current TensorDict, and writes the
# current recurrent states in the TensorDict of the next state:
#
# .. figure:: /_static/img/rollout_recurrent.png
# :alt: Data collection with a recurrent policy
#
# As this figure shows, our environment populates the TensorDict with zeroed recurrent
# states which are read by the policy together with the observation to produce an
# action, and recurrent states that will be used for the next step.
# When the :func:`~torchrl.envs.utils.step_mdp` function is called, the recurrent states
# from the next state are brought to the current TensorDict. Let's see how this
# is implemented in practice.

######################################################################
# If you are running this in Google Colab, make sure you install the following dependencies:
#
# .. code-block:: bash
#
# !pip3 install torchrl-nightly
# !pip3 install torchrl
# !pip3 install gym[mujoco]
# !pip3 install tqdm
#
Expand Down Expand Up @@ -87,18 +104,18 @@
# 84x84, scaling down the rewards and normalizing the observations.
#
# .. note::
# The :class:`torchrl.envs.StepCounter` transform is accessory. Since the CartPole
# The :class:`~torchrl.envs.transforms.StepCounter` transform is accessory. Since the CartPole
# task goal is to make trajectories as long as possible, counting the steps
# can help us track the performance of our policy.
#
# Two transforms are important for the purpose of this tutorial:
#
# - :class:`torchrl.envs.InitTracker` will stamp the
# calls to :meth:`torchrl.envs.EnvBase.reset` by adding a ``"is_init"``
# boolean mask in the tensordict that will track which steps require a reset
# - :class:`~torchrl.envs.transforms.InitTracker` will stamp the
# calls to :meth:`~torchrl.envs.EnvBase.reset` by adding a ``"is_init"``
# boolean mask in the TensorDict that will track which steps require a reset
# of the RNN hidden states.
# - The :class:`torchrl.envs.TensorDictPrimer` transform is a bit more
# technical: per se, it is not required to use RNN policies. However, it
# - The :class:`~torchrl.envs.transforms.TensorDictPrimer` transform is a bit more
# technical. It is not required to use RNN policies. However, it
# instructs the environment (and subsequently the collector) that some extra
# keys are to be expected. Once added, a call to `env.reset()` will populate
# the entries indicated in the primer with zeroed tensors. Knowing that
Expand All @@ -110,7 +127,7 @@
# the training of our policy, but it will make the recurrent keys disappear
# from the collected data and the replay buffer, which will in turn lead to
# a slightly less optimal training.
# Fortunately, the :class:`torchrl.modules.LSTMModule` we propose is
# Fortunately, the :class:`~torchrl.modules.LSTMModule` we propose is
# equipped with a helper method to build just that transform for us, so
# we can wait until we build it!
#
Expand All @@ -127,6 +144,7 @@
ObservationNorm(standard_normal=True, in_keys=["pixels"]),
),
)

######################################################################
# As always, we need to initialize manually our normalization constants:
#
Expand All @@ -137,16 +155,16 @@
# Policy
# ------
#
# Our policy will have 3 components: a :class:`torchrl.modules.ConvNet`
# backbone, an :class:`torchrl.modules.LSTMModule` memory layer and a shallow
# :class:`torchrl.modules.MLP` block that will map the LSTM output onto the
# Our policy will have 3 components: a :class:`~torchrl.modules.ConvNet`
# backbone, an :class:`~torchrl.modules.LSTMModule` memory layer and a shallow
# :class:`~torchrl.modules.MLP` block that will map the LSTM output onto the
# action values.
#
# Convolutional network
# ~~~~~~~~~~~~~~~~~~~~~
#
# We build a convolutional network flanked with a :class:torch.nn.AdaptiveAvgPool2d`
# that will squash the output in a vector of size 64. The :class:`torchrl.modules.ConvNet`
# We build a convolutional network flanked with a :class:`torch.nn.AdaptiveAvgPool2d`
# that will squash the output in a vector of size 64. The :class:`~torchrl.modules.ConvNet`
# can assist us with this:
#

Expand All @@ -171,20 +189,20 @@
# LSTM Module
# ~~~~~~~~~~~
#
# TorchRL provides a specialized :class:`torchrl.modules.LSTMModule` class
# to incorporate LSTMs in your code-base. It is a :class:`tensordict.nn.TensorDictModuleBase`
# TorchRL provides a specialized :class:`~torchrl.modules.LSTMModule` class
# to incorporate LSTMs in your code-base. It is a :class:`~tensordict.nn.TensorDictModuleBase`
# subclass: as such, it has a set of ``in_keys`` and ``out_keys`` that indicate
# what values should be expected to be read and written/updated during the
# execution of the module. The class comes with customizable pre-defined
# execution of the module. The class comes with customizable predefined
# values for these attributes to facilitate its construction.
#
# .. note::
# *Usage limitations*: The class supports almost all LSTM features such as
# dropout or multi-layered LSTMs.
# However, to respect TorchRL's conventions, this LSTM must have the ``batch_first``
# attribute set to ``True`` which is **not** the default in PyTorch. However,
# our :class:`torchrl.modules.LSTMModule` changes this default
# behaviour so we're good with a native call.
# our :class:`~torchrl.modules.LSTMModule` changes this default
# behavior, so we're good with a native call.
#
# Also, the LSTM cannot have a ``bidirectional`` attribute set to ``True`` as
# this wouldn't be usable in online settings. In this case, the default value
Expand All @@ -200,28 +218,28 @@
)

######################################################################
# Let us look at the lstm class, specifically its in and out_keys:
# Let us look at the LSTM Module class, specifically its in and out_keys:
print("in_keys", lstm.in_keys)
print("out_keys", lstm.out_keys)

######################################################################
# We can see that these values contain the key we indicated as the in_key (and out_key)
# as well as recurrent key names. The out_keys are preceded by a "next" prefix
# that indicates that they will need to be written in the "next" tensordict.
# that indicates that they will need to be written in the "next" TensorDict.
# We use this convention (which can be overridden by passing the in_keys/out_keys
# arguments) to make sure that a call to :func:`torchrl.envs.step_mdp` will
# move the recurrent state to the root tensordict, making it available to the
# arguments) to make sure that a call to :func:`~torchrl.envs.utils.step_mdp` will
# move the recurrent state to the root TensorDict, making it available to the
# RNN during the following call (see figure in the intro).
#
# As mentioned earlier, we have one more optional transform to add to our
# environment to make sure that the recurrent states are passed to the buffer.
# The :meth:`torchrl.modules.LSTMModule.make_tensordict_primer` method does
# The :meth:`~torchrl.modules.LSTMModule.make_tensordict_primer` method does
# exactly that:
#
env.append_transform(lstm.make_tensordict_primer())

######################################################################
# and that's it! We can print the env to check that everything looks good now
# and that's it! We can print the environment to check that everything looks good now
# that we have added the primer:
print(env)

Expand Down Expand Up @@ -249,7 +267,8 @@
# Using the Q-Values to select an action
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The last part of our policy is the Q-Value Module. The Q-Value module :class:`torchrl.modules.QValueModule`
# The last part of our policy is the Q-Value Module.
# The Q-Value module :class:`~torchrl.modules.tensordict_module.QValueModule`
# will read the ``"action_values"`` key that is produced by our MLP and
# from it, gather the action that has the maximum value.
# The only thing we need to do is to specify the action space, which can be done
Expand All @@ -261,19 +280,20 @@
######################################################################
# .. note::
# TorchRL also provides a wrapper class :class:`torchrl.modules.QValueActor` that
# wraps a module in a Sequential together with a :class:`torchrl.modules.QValueModule`
# wraps a module in a Sequential together with a :class:`~torchrl.modules.tensordict_module.QValueModule`
# like we are doing explicitly here. There is little advantage to do this
# and the process is less transparent, but the end results will be similar to
# what we do here.
#
# We can now put things together in a :class:`tensordict.nn.TensorDictSequential`
# We can now put things together in a :class:`~tensordict.nn.TensorDictSequential`
#
stoch_policy = Seq(feature, lstm, mlp, qval)

######################################################################
# DQN being a deterministic algorithm, exploration is a crucial part of it.
# We'll be using an :math:`\epsilon`-greedy policy with an epsilon of 0.2 decaying
# progressively to 0. This decay is achieved via a call to :meth:`torchrl.modules.EGreedyWrapper.step`
# progressively to 0.
# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyWrapper.step`
# (see training loop below).
#
stoch_policy = EGreedyWrapper(
Expand All @@ -291,7 +311,7 @@
# To use it, we just need to tell the LSTM module to run on "recurrent-mode"
# when used by the loss.
# As we'll usually want to have two copies of the LSTM module, we do this by
# calling a :meth:`torchrl.modules.LSTMModule.set_recurrent_mode` method that
# calling a :meth:`~torchrl.modules.LSTMModule.set_recurrent_mode` method that
# will return a new instance of the LSTM (with shared weights) that will
# assume that the input data is sequential in nature.
#
Expand All @@ -309,7 +329,7 @@
#
# Out DQN loss requires us to pass the policy and, again, the action-space.
# While this may seem redundant, it is important as we want to make sure that
# the :class:`torchrl.objectives.DQNLoss` and the :class:`torchrl.modules.QValueModule`
# the :class:`~torchrl.objectives.DQNLoss` and the :class:`~torchrl.modules.tensordict_module.QValueModule`
# classes are compatible, but aren't strongly dependent on each other.
#
# To use the Double-DQN, we ask for a ``delay_value`` argument that will
Expand All @@ -319,7 +339,7 @@

######################################################################
# Since we are using a double DQN, we need to update the target parameters.
# We'll use a :class:`torchrl.objectives.SoftUpdate` instance to carry out
# We'll use a :class:`~torchrl.objectives.SoftUpdate` instance to carry out
# this work.
#
updater = SoftUpdate(loss_fn, eps=0.95)
Expand All @@ -335,7 +355,7 @@
# will be designed to store 20 thousands trajectories of 50 steps each.
# At each optimization step (16 per data collection), we'll collect 4 items
# from our buffer, for a total of 200 transitions.
# We'll use a :class:`torchrl.data.LazyMemmapStorage` storage to keep the data
# We'll use a :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` storage to keep the data
# on disk.
#
# .. note::
Expand Down Expand Up @@ -374,7 +394,7 @@
# it is important to pass data that is not flattened
rb.extend(data.unsqueeze(0).to_tensordict().cpu())
for _ in range(utd):
s = rb.sample().to(device)
s = rb.sample().to(device, non_blocking=True)
loss_vals = loss_fn(s)
loss_vals["loss"].backward()
optim.step()
Expand All @@ -386,10 +406,9 @@
stoch_policy.step(data.numel())
updater.step()

if i % 50 == 0:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
rollout = env.rollout(10000, stoch_policy)
traj_lens.append(rollout.get(("next", "step_count")).max().item())
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
rollout = env.rollout(10000, stoch_policy)
traj_lens.append(rollout.get(("next", "step_count")).max().item())

######################################################################
# Let's plot our results:
Expand All @@ -405,14 +424,18 @@
# Conclusion
# ----------
#
# We have seen how an RNN can be incorporated in a policy in torchrl.
# We have seen how an RNN can be incorporated in a policy in TorchRL.
# You should now be able:
#
# - To create an LSTM module that acts as a TensorDictModule;
# - How to indicate to the LSTMModule that a reset is needed via an :class:`torchrl.envs.InitTracker`
# transform.
# - Incorporate this module in a policy and in a loss module;
# - Create an LSTM module that acts as a :class:`~tensordict.nn.TensorDictModule`
# - Indicate to the LSTM module that a reset is needed via an :class:`~torchrl.envs.transforms.InitTracker`
# transform
# - Incorporate this module in a policy and in a loss module
# - Make sure that the collector is made aware of the recurrent state entries
# such that they can be stored in the replay buffer along with the rest of
# the data.
# the data
#
# Further Reading
# ---------------
#
# - The TorchRL documentation can be found `here <https://pytorch.org/rl/>`_.
Loading