naming
This commit is contained in:
parent
40da985c6b
commit
e316499047
@ -14,14 +14,14 @@ class MockEnv(Module):
|
||||
self,
|
||||
image_shape,
|
||||
reward_range = (-100., 100.),
|
||||
batch_size = 1,
|
||||
num_envs = 1,
|
||||
vectorized = False
|
||||
):
|
||||
super().__init__()
|
||||
self.image_shape = image_shape
|
||||
self.reward_range = reward_range
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.num_envs = num_envs
|
||||
self.vectorized = vectorized
|
||||
self.register_buffer('_step', tensor(0))
|
||||
|
||||
@ -42,8 +42,12 @@ class MockEnv(Module):
|
||||
state = self.get_random_state()
|
||||
reward = randint(*self.reward_range, ()).float()
|
||||
|
||||
if self.vectorized:
|
||||
state = repeat(state, '... -> b ...', b = self.batch_size)
|
||||
reward = repeat(rewardstate, ' -> b', b = self.batch_size)
|
||||
if not self.vectorized:
|
||||
return state, reward
|
||||
|
||||
return state, reward
|
||||
assert actions.shape[0] == self.num_envs
|
||||
|
||||
state = repeat(state, '... -> b ...', b = self.num_envs)
|
||||
reward = repeat(reward, ' -> b', b = self.num_envs)
|
||||
|
||||
return state, rewards
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user