diff --git a/tools.py b/tools.py index bc46903..8b52379 100644 --- a/tools.py +++ b/tools.py @@ -804,7 +804,9 @@ def weight_init(m): denoms = (in_num + out_num) / 2.0 scale = 1.0 / denoms std = np.sqrt(scale) / 0.87962566103423978 - nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) + nn.init.trunc_normal_( + m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std + ) if hasattr(m.bias, "data"): m.bias.data.fill_(0.0) elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):