fix pytest error on non-linux system (#638)

This commit is contained in:
Jiayi Weng 2022-05-12 08:52:55 -04:00 committed by GitHub
parent bf8f63ffc3
commit a03f19af72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 9 deletions

View File

@ -1,11 +1,5 @@
import argparse
import os
import sys
try:
import envpool
except ImportError:
envpool = None
import numpy as np
import pytest
@ -19,6 +13,11 @@ from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, ActorProb, Critic
try:
import envpool
except ImportError:
envpool = None
def get_args():
parser = argparse.ArgumentParser()
@ -57,8 +56,9 @@ def get_args():
return args
@pytest.mark.skipif(sys.platform != "linux", reason="envpool only support linux now")
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_sac_with_il(args=get_args()):
# if you want to use python vector env, please refer to other test scripts
train_envs = env = envpool.make_gym(
args.task, num_envs=args.training_num, seed=args.seed
)

View File

@ -2,9 +2,9 @@ import argparse
import os
import pprint
import envpool
import gym
import numpy as np
import pytest
import torch
from torch.utils.tensorboard import SummaryWriter
@ -15,6 +15,11 @@ from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic
try:
import envpool
except ImportError:
envpool = None
def get_args():
parser = argparse.ArgumentParser()
@ -52,7 +57,9 @@ def get_args():
return args
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_a2c_with_il(args=get_args()):
# if you want to use python vector env, please refer to other test scripts
train_envs = env = envpool.make_gym(
args.task, num_envs=args.training_num, seed=args.seed
)

View File

@ -2,8 +2,8 @@ import argparse
import os
import pprint
import envpool
import numpy as np
import pytest
import torch
from torch.utils.tensorboard import SummaryWriter
@ -12,6 +12,11 @@ from tianshou.policy import PSRLPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger
try:
import envpool
except ImportError:
envpool = None
def get_args():
parser = argparse.ArgumentParser()
@ -40,7 +45,9 @@ def get_args():
return parser.parse_known_args()[0]
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_psrl(args=get_args()):
# if you want to use python vector env, please refer to other test scripts
train_envs = env = envpool.make_gym(
args.task, num_envs=args.training_num, seed=args.seed
)