added noise param to collector for test phase, fixed examples to adapt modification (#86)
* Add auto alpha tuning and exploration noise for sac. Add class BaseNoise and GaussianNoise for the concept of exploration noise. Add new test for sac tested in MountainCarContinuous-v0, which should benefits from the two above new feature. * add exploration noise to collector, fix example to adapt modification
This commit is contained in:
parent
e8b44bbaf4
commit
13828f6309
@ -9,6 +9,7 @@ from tianshou.policy import DDPGPolicy
|
|||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
|
from tianshou.exploration import GaussianNoise
|
||||||
|
|
||||||
from continuous_net import Actor, Critic
|
from continuous_net import Actor, Critic
|
||||||
|
|
||||||
@ -67,7 +68,7 @@ def test_ddpg(args=get_args()):
|
|||||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||||
policy = DDPGPolicy(
|
policy = DDPGPolicy(
|
||||||
actor, actor_optim, critic, critic_optim,
|
actor, actor_optim, critic, critic_optim,
|
||||||
args.tau, args.gamma, args.exploration_noise,
|
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
|
||||||
[env.action_space.low[0], env.action_space.high[0]],
|
[env.action_space.low[0], env.action_space.high[0]],
|
||||||
reward_normalization=True, ignore_done=True)
|
reward_normalization=True, ignore_done=True)
|
||||||
# collector
|
# collector
|
||||||
|
@ -9,6 +9,7 @@ from tianshou.policy import TD3Policy
|
|||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
|
from tianshou.exploration import GaussianNoise
|
||||||
|
|
||||||
from continuous_net import Actor, Critic
|
from continuous_net import Actor, Critic
|
||||||
|
|
||||||
@ -74,7 +75,8 @@ def test_td3(args=get_args()):
|
|||||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
policy = TD3Policy(
|
policy = TD3Policy(
|
||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
|
args.tau, args.gamma,
|
||||||
|
GaussianNoise(sigma=args.exploration_noise), args.policy_noise,
|
||||||
args.update_actor_freq, args.noise_clip,
|
args.update_actor_freq, args.noise_clip,
|
||||||
[env.action_space.low[0], env.action_space.high[0]],
|
[env.action_space.low[0], env.action_space.high[0]],
|
||||||
reward_normalization=True, ignore_done=True)
|
reward_normalization=True, ignore_done=True)
|
||||||
|
@ -9,6 +9,7 @@ from tianshou.policy import TD3Policy
|
|||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
|
from tianshou.exploration import GaussianNoise
|
||||||
from continuous_net import Actor, Critic
|
from continuous_net import Actor, Critic
|
||||||
from mujoco.register import reg
|
from mujoco.register import reg
|
||||||
|
|
||||||
@ -77,7 +78,8 @@ def test_td3(args=get_args()):
|
|||||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
policy = TD3Policy(
|
policy = TD3Policy(
|
||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
|
args.tau, args.gamma,
|
||||||
|
GaussianNoise(sigma=args.exploration_noise), args.policy_noise,
|
||||||
args.update_actor_freq, args.noise_clip,
|
args.update_actor_freq, args.noise_clip,
|
||||||
[env.action_space.low[0], env.action_space.high[0]],
|
[env.action_space.low[0], env.action_space.high[0]],
|
||||||
reward_normalization=True, ignore_done=True)
|
reward_normalization=True, ignore_done=True)
|
||||||
|
@ -23,20 +23,20 @@ def get_args():
|
|||||||
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--critic-lr', type=float, default=3e-4)
|
parser.add_argument('--critic-lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--noise_std', type=float, default=0.5)
|
parser.add_argument('--noise_std', type=float, default=1.2)
|
||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--tau', type=float, default=0.005)
|
parser.add_argument('--tau', type=float, default=0.005)
|
||||||
parser.add_argument('--auto_alpha', type=bool, default=True)
|
parser.add_argument('--auto_alpha', type=bool, default=True)
|
||||||
parser.add_argument('--alpha', type=float, default=0.2)
|
parser.add_argument('--alpha', type=float, default=0.2)
|
||||||
parser.add_argument('--epoch', type=int, default=20)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=1)
|
parser.add_argument('--collect-per-step', type=int, default=5)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--layer-num', type=int, default=1)
|
parser.add_argument('--layer-num', type=int, default=1)
|
||||||
parser.add_argument('--training-num', type=int, default=80)
|
parser.add_argument('--training-num', type=int, default=16)
|
||||||
parser.add_argument('--test-num', type=int, default=100)
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument('--render', type=float, default=1.0/35.0)
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
parser.add_argument('--rew-norm', type=bool, default=False)
|
parser.add_argument('--rew-norm', type=bool, default=False)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device', type=str,
|
'--device', type=str,
|
||||||
|
@ -28,6 +28,7 @@ def get_args():
|
|||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--tau', type=float, default=0.005)
|
parser.add_argument('--tau', type=float, default=0.005)
|
||||||
parser.add_argument('--exploration-noise', type=float, default=0.1)
|
parser.add_argument('--exploration-noise', type=float, default=0.1)
|
||||||
|
parser.add_argument('--test-noise', type=float, default=0.1)
|
||||||
parser.add_argument('--epoch', type=int, default=20)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=4)
|
parser.add_argument('--collect-per-step', type=int, default=4)
|
||||||
@ -87,7 +88,8 @@ def test_ddpg(args=get_args()):
|
|||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(
|
||||||
|
policy, test_envs, action_noise=GaussianNoise(sigma=args.test_noise))
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'ddpg')
|
log_path = os.path.join(args.logdir, args.task, 'ddpg')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
@ -9,6 +9,7 @@ from tianshou.utils import MovAvg
|
|||||||
from tianshou.env import BaseVectorEnv
|
from tianshou.env import BaseVectorEnv
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
||||||
|
from tianshou.exploration import BaseNoise
|
||||||
|
|
||||||
|
|
||||||
class Collector(object):
|
class Collector(object):
|
||||||
@ -27,6 +28,9 @@ class Collector(object):
|
|||||||
added to the buffer, see issue #42, defaults to ``None``.
|
added to the buffer, see issue #42, defaults to ``None``.
|
||||||
:param int stat_size: for the moving average of recording speed, defaults
|
:param int stat_size: for the moving average of recording speed, defaults
|
||||||
to 100.
|
to 100.
|
||||||
|
:param BaseNoise action_noise: add a noise to continuous action. Normally
|
||||||
|
a policy already has a noise param for exploration in training phase,
|
||||||
|
so this is recommended to use in test collector for some purpose.
|
||||||
|
|
||||||
The ``preprocess_fn`` is a function called before the data has been added
|
The ``preprocess_fn`` is a function called before the data has been added
|
||||||
to the buffer with batch format, which receives up to 7 keys as listed in
|
to the buffer with batch format, which receives up to 7 keys as listed in
|
||||||
@ -87,6 +91,7 @@ class Collector(object):
|
|||||||
= None,
|
= None,
|
||||||
preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
|
preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
|
||||||
stat_size: Optional[int] = 100,
|
stat_size: Optional[int] = 100,
|
||||||
|
action_noise: Optional[BaseNoise] = None,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.env = env
|
self.env = env
|
||||||
@ -119,6 +124,7 @@ class Collector(object):
|
|||||||
else:
|
else:
|
||||||
raise TypeError('The buffer in data collector is invalid!')
|
raise TypeError('The buffer in data collector is invalid!')
|
||||||
self.stat_size = stat_size
|
self.stat_size = stat_size
|
||||||
|
self._action_noise = action_noise
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
@ -132,6 +138,8 @@ class Collector(object):
|
|||||||
self.collect_step = 0
|
self.collect_step = 0
|
||||||
self.collect_episode = 0
|
self.collect_episode = 0
|
||||||
self.collect_time = 0
|
self.collect_time = 0
|
||||||
|
if self._action_noise is not None:
|
||||||
|
self._action_noise.reset()
|
||||||
|
|
||||||
def reset_buffer(self) -> None:
|
def reset_buffer(self) -> None:
|
||||||
"""Reset the main data buffer."""
|
"""Reset the main data buffer."""
|
||||||
@ -268,6 +276,8 @@ class Collector(object):
|
|||||||
self._policy = to_numpy(result.policy) \
|
self._policy = to_numpy(result.policy) \
|
||||||
if hasattr(result, 'policy') else [{}] * self.env_num
|
if hasattr(result, 'policy') else [{}] * self.env_num
|
||||||
self._act = to_numpy(result.act)
|
self._act = to_numpy(result.act)
|
||||||
|
if self._action_noise is not None:
|
||||||
|
self._act += self._action_noise(self._act.shape)
|
||||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||||
self._act if self._multi_env else self._act[0])
|
self._act if self._multi_env else self._act[0])
|
||||||
if not self._multi_env:
|
if not self._multi_env:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user