diff --git a/.github/workflows/extra_sys.yml b/.github/workflows/extra_sys.yml index df21abc..4e6ecda 100644 --- a/.github/workflows/extra_sys.yml +++ b/.github/workflows/extra_sys.yml @@ -22,6 +22,7 @@ jobs: - name: Install dependencies run: | python -m pip install ".[dev]" --upgrade + python -m pip uninstall ray -y - name: wandb login run: | wandb login e2366d661b89f2bee877c40bee15502d67b7abef diff --git a/setup.py b/setup.py index 963fb79..00af4fe 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ setup( "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], keywords="reinforcement learning platform pytorch", packages=find_packages( @@ -66,7 +67,7 @@ setup( "isort", "pytest", "pytest-cov", - "ray>=1.0.0,<1.7.0", + "ray>=1.0.0", "wandb>=0.12.0", "networkx", "mypy", diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index b952cc5..da74b7a 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -8,7 +8,7 @@ import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv +from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger @@ -19,7 +19,7 @@ from tianshou.utils.net.discrete import Actor, Critic def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--gamma', type=float, default=0.99) @@ -57,11 +57,11 @@ def test_ppo(args=get_args()): args.action_shape = env.action_space.shape or env.action_space.n # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = SubprocVectorEnv( + train_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)] ) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( + test_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)] ) # seed diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 41be368..118a296 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -8,7 +8,7 @@ import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv +from tianshou.env import DummyVectorEnv from tianshou.policy import DiscreteSACPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger @@ -19,7 +19,7 @@ from tianshou.utils.net.discrete import Actor, Critic def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--actor-lr', type=float, default=1e-4) parser.add_argument('--critic-lr', type=float, default=1e-3) @@ -32,8 +32,8 @@ def get_args(): parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') @@ -49,13 +49,16 @@ def get_args(): def test_discrete_sac(args=get_args()): env = gym.make(args.task) + if args.task == 'CartPole-v0': + env.spec.reward_threshold = 180 # lower the goal + args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n - train_envs = SubprocVectorEnv( + train_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)] ) - test_envs = SubprocVectorEnv( + test_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)] ) # seed diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 1fcd60f..64bcda6 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -193,12 +193,5 @@ def test_dqn_icm(args=get_args()): print(f"Final reward: {rews.mean()}, length: {lens.mean()}") -def test_pdqn_icm(args=get_args()): - args.prioritized_replay = True - args.gamma = .95 - args.seed = 1 - test_dqn_icm(args) - - if __name__ == '__main__': test_dqn_icm(get_args()) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 1c202f9..f548197 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -8,18 +8,18 @@ import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv +from tianshou.env import DummyVectorEnv from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLP, ActorCritic, DataParallelNet, Net +from tianshou.utils.net.common import MLP, ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--gamma', type=float, default=0.99) @@ -75,11 +75,11 @@ def test_ppo(args=get_args()): args.action_shape = env.action_space.shape or env.action_space.n # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = SubprocVectorEnv( + train_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)] ) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( + test_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)] ) # seed @@ -89,14 +89,8 @@ def test_ppo(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - if torch.cuda.is_available(): - actor = DataParallelNet( - Actor(net, args.action_shape, device=None).to(args.device) - ) - critic = DataParallelNet(Critic(net, device=None).to(args.device)) - else: - actor = Actor(net, args.action_shape, device=args.device).to(args.device) - critic = Critic(net, device=args.device).to(args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization for m in actor_critic.modules(): diff --git a/test/offline/expert_QRDQN_CartPole-v0.pkl b/test/offline/expert_QRDQN_CartPole-v0.pkl new file mode 100644 index 0000000..4557f7b Binary files /dev/null and b/test/offline/expert_QRDQN_CartPole-v0.pkl differ diff --git a/test/offline/expert_SAC_Pendulum-v0.pkl b/test/offline/expert_SAC_Pendulum-v0.pkl new file mode 100644 index 0000000..64ce17b Binary files /dev/null and b/test/offline/expert_SAC_Pendulum-v0.pkl differ diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 78d8517..eeddf1a 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -15,6 +15,10 @@ from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net +def expert_file_name(): + return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v0.pkl") + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') @@ -42,9 +46,7 @@ def get_args(): parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) - parser.add_argument( - '--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl" - ) + parser.add_argument('--save-buffer-name', type=str, default=expert_file_name()) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' ) @@ -155,6 +157,9 @@ def gather_data(): policy.set_eps(0.2) collector = Collector(policy, test_envs, buf, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) - pickle.dump(buf, open(args.save_buffer_name, "wb")) + if args.save_buffer_name.endswith(".hdf5"): + buf.save_hdf5(args.save_buffer_name) + else: + pickle.dump(buf, open(args.save_buffer_name, "wb")) print(result["rews"].mean()) return buf diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 4c0275e..80ce12e 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -16,11 +16,15 @@ from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic +def expert_file_name(): + return os.path.join(os.path.dirname(__file__), "expert_SAC_Pendulum-v0.pkl") + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--buffer-size', type=int, default=200000) + parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--actor-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3) @@ -52,9 +56,7 @@ def get_args(): parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--rew-norm', action="store_true", default=False) parser.add_argument('--n-step', type=int, default=3) - parser.add_argument( - "--save-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl" - ) + parser.add_argument("--save-buffer-name", type=str, default=expert_file_name()) args = parser.parse_known_args()[0] return args @@ -166,5 +168,8 @@ def gather_data(): result = train_collector.collect(n_step=args.buffer_size) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") - pickle.dump(buffer, open(args.save_buffer_name, "wb")) + if args.save_buffer_name.endswith(".hdf5"): + buffer.save_hdf5(args.save_buffer_name) + else: + pickle.dump(buffer, open(args.save_buffer_name, "wb")) return buffer diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index ab98e49..ac5fe95 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -9,7 +9,7 @@ import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv from tianshou.policy import BCQPolicy from tianshou.trainer import offline_trainer @@ -18,26 +18,26 @@ from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, Critic, Perturbation if __name__ == "__main__": - from gather_pendulum_data import gather_data + from gather_pendulum_data import expert_file_name, gather_data else: # pytest - from test.offline.gather_pendulum_data import gather_data + from test.offline.gather_pendulum_data import expert_file_name, gather_data def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[200, 150]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64]) parser.add_argument('--actor-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3) - parser.add_argument('--epoch', type=int, default=7) - parser.add_argument('--step-per-epoch', type=int, default=2000) - parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=500) + parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[375, 375]) + parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[32, 32]) # default to 2 * action_dim parser.add_argument('--latent_dim', type=int, default=None) parser.add_argument("--gamma", default=0.99) @@ -56,16 +56,17 @@ def get_args(): action='store_true', help='watch the play of pre-trained policy only', ) - parser.add_argument( - "--load-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl" - ) + parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) args = parser.parse_known_args()[0] return args def test_bcq(args=get_args()): if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): - buffer = pickle.load(open(args.load_buffer_name, "rb")) + if args.load_buffer_name.endswith(".hdf5"): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + buffer = pickle.load(open(args.load_buffer_name, "rb")) else: buffer = gather_data() env = gym.make(args.task) @@ -73,7 +74,7 @@ def test_bcq(args=get_args()): args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] # float if args.task == 'Pendulum-v0': - env.spec.reward_threshold = -800 # too low? + env.spec.reward_threshold = -1100 # too low? args.state_dim = args.state_shape[0] args.action_dim = args.action_shape[0] diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index ae1507a..ce780ac 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -9,7 +9,7 @@ import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv from tianshou.policy import CQLPolicy from tianshou.trainer import offline_trainer @@ -18,16 +18,16 @@ from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic if __name__ == "__main__": - from gather_pendulum_data import gather_data + from gather_pendulum_data import expert_file_name, gather_data else: # pytest - from test.offline.gather_pendulum_data import gather_data + from test.offline.gather_pendulum_data import expert_file_name, gather_data def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--actor-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--alpha', type=float, default=0.2) @@ -35,10 +35,10 @@ def get_args(): parser.add_argument('--alpha-lr', type=float, default=1e-3) parser.add_argument('--cql-alpha-lr', type=float, default=1e-3) parser.add_argument("--start-timesteps", type=int, default=10000) - parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2000) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=500) parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--batch-size', type=int, default=64) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--temperature", type=float, default=1.0) @@ -48,7 +48,6 @@ def get_args(): parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--eval-freq", type=int, default=1) - parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=1 / 35) @@ -62,16 +61,17 @@ def get_args(): action='store_true', help='watch the play of pre-trained policy only', ) - parser.add_argument( - "--load-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl" - ) + parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) args = parser.parse_known_args()[0] return args def test_cql(args=get_args()): if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): - buffer = pickle.load(open(args.load_buffer_name, "rb")) + if args.load_buffer_name.endswith(".hdf5"): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + buffer = pickle.load(open(args.load_buffer_name, "rb")) else: buffer = gather_data() env = gym.make(args.task) @@ -106,7 +106,7 @@ def test_cql(args=get_args()): max_action=args.max_action, device=args.device, unbounded=True, - conditioned_sigma=True + conditioned_sigma=True, ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 460ddb3..e83d13d 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -8,7 +8,7 @@ import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import DiscreteBCQPolicy from tianshou.trainer import offline_trainer @@ -17,9 +17,9 @@ from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor if __name__ == "__main__": - from gather_cartpole_data import gather_data + from gather_cartpole_data import expert_file_name, gather_data else: # pytest - from test.offline.gather_cartpole_data import gather_data + from test.offline.gather_cartpole_data import expert_file_name, gather_data def get_args(): @@ -40,11 +40,7 @@ def get_args(): parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) - parser.add_argument( - "--load-buffer-name", - type=str, - default="./expert_QRDQN_CartPole-v0.pkl", - ) + parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, @@ -94,7 +90,10 @@ def test_discrete_bcq(args=get_args()): ) # buffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): - buffer = pickle.load(open(args.load_buffer_name, "rb")) + if args.load_buffer_name.endswith(".hdf5"): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + buffer = pickle.load(open(args.load_buffer_name, "rb")) else: buffer = gather_data() diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index c97f456..eaac481 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -8,7 +8,7 @@ import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import DiscreteCQLPolicy from tianshou.trainer import offline_trainer @@ -16,9 +16,9 @@ from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net if __name__ == "__main__": - from gather_cartpole_data import gather_data + from gather_cartpole_data import expert_file_name, gather_data else: # pytest - from test.offline.gather_cartpole_data import gather_data + from test.offline.gather_cartpole_data import expert_file_name, gather_data def get_args(): @@ -26,24 +26,20 @@ def get_args(): parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) - parser.add_argument("--lr", type=float, default=7e-4) + parser.add_argument("--lr", type=float, default=3e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument('--num-quantiles', type=int, default=200) parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--min-q-weight", type=float, default=10.) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--update-per-epoch", type=int, default=1000) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64]) parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) - parser.add_argument( - "--load-buffer-name", - type=str, - default="./expert_QRDQN_CartPole-v0.pkl", - ) + parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, @@ -57,7 +53,7 @@ def test_discrete_cql(args=get_args()): # envs env = gym.make(args.task) if args.task == 'CartPole-v0': - env.spec.reward_threshold = 185 # lower the goal + env.spec.reward_threshold = 170 # lower the goal args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n test_envs = DummyVectorEnv( @@ -89,7 +85,10 @@ def test_discrete_cql(args=get_args()): ).to(args.device) # buffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): - buffer = pickle.load(open(args.load_buffer_name, "rb")) + if args.load_buffer_name.endswith(".hdf5"): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + buffer = pickle.load(open(args.load_buffer_name, "rb")) else: buffer = gather_data() diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 0b4e7c6..2e8916a 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -8,7 +8,7 @@ import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import DiscreteCRRPolicy from tianshou.trainer import offline_trainer @@ -17,9 +17,9 @@ from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic if __name__ == "__main__": - from gather_cartpole_data import gather_data + from gather_cartpole_data import expert_file_name, gather_data else: # pytest - from test.offline.gather_cartpole_data import gather_data + from test.offline.gather_cartpole_data import expert_file_name, gather_data def get_args(): @@ -37,11 +37,7 @@ def get_args(): parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) - parser.add_argument( - "--load-buffer-name", - type=str, - default="./expert_QRDQN_CartPole-v0.pkl", - ) + parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, @@ -55,7 +51,7 @@ def test_discrete_crr(args=get_args()): # envs env = gym.make(args.task) if args.task == 'CartPole-v0': - env.spec.reward_threshold = 190 # lower the goal + env.spec.reward_threshold = 180 # lower the goal args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n test_envs = DummyVectorEnv( @@ -92,7 +88,10 @@ def test_discrete_crr(args=get_args()): ).to(args.device) # buffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): - buffer = pickle.load(open(args.load_buffer_name, "rb")) + if args.load_buffer_name.endswith(".hdf5"): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + buffer = pickle.load(open(args.load_buffer_name, "rb")) else: buffer = gather_data()