gpt_sovits_v3

gpt_sovits_v3
This commit is contained in:
RVC-Boss 2025-02-11 21:14:48 +08:00 committed by GitHub
parent 17d9be2a70
commit 6b12b4b10b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 553 additions and 0 deletions

View File

@ -0,0 +1,77 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import torch
import torch.nn as nn
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
from alias_free_activation.cuda import load
anti_alias_activation_cuda = load.load()
class FusedAntiAliasActivation(torch.autograd.Function):
"""
Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
The hyperparameters are hard-coded in the kernel to maximize speed.
NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
"""
@staticmethod
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
activation_results = anti_alias_activation_cuda.forward(
inputs, up_ftr, down_ftr, alpha, beta
)
return activation_results
@staticmethod
def backward(ctx, output_grads):
raise NotImplementedError
return output_grads, None, None
class Activation1d(nn.Module):
def __init__(
self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
fused: bool = True,
):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
self.fused = fused # Whether to use fused CUDA kernel or not
def forward(self, x):
if not self.fused:
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
else:
if self.act.__class__.__name__ == "Snake":
beta = self.act.alpha.data # Snake uses same params for alpha and beta
else:
beta = (
self.act.beta.data
) # Snakebeta uses different params for alpha and beta
alpha = self.act.alpha.data
if (
not self.act.alpha_logscale
): # Exp baked into cuda kernel, cancel it out with a log
alpha = torch.log(alpha)
beta = torch.log(beta)
x = FusedAntiAliasActivation.apply(
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
)
return x

View File

@ -0,0 +1,23 @@
/* coding=utf-8
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
}

View File

@ -0,0 +1,246 @@
/* coding=utf-8
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "type_shim.h"
#include <assert.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace
{
// Hard-coded hyperparameters
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
constexpr int BUFFER_SIZE = 32;
constexpr int FILTER_SIZE = 12;
constexpr int HALF_FILTER_SIZE = 6;
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
template <typename input_t, typename output_t, typename acc_t>
__global__ void anti_alias_activation_forward(
output_t *dst,
const input_t *src,
const input_t *up_ftr,
const input_t *down_ftr,
const input_t *alpha,
const input_t *beta,
int batch_size,
int channels,
int seq_len)
{
// Up and downsample filters
input_t up_filter[FILTER_SIZE];
input_t down_filter[FILTER_SIZE];
// Load data from global memory including extra indices reserved for replication paddings
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
// Output stores downsampled output before writing to dst
output_t output[BUFFER_SIZE];
// blockDim/threadIdx = (128, 1, 1)
// gridDim/blockIdx = (seq_blocks, channels, batches)
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
int local_offset = threadIdx.x * BUFFER_SIZE;
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
// intermediate have double the seq_len
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
// Get values needed for replication padding before moving pointer
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
input_t seq_left_most_value = right_most_pntr[0];
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
// Move src and dst pointers
src += block_offset + local_offset;
dst += block_offset + local_offset;
// Alpha and beta values for snake activatons. Applies exp by default
alpha = alpha + blockIdx.y;
input_t alpha_val = expf(alpha[0]);
beta = beta + blockIdx.y;
input_t beta_val = expf(beta[0]);
#pragma unroll
for (int it = 0; it < FILTER_SIZE; it += 1)
{
up_filter[it] = up_ftr[it];
down_filter[it] = down_ftr[it];
}
// Apply replication padding for upsampling, matching torch impl
#pragma unroll
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
{
int element_index = seq_offset + it; // index for element
if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
{
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
}
if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
{
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
}
if ((element_index >= 0) && (element_index < seq_len))
{
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
}
}
// Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
#pragma unroll
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
{
input_t acc = 0.0;
int element_index = intermediate_seq_offset + it; // index for intermediate
#pragma unroll
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
{
if ((element_index + f_idx) >= 0)
{
acc += up_filter[f_idx] * elements[it + f_idx];
}
}
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
}
// Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
double no_div_by_zero = 0.000000001;
#pragma unroll
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
{
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
}
// Apply replication padding before downsampling conv from intermediates
#pragma unroll
for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
{
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
}
#pragma unroll
for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
{
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
}
// Apply downsample strided convolution (assuming stride=2) from intermediates
#pragma unroll
for (int it = 0; it < BUFFER_SIZE; it += 1)
{
input_t acc = 0.0;
#pragma unroll
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
{
// Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
}
output[it] = acc;
}
// Write output to dst
#pragma unroll
for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
{
int element_index = seq_offset + it;
if (element_index < seq_len)
{
dst[it] = output[it];
}
}
}
template <typename input_t, typename output_t, typename acc_t>
void dispatch_anti_alias_activation_forward(
output_t *dst,
const input_t *src,
const input_t *up_ftr,
const input_t *down_ftr,
const input_t *alpha,
const input_t *beta,
int batch_size,
int channels,
int seq_len)
{
if (seq_len == 0)
{
return;
}
else
{
// Use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
constexpr int seq_len_per_block = 4096;
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
dim3 blocks(blocks_per_seq_len, channels, batch_size);
dim3 threads(threads_per_block, 1, 1);
anti_alias_activation_forward<input_t, output_t, acc_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
}
}
}
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
{
// Input is a 3d tensor with dimensions [batches, channels, seq_len]
const int batches = input.size(0);
const int channels = input.size(1);
const int seq_len = input.size(2);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor anti_alias_activation_results =
torch::empty({batches, channels, seq_len}, act_options);
void *input_ptr = static_cast<void *>(input.data_ptr());
void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
void *beta_ptr = static_cast<void *>(beta.data_ptr());
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch anti alias activation_forward",
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
reinterpret_cast<const scalar_t *>(input_ptr),
reinterpret_cast<const scalar_t *>(up_filter_ptr),
reinterpret_cast<const scalar_t *>(down_filter_ptr),
reinterpret_cast<const scalar_t *>(alpha_ptr),
reinterpret_cast<const scalar_t *>(beta_ptr),
batches,
channels,
seq_len););
return anti_alias_activation_results;
}

View File

@ -0,0 +1,29 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif

View File

@ -0,0 +1,86 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import pathlib
import subprocess
from torch.utils import cpp_extension
"""
Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
"""
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
def load():
# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / "build"
_create_build_dir(buildpath)
# Helper function to build the kernels.
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
return cpp_extension.load(
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=[
"-O3",
],
extra_cuda_cflags=[
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"--use_fast_math",
]
+ extra_cuda_flags
+ cc_flag,
verbose=True,
)
extra_cuda_flags = [
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
]
sources = [
srcpath / "anti_alias_activation.cpp",
srcpath / "anti_alias_activation_cuda.cu",
]
anti_alias_activation_cuda = _cpp_extention_load_helper(
"anti_alias_activation_cuda", sources, extra_cuda_flags
)
return anti_alias_activation_cuda
def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def _create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")

View File

@ -0,0 +1,92 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch (TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}