fix PointMaze (#8)
* update atari.py * fix setup.py pass the pytest * fix setup.py pass the pytest * add args "render" * change the tensorboard writter * change the tensorboard writter * change device, render, tensorboard log location * change device, render, tensorboard log location * remove some wrong local files * fix some tab mistakes and the envs name in continuous/test_xx.py * add examples and point robot maze environment * fix some bugs during testing examples * add dqn network and fix some args * change back the tensorboard writter's frequency to ensure ppo and a2c can write things normally * add a warning to collector * rm some unrelated files * reformat * fix a bug in test_dqn due to the model wrong selection * change atari frame skip and observation to improve performance * readd some files * change import * modified readme * rm tensorboard log * update atari and mujoco which are ignored * rm the wrong lines * readd the import of PointMaze * fix a typo in test/discrete/net.py * add a class declaration to pass the flake8 * fix flake8 errors
This commit is contained in:
parent
f68f23292e
commit
eb7fb37806
@ -9,13 +9,12 @@ from tianshou.policy import TD3Policy
|
|||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
|
|
||||||
from continuous_net import Actor, Critic
|
from continuous_net import Actor, Critic
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='PointMaze-v0')
|
parser.add_argument('--task', type=str, default='PointMaze-v1')
|
||||||
parser.add_argument('--seed', type=int, default=1626)
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--actor-lr', type=float, default=3e-5)
|
parser.add_argument('--actor-lr', type=float, default=3e-5)
|
||||||
|
2
tianshou/env/__init__.py
vendored
2
tianshou/env/__init__.py
vendored
@ -2,8 +2,10 @@ from tianshou.env.utils import CloudpickleWrapper
|
|||||||
from tianshou.env.common import EnvWrapper, FrameStack
|
from tianshou.env.common import EnvWrapper, FrameStack
|
||||||
from tianshou.env.vecenv import BaseVectorEnv, VectorEnv, \
|
from tianshou.env.vecenv import BaseVectorEnv, VectorEnv, \
|
||||||
SubprocVectorEnv, RayVectorEnv
|
SubprocVectorEnv, RayVectorEnv
|
||||||
|
from tianshou.env import mujoco
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
'mujoco',
|
||||||
'EnvWrapper',
|
'EnvWrapper',
|
||||||
'FrameStack',
|
'FrameStack',
|
||||||
'BaseVectorEnv',
|
'BaseVectorEnv',
|
||||||
|
37
tianshou/env/mujoco/point_maze_env.py
vendored
37
tianshou/env/mujoco/point_maze_env.py
vendored
@ -233,7 +233,7 @@ class PointMazeEnv(gym.Env):
|
|||||||
|
|
||||||
def valid(row, col):
|
def valid(row, col):
|
||||||
return self._view.shape[0] > row >= 0 \
|
return self._view.shape[0] > row >= 0 \
|
||||||
and self._view.shape[1] > col >= 0
|
and self._view.shape[1] > col >= 0
|
||||||
|
|
||||||
def update_view(x, y, d, row=None, col=None):
|
def update_view(x, y, d, row=None, col=None):
|
||||||
if row is None or col is None:
|
if row is None or col is None:
|
||||||
@ -252,36 +252,36 @@ class PointMazeEnv(gym.Env):
|
|||||||
|
|
||||||
if valid(row, col):
|
if valid(row, col):
|
||||||
self._view[row, col, d] += (
|
self._view[row, col, d] += (
|
||||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||||
if valid(row - 1, col):
|
if valid(row - 1, col):
|
||||||
self._view[row - 1, col, d] += (
|
self._view[row - 1, col, d] += (
|
||||||
(max(0., 0.5 - row_frac)) *
|
(max(0., 0.5 - row_frac)) *
|
||||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||||
if valid(row + 1, col):
|
if valid(row + 1, col):
|
||||||
self._view[row + 1, col, d] += (
|
self._view[row + 1, col, d] += (
|
||||||
(max(0., row_frac - 0.5)) *
|
(max(0., row_frac - 0.5)) *
|
||||||
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
|
||||||
if valid(row, col - 1):
|
if valid(row, col - 1):
|
||||||
self._view[row, col - 1, d] += (
|
self._view[row, col - 1, d] += (
|
||||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||||
(max(0., 0.5 - col_frac)))
|
(max(0., 0.5 - col_frac)))
|
||||||
if valid(row, col + 1):
|
if valid(row, col + 1):
|
||||||
self._view[row, col + 1, d] += (
|
self._view[row, col + 1, d] += (
|
||||||
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
|
||||||
(max(0., col_frac - 0.5)))
|
(max(0., col_frac - 0.5)))
|
||||||
if valid(row - 1, col - 1):
|
if valid(row - 1, col - 1):
|
||||||
self._view[row - 1, col - 1, d] += (
|
self._view[row - 1, col - 1, d] += (
|
||||||
(max(0., 0.5 - row_frac)) * max(0., 0.5 - col_frac))
|
(max(0., 0.5 - row_frac)) * max(0., 0.5 - col_frac))
|
||||||
if valid(row - 1, col + 1):
|
if valid(row - 1, col + 1):
|
||||||
self._view[row - 1, col + 1, d] += (
|
self._view[row - 1, col + 1, d] += (
|
||||||
(max(0., 0.5 - row_frac)) * max(0., col_frac - 0.5))
|
(max(0., 0.5 - row_frac)) * max(0., col_frac - 0.5))
|
||||||
if valid(row + 1, col + 1):
|
if valid(row + 1, col + 1):
|
||||||
self._view[row + 1, col + 1, d] += (
|
self._view[row + 1, col + 1, d] += (
|
||||||
(max(0., row_frac - 0.5)) * max(0., col_frac - 0.5))
|
(max(0., row_frac - 0.5)) * max(0., col_frac - 0.5))
|
||||||
if valid(row + 1, col - 1):
|
if valid(row + 1, col - 1):
|
||||||
self._view[row + 1, col - 1, d] += (
|
self._view[row + 1, col - 1, d] += (
|
||||||
(max(0., row_frac - 0.5)) * max(0., 0.5 - col_frac))
|
(max(0., row_frac - 0.5)) * max(0., 0.5 - col_frac))
|
||||||
|
|
||||||
# Draw ant.
|
# Draw ant.
|
||||||
robot_x, robot_y = self.wrapped_env.get_body_com("torso")[:2]
|
robot_x, robot_y = self.wrapped_env.get_body_com("torso")[:2]
|
||||||
@ -376,7 +376,8 @@ class PointMazeEnv(gym.Env):
|
|||||||
sensor_readings = np.zeros((self._n_bins, 3))
|
sensor_readings = np.zeros((self._n_bins, 3))
|
||||||
for ray_idx in range(self._n_bins):
|
for ray_idx in range(self._n_bins):
|
||||||
ray_ori = (ori - self._sensor_span * 0.5 + (
|
ray_ori = (ori - self._sensor_span * 0.5 + (
|
||||||
2 * ray_idx + 1.0) / (2 * self._n_bins) * self._sensor_span)
|
2 * ray_idx + 1.0) /
|
||||||
|
(2 * self._n_bins) * self._sensor_span)
|
||||||
ray_segments = []
|
ray_segments = []
|
||||||
# Get all segments that intersect with ray.
|
# Get all segments that intersect with ray.
|
||||||
for seg in segments:
|
for seg in segments:
|
||||||
@ -401,8 +402,8 @@ class PointMazeEnv(gym.Env):
|
|||||||
2 if maze_env_utils.can_move(seg_type) else # Block.
|
2 if maze_env_utils.can_move(seg_type) else # Block.
|
||||||
None)
|
None)
|
||||||
if first_seg["distance"] <= self._sensor_range:
|
if first_seg["distance"] <= self._sensor_range:
|
||||||
sensor_readings[ray_idx][idx] = (
|
sensor_readings[ray_idx][idx] = \
|
||||||
self._sensor_range - first_seg[
|
(self._sensor_range - first_seg[
|
||||||
"distance"]) / self._sensor_range
|
"distance"]) / self._sensor_range
|
||||||
return sensor_readings
|
return sensor_readings
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user