diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index c6fe6dd..fc04a21 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -190,6 +190,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 765463c..f51c7fd 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -232,6 +232,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index f616a68..127a14b 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -203,6 +203,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 9110694..8b16252 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -200,6 +200,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 612b540..f1a89ef 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -256,6 +256,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OnpolicyTrainer( diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 7d6330e..dfb9641 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -194,6 +194,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 86e7fe0..7b341c0 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -230,6 +230,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index d5edf1a..f06964c 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -243,6 +243,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index ad53b16..96e61b6 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -92,6 +92,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "dqn") diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index f52f6d5..8b1e8ca 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -117,6 +117,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector(policy, test_envs, exploration_noise=False) # policy.set_eps(1) + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # log current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 47ba9d1..9e5db58 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -94,6 +94,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "dqn") diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index be6594a..cd6ceec 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -213,6 +213,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: ) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) + train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) def save_best_fn(policy: BasePolicy) -> None: diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index ceac476..b2a4087 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -122,6 +122,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: buffer = ReplayBuffer(args.buffer_size) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) + train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) # log diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index b300e49..ae46b22 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -150,6 +150,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: buffer = ReplayBuffer(args.buffer_size) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) + train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) # log diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 2058a71..a0bd567 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -144,6 +144,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: buffer = ReplayBuffer(args.buffer_size) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) + train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) # log diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 30e7539..6b6dfdc 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -142,6 +142,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: buffer = ReplayBuffer(args.buffer_size) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) + train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) # log diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 62daaf6..4211585 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -196,6 +196,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 7476d4f..f5abf0b 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -258,6 +258,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OnpolicyTrainer( diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index fd5b15a..d13b03d 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -36,7 +36,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--alpha", type=float, default=0.2) parser.add_argument("--auto-alpha", type=int, default=1) parser.add_argument("--alpha-lr", type=float, default=3e-4) - parser.add_argument("--epoch", type=int, default=5) + parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--step-per-epoch", type=int, default=24000) parser.add_argument("--il-step-per-epoch", type=int, default=500) parser.add_argument("--step-per-collect", type=int, default=10) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index c3f6afe..1089d4b 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -115,7 +115,8 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector(policy, test_envs, exploration_noise=False) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 483aca9..4d25d43 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -119,7 +119,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "c51") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 6c58883..b62a93c 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -110,7 +110,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 8bca5c1..5c24518 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -94,7 +94,8 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: # the stack_num is for RNN training: sample framestack obs test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "drqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index f1af574..8ff9eeb 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -127,7 +127,8 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "fqf") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 87e7398..765bbf9 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -123,7 +123,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "iqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 879717a..6485637 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -112,7 +112,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index c703534..ff4ef1c 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -127,7 +127,8 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "rainbow") writer = SummaryWriter(log_path) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 9ca9c70..9a4206e 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -153,7 +153,8 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "dqn_icm") writer = SummaryWriter(log_path) diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index e8411b2..9387794 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -118,7 +118,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector.reset() # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 7b3fb4d..abd0c88 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -135,7 +135,8 @@ def train_agent( exploration_noise=True, ) test_collector = Collector(policy, test_envs, exploration_noise=True) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") writer = SummaryWriter(log_path) diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index da580f3..7ed6319 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -171,7 +171,8 @@ def train_agent( ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn") writer = SummaryWriter(log_path) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 751fedf..345d50b 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -330,10 +330,7 @@ class Collector: :param random: whether to use random policy for collecting data. :param render: the sleep time between rendering consecutive frames. :param no_grad: whether to retain gradient in policy.forward(). - :param reset_before_collect: whether to reset the environment before - collecting data. - It has only an effect if n_episode is not None, i.e. - if one wants to collect a fixed number of episodes. + :param reset_before_collect: whether to reset the environment before collecting data. (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Only used if reset_before_collect is True. diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 17f0550..87cbd49 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -311,7 +311,8 @@ class Experiment(ToStringMixin): ) -> None: policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=num_episodes, render=render, reset_before_collect=True) + collector.reset() + result = collector.collect(n_episode=num_episodes, render=render) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy log.info(