Fix exception with watching pistonball environments (#663)

This commit is contained in:
Yifei Cheng 2022-06-11 15:12:48 -04:00 committed by GitHub
parent df35718992
commit 21b15803ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 0 deletions

View File

@ -1,5 +1,6 @@
import argparse import argparse
import os import os
import warnings
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import gym import gym
@ -177,6 +178,12 @@ def watch(
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
) -> None: ) -> None:
env = DummyVectorEnv([get_env]) env = DummyVectorEnv([get_env])
if not policy:
warnings.warn(
"watching random agents, as loading pre-trained policies is "
"currently not supported"
)
policy, _, _ = get_agents(args)
policy.eval() policy.eval()
[agent.set_eps(args.eps_test) for agent in policy.policies.values()] [agent.set_eps(args.eps_test) for agent in policy.policies.values()]
collector = Collector(policy, env, exploration_noise=True) collector = Collector(policy, env, exploration_noise=True)

View File

@ -1,5 +1,6 @@
import argparse import argparse
import os import os
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import gym import gym
@ -269,6 +270,12 @@ def watch(
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
) -> None: ) -> None:
env = DummyVectorEnv([get_env]) env = DummyVectorEnv([get_env])
if not policy:
warnings.warn(
"watching random agents, as loading pre-trained policies is "
"currently not supported"
)
policy, _, _ = get_agents(args)
policy.eval() policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render) result = collector.collect(n_episode=1, render=args.render)