Fix/reset before collect in procedural examples, tests and hl experiment (#1100)

Needed due to a breaking change in the Collector which was overlooked in some of the examples
This commit is contained in:
maxhuettenrauch 2024-04-16 10:30:21 +02:00 committed by GitHub
parent 766f6fedf2
commit 60d1ba1c8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 45 additions and 18 deletions

View File

@ -190,6 +190,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
sys.exit(0)
# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(

View File

@ -232,6 +232,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
sys.exit(0)
# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(

View File

@ -203,6 +203,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
sys.exit(0)
# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(

View File

@ -200,6 +200,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
sys.exit(0)
# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer

View File

@ -256,6 +256,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
sys.exit(0)
# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OnpolicyTrainer(

View File

@ -194,6 +194,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
sys.exit(0)
# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(

View File

@ -230,6 +230,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
sys.exit(0)
# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(

View File

@ -243,6 +243,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
sys.exit(0)
# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(

View File

@ -92,6 +92,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "dqn")

View File

@ -117,6 +117,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
)
test_collector = Collector(policy, test_envs, exploration_noise=False)
# policy.set_eps(1)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

View File

@ -94,6 +94,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "dqn")

View File

@ -213,6 +213,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.reset()
train_collector.collect(n_step=args.start_timesteps, random=True)
def save_best_fn(policy: BasePolicy) -> None:

View File

@ -122,6 +122,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.reset()
train_collector.collect(n_step=args.start_timesteps, random=True)
# log

View File

@ -150,6 +150,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.reset()
train_collector.collect(n_step=args.start_timesteps, random=True)
# log

View File

@ -144,6 +144,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.reset()
train_collector.collect(n_step=args.start_timesteps, random=True)
# log

View File

@ -142,6 +142,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.reset()
train_collector.collect(n_step=args.start_timesteps, random=True)
# log

View File

@ -196,6 +196,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
sys.exit(0)
# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(

View File

@ -258,6 +258,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
sys.exit(0)
# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OnpolicyTrainer(

View File

@ -36,7 +36,7 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--alpha", type=float, default=0.2)
parser.add_argument("--auto-alpha", type=int, default=1)
parser.add_argument("--alpha-lr", type=float, default=3e-4)
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--epoch", type=int, default=10)
parser.add_argument("--step-per-epoch", type=int, default=24000)
parser.add_argument("--il-step-per-epoch", type=int, default=500)
parser.add_argument("--step-per-collect", type=int, default=10)

View File

@ -115,7 +115,8 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
)
test_collector = Collector(policy, test_envs, exploration_noise=False)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
def train_fn(epoch: int, env_step: int) -> None: # exp decay
eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test)

View File

@ -119,7 +119,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "c51")
writer = SummaryWriter(log_path)

View File

@ -110,7 +110,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "dqn")
writer = SummaryWriter(log_path)

View File

@ -94,7 +94,8 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
# the stack_num is for RNN training: sample framestack obs
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "drqn")
writer = SummaryWriter(log_path)

View File

@ -127,7 +127,8 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "fqf")
writer = SummaryWriter(log_path)

View File

@ -123,7 +123,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "iqn")
writer = SummaryWriter(log_path)

View File

@ -112,7 +112,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "qrdqn")
writer = SummaryWriter(log_path)

View File

@ -127,7 +127,8 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "rainbow")
writer = SummaryWriter(log_path)

View File

@ -153,7 +153,8 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "dqn_icm")
writer = SummaryWriter(log_path)

View File

@ -118,7 +118,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
test_collector = Collector(policy, test_envs, exploration_noise=True)
test_collector.reset()
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "qrdqn")
writer = SummaryWriter(log_path)

View File

@ -135,7 +135,8 @@ def train_agent(
exploration_noise=True,
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, "pistonball", "dqn")
writer = SummaryWriter(log_path)

View File

@ -171,7 +171,8 @@ def train_agent(
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn")
writer = SummaryWriter(log_path)

View File

@ -330,10 +330,7 @@ class Collector:
:param random: whether to use random policy for collecting data.
:param render: the sleep time between rendering consecutive frames.
:param no_grad: whether to retain gradient in policy.forward().
:param reset_before_collect: whether to reset the environment before
collecting data.
It has only an effect if n_episode is not None, i.e.
if one wants to collect a fixed number of episodes.
:param reset_before_collect: whether to reset the environment before collecting data.
(The collector needs the initial obs and info to function properly.)
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Only used if reset_before_collect is True.

View File

@ -311,7 +311,8 @@ class Experiment(ToStringMixin):
) -> None:
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=num_episodes, render=render, reset_before_collect=True)
collector.reset()
result = collector.collect(n_episode=num_episodes, render=render)
assert result.returns_stat is not None # for mypy
assert result.lens_stat is not None # for mypy
log.info(