added noise param to collector for test phase, fixed examples to adapt modification (#86)

* Add auto alpha tuning and exploration noise for sac.
Add class BaseNoise and GaussianNoise for the concept of exploration noise.
Add new test for sac tested in MountainCarContinuous-v0,
which should benefits from the two above new feature.

* add exploration noise to collector, fix example to adapt modification
This commit is contained in:
danagi 2020-06-23 07:20:51 +08:00 committed by GitHub
parent e8b44bbaf4
commit 13828f6309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 25 additions and 8 deletions

View File

@ -9,6 +9,7 @@ from tianshou.policy import DDPGPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.env import VectorEnv, SubprocVectorEnv
from tianshou.exploration import GaussianNoise
from continuous_net import Actor, Critic from continuous_net import Actor, Critic
@ -67,7 +68,7 @@ def test_ddpg(args=get_args()):
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy( policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim, actor, actor_optim, critic, critic_optim,
args.tau, args.gamma, args.exploration_noise, args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
[env.action_space.low[0], env.action_space.high[0]], [env.action_space.low[0], env.action_space.high[0]],
reward_normalization=True, ignore_done=True) reward_normalization=True, ignore_done=True)
# collector # collector

View File

@ -9,6 +9,7 @@ from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.env import VectorEnv, SubprocVectorEnv
from tianshou.exploration import GaussianNoise
from continuous_net import Actor, Critic from continuous_net import Actor, Critic
@ -74,7 +75,8 @@ def test_td3(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy( policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.exploration_noise, args.policy_noise, args.tau, args.gamma,
GaussianNoise(sigma=args.exploration_noise), args.policy_noise,
args.update_actor_freq, args.noise_clip, args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]], [env.action_space.low[0], env.action_space.high[0]],
reward_normalization=True, ignore_done=True) reward_normalization=True, ignore_done=True)

View File

@ -9,6 +9,7 @@ from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.env import VectorEnv, SubprocVectorEnv
from tianshou.exploration import GaussianNoise
from continuous_net import Actor, Critic from continuous_net import Actor, Critic
from mujoco.register import reg from mujoco.register import reg
@ -77,7 +78,8 @@ def test_td3(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy( policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.exploration_noise, args.policy_noise, args.tau, args.gamma,
GaussianNoise(sigma=args.exploration_noise), args.policy_noise,
args.update_actor_freq, args.noise_clip, args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]], [env.action_space.low[0], env.action_space.high[0]],
reward_normalization=True, ignore_done=True) reward_normalization=True, ignore_done=True)

View File

@ -23,20 +23,20 @@ def get_args():
parser.add_argument('--actor-lr', type=float, default=3e-4) parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=3e-4) parser.add_argument('--critic-lr', type=float, default=3e-4)
parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--alpha-lr', type=float, default=3e-4)
parser.add_argument('--noise_std', type=float, default=0.5) parser.add_argument('--noise_std', type=float, default=1.2)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--auto_alpha', type=bool, default=True) parser.add_argument('--auto_alpha', type=bool, default=True)
parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400) parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=1) parser.add_argument('--collect-per-step', type=int, default=5)
parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1) parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=80) parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=1.0/35.0) parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=False) parser.add_argument('--rew-norm', type=bool, default=False)
parser.add_argument( parser.add_argument(
'--device', type=str, '--device', type=str,

View File

@ -28,6 +28,7 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--exploration-noise', type=float, default=0.1) parser.add_argument('--exploration-noise', type=float, default=0.1)
parser.add_argument('--test-noise', type=float, default=0.1)
parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400) parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=4) parser.add_argument('--collect-per-step', type=int, default=4)
@ -87,7 +88,8 @@ def test_ddpg(args=get_args()):
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs) test_collector = Collector(
policy, test_envs, action_noise=GaussianNoise(sigma=args.test_noise))
# log # log
log_path = os.path.join(args.logdir, args.task, 'ddpg') log_path = os.path.join(args.logdir, args.task, 'ddpg')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -9,6 +9,7 @@ from tianshou.utils import MovAvg
from tianshou.env import BaseVectorEnv from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
from tianshou.exploration import BaseNoise
class Collector(object): class Collector(object):
@ -27,6 +28,9 @@ class Collector(object):
added to the buffer, see issue #42, defaults to ``None``. added to the buffer, see issue #42, defaults to ``None``.
:param int stat_size: for the moving average of recording speed, defaults :param int stat_size: for the moving average of recording speed, defaults
to 100. to 100.
:param BaseNoise action_noise: add a noise to continuous action. Normally
a policy already has a noise param for exploration in training phase,
so this is recommended to use in test collector for some purpose.
The ``preprocess_fn`` is a function called before the data has been added The ``preprocess_fn`` is a function called before the data has been added
to the buffer with batch format, which receives up to 7 keys as listed in to the buffer with batch format, which receives up to 7 keys as listed in
@ -87,6 +91,7 @@ class Collector(object):
= None, = None,
preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
stat_size: Optional[int] = 100, stat_size: Optional[int] = 100,
action_noise: Optional[BaseNoise] = None,
**kwargs) -> None: **kwargs) -> None:
super().__init__() super().__init__()
self.env = env self.env = env
@ -119,6 +124,7 @@ class Collector(object):
else: else:
raise TypeError('The buffer in data collector is invalid!') raise TypeError('The buffer in data collector is invalid!')
self.stat_size = stat_size self.stat_size = stat_size
self._action_noise = action_noise
self.reset() self.reset()
def reset(self) -> None: def reset(self) -> None:
@ -132,6 +138,8 @@ class Collector(object):
self.collect_step = 0 self.collect_step = 0
self.collect_episode = 0 self.collect_episode = 0
self.collect_time = 0 self.collect_time = 0
if self._action_noise is not None:
self._action_noise.reset()
def reset_buffer(self) -> None: def reset_buffer(self) -> None:
"""Reset the main data buffer.""" """Reset the main data buffer."""
@ -268,6 +276,8 @@ class Collector(object):
self._policy = to_numpy(result.policy) \ self._policy = to_numpy(result.policy) \
if hasattr(result, 'policy') else [{}] * self.env_num if hasattr(result, 'policy') else [{}] * self.env_num
self._act = to_numpy(result.act) self._act = to_numpy(result.act)
if self._action_noise is not None:
self._act += self._action_noise(self._act.shape)
obs_next, self._rew, self._done, self._info = self.env.step( obs_next, self._rew, self._done, self._info = self.env.step(
self._act if self._multi_env else self._act[0]) self._act if self._multi_env else self._act[0])
if not self._multi_env: if not self._multi_env: