This commit is contained in:
lucidrains 2025-10-21 10:57:55 -07:00
parent 40da985c6b
commit e316499047

View File

@ -14,14 +14,14 @@ class MockEnv(Module):
self, self,
image_shape, image_shape,
reward_range = (-100., 100.), reward_range = (-100., 100.),
batch_size = 1, num_envs = 1,
vectorized = False vectorized = False
): ):
super().__init__() super().__init__()
self.image_shape = image_shape self.image_shape = image_shape
self.reward_range = reward_range self.reward_range = reward_range
self.batch_size = batch_size self.num_envs = num_envs
self.vectorized = vectorized self.vectorized = vectorized
self.register_buffer('_step', tensor(0)) self.register_buffer('_step', tensor(0))
@ -42,8 +42,12 @@ class MockEnv(Module):
state = self.get_random_state() state = self.get_random_state()
reward = randint(*self.reward_range, ()).float() reward = randint(*self.reward_range, ()).float()
if self.vectorized: if not self.vectorized:
state = repeat(state, '... -> b ...', b = self.batch_size)
reward = repeat(rewardstate, ' -> b', b = self.batch_size)
return state, reward return state, reward
assert actions.shape[0] == self.num_envs
state = repeat(state, '... -> b ...', b = self.num_envs)
reward = repeat(reward, ' -> b', b = self.num_envs)
return state, rewards