Skip to content

Upgrade to Gymnasium interface #6203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ repos:
(?x)^(
.*cs.meta|
.*.css|
.*.meta
.*.meta|
.*.asset|
.*.prefab|
.*.unity|
.*.json
)$
args: [--fix=lf]
args: [--fix=crlf]

- id: trailing-whitespace
name: trailing-whitespace-markdown
Expand Down
4 changes: 2 additions & 2 deletions colab/Colab_UnityEnvironment_4_SB3VectorEnv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@
"from pathlib import Path\n",
"from typing import Callable, Any\n",
"\n",
"import gym\n",
"from gym import Env\n",
"import gymnasium as gym\n",
"from gymnasium import Env\n",
"\n",
"from stable_baselines3 import PPO\n",
"from stable_baselines3.common.vec_env import VecMonitor, VecEnv, SubprocVecEnv\n",
Expand Down
19 changes: 12 additions & 7 deletions docs/Python-Gym-API-Documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,22 @@ Environment initialization
#### reset

```python
| reset() -> Union[List[np.ndarray], np.ndarray]
| reset(*, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[np.ndarray, Dict]
```

Resets the state of the environment and returns an initial observation.
Returns: observation (object/list): the initial observation of the
space.
Resets the state of the environment and returns an initial observation and info.

**Returns**:

- `observation` _object/list_ - the initial observation of the
space.
- `info` _dict_ - contains auxiliary diagnostic information.

<a name="mlagents_envs.envs.unity_gym_env.UnityToGymWrapper.step"></a>
#### step

```python
| step(action: List[Any]) -> GymStepResult
| step(action: Any) -> GymStepResult
```

Run one timestep of the environment's dynamics. When end of
Expand All @@ -86,14 +90,15 @@ Accepts an action and returns a tuple (observation, reward, done, info).

- `observation` _object/list_ - agent's observation of the current environment
reward (float/list) : amount of reward returned after previous action
- `done` _boolean/list_ - whether the episode has ended.
- `terminated` _boolean/list_ - whether the episode has ended by termination.
- `truncated` _boolean/list_ - whether the episode has ended by truncation.
- `info` _dict_ - contains auxiliary diagnostic information.

<a name="mlagents_envs.envs.unity_gym_env.UnityToGymWrapper.render"></a>
#### render

```python
| render(mode="rgb_array")
| render()
```

Return the latest visual observations.
Expand Down
2 changes: 1 addition & 1 deletion docs/Python-Gym-API.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ observation, a single discrete action and a single Agent in the scene.
Add the following code to the `train_unity.py` file:

```python
import gym
import gymnasium as gym

from baselines import deepq
from baselines import logger
Expand Down
2 changes: 1 addition & 1 deletion ml-agents-envs/mlagents_envs/envs/unity_aec_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, Optional
from gym import error
from gymnasium import error
from mlagents_envs.base_env import BaseEnv
from pettingzoo import AECEnv

Expand Down
44 changes: 30 additions & 14 deletions ml-agents-envs/mlagents_envs/envs/unity_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import numpy as np
from typing import Any, Dict, List, Optional, Tuple, Union

import gym
from gym import error, spaces
import gymnasium as gym
from gymnasium import error, spaces

from mlagents_envs.base_env import ActionTuple, BaseEnv
from mlagents_envs.base_env import DecisionSteps, TerminalSteps
Expand All @@ -20,7 +20,7 @@ class UnityGymException(error.Error):


logger = logging_util.get_logger(__name__)
GymStepResult = Tuple[np.ndarray, float, bool, Dict]
GymStepResult = Tuple[np.ndarray, float, bool, bool, Dict]


class UnityToGymWrapper(gym.Env):
Expand Down Expand Up @@ -151,21 +151,26 @@ def __init__(
else:
self._observation_space = list_spaces[0] # only return the first one

def reset(self) -> Union[List[np.ndarray], np.ndarray]:
"""Resets the state of the environment and returns an initial observation.
Returns: observation (object/list): the initial observation of the
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> Tuple[np.ndarray, Dict]:
"""Resets the state of the environment and returns an initial observation and info.
Returns:
observation (object/list): the initial observation of the
space.
info (dict): contains auxiliary diagnostic information.
"""
super().reset(seed=seed, options=options)
self._env.reset()
decision_step, _ = self._env.get_steps(self.name)
n_agents = len(decision_step)
self._check_agents(n_agents)
self.game_over = False

res: GymStepResult = self._single_step(decision_step)
return res[0]
return res[0], res[4]

def step(self, action: List[Any]) -> GymStepResult:
def step(self, action: Any) -> GymStepResult:
"""Run one timestep of the environment's dynamics. When end of
episode is reached, you are responsible for calling `reset()`
to reset this environment's state.
Expand All @@ -175,14 +180,15 @@ def step(self, action: List[Any]) -> GymStepResult:
Returns:
observation (object/list): agent's observation of the current environment
reward (float/list) : amount of reward returned after previous action
done (boolean/list): whether the episode has ended.
terminated (boolean/list): whether the episode has ended by termination.
truncated (boolean/list): whether the episode has ended by truncation.
info (dict): contains auxiliary diagnostic information.
"""
if self.game_over:
raise UnityGymException(
"You are calling 'step()' even though this environment has already "
"returned done = True. You must always call 'reset()' once you "
"receive 'done = True'."
"returned `terminated` or `truncated` as True. You must always call 'reset()' once you "
"receive `terminated` or `truncated` as True."
)
if self._flattener is not None:
# Translate action into list
Expand Down Expand Up @@ -227,9 +233,19 @@ def _single_step(self, info: Union[DecisionSteps, TerminalSteps]) -> GymStepResu
visual_obs = self._get_vis_obs_list(info)
self.visual_obs = self._preprocess_single(visual_obs[0][0])

done = isinstance(info, TerminalSteps)
if isinstance(info, TerminalSteps):
interrupted = info.interrupted
terminated, truncated = not interrupted, interrupted
else:
terminated, truncated = False, False

return (default_observation, info.reward[0], done, {"step": info})
return (
default_observation,
info.reward[0],
terminated,
truncated,
{"step": info},
)

def _preprocess_single(self, single_visual_obs: np.ndarray) -> np.ndarray:
if self.uint8_visual:
Expand Down Expand Up @@ -276,7 +292,7 @@ def _get_vec_obs_size(self) -> int:
result += obs_spec.shape[0]
return result

def render(self, mode="rgb_array"):
def render(self):
"""
Return the latest visual observations.
Note that it will not render a new frame of the environment.
Expand Down
2 changes: 1 addition & 1 deletion ml-agents-envs/mlagents_envs/envs/unity_parallel_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Optional, Dict, Any, Tuple
from gym import error
from gymnasium import error
from mlagents_envs.base_env import BaseEnv
from pettingzoo import ParallelEnv

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import atexit
from typing import Optional, List, Set, Dict, Any, Tuple
import numpy as np
from gym import error, spaces
from gymnasium import error, spaces
from mlagents_envs.base_env import BaseEnv, ActionTuple
from mlagents_envs.envs.env_helpers import _agent_id_to_behavior, _unwrap_batch_steps

Expand Down
8 changes: 4 additions & 4 deletions ml-agents-envs/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def run(self):
"Pillow>=4.2.1",
"protobuf>=3.6,<3.21",
"pyyaml>=3.1.0",
"gym>=0.21.0",
"pettingzoo==1.15.0",
"numpy>=1.23.5,<1.24.0",
"gymnasium>=0.25.0",
"pettingzoo>=1.15.0",
"numpy>=1.23.5,<2.0",
"filelock>=3.4.0",
],
python_requires=">=3.10.1,<=3.10.12",
python_requires=">=3.9,<4",
# TODO: Remove this once mypy stops having spurious setuptools issues.
cmdclass={"verify": VerifyVersionCommand}, # type: ignore
)
2 changes: 1 addition & 1 deletion ml-agents-envs/tests/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
import numpy as np

from gym import spaces
from gymnasium import spaces

from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper
from mlagents_envs.base_env import (
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/subprocess_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
UnityCommunicatorStoppedException,
)
from multiprocessing import Process, Pipe, Queue
from multiprocessing.connection import Connection
from multiprocessing.connection import Connection, PipeConnection
from queue import Empty as EmptyQueueException
from mlagents_envs.base_env import BaseEnv, BehaviorName, BehaviorSpec
from mlagents_envs import logging_util
Expand Down Expand Up @@ -77,7 +77,7 @@ class StepResponse(NamedTuple):


class UnityEnvWorker:
def __init__(self, process: Process, worker_id: int, conn: Connection):
def __init__(self, process: Process, worker_id: int, conn: PipeConnection):
self.process = process
self.worker_id = worker_id
self.conn = conn
Expand Down
7 changes: 3 additions & 4 deletions utils/generate_markdown_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import argparse
import hashlib


# pydoc-markdown -I . -m module_name --render_toc > doc.md


Expand Down Expand Up @@ -52,8 +51,8 @@ def remove_trailing_whitespace(filename):
# compare source and destination and write only if changed
if source_file != destination_file:
num_changed += 1
with open(filename, "wb") as f:
f.write(destination_file.encode())
with open(filename, "w", newline="\r\n") as f:
f.write(destination_file)


if __name__ == "__main__":
Expand Down Expand Up @@ -84,7 +83,7 @@ def remove_trailing_whitespace(filename):
for submodule in submodules:
module_args.append("-m")
module_args.append(f"{module_name}.{submodule}")
with open(output_file_name, "w") as output_file:
with open(output_file_name, "wb") as output_file:
subprocess_args = [
"pydoc-markdown",
"-I",
Expand Down