removed scheduling function
This commit is contained in:
		
							parent
							
								
									2cdba230d8
								
							
						
					
					
						commit
						16635df3e4
					
				
							
								
								
									
										10
									
								
								configs.yaml
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								configs.yaml
									
									
									
									
									
								
							| @ -59,9 +59,9 @@ defaults: | |||||||
|     {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse} |     {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse} | ||||||
|   value_head: 'symlog_disc' |   value_head: 'symlog_disc' | ||||||
|   reward_head: 'symlog_disc' |   reward_head: 'symlog_disc' | ||||||
|   dyn_scale: '0.5' |   dyn_scale: 0.5 | ||||||
|   rep_scale: '0.1' |   rep_scale: 0.1 | ||||||
|   kl_free: '1.0' |   kl_free: 1.0 | ||||||
|   cont_scale: 1.0 |   cont_scale: 1.0 | ||||||
|   reward_scale: 1.0 |   reward_scale: 1.0 | ||||||
|   weight_decay: 0.0 |   weight_decay: 0.0 | ||||||
| @ -93,10 +93,10 @@ defaults: | |||||||
|   discount_lambda: 0.95 |   discount_lambda: 0.95 | ||||||
|   imag_horizon: 15 |   imag_horizon: 15 | ||||||
|   imag_gradient: 'dynamics' |   imag_gradient: 'dynamics' | ||||||
|   imag_gradient_mix: '0.0' |   imag_gradient_mix: 0.0 | ||||||
|   imag_sample: True |   imag_sample: True | ||||||
|   actor_dist: 'normal' |   actor_dist: 'normal' | ||||||
|   actor_entropy: '3e-4' |   actor_entropy: 3e-4 | ||||||
|   actor_state_entropy: 0.0 |   actor_state_entropy: 0.0 | ||||||
|   actor_init_std: 1.0 |   actor_init_std: 1.0 | ||||||
|   actor_min_std: 0.1 |   actor_min_std: 0.1 | ||||||
|  | |||||||
							
								
								
									
										10
									
								
								dreamer.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								dreamer.py
									
									
									
									
									
								
							| @ -40,16 +40,6 @@ class Dreamer(nn.Module): | |||||||
|         # this is update step |         # this is update step | ||||||
|         self._step = logger.step // config.action_repeat |         self._step = logger.step // config.action_repeat | ||||||
|         self._update_count = 0 |         self._update_count = 0 | ||||||
|         # Schedules. |  | ||||||
|         config.actor_entropy = lambda x=config.actor_entropy: tools.schedule( |  | ||||||
|             x, self._step |  | ||||||
|         ) |  | ||||||
|         config.actor_state_entropy = ( |  | ||||||
|             lambda x=config.actor_state_entropy: tools.schedule(x, self._step) |  | ||||||
|         ) |  | ||||||
|         config.imag_gradient_mix = lambda x=config.imag_gradient_mix: tools.schedule( |  | ||||||
|             x, self._step |  | ||||||
|         ) |  | ||||||
|         self._dataset = dataset |         self._dataset = dataset | ||||||
|         self._wm = models.WorldModel(obs_space, act_space, self._step, config) |         self._wm = models.WorldModel(obs_space, act_space, self._step, config) | ||||||
|         self._task_behavior = models.ImagBehavior( |         self._task_behavior = models.ImagBehavior( | ||||||
|  | |||||||
							
								
								
									
										24
									
								
								models.py
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								models.py
									
									
									
									
									
								
							| @ -128,9 +128,9 @@ class WorldModel(nn.Module): | |||||||
|                 post, prior = self.dynamics.observe( |                 post, prior = self.dynamics.observe( | ||||||
|                     embed, data["action"], data["is_first"] |                     embed, data["action"], data["is_first"] | ||||||
|                 ) |                 ) | ||||||
|                 kl_free = tools.schedule(self._config.kl_free, self._step) |                 kl_free = self._config.kl_free | ||||||
|                 dyn_scale = tools.schedule(self._config.dyn_scale, self._step) |                 dyn_scale = self._config.dyn_scale | ||||||
|                 rep_scale = tools.schedule(self._config.rep_scale, self._step) |                 rep_scale = self._config.rep_scale | ||||||
|                 kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss( |                 kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss( | ||||||
|                     post, prior, kl_free, dyn_scale, rep_scale |                     post, prior, kl_free, dyn_scale, rep_scale | ||||||
|                 ) |                 ) | ||||||
| @ -393,10 +393,10 @@ class ImagBehavior(nn.Module): | |||||||
|             discount = self._config.discount * self._world_model.heads["cont"](inp).mean |             discount = self._config.discount * self._world_model.heads["cont"](inp).mean | ||||||
|         else: |         else: | ||||||
|             discount = self._config.discount * torch.ones_like(reward) |             discount = self._config.discount * torch.ones_like(reward) | ||||||
|         if self._config.future_entropy and self._config.actor_entropy() > 0: |         if self._config.future_entropy and self._config.actor_entropy > 0: | ||||||
|             reward += self._config.actor_entropy() * actor_ent |             reward += self._config.actor_entropy * actor_ent | ||||||
|         if self._config.future_entropy and self._config.actor_state_entropy() > 0: |         if self._config.future_entropy and self._config.actor_state_entropy > 0: | ||||||
|             reward += self._config.actor_state_entropy() * state_ent |             reward += self._config.actor_state_entropy * state_ent | ||||||
|         value = self.value(imag_feat).mode() |         value = self.value(imag_feat).mode() | ||||||
|         target = tools.lambda_return( |         target = tools.lambda_return( | ||||||
|             reward[1:], |             reward[1:], | ||||||
| @ -450,16 +450,16 @@ class ImagBehavior(nn.Module): | |||||||
|                 policy.log_prob(imag_action)[:-1][:, :, None] |                 policy.log_prob(imag_action)[:-1][:, :, None] | ||||||
|                 * (target - self.value(imag_feat[:-1]).mode()).detach() |                 * (target - self.value(imag_feat[:-1]).mode()).detach() | ||||||
|             ) |             ) | ||||||
|             mix = self._config.imag_gradient_mix() |             mix = self._config.imag_gradient_mix | ||||||
|             actor_target = mix * target + (1 - mix) * actor_target |             actor_target = mix * target + (1 - mix) * actor_target | ||||||
|             metrics["imag_gradient_mix"] = mix |             metrics["imag_gradient_mix"] = mix | ||||||
|         else: |         else: | ||||||
|             raise NotImplementedError(self._config.imag_gradient) |             raise NotImplementedError(self._config.imag_gradient) | ||||||
|         if not self._config.future_entropy and (self._config.actor_entropy() > 0): |         if not self._config.future_entropy and self._config.actor_entropy > 0: | ||||||
|             actor_entropy = self._config.actor_entropy() * actor_ent[:-1][:, :, None] |             actor_entropy = self._config.actor_entropy * actor_ent[:-1][:, :, None] | ||||||
|             actor_target += actor_entropy |             actor_target += actor_entropy | ||||||
|         if not self._config.future_entropy and (self._config.actor_state_entropy() > 0): |         if not self._config.future_entropy and (self._config.actor_state_entropy > 0): | ||||||
|             state_entropy = self._config.actor_state_entropy() * state_ent[:-1] |             state_entropy = self._config.actor_state_entropy * state_ent[:-1] | ||||||
|             actor_target += state_entropy |             actor_target += state_entropy | ||||||
|             metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy)) |             metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy)) | ||||||
|         actor_loss = -torch.mean(weights[:-1] * actor_target) |         actor_loss = -torch.mean(weights[:-1] * actor_target) | ||||||
|  | |||||||
							
								
								
									
										27
									
								
								tools.py
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								tools.py
									
									
									
									
									
								
							| @ -899,33 +899,6 @@ class Until: | |||||||
|         return step < self._until |         return step < self._until | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def schedule(string, step): |  | ||||||
|     try: |  | ||||||
|         return float(string) |  | ||||||
|     except ValueError: |  | ||||||
|         match = re.match(r"linear\((.+),(.+),(.+)\)", string) |  | ||||||
|         if match: |  | ||||||
|             initial, final, duration = [float(group) for group in match.groups()] |  | ||||||
|             mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0] |  | ||||||
|             return (1 - mix) * initial + mix * final |  | ||||||
|         match = re.match(r"warmup\((.+),(.+)\)", string) |  | ||||||
|         if match: |  | ||||||
|             warmup, value = [float(group) for group in match.groups()] |  | ||||||
|             scale = torch.clip(step / warmup, 0, 1) |  | ||||||
|             return scale * value |  | ||||||
|         match = re.match(r"exp\((.+),(.+),(.+)\)", string) |  | ||||||
|         if match: |  | ||||||
|             initial, final, halflife = [float(group) for group in match.groups()] |  | ||||||
|             return (initial - final) * 0.5 ** (step / halflife) + final |  | ||||||
|         match = re.match(r"horizon\((.+),(.+),(.+)\)", string) |  | ||||||
|         if match: |  | ||||||
|             initial, final, duration = [float(group) for group in match.groups()] |  | ||||||
|             mix = torch.clip(step / duration, 0, 1) |  | ||||||
|             horizon = (1 - mix) * initial + mix * final |  | ||||||
|             return 1 - 1 / horizon |  | ||||||
|         raise NotImplementedError(string) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def weight_init(m): | def weight_init(m): | ||||||
|     if isinstance(m, nn.Linear): |     if isinstance(m, nn.Linear): | ||||||
|         in_num = m.in_features |         in_num = m.in_features | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user