update readme and force flake8
This commit is contained in:
parent
068c4068ec
commit
f68f23292e
9
.github/workflows/pytest.yml
vendored
9
.github/workflows/pytest.yml
vendored
@ -26,15 +26,10 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install .
|
pip install ".[dev]"
|
||||||
- name: Lint with flake8
|
- name: Lint with flake8
|
||||||
run: |
|
run: |
|
||||||
pip install flake8
|
flake8 . --count --show-source --statistics
|
||||||
# stop the build if there are Python syntax errors or undefined names
|
|
||||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
|
||||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
|
||||||
flake8 . --count --exit-zero --max-complexity=30 --max-line-length=79 --statistics
|
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
pip install pytest pytest-cov
|
|
||||||
pytest test --cov tianshou
|
pytest test --cov tianshou
|
||||||
|
@ -24,8 +24,7 @@ pytest test --cov tianshou -s
|
|||||||
|
|
||||||
We follow PEP8 python code style. To check, in the main directory, run:
|
We follow PEP8 python code style. To check, in the main directory, run:
|
||||||
```python
|
```python
|
||||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
flake8 . --count --show-source --statistics
|
||||||
flake8 . --count --exit-zero --max-complexity=30 --max-line-length=79 --statistics
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Documents
|
#### Documents
|
||||||
|
18
README.md
18
README.md
@ -1,9 +1,9 @@
|
|||||||
|
|
||||||
<h1 align="center">Tianshou</h1>
|
<h1 align="center">Tianshou</h1>
|
||||||
|
|
||||||

|
[](https://pypi.org/project/tianshou/)
|
||||||

|
[](https://github.com/thu-ml/tianshou/actions)
|
||||||
[](https://tianshou.readthedocs.io/en/latest/?badge=latest)
|
[](https://tianshou.readthedocs.io)
|
||||||
[](https://github.com/thu-ml/tianshou/stargazers)
|
[](https://github.com/thu-ml/tianshou/stargazers)
|
||||||
[](https://github.com/thu-ml/tianshou/network)
|
[](https://github.com/thu-ml/tianshou/network)
|
||||||
[](https://github.com/thu-ml/tianshou/issues)
|
[](https://github.com/thu-ml/tianshou/issues)
|
||||||
@ -35,7 +35,7 @@ pip3 install tianshou
|
|||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
The tutorials and API documentation are hosted on [https://tianshou.readthedocs.io](https://tianshou.readthedocs.io). It is under construction currently.
|
The tutorials and API documentation are hosted on [https://tianshou.readthedocs.io](https://tianshou.readthedocs.io).
|
||||||
|
|
||||||
The example scripts are under [test/](/test/) folder and [examples/](/examples/) folder.
|
The example scripts are under [test/](/test/) folder and [examples/](/examples/) folder.
|
||||||
|
|
||||||
@ -53,16 +53,18 @@ We select some of famous (>1k stars) reinforcement learning platforms. Here is t
|
|||||||
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||||
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
|
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
|
||||||
| Algo - Task | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
| Algo - Task | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
||||||
| PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | |
|
| PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | ? |
|
||||||
| DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | |
|
| DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | ? |
|
||||||
| A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | |
|
| A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | ? |
|
||||||
| PPO - CartPole | 35.25±16.47s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | |
|
| PPO - CartPole | 35.25±16.47s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | ? |
|
||||||
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | 377.99±13.79s | 652.83±471.28s | 172.18±62.48s |
|
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | 377.99±13.79s | 652.83±471.28s | 172.18±62.48s |
|
||||||
| TD3 - Pendulum | 48.39±7.22s | None | 620.83±248.43s | 619.33±324.97s | 210.31±76.30s |
|
| TD3 - Pendulum | 48.39±7.22s | None | 620.83±248.43s | 619.33±324.97s | 210.31±76.30s |
|
||||||
| SAC - Pendulum | 38.92±2.09s | None | 92.68±4.48s | 808.21±405.70s | 295.92±140.85s |
|
| SAC - Pendulum | 38.92±2.09s | None | 92.68±4.48s | 808.21±405.70s | 295.92±140.85s |
|
||||||
|
|
||||||
*: Could not reach the target reward threshold in 1e6 steps in any of 10 runs. The total runtime is in the brackets.
|
*: Could not reach the target reward threshold in 1e6 steps in any of 10 runs. The total runtime is in the brackets.
|
||||||
|
|
||||||
|
?: We have tried but it is nontrivial for running non-Atari game on rlpyt. See [here](https://github.com/astooke/rlpyt/issues/127#issuecomment-601741210).
|
||||||
|
|
||||||
All of the platforms use 10 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes' mean returns.
|
All of the platforms use 10 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes' mean returns.
|
||||||
|
|
||||||
Tianshou and RLlib's configures are very similar. They both use multiple workers for sampling. Indeed, both RLlib and rlpyt are excellent reinforcement learning platform :)
|
Tianshou and RLlib's configures are very similar. They both use multiple workers for sampling. Indeed, both RLlib and rlpyt are excellent reinforcement learning platform :)
|
||||||
|
3
examples/README.md
Normal file
3
examples/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
Result of Ant-v2:
|
||||||
|
|
||||||
|

|
@ -74,7 +74,7 @@ class DQN(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, state=None, info={}):
|
def forward(self, x, state=None, info={}):
|
||||||
if not isinstance(x, torch.Tensor):
|
if not isinstance(x, torch.Tensor):
|
||||||
s = torch.tensor(x, device=self.device, dtype=torch.float)
|
x = torch.tensor(x, device=self.device, dtype=torch.float)
|
||||||
x = F.relu(self.bn1(self.conv1(x)))
|
x = F.relu(self.bn1(self.conv1(x)))
|
||||||
x = F.relu(self.bn2(self.conv2(x)))
|
x = F.relu(self.bn2(self.conv2(x)))
|
||||||
x = F.relu(self.bn3(self.conv3(x)))
|
x = F.relu(self.bn3(self.conv3(x)))
|
||||||
|
13
tianshou/env/atari.py
vendored
13
tianshou/env/atari.py
vendored
@ -1,13 +1,11 @@
|
|||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
import cv2
|
import cv2
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.spaces.box import Box
|
from gym.spaces.box import Box
|
||||||
|
|
||||||
|
|
||||||
def create_atari_environment(name=None, sticky_actions=True, max_episode_steps=2000):
|
def create_atari_environment(name=None, sticky_actions=True,
|
||||||
|
max_episode_steps=2000):
|
||||||
game_version = 'v0' if sticky_actions else 'v4'
|
game_version = 'v0' if sticky_actions else 'v4'
|
||||||
name = '{}NoFrameskip-{}'.format(name, game_version)
|
name = '{}NoFrameskip-{}'.format(name, game_version)
|
||||||
env = gym.make(name)
|
env = gym.make(name)
|
||||||
@ -61,7 +59,8 @@ class preprocessing(object):
|
|||||||
self._grayscale_obs(self.screen_buffer[0])
|
self._grayscale_obs(self.screen_buffer[0])
|
||||||
self.screen_buffer[1].fill(0)
|
self.screen_buffer[1].fill(0)
|
||||||
|
|
||||||
return np.stack([self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
|
return np.stack([
|
||||||
|
self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
|
||||||
|
|
||||||
def render(self, mode):
|
def render(self, mode):
|
||||||
|
|
||||||
@ -95,7 +94,9 @@ class preprocessing(object):
|
|||||||
if len(observation) > 0:
|
if len(observation) > 0:
|
||||||
observation = np.stack(observation, axis=-1)
|
observation = np.stack(observation, axis=-1)
|
||||||
else:
|
else:
|
||||||
observation = np.stack([self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
|
observation = np.stack([
|
||||||
|
self._pool_and_resize() for _ in range(self.frame_skip)],
|
||||||
|
axis=-1)
|
||||||
if self.count >= self.max_episode_steps:
|
if self.count >= self.max_episode_steps:
|
||||||
terminal = True
|
terminal = True
|
||||||
self.terminal = terminal
|
self.terminal = terminal
|
||||||
|
1
tianshou/env/mujoco/__init__.py
vendored
1
tianshou/env/mujoco/__init__.py
vendored
@ -1,5 +1,4 @@
|
|||||||
from gym.envs.registration import register
|
from gym.envs.registration import register
|
||||||
import gym
|
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='PointMaze-v0',
|
id='PointMaze-v0',
|
||||||
|
18
tianshou/env/mujoco/maze_env_utils.py
vendored
18
tianshou/env/mujoco/maze_env_utils.py
vendored
@ -1,5 +1,4 @@
|
|||||||
"""Adapted from rllab maze_env_utils.py."""
|
"""Adapted from rllab maze_env_utils.py."""
|
||||||
import numpy as np
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
||||||
@ -111,25 +110,24 @@ def construct_maze(maze_id='Maze'):
|
|||||||
[1, 1, 1, 1],
|
[1, 1, 1, 1],
|
||||||
]
|
]
|
||||||
elif maze_id == 'Block':
|
elif maze_id == 'Block':
|
||||||
O = 'r'
|
|
||||||
structure = [
|
structure = [
|
||||||
[1, 1, 1, 1, 1],
|
[1, 1, 1, 1, 1],
|
||||||
[1, O, 0, 0, 1],
|
[1, 'r', 0, 0, 1],
|
||||||
[1, 0, 0, 0, 1],
|
[1, 0, 0, 0, 1],
|
||||||
[1, 0, 0, 0, 1],
|
[1, 0, 0, 0, 1],
|
||||||
[1, 1, 1, 1, 1],
|
[1, 1, 1, 1, 1],
|
||||||
]
|
]
|
||||||
elif maze_id == 'BlockMaze':
|
elif maze_id == 'BlockMaze':
|
||||||
O = 'r'
|
|
||||||
structure = [
|
structure = [
|
||||||
[1, 1, 1, 1],
|
[1, 1, 1, 1],
|
||||||
[1, O, 0, 1],
|
[1, 'r', 0, 1],
|
||||||
[1, 1, 0, 1],
|
[1, 1, 0, 1],
|
||||||
[1, 0, 0, 1],
|
[1, 0, 0, 1],
|
||||||
[1, 1, 1, 1],
|
[1, 1, 1, 1],
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('The provided MazeId %s is not recognized' % maze_id)
|
raise NotImplementedError(
|
||||||
|
'The provided MazeId %s is not recognized' % maze_id)
|
||||||
|
|
||||||
return structure
|
return structure
|
||||||
|
|
||||||
@ -157,7 +155,8 @@ def line_intersect(pt1, pt2, ptA, ptB):
|
|||||||
|
|
||||||
DET = (-dx1 * dy + dy1 * dx)
|
DET = (-dx1 * dy + dy1 * dx)
|
||||||
|
|
||||||
if math.fabs(DET) < DET_TOLERANCE: return (0, 0, 0, 0, 0)
|
if math.fabs(DET) < DET_TOLERANCE:
|
||||||
|
return (0, 0, 0, 0, 0)
|
||||||
|
|
||||||
# now, the determinant should be OK
|
# now, the determinant should be OK
|
||||||
DETinv = 1.0 / DET
|
DETinv = 1.0 / DET
|
||||||
@ -176,8 +175,9 @@ def line_intersect(pt1, pt2, ptA, ptB):
|
|||||||
|
|
||||||
def ray_segment_intersect(ray, segment):
|
def ray_segment_intersect(ray, segment):
|
||||||
"""
|
"""
|
||||||
Check if the ray originated from (x, y) with direction theta intersects the line segment (x1, y1) -- (x2, y2),
|
Check if the ray originated from (x, y) with direction theta
|
||||||
and return the intersection point if there is one
|
intersects the line segment (x1, y1) -- (x2, y2), and return
|
||||||
|
the intersection point if there is one
|
||||||
"""
|
"""
|
||||||
(x, y), theta = ray
|
(x, y), theta = ray
|
||||||
# (x1, y1), (x2, y2) = segment
|
# (x1, y1), (x2, y2) = segment
|
||||||
|
86
tianshou/env/mujoco/point_maze_env.py
vendored
86
tianshou/env/mujoco/point_maze_env.py
vendored
@ -108,7 +108,7 @@ class PointMazeEnv(gym.Env):
|
|||||||
rgba="0.9 0.9 0.9 1",
|
rgba="0.9 0.9 0.9 1",
|
||||||
)
|
)
|
||||||
if struct == 1: # Unmovable block.
|
if struct == 1: # Unmovable block.
|
||||||
# Offset all coordinates so that robot starts at the origin.
|
# Offset all coordinates so that robot starts at the origin
|
||||||
ET.SubElement(
|
ET.SubElement(
|
||||||
worldbody, "geom",
|
worldbody, "geom",
|
||||||
name="block_%d_%d" % (i, j),
|
name="block_%d_%d" % (i, j),
|
||||||
@ -134,13 +134,13 @@ class PointMazeEnv(gym.Env):
|
|||||||
y_offset = 0.0
|
y_offset = 0.0
|
||||||
shrink = 0.1 if spinning else 0.99 if falling else 1.0
|
shrink = 0.1 if spinning else 0.99 if falling else 1.0
|
||||||
height_shrink = 0.1 if spinning else 1.0
|
height_shrink = 0.1 if spinning else 1.0
|
||||||
|
_x = j * size_scaling - torso_x + x_offset
|
||||||
|
_y = i * size_scaling - torso_y + y_offset
|
||||||
|
_z = height / 2 * size_scaling * height_shrink
|
||||||
movable_body = ET.SubElement(
|
movable_body = ET.SubElement(
|
||||||
worldbody, "body",
|
worldbody, "body",
|
||||||
name=name,
|
name=name,
|
||||||
pos="%f %f %f" % (j * size_scaling - torso_x + x_offset,
|
pos="%f %f %f" % (_x, _y, height_offset + _z),
|
||||||
i * size_scaling - torso_y + y_offset,
|
|
||||||
height_offset +
|
|
||||||
height / 2 * size_scaling * height_shrink),
|
|
||||||
)
|
)
|
||||||
ET.SubElement(
|
ET.SubElement(
|
||||||
movable_body, "geom",
|
movable_body, "geom",
|
||||||
@ -148,7 +148,7 @@ class PointMazeEnv(gym.Env):
|
|||||||
pos="0 0 0",
|
pos="0 0 0",
|
||||||
size="%f %f %f" % (0.5 * size_scaling * shrink,
|
size="%f %f %f" % (0.5 * size_scaling * shrink,
|
||||||
0.5 * size_scaling * shrink,
|
0.5 * size_scaling * shrink,
|
||||||
height / 2 * size_scaling * height_shrink),
|
_z),
|
||||||
type="box",
|
type="box",
|
||||||
material="",
|
material="",
|
||||||
mass="0.001" if falling else "0.0002",
|
mass="0.001" if falling else "0.0002",
|
||||||
@ -232,13 +232,13 @@ class PointMazeEnv(gym.Env):
|
|||||||
self._view = np.zeros_like(self._view)
|
self._view = np.zeros_like(self._view)
|
||||||
|
|
||||||
def valid(row, col):
|
def valid(row, col):
|
||||||
return self._view.shape[0] > row >= 0 and self._view.shape[1] > col >= 0
|
return self._view.shape[0] > row >= 0 \
|
||||||
|
and self._view.shape[1] > col >= 0
|
||||||
|
|
||||||
def update_view(x, y, d, row=None, col=None):
|
def update_view(x, y, d, row=None, col=None):
|
||||||
if row is None or col is None:
|
if row is None or col is None:
|
||||||
x = x - self._robot_x
|
x = x - self._robot_x
|
||||||
y = y - self._robot_y
|
y = y - self._robot_y
|
||||||
th = self._robot_ori
|
|
||||||
|
|
||||||
row, col = self._xy_to_rowcol(x, y)
|
row, col = self._xy_to_rowcol(x, y)
|
||||||
update_view(x, y, d, row=row, col=col)
|
update_view(x, y, d, row=row, col=col)
|
||||||
@ -291,7 +291,6 @@ class PointMazeEnv(gym.Env):
|
|||||||
|
|
||||||
structure = self.MAZE_STRUCTURE
|
structure = self.MAZE_STRUCTURE
|
||||||
size_scaling = self.MAZE_SIZE_SCALING
|
size_scaling = self.MAZE_SIZE_SCALING
|
||||||
height = self.MAZE_HEIGHT
|
|
||||||
|
|
||||||
# Draw immovable blocks and chasms.
|
# Draw immovable blocks and chasms.
|
||||||
for i in range(len(structure)):
|
for i in range(len(structure)):
|
||||||
@ -311,7 +310,9 @@ class PointMazeEnv(gym.Env):
|
|||||||
update_view(block_x, block_y, 2)
|
update_view(block_x, block_y, 2)
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
cv2.imshow('x.jpg', cv2.resize(np.uint8(self._view * 255), (512, 512), interpolation=cv2.INTER_CUBIC))
|
cv2.imshow('x.jpg', cv2.resize(
|
||||||
|
np.uint8(self._view * 255), (512, 512),
|
||||||
|
interpolation=cv2.INTER_CUBIC))
|
||||||
cv2.waitKey(0)
|
cv2.waitKey(0)
|
||||||
|
|
||||||
return self._view
|
return self._view
|
||||||
@ -350,10 +351,11 @@ class PointMazeEnv(gym.Env):
|
|||||||
))
|
))
|
||||||
|
|
||||||
for block_name, block_type in self.movable_blocks:
|
for block_name, block_type in self.movable_blocks:
|
||||||
block_x, block_y, block_z = self.wrapped_env.get_body_com(block_name)[
|
block_x, block_y, block_z = \
|
||||||
:3]
|
self.wrapped_env.get_body_com(block_name)[:3]
|
||||||
if (block_z + height * size_scaling / 2 >= robot_z and
|
if (block_z + height * size_scaling / 2 >= robot_z and
|
||||||
robot_z >= block_z - height * size_scaling / 2): # Block in view.
|
robot_z >= block_z - height * size_scaling / 2):
|
||||||
|
# Block in view.
|
||||||
x1 = block_x - 0.5 * size_scaling
|
x1 = block_x - 0.5 * size_scaling
|
||||||
x2 = block_x + 0.5 * size_scaling
|
x2 = block_x + 0.5 * size_scaling
|
||||||
y1 = block_y - 0.5 * size_scaling
|
y1 = block_y - 0.5 * size_scaling
|
||||||
@ -373,8 +375,8 @@ class PointMazeEnv(gym.Env):
|
|||||||
# 3 for wall, drop-off, block
|
# 3 for wall, drop-off, block
|
||||||
sensor_readings = np.zeros((self._n_bins, 3))
|
sensor_readings = np.zeros((self._n_bins, 3))
|
||||||
for ray_idx in range(self._n_bins):
|
for ray_idx in range(self._n_bins):
|
||||||
ray_ori = (ori - self._sensor_span * 0.5 +
|
ray_ori = (ori - self._sensor_span * 0.5 + (
|
||||||
(2 * ray_idx + 1.0) / (2 * self._n_bins) * self._sensor_span)
|
2 * ray_idx + 1.0) / (2 * self._n_bins) * self._sensor_span)
|
||||||
ray_segments = []
|
ray_segments = []
|
||||||
# Get all segments that intersect with ray.
|
# Get all segments that intersect with ray.
|
||||||
for seg in segments:
|
for seg in segments:
|
||||||
@ -406,11 +408,8 @@ class PointMazeEnv(gym.Env):
|
|||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
wrapped_obs = self.wrapped_env._get_obs()
|
wrapped_obs = self.wrapped_env._get_obs()
|
||||||
# print("ant obs", wrapped_obs)
|
|
||||||
if self._top_down_view:
|
if self._top_down_view:
|
||||||
view = [self.get_top_down_view().flat]
|
self.get_top_down_view()
|
||||||
else:
|
|
||||||
view = []
|
|
||||||
|
|
||||||
if self._observe_blocks:
|
if self._observe_blocks:
|
||||||
additional_obs = []
|
additional_obs = []
|
||||||
@ -420,7 +419,7 @@ class PointMazeEnv(gym.Env):
|
|||||||
wrapped_obs = np.concatenate([wrapped_obs[:3]] + additional_obs +
|
wrapped_obs = np.concatenate([wrapped_obs[:3]] + additional_obs +
|
||||||
[wrapped_obs[3:]])
|
[wrapped_obs[3:]])
|
||||||
|
|
||||||
range_sensor_obs = self.get_range_sensor_obs()
|
self.get_range_sensor_obs()
|
||||||
return wrapped_obs
|
return wrapped_obs
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
@ -446,7 +445,8 @@ class PointMazeEnv(gym.Env):
|
|||||||
pos="%f %f %f" % (goal_x,
|
pos="%f %f %f" % (goal_x,
|
||||||
goal_y,
|
goal_y,
|
||||||
self.MAZE_HEIGHT / 2 * size_scaling),
|
self.MAZE_HEIGHT / 2 * size_scaling),
|
||||||
size="%f %f %f" % (0.1 * size_scaling, # smaller than the block to prevent collision
|
# smaller than the block to prevent collision
|
||||||
|
size="%f %f %f" % (0.1 * size_scaling,
|
||||||
0.1 * size_scaling,
|
0.1 * size_scaling,
|
||||||
self.MAZE_HEIGHT / 2 * size_scaling),
|
self.MAZE_HEIGHT / 2 * size_scaling),
|
||||||
type="box",
|
type="box",
|
||||||
@ -455,7 +455,8 @@ class PointMazeEnv(gym.Env):
|
|||||||
conaffinity="1",
|
conaffinity="1",
|
||||||
rgba="1.0 0.0 0.0 0.5"
|
rgba="1.0 0.0 0.0 0.5"
|
||||||
)
|
)
|
||||||
# Note: running the lines below will make the robot position wrong! (because the graph is rebuilt)
|
# Note: running the lines below will make the robot position wrong!
|
||||||
|
# (because the graph is rebuilt)
|
||||||
torso = self.tree.find(".//body[@name='torso']")
|
torso = self.tree.find(".//body[@name='torso']")
|
||||||
geoms = torso.findall(".//geom")
|
geoms = torso.findall(".//geom")
|
||||||
for geom in geoms:
|
for geom in geoms:
|
||||||
@ -463,18 +464,21 @@ class PointMazeEnv(gym.Env):
|
|||||||
raise Exception("Every geom of the torso must have a name "
|
raise Exception("Every geom of the torso must have a name "
|
||||||
"defined")
|
"defined")
|
||||||
_, file_path = tempfile.mkstemp(text=True, suffix='.xml')
|
_, file_path = tempfile.mkstemp(text=True, suffix='.xml')
|
||||||
self.tree.write(
|
self.tree.write(file_path)
|
||||||
file_path) # here we write a temporal file with the robot specifications. Why not the original one??
|
# here we write a temporal file with the robot specifications.
|
||||||
|
# Why not the original one??
|
||||||
|
|
||||||
model_cls = self.__class__.MODEL_CLASS
|
model_cls = self.__class__.MODEL_CLASS
|
||||||
self.wrapped_env = model_cls(*self.args, file_path=file_path,
|
# file to the robot specifications; model_cls is AntEnv
|
||||||
**self.kwargs) # file to the robot specifications; model_cls is AntEnv
|
self.wrapped_env = model_cls(
|
||||||
|
*self.args, file_path=file_path, **self.kwargs)
|
||||||
|
|
||||||
self.t = 0
|
self.t = 0
|
||||||
self.trajectory = []
|
self.trajectory = []
|
||||||
self.wrapped_env.reset()
|
self.wrapped_env.reset()
|
||||||
if len(self._init_positions) > 1:
|
if len(self._init_positions) > 1:
|
||||||
xy = self._init_positions[self.np_random.randint(len(self._init_positions))]
|
xy = self._init_positions[self.np_random.randint(
|
||||||
|
len(self._init_positions))]
|
||||||
self.wrapped_env.set_xy(xy)
|
self.wrapped_env.set_xy(xy)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
@ -518,38 +522,38 @@ class PointMazeEnv(gym.Env):
|
|||||||
def _is_in_collision(self, pos):
|
def _is_in_collision(self, pos):
|
||||||
x, y = pos
|
x, y = pos
|
||||||
structure = self.MAZE_STRUCTURE
|
structure = self.MAZE_STRUCTURE
|
||||||
size_scaling = self.MAZE_SIZE_SCALING
|
scale = self.MAZE_SIZE_SCALING
|
||||||
for i in range(len(structure)):
|
for i in range(len(structure)):
|
||||||
for j in range(len(structure[0])):
|
for j in range(len(structure[0])):
|
||||||
if structure[i][j] == 1:
|
if structure[i][j] == 1:
|
||||||
minx = j * size_scaling - size_scaling * 0.5 - self._init_torso_x
|
minx = j * scale - scale * 0.5 - self._init_torso_x
|
||||||
maxx = j * size_scaling + size_scaling * 0.5 - self._init_torso_x
|
maxx = j * scale + scale * 0.5 - self._init_torso_x
|
||||||
miny = i * size_scaling - size_scaling * 0.5 - self._init_torso_y
|
miny = i * scale - scale * 0.5 - self._init_torso_y
|
||||||
maxy = i * size_scaling + size_scaling * 0.5 - self._init_torso_y
|
maxy = i * scale + scale * 0.5 - self._init_torso_y
|
||||||
if minx <= x <= maxx and miny <= y <= maxy:
|
if minx <= x <= maxx and miny <= y <= maxy:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _rowcol_to_xy(self, j, i):
|
def _rowcol_to_xy(self, j, i):
|
||||||
size_scaling = self.MAZE_SIZE_SCALING
|
scale = self.MAZE_SIZE_SCALING
|
||||||
minx = j * size_scaling - size_scaling * 0.5 - self._init_torso_x
|
minx = j * scale - scale * 0.5 - self._init_torso_x
|
||||||
maxx = j * size_scaling + size_scaling * 0.5 - self._init_torso_x
|
maxx = j * scale + scale * 0.5 - self._init_torso_x
|
||||||
miny = i * size_scaling - size_scaling * 0.5 - self._init_torso_y
|
miny = i * scale - scale * 0.5 - self._init_torso_y
|
||||||
maxy = i * size_scaling + size_scaling * 0.5 - self._init_torso_y
|
maxy = i * scale + scale * 0.5 - self._init_torso_y
|
||||||
return (minx + maxx) / 2, (miny + maxy) / 2
|
return (minx + maxx) / 2, (miny + maxy) / 2
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self.t += 1
|
self.t += 1
|
||||||
if self._manual_collision:
|
if self._manual_collision:
|
||||||
old_pos = self.wrapped_env.get_xy()
|
old_pos = self.wrapped_env.get_xy()
|
||||||
inner_next_obs, inner_reward, inner_done, info = self.wrapped_env.step(
|
inner_next_obs, inner_reward, inner_done, info = \
|
||||||
action)
|
self.wrapped_env.step(action)
|
||||||
new_pos = self.wrapped_env.get_xy()
|
new_pos = self.wrapped_env.get_xy()
|
||||||
if self._is_in_collision(new_pos):
|
if self._is_in_collision(new_pos):
|
||||||
self.wrapped_env.set_xy(old_pos)
|
self.wrapped_env.set_xy(old_pos)
|
||||||
else:
|
else:
|
||||||
inner_next_obs, inner_reward, inner_done, info = self.wrapped_env.step(
|
inner_next_obs, inner_reward, inner_done, info = \
|
||||||
action)
|
self.wrapped_env.step(action)
|
||||||
next_obs = self._get_obs()
|
next_obs = self._get_obs()
|
||||||
done = False
|
done = False
|
||||||
if self.goal is not None:
|
if self.goal is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user