Fix exception with watching pistonball environments (#663)
This commit is contained in:
parent
df35718992
commit
21b15803ac
@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
@ -177,6 +178,12 @@ def watch(
|
|||||||
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
env = DummyVectorEnv([get_env])
|
env = DummyVectorEnv([get_env])
|
||||||
|
if not policy:
|
||||||
|
warnings.warn(
|
||||||
|
"watching random agents, as loading pre-trained policies is "
|
||||||
|
"currently not supported"
|
||||||
|
)
|
||||||
|
policy, _, _ = get_agents(args)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
||||||
collector = Collector(policy, env, exploration_noise=True)
|
collector = Collector(policy, env, exploration_noise=True)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
@ -269,6 +270,12 @@ def watch(
|
|||||||
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
env = DummyVectorEnv([get_env])
|
env = DummyVectorEnv([get_env])
|
||||||
|
if not policy:
|
||||||
|
warnings.warn(
|
||||||
|
"watching random agents, as loading pre-trained policies is "
|
||||||
|
"currently not supported"
|
||||||
|
)
|
||||||
|
policy, _, _ = get_agents(args)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
result = collector.collect(n_episode=1, render=args.render)
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user