mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
#!/usr/bin/env python3
|
|
"""Initialize modules for espnet2 neural networks."""
|
|
|
|
import torch
|
|
from typeguard import check_argument_types
|
|
|
|
|
|
def initialize(model: torch.nn.Module, init: str):
|
|
"""Initialize weights of a neural network module.
|
|
|
|
Parameters are initialized using the given method or distribution.
|
|
|
|
Custom initialization routines can be implemented into submodules
|
|
as function `espnet_initialization_fn` within the custom module.
|
|
|
|
Args:
|
|
model: Target.
|
|
init: Method of initialization.
|
|
"""
|
|
assert check_argument_types()
|
|
print("init with", init)
|
|
|
|
# weight init
|
|
for p in model.parameters():
|
|
if p.dim() > 1:
|
|
if init == "xavier_uniform":
|
|
torch.nn.init.xavier_uniform_(p.data)
|
|
elif init == "xavier_normal":
|
|
torch.nn.init.xavier_normal_(p.data)
|
|
elif init == "kaiming_uniform":
|
|
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
|
|
elif init == "kaiming_normal":
|
|
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
|
|
else:
|
|
raise ValueError("Unknown initialization: " + init)
|
|
# bias init
|
|
for name, p in model.named_parameters():
|
|
if ".bias" in name and p.dim() == 1:
|
|
p.data.zero_()
|