diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 2b2e12e..c347f75 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -78,6 +78,8 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed + args.training_num) # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb(net, args.action_shape, device=args.device, unbounded=True).to(args.device) @@ -181,10 +183,12 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: action_scaling=True, action_bound_method="clip", ) + il_test_env = gym.make(args.task) + il_test_env.reset(seed=args.seed + args.training_num + args.test_num) il_test_collector = Collector( il_policy, # envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed), - gym.make(args.task), + il_test_env, ) train_collector.reset() result = OffpolicyTrainer(