From a03f19af7266293921d7a7cc53c15b5654e74203 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 12 May 2022 08:52:55 -0400 Subject: [PATCH] fix pytest error on non-linux system (#638) --- test/continuous/test_sac_with_il.py | 14 +++++++------- test/discrete/test_a2c_with_il.py | 9 ++++++++- test/modelbased/test_psrl.py | 9 ++++++++- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index a204e55..b65e2d3 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -1,11 +1,5 @@ import argparse import os -import sys - -try: - import envpool -except ImportError: - envpool = None import numpy as np import pytest @@ -19,6 +13,11 @@ from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, ActorProb, Critic +try: + import envpool +except ImportError: + envpool = None + def get_args(): parser = argparse.ArgumentParser() @@ -57,8 +56,9 @@ def get_args(): return args -@pytest.mark.skipif(sys.platform != "linux", reason="envpool only support linux now") +@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_sac_with_il(args=get_args()): + # if you want to use python vector env, please refer to other test scripts train_envs = env = envpool.make_gym( args.task, num_envs=args.training_num, seed=args.seed ) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index b777e2e..0a453e1 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -2,9 +2,9 @@ import argparse import os import pprint -import envpool import gym import numpy as np +import pytest import torch from torch.utils.tensorboard import SummaryWriter @@ -15,6 +15,11 @@ from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic +try: + import envpool +except ImportError: + envpool = None + def get_args(): parser = argparse.ArgumentParser() @@ -52,7 +57,9 @@ def get_args(): return args +@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_a2c_with_il(args=get_args()): + # if you want to use python vector env, please refer to other test scripts train_envs = env = envpool.make_gym( args.task, num_envs=args.training_num, seed=args.seed ) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 3a30230..e0716aa 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -2,8 +2,8 @@ import argparse import os import pprint -import envpool import numpy as np +import pytest import torch from torch.utils.tensorboard import SummaryWriter @@ -12,6 +12,11 @@ from tianshou.policy import PSRLPolicy from tianshou.trainer import onpolicy_trainer from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger +try: + import envpool +except ImportError: + envpool = None + def get_args(): parser = argparse.ArgumentParser() @@ -40,7 +45,9 @@ def get_args(): return parser.parse_known_args()[0] +@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_psrl(args=get_args()): + # if you want to use python vector env, please refer to other test scripts train_envs = env = envpool.make_gym( args.task, num_envs=args.training_num, seed=args.seed )