Closes #947 This removes all kwargs from all policy constructors. While doing that, I also improved several names and added a whole lot of TODOs. ## Functional changes: 1. Added possibility to pass None as `critic2` and `critic2_optim`. In fact, the default behavior then should cover the absolute majority of cases 2. Added a function called `clone_optimizer` as a temporary measure to support passing `critic2_optim=None` ## Breaking changes: 1. `action_space` is no longer optional. In fact, it already was non-optional, as there was a ValueError in BasePolicy.init. So now several examples were fixed to reflect that 2. `reward_normalization` removed from DDPG and children. It was never allowed to pass it as `True` there, an error would have been raised in `compute_n_step_reward`. Now I removed it from the interface 3. renamed `critic1` and similar to `critic`, in order to have uniform interfaces. Note that the `critic` in DDPG was optional for the sole reason that child classes used `critic1`. I removed this optionality (DDPG can't do anything with `critic=None`) 4. Several renamings of fields (mostly private to public, so backwards compatible) ## Additional changes: 1. Removed type and default declaration from docstring. This kind of duplication is really not necessary 2. Policy constructors are now only called using named arguments, not a fragile mixture of positional and named as before 5. Minor beautifications in typing and code 6. Generally shortened docstrings and made them uniform across all policies (hopefully) ## Comment: With these changes, several problems in tianshou's inheritance hierarchy become more apparent. I tried highlighting them for future work. --------- Co-authored-by: Dominik Jain <d.jain@appliedai.de>
		
			
				
	
	
		
			65 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			65 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from collections.abc import Iterator
 | |
| from typing import TypeVar
 | |
| 
 | |
| import torch
 | |
| from torch import nn
 | |
| 
 | |
| 
 | |
| def optim_step(
 | |
|     loss: torch.Tensor,
 | |
|     optim: torch.optim.Optimizer,
 | |
|     module: nn.Module,
 | |
|     max_grad_norm: float | None = None,
 | |
| ) -> None:
 | |
|     """Perform a single optimization step.
 | |
| 
 | |
|     :param loss:
 | |
|     :param optim:
 | |
|     :param module:
 | |
|     :param max_grad_norm: if passed, will clip gradients using this
 | |
|     """
 | |
|     optim.zero_grad()
 | |
|     loss.backward()
 | |
|     if max_grad_norm:
 | |
|         nn.utils.clip_grad_norm_(module.parameters(), max_norm=max_grad_norm)
 | |
|     optim.step()
 | |
| 
 | |
| 
 | |
| _STANDARD_TORCH_OPTIMIZERS = [
 | |
|     torch.optim.Adam,
 | |
|     torch.optim.SGD,
 | |
|     torch.optim.RMSprop,
 | |
|     torch.optim.Adadelta,
 | |
|     torch.optim.AdamW,
 | |
|     torch.optim.Adamax,
 | |
|     torch.optim.NAdam,
 | |
|     torch.optim.SparseAdam,
 | |
|     torch.optim.LBFGS,
 | |
| ]
 | |
| 
 | |
| TOptim = TypeVar("TOptim", bound=torch.optim.Optimizer)
 | |
| 
 | |
| 
 | |
| def clone_optimizer(
 | |
|     optim: TOptim,
 | |
|     new_params: nn.Parameter | Iterator[nn.Parameter],
 | |
| ) -> TOptim:
 | |
|     """Clone an optimizer to get a new optim instance with new parameters.
 | |
| 
 | |
|     **WARNING**: This is a temporary measure, and should not be used in downstream code!
 | |
|     Once tianshou interfaces have moved to optimizer factories instead of optimizers,
 | |
|     this will be removed.
 | |
| 
 | |
|     :param optim: the optimizer to clone
 | |
|     :param new_params: the new parameters to use
 | |
|     :return: a new optimizer with the same configuration as the old one
 | |
|     """
 | |
|     optim_class = type(optim)
 | |
|     # custom optimizers may not behave as expected
 | |
|     if optim_class not in _STANDARD_TORCH_OPTIMIZERS:
 | |
|         raise ValueError(
 | |
|             f"Cannot clone optimizer {optim} of type {optim_class}"
 | |
|             f"Currently, only standard torch optimizers are supported.",
 | |
|         )
 | |
|     return optim_class(new_params, **optim.defaults)
 |