bug fix of limits for trunc_normal_
This commit is contained in:
parent
f7c505579c
commit
970d1dc3e9
4
tools.py
4
tools.py
@ -804,7 +804,9 @@ def weight_init(m):
|
||||
denoms = (in_num + out_num) / 2.0
|
||||
scale = 1.0 / denoms
|
||||
std = np.sqrt(scale) / 0.87962566103423978
|
||||
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
|
||||
nn.init.trunc_normal_(
|
||||
m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std
|
||||
)
|
||||
if hasattr(m.bias, "data"):
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
||||
|
Loading…
x
Reference in New Issue
Block a user