fix pytest error on non-linux system (#638)
This commit is contained in:
parent
bf8f63ffc3
commit
a03f19af72
@ -1,11 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
try:
|
|
||||||
import envpool
|
|
||||||
except ImportError:
|
|
||||||
envpool = None
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -19,6 +13,11 @@ from tianshou.utils import TensorboardLogger
|
|||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.utils.net.continuous import Actor, ActorProb, Critic
|
from tianshou.utils.net.continuous import Actor, ActorProb, Critic
|
||||||
|
|
||||||
|
try:
|
||||||
|
import envpool
|
||||||
|
except ImportError:
|
||||||
|
envpool = None
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -57,8 +56,9 @@ def get_args():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform != "linux", reason="envpool only support linux now")
|
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
||||||
def test_sac_with_il(args=get_args()):
|
def test_sac_with_il(args=get_args()):
|
||||||
|
# if you want to use python vector env, please refer to other test scripts
|
||||||
train_envs = env = envpool.make_gym(
|
train_envs = env = envpool.make_gym(
|
||||||
args.task, num_envs=args.training_num, seed=args.seed
|
args.task, num_envs=args.training_num, seed=args.seed
|
||||||
)
|
)
|
||||||
|
@ -2,9 +2,9 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import pprint
|
import pprint
|
||||||
|
|
||||||
import envpool
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -15,6 +15,11 @@ from tianshou.utils import TensorboardLogger
|
|||||||
from tianshou.utils.net.common import ActorCritic, Net
|
from tianshou.utils.net.common import ActorCritic, Net
|
||||||
from tianshou.utils.net.discrete import Actor, Critic
|
from tianshou.utils.net.discrete import Actor, Critic
|
||||||
|
|
||||||
|
try:
|
||||||
|
import envpool
|
||||||
|
except ImportError:
|
||||||
|
envpool = None
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -52,7 +57,9 @@ def get_args():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
||||||
def test_a2c_with_il(args=get_args()):
|
def test_a2c_with_il(args=get_args()):
|
||||||
|
# if you want to use python vector env, please refer to other test scripts
|
||||||
train_envs = env = envpool.make_gym(
|
train_envs = env = envpool.make_gym(
|
||||||
args.task, num_envs=args.training_num, seed=args.seed
|
args.task, num_envs=args.training_num, seed=args.seed
|
||||||
)
|
)
|
||||||
|
@ -2,8 +2,8 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import pprint
|
import pprint
|
||||||
|
|
||||||
import envpool
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -12,6 +12,11 @@ from tianshou.policy import PSRLPolicy
|
|||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger
|
from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger
|
||||||
|
|
||||||
|
try:
|
||||||
|
import envpool
|
||||||
|
except ImportError:
|
||||||
|
envpool = None
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -40,7 +45,9 @@ def get_args():
|
|||||||
return parser.parse_known_args()[0]
|
return parser.parse_known_args()[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
||||||
def test_psrl(args=get_args()):
|
def test_psrl(args=get_args()):
|
||||||
|
# if you want to use python vector env, please refer to other test scripts
|
||||||
train_envs = env = envpool.make_gym(
|
train_envs = env = envpool.make_gym(
|
||||||
args.task, num_envs=args.training_num, seed=args.seed
|
args.task, num_envs=args.training_num, seed=args.seed
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user