Tests: fixed test_psrl.py: use args.reward_threshold instead of spec
For some reason now env.spec.reward_treshold is None - some change in upstream code Also added better pytest skip message
This commit is contained in:
parent
6a5b3c837a
commit
78ea013956
@ -44,7 +44,10 @@ def get_args() -> argparse.Namespace:
|
|||||||
return parser.parse_known_args()[0]
|
return parser.parse_known_args()[0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
@pytest.mark.skipif(
|
||||||
|
envpool is None,
|
||||||
|
reason="EnvPool is not installed. If on linux, please install it (e.g. as poetry extra)",
|
||||||
|
)
|
||||||
def test_psrl(args: argparse.Namespace = get_args()) -> None:
|
def test_psrl(args: argparse.Namespace = get_args()) -> None:
|
||||||
# if you want to use python vector env, please refer to other test scripts
|
# if you want to use python vector env, please refer to other test scripts
|
||||||
train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed)
|
train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed)
|
||||||
@ -115,4 +118,4 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None:
|
|||||||
logger=logger,
|
logger=logger,
|
||||||
test_in_train=False,
|
test_in_train=False,
|
||||||
).run()
|
).run()
|
||||||
assert result.best_reward >= env.spec.reward_threshold
|
assert result.best_reward >= args.reward_threshold
|
||||||
|
Loading…
x
Reference in New Issue
Block a user