bug fix of limits for trunc_normal_

This commit is contained in:
NM512 2023-06-17 15:28:26 +09:00
parent f7c505579c
commit 970d1dc3e9

View File

@ -804,7 +804,9 @@ def weight_init(m):
denoms = (in_num + out_num) / 2.0 denoms = (in_num + out_num) / 2.0
scale = 1.0 / denoms scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978 std = np.sqrt(scale) / 0.87962566103423978
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) nn.init.trunc_normal_(
m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std
)
if hasattr(m.bias, "data"): if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0) m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):