erased unnecessary reward input
This commit is contained in:
parent
9ca5082da3
commit
f07d843953
@ -66,7 +66,7 @@ class Dreamer(nn.Module):
|
|||||||
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
|
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
|
||||||
)[config.expl_behavior]().to(self._config.device)
|
)[config.expl_behavior]().to(self._config.device)
|
||||||
|
|
||||||
def __call__(self, obs, reset, state=None, reward=None, training=True):
|
def __call__(self, obs, reset, state=None, training=True):
|
||||||
step = self._step
|
step = self._step
|
||||||
if self._should_reset(step):
|
if self._should_reset(step):
|
||||||
state = None
|
state = None
|
||||||
@ -295,7 +295,7 @@ def main(config):
|
|||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
|
||||||
def random_agent(o, d, s, r):
|
def random_agent(o, d, s):
|
||||||
action = random_actor.sample()
|
action = random_actor.sample()
|
||||||
logprob = random_actor.log_prob(action)
|
logprob = random_actor.log_prob(action)
|
||||||
return {"action": action, "logprob": logprob}, None
|
return {"action": action, "logprob": logprob}, None
|
||||||
|
3
tools.py
3
tools.py
@ -148,10 +148,9 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s
|
|||||||
add_to_cache(cache, envs[i].id, t)
|
add_to_cache(cache, envs[i].id, t)
|
||||||
for index, result in zip(indices, results):
|
for index, result in zip(indices, results):
|
||||||
obs[index] = result
|
obs[index] = result
|
||||||
reward = [reward[i] * (1 - done[i]) for i in range(len(envs))]
|
|
||||||
# Step agents.
|
# Step agents.
|
||||||
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
|
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
|
||||||
action, agent_state = agent(obs, done, agent_state, reward)
|
action, agent_state = agent(obs, done, agent_state)
|
||||||
if isinstance(action, dict):
|
if isinstance(action, dict):
|
||||||
action = [
|
action = [
|
||||||
{k: np.array(action[k][i].detach().cpu()) for k in action}
|
{k: np.array(action[k][i].detach().cpu()) for k in action}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user