Fix Plan2Explore
This commit is contained in:
		
							parent
							
								
									2cdba230d8
								
							
						
					
					
						commit
						6bf120237b
					
				| @ -58,9 +58,9 @@ class Plan2Explore(nn.Module): | ||||
|             "feat": config.dyn_stoch + config.dyn_deter, | ||||
|         }[self._config.disag_target] | ||||
|         kw = dict( | ||||
|             inp_dim=feat_size + config.num_actions | ||||
|             inp_dim=feat_size + (config.num_actions | ||||
|             if config.disag_action_cond | ||||
|             else 0,  # pytorch version | ||||
|             else 0),  # pytorch version | ||||
|             shape=size, | ||||
|             layers=config.disag_layers, | ||||
|             units=config.disag_units, | ||||
| @ -72,7 +72,7 @@ class Plan2Explore(nn.Module): | ||||
|         kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp) | ||||
|         self._model_opt = tools.Optimizer( | ||||
|             "explorer", | ||||
|             self.parameters(), | ||||
|             self._networks.parameters(), | ||||
|             config.model_lr, | ||||
|             config.opt_eps, | ||||
|             config.grad_clip, | ||||
| @ -80,7 +80,7 @@ class Plan2Explore(nn.Module): | ||||
|         ) | ||||
| 
 | ||||
|     def train(self, start, context, data): | ||||
|         with tools.RequiresGrad(self): | ||||
|         with tools.RequiresGrad(self._networks): | ||||
|             metrics = {} | ||||
|             stoch = start["stoch"] | ||||
|             if self._config.dyn_discrete: | ||||
| @ -129,5 +129,5 @@ class Plan2Explore(nn.Module): | ||||
|                 [torch.mean(pred.log_prob(targets))[None] for pred in preds], 0 | ||||
|             ) | ||||
|             loss = -torch.mean(likes) | ||||
|         metrics = self._model_opt(loss, self.parameters()) | ||||
|         metrics = self._model_opt(loss, self._networks.parameters()) | ||||
|         return metrics | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user