update readme and force flake8
This commit is contained in:
parent
068c4068ec
commit
f68f23292e
9
.github/workflows/pytest.yml
vendored
9
.github/workflows/pytest.yml
vendored
@ -26,15 +26,10 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .
|
||||
pip install ".[dev]"
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
pip install flake8
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=30 --max-line-length=79 --statistics
|
||||
flake8 . --count --show-source --statistics
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pip install pytest pytest-cov
|
||||
pytest test --cov tianshou
|
||||
|
@ -24,8 +24,7 @@ pytest test --cov tianshou -s
|
||||
|
||||
We follow PEP8 python code style. To check, in the main directory, run:
|
||||
```python
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
flake8 . --count --exit-zero --max-complexity=30 --max-line-length=79 --statistics
|
||||
flake8 . --count --show-source --statistics
|
||||
```
|
||||
|
||||
#### Documents
|
||||
@ -40,4 +39,4 @@ To compile docs into webpages, Run
|
||||
make html
|
||||
```
|
||||
under the `docs/` directory. The generated webpages are in `docs/_build` and
|
||||
can be viewed with browsers.
|
||||
can be viewed with browsers.
|
||||
|
18
README.md
18
README.md
@ -1,9 +1,9 @@
|
||||
|
||||
<h1 align="center">Tianshou</h1>
|
||||
|
||||

|
||||

|
||||
[](https://tianshou.readthedocs.io/en/latest/?badge=latest)
|
||||
[](https://pypi.org/project/tianshou/)
|
||||
[](https://github.com/thu-ml/tianshou/actions)
|
||||
[](https://tianshou.readthedocs.io)
|
||||
[](https://github.com/thu-ml/tianshou/stargazers)
|
||||
[](https://github.com/thu-ml/tianshou/network)
|
||||
[](https://github.com/thu-ml/tianshou/issues)
|
||||
@ -35,7 +35,7 @@ pip3 install tianshou
|
||||
|
||||
## Documentation
|
||||
|
||||
The tutorials and API documentation are hosted on [https://tianshou.readthedocs.io](https://tianshou.readthedocs.io). It is under construction currently.
|
||||
The tutorials and API documentation are hosted on [https://tianshou.readthedocs.io](https://tianshou.readthedocs.io).
|
||||
|
||||
The example scripts are under [test/](/test/) folder and [examples/](/examples/) folder.
|
||||
|
||||
@ -53,16 +53,18 @@ We select some of famous (>1k stars) reinforcement learning platforms. Here is t
|
||||
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
|
||||
| Algo - Task | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
||||
| PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | |
|
||||
| DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | |
|
||||
| A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | |
|
||||
| PPO - CartPole | 35.25±16.47s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | |
|
||||
| PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | ? |
|
||||
| DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | ? |
|
||||
| A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | ? |
|
||||
| PPO - CartPole | 35.25±16.47s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | ? |
|
||||
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | 377.99±13.79s | 652.83±471.28s | 172.18±62.48s |
|
||||
| TD3 - Pendulum | 48.39±7.22s | None | 620.83±248.43s | 619.33±324.97s | 210.31±76.30s |
|
||||
| SAC - Pendulum | 38.92±2.09s | None | 92.68±4.48s | 808.21±405.70s | 295.92±140.85s |
|
||||
|
||||
*: Could not reach the target reward threshold in 1e6 steps in any of 10 runs. The total runtime is in the brackets.
|
||||
|
||||
?: We have tried but it is nontrivial for running non-Atari game on rlpyt. See [here](https://github.com/astooke/rlpyt/issues/127#issuecomment-601741210).
|
||||
|
||||
All of the platforms use 10 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes' mean returns.
|
||||
|
||||
Tianshou and RLlib's configures are very similar. They both use multiple workers for sampling. Indeed, both RLlib and rlpyt are excellent reinforcement learning platform :)
|
||||
|
3
examples/README.md
Normal file
3
examples/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
Result of Ant-v2:
|
||||
|
||||

|
@ -74,7 +74,7 @@ class DQN(nn.Module):
|
||||
|
||||
def forward(self, x, state=None, info={}):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
s = torch.tensor(x, device=self.device, dtype=torch.float)
|
||||
x = torch.tensor(x, device=self.device, dtype=torch.float)
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
|
13
tianshou/env/atari.py
vendored
13
tianshou/env/atari.py
vendored
@ -1,13 +1,11 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import cv2
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym.spaces.box import Box
|
||||
|
||||
|
||||
def create_atari_environment(name=None, sticky_actions=True, max_episode_steps=2000):
|
||||
def create_atari_environment(name=None, sticky_actions=True,
|
||||
max_episode_steps=2000):
|
||||
game_version = 'v0' if sticky_actions else 'v4'
|
||||
name = '{}NoFrameskip-{}'.format(name, game_version)
|
||||
env = gym.make(name)
|
||||
@ -61,7 +59,8 @@ class preprocessing(object):
|
||||
self._grayscale_obs(self.screen_buffer[0])
|
||||
self.screen_buffer[1].fill(0)
|
||||
|
||||
return np.stack([self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
|
||||
return np.stack([
|
||||
self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
|
||||
|
||||
def render(self, mode):
|
||||
|
||||
@ -95,7 +94,9 @@ class preprocessing(object):
|
||||
if len(observation) > 0:
|
||||
observation = np.stack(observation, axis=-1)
|
||||
else:
|
||||
observation = np.stack([self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
|
||||
observation = np.stack([
|
||||
self._pool_and_resize() for _ in range(self.frame_skip)],
|
||||
axis=-1)
|
||||
if self.count >= self.max_episode_steps:
|
||||
terminal = True
|
||||
self.terminal = terminal
|
||||
|
1
tianshou/env/mujoco/__init__.py
vendored
1
tianshou/env/mujoco/__init__.py
vendored
@ -1,5 +1,4 @@
|
||||
from gym.envs.registration import register
|
||||
import gym
|
||||
|
||||
register(
|
||||
id='PointMaze-v0',
|
||||
|
18
tianshou/env/mujoco/maze_env_utils.py
vendored
18
tianshou/env/mujoco/maze_env_utils.py
vendored
@ -1,5 +1,4 @@
|
||||
"""Adapted from rllab maze_env_utils.py."""
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
@ -111,25 +110,24 @@ def construct_maze(maze_id='Maze'):
|
||||
[1, 1, 1, 1],
|
||||
]
|
||||
elif maze_id == 'Block':
|
||||
O = 'r'
|
||||
structure = [
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, O, 0, 0, 1],
|
||||
[1, 'r', 0, 0, 1],
|
||||
[1, 0, 0, 0, 1],
|
||||
[1, 0, 0, 0, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
]
|
||||
elif maze_id == 'BlockMaze':
|
||||
O = 'r'
|
||||
structure = [
|
||||
[1, 1, 1, 1],
|
||||
[1, O, 0, 1],
|
||||
[1, 'r', 0, 1],
|
||||
[1, 1, 0, 1],
|
||||
[1, 0, 0, 1],
|
||||
[1, 1, 1, 1],
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError('The provided MazeId %s is not recognized' % maze_id)
|
||||
raise NotImplementedError(
|
||||
'The provided MazeId %s is not recognized' % maze_id)
|
||||
|
||||
return structure
|
||||
|
||||
@ -157,7 +155,8 @@ def line_intersect(pt1, pt2, ptA, ptB):
|
||||
|
||||
DET = (-dx1 * dy + dy1 * dx)
|
||||
|
||||
if math.fabs(DET) < DET_TOLERANCE: return (0, 0, 0, 0, 0)
|
||||
if math.fabs(DET) < DET_TOLERANCE:
|
||||
return (0, 0, 0, 0, 0)
|
||||
|
||||
# now, the determinant should be OK
|
||||
DETinv = 1.0 / DET
|
||||
@ -176,8 +175,9 @@ def line_intersect(pt1, pt2, ptA, ptB):
|
||||
|
||||
def ray_segment_intersect(ray, segment):
|
||||
"""
|
||||
Check if the ray originated from (x, y) with direction theta intersects the line segment (x1, y1) -- (x2, y2),
|
||||
and return the intersection point if there is one
|
||||
Check if the ray originated from (x, y) with direction theta
|
||||
intersects the line segment (x1, y1) -- (x2, y2), and return
|
||||
the intersection point if there is one
|
||||
"""
|
||||
(x, y), theta = ray
|
||||
# (x1, y1), (x2, y2) = segment
|
||||
|
118
tianshou/env/mujoco/point_maze_env.py
vendored
118
tianshou/env/mujoco/point_maze_env.py
vendored
@ -108,7 +108,7 @@ class PointMazeEnv(gym.Env):
|
||||
rgba="0.9 0.9 0.9 1",
|
||||
)
|
||||
if struct == 1: # Unmovable block.
|
||||
# Offset all coordinates so that robot starts at the origin.
|
||||
# Offset all coordinates so that robot starts at the origin
|
||||
ET.SubElement(
|
||||
worldbody, "geom",
|
||||
name="block_%d_%d" % (i, j),
|
||||
@ -134,13 +134,13 @@ class PointMazeEnv(gym.Env):
|
||||
y_offset = 0.0
|
||||
shrink = 0.1 if spinning else 0.99 if falling else 1.0
|
||||
height_shrink = 0.1 if spinning else 1.0
|
||||
_x = j * size_scaling - torso_x + x_offset
|
||||
_y = i * size_scaling - torso_y + y_offset
|
||||
_z = height / 2 * size_scaling * height_shrink
|
||||
movable_body = ET.SubElement(
|
||||
worldbody, "body",
|
||||
name=name,
|
||||
pos="%f %f %f" % (j * size_scaling - torso_x + x_offset,
|
||||
i * size_scaling - torso_y + y_offset,
|
||||
height_offset +
|
||||
height / 2 * size_scaling * height_shrink),
|
||||
pos="%f %f %f" % (_x, _y, height_offset + _z),
|
||||
)
|
||||
ET.SubElement(
|
||||
movable_body, "geom",
|
||||
@ -148,7 +148,7 @@ class PointMazeEnv(gym.Env):
|
||||
pos="0 0 0",
|
||||
size="%f %f %f" % (0.5 * size_scaling * shrink,
|
||||
0.5 * size_scaling * shrink,
|
||||
height / 2 * size_scaling * height_shrink),
|
||||
_z),
|
||||
type="box",
|
||||
material="",
|
||||
mass="0.001" if falling else "0.0002",
|
||||
@ -232,13 +232,13 @@ class PointMazeEnv(gym.Env):
|
||||
self._view = np.zeros_like(self._view)
|
||||
|
||||
def valid(row, col):
|
||||
return self._view.shape[0] > row >= 0 and self._view.shape[1] > col >= 0
|
||||
return self._view.shape[0] > row >= 0 \
|
||||
and self._view.shape[1] > col >= 0
|
||||
|
||||
def update_view(x, y, d, row=None, col=None):
|
||||
if row is None or col is None:
|
||||
x = x - self._robot_x
|
||||
y = y - self._robot_y
|
||||
th = self._robot_ori
|
||||
|
||||
row, col = self._xy_to_rowcol(x, y)
|
||||
update_view(x, y, d, row=row, col=col)
|
||||
@ -252,36 +252,36 @@ class PointMazeEnv(gym.Env):
|
||||
|
||||
if valid(row, col):
|
||||
self._view[row, col, d] += (
|
||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||
if valid(row - 1, col):
|
||||
self._view[row - 1, col, d] += (
|
||||
(max(0., 0.5 - row_frac)) *
|
||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||
(max(0., 0.5 - row_frac)) *
|
||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||
if valid(row + 1, col):
|
||||
self._view[row + 1, col, d] += (
|
||||
(max(0., row_frac - 0.5)) *
|
||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||
(max(0., row_frac - 0.5)) *
|
||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||
if valid(row, col - 1):
|
||||
self._view[row, col - 1, d] += (
|
||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||
(max(0., 0.5 - col_frac)))
|
||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||
(max(0., 0.5 - col_frac)))
|
||||
if valid(row, col + 1):
|
||||
self._view[row, col + 1, d] += (
|
||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||
(max(0., col_frac - 0.5)))
|
||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||
(max(0., col_frac - 0.5)))
|
||||
if valid(row - 1, col - 1):
|
||||
self._view[row - 1, col - 1, d] += (
|
||||
(max(0., 0.5 - row_frac)) * max(0., 0.5 - col_frac))
|
||||
(max(0., 0.5 - row_frac)) * max(0., 0.5 - col_frac))
|
||||
if valid(row - 1, col + 1):
|
||||
self._view[row - 1, col + 1, d] += (
|
||||
(max(0., 0.5 - row_frac)) * max(0., col_frac - 0.5))
|
||||
(max(0., 0.5 - row_frac)) * max(0., col_frac - 0.5))
|
||||
if valid(row + 1, col + 1):
|
||||
self._view[row + 1, col + 1, d] += (
|
||||
(max(0., row_frac - 0.5)) * max(0., col_frac - 0.5))
|
||||
(max(0., row_frac - 0.5)) * max(0., col_frac - 0.5))
|
||||
if valid(row + 1, col - 1):
|
||||
self._view[row + 1, col - 1, d] += (
|
||||
(max(0., row_frac - 0.5)) * max(0., 0.5 - col_frac))
|
||||
(max(0., row_frac - 0.5)) * max(0., 0.5 - col_frac))
|
||||
|
||||
# Draw ant.
|
||||
robot_x, robot_y = self.wrapped_env.get_body_com("torso")[:2]
|
||||
@ -291,7 +291,6 @@ class PointMazeEnv(gym.Env):
|
||||
|
||||
structure = self.MAZE_STRUCTURE
|
||||
size_scaling = self.MAZE_SIZE_SCALING
|
||||
height = self.MAZE_HEIGHT
|
||||
|
||||
# Draw immovable blocks and chasms.
|
||||
for i in range(len(structure)):
|
||||
@ -311,7 +310,9 @@ class PointMazeEnv(gym.Env):
|
||||
update_view(block_x, block_y, 2)
|
||||
|
||||
import cv2
|
||||
cv2.imshow('x.jpg', cv2.resize(np.uint8(self._view * 255), (512, 512), interpolation=cv2.INTER_CUBIC))
|
||||
cv2.imshow('x.jpg', cv2.resize(
|
||||
np.uint8(self._view * 255), (512, 512),
|
||||
interpolation=cv2.INTER_CUBIC))
|
||||
cv2.waitKey(0)
|
||||
|
||||
return self._view
|
||||
@ -350,10 +351,11 @@ class PointMazeEnv(gym.Env):
|
||||
))
|
||||
|
||||
for block_name, block_type in self.movable_blocks:
|
||||
block_x, block_y, block_z = self.wrapped_env.get_body_com(block_name)[
|
||||
:3]
|
||||
block_x, block_y, block_z = \
|
||||
self.wrapped_env.get_body_com(block_name)[:3]
|
||||
if (block_z + height * size_scaling / 2 >= robot_z and
|
||||
robot_z >= block_z - height * size_scaling / 2): # Block in view.
|
||||
robot_z >= block_z - height * size_scaling / 2):
|
||||
# Block in view.
|
||||
x1 = block_x - 0.5 * size_scaling
|
||||
x2 = block_x + 0.5 * size_scaling
|
||||
y1 = block_y - 0.5 * size_scaling
|
||||
@ -373,8 +375,8 @@ class PointMazeEnv(gym.Env):
|
||||
# 3 for wall, drop-off, block
|
||||
sensor_readings = np.zeros((self._n_bins, 3))
|
||||
for ray_idx in range(self._n_bins):
|
||||
ray_ori = (ori - self._sensor_span * 0.5 +
|
||||
(2 * ray_idx + 1.0) / (2 * self._n_bins) * self._sensor_span)
|
||||
ray_ori = (ori - self._sensor_span * 0.5 + (
|
||||
2 * ray_idx + 1.0) / (2 * self._n_bins) * self._sensor_span)
|
||||
ray_segments = []
|
||||
# Get all segments that intersect with ray.
|
||||
for seg in segments:
|
||||
@ -400,17 +402,14 @@ class PointMazeEnv(gym.Env):
|
||||
None)
|
||||
if first_seg["distance"] <= self._sensor_range:
|
||||
sensor_readings[ray_idx][idx] = (
|
||||
self._sensor_range - first_seg[
|
||||
"distance"]) / self._sensor_range
|
||||
self._sensor_range - first_seg[
|
||||
"distance"]) / self._sensor_range
|
||||
return sensor_readings
|
||||
|
||||
def _get_obs(self):
|
||||
wrapped_obs = self.wrapped_env._get_obs()
|
||||
# print("ant obs", wrapped_obs)
|
||||
if self._top_down_view:
|
||||
view = [self.get_top_down_view().flat]
|
||||
else:
|
||||
view = []
|
||||
self.get_top_down_view()
|
||||
|
||||
if self._observe_blocks:
|
||||
additional_obs = []
|
||||
@ -420,7 +419,7 @@ class PointMazeEnv(gym.Env):
|
||||
wrapped_obs = np.concatenate([wrapped_obs[:3]] + additional_obs +
|
||||
[wrapped_obs[3:]])
|
||||
|
||||
range_sensor_obs = self.get_range_sensor_obs()
|
||||
self.get_range_sensor_obs()
|
||||
return wrapped_obs
|
||||
|
||||
def seed(self, seed=None):
|
||||
@ -446,7 +445,8 @@ class PointMazeEnv(gym.Env):
|
||||
pos="%f %f %f" % (goal_x,
|
||||
goal_y,
|
||||
self.MAZE_HEIGHT / 2 * size_scaling),
|
||||
size="%f %f %f" % (0.1 * size_scaling, # smaller than the block to prevent collision
|
||||
# smaller than the block to prevent collision
|
||||
size="%f %f %f" % (0.1 * size_scaling,
|
||||
0.1 * size_scaling,
|
||||
self.MAZE_HEIGHT / 2 * size_scaling),
|
||||
type="box",
|
||||
@ -455,7 +455,8 @@ class PointMazeEnv(gym.Env):
|
||||
conaffinity="1",
|
||||
rgba="1.0 0.0 0.0 0.5"
|
||||
)
|
||||
# Note: running the lines below will make the robot position wrong! (because the graph is rebuilt)
|
||||
# Note: running the lines below will make the robot position wrong!
|
||||
# (because the graph is rebuilt)
|
||||
torso = self.tree.find(".//body[@name='torso']")
|
||||
geoms = torso.findall(".//geom")
|
||||
for geom in geoms:
|
||||
@ -463,18 +464,21 @@ class PointMazeEnv(gym.Env):
|
||||
raise Exception("Every geom of the torso must have a name "
|
||||
"defined")
|
||||
_, file_path = tempfile.mkstemp(text=True, suffix='.xml')
|
||||
self.tree.write(
|
||||
file_path) # here we write a temporal file with the robot specifications. Why not the original one??
|
||||
self.tree.write(file_path)
|
||||
# here we write a temporal file with the robot specifications.
|
||||
# Why not the original one??
|
||||
|
||||
model_cls = self.__class__.MODEL_CLASS
|
||||
self.wrapped_env = model_cls(*self.args, file_path=file_path,
|
||||
**self.kwargs) # file to the robot specifications; model_cls is AntEnv
|
||||
# file to the robot specifications; model_cls is AntEnv
|
||||
self.wrapped_env = model_cls(
|
||||
*self.args, file_path=file_path, **self.kwargs)
|
||||
|
||||
self.t = 0
|
||||
self.trajectory = []
|
||||
self.wrapped_env.reset()
|
||||
if len(self._init_positions) > 1:
|
||||
xy = self._init_positions[self.np_random.randint(len(self._init_positions))]
|
||||
xy = self._init_positions[self.np_random.randint(
|
||||
len(self._init_positions))]
|
||||
self.wrapped_env.set_xy(xy)
|
||||
return self._get_obs()
|
||||
|
||||
@ -518,38 +522,38 @@ class PointMazeEnv(gym.Env):
|
||||
def _is_in_collision(self, pos):
|
||||
x, y = pos
|
||||
structure = self.MAZE_STRUCTURE
|
||||
size_scaling = self.MAZE_SIZE_SCALING
|
||||
scale = self.MAZE_SIZE_SCALING
|
||||
for i in range(len(structure)):
|
||||
for j in range(len(structure[0])):
|
||||
if structure[i][j] == 1:
|
||||
minx = j * size_scaling - size_scaling * 0.5 - self._init_torso_x
|
||||
maxx = j * size_scaling + size_scaling * 0.5 - self._init_torso_x
|
||||
miny = i * size_scaling - size_scaling * 0.5 - self._init_torso_y
|
||||
maxy = i * size_scaling + size_scaling * 0.5 - self._init_torso_y
|
||||
minx = j * scale - scale * 0.5 - self._init_torso_x
|
||||
maxx = j * scale + scale * 0.5 - self._init_torso_x
|
||||
miny = i * scale - scale * 0.5 - self._init_torso_y
|
||||
maxy = i * scale + scale * 0.5 - self._init_torso_y
|
||||
if minx <= x <= maxx and miny <= y <= maxy:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _rowcol_to_xy(self, j, i):
|
||||
size_scaling = self.MAZE_SIZE_SCALING
|
||||
minx = j * size_scaling - size_scaling * 0.5 - self._init_torso_x
|
||||
maxx = j * size_scaling + size_scaling * 0.5 - self._init_torso_x
|
||||
miny = i * size_scaling - size_scaling * 0.5 - self._init_torso_y
|
||||
maxy = i * size_scaling + size_scaling * 0.5 - self._init_torso_y
|
||||
scale = self.MAZE_SIZE_SCALING
|
||||
minx = j * scale - scale * 0.5 - self._init_torso_x
|
||||
maxx = j * scale + scale * 0.5 - self._init_torso_x
|
||||
miny = i * scale - scale * 0.5 - self._init_torso_y
|
||||
maxy = i * scale + scale * 0.5 - self._init_torso_y
|
||||
return (minx + maxx) / 2, (miny + maxy) / 2
|
||||
|
||||
def step(self, action):
|
||||
self.t += 1
|
||||
if self._manual_collision:
|
||||
old_pos = self.wrapped_env.get_xy()
|
||||
inner_next_obs, inner_reward, inner_done, info = self.wrapped_env.step(
|
||||
action)
|
||||
inner_next_obs, inner_reward, inner_done, info = \
|
||||
self.wrapped_env.step(action)
|
||||
new_pos = self.wrapped_env.get_xy()
|
||||
if self._is_in_collision(new_pos):
|
||||
self.wrapped_env.set_xy(old_pos)
|
||||
else:
|
||||
inner_next_obs, inner_reward, inner_done, info = self.wrapped_env.step(
|
||||
action)
|
||||
inner_next_obs, inner_reward, inner_done, info = \
|
||||
self.wrapped_env.step(action)
|
||||
next_obs = self._get_obs()
|
||||
done = False
|
||||
if self.goal is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user