Fix exception with watching pistonball environments (#663)
This commit is contained in:
parent
df35718992
commit
21b15803ac
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import gym
|
||||
@ -177,6 +178,12 @@ def watch(
|
||||
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
||||
) -> None:
|
||||
env = DummyVectorEnv([get_env])
|
||||
if not policy:
|
||||
warnings.warn(
|
||||
"watching random agents, as loading pre-trained policies is "
|
||||
"currently not supported"
|
||||
)
|
||||
policy, _, _ = get_agents(args)
|
||||
policy.eval()
|
||||
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
||||
collector = Collector(policy, env, exploration_noise=True)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
@ -269,6 +270,12 @@ def watch(
|
||||
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
||||
) -> None:
|
||||
env = DummyVectorEnv([get_env])
|
||||
if not policy:
|
||||
warnings.warn(
|
||||
"watching random agents, as loading pre-trained policies is "
|
||||
"currently not supported"
|
||||
)
|
||||
policy, _, _ = get_agents(args)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
|
Loading…
x
Reference in New Issue
Block a user