From 970d1dc3e91a59f60a6b1512b8c9ccd64658ce2a Mon Sep 17 00:00:00 2001 From: NM512 Date: Sat, 17 Jun 2023 15:28:26 +0900 Subject: [PATCH] bug fix of limits for trunc_normal_ --- tools.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools.py b/tools.py index bc46903..8b52379 100644 --- a/tools.py +++ b/tools.py @@ -804,7 +804,9 @@ def weight_init(m): denoms = (in_num + out_num) / 2.0 scale = 1.0 / denoms std = np.sqrt(scale) / 0.87962566103423978 - nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) + nn.init.trunc_normal_( + m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std + ) if hasattr(m.bias, "data"): m.bias.data.fill_(0.0) elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):