bug fix of limits for trunc_normal_
This commit is contained in:
parent
f7c505579c
commit
970d1dc3e9
4
tools.py
4
tools.py
@ -804,7 +804,9 @@ def weight_init(m):
|
|||||||
denoms = (in_num + out_num) / 2.0
|
denoms = (in_num + out_num) / 2.0
|
||||||
scale = 1.0 / denoms
|
scale = 1.0 / denoms
|
||||||
std = np.sqrt(scale) / 0.87962566103423978
|
std = np.sqrt(scale) / 0.87962566103423978
|
||||||
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
|
nn.init.trunc_normal_(
|
||||||
|
m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std
|
||||||
|
)
|
||||||
if hasattr(m.bias, "data"):
|
if hasattr(m.bias, "data"):
|
||||||
m.bias.data.fill_(0.0)
|
m.bias.data.fill_(0.0)
|
||||||
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user