make unit test faster (#522)

* test cache expert data in offline training

* faster cql test

* faster tests

* use dummy

* test ray dependency
This commit is contained in:
Jiayi Weng 2022-02-08 11:24:52 -05:00 committed by GitHub
parent 9c100e0705
commit 3d697aa4c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 99 additions and 99 deletions

View File

@ -22,6 +22,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install ".[dev]" --upgrade python -m pip install ".[dev]" --upgrade
python -m pip uninstall ray -y
- name: wandb login - name: wandb login
run: | run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef wandb login e2366d661b89f2bee877c40bee15502d67b7abef

View File

@ -41,6 +41,7 @@ setup(
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
], ],
keywords="reinforcement learning platform pytorch", keywords="reinforcement learning platform pytorch",
packages=find_packages( packages=find_packages(
@ -66,7 +67,7 @@ setup(
"isort", "isort",
"pytest", "pytest",
"pytest-cov", "pytest-cov",
"ray>=1.0.0,<1.7.0", "ray>=1.0.0",
"wandb>=0.12.0", "wandb>=0.12.0",
"networkx", "networkx",
"mypy", "mypy",

View File

@ -8,7 +8,7 @@ import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.policy import PPOPolicy from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
@ -19,7 +19,7 @@ from tianshou.utils.net.discrete import Actor, Critic
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
@ -57,11 +57,11 @@ def test_ppo(args=get_args()):
args.action_shape = env.action_space.shape or env.action_space.n args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task) # train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv # you can also use tianshou.env.SubprocVectorEnv
train_envs = SubprocVectorEnv( train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)] [lambda: gym.make(args.task) for _ in range(args.training_num)]
) )
# test_envs = gym.make(args.task) # test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv( test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)] [lambda: gym.make(args.task) for _ in range(args.test_num)]
) )
# seed # seed

View File

@ -8,7 +8,7 @@ import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.policy import DiscreteSACPolicy from tianshou.policy import DiscreteSACPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
@ -19,7 +19,7 @@ from tianshou.utils.net.discrete import Actor, Critic
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=1e-4) parser.add_argument('--actor-lr', type=float, default=1e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3)
@ -32,8 +32,8 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
@ -49,13 +49,16 @@ def get_args():
def test_discrete_sac(args=get_args()): def test_discrete_sac(args=get_args()):
env = gym.make(args.task) env = gym.make(args.task)
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 180 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n args.action_shape = env.action_space.shape or env.action_space.n
train_envs = SubprocVectorEnv( train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)] [lambda: gym.make(args.task) for _ in range(args.training_num)]
) )
test_envs = SubprocVectorEnv( test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)] [lambda: gym.make(args.task) for _ in range(args.test_num)]
) )
# seed # seed

View File

@ -193,12 +193,5 @@ def test_dqn_icm(args=get_args()):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}") print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_pdqn_icm(args=get_args()):
args.prioritized_replay = True
args.gamma = .95
args.seed = 1
test_dqn_icm(args)
if __name__ == '__main__': if __name__ == '__main__':
test_dqn_icm(get_args()) test_dqn_icm(get_args())

View File

@ -8,18 +8,18 @@ import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.policy import ICMPolicy, PPOPolicy
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import MLP, ActorCritic, DataParallelNet, Net from tianshou.utils.net.common import MLP, ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
@ -75,11 +75,11 @@ def test_ppo(args=get_args()):
args.action_shape = env.action_space.shape or env.action_space.n args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task) # train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv # you can also use tianshou.env.SubprocVectorEnv
train_envs = SubprocVectorEnv( train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)] [lambda: gym.make(args.task) for _ in range(args.training_num)]
) )
# test_envs = gym.make(args.task) # test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv( test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)] [lambda: gym.make(args.task) for _ in range(args.test_num)]
) )
# seed # seed
@ -89,12 +89,6 @@ def test_ppo(args=get_args()):
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
if torch.cuda.is_available():
actor = DataParallelNet(
Actor(net, args.action_shape, device=None).to(args.device)
)
critic = DataParallelNet(Critic(net, device=None).to(args.device))
else:
actor = Actor(net, args.action_shape, device=args.device).to(args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device)
actor_critic = ActorCritic(actor, critic) actor_critic = ActorCritic(actor, critic)

Binary file not shown.

Binary file not shown.

View File

@ -15,6 +15,10 @@ from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
def expert_file_name():
return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v0.pkl")
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--task', type=str, default='CartPole-v0')
@ -42,9 +46,7 @@ def get_args():
parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4) parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument( parser.add_argument('--save-buffer-name', type=str, default=expert_file_name())
'--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl"
)
parser.add_argument( parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
) )
@ -155,6 +157,9 @@ def gather_data():
policy.set_eps(0.2) policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True) collector = Collector(policy, test_envs, buf, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size) result = collector.collect(n_step=args.buffer_size)
if args.save_buffer_name.endswith(".hdf5"):
buf.save_hdf5(args.save_buffer_name)
else:
pickle.dump(buf, open(args.save_buffer_name, "wb")) pickle.dump(buf, open(args.save_buffer_name, "wb"))
print(result["rews"].mean()) print(result["rews"].mean())
return buf return buf

View File

@ -16,11 +16,15 @@ from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.net.continuous import ActorProb, Critic
def expert_file_name():
return os.path.join(os.path.dirname(__file__), "expert_SAC_Pendulum-v0.pkl")
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=200000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128])
parser.add_argument('--actor-lr', type=float, default=1e-3) parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3)
@ -52,9 +56,7 @@ def get_args():
parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--alpha-lr', type=float, default=3e-4)
parser.add_argument('--rew-norm', action="store_true", default=False) parser.add_argument('--rew-norm', action="store_true", default=False)
parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--n-step', type=int, default=3)
parser.add_argument( parser.add_argument("--save-buffer-name", type=str, default=expert_file_name())
"--save-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl"
)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args
@ -166,5 +168,8 @@ def gather_data():
result = train_collector.collect(n_step=args.buffer_size) result = train_collector.collect(n_step=args.buffer_size)
rews, lens = result["rews"], result["lens"] rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}") print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if args.save_buffer_name.endswith(".hdf5"):
buffer.save_hdf5(args.save_buffer_name)
else:
pickle.dump(buffer, open(args.save_buffer_name, "wb")) pickle.dump(buffer, open(args.save_buffer_name, "wb"))
return buffer return buffer

View File

@ -9,7 +9,7 @@ import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.policy import BCQPolicy from tianshou.policy import BCQPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
@ -18,26 +18,26 @@ from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import VAE, Critic, Perturbation from tianshou.utils.net.continuous import VAE, Critic, Perturbation
if __name__ == "__main__": if __name__ == "__main__":
from gather_pendulum_data import gather_data from gather_pendulum_data import expert_file_name, gather_data
else: # pytest else: # pytest
from test.offline.gather_pendulum_data import gather_data from test.offline.gather_pendulum_data import expert_file_name, gather_data
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[200, 150]) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64])
parser.add_argument('--actor-lr', type=float, default=1e-3) parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--epoch', type=int, default=7) parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=2000) parser.add_argument('--step-per-epoch', type=int, default=500)
parser.add_argument('--batch-size', type=int, default=256) parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[375, 375]) parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[32, 32])
# default to 2 * action_dim # default to 2 * action_dim
parser.add_argument('--latent_dim', type=int, default=None) parser.add_argument('--latent_dim', type=int, default=None)
parser.add_argument("--gamma", default=0.99) parser.add_argument("--gamma", default=0.99)
@ -56,15 +56,16 @@ def get_args():
action='store_true', action='store_true',
help='watch the play of pre-trained policy only', help='watch the play of pre-trained policy only',
) )
parser.add_argument( parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
"--load-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl"
)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args
def test_bcq(args=get_args()): def test_bcq(args=get_args()):
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
buffer = pickle.load(open(args.load_buffer_name, "rb")) buffer = pickle.load(open(args.load_buffer_name, "rb"))
else: else:
buffer = gather_data() buffer = gather_data()
@ -73,7 +74,7 @@ def test_bcq(args=get_args()):
args.action_shape = env.action_space.shape or env.action_space.n args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0] # float args.max_action = env.action_space.high[0] # float
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -800 # too low? env.spec.reward_threshold = -1100 # too low?
args.state_dim = args.state_shape[0] args.state_dim = args.state_shape[0]
args.action_dim = args.action_shape[0] args.action_dim = args.action_shape[0]

View File

@ -9,7 +9,7 @@ import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.policy import CQLPolicy from tianshou.policy import CQLPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
@ -18,16 +18,16 @@ from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.net.continuous import ActorProb, Critic
if __name__ == "__main__": if __name__ == "__main__":
from gather_pendulum_data import gather_data from gather_pendulum_data import expert_file_name, gather_data
else: # pytest else: # pytest
from test.offline.gather_pendulum_data import gather_data from test.offline.gather_pendulum_data import expert_file_name, gather_data
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--actor-lr', type=float, default=1e-3) parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--alpha', type=float, default=0.2)
@ -35,10 +35,10 @@ def get_args():
parser.add_argument('--alpha-lr', type=float, default=1e-3) parser.add_argument('--alpha-lr', type=float, default=1e-3)
parser.add_argument('--cql-alpha-lr', type=float, default=1e-3) parser.add_argument('--cql-alpha-lr', type=float, default=1e-3)
parser.add_argument("--start-timesteps", type=int, default=10000) parser.add_argument("--start-timesteps", type=int, default=10000)
parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=2000) parser.add_argument('--step-per-epoch', type=int, default=500)
parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=256) parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--temperature", type=float, default=1.0)
@ -48,7 +48,6 @@ def get_args():
parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--eval-freq", type=int, default=1) parser.add_argument("--eval-freq", type=int, default=1)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=1 / 35) parser.add_argument('--render', type=float, default=1 / 35)
@ -62,15 +61,16 @@ def get_args():
action='store_true', action='store_true',
help='watch the play of pre-trained policy only', help='watch the play of pre-trained policy only',
) )
parser.add_argument( parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
"--load-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl"
)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args
def test_cql(args=get_args()): def test_cql(args=get_args()):
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
buffer = pickle.load(open(args.load_buffer_name, "rb")) buffer = pickle.load(open(args.load_buffer_name, "rb"))
else: else:
buffer = gather_data() buffer = gather_data()
@ -106,7 +106,7 @@ def test_cql(args=get_args()):
max_action=args.max_action, max_action=args.max_action,
device=args.device, device=args.device,
unbounded=True, unbounded=True,
conditioned_sigma=True conditioned_sigma=True,
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)

View File

@ -8,7 +8,7 @@ import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.policy import DiscreteBCQPolicy from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
@ -17,9 +17,9 @@ from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor from tianshou.utils.net.discrete import Actor
if __name__ == "__main__": if __name__ == "__main__":
from gather_cartpole_data import gather_data from gather_cartpole_data import expert_file_name, gather_data
else: # pytest else: # pytest
from test.offline.gather_cartpole_data import gather_data from test.offline.gather_cartpole_data import expert_file_name, gather_data
def get_args(): def get_args():
@ -40,11 +40,7 @@ def get_args():
parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.) parser.add_argument("--render", type=float, default=0.)
parser.add_argument( parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
"--load-buffer-name",
type=str,
default="./expert_QRDQN_CartPole-v0.pkl",
)
parser.add_argument( parser.add_argument(
"--device", "--device",
type=str, type=str,
@ -94,6 +90,9 @@ def test_discrete_bcq(args=get_args()):
) )
# buffer # buffer
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
buffer = pickle.load(open(args.load_buffer_name, "rb")) buffer = pickle.load(open(args.load_buffer_name, "rb"))
else: else:
buffer = gather_data() buffer = gather_data()

View File

@ -8,7 +8,7 @@ import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.policy import DiscreteCQLPolicy from tianshou.policy import DiscreteCQLPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
@ -16,9 +16,9 @@ from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
if __name__ == "__main__": if __name__ == "__main__":
from gather_cartpole_data import gather_data from gather_cartpole_data import expert_file_name, gather_data
else: # pytest else: # pytest
from test.offline.gather_cartpole_data import gather_data from test.offline.gather_cartpole_data import expert_file_name, gather_data
def get_args(): def get_args():
@ -26,24 +26,20 @@ def get_args():
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001) parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=7e-4) parser.add_argument("--lr", type=float, default=3e-3)
parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument('--num-quantiles', type=int, default=200) parser.add_argument('--num-quantiles', type=int, default=200)
parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--min-q-weight", type=float, default=10.) parser.add_argument("--min-q-weight", type=float, default=10.)
parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--update-per-epoch", type=int, default=1000) parser.add_argument("--update-per-epoch", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64])
parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.) parser.add_argument("--render", type=float, default=0.)
parser.add_argument( parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
"--load-buffer-name",
type=str,
default="./expert_QRDQN_CartPole-v0.pkl",
)
parser.add_argument( parser.add_argument(
"--device", "--device",
type=str, type=str,
@ -57,7 +53,7 @@ def test_discrete_cql(args=get_args()):
# envs # envs
env = gym.make(args.task) env = gym.make(args.task)
if args.task == 'CartPole-v0': if args.task == 'CartPole-v0':
env.spec.reward_threshold = 185 # lower the goal env.spec.reward_threshold = 170 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv( test_envs = DummyVectorEnv(
@ -89,6 +85,9 @@ def test_discrete_cql(args=get_args()):
).to(args.device) ).to(args.device)
# buffer # buffer
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
buffer = pickle.load(open(args.load_buffer_name, "rb")) buffer = pickle.load(open(args.load_buffer_name, "rb"))
else: else:
buffer = gather_data() buffer = gather_data()

View File

@ -8,7 +8,7 @@ import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.policy import DiscreteCRRPolicy from tianshou.policy import DiscreteCRRPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
@ -17,9 +17,9 @@ from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.net.discrete import Actor, Critic
if __name__ == "__main__": if __name__ == "__main__":
from gather_cartpole_data import gather_data from gather_cartpole_data import expert_file_name, gather_data
else: # pytest else: # pytest
from test.offline.gather_cartpole_data import gather_data from test.offline.gather_cartpole_data import expert_file_name, gather_data
def get_args(): def get_args():
@ -37,11 +37,7 @@ def get_args():
parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.) parser.add_argument("--render", type=float, default=0.)
parser.add_argument( parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
"--load-buffer-name",
type=str,
default="./expert_QRDQN_CartPole-v0.pkl",
)
parser.add_argument( parser.add_argument(
"--device", "--device",
type=str, type=str,
@ -55,7 +51,7 @@ def test_discrete_crr(args=get_args()):
# envs # envs
env = gym.make(args.task) env = gym.make(args.task)
if args.task == 'CartPole-v0': if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal env.spec.reward_threshold = 180 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv( test_envs = DummyVectorEnv(
@ -92,6 +88,9 @@ def test_discrete_crr(args=get_args()):
).to(args.device) ).to(args.device)
# buffer # buffer
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
buffer = pickle.load(open(args.load_buffer_name, "rb")) buffer = pickle.load(open(args.load_buffer_name, "rb"))
else: else:
buffer = gather_data() buffer = gather_data()