avoid ".to(device)"
This commit is contained in:
parent
669b7e1b43
commit
7433d1e877
@ -258,8 +258,8 @@ def main(config):
|
|||||||
else:
|
else:
|
||||||
random_actor = torchd.independent.Independent(
|
random_actor = torchd.independent.Independent(
|
||||||
torchd.uniform.Uniform(
|
torchd.uniform.Uniform(
|
||||||
torch.Tensor(acts.low).repeat(config.envs, 1),
|
torch.tensor(acts.low).repeat(config.envs, 1),
|
||||||
torch.Tensor(acts.high).repeat(config.envs, 1),
|
torch.tensor(acts.high).repeat(config.envs, 1),
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
@ -16,19 +16,19 @@ class Random(nn.Module):
|
|||||||
def actor(self, feat):
|
def actor(self, feat):
|
||||||
if self._config.actor["dist"] == "onehot":
|
if self._config.actor["dist"] == "onehot":
|
||||||
return tools.OneHotDist(
|
return tools.OneHotDist(
|
||||||
torch.zeros(self._config.num_actions)
|
torch.zeros(
|
||||||
.repeat(self._config.envs, 1)
|
self._config.num_actions, device=self._config.device
|
||||||
.to(self._config.device)
|
).repeat(self._config.envs, 1)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return torchd.independent.Independent(
|
return torchd.independent.Independent(
|
||||||
torchd.uniform.Uniform(
|
torchd.uniform.Uniform(
|
||||||
torch.Tensor(self._act_space.low)
|
torch.tensor(
|
||||||
.repeat(self._config.envs, 1)
|
self._act_space.low, device=self._config.device
|
||||||
.to(self._config.device),
|
).repeat(self._config.envs, 1),
|
||||||
torch.Tensor(self._act_space.high)
|
torch.tensor(
|
||||||
.repeat(self._config.envs, 1)
|
self._act_space.high, device=self._config.device
|
||||||
.to(self._config.device),
|
).repeat(self._config.envs, 1),
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
@ -97,7 +97,8 @@ class Plan2Explore(nn.Module):
|
|||||||
inputs = context["feat"]
|
inputs = context["feat"]
|
||||||
if self._config.disag_action_cond:
|
if self._config.disag_action_cond:
|
||||||
inputs = torch.concat(
|
inputs = torch.concat(
|
||||||
[inputs, torch.Tensor(data["action"]).to(self._config.device)], -1
|
[inputs, torch.tensor(data["action"], device=self._config.device)],
|
||||||
|
-1,
|
||||||
)
|
)
|
||||||
metrics.update(self._train_ensemble(inputs, target))
|
metrics.update(self._train_ensemble(inputs, target))
|
||||||
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
||||||
|
18
models.py
18
models.py
@ -14,7 +14,7 @@ class RewardEMA:
|
|||||||
def __init__(self, device, alpha=1e-2):
|
def __init__(self, device, alpha=1e-2):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.range = torch.tensor([0.05, 0.95]).to(device)
|
self.range = torch.tensor([0.05, 0.95], device=device)
|
||||||
|
|
||||||
def __call__(self, x, ema_vals):
|
def __call__(self, x, ema_vals):
|
||||||
flat_x = torch.flatten(x.detach())
|
flat_x = torch.flatten(x.detach())
|
||||||
@ -172,18 +172,20 @@ class WorldModel(nn.Module):
|
|||||||
|
|
||||||
# this function is called during both rollout and training
|
# this function is called during both rollout and training
|
||||||
def preprocess(self, obs):
|
def preprocess(self, obs):
|
||||||
obs = obs.copy()
|
obs = {
|
||||||
obs["image"] = torch.Tensor(obs["image"]) / 255.0
|
k: torch.tensor(v, device=self._config.device, dtype=torch.float32)
|
||||||
|
for k, v in obs.items()
|
||||||
|
}
|
||||||
|
obs["image"] = obs["image"] / 255.0
|
||||||
if "discount" in obs:
|
if "discount" in obs:
|
||||||
obs["discount"] *= self._config.discount
|
obs["discount"] *= self._config.discount
|
||||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||||
obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
|
obs["discount"] = obs["discount"].unsqueeze(-1)
|
||||||
# 'is_first' is necesarry to initialize hidden state at training
|
# 'is_first' is necesarry to initialize hidden state at training
|
||||||
assert "is_first" in obs
|
assert "is_first" in obs
|
||||||
# 'is_terminal' is necesarry to train cont_head
|
# 'is_terminal' is necesarry to train cont_head
|
||||||
assert "is_terminal" in obs
|
assert "is_terminal" in obs
|
||||||
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
|
obs["cont"] = (1.0 - obs["is_terminal"]).unsqueeze(-1)
|
||||||
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
|
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def video_pred(self, data):
|
def video_pred(self, data):
|
||||||
@ -277,7 +279,9 @@ class ImagBehavior(nn.Module):
|
|||||||
)
|
)
|
||||||
if self._config.reward_EMA:
|
if self._config.reward_EMA:
|
||||||
# register ema_vals to nn.Module for enabling torch.save and torch.load
|
# register ema_vals to nn.Module for enabling torch.save and torch.load
|
||||||
self.register_buffer("ema_vals", torch.zeros((2,)).to(self._config.device))
|
self.register_buffer(
|
||||||
|
"ema_vals", torch.zeros((2,), device=self._config.device)
|
||||||
|
)
|
||||||
self.reward_ema = RewardEMA(device=self._config.device)
|
self.reward_ema = RewardEMA(device=self._config.device)
|
||||||
|
|
||||||
def _train(
|
def _train(
|
||||||
|
20
networks.py
20
networks.py
@ -97,22 +97,22 @@ class RSSM(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def initial(self, batch_size):
|
def initial(self, batch_size):
|
||||||
deter = torch.zeros(batch_size, self._deter).to(self._device)
|
deter = torch.zeros(batch_size, self._deter, device=self._device)
|
||||||
if self._discrete:
|
if self._discrete:
|
||||||
state = dict(
|
state = dict(
|
||||||
logit=torch.zeros([batch_size, self._stoch, self._discrete]).to(
|
logit=torch.zeros(
|
||||||
self._device
|
[batch_size, self._stoch, self._discrete], device=self._device
|
||||||
),
|
),
|
||||||
stoch=torch.zeros([batch_size, self._stoch, self._discrete]).to(
|
stoch=torch.zeros(
|
||||||
self._device
|
[batch_size, self._stoch, self._discrete], device=self._device
|
||||||
),
|
),
|
||||||
deter=deter,
|
deter=deter,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state = dict(
|
state = dict(
|
||||||
mean=torch.zeros([batch_size, self._stoch]).to(self._device),
|
mean=torch.zeros([batch_size, self._stoch], device=self._device),
|
||||||
std=torch.zeros([batch_size, self._stoch]).to(self._device),
|
std=torch.zeros([batch_size, self._stoch], device=self._device),
|
||||||
stoch=torch.zeros([batch_size, self._stoch]).to(self._device),
|
stoch=torch.zeros([batch_size, self._stoch], device=self._device),
|
||||||
deter=deter,
|
deter=deter,
|
||||||
)
|
)
|
||||||
if self._initial == "zeros":
|
if self._initial == "zeros":
|
||||||
@ -175,8 +175,8 @@ class RSSM(nn.Module):
|
|||||||
# initialize all prev_state
|
# initialize all prev_state
|
||||||
if prev_state == None or torch.sum(is_first) == len(is_first):
|
if prev_state == None or torch.sum(is_first) == len(is_first):
|
||||||
prev_state = self.initial(len(is_first))
|
prev_state = self.initial(len(is_first))
|
||||||
prev_action = torch.zeros((len(is_first), self._num_actions)).to(
|
prev_action = torch.zeros(
|
||||||
self._device
|
(len(is_first), self._num_actions), device=self._device
|
||||||
)
|
)
|
||||||
# overwrite the prev_state only where is_first=True
|
# overwrite the prev_state only where is_first=True
|
||||||
elif torch.sum(is_first) > 0:
|
elif torch.sum(is_first) > 0:
|
||||||
|
7
tools.py
7
tools.py
@ -461,7 +461,7 @@ class DiscDist:
|
|||||||
):
|
):
|
||||||
self.logits = logits
|
self.logits = logits
|
||||||
self.probs = torch.softmax(logits, -1)
|
self.probs = torch.softmax(logits, -1)
|
||||||
self.buckets = torch.linspace(low, high, steps=255).to(device)
|
self.buckets = torch.linspace(low, high, steps=255, device=device)
|
||||||
self.width = (self.buckets[-1] - self.buckets[0]) / 255
|
self.width = (self.buckets[-1] - self.buckets[0]) / 255
|
||||||
self.transfwd = transfwd
|
self.transfwd = transfwd
|
||||||
self.transbwd = transbwd
|
self.transbwd = transbwd
|
||||||
@ -624,8 +624,7 @@ class UnnormalizedHuber(torchd.normal.Normal):
|
|||||||
|
|
||||||
def log_prob(self, event):
|
def log_prob(self, event):
|
||||||
return -(
|
return -(
|
||||||
torch.sqrt((event - self.mean) ** 2 + self._threshold**2)
|
torch.sqrt((event - self.mean) ** 2 + self._threshold**2) - self._threshold
|
||||||
- self._threshold
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def mode(self):
|
def mode(self):
|
||||||
@ -762,7 +761,7 @@ class Optimizer:
|
|||||||
self._scaler.update()
|
self._scaler.update()
|
||||||
# self._opt.step()
|
# self._opt.step()
|
||||||
self._opt.zero_grad()
|
self._opt.zero_grad()
|
||||||
metrics[f"{self._name}_grad_norm"] = norm.item()
|
metrics[f"{self._name}_grad_norm"] = to_np(norm)
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def _apply_weight_decay(self, varibs):
|
def _apply_weight_decay(self, varibs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user