fix PointMaze (#8)
* update atari.py * fix setup.py pass the pytest * fix setup.py pass the pytest * add args "render" * change the tensorboard writter * change the tensorboard writter * change device, render, tensorboard log location * change device, render, tensorboard log location * remove some wrong local files * fix some tab mistakes and the envs name in continuous/test_xx.py * add examples and point robot maze environment * fix some bugs during testing examples * add dqn network and fix some args * change back the tensorboard writter's frequency to ensure ppo and a2c can write things normally * add a warning to collector * rm some unrelated files * reformat * fix a bug in test_dqn due to the model wrong selection * change atari frame skip and observation to improve performance * readd some files * change import * modified readme * rm tensorboard log * update atari and mujoco which are ignored * rm the wrong lines * readd the import of PointMaze * fix a typo in test/discrete/net.py * add a class declaration to pass the flake8 * fix flake8 errors
This commit is contained in:
parent
f68f23292e
commit
eb7fb37806
@ -9,13 +9,12 @@ from tianshou.policy import TD3Policy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
|
||||
from continuous_net import Actor, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PointMaze-v0')
|
||||
parser.add_argument('--task', type=str, default='PointMaze-v1')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--actor-lr', type=float, default=3e-5)
|
||||
|
2
tianshou/env/__init__.py
vendored
2
tianshou/env/__init__.py
vendored
@ -2,8 +2,10 @@ from tianshou.env.utils import CloudpickleWrapper
|
||||
from tianshou.env.common import EnvWrapper, FrameStack
|
||||
from tianshou.env.vecenv import BaseVectorEnv, VectorEnv, \
|
||||
SubprocVectorEnv, RayVectorEnv
|
||||
from tianshou.env import mujoco
|
||||
|
||||
__all__ = [
|
||||
'mujoco',
|
||||
'EnvWrapper',
|
||||
'FrameStack',
|
||||
'BaseVectorEnv',
|
||||
|
37
tianshou/env/mujoco/point_maze_env.py
vendored
37
tianshou/env/mujoco/point_maze_env.py
vendored
@ -233,7 +233,7 @@ class PointMazeEnv(gym.Env):
|
||||
|
||||
def valid(row, col):
|
||||
return self._view.shape[0] > row >= 0 \
|
||||
and self._view.shape[1] > col >= 0
|
||||
and self._view.shape[1] > col >= 0
|
||||
|
||||
def update_view(x, y, d, row=None, col=None):
|
||||
if row is None or col is None:
|
||||
@ -252,36 +252,36 @@ class PointMazeEnv(gym.Env):
|
||||
|
||||
if valid(row, col):
|
||||
self._view[row, col, d] += (
|
||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||
if valid(row - 1, col):
|
||||
self._view[row - 1, col, d] += (
|
||||
(max(0., 0.5 - row_frac)) *
|
||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||
(max(0., 0.5 - row_frac)) *
|
||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||
if valid(row + 1, col):
|
||||
self._view[row + 1, col, d] += (
|
||||
(max(0., row_frac - 0.5)) *
|
||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||
(max(0., row_frac - 0.5)) *
|
||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||
if valid(row, col - 1):
|
||||
self._view[row, col - 1, d] += (
|
||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||
(max(0., 0.5 - col_frac)))
|
||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||
(max(0., 0.5 - col_frac)))
|
||||
if valid(row, col + 1):
|
||||
self._view[row, col + 1, d] += (
|
||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||
(max(0., col_frac - 0.5)))
|
||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||
(max(0., col_frac - 0.5)))
|
||||
if valid(row - 1, col - 1):
|
||||
self._view[row - 1, col - 1, d] += (
|
||||
(max(0., 0.5 - row_frac)) * max(0., 0.5 - col_frac))
|
||||
(max(0., 0.5 - row_frac)) * max(0., 0.5 - col_frac))
|
||||
if valid(row - 1, col + 1):
|
||||
self._view[row - 1, col + 1, d] += (
|
||||
(max(0., 0.5 - row_frac)) * max(0., col_frac - 0.5))
|
||||
(max(0., 0.5 - row_frac)) * max(0., col_frac - 0.5))
|
||||
if valid(row + 1, col + 1):
|
||||
self._view[row + 1, col + 1, d] += (
|
||||
(max(0., row_frac - 0.5)) * max(0., col_frac - 0.5))
|
||||
(max(0., row_frac - 0.5)) * max(0., col_frac - 0.5))
|
||||
if valid(row + 1, col - 1):
|
||||
self._view[row + 1, col - 1, d] += (
|
||||
(max(0., row_frac - 0.5)) * max(0., 0.5 - col_frac))
|
||||
(max(0., row_frac - 0.5)) * max(0., 0.5 - col_frac))
|
||||
|
||||
# Draw ant.
|
||||
robot_x, robot_y = self.wrapped_env.get_body_com("torso")[:2]
|
||||
@ -376,7 +376,8 @@ class PointMazeEnv(gym.Env):
|
||||
sensor_readings = np.zeros((self._n_bins, 3))
|
||||
for ray_idx in range(self._n_bins):
|
||||
ray_ori = (ori - self._sensor_span * 0.5 + (
|
||||
2 * ray_idx + 1.0) / (2 * self._n_bins) * self._sensor_span)
|
||||
2 * ray_idx + 1.0) /
|
||||
(2 * self._n_bins) * self._sensor_span)
|
||||
ray_segments = []
|
||||
# Get all segments that intersect with ray.
|
||||
for seg in segments:
|
||||
@ -401,8 +402,8 @@ class PointMazeEnv(gym.Env):
|
||||
2 if maze_env_utils.can_move(seg_type) else # Block.
|
||||
None)
|
||||
if first_seg["distance"] <= self._sensor_range:
|
||||
sensor_readings[ray_idx][idx] = (
|
||||
self._sensor_range - first_seg[
|
||||
sensor_readings[ray_idx][idx] = \
|
||||
(self._sensor_range - first_seg[
|
||||
"distance"]) / self._sensor_range
|
||||
return sensor_readings
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user