Replace deprecated function

This commit is contained in:
KamioRinn 2024-07-21 01:40:29 +08:00
parent f043bfebbb
commit 25a2bbd3b9
5 changed files with 16 additions and 11 deletions

View File

@ -433,7 +433,8 @@ class FFN(nn.Module):
import torch.nn as nn
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
class Depthwise_Separable_Conv1D(nn.Module):

View File

@ -9,7 +9,8 @@ from module import modules
from module import attentions
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer

View File

@ -9,7 +9,8 @@ from module import modules
from module import attentions_onnx as attentions
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer

View File

@ -5,7 +5,8 @@ from torch import nn
from torch.nn import functional as F
from torch.nn import Conv1d
from torch.nn.utils import weight_norm, remove_weight_norm
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm
from module import commons
from module.commons import init_weights, get_padding
@ -159,7 +160,7 @@ class WN(torch.nn.Module):
cond_layer = torch.nn.Conv1d(
gin_channels, 2 * hidden_channels * n_layers, 1
)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
self.cond_layer = weight_norm(cond_layer, name="weight")
for i in range(n_layers):
dilation = dilation_rate**i
@ -171,7 +172,7 @@ class WN(torch.nn.Module):
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
in_layer = weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
@ -181,7 +182,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
res_skip_layer = weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
@ -213,11 +214,11 @@ class WN(torch.nn.Module):
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
remove_weight_norm(l)
class ResBlock1(torch.nn.Module):

View File

@ -2,7 +2,8 @@
import torch
from torch import nn
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm
from module.attentions import MultiHeadAttention