learnable initial state options for RSSM
This commit is contained in:
		
							parent
							
								
									1328ff1088
								
							
						
					
					
						commit
						0eb66997fb
					
				@ -63,6 +63,7 @@ defaults:
 | 
				
			|||||||
  weight_decay: 0.0
 | 
					  weight_decay: 0.0
 | 
				
			||||||
  unimix_ratio: 0.01
 | 
					  unimix_ratio: 0.01
 | 
				
			||||||
  action_unimix_ratio: 0.01
 | 
					  action_unimix_ratio: 0.01
 | 
				
			||||||
 | 
					  initial: 'learned'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Training
 | 
					  # Training
 | 
				
			||||||
  batch_size: 16
 | 
					  batch_size: 16
 | 
				
			||||||
 | 
				
			|||||||
@ -110,9 +110,10 @@ class Dreamer(nn.Module):
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            latent, action = state
 | 
					            latent, action = state
 | 
				
			||||||
        embed = self._wm.encoder(self._wm.preprocess(obs))
 | 
					        obs = self._wm.preprocess(obs)
 | 
				
			||||||
 | 
					        embed = self._wm.encoder(obs)
 | 
				
			||||||
        latent, _ = self._wm.dynamics.obs_step(
 | 
					        latent, _ = self._wm.dynamics.obs_step(
 | 
				
			||||||
            latent, action, embed, self._config.collect_dyn_sample
 | 
					            latent, action, embed, obs["is_first"], self._config.collect_dyn_sample
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        if self._config.eval_state_mean:
 | 
					        if self._config.eval_state_mean:
 | 
				
			||||||
            latent["stoch"] = latent["mean"]
 | 
					            latent["stoch"] = latent["mean"]
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										14
									
								
								models.py
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								models.py
									
									
									
									
									
								
							@ -66,6 +66,7 @@ class WorldModel(nn.Module):
 | 
				
			|||||||
            config.dyn_min_std,
 | 
					            config.dyn_min_std,
 | 
				
			||||||
            config.dyn_cell,
 | 
					            config.dyn_cell,
 | 
				
			||||||
            config.unimix_ratio,
 | 
					            config.unimix_ratio,
 | 
				
			||||||
 | 
					            config.initial,
 | 
				
			||||||
            config.num_actions,
 | 
					            config.num_actions,
 | 
				
			||||||
            embed_size,
 | 
					            embed_size,
 | 
				
			||||||
            config.device,
 | 
					            config.device,
 | 
				
			||||||
@ -95,6 +96,7 @@ class WorldModel(nn.Module):
 | 
				
			|||||||
                config.norm,
 | 
					                config.norm,
 | 
				
			||||||
                dist=config.reward_head,
 | 
					                dist=config.reward_head,
 | 
				
			||||||
                outscale=0.0,
 | 
					                outscale=0.0,
 | 
				
			||||||
 | 
					                device=config.device,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.heads["reward"] = networks.DenseHead(
 | 
					            self.heads["reward"] = networks.DenseHead(
 | 
				
			||||||
@ -106,6 +108,7 @@ class WorldModel(nn.Module):
 | 
				
			|||||||
                config.norm,
 | 
					                config.norm,
 | 
				
			||||||
                dist=config.reward_head,
 | 
					                dist=config.reward_head,
 | 
				
			||||||
                outscale=0.0,
 | 
					                outscale=0.0,
 | 
				
			||||||
 | 
					                device=config.device,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        self.heads["cont"] = networks.DenseHead(
 | 
					        self.heads["cont"] = networks.DenseHead(
 | 
				
			||||||
            feat_size,  # pytorch version
 | 
					            feat_size,  # pytorch version
 | 
				
			||||||
@ -115,6 +118,7 @@ class WorldModel(nn.Module):
 | 
				
			|||||||
            config.act,
 | 
					            config.act,
 | 
				
			||||||
            config.norm,
 | 
					            config.norm,
 | 
				
			||||||
            dist="binary",
 | 
					            dist="binary",
 | 
				
			||||||
 | 
					            device=config.device,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        for name in config.grad_heads:
 | 
					        for name in config.grad_heads:
 | 
				
			||||||
            assert name in self.heads, name
 | 
					            assert name in self.heads, name
 | 
				
			||||||
@ -140,7 +144,9 @@ class WorldModel(nn.Module):
 | 
				
			|||||||
        with tools.RequiresGrad(self):
 | 
					        with tools.RequiresGrad(self):
 | 
				
			||||||
            with torch.cuda.amp.autocast(self._use_amp):
 | 
					            with torch.cuda.amp.autocast(self._use_amp):
 | 
				
			||||||
                embed = self.encoder(data)
 | 
					                embed = self.encoder(data)
 | 
				
			||||||
                post, prior = self.dynamics.observe(embed, data["action"])
 | 
					                post, prior = self.dynamics.observe(
 | 
				
			||||||
 | 
					                    embed, data["action"], data["is_first"]
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
                kl_free = tools.schedule(self._config.kl_free, self._step)
 | 
					                kl_free = tools.schedule(self._config.kl_free, self._step)
 | 
				
			||||||
                dyn_scale = tools.schedule(self._config.dyn_scale, self._step)
 | 
					                dyn_scale = tools.schedule(self._config.dyn_scale, self._step)
 | 
				
			||||||
                rep_scale = tools.schedule(self._config.rep_scale, self._step)
 | 
					                rep_scale = tools.schedule(self._config.rep_scale, self._step)
 | 
				
			||||||
@ -204,7 +210,9 @@ class WorldModel(nn.Module):
 | 
				
			|||||||
        data = self.preprocess(data)
 | 
					        data = self.preprocess(data)
 | 
				
			||||||
        embed = self.encoder(data)
 | 
					        embed = self.encoder(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        states, _ = self.dynamics.observe(embed[:6, :5], data["action"][:6, :5])
 | 
					        states, _ = self.dynamics.observe(
 | 
				
			||||||
 | 
					            embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6]
 | 
					        recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6]
 | 
				
			||||||
        reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
 | 
					        reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
 | 
				
			||||||
        init = {k: v[:, -1] for k, v in states.items()}
 | 
					        init = {k: v[:, -1] for k, v in states.items()}
 | 
				
			||||||
@ -257,6 +265,7 @@ class ImagBehavior(nn.Module):
 | 
				
			|||||||
                config.norm,
 | 
					                config.norm,
 | 
				
			||||||
                config.value_head,
 | 
					                config.value_head,
 | 
				
			||||||
                outscale=0.0,
 | 
					                outscale=0.0,
 | 
				
			||||||
 | 
					                device=config.device,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.value = networks.DenseHead(
 | 
					            self.value = networks.DenseHead(
 | 
				
			||||||
@ -268,6 +277,7 @@ class ImagBehavior(nn.Module):
 | 
				
			|||||||
                config.norm,
 | 
					                config.norm,
 | 
				
			||||||
                config.value_head,
 | 
					                config.value_head,
 | 
				
			||||||
                outscale=0.0,
 | 
					                outscale=0.0,
 | 
				
			||||||
 | 
					                device=config.device,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        if config.slow_value_target:
 | 
					        if config.slow_value_target:
 | 
				
			||||||
            self._slow_value = copy.deepcopy(self.value)
 | 
					            self._slow_value = copy.deepcopy(self.value)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										52
									
								
								networks.py
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								networks.py
									
									
									
									
									
								
							@ -28,6 +28,7 @@ class RSSM(nn.Module):
 | 
				
			|||||||
        min_std=0.1,
 | 
					        min_std=0.1,
 | 
				
			||||||
        cell="gru",
 | 
					        cell="gru",
 | 
				
			||||||
        unimix_ratio=0.01,
 | 
					        unimix_ratio=0.01,
 | 
				
			||||||
 | 
					        initial="learned",
 | 
				
			||||||
        num_actions=None,
 | 
					        num_actions=None,
 | 
				
			||||||
        embed=None,
 | 
					        embed=None,
 | 
				
			||||||
        device=None,
 | 
					        device=None,
 | 
				
			||||||
@ -48,6 +49,7 @@ class RSSM(nn.Module):
 | 
				
			|||||||
        self._std_act = std_act
 | 
					        self._std_act = std_act
 | 
				
			||||||
        self._temp_post = temp_post
 | 
					        self._temp_post = temp_post
 | 
				
			||||||
        self._unimix_ratio = unimix_ratio
 | 
					        self._unimix_ratio = unimix_ratio
 | 
				
			||||||
 | 
					        self._initial = initial
 | 
				
			||||||
        self._embed = embed
 | 
					        self._embed = embed
 | 
				
			||||||
        self._device = device
 | 
					        self._device = device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -112,6 +114,12 @@ class RSSM(nn.Module):
 | 
				
			|||||||
            self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
 | 
					            self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
 | 
				
			||||||
            self._obs_stat_layer.apply(tools.weight_init)
 | 
					            self._obs_stat_layer.apply(tools.weight_init)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self._initial == "learned":
 | 
				
			||||||
 | 
					            self.W = torch.nn.Parameter(
 | 
				
			||||||
 | 
					                torch.zeros((1, self._deter), device=torch.device(self._device)),
 | 
				
			||||||
 | 
					                requires_grad=True,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def initial(self, batch_size):
 | 
					    def initial(self, batch_size):
 | 
				
			||||||
        deter = torch.zeros(batch_size, self._deter).to(self._device)
 | 
					        deter = torch.zeros(batch_size, self._deter).to(self._device)
 | 
				
			||||||
        if self._discrete:
 | 
					        if self._discrete:
 | 
				
			||||||
@ -131,19 +139,27 @@ class RSSM(nn.Module):
 | 
				
			|||||||
                stoch=torch.zeros([batch_size, self._stoch]).to(self._device),
 | 
					                stoch=torch.zeros([batch_size, self._stoch]).to(self._device),
 | 
				
			||||||
                deter=deter,
 | 
					                deter=deter,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        return state
 | 
					        if self._initial == "zeros":
 | 
				
			||||||
 | 
					            return state
 | 
				
			||||||
 | 
					        elif self._initial == "learned":
 | 
				
			||||||
 | 
					            state["deter"] = torch.tanh(self.W).repeat(batch_size, 1)
 | 
				
			||||||
 | 
					            state["stoch"] = self.get_stoch(state["deter"])
 | 
				
			||||||
 | 
					            return state
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise NotImplementedError(self._initial)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def observe(self, embed, action, state=None):
 | 
					    def observe(self, embed, action, is_first, state=None):
 | 
				
			||||||
        swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
 | 
					        swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
 | 
				
			||||||
        if state is None:
 | 
					        if state is None:
 | 
				
			||||||
            state = self.initial(action.shape[0])
 | 
					            state = self.initial(action.shape[0])
 | 
				
			||||||
        # (batch, time, ch) -> (time, batch, ch)
 | 
					        # (batch, time, ch) -> (time, batch, ch)
 | 
				
			||||||
        embed, action = swap(embed), swap(action)
 | 
					        embed, action, is_first = swap(embed), swap(action), swap(is_first)
 | 
				
			||||||
 | 
					        # prev_state[0] means selecting posterior of return(posterior, prior) from obs_step
 | 
				
			||||||
        post, prior = tools.static_scan(
 | 
					        post, prior = tools.static_scan(
 | 
				
			||||||
            lambda prev_state, prev_act, embed: self.obs_step(
 | 
					            lambda prev_state, prev_act, embed, is_first: self.obs_step(
 | 
				
			||||||
                prev_state[0], prev_act, embed
 | 
					                prev_state[0], prev_act, embed, is_first
 | 
				
			||||||
            ),
 | 
					            ),
 | 
				
			||||||
            (action, embed),
 | 
					            (action, embed, is_first),
 | 
				
			||||||
            (state, state),
 | 
					            (state, state),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -184,10 +200,22 @@ class RSSM(nn.Module):
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
        return dist
 | 
					        return dist
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def obs_step(self, prev_state, prev_action, embed, sample=True):
 | 
					    def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
 | 
				
			||||||
        # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
 | 
					        # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
 | 
				
			||||||
        # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
 | 
					        # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
 | 
				
			||||||
        prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
 | 
					        prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if torch.sum(is_first) > 0:
 | 
				
			||||||
 | 
					            is_first = is_first[:, None]
 | 
				
			||||||
 | 
					            prev_action *= 1.0 - is_first
 | 
				
			||||||
 | 
					            init_state = self.initial(len(is_first))
 | 
				
			||||||
 | 
					            for key, val in prev_state.items():
 | 
				
			||||||
 | 
					                is_first_r = torch.reshape(
 | 
				
			||||||
 | 
					                    is_first,
 | 
				
			||||||
 | 
					                    is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                val = val * (1.0 - is_first_r) + init_state[key] * is_first_r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        prior = self.img_step(prev_state, prev_action, None, sample)
 | 
					        prior = self.img_step(prev_state, prev_action, None, sample)
 | 
				
			||||||
        if self._shared:
 | 
					        if self._shared:
 | 
				
			||||||
            post = self.img_step(prev_state, prev_action, embed, sample)
 | 
					            post = self.img_step(prev_state, prev_action, embed, sample)
 | 
				
			||||||
@ -242,6 +270,12 @@ class RSSM(nn.Module):
 | 
				
			|||||||
        prior = {"stoch": stoch, "deter": deter, **stats}
 | 
					        prior = {"stoch": stoch, "deter": deter, **stats}
 | 
				
			||||||
        return prior
 | 
					        return prior
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_stoch(self, deter):
 | 
				
			||||||
 | 
					        x = self._img_out_layers(deter)
 | 
				
			||||||
 | 
					        stats = self._suff_stats_layer("ims", x)
 | 
				
			||||||
 | 
					        dist = self.get_dist(stats)
 | 
				
			||||||
 | 
					        return dist.mode()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _suff_stats_layer(self, name, x):
 | 
					    def _suff_stats_layer(self, name, x):
 | 
				
			||||||
        if self._discrete:
 | 
					        if self._discrete:
 | 
				
			||||||
            if name == "ims":
 | 
					            if name == "ims":
 | 
				
			||||||
@ -435,6 +469,7 @@ class DenseHead(nn.Module):
 | 
				
			|||||||
        dist="normal",
 | 
					        dist="normal",
 | 
				
			||||||
        std=1.0,
 | 
					        std=1.0,
 | 
				
			||||||
        outscale=1.0,
 | 
					        outscale=1.0,
 | 
				
			||||||
 | 
					        device="cuda",
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super(DenseHead, self).__init__()
 | 
					        super(DenseHead, self).__init__()
 | 
				
			||||||
        self._shape = (shape,) if isinstance(shape, int) else shape
 | 
					        self._shape = (shape,) if isinstance(shape, int) else shape
 | 
				
			||||||
@ -446,6 +481,7 @@ class DenseHead(nn.Module):
 | 
				
			|||||||
        self._norm = norm
 | 
					        self._norm = norm
 | 
				
			||||||
        self._dist = dist
 | 
					        self._dist = dist
 | 
				
			||||||
        self._std = std
 | 
					        self._std = std
 | 
				
			||||||
 | 
					        self._device = device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        layers = []
 | 
					        layers = []
 | 
				
			||||||
        for index in range(self._layers):
 | 
					        for index in range(self._layers):
 | 
				
			||||||
@ -491,7 +527,7 @@ class DenseHead(nn.Module):
 | 
				
			|||||||
                )
 | 
					                )
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        if self._dist == "twohot_symlog":
 | 
					        if self._dist == "twohot_symlog":
 | 
				
			||||||
            return tools.TwoHotDistSymlog(logits=mean)
 | 
					            return tools.TwoHotDistSymlog(logits=mean, device=self._device)
 | 
				
			||||||
        raise NotImplementedError(self._dist)
 | 
					        raise NotImplementedError(self._dist)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user