avoid ".to(device)"
This commit is contained in:
parent
669b7e1b43
commit
7433d1e877
@ -258,8 +258,8 @@ def main(config):
|
||||
else:
|
||||
random_actor = torchd.independent.Independent(
|
||||
torchd.uniform.Uniform(
|
||||
torch.Tensor(acts.low).repeat(config.envs, 1),
|
||||
torch.Tensor(acts.high).repeat(config.envs, 1),
|
||||
torch.tensor(acts.low).repeat(config.envs, 1),
|
||||
torch.tensor(acts.high).repeat(config.envs, 1),
|
||||
),
|
||||
1,
|
||||
)
|
||||
|
@ -16,19 +16,19 @@ class Random(nn.Module):
|
||||
def actor(self, feat):
|
||||
if self._config.actor["dist"] == "onehot":
|
||||
return tools.OneHotDist(
|
||||
torch.zeros(self._config.num_actions)
|
||||
.repeat(self._config.envs, 1)
|
||||
.to(self._config.device)
|
||||
torch.zeros(
|
||||
self._config.num_actions, device=self._config.device
|
||||
).repeat(self._config.envs, 1)
|
||||
)
|
||||
else:
|
||||
return torchd.independent.Independent(
|
||||
torchd.uniform.Uniform(
|
||||
torch.Tensor(self._act_space.low)
|
||||
.repeat(self._config.envs, 1)
|
||||
.to(self._config.device),
|
||||
torch.Tensor(self._act_space.high)
|
||||
.repeat(self._config.envs, 1)
|
||||
.to(self._config.device),
|
||||
torch.tensor(
|
||||
self._act_space.low, device=self._config.device
|
||||
).repeat(self._config.envs, 1),
|
||||
torch.tensor(
|
||||
self._act_space.high, device=self._config.device
|
||||
).repeat(self._config.envs, 1),
|
||||
),
|
||||
1,
|
||||
)
|
||||
@ -97,7 +97,8 @@ class Plan2Explore(nn.Module):
|
||||
inputs = context["feat"]
|
||||
if self._config.disag_action_cond:
|
||||
inputs = torch.concat(
|
||||
[inputs, torch.Tensor(data["action"]).to(self._config.device)], -1
|
||||
[inputs, torch.tensor(data["action"], device=self._config.device)],
|
||||
-1,
|
||||
)
|
||||
metrics.update(self._train_ensemble(inputs, target))
|
||||
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
||||
|
18
models.py
18
models.py
@ -14,7 +14,7 @@ class RewardEMA:
|
||||
def __init__(self, device, alpha=1e-2):
|
||||
self.device = device
|
||||
self.alpha = alpha
|
||||
self.range = torch.tensor([0.05, 0.95]).to(device)
|
||||
self.range = torch.tensor([0.05, 0.95], device=device)
|
||||
|
||||
def __call__(self, x, ema_vals):
|
||||
flat_x = torch.flatten(x.detach())
|
||||
@ -172,18 +172,20 @@ class WorldModel(nn.Module):
|
||||
|
||||
# this function is called during both rollout and training
|
||||
def preprocess(self, obs):
|
||||
obs = obs.copy()
|
||||
obs["image"] = torch.Tensor(obs["image"]) / 255.0
|
||||
obs = {
|
||||
k: torch.tensor(v, device=self._config.device, dtype=torch.float32)
|
||||
for k, v in obs.items()
|
||||
}
|
||||
obs["image"] = obs["image"] / 255.0
|
||||
if "discount" in obs:
|
||||
obs["discount"] *= self._config.discount
|
||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||
obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
|
||||
obs["discount"] = obs["discount"].unsqueeze(-1)
|
||||
# 'is_first' is necesarry to initialize hidden state at training
|
||||
assert "is_first" in obs
|
||||
# 'is_terminal' is necesarry to train cont_head
|
||||
assert "is_terminal" in obs
|
||||
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
|
||||
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
|
||||
obs["cont"] = (1.0 - obs["is_terminal"]).unsqueeze(-1)
|
||||
return obs
|
||||
|
||||
def video_pred(self, data):
|
||||
@ -277,7 +279,9 @@ class ImagBehavior(nn.Module):
|
||||
)
|
||||
if self._config.reward_EMA:
|
||||
# register ema_vals to nn.Module for enabling torch.save and torch.load
|
||||
self.register_buffer("ema_vals", torch.zeros((2,)).to(self._config.device))
|
||||
self.register_buffer(
|
||||
"ema_vals", torch.zeros((2,), device=self._config.device)
|
||||
)
|
||||
self.reward_ema = RewardEMA(device=self._config.device)
|
||||
|
||||
def _train(
|
||||
|
20
networks.py
20
networks.py
@ -97,22 +97,22 @@ class RSSM(nn.Module):
|
||||
)
|
||||
|
||||
def initial(self, batch_size):
|
||||
deter = torch.zeros(batch_size, self._deter).to(self._device)
|
||||
deter = torch.zeros(batch_size, self._deter, device=self._device)
|
||||
if self._discrete:
|
||||
state = dict(
|
||||
logit=torch.zeros([batch_size, self._stoch, self._discrete]).to(
|
||||
self._device
|
||||
logit=torch.zeros(
|
||||
[batch_size, self._stoch, self._discrete], device=self._device
|
||||
),
|
||||
stoch=torch.zeros([batch_size, self._stoch, self._discrete]).to(
|
||||
self._device
|
||||
stoch=torch.zeros(
|
||||
[batch_size, self._stoch, self._discrete], device=self._device
|
||||
),
|
||||
deter=deter,
|
||||
)
|
||||
else:
|
||||
state = dict(
|
||||
mean=torch.zeros([batch_size, self._stoch]).to(self._device),
|
||||
std=torch.zeros([batch_size, self._stoch]).to(self._device),
|
||||
stoch=torch.zeros([batch_size, self._stoch]).to(self._device),
|
||||
mean=torch.zeros([batch_size, self._stoch], device=self._device),
|
||||
std=torch.zeros([batch_size, self._stoch], device=self._device),
|
||||
stoch=torch.zeros([batch_size, self._stoch], device=self._device),
|
||||
deter=deter,
|
||||
)
|
||||
if self._initial == "zeros":
|
||||
@ -175,8 +175,8 @@ class RSSM(nn.Module):
|
||||
# initialize all prev_state
|
||||
if prev_state == None or torch.sum(is_first) == len(is_first):
|
||||
prev_state = self.initial(len(is_first))
|
||||
prev_action = torch.zeros((len(is_first), self._num_actions)).to(
|
||||
self._device
|
||||
prev_action = torch.zeros(
|
||||
(len(is_first), self._num_actions), device=self._device
|
||||
)
|
||||
# overwrite the prev_state only where is_first=True
|
||||
elif torch.sum(is_first) > 0:
|
||||
|
7
tools.py
7
tools.py
@ -461,7 +461,7 @@ class DiscDist:
|
||||
):
|
||||
self.logits = logits
|
||||
self.probs = torch.softmax(logits, -1)
|
||||
self.buckets = torch.linspace(low, high, steps=255).to(device)
|
||||
self.buckets = torch.linspace(low, high, steps=255, device=device)
|
||||
self.width = (self.buckets[-1] - self.buckets[0]) / 255
|
||||
self.transfwd = transfwd
|
||||
self.transbwd = transbwd
|
||||
@ -624,8 +624,7 @@ class UnnormalizedHuber(torchd.normal.Normal):
|
||||
|
||||
def log_prob(self, event):
|
||||
return -(
|
||||
torch.sqrt((event - self.mean) ** 2 + self._threshold**2)
|
||||
- self._threshold
|
||||
torch.sqrt((event - self.mean) ** 2 + self._threshold**2) - self._threshold
|
||||
)
|
||||
|
||||
def mode(self):
|
||||
@ -762,7 +761,7 @@ class Optimizer:
|
||||
self._scaler.update()
|
||||
# self._opt.step()
|
||||
self._opt.zero_grad()
|
||||
metrics[f"{self._name}_grad_norm"] = norm.item()
|
||||
metrics[f"{self._name}_grad_norm"] = to_np(norm)
|
||||
return metrics
|
||||
|
||||
def _apply_weight_decay(self, varibs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user