Fix exception with watching pistonball environments (#663)

This commit is contained in:
Yifei Cheng 2022-06-11 15:12:48 -04:00 committed by GitHub
parent df35718992
commit 21b15803ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 0 deletions

View File

@ -1,5 +1,6 @@
import argparse
import os
import warnings
from typing import List, Optional, Tuple
import gym
@ -177,6 +178,12 @@ def watch(
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
) -> None:
env = DummyVectorEnv([get_env])
if not policy:
warnings.warn(
"watching random agents, as loading pre-trained policies is "
"currently not supported"
)
policy, _, _ = get_agents(args)
policy.eval()
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
collector = Collector(policy, env, exploration_noise=True)

View File

@ -1,5 +1,6 @@
import argparse
import os
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
import gym
@ -269,6 +270,12 @@ def watch(
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
) -> None:
env = DummyVectorEnv([get_env])
if not policy:
warnings.warn(
"watching random agents, as loading pre-trained policies is "
"currently not supported"
)
policy, _, _ = get_agents(args)
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)