fix pytest error on non-linux system (#638)
This commit is contained in:
parent
bf8f63ffc3
commit
a03f19af72
@ -1,11 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
try:
|
||||
import envpool
|
||||
except ImportError:
|
||||
envpool = None
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -19,6 +13,11 @@ from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import Actor, ActorProb, Critic
|
||||
|
||||
try:
|
||||
import envpool
|
||||
except ImportError:
|
||||
envpool = None
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -57,8 +56,9 @@ def get_args():
|
||||
return args
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "linux", reason="envpool only support linux now")
|
||||
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
||||
def test_sac_with_il(args=get_args()):
|
||||
# if you want to use python vector env, please refer to other test scripts
|
||||
train_envs = env = envpool.make_gym(
|
||||
args.task, num_envs=args.training_num, seed=args.seed
|
||||
)
|
||||
|
@ -2,9 +2,9 @@ import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import envpool
|
||||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -15,6 +15,11 @@ from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
|
||||
try:
|
||||
import envpool
|
||||
except ImportError:
|
||||
envpool = None
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -52,7 +57,9 @@ def get_args():
|
||||
return args
|
||||
|
||||
|
||||
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
||||
def test_a2c_with_il(args=get_args()):
|
||||
# if you want to use python vector env, please refer to other test scripts
|
||||
train_envs = env = envpool.make_gym(
|
||||
args.task, num_envs=args.training_num, seed=args.seed
|
||||
)
|
||||
|
@ -2,8 +2,8 @@ import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import envpool
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -12,6 +12,11 @@ from tianshou.policy import PSRLPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger
|
||||
|
||||
try:
|
||||
import envpool
|
||||
except ImportError:
|
||||
envpool = None
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -40,7 +45,9 @@ def get_args():
|
||||
return parser.parse_known_args()[0]
|
||||
|
||||
|
||||
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
||||
def test_psrl(args=get_args()):
|
||||
# if you want to use python vector env, please refer to other test scripts
|
||||
train_envs = env = envpool.make_gym(
|
||||
args.task, num_envs=args.training_num, seed=args.seed
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user