diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 4f739f0..48a5f20 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -1,5 +1,6 @@ import argparse import os +import warnings from typing import List, Optional, Tuple import gym @@ -177,6 +178,12 @@ def watch( args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None ) -> None: env = DummyVectorEnv([get_env]) + if not policy: + warnings.warn( + "watching random agents, as loading pre-trained policies is " + "currently not supported" + ) + policy, _, _ = get_agents(args) policy.eval() [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector(policy, env, exploration_noise=True) diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index da1285a..8f48184 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -1,5 +1,6 @@ import argparse import os +import warnings from typing import Any, Dict, List, Optional, Tuple, Union import gym @@ -269,6 +270,12 @@ def watch( args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None ) -> None: env = DummyVectorEnv([get_env]) + if not policy: + warnings.warn( + "watching random agents, as loading pre-trained policies is " + "currently not supported" + ) + policy, _, _ = get_agents(args) policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render)