From 7433d1e87747ff574cdeffb5820f2a2647d212bd Mon Sep 17 00:00:00 2001 From: NM512 Date: Sat, 28 Sep 2024 07:58:15 +0900 Subject: [PATCH] avoid ".to(device)" --- dreamer.py | 4 ++-- exploration.py | 21 +++++++++++---------- models.py | 18 +++++++++++------- networks.py | 20 ++++++++++---------- tools.py | 7 +++---- 5 files changed, 37 insertions(+), 33 deletions(-) diff --git a/dreamer.py b/dreamer.py index 7384845..f4de475 100644 --- a/dreamer.py +++ b/dreamer.py @@ -258,8 +258,8 @@ def main(config): else: random_actor = torchd.independent.Independent( torchd.uniform.Uniform( - torch.Tensor(acts.low).repeat(config.envs, 1), - torch.Tensor(acts.high).repeat(config.envs, 1), + torch.tensor(acts.low).repeat(config.envs, 1), + torch.tensor(acts.high).repeat(config.envs, 1), ), 1, ) diff --git a/exploration.py b/exploration.py index 98d5231..fd335ce 100644 --- a/exploration.py +++ b/exploration.py @@ -16,19 +16,19 @@ class Random(nn.Module): def actor(self, feat): if self._config.actor["dist"] == "onehot": return tools.OneHotDist( - torch.zeros(self._config.num_actions) - .repeat(self._config.envs, 1) - .to(self._config.device) + torch.zeros( + self._config.num_actions, device=self._config.device + ).repeat(self._config.envs, 1) ) else: return torchd.independent.Independent( torchd.uniform.Uniform( - torch.Tensor(self._act_space.low) - .repeat(self._config.envs, 1) - .to(self._config.device), - torch.Tensor(self._act_space.high) - .repeat(self._config.envs, 1) - .to(self._config.device), + torch.tensor( + self._act_space.low, device=self._config.device + ).repeat(self._config.envs, 1), + torch.tensor( + self._act_space.high, device=self._config.device + ).repeat(self._config.envs, 1), ), 1, ) @@ -97,7 +97,8 @@ class Plan2Explore(nn.Module): inputs = context["feat"] if self._config.disag_action_cond: inputs = torch.concat( - [inputs, torch.Tensor(data["action"]).to(self._config.device)], -1 + [inputs, torch.tensor(data["action"], device=self._config.device)], + -1, ) metrics.update(self._train_ensemble(inputs, target)) metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1]) diff --git a/models.py b/models.py index f0d3f63..5d27ff1 100644 --- a/models.py +++ b/models.py @@ -14,7 +14,7 @@ class RewardEMA: def __init__(self, device, alpha=1e-2): self.device = device self.alpha = alpha - self.range = torch.tensor([0.05, 0.95]).to(device) + self.range = torch.tensor([0.05, 0.95], device=device) def __call__(self, x, ema_vals): flat_x = torch.flatten(x.detach()) @@ -172,18 +172,20 @@ class WorldModel(nn.Module): # this function is called during both rollout and training def preprocess(self, obs): - obs = obs.copy() - obs["image"] = torch.Tensor(obs["image"]) / 255.0 + obs = { + k: torch.tensor(v, device=self._config.device, dtype=torch.float32) + for k, v in obs.items() + } + obs["image"] = obs["image"] / 255.0 if "discount" in obs: obs["discount"] *= self._config.discount # (batch_size, batch_length) -> (batch_size, batch_length, 1) - obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1) + obs["discount"] = obs["discount"].unsqueeze(-1) # 'is_first' is necesarry to initialize hidden state at training assert "is_first" in obs # 'is_terminal' is necesarry to train cont_head assert "is_terminal" in obs - obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1) - obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()} + obs["cont"] = (1.0 - obs["is_terminal"]).unsqueeze(-1) return obs def video_pred(self, data): @@ -277,7 +279,9 @@ class ImagBehavior(nn.Module): ) if self._config.reward_EMA: # register ema_vals to nn.Module for enabling torch.save and torch.load - self.register_buffer("ema_vals", torch.zeros((2,)).to(self._config.device)) + self.register_buffer( + "ema_vals", torch.zeros((2,), device=self._config.device) + ) self.reward_ema = RewardEMA(device=self._config.device) def _train( diff --git a/networks.py b/networks.py index 2517b3b..12efdc1 100644 --- a/networks.py +++ b/networks.py @@ -97,22 +97,22 @@ class RSSM(nn.Module): ) def initial(self, batch_size): - deter = torch.zeros(batch_size, self._deter).to(self._device) + deter = torch.zeros(batch_size, self._deter, device=self._device) if self._discrete: state = dict( - logit=torch.zeros([batch_size, self._stoch, self._discrete]).to( - self._device + logit=torch.zeros( + [batch_size, self._stoch, self._discrete], device=self._device ), - stoch=torch.zeros([batch_size, self._stoch, self._discrete]).to( - self._device + stoch=torch.zeros( + [batch_size, self._stoch, self._discrete], device=self._device ), deter=deter, ) else: state = dict( - mean=torch.zeros([batch_size, self._stoch]).to(self._device), - std=torch.zeros([batch_size, self._stoch]).to(self._device), - stoch=torch.zeros([batch_size, self._stoch]).to(self._device), + mean=torch.zeros([batch_size, self._stoch], device=self._device), + std=torch.zeros([batch_size, self._stoch], device=self._device), + stoch=torch.zeros([batch_size, self._stoch], device=self._device), deter=deter, ) if self._initial == "zeros": @@ -175,8 +175,8 @@ class RSSM(nn.Module): # initialize all prev_state if prev_state == None or torch.sum(is_first) == len(is_first): prev_state = self.initial(len(is_first)) - prev_action = torch.zeros((len(is_first), self._num_actions)).to( - self._device + prev_action = torch.zeros( + (len(is_first), self._num_actions), device=self._device ) # overwrite the prev_state only where is_first=True elif torch.sum(is_first) > 0: diff --git a/tools.py b/tools.py index c968e68..3efb932 100644 --- a/tools.py +++ b/tools.py @@ -461,7 +461,7 @@ class DiscDist: ): self.logits = logits self.probs = torch.softmax(logits, -1) - self.buckets = torch.linspace(low, high, steps=255).to(device) + self.buckets = torch.linspace(low, high, steps=255, device=device) self.width = (self.buckets[-1] - self.buckets[0]) / 255 self.transfwd = transfwd self.transbwd = transbwd @@ -624,8 +624,7 @@ class UnnormalizedHuber(torchd.normal.Normal): def log_prob(self, event): return -( - torch.sqrt((event - self.mean) ** 2 + self._threshold**2) - - self._threshold + torch.sqrt((event - self.mean) ** 2 + self._threshold**2) - self._threshold ) def mode(self): @@ -762,7 +761,7 @@ class Optimizer: self._scaler.update() # self._opt.step() self._opt.zero_grad() - metrics[f"{self._name}_grad_norm"] = norm.item() + metrics[f"{self._name}_grad_norm"] = to_np(norm) return metrics def _apply_weight_decay(self, varibs):