Fix/reset before collect in procedural examples, tests and hl experiment (#1100)
Needed due to a breaking change in the Collector which was overlooked in some of the examples
This commit is contained in:
parent
766f6fedf2
commit
60d1ba1c8f
@ -190,6 +190,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
sys.exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = OffpolicyTrainer(
|
||||
|
@ -232,6 +232,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
sys.exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = OffpolicyTrainer(
|
||||
|
@ -203,6 +203,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
|
||||
sys.exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = OffpolicyTrainer(
|
||||
|
@ -200,6 +200,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
|
||||
sys.exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
|
||||
|
@ -256,6 +256,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
sys.exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = OnpolicyTrainer(
|
||||
|
@ -194,6 +194,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
|
||||
sys.exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = OffpolicyTrainer(
|
||||
|
@ -230,6 +230,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
|
||||
sys.exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = OffpolicyTrainer(
|
||||
|
@ -243,6 +243,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
sys.exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = OffpolicyTrainer(
|
||||
|
@ -92,6 +92,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "dqn")
|
||||
|
@ -117,6 +117,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=False)
|
||||
# policy.set_eps(1)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
|
@ -94,6 +94,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "dqn")
|
||||
|
@ -213,6 +213,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
|
||||
)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
|
||||
def save_best_fn(policy: BasePolicy) -> None:
|
||||
|
@ -122,6 +122,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
|
||||
buffer = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
|
||||
# log
|
||||
|
@ -150,6 +150,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
|
||||
buffer = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
|
||||
# log
|
||||
|
@ -144,6 +144,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
buffer = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
|
||||
# log
|
||||
|
@ -142,6 +142,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
|
||||
buffer = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
|
||||
# log
|
||||
|
@ -196,6 +196,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
sys.exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = OffpolicyTrainer(
|
||||
|
@ -258,6 +258,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
sys.exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = OnpolicyTrainer(
|
||||
|
@ -36,7 +36,7 @@ def get_args() -> argparse.Namespace:
|
||||
parser.add_argument("--alpha", type=float, default=0.2)
|
||||
parser.add_argument("--auto-alpha", type=int, default=1)
|
||||
parser.add_argument("--alpha-lr", type=float, default=3e-4)
|
||||
parser.add_argument("--epoch", type=int, default=5)
|
||||
parser.add_argument("--epoch", type=int, default=10)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=24000)
|
||||
parser.add_argument("--il-step-per-epoch", type=int, default=500)
|
||||
parser.add_argument("--step-per-collect", type=int, default=10)
|
||||
|
@ -115,7 +115,8 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=False)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
|
||||
def train_fn(epoch: int, env_step: int) -> None: # exp decay
|
||||
eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test)
|
||||
|
@ -119,7 +119,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "c51")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -110,7 +110,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "dqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -94,7 +94,8 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
|
||||
# the stack_num is for RNN training: sample framestack obs
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "drqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -127,7 +127,8 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "fqf")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -123,7 +123,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "iqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -112,7 +112,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "qrdqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -127,7 +127,8 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "rainbow")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -153,7 +153,8 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "dqn_icm")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -118,7 +118,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
test_collector.reset()
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "qrdqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -135,7 +135,8 @@ def train_agent(
|
||||
exploration_noise=True,
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, "pistonball", "dqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -171,7 +171,8 @@ def train_agent(
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -330,10 +330,7 @@ class Collector:
|
||||
:param random: whether to use random policy for collecting data.
|
||||
:param render: the sleep time between rendering consecutive frames.
|
||||
:param no_grad: whether to retain gradient in policy.forward().
|
||||
:param reset_before_collect: whether to reset the environment before
|
||||
collecting data.
|
||||
It has only an effect if n_episode is not None, i.e.
|
||||
if one wants to collect a fixed number of episodes.
|
||||
:param reset_before_collect: whether to reset the environment before collecting data.
|
||||
(The collector needs the initial obs and info to function properly.)
|
||||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
||||
reset function. Only used if reset_before_collect is True.
|
||||
|
@ -311,7 +311,8 @@ class Experiment(ToStringMixin):
|
||||
) -> None:
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=num_episodes, render=render, reset_before_collect=True)
|
||||
collector.reset()
|
||||
result = collector.collect(n_episode=num_episodes, render=render)
|
||||
assert result.returns_stat is not None # for mypy
|
||||
assert result.lens_stat is not None # for mypy
|
||||
log.info(
|
||||
|
Loading…
x
Reference in New Issue
Block a user