26 lines
		
	
	
		
			485 B
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			26 lines
		
	
	
		
			485 B
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | from typing import overload | ||
|  | 
 | ||
|  | import torch | ||
|  | 
 | ||
|  | 
 | ||
|  | @overload | ||
|  | def to_optional_float(x: torch.Tensor) -> float: | ||
|  |     ... | ||
|  | 
 | ||
|  | 
 | ||
|  | @overload | ||
|  | def to_optional_float(x: float) -> float: | ||
|  |     ... | ||
|  | 
 | ||
|  | 
 | ||
|  | @overload | ||
|  | def to_optional_float(x: None) -> None: | ||
|  |     ... | ||
|  | 
 | ||
|  | 
 | ||
|  | def to_optional_float(x: torch.Tensor | float | None) -> float | None: | ||
|  |     """For the common case where one needs to extract a float from a scalar Tensor, which may be None.""" | ||
|  |     if isinstance(x, torch.Tensor): | ||
|  |         return x.item() | ||
|  |     return x |