fix pytest error on non-linux system (#638)

This commit is contained in:
Jiayi Weng 2022-05-12 08:52:55 -04:00 committed by GitHub
parent bf8f63ffc3
commit a03f19af72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 9 deletions

View File

@ -1,11 +1,5 @@
import argparse import argparse
import os import os
import sys
try:
import envpool
except ImportError:
envpool = None
import numpy as np import numpy as np
import pytest import pytest
@ -19,6 +13,11 @@ from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, ActorProb, Critic from tianshou.utils.net.continuous import Actor, ActorProb, Critic
try:
import envpool
except ImportError:
envpool = None
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -57,8 +56,9 @@ def get_args():
return args return args
@pytest.mark.skipif(sys.platform != "linux", reason="envpool only support linux now") @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_sac_with_il(args=get_args()): def test_sac_with_il(args=get_args()):
# if you want to use python vector env, please refer to other test scripts
train_envs = env = envpool.make_gym( train_envs = env = envpool.make_gym(
args.task, num_envs=args.training_num, seed=args.seed args.task, num_envs=args.training_num, seed=args.seed
) )

View File

@ -2,9 +2,9 @@ import argparse
import os import os
import pprint import pprint
import envpool
import gym import gym
import numpy as np import numpy as np
import pytest
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -15,6 +15,11 @@ from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.net.discrete import Actor, Critic
try:
import envpool
except ImportError:
envpool = None
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -52,7 +57,9 @@ def get_args():
return args return args
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_a2c_with_il(args=get_args()): def test_a2c_with_il(args=get_args()):
# if you want to use python vector env, please refer to other test scripts
train_envs = env = envpool.make_gym( train_envs = env = envpool.make_gym(
args.task, num_envs=args.training_num, seed=args.seed args.task, num_envs=args.training_num, seed=args.seed
) )

View File

@ -2,8 +2,8 @@ import argparse
import os import os
import pprint import pprint
import envpool
import numpy as np import numpy as np
import pytest
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -12,6 +12,11 @@ from tianshou.policy import PSRLPolicy
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger
try:
import envpool
except ImportError:
envpool = None
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -40,7 +45,9 @@ def get_args():
return parser.parse_known_args()[0] return parser.parse_known_args()[0]
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_psrl(args=get_args()): def test_psrl(args=get_args()):
# if you want to use python vector env, please refer to other test scripts
train_envs = env = envpool.make_gym( train_envs = env = envpool.make_gym(
args.task, num_envs=args.training_num, seed=args.seed args.task, num_envs=args.training_num, seed=args.seed
) )