limit action values in sampling stage
This commit is contained in:
		
							parent
							
								
									a9e85e8b7c
								
							
						
					
					
						commit
						a27711ab96
					
				| @ -239,9 +239,10 @@ class ImagBehavior(nn.Module): | ||||
|             "learned", | ||||
|             config.actor["min_std"], | ||||
|             config.actor["max_std"], | ||||
|             config.actor["temp"], | ||||
|             absmax=1.0, | ||||
|             temp=config.actor["temp"], | ||||
|             unimix_ratio=config.actor["unimix_ratio"], | ||||
|             outscale=1.0, | ||||
|             outscale=config.actor["outscale"], | ||||
|             name="Actor", | ||||
|         ) | ||||
|         self.value = networks.MLP( | ||||
|  | ||||
							
								
								
									
										30
									
								
								networks.py
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								networks.py
									
									
									
									
									
								
							| @ -200,9 +200,8 @@ class RSSM(nn.Module): | ||||
|         return dist | ||||
| 
 | ||||
|     def obs_step(self, prev_state, prev_action, embed, is_first, sample=True): | ||||
|         # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer) | ||||
|         # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer) | ||||
|         # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs | ||||
|         prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() | ||||
| 
 | ||||
|         # initialize all prev_state | ||||
|         if prev_state == None or torch.sum(is_first) == len(is_first): | ||||
| @ -246,7 +245,6 @@ class RSSM(nn.Module): | ||||
|     # this is used for making future image | ||||
|     def img_step(self, prev_state, prev_action, embed=None, sample=True): | ||||
|         # (batch, stoch, discrete_num) | ||||
|         prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() | ||||
|         prev_stoch = prev_state["stoch"] | ||||
|         if self._discrete: | ||||
|             shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete] | ||||
| @ -644,6 +642,7 @@ class MLP(nn.Module): | ||||
|         std=1.0, | ||||
|         min_std=0.1, | ||||
|         max_std=1.0, | ||||
|         absmax=None, | ||||
|         temp=0.1, | ||||
|         unimix_ratio=0.01, | ||||
|         outscale=1.0, | ||||
| @ -660,12 +659,13 @@ class MLP(nn.Module): | ||||
|         norm = getattr(torch.nn, norm) | ||||
|         self._dist = dist | ||||
|         self._std = std | ||||
|         self._symlog_inputs = symlog_inputs | ||||
|         self._device = device | ||||
|         self._min_std = min_std | ||||
|         self._max_std = max_std | ||||
|         self._absmax = absmax | ||||
|         self._temp = temp | ||||
|         self._unimix_ratio = unimix_ratio | ||||
|         self._symlog_inputs = symlog_inputs | ||||
|         self._device = device | ||||
| 
 | ||||
|         self.layers = nn.Sequential() | ||||
|         for index in range(self._layers): | ||||
| @ -738,23 +738,33 @@ class MLP(nn.Module): | ||||
|                 std + 2.0 | ||||
|             ) + self._min_std | ||||
|             dist = torchd.normal.Normal(torch.tanh(mean), std) | ||||
|             dist = tools.ContDist(torchd.independent.Independent(dist, 1), absmax=1.0) | ||||
|             dist = tools.ContDist( | ||||
|                 torchd.independent.Independent(dist, 1), absmax=self._absmax | ||||
|             ) | ||||
|         elif self._dist == "normal_std_fixed": | ||||
|             dist = torchd.normal.Normal(mean, self._std) | ||||
|             dist = tools.ContDist(torchd.independent.Independent(dist, 1)) | ||||
|             dist = tools.ContDist( | ||||
|                 torchd.independent.Independent(dist, 1), absmax=self._absmax | ||||
|             ) | ||||
|         elif self._dist == "trunc_normal": | ||||
|             mean = torch.tanh(mean) | ||||
|             std = 2 * torch.sigmoid(std / 2) + self._min_std | ||||
|             dist = tools.SafeTruncatedNormal(mean, std, -1, 1) | ||||
|             dist = tools.ContDist(torchd.independent.Independent(dist, 1)) | ||||
|             dist = tools.ContDist( | ||||
|                 torchd.independent.Independent(dist, 1), absmax=self._absmax | ||||
|             ) | ||||
|         elif self._dist == "onehot": | ||||
|             dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio) | ||||
|         elif self._dist == "onehot_gumble": | ||||
|             dist = tools.ContDist(torchd.gumbel.Gumbel(mean, 1 / self._temp)) | ||||
|             dist = tools.ContDist( | ||||
|                 torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax | ||||
|             ) | ||||
|         elif dist == "huber": | ||||
|             dist = tools.ContDist( | ||||
|                 torchd.independent.Independent( | ||||
|                     tools.UnnormalizedHuber(mean, std, 1.0), len(shape) | ||||
|                     tools.UnnormalizedHuber(mean, std, 1.0), | ||||
|                     len(shape), | ||||
|                     absmax=self._absmax, | ||||
|                 ) | ||||
|             ) | ||||
|         elif dist == "binary": | ||||
|  | ||||
							
								
								
									
										13
									
								
								tools.py
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								tools.py
									
									
									
									
									
								
							| @ -562,10 +562,11 @@ class SymlogDist: | ||||
| 
 | ||||
| 
 | ||||
| class ContDist: | ||||
|     def __init__(self, dist=None): | ||||
|     def __init__(self, dist=None, absmax=None): | ||||
|         super().__init__() | ||||
|         self._dist = dist | ||||
|         self.mean = dist.mean | ||||
|         self.absmax = absmax | ||||
| 
 | ||||
|     def __getattr__(self, name): | ||||
|         return getattr(self._dist, name) | ||||
| @ -574,10 +575,16 @@ class ContDist: | ||||
|         return self._dist.entropy() | ||||
| 
 | ||||
|     def mode(self): | ||||
|         return self._dist.mean | ||||
|         out = self._dist.mean | ||||
|         if self.absmax is not None: | ||||
|             out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach() | ||||
|         return out | ||||
| 
 | ||||
|     def sample(self, sample_shape=()): | ||||
|         return self._dist.rsample(sample_shape) | ||||
|         out = self._dist.rsample(sample_shape) | ||||
|         if self.absmax is not None: | ||||
|             out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach() | ||||
|         return out | ||||
| 
 | ||||
|     def log_prob(self, x): | ||||
|         return self._dist.log_prob(x) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user