diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml
index 5714de8..538a6c3 100644
--- a/.github/workflows/pytest.yml
+++ b/.github/workflows/pytest.yml
@@ -26,15 +26,10 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install .
+ pip install ".[dev]"
- name: Lint with flake8
run: |
- pip install flake8
- # stop the build if there are Python syntax errors or undefined names
- flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
- # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- flake8 . --count --exit-zero --max-complexity=30 --max-line-length=79 --statistics
+ flake8 . --count --show-source --statistics
- name: Test with pytest
run: |
- pip install pytest pytest-cov
pytest test --cov tianshou
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index e62cd45..d2d5ba9 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -24,8 +24,7 @@ pytest test --cov tianshou -s
We follow PEP8 python code style. To check, in the main directory, run:
```python
-flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
-flake8 . --count --exit-zero --max-complexity=30 --max-line-length=79 --statistics
+flake8 . --count --show-source --statistics
```
#### Documents
@@ -40,4 +39,4 @@ To compile docs into webpages, Run
make html
```
under the `docs/` directory. The generated webpages are in `docs/_build` and
-can be viewed with browsers.
\ No newline at end of file
+can be viewed with browsers.
diff --git a/README.md b/README.md
index 6b7e913..14703c4 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,9 @@
Tianshou
-
-
-[](https://tianshou.readthedocs.io/en/latest/?badge=latest)
+[](https://pypi.org/project/tianshou/)
+[](https://github.com/thu-ml/tianshou/actions)
+[](https://tianshou.readthedocs.io)
[](https://github.com/thu-ml/tianshou/stargazers)
[](https://github.com/thu-ml/tianshou/network)
[](https://github.com/thu-ml/tianshou/issues)
@@ -35,7 +35,7 @@ pip3 install tianshou
## Documentation
-The tutorials and API documentation are hosted on [https://tianshou.readthedocs.io](https://tianshou.readthedocs.io). It is under construction currently.
+The tutorials and API documentation are hosted on [https://tianshou.readthedocs.io](https://tianshou.readthedocs.io).
The example scripts are under [test/](/test/) folder and [examples/](/examples/) folder.
@@ -53,16 +53,18 @@ We select some of famous (>1k stars) reinforcement learning platforms. Here is t
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
| Algo - Task | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
-| PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | |
-| DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | |
-| A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | |
-| PPO - CartPole | 35.25±16.47s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | |
+| PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | ? |
+| DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | ? |
+| A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | ? |
+| PPO - CartPole | 35.25±16.47s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | ? |
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | 377.99±13.79s | 652.83±471.28s | 172.18±62.48s |
| TD3 - Pendulum | 48.39±7.22s | None | 620.83±248.43s | 619.33±324.97s | 210.31±76.30s |
| SAC - Pendulum | 38.92±2.09s | None | 92.68±4.48s | 808.21±405.70s | 295.92±140.85s |
*: Could not reach the target reward threshold in 1e6 steps in any of 10 runs. The total runtime is in the brackets.
+?: We have tried but it is nontrivial for running non-Atari game on rlpyt. See [here](https://github.com/astooke/rlpyt/issues/127#issuecomment-601741210).
+
All of the platforms use 10 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes' mean returns.
Tianshou and RLlib's configures are very similar. They both use multiple workers for sampling. Indeed, both RLlib and rlpyt are excellent reinforcement learning platform :)
diff --git a/examples/README.md b/examples/README.md
new file mode 100644
index 0000000..8a63d27
--- /dev/null
+++ b/examples/README.md
@@ -0,0 +1,3 @@
+Result of Ant-v2:
+
+
\ No newline at end of file
diff --git a/test/discrete/net.py b/test/discrete/net.py
index e3d9a4a..f272e09 100644
--- a/test/discrete/net.py
+++ b/test/discrete/net.py
@@ -74,7 +74,7 @@ class DQN(nn.Module):
def forward(self, x, state=None, info={}):
if not isinstance(x, torch.Tensor):
- s = torch.tensor(x, device=self.device, dtype=torch.float)
+ x = torch.tensor(x, device=self.device, dtype=torch.float)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
diff --git a/tianshou/env/atari.py b/tianshou/env/atari.py
index 4e243cd..de44b69 100644
--- a/tianshou/env/atari.py
+++ b/tianshou/env/atari.py
@@ -1,13 +1,11 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
import cv2
import gym
import numpy as np
from gym.spaces.box import Box
-def create_atari_environment(name=None, sticky_actions=True, max_episode_steps=2000):
+def create_atari_environment(name=None, sticky_actions=True,
+ max_episode_steps=2000):
game_version = 'v0' if sticky_actions else 'v4'
name = '{}NoFrameskip-{}'.format(name, game_version)
env = gym.make(name)
@@ -61,7 +59,8 @@ class preprocessing(object):
self._grayscale_obs(self.screen_buffer[0])
self.screen_buffer[1].fill(0)
- return np.stack([self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
+ return np.stack([
+ self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
def render(self, mode):
@@ -95,7 +94,9 @@ class preprocessing(object):
if len(observation) > 0:
observation = np.stack(observation, axis=-1)
else:
- observation = np.stack([self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
+ observation = np.stack([
+ self._pool_and_resize() for _ in range(self.frame_skip)],
+ axis=-1)
if self.count >= self.max_episode_steps:
terminal = True
self.terminal = terminal
diff --git a/tianshou/env/mujoco/__init__.py b/tianshou/env/mujoco/__init__.py
index 5f53fd2..2e52e12 100644
--- a/tianshou/env/mujoco/__init__.py
+++ b/tianshou/env/mujoco/__init__.py
@@ -1,5 +1,4 @@
from gym.envs.registration import register
-import gym
register(
id='PointMaze-v0',
diff --git a/tianshou/env/mujoco/maze_env_utils.py b/tianshou/env/mujoco/maze_env_utils.py
index 059432e..dafce77 100644
--- a/tianshou/env/mujoco/maze_env_utils.py
+++ b/tianshou/env/mujoco/maze_env_utils.py
@@ -1,5 +1,4 @@
"""Adapted from rllab maze_env_utils.py."""
-import numpy as np
import math
@@ -111,25 +110,24 @@ def construct_maze(maze_id='Maze'):
[1, 1, 1, 1],
]
elif maze_id == 'Block':
- O = 'r'
structure = [
[1, 1, 1, 1, 1],
- [1, O, 0, 0, 1],
+ [1, 'r', 0, 0, 1],
[1, 0, 0, 0, 1],
[1, 0, 0, 0, 1],
[1, 1, 1, 1, 1],
]
elif maze_id == 'BlockMaze':
- O = 'r'
structure = [
[1, 1, 1, 1],
- [1, O, 0, 1],
+ [1, 'r', 0, 1],
[1, 1, 0, 1],
[1, 0, 0, 1],
[1, 1, 1, 1],
]
else:
- raise NotImplementedError('The provided MazeId %s is not recognized' % maze_id)
+ raise NotImplementedError(
+ 'The provided MazeId %s is not recognized' % maze_id)
return structure
@@ -157,7 +155,8 @@ def line_intersect(pt1, pt2, ptA, ptB):
DET = (-dx1 * dy + dy1 * dx)
- if math.fabs(DET) < DET_TOLERANCE: return (0, 0, 0, 0, 0)
+ if math.fabs(DET) < DET_TOLERANCE:
+ return (0, 0, 0, 0, 0)
# now, the determinant should be OK
DETinv = 1.0 / DET
@@ -176,8 +175,9 @@ def line_intersect(pt1, pt2, ptA, ptB):
def ray_segment_intersect(ray, segment):
"""
- Check if the ray originated from (x, y) with direction theta intersects the line segment (x1, y1) -- (x2, y2),
- and return the intersection point if there is one
+ Check if the ray originated from (x, y) with direction theta
+ intersects the line segment (x1, y1) -- (x2, y2), and return
+ the intersection point if there is one
"""
(x, y), theta = ray
# (x1, y1), (x2, y2) = segment
diff --git a/tianshou/env/mujoco/point_maze_env.py b/tianshou/env/mujoco/point_maze_env.py
index 9eb495b..81ce29d 100644
--- a/tianshou/env/mujoco/point_maze_env.py
+++ b/tianshou/env/mujoco/point_maze_env.py
@@ -108,7 +108,7 @@ class PointMazeEnv(gym.Env):
rgba="0.9 0.9 0.9 1",
)
if struct == 1: # Unmovable block.
- # Offset all coordinates so that robot starts at the origin.
+ # Offset all coordinates so that robot starts at the origin
ET.SubElement(
worldbody, "geom",
name="block_%d_%d" % (i, j),
@@ -134,13 +134,13 @@ class PointMazeEnv(gym.Env):
y_offset = 0.0
shrink = 0.1 if spinning else 0.99 if falling else 1.0
height_shrink = 0.1 if spinning else 1.0
+ _x = j * size_scaling - torso_x + x_offset
+ _y = i * size_scaling - torso_y + y_offset
+ _z = height / 2 * size_scaling * height_shrink
movable_body = ET.SubElement(
worldbody, "body",
name=name,
- pos="%f %f %f" % (j * size_scaling - torso_x + x_offset,
- i * size_scaling - torso_y + y_offset,
- height_offset +
- height / 2 * size_scaling * height_shrink),
+ pos="%f %f %f" % (_x, _y, height_offset + _z),
)
ET.SubElement(
movable_body, "geom",
@@ -148,7 +148,7 @@ class PointMazeEnv(gym.Env):
pos="0 0 0",
size="%f %f %f" % (0.5 * size_scaling * shrink,
0.5 * size_scaling * shrink,
- height / 2 * size_scaling * height_shrink),
+ _z),
type="box",
material="",
mass="0.001" if falling else "0.0002",
@@ -232,13 +232,13 @@ class PointMazeEnv(gym.Env):
self._view = np.zeros_like(self._view)
def valid(row, col):
- return self._view.shape[0] > row >= 0 and self._view.shape[1] > col >= 0
+ return self._view.shape[0] > row >= 0 \
+ and self._view.shape[1] > col >= 0
def update_view(x, y, d, row=None, col=None):
if row is None or col is None:
x = x - self._robot_x
y = y - self._robot_y
- th = self._robot_ori
row, col = self._xy_to_rowcol(x, y)
update_view(x, y, d, row=row, col=col)
@@ -252,36 +252,36 @@ class PointMazeEnv(gym.Env):
if valid(row, col):
self._view[row, col, d] += (
- (min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
- (min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
+ (min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
+ (min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
if valid(row - 1, col):
self._view[row - 1, col, d] += (
- (max(0., 0.5 - row_frac)) *
- (min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
+ (max(0., 0.5 - row_frac)) *
+ (min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
if valid(row + 1, col):
self._view[row + 1, col, d] += (
- (max(0., row_frac - 0.5)) *
- (min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
+ (max(0., row_frac - 0.5)) *
+ (min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
if valid(row, col - 1):
self._view[row, col - 1, d] += (
- (min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
- (max(0., 0.5 - col_frac)))
+ (min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
+ (max(0., 0.5 - col_frac)))
if valid(row, col + 1):
self._view[row, col + 1, d] += (
- (min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
- (max(0., col_frac - 0.5)))
+ (min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
+ (max(0., col_frac - 0.5)))
if valid(row - 1, col - 1):
self._view[row - 1, col - 1, d] += (
- (max(0., 0.5 - row_frac)) * max(0., 0.5 - col_frac))
+ (max(0., 0.5 - row_frac)) * max(0., 0.5 - col_frac))
if valid(row - 1, col + 1):
self._view[row - 1, col + 1, d] += (
- (max(0., 0.5 - row_frac)) * max(0., col_frac - 0.5))
+ (max(0., 0.5 - row_frac)) * max(0., col_frac - 0.5))
if valid(row + 1, col + 1):
self._view[row + 1, col + 1, d] += (
- (max(0., row_frac - 0.5)) * max(0., col_frac - 0.5))
+ (max(0., row_frac - 0.5)) * max(0., col_frac - 0.5))
if valid(row + 1, col - 1):
self._view[row + 1, col - 1, d] += (
- (max(0., row_frac - 0.5)) * max(0., 0.5 - col_frac))
+ (max(0., row_frac - 0.5)) * max(0., 0.5 - col_frac))
# Draw ant.
robot_x, robot_y = self.wrapped_env.get_body_com("torso")[:2]
@@ -291,7 +291,6 @@ class PointMazeEnv(gym.Env):
structure = self.MAZE_STRUCTURE
size_scaling = self.MAZE_SIZE_SCALING
- height = self.MAZE_HEIGHT
# Draw immovable blocks and chasms.
for i in range(len(structure)):
@@ -311,7 +310,9 @@ class PointMazeEnv(gym.Env):
update_view(block_x, block_y, 2)
import cv2
- cv2.imshow('x.jpg', cv2.resize(np.uint8(self._view * 255), (512, 512), interpolation=cv2.INTER_CUBIC))
+ cv2.imshow('x.jpg', cv2.resize(
+ np.uint8(self._view * 255), (512, 512),
+ interpolation=cv2.INTER_CUBIC))
cv2.waitKey(0)
return self._view
@@ -350,10 +351,11 @@ class PointMazeEnv(gym.Env):
))
for block_name, block_type in self.movable_blocks:
- block_x, block_y, block_z = self.wrapped_env.get_body_com(block_name)[
- :3]
+ block_x, block_y, block_z = \
+ self.wrapped_env.get_body_com(block_name)[:3]
if (block_z + height * size_scaling / 2 >= robot_z and
- robot_z >= block_z - height * size_scaling / 2): # Block in view.
+ robot_z >= block_z - height * size_scaling / 2):
+ # Block in view.
x1 = block_x - 0.5 * size_scaling
x2 = block_x + 0.5 * size_scaling
y1 = block_y - 0.5 * size_scaling
@@ -373,8 +375,8 @@ class PointMazeEnv(gym.Env):
# 3 for wall, drop-off, block
sensor_readings = np.zeros((self._n_bins, 3))
for ray_idx in range(self._n_bins):
- ray_ori = (ori - self._sensor_span * 0.5 +
- (2 * ray_idx + 1.0) / (2 * self._n_bins) * self._sensor_span)
+ ray_ori = (ori - self._sensor_span * 0.5 + (
+ 2 * ray_idx + 1.0) / (2 * self._n_bins) * self._sensor_span)
ray_segments = []
# Get all segments that intersect with ray.
for seg in segments:
@@ -400,17 +402,14 @@ class PointMazeEnv(gym.Env):
None)
if first_seg["distance"] <= self._sensor_range:
sensor_readings[ray_idx][idx] = (
- self._sensor_range - first_seg[
- "distance"]) / self._sensor_range
+ self._sensor_range - first_seg[
+ "distance"]) / self._sensor_range
return sensor_readings
def _get_obs(self):
wrapped_obs = self.wrapped_env._get_obs()
- # print("ant obs", wrapped_obs)
if self._top_down_view:
- view = [self.get_top_down_view().flat]
- else:
- view = []
+ self.get_top_down_view()
if self._observe_blocks:
additional_obs = []
@@ -420,7 +419,7 @@ class PointMazeEnv(gym.Env):
wrapped_obs = np.concatenate([wrapped_obs[:3]] + additional_obs +
[wrapped_obs[3:]])
- range_sensor_obs = self.get_range_sensor_obs()
+ self.get_range_sensor_obs()
return wrapped_obs
def seed(self, seed=None):
@@ -446,7 +445,8 @@ class PointMazeEnv(gym.Env):
pos="%f %f %f" % (goal_x,
goal_y,
self.MAZE_HEIGHT / 2 * size_scaling),
- size="%f %f %f" % (0.1 * size_scaling, # smaller than the block to prevent collision
+ # smaller than the block to prevent collision
+ size="%f %f %f" % (0.1 * size_scaling,
0.1 * size_scaling,
self.MAZE_HEIGHT / 2 * size_scaling),
type="box",
@@ -455,7 +455,8 @@ class PointMazeEnv(gym.Env):
conaffinity="1",
rgba="1.0 0.0 0.0 0.5"
)
- # Note: running the lines below will make the robot position wrong! (because the graph is rebuilt)
+ # Note: running the lines below will make the robot position wrong!
+ # (because the graph is rebuilt)
torso = self.tree.find(".//body[@name='torso']")
geoms = torso.findall(".//geom")
for geom in geoms:
@@ -463,18 +464,21 @@ class PointMazeEnv(gym.Env):
raise Exception("Every geom of the torso must have a name "
"defined")
_, file_path = tempfile.mkstemp(text=True, suffix='.xml')
- self.tree.write(
- file_path) # here we write a temporal file with the robot specifications. Why not the original one??
+ self.tree.write(file_path)
+ # here we write a temporal file with the robot specifications.
+ # Why not the original one??
model_cls = self.__class__.MODEL_CLASS
- self.wrapped_env = model_cls(*self.args, file_path=file_path,
- **self.kwargs) # file to the robot specifications; model_cls is AntEnv
+ # file to the robot specifications; model_cls is AntEnv
+ self.wrapped_env = model_cls(
+ *self.args, file_path=file_path, **self.kwargs)
self.t = 0
self.trajectory = []
self.wrapped_env.reset()
if len(self._init_positions) > 1:
- xy = self._init_positions[self.np_random.randint(len(self._init_positions))]
+ xy = self._init_positions[self.np_random.randint(
+ len(self._init_positions))]
self.wrapped_env.set_xy(xy)
return self._get_obs()
@@ -518,38 +522,38 @@ class PointMazeEnv(gym.Env):
def _is_in_collision(self, pos):
x, y = pos
structure = self.MAZE_STRUCTURE
- size_scaling = self.MAZE_SIZE_SCALING
+ scale = self.MAZE_SIZE_SCALING
for i in range(len(structure)):
for j in range(len(structure[0])):
if structure[i][j] == 1:
- minx = j * size_scaling - size_scaling * 0.5 - self._init_torso_x
- maxx = j * size_scaling + size_scaling * 0.5 - self._init_torso_x
- miny = i * size_scaling - size_scaling * 0.5 - self._init_torso_y
- maxy = i * size_scaling + size_scaling * 0.5 - self._init_torso_y
+ minx = j * scale - scale * 0.5 - self._init_torso_x
+ maxx = j * scale + scale * 0.5 - self._init_torso_x
+ miny = i * scale - scale * 0.5 - self._init_torso_y
+ maxy = i * scale + scale * 0.5 - self._init_torso_y
if minx <= x <= maxx and miny <= y <= maxy:
return True
return False
def _rowcol_to_xy(self, j, i):
- size_scaling = self.MAZE_SIZE_SCALING
- minx = j * size_scaling - size_scaling * 0.5 - self._init_torso_x
- maxx = j * size_scaling + size_scaling * 0.5 - self._init_torso_x
- miny = i * size_scaling - size_scaling * 0.5 - self._init_torso_y
- maxy = i * size_scaling + size_scaling * 0.5 - self._init_torso_y
+ scale = self.MAZE_SIZE_SCALING
+ minx = j * scale - scale * 0.5 - self._init_torso_x
+ maxx = j * scale + scale * 0.5 - self._init_torso_x
+ miny = i * scale - scale * 0.5 - self._init_torso_y
+ maxy = i * scale + scale * 0.5 - self._init_torso_y
return (minx + maxx) / 2, (miny + maxy) / 2
def step(self, action):
self.t += 1
if self._manual_collision:
old_pos = self.wrapped_env.get_xy()
- inner_next_obs, inner_reward, inner_done, info = self.wrapped_env.step(
- action)
+ inner_next_obs, inner_reward, inner_done, info = \
+ self.wrapped_env.step(action)
new_pos = self.wrapped_env.get_xy()
if self._is_in_collision(new_pos):
self.wrapped_env.set_xy(old_pos)
else:
- inner_next_obs, inner_reward, inner_done, info = self.wrapped_env.step(
- action)
+ inner_next_obs, inner_reward, inner_done, info = \
+ self.wrapped_env.step(action)
next_obs = self._get_obs()
done = False
if self.goal is not None: