diff --git a/tools.py b/tools.py index 3b8a912..24e0e3d 100644 --- a/tools.py +++ b/tools.py @@ -328,9 +328,8 @@ class TwoHotDistSymlog: self.width = (self.buckets[-1] - self.buckets[0]) / 255 def mean(self): - print("mean called") - _mode = self.probs * self.buckets - return symexp(torch.sum(_mode, dim=-1, keepdim=True)) + _mean = self.probs * self.buckets + return symexp(torch.sum(_mean, dim=-1, keepdim=True)) def mode(self): _mode = self.probs * self.buckets