mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-06 14:38:14 +08:00
Compare commits
7 Commits
4a25469099
...
c20d4dec7c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c20d4dec7c | ||
|
|
08d627c333 | ||
|
|
6d95b559e8 | ||
|
|
c93bf48785 | ||
|
|
c9308ec96b | ||
|
|
87c521ea6e | ||
|
|
f19b76b26f |
76
GPT_SoVITS/AR/models/embedding_cudagraph.py
Normal file
76
GPT_SoVITS/AR/models/embedding_cudagraph.py
Normal file
@ -0,0 +1,76 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(self, embedding_dim: int, vocab_size: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_dim = embedding_dim
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
||||
|
||||
@property
|
||||
def weight(self) -> torch.Tensor:
|
||||
return self.word_embeddings.weight
|
||||
|
||||
def embedding(self, index: int) -> torch.Tensor:
|
||||
return self.word_embeddings.weight[index : index + 1]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.word_embeddings(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class SinePositionalEmbeddingNested(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
dropout: float = 0.0,
|
||||
scale: bool = False,
|
||||
alpha: bool = False,
|
||||
max_batch_size: int = 20,
|
||||
max_seq_len: int = 2500,
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
||||
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
self.reverse = False
|
||||
self.register_buffer(
|
||||
"pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False
|
||||
)
|
||||
self.pe: torch.Tensor
|
||||
self.compute_pe()
|
||||
|
||||
def compute_pe(self):
|
||||
if self.reverse:
|
||||
position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.embedding_dim)
|
||||
)
|
||||
pe = self.pe
|
||||
pe[:, :, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, :, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
def forward(self, input_pos: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = x.shape[0]
|
||||
pe_values = self.pe[torch.arange(batch_size), input_pos - 1]
|
||||
return x * self.x_scale + self.alpha * pe_values.unsqueeze(1)
|
||||
|
||||
def prefill(self, x: torch.Tensor) -> torch.Tensor:
|
||||
input_pos = torch.tensor([i.shape[0] for i in x.unbind()])
|
||||
pe_values = torch.nested.nested_tensor(
|
||||
[self.pe[i, : input_pos[i], :] for i in range(input_pos.size(0))]
|
||||
)
|
||||
return x * self.x_scale + self.alpha.item() * pe_values
|
||||
78
GPT_SoVITS/AR/models/structs_cudagraph.py
Normal file
78
GPT_SoVITS/AR/models/structs_cudagraph.py
Normal file
@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class T2SResult:
|
||||
result: List[Tensor] | None = None
|
||||
infer_speed: float = 0.0
|
||||
status: Literal["Success", "Error"] = "Success"
|
||||
exception: Optional[Exception] = None
|
||||
traceback: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class T2SRequest:
|
||||
x: List[torch.Tensor]
|
||||
x_lens: Tensor
|
||||
prompts: torch.Tensor
|
||||
bert_feature: List[Tensor]
|
||||
valid_length: int
|
||||
top_k: int = 5
|
||||
top_p: float = 1
|
||||
early_stop_num: int = -1
|
||||
temperature: float = 1.0
|
||||
repetition_penalty: float = 1.35
|
||||
use_cuda_graph: bool = False
|
||||
debug: bool = False
|
||||
|
||||
|
||||
class T2SSession:
|
||||
def __init__(self, decoder, request: T2SRequest, device: torch.device, dtype: torch.dtype):
|
||||
with device:
|
||||
self.decoder = decoder
|
||||
self.request = request
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
bsz = len(request.x)
|
||||
y_len = request.prompts.size(-1)
|
||||
self.bsz = bsz
|
||||
self.y_len = y_len
|
||||
|
||||
from AR.models.t2s_model_cudagraph import Sampler
|
||||
|
||||
self.sampler = Sampler(bsz, decoder.vocab_size)
|
||||
|
||||
self.x = request.x
|
||||
self.x_lens = request.x_lens.to(torch.int32)
|
||||
self.y = request.prompts
|
||||
self.bert_feature = request.bert_feature
|
||||
|
||||
self.prefill_len = self.x_lens + self.y.size(1)
|
||||
|
||||
self.input_pos = torch.zeros_like(self.prefill_len)
|
||||
self.input_pos.add_(self.prefill_len)
|
||||
|
||||
self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
|
||||
self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
|
||||
|
||||
self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature)
|
||||
|
||||
attn_mask = []
|
||||
for bs in range(bsz):
|
||||
pos = int(self.x_lens[bs].item())
|
||||
mask = torch.zeros(pos + y_len, pos + y_len).bool()
|
||||
mask[:, :pos].fill_(True)
|
||||
if y_len > 0:
|
||||
mask[-y_len:, -y_len:] = ~torch.triu(
|
||||
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
|
||||
)
|
||||
attn_mask.append(mask)
|
||||
self.attn_mask_nested = torch.nested.nested_tensor(attn_mask)
|
||||
602
GPT_SoVITS/AR/models/t2s_model_cudagraph.py
Normal file
602
GPT_SoVITS/AR/models/t2s_model_cudagraph.py
Normal file
@ -0,0 +1,602 @@
|
||||
"""
|
||||
CUDA Graph accelerated T2S decoder.
|
||||
Uses PyTorch native scaled_dot_product_attention (no flash_attn dependency).
|
||||
Adapted from gsvpp/AR/models/t2s_model_abc.py and t2s_model_flash_attn.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, List, MutableSequence, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.graphs import CUDAGraph
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.models.embedding_cudagraph import (
|
||||
SinePositionalEmbeddingNested as SinePositionalEmbedding,
|
||||
)
|
||||
from AR.models.embedding_cudagraph import TokenEmbedding
|
||||
from AR.models.structs_cudagraph import T2SRequest, T2SResult, T2SSession
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
def __init__(self, batch_size: int, vocab_size: int) -> None:
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Tensor,
|
||||
previous_tokens: Tensor,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
repetition_penalty: float,
|
||||
) -> Tensor:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
|
||||
v, _ = torch.topk(logits, top_k)
|
||||
pivot = v[:, -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
q = torch.empty_like(probs).exponential_(1.0)
|
||||
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
|
||||
|
||||
return idx_next
|
||||
|
||||
|
||||
# ─── KV Cache ────────────────────<E29480><E29480><EFBFBD>───────────────────────────────────────────
|
||||
|
||||
|
||||
class KVCacheNHD(nn.Module):
|
||||
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
||||
super().__init__()
|
||||
assert batch_size > 0
|
||||
cache_shape = (batch_size, max_seq_length, n_heads, head_dim)
|
||||
self.n_head = n_heads
|
||||
self.head_dim = head_dim
|
||||
self.batch_size = batch_size
|
||||
self.max_seq_length = max_seq_length
|
||||
self.register_buffer(
|
||||
"k_cache", torch.zeros(size=cache_shape), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"v_cache", torch.zeros(size=cache_shape), persistent=False
|
||||
)
|
||||
|
||||
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
||||
index = (
|
||||
(input_pos - 1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
.expand(-1, -1, self.n_head, self.head_dim)
|
||||
.to(torch.int64)
|
||||
)
|
||||
k_out = self.k_cache
|
||||
v_out = self.v_cache
|
||||
k_out.scatter_(1, index, k_val)
|
||||
v_out.scatter_(1, index, v_val)
|
||||
return k_out, v_out
|
||||
|
||||
def empty(self):
|
||||
self.k_cache.zero_()
|
||||
self.v_cache.zero_()
|
||||
|
||||
def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int):
|
||||
self.k_cache[[bs], : k_val.shape[1]] = k_val
|
||||
self.v_cache[[bs], : v_val.shape[1]] = v_val
|
||||
|
||||
|
||||
# ─── Attention (PyTorch native SDPA, no flash_attn) ─────────────────────────
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, n_head: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.hidden_dim = hidden_dim
|
||||
assert hidden_dim % n_head == 0
|
||||
self.head_dim = hidden_dim // n_head
|
||||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict: dict, prefix, *args):
|
||||
keys_to_modify = [key for key in state_dict if "in_proj_" in key]
|
||||
for key in keys_to_modify:
|
||||
new_key = key.replace("in_proj_", "in_proj.")
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheNHD
|
||||
) -> Tensor:
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj.forward(x).chunk(3, dim=-1)
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
|
||||
k_cache, v_cache = kv_cache.update(input_pos, k, v)
|
||||
|
||||
q = q.transpose(1, 2) # [B, H, 1, D]
|
||||
k_out = k_cache.transpose(1, 2) # [B, H, max_seq, D]
|
||||
v_out = v_cache.transpose(1, 2) # [B, H, max_seq, D]
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k_out, v_out)
|
||||
|
||||
attn = self.dropout.forward(attn)
|
||||
attn = attn.transpose(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||||
attn = self.out_proj.forward(attn)
|
||||
return attn
|
||||
|
||||
def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheNHD) -> Tensor:
|
||||
bsz = x.size(0)
|
||||
outputs = []
|
||||
for bs in range(bsz):
|
||||
x_b = x[bs].unsqueeze(0)
|
||||
q, k, v = self.in_proj.forward(x_b.unsqueeze(0)).chunk(3, dim=-1)
|
||||
q = q.contiguous().view(1, -1, self.n_head, self.head_dim)
|
||||
k = k.contiguous().view(1, -1, self.n_head, self.head_dim)
|
||||
v = v.contiguous().view(1, -1, self.n_head, self.head_dim)
|
||||
kv_cache.prefill_kv(k, v, bs)
|
||||
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
|
||||
attn_mask = (
|
||||
mask[bs].unsqueeze(0).unsqueeze(0).expand(1, self.n_head, -1, -1)
|
||||
)
|
||||
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
attn = self.dropout.forward(attn)
|
||||
attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
|
||||
output = self.out_proj.forward(attn)
|
||||
outputs.append(output.squeeze(0))
|
||||
return torch.nested.nested_tensor(outputs)
|
||||
|
||||
|
||||
# ─── Feed Forward ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
|
||||
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.dropout.forward(
|
||||
self.linear2(self.dropout.forward(F.relu(self.linear1(x))))
|
||||
)
|
||||
|
||||
|
||||
# ─── Transformer Block ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, n_head, ffn_dim, hidden_dim) -> None:
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.attention = Attention(n_head, hidden_dim)
|
||||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||||
self.attention_norm = nn.LayerNorm([hidden_dim])
|
||||
self.ffn_norm = nn.LayerNorm([hidden_dim])
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
||||
for key in list(state_dict.keys()):
|
||||
new_key = (
|
||||
key.replace("self_attn", "attention")
|
||||
.replace("linear", "feed_forward.linear")
|
||||
.replace("norm1", "attention_norm")
|
||||
.replace("norm2", "ffn_norm")
|
||||
)
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheNHD
|
||||
) -> Tensor:
|
||||
h = self.attention_norm.forward(
|
||||
x + self.dropout.forward(self.attention.forward(x, input_pos, kv_cache))
|
||||
)
|
||||
out = self.ffn_norm.forward(h + self.feed_forward.forward(h))
|
||||
return out
|
||||
|
||||
def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheNHD) -> Tensor:
|
||||
h = self.attention_norm.forward(
|
||||
x + self.dropout.forward(self.attention.prefill(x, mask, kv_cache))
|
||||
)
|
||||
out = self.ffn_norm.forward(h + self.feed_forward.forward(h))
|
||||
return out
|
||||
|
||||
|
||||
# ─── Transformer Decoder ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.n_head = n_head
|
||||
assert hidden_dim % n_head == 0
|
||||
self.head_dim = hidden_dim // n_head
|
||||
self.vocab_size = vocab_size
|
||||
self.n_layer = n_layer
|
||||
self.layers = nn.ModuleList(
|
||||
TransformerBlock(n_head, ffn_dim, hidden_dim) for _ in range(n_layer)
|
||||
)
|
||||
self.max_seq_length: int = max_seq_length
|
||||
self.max_batch_size: int = max_batch_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_pos: Tensor,
|
||||
x: Tensor,
|
||||
kv_caches: MutableSequence[KVCacheNHD],
|
||||
):
|
||||
for layer, kv_cache in zip(self.layers, kv_caches):
|
||||
x = layer.forward(x, input_pos, kv_cache)
|
||||
return x
|
||||
|
||||
def prefill(
|
||||
self,
|
||||
x: Tensor,
|
||||
mask: Tensor,
|
||||
kv_caches: MutableSequence[KVCacheNHD],
|
||||
):
|
||||
for layer, kv_cache in zip(self.layers, kv_caches):
|
||||
x = layer.prefill(x, mask, kv_cache)
|
||||
return x
|
||||
|
||||
|
||||
# ─── T2S Decoder ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class T2SDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
*args,
|
||||
norm_first=False,
|
||||
max_seq_length=2500,
|
||||
max_batch_size=10,
|
||||
**kwds,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
hidden_dim = config["model"]["hidden_dim"]
|
||||
embedding_dim = config["model"]["embedding_dim"]
|
||||
n_head = config["model"]["head"]
|
||||
n_layer = config["model"]["n_layer"]
|
||||
vocab_size = config["model"]["vocab_size"]
|
||||
phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
|
||||
p_dropout = config["model"]["dropout"]
|
||||
EOS = config["model"]["EOS"]
|
||||
ffn_dim = hidden_dim * 4
|
||||
|
||||
self.n_layer = n_layer
|
||||
self.hidden_dim = hidden_dim
|
||||
self.n_head = n_head
|
||||
assert hidden_dim % n_head == 0
|
||||
self.head_dim = hidden_dim // n_head
|
||||
self.embedding_dim = embedding_dim
|
||||
self.vocab_size = vocab_size
|
||||
self.phoneme_vocab_size = phoneme_vocab_size
|
||||
self.p_dropout = p_dropout
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
self.EOS = EOS
|
||||
assert self.EOS == self.vocab_size - 1
|
||||
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_text_embedding = TokenEmbedding(
|
||||
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
||||
)
|
||||
self.ar_text_position = SinePositionalEmbedding(
|
||||
self.embedding_dim,
|
||||
dropout=0.1,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
)
|
||||
self.ar_audio_embedding = TokenEmbedding(
|
||||
self.embedding_dim, self.vocab_size, self.p_dropout
|
||||
)
|
||||
self.ar_audio_position = SinePositionalEmbedding(
|
||||
self.embedding_dim,
|
||||
dropout=0.1,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
)
|
||||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||||
self.h = TransformerDecoder(
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict, prefix, *args):
|
||||
model_keys = [key for key in state_dict if key.startswith("model.")]
|
||||
for key in model_keys:
|
||||
new_key = key[len("model.") :]
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
def init_cache(self, bsz: int = 0) -> nn.ModuleList:
|
||||
bsz = bsz or self.h.max_batch_size
|
||||
assert bsz <= self.h.max_batch_size
|
||||
seq_lens = self.h.max_seq_length
|
||||
device = self.bert_proj.bias.device
|
||||
dtype = self.bert_proj.bias.dtype
|
||||
return nn.ModuleList(
|
||||
[
|
||||
KVCacheNHD(bsz, seq_lens, self.n_head, self.head_dim)
|
||||
for _ in range(self.n_layer)
|
||||
],
|
||||
).to(device, dtype)
|
||||
|
||||
def embed(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
y: torch.Tensor,
|
||||
bert_features: List[torch.Tensor],
|
||||
):
|
||||
x_nested = torch.nested.nested_tensor(x)
|
||||
assert x_nested.size(0) <= self.max_batch_size
|
||||
bert_features_nested = torch.nested.nested_tensor(
|
||||
list(map(lambda t: t.transpose(0, 1), bert_features))
|
||||
)
|
||||
x_emb = self.ar_text_embedding.forward(x_nested)
|
||||
bert = self.bert_proj.forward(bert_features_nested)
|
||||
x_emb = x_emb + bert
|
||||
x_pos = self.ar_text_position.prefill(x_emb)
|
||||
|
||||
y_nested = torch.nested.nested_tensor(list(y.unbind(0)))
|
||||
y_emb = self.ar_audio_embedding.forward(y_nested)
|
||||
y_pos = self.ar_audio_position.prefill(y_emb)
|
||||
|
||||
xy_pos = torch.nested.nested_tensor(
|
||||
[torch.cat([x_pos[i], y_pos[i]]) for i in range(len(x))]
|
||||
)
|
||||
return xy_pos
|
||||
|
||||
def capture(
|
||||
self,
|
||||
input_pos: Tensor,
|
||||
x: Tensor,
|
||||
x_dec: Tensor,
|
||||
kv_caches,
|
||||
) -> CUDAGraph:
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(5):
|
||||
self.h.forward(input_pos, x, kv_caches)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
with torch.cuda.graph(graph):
|
||||
x_dec.copy_(self.h.forward(input_pos, x, kv_caches))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
# ─── CUDA Graph Runner ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class CUDAGraphRunner:
|
||||
def __init__(
|
||||
self,
|
||||
decoder_model: T2SDecoder,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> None:
|
||||
assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
|
||||
assert dtype in {torch.float16, torch.bfloat16, torch.float32}
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.decoder_model: T2SDecoder = decoder_model.to(self.device, self.dtype)
|
||||
self.graph: Optional[CUDAGraph] = None
|
||||
self.xy_pos_ = torch.rand(
|
||||
(1, 1, decoder_model.embedding_dim), device=device
|
||||
).to(dtype)
|
||||
self.xy_dec_ = torch.rand(
|
||||
(1, 1, decoder_model.embedding_dim), device=device
|
||||
).to(dtype)
|
||||
self.kv_cache = decoder_model.init_cache(1)
|
||||
self.input_pos = torch.tensor([10]).int().cuda()
|
||||
|
||||
def _handle_request(self, request: T2SRequest):
|
||||
with self.device:
|
||||
for i in self.kv_cache:
|
||||
i.empty()
|
||||
|
||||
decoder = self.decoder_model
|
||||
session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
|
||||
self.input_pos.copy_(session.input_pos)
|
||||
|
||||
t1 = 0.0
|
||||
infer_speed = 0.0
|
||||
y = session.y
|
||||
bsz = y.size(0)
|
||||
|
||||
for idx in tqdm(range(1500)):
|
||||
if idx == 0:
|
||||
xy_dec = decoder.h.prefill(
|
||||
session.xy_pos, session.attn_mask_nested, self.kv_cache
|
||||
)
|
||||
xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()])
|
||||
else:
|
||||
if (
|
||||
request.use_cuda_graph
|
||||
and self.graph is None
|
||||
and torch.cuda.is_available()
|
||||
):
|
||||
self.xy_pos_.copy_(session.xy_pos)
|
||||
self.graph = decoder.capture(
|
||||
self.input_pos,
|
||||
self.xy_pos_,
|
||||
self.xy_dec_,
|
||||
kv_caches=self.kv_cache,
|
||||
)
|
||||
|
||||
if self.graph:
|
||||
self.xy_pos_.copy_(session.xy_pos)
|
||||
self.graph.replay()
|
||||
xy_dec = self.xy_dec_.clone()
|
||||
else:
|
||||
xy_dec = decoder.h.forward(
|
||||
self.input_pos,
|
||||
session.xy_pos,
|
||||
self.kv_cache,
|
||||
)
|
||||
|
||||
logits = decoder.ar_predict_layer(xy_dec[:, -1])
|
||||
self.input_pos.add_(1)
|
||||
|
||||
if idx == 0:
|
||||
logits[:, -1] = float("-inf")
|
||||
|
||||
samples = session.sampler.sample(
|
||||
logits=logits,
|
||||
previous_tokens=session.y,
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
|
||||
session.y = torch.cat([session.y, samples], dim=1)
|
||||
|
||||
argmax_token = torch.argmax(logits, dim=-1)
|
||||
sample_token = samples.squeeze(1)
|
||||
EOS_mask = (argmax_token == decoder.EOS) | (
|
||||
sample_token == decoder.EOS
|
||||
)
|
||||
|
||||
newly_done_mask = EOS_mask & (~session.completed)
|
||||
newly_done_indices = newly_done_mask.nonzero()
|
||||
|
||||
if newly_done_indices.numel() > 0:
|
||||
session.y_results[newly_done_indices[0]] = session.y[
|
||||
newly_done_indices[0], session.y_len : -1
|
||||
].squeeze(0)
|
||||
session.completed[newly_done_indices] = True
|
||||
|
||||
if torch.all(session.completed).item():
|
||||
if session.y.size(1) == 0:
|
||||
session.y = torch.cat(
|
||||
[session.y, torch.zeros_like(samples)], dim=1
|
||||
)
|
||||
tqdm.write("Bad Zero Prediction")
|
||||
else:
|
||||
tqdm.write(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n"
|
||||
f"{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
tqdm.write(
|
||||
f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s"
|
||||
)
|
||||
infer_speed = (idx - 1) / (time.perf_counter() - t1)
|
||||
break
|
||||
|
||||
if (
|
||||
request.early_stop_num != -1
|
||||
and (session.y.size(1) - session.y_len) > request.early_stop_num
|
||||
) or idx == 1499:
|
||||
for i in range(bsz):
|
||||
if not session.completed[i].item():
|
||||
session.y_results[i] = session.y[i, session.y_len :]
|
||||
session.completed[i] = True
|
||||
break
|
||||
|
||||
y_emb = decoder.ar_audio_embedding(session.y[:, -1:])
|
||||
session.xy_pos = decoder.ar_audio_position.forward(
|
||||
self.input_pos - session.x_lens, y_emb
|
||||
)
|
||||
|
||||
if idx == 2:
|
||||
t1 = time.perf_counter()
|
||||
|
||||
if idx % 100 == 0 and self.device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return session.y_results[: request.valid_length], infer_speed
|
||||
|
||||
def generate(self, request: T2SRequest) -> T2SResult:
|
||||
try:
|
||||
result, infer_speed = self._handle_request(request)
|
||||
t2s_result = T2SResult(
|
||||
result=result, infer_speed=infer_speed, status="Success"
|
||||
)
|
||||
except Exception as e:
|
||||
t2s_result = T2SResult(
|
||||
status="Error", exception=e, traceback=traceback.format_exc()
|
||||
)
|
||||
return t2s_result
|
||||
|
||||
@staticmethod
|
||||
def load_decoder(weights_path, max_batch_size=1) -> T2SDecoder:
|
||||
print(
|
||||
f"Loading Text2Semantic Weights from {weights_path} with CUDA Graph (SDPA) Implement"
|
||||
)
|
||||
dict_s1 = torch.load(
|
||||
weights_path, map_location="cpu", weights_only=False#, mmap=True
|
||||
)
|
||||
config = dict_s1["config"]
|
||||
decoder = T2SDecoder(config, max_batch_size=max_batch_size)
|
||||
state_dict = dict_s1["weight"]
|
||||
decoder.load_state_dict(state_dict)
|
||||
return decoder.eval()
|
||||
File diff suppressed because it is too large
Load Diff
825
api_role.py
Normal file
825
api_role.py
Normal file
@ -0,0 +1,825 @@
|
||||
"""
|
||||
GPT-SoVITS API 实现
|
||||
|
||||
### 完整请求示例 (/ttsrole POST)
|
||||
{
|
||||
"text": "你好", # str, 必填, 要合成的文本内容
|
||||
"role": "role1", # str, 必填, 角色名称,决定使用 roles/{role} 中的配置和音频
|
||||
"emotion": "开心", # str, 可选, 情感标签,用于从 roles/{role}/reference_audios 中选择音频
|
||||
"text_lang": "auto", # str, 可选, 默认 "auto", 文本语言,"auto" 时根据 emotion 或角色目录动态选择
|
||||
"ref_audio_path": "/path/to/ref.wav", # str, 可选, 参考音频路径,若提供则优先使用,跳过自动选择
|
||||
"aux_ref_audio_paths": ["/path1.wav", "/path2.wav"], # List[str], 可选, 辅助参考音频路径,用于多说话人融合
|
||||
"prompt_lang": "ja", # str, 可选, 提示文本语言,若提供 ref_audio_path 则需指定,"auto" 模式下动态选择
|
||||
"prompt_text": "こんにちは", # str, 可选, 提示文本,与 ref_audio_path 配对使用,自动选择时从文件或文件名生成
|
||||
"top_k": 10, # int, 可选, Top-K 采样值,覆盖 inference.top_k
|
||||
"top_p": 0.8, # float, 可选, Top-P 采样值,覆盖 inference.top_p
|
||||
"temperature": 1.0, # float, 可选, 温度值,覆盖 inference.temperature
|
||||
"text_split_method": "cut5", # str, 可选, 文本分割方法,覆盖 inference.text_split_method, 具体见text_segmentation_method.py
|
||||
"batch_size": 2, # int, 可选, 批处理大小,覆盖 inference.batch_size
|
||||
"batch_threshold": 0.75, # float, 可选, 批处理阈值,覆盖 inference.batch_threshold
|
||||
"split_bucket": true, # bool, 可选, 是否按桶分割,覆盖 inference.split_bucket
|
||||
"speed_factor": 1.2, # float, 可选, 语速因子,覆盖 inference.speed_factor
|
||||
"fragment_interval": 0.3, # float, 可选, 片段间隔(秒),覆盖 inference.fragment_interval
|
||||
"seed": 42, # int, 可选, 随机种子,覆盖 seed
|
||||
"media_type": "wav", # str, 可选, 默认 "wav", 输出格式,支持 "wav", "raw", "ogg", "aac"
|
||||
"streaming_mode": false, # bool, 可选, 默认 false, 是否流式返回
|
||||
"parallel_infer": true, # bool, 可选, 默认 true, 是否并行推理
|
||||
"repetition_penalty": 1.35, # float, 可选, 重复惩罚值,覆盖 inference.repetition_penalty
|
||||
"version": "v2", # str, 可选, 配置文件版本,覆盖 version
|
||||
"languages": ["zh", "ja", "en"], # List[str], 可选, 支持的语言列表,覆盖 languages
|
||||
"bert_base_path": "/path/to/bert", # str, 可选, BERT 模型路径,覆盖 bert_base_path
|
||||
"cnhuhbert_base_path": "/path/to/hubert", # str, 可选, HuBERT 模型路径,覆盖 cnhuhbert_base_path
|
||||
"device": "cpu", # str, 可选, 统一设备,覆盖 device
|
||||
"is_half": true, # bool, 可选, 是否使用半精度,覆盖 is_half
|
||||
"t2s_weights_path": "/path/to/gpt.ckpt", # str, 可选, GPT 模型路径,覆盖 t2s_weights_path
|
||||
"vits_weights_path": "/path/to/sovits.pth", # str, 可选, SoVITS 模型路径,覆盖 vits_weights_path
|
||||
"t2s_model_path": "/path/to/gpt.ckpt", # str, 可选, GPT 模型路径(与 t2s_weights_path 同义)
|
||||
"t2s_model_device": "cpu", # str, 可选, GPT 模型设备,覆盖 t2s_model.device,默认检测显卡
|
||||
"vits_model_path": "/path/to/sovits.pth", # str, 可选, SoVITS 模型路径(与 vits_weights_path 同义)
|
||||
"vits_model_device": "cpu" # str, 可选, SoVITS 模型设备,覆盖 vits_model.device,默认检测显卡
|
||||
}
|
||||
|
||||
### 参数必要性和优先级
|
||||
- 必填参数:
|
||||
- /ttsrole: text, role
|
||||
- /tts: text, ref_audio_path, prompt_lang
|
||||
- 可选参数: 其他均为可选,默认值从 roles/{role}/tts_infer.yaml 或 GPT_SoVITS/configs/tts_infer.yaml 获取
|
||||
- 优先级: POST 请求参数 > roles/{role}/tts_infer.yaml > 默认 GPT_SoVITS/configs/tts_infer.yaml
|
||||
|
||||
### 目录结构
|
||||
GPT-SoVITS-roleapi/
|
||||
├── api_role.py # 本文件, API 主程序
|
||||
├── GPT_SoVITS/ # GPT-SoVITS 核心库
|
||||
│ └── configs/
|
||||
│ └── tts_infer.yaml # 默认配置文件
|
||||
├── roles/ # 角色配置目录
|
||||
│ ├── role1/ # 示例角色 role1
|
||||
│ │ ├── tts_infer.yaml # 角色配置文件(可选)
|
||||
│ │ ├── model.ckpt # GPT 模型(可选)
|
||||
│ │ ├── model.pth # SoVITS 模型(可选)
|
||||
│ │ └── reference_audios/ # 角色参考音频目录
|
||||
│ │ ├── zh/
|
||||
│ │ │ ├── 【开心】voice1.wav
|
||||
│ │ │ ├── 【开心】voice1.txt
|
||||
│ │ ├── ja/
|
||||
│ │ │ ├── 【开心】voice2.wav
|
||||
│ │ │ ├── 【开心】voice2.txt
|
||||
│ ├── role2/
|
||||
│ │ ├── tts_infer.yaml
|
||||
│ │ ├── model.ckpt
|
||||
│ │ ├── model.pth
|
||||
│ │ └── reference_audios/
|
||||
│ │ ├── zh/
|
||||
│ │ │ ├── 【开心】voice1.wav
|
||||
│ │ │ ├── 【开心】voice1.txt
|
||||
│ │ │ ├── 【悲伤】asdafasdas.wav
|
||||
│ │ │ ├── 【悲伤】asdafasdas.txt
|
||||
│ │ ├── ja/
|
||||
│ │ │ ├── 【开心】voice2.wav
|
||||
│ │ │ ├── 【开心】voice2.txt
|
||||
|
||||
### text_lang, prompt_lang, prompt_text 选择逻辑 (/ttsrole)
|
||||
1. text_lang 选择逻辑:
|
||||
- 默认值: "auto"
|
||||
- 如果请求未提供 text_lang,视为 "auto"
|
||||
- 当 text_lang = "auto" 且存在 emotion 参数:
|
||||
- 从 roles/{role}/reference_audios 下所有语言文件夹中查找以 "【emotion】" 开头的音频
|
||||
- 随机选择一个匹配的音频,语言由音频所在文件夹确定
|
||||
- 当 text_lang 指定具体语言(如 "zh"):
|
||||
- 从 roles/{role}/reference_audios/{text_lang} 中选择音频
|
||||
- 如果指定语言无匹配音频,则尝试其他语言文件夹
|
||||
2. prompt_lang 选择逻辑:
|
||||
- 如果提供了 ref_audio_path,则需显式指定 prompt_lang
|
||||
- 如果未提供 ref_audio_path 且 text_lang = "auto" 且存在 emotion:
|
||||
- prompt_lang = 随机选择的音频所在语言文件夹名(如 "zh" 或 "ja")
|
||||
- 如果未提供 ref_audio_path 且 text_lang 指定具体语言:
|
||||
- prompt_lang = text_lang(如 "zh")
|
||||
- 如果 text_lang 无匹配音频,则为随机选择的音频所在语言
|
||||
3. prompt_text 选择逻辑:
|
||||
- 如果提供了 ref_audio_path(如 "/path/to/ref.wav"):
|
||||
- 检查文件名是否包含 "【xxx】" 前缀:
|
||||
- 如果有(如 "【开心】abc.wav"):
|
||||
- 若存在对应 .txt 文件(如 "【开心】abc.txt"),prompt_text = .txt 文件内容
|
||||
- 若无对应 .txt 文件,prompt_text = "abc"(去掉 "【开心】" 和 ".wav" 的部分)
|
||||
- 如果无 "【xxx】" 前缀:
|
||||
- 若存在对应 .txt 文件(如 "ref.txt"),prompt_text = .txt 文件内容
|
||||
- 若无对应 .txt 文件,prompt_text = "ref"(去掉 ".wav" 的部分)
|
||||
- 如果未提供 ref_audio_path:
|
||||
- 从 roles/{role}/reference_audios 中选择音频(基于 text_lang 和 emotion):
|
||||
- 优先匹配 "【emotion】" 前缀的音频(如 "【开心】voice1.wav")
|
||||
- 若存在对应 .txt 文件(如 "【开心】voice1.txt"),prompt_text = .txt 文件内容
|
||||
- 若无对应 .txt 文件,prompt_text = "voice1"(去掉 "【开心】" 和 ".wav" 的部分)
|
||||
- 未匹配 emotion 则随机选择一个音频,逻辑同上
|
||||
|
||||
### 讲解
|
||||
1. 必填参数:
|
||||
- /ttsrole: text, role
|
||||
- /tts: text, ref_audio_path, prompt_lang
|
||||
2. 音频选择 (/ttsrole):
|
||||
- 若提供 ref_audio_path,则使用它
|
||||
- 否则根据 role、text_lang、emotion 从 roles/{role}/reference_audios 中选择
|
||||
- text_lang = "auto" 时,若有 emotion,则跨语言匹配 "【emotion】" 前缀音频
|
||||
- emotion 匹配 "【emotion】" 前缀音频,未匹配则随机选择
|
||||
3. 设备选择:
|
||||
- 默认尝试检测显卡(torch.cuda.is_available()),若可用则用 "cuda",否则 "cpu"
|
||||
- 若缺少 torch 依赖或检测失败,回退到 "cpu"
|
||||
- POST 参数 device, t2s_model_device, vits_model_device 可强制指定设备,优先级最高
|
||||
4. 配置文件:
|
||||
- 默认加载 GPT_SoVITS/configs/tts_infer.yaml
|
||||
- 若 roles/{role}/tts_infer.yaml 存在且未被请求参数覆盖,则使用它 (/ttsrole)
|
||||
- 请求参数(如 top_k, bert_base_path)覆盖所有配置文件
|
||||
5. 返回格式:
|
||||
- 成功时返回音频流 (Response 或 StreamingResponse)
|
||||
- 失败时返回 JSON,包含错误消息和可能的异常详情
|
||||
6. 运行:
|
||||
- python api_role.py -a 127.0.0.1 -p 9880
|
||||
- 检查启动日志确认设备
|
||||
|
||||
### 调用示例 (/ttsrole)
|
||||
## 非流式调用,会一次性返回完整的音频数据,适用于需要完整音频文件的场景
|
||||
import requests
|
||||
url = "http://127.0.0.1:9880/ttsrole"
|
||||
payload = {
|
||||
"text": "你好,这是一个测试", # 要合成的文本
|
||||
"role": "role1", # 角色名称,必填
|
||||
"emotion": "开心", # 情感标签,可选
|
||||
"text_lang": "zh", # 文本语言,可选,默认为 "zh"
|
||||
"media_type": "wav" # 输出音频格式,默认 "wav"
|
||||
}
|
||||
response = requests.post(url, json=payload)
|
||||
if response.status_code == 200:
|
||||
with open("output_non_stream.wav", "wb") as f:
|
||||
f.write(response.content)
|
||||
print("非流式音频已生成并保存为 output_non_stream.wav")
|
||||
else:
|
||||
print(f"请求失败: {response.json()}")
|
||||
|
||||
## 流式调用,会分块返回音频数据,适用于实时播放或处理大文件的场景
|
||||
import requests
|
||||
url = "http://127.0.0.1:9880/ttsrole"
|
||||
payload = {
|
||||
"text": "你好,这是一个测试", # 要合成的文本
|
||||
"role": "role1", # 角色名称,必填
|
||||
"emotion": "开心", # 情感标签,可选
|
||||
"text_lang": "zh", # 文本语言,可选,默认为 "zh"
|
||||
"media_type": "wav", # 输出音频格式,默认 "wav"
|
||||
"streaming_mode": True # 启用流式模式
|
||||
}
|
||||
with requests.post(url, json=payload, stream=True) as response:
|
||||
if response.status_code == 200:
|
||||
with open("output_stream.wav", "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk: # 确保 chunk 不为空
|
||||
f.write(chunk)
|
||||
print("流式音频已生成并保存为 output_stream.wav")
|
||||
else:
|
||||
print(f"请求失败: {response.json()}")
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Generator, Optional, List, Dict
|
||||
import random
|
||||
import glob
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import asyncio
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import wave
|
||||
import signal
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, HTTPException, Response
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
from io import BytesIO
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||
|
||||
# 尝试导入 PyTorch,检测显卡支持
|
||||
try:
|
||||
import torch
|
||||
cuda_available = torch.cuda.is_available()
|
||||
except ImportError:
|
||||
cuda_available = False
|
||||
print("缺少 PyTorch 依赖,默认使用 CPU")
|
||||
except Exception as e:
|
||||
cuda_available = False
|
||||
print(f"检测显卡时出错: {str(e)},默认使用 CPU")
|
||||
|
||||
i18n = I18nAuto()
|
||||
cut_method_names = get_cut_method_names()
|
||||
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
|
||||
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
||||
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
|
||||
args = parser.parse_args()
|
||||
config_path = args.tts_config
|
||||
port = args.port
|
||||
host = args.bind_addr
|
||||
argv = sys.argv
|
||||
|
||||
if config_path in [None, ""]:
|
||||
config_path = "GPT_SoVITS/configs/tts_infer.yaml"
|
||||
|
||||
default_device = "cuda" if cuda_available else "cpu"
|
||||
print(f"默认设备设置为: {default_device}")
|
||||
|
||||
# 初始化 TTS 配置
|
||||
tts_config = TTS_Config(config_path)
|
||||
print(f"TTS_Config contents: {tts_config.__dict__}")
|
||||
if hasattr(tts_config, 'device'):
|
||||
tts_config.device = default_device
|
||||
tts_pipeline = TTS(tts_config)
|
||||
|
||||
# 创建线程池用于异步执行 TTS 任务
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
APP = FastAPI()
|
||||
|
||||
class TTS_Request(BaseModel):
|
||||
text: str
|
||||
ref_audio_path: str
|
||||
prompt_lang: str
|
||||
text_lang: str = "auto"
|
||||
aux_ref_audio_paths: Optional[List[str]] = None
|
||||
prompt_text: Optional[str] = ""
|
||||
top_k: Optional[int] = 5
|
||||
top_p: Optional[float] = 1
|
||||
temperature: Optional[float] = 1
|
||||
text_split_method: Optional[str] = "cut5"
|
||||
batch_size: Optional[int] = 1
|
||||
batch_threshold: Optional[float] = 0.75
|
||||
split_bucket: Optional[bool] = True
|
||||
speed_factor: Optional[float] = 1.0
|
||||
fragment_interval: Optional[float] = 0.3
|
||||
seed: Optional[int] = -1
|
||||
media_type: Optional[str] = "wav"
|
||||
streaming_mode: Optional[bool] = False
|
||||
parallel_infer: Optional[bool] = True
|
||||
repetition_penalty: Optional[float] = 1.35
|
||||
device: Optional[str] = None
|
||||
|
||||
class TTSRole_Request(BaseModel):
|
||||
text: str
|
||||
role: str
|
||||
text_lang: Optional[str] = "auto"
|
||||
ref_audio_path: Optional[str] = None
|
||||
aux_ref_audio_paths: Optional[List[str]] = None
|
||||
prompt_lang: Optional[str] = None
|
||||
prompt_text: Optional[str] = None
|
||||
emotion: Optional[str] = None
|
||||
top_k: Optional[int] = 5
|
||||
top_p: Optional[float] = 1
|
||||
temperature: Optional[float] = 1
|
||||
text_split_method: Optional[str] = "cut5"
|
||||
batch_size: Optional[int] = 1
|
||||
batch_threshold: Optional[float] = 0.75
|
||||
split_bucket: Optional[bool] = True
|
||||
speed_factor: Optional[float] = 1.0
|
||||
fragment_interval: Optional[float] = 0.3
|
||||
seed: Optional[int] = -1
|
||||
media_type: Optional[str] = "wav"
|
||||
streaming_mode: Optional[bool] = False
|
||||
parallel_infer: Optional[bool] = True
|
||||
repetition_penalty: Optional[float] = 1.35
|
||||
bert_base_path: Optional[str] = None
|
||||
cnhuhbert_base_path: Optional[str] = None
|
||||
device: Optional[str] = None
|
||||
is_half: Optional[bool] = None
|
||||
t2s_weights_path: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
vits_weights_path: Optional[str] = None
|
||||
t2s_model_path: Optional[str] = None
|
||||
vits_model_path: Optional[str] = None
|
||||
t2s_model_device: Optional[str] = None
|
||||
vits_model_device: Optional[str] = None
|
||||
|
||||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
|
||||
audio_file.write(data)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
io_buffer.write(data.tobytes())
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
sf.write(io_buffer, data, rate, format='wav')
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
process = subprocess.Popen([
|
||||
'ffmpeg', '-f', 's16le', '-ar', str(rate), '-ac', '1', '-i', 'pipe:0',
|
||||
'-c:a', 'aac', '-b:a', '192k', '-vn', '-f', 'adts', 'pipe:1'
|
||||
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
out, _ = process.communicate(input=data.tobytes())
|
||||
io_buffer.write(out)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
def pack_audio(data: np.ndarray, rate: int, media_type: str) -> BytesIO:
|
||||
io_buffer = BytesIO()
|
||||
if media_type == "ogg":
|
||||
io_buffer = pack_ogg(io_buffer, data, rate)
|
||||
elif media_type == "aac":
|
||||
io_buffer = pack_aac(io_buffer, data, rate)
|
||||
elif media_type == "wav":
|
||||
io_buffer = pack_wav(io_buffer, data, rate)
|
||||
else:
|
||||
io_buffer = pack_raw(io_buffer, data, rate)
|
||||
return io_buffer
|
||||
|
||||
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
|
||||
wav_buf = BytesIO()
|
||||
with wave.open(wav_buf, "wb") as vfout:
|
||||
vfout.setnchannels(channels)
|
||||
vfout.setsampwidth(sample_width)
|
||||
vfout.setframerate(sample_rate)
|
||||
vfout.writeframes(frame_input)
|
||||
wav_buf.seek(0)
|
||||
return wav_buf.read()
|
||||
|
||||
def handle_control(command: str):
|
||||
if command == "restart":
|
||||
os.execl(sys.executable, sys.executable, *argv)
|
||||
elif command == "exit":
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
exit(0)
|
||||
|
||||
def check_params(req: dict, is_ttsrole: bool = False):
|
||||
text = req.get("text")
|
||||
text_lang = req.get("text_lang", "auto")
|
||||
ref_audio_path = req.get("ref_audio_path")
|
||||
prompt_lang = req.get("prompt_lang")
|
||||
media_type = req.get("media_type", "wav")
|
||||
streaming_mode = req.get("streaming_mode", False)
|
||||
text_split_method = req.get("text_split_method", "cut5")
|
||||
|
||||
if not text:
|
||||
return {"status": "error", "message": "text is required"}
|
||||
|
||||
if is_ttsrole:
|
||||
role = req.get("role")
|
||||
if not role:
|
||||
return {"status": "error", "message": "role is required for /ttsrole"}
|
||||
else:
|
||||
if not ref_audio_path:
|
||||
return {"status": "error", "message": "ref_audio_path is required"}
|
||||
if not prompt_lang:
|
||||
return {"status": "error", "message": "prompt_lang is required"}
|
||||
|
||||
languages = req.get("languages") or tts_config.languages
|
||||
if text_lang != "auto" and text_lang.lower() not in languages:
|
||||
return {"status": "error", "message": f"text_lang: {text_lang} is not supported"}
|
||||
if prompt_lang and prompt_lang.lower() not in languages:
|
||||
return {"status": "error", "message": f"prompt_lang: {prompt_lang} is not supported"}
|
||||
|
||||
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
||||
return {"status": "error", "message": f"media_type: {media_type} is not supported"}
|
||||
if media_type == "ogg" and not streaming_mode:
|
||||
return {"status": "error", "message": "ogg format is not supported in non-streaming mode"}
|
||||
if text_split_method not in cut_method_names:
|
||||
return {"status": "error", "message": f"text_split_method: {text_split_method} is not supported"}
|
||||
|
||||
return None
|
||||
|
||||
def load_role_config(role: str, req: dict):
|
||||
role_dir = os.path.join(now_dir, "roles", role)
|
||||
if not os.path.exists(role_dir):
|
||||
return False
|
||||
|
||||
if not any(req.get(k) for k in ["version", "bert_base_path", "cnhuhbert_base_path", "device", "is_half", "t2s_weights_path", "vits_weights_path"]):
|
||||
config_path_new = os.path.join(role_dir, "tts_infer.yaml")
|
||||
if os.path.exists(config_path_new):
|
||||
global tts_config, tts_pipeline
|
||||
tts_config = TTS_Config(config_path_new)
|
||||
if hasattr(tts_config, 'device'):
|
||||
tts_config.device = default_device
|
||||
tts_pipeline = TTS(tts_config)
|
||||
|
||||
if not req.get("t2s_weights_path") and not req.get("t2s_model_path"):
|
||||
gpt_path = glob.glob(os.path.join(role_dir, "*.ckpt"))
|
||||
if gpt_path:
|
||||
tts_pipeline.init_t2s_weights(gpt_path[0])
|
||||
if not req.get("vits_weights_path") and not req.get("vits_model_path"):
|
||||
sovits_path = glob.glob(os.path.join(role_dir, "*.pth"))
|
||||
if sovits_path:
|
||||
tts_pipeline.init_vits_weights(sovits_path[0])
|
||||
|
||||
return True
|
||||
|
||||
def select_ref_audio(role: str, text_lang: str, emotion: str = None):
|
||||
audio_base_dir = os.path.join(now_dir, "roles", role, "reference_audios")
|
||||
if not os.path.exists(audio_base_dir):
|
||||
return None, None, None
|
||||
|
||||
if text_lang.lower() == "auto" and emotion:
|
||||
all_langs = [d for d in os.listdir(audio_base_dir) if os.path.isdir(os.path.join(audio_base_dir, d))]
|
||||
emotion_files = []
|
||||
for lang in all_langs:
|
||||
lang_dir = os.path.join(audio_base_dir, lang)
|
||||
emotion_files.extend(glob.glob(os.path.join(lang_dir, f"【{emotion}】*.*")))
|
||||
|
||||
if emotion_files:
|
||||
audio_path = random.choice(emotion_files)
|
||||
txt_path = audio_path.rsplit(".", 1)[0] + ".txt"
|
||||
if os.path.exists(txt_path):
|
||||
with open(txt_path, "r", encoding="utf-8") as f:
|
||||
prompt_text = f.read().strip()
|
||||
else:
|
||||
basename = os.path.basename(audio_path)
|
||||
start_idx = basename.find("】") + 1
|
||||
end_idx = basename.rfind(".")
|
||||
prompt_text = basename[start_idx:end_idx] if end_idx > start_idx else basename
|
||||
|
||||
prompt_lang = os.path.basename(os.path.dirname(audio_path))
|
||||
return audio_path, prompt_text, prompt_lang
|
||||
|
||||
lang_dir = os.path.join(audio_base_dir, text_lang.lower())
|
||||
all_langs = [d for d in os.listdir(audio_base_dir) if os.path.isdir(os.path.join(audio_base_dir, d))]
|
||||
|
||||
def find_audio_in_dir(dir_path):
|
||||
if not os.path.exists(dir_path):
|
||||
return None, None
|
||||
audio_files = glob.glob(os.path.join(dir_path, "【*】*.*"))
|
||||
if not audio_files:
|
||||
audio_files = glob.glob(os.path.join(dir_path, "*.*"))
|
||||
if not audio_files:
|
||||
return None, None
|
||||
|
||||
if emotion:
|
||||
emotion_files = [f for f in audio_files if f"【{emotion}】" in os.path.basename(f)]
|
||||
if emotion_files:
|
||||
audio_path = random.choice(emotion_files)
|
||||
else:
|
||||
audio_path = random.choice(audio_files)
|
||||
else:
|
||||
audio_path = random.choice(audio_files)
|
||||
|
||||
txt_path = audio_path.rsplit(".", 1)[0] + ".txt"
|
||||
prompt_text = None
|
||||
if os.path.exists(txt_path):
|
||||
with open(txt_path, "r", encoding="utf-8") as f:
|
||||
prompt_text = f.read().strip()
|
||||
else:
|
||||
basename = os.path.basename(audio_path)
|
||||
start_idx = basename.find("】") + 1
|
||||
end_idx = basename.rfind(".")
|
||||
if start_idx > 0 and end_idx > start_idx:
|
||||
prompt_text = basename[start_idx:end_idx]
|
||||
else:
|
||||
prompt_text = basename[:end_idx] if end_idx > 0 else basename
|
||||
|
||||
return audio_path, prompt_text
|
||||
|
||||
audio_path, prompt_text = find_audio_in_dir(lang_dir)
|
||||
if audio_path:
|
||||
return audio_path, prompt_text, text_lang.lower()
|
||||
|
||||
for lang in all_langs:
|
||||
if lang != text_lang.lower():
|
||||
audio_path, prompt_text = find_audio_in_dir(os.path.join(audio_base_dir, lang))
|
||||
if audio_path:
|
||||
return audio_path, prompt_text, lang
|
||||
|
||||
return None, None, None
|
||||
|
||||
def set_pipeline_device(pipeline: TTS, device: str):
|
||||
"""将 TTS 管道中的所有模型和相关组件迁移到指定设备,仅在设备变化时执行"""
|
||||
if not torch.cuda.is_available() and device.startswith("cuda"):
|
||||
print(f"警告: CUDA 不可用,强制使用 CPU")
|
||||
device = "cpu"
|
||||
|
||||
target_device = torch.device(device)
|
||||
|
||||
# 检查当前设备是否需要切换
|
||||
current_device = None
|
||||
if hasattr(pipeline, 't2s_model') and pipeline.t2s_model is not None:
|
||||
current_device = next(pipeline.t2s_model.parameters()).device
|
||||
elif hasattr(pipeline, 'vits_model') and pipeline.vits_model is not None:
|
||||
current_device = next(pipeline.vits_model.parameters()).device
|
||||
|
||||
if current_device == target_device:
|
||||
print(f"设备已是 {device},无需切换")
|
||||
return
|
||||
|
||||
# 更新配置中的设备
|
||||
if hasattr(pipeline, 'configs') and hasattr(pipeline.configs, 'device'):
|
||||
pipeline.configs.device = device
|
||||
|
||||
# 迁移所有可能的模型到指定设备
|
||||
for attr in ['t2s_model', 'vits_model']:
|
||||
if hasattr(pipeline, attr) and getattr(pipeline, attr) is not None:
|
||||
getattr(pipeline, attr).to(target_device)
|
||||
|
||||
for attr in dir(pipeline):
|
||||
if attr.endswith('_model') and getattr(pipeline, attr) is not None:
|
||||
try:
|
||||
getattr(pipeline, attr).to(target_device)
|
||||
print(f"迁移 {attr} 到 {device}")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# 清理 GPU 缓存
|
||||
if torch.cuda.is_available() and not device.startswith("cuda"):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print(f"TTS 管道设备已设置为: {device}")
|
||||
|
||||
def run_tts_pipeline(req):
|
||||
"""在线程池中运行 TTS 任务"""
|
||||
return tts_pipeline.run(req)
|
||||
|
||||
async def tts_handle(req: dict, is_ttsrole: bool = False):
|
||||
streaming_mode = req.get("streaming_mode", False)
|
||||
media_type = req.get("media_type", "wav")
|
||||
|
||||
if "text_lang" not in req:
|
||||
req["text_lang"] = "auto"
|
||||
|
||||
check_res = check_params(req, is_ttsrole)
|
||||
if check_res is not None:
|
||||
return JSONResponse(status_code=400, content=check_res)
|
||||
|
||||
# 如果请求中指定了 device,则覆盖所有与设备相关的参数并更新管道设备
|
||||
if "device" in req and req["device"] is not None:
|
||||
device = req["device"]
|
||||
req["t2s_model_device"] = device
|
||||
req["vits_model_device"] = device
|
||||
if hasattr(tts_config, 'device'):
|
||||
tts_config.device = device
|
||||
set_pipeline_device(tts_pipeline, device)
|
||||
|
||||
if is_ttsrole:
|
||||
role_exists = load_role_config(req["role"], req)
|
||||
|
||||
for key in ["bert_base_path", "cnhuhbert_base_path", "device", "is_half", "t2s_weights_path", "version", "vits_weights_path"]:
|
||||
if req.get(key) is not None:
|
||||
setattr(tts_config, key, req[key])
|
||||
|
||||
if req.get("t2s_model_path"):
|
||||
tts_config.t2s_weights_path = req["t2s_model_path"]
|
||||
tts_pipeline.init_t2s_weights(req["t2s_model_path"])
|
||||
if req.get("vits_model_path"):
|
||||
tts_config.vits_weights_path = req["vits_model_path"]
|
||||
tts_pipeline.init_vits_weights(req["vits_model_path"])
|
||||
|
||||
if not req.get("ref_audio_path"):
|
||||
ref_audio_path, prompt_text, prompt_lang = select_ref_audio(req["role"], req["text_lang"], req.get("emotion"))
|
||||
if ref_audio_path:
|
||||
req["ref_audio_path"] = ref_audio_path
|
||||
req["prompt_text"] = prompt_text or ""
|
||||
req["prompt_lang"] = prompt_lang or req["text_lang"]
|
||||
elif not role_exists:
|
||||
return JSONResponse(status_code=400, content={"status": "error", "message": "Role directory not found and no suitable reference audio provided"})
|
||||
else:
|
||||
ref_audio_path = req["ref_audio_path"]
|
||||
txt_path = ref_audio_path.rsplit(".", 1)[0] + ".txt"
|
||||
if os.path.exists(txt_path):
|
||||
with open(txt_path, "r", encoding="utf-8") as f:
|
||||
req["prompt_text"] = f.read().strip()
|
||||
else:
|
||||
basename = os.path.basename(ref_audio_path)
|
||||
if "【" in basename and "】" in basename:
|
||||
start_idx = basename.find("】") + 1
|
||||
end_idx = basename.rfind(".")
|
||||
if start_idx > 0 and end_idx > start_idx:
|
||||
req["prompt_text"] = basename[start_idx:end_idx]
|
||||
else:
|
||||
req["prompt_text"] = basename[:end_idx] if end_idx > 0 else basename
|
||||
else:
|
||||
end_idx = basename.rfind(".")
|
||||
req["prompt_text"] = basename[:end_idx] if end_idx > 0 else basename
|
||||
|
||||
if streaming_mode:
|
||||
req["return_fragment"] = True
|
||||
|
||||
try:
|
||||
print(f"当前请求设备: {req.get('device')}")
|
||||
if hasattr(tts_pipeline, 't2s_model') and tts_pipeline.t2s_model is not None:
|
||||
print(f"t2s_model 设备: {next(tts_pipeline.t2s_model.parameters()).device}")
|
||||
if hasattr(tts_pipeline, 'vits_model') and tts_pipeline.vits_model is not None:
|
||||
print(f"vits_model 设备: {next(tts_pipeline.vits_model.parameters()).device}")
|
||||
|
||||
# 异步执行 TTS 任务
|
||||
loop = asyncio.get_event_loop()
|
||||
tts_generator = await loop.run_in_executor(executor, run_tts_pipeline, req)
|
||||
|
||||
if streaming_mode:
|
||||
def streaming_generator():
|
||||
if media_type == "wav":
|
||||
yield wave_header_chunk()
|
||||
stream_type = "raw"
|
||||
else:
|
||||
stream_type = media_type
|
||||
for sr, chunk in tts_generator:
|
||||
buf = pack_audio(chunk, sr, stream_type)
|
||||
yield buf.getvalue()
|
||||
return StreamingResponse(streaming_generator(), media_type=f"audio/{media_type}")
|
||||
else:
|
||||
sr, audio_data = next(tts_generator)
|
||||
buf = pack_audio(audio_data, sr, media_type)
|
||||
return Response(buf.getvalue(), media_type=f"audio/{media_type}")
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"status": "error", "message": "tts failed", "exception": str(e)})
|
||||
|
||||
@APP.get("/control")
|
||||
async def control(command: str = None):
|
||||
if command is None:
|
||||
return JSONResponse(status_code=400, content={"status": "error", "message": "command is required"})
|
||||
handle_control(command)
|
||||
|
||||
@APP.get("/tts")
|
||||
async def tts_get_endpoint(
|
||||
text: str,
|
||||
ref_audio_path: str,
|
||||
prompt_lang: str,
|
||||
text_lang: str = "auto",
|
||||
aux_ref_audio_paths: Optional[List[str]] = None,
|
||||
prompt_text: Optional[str] = "",
|
||||
top_k: Optional[int] = 5,
|
||||
top_p: Optional[float] = 1,
|
||||
temperature: Optional[float] = 1,
|
||||
text_split_method: Optional[str] = "cut0",
|
||||
batch_size: Optional[int] = 1,
|
||||
batch_threshold: Optional[float] = 0.75,
|
||||
split_bucket: Optional[bool] = True,
|
||||
speed_factor: Optional[float] = 1.0,
|
||||
fragment_interval: Optional[float] = 0.3,
|
||||
seed: Optional[int] = -1,
|
||||
media_type: Optional[str] = "wav",
|
||||
streaming_mode: Optional[bool] = False,
|
||||
parallel_infer: Optional[bool] = True,
|
||||
repetition_penalty: Optional[float] = 1.35,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
req = {
|
||||
"text": text,
|
||||
"text_lang": text_lang.lower(),
|
||||
"ref_audio_path": ref_audio_path,
|
||||
"aux_ref_audio_paths": aux_ref_audio_paths,
|
||||
"prompt_lang": prompt_lang.lower(),
|
||||
"prompt_text": prompt_text,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"text_split_method": text_split_method,
|
||||
"batch_size": batch_size,
|
||||
"batch_threshold": batch_threshold,
|
||||
"split_bucket": split_bucket,
|
||||
"speed_factor": speed_factor,
|
||||
"fragment_interval": fragment_interval,
|
||||
"seed": seed,
|
||||
"media_type": media_type,
|
||||
"streaming_mode": streaming_mode,
|
||||
"parallel_infer": parallel_infer,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"device": device
|
||||
}
|
||||
return await tts_handle(req)
|
||||
|
||||
@APP.post("/tts")
|
||||
async def tts_post_endpoint(request: TTS_Request):
|
||||
req = request.dict(exclude_unset=True)
|
||||
if "text_lang" in req:
|
||||
req["text_lang"] = req["text_lang"].lower()
|
||||
if "prompt_lang" in req:
|
||||
req["prompt_lang"] = req["prompt_lang"].lower()
|
||||
return await tts_handle(req)
|
||||
|
||||
@APP.get("/ttsrole")
|
||||
async def ttsrole_get_endpoint(
|
||||
text: str,
|
||||
role: str,
|
||||
text_lang: str = "auto",
|
||||
ref_audio_path: Optional[str] = None,
|
||||
aux_ref_audio_paths: Optional[List[str]] = None,
|
||||
prompt_lang: Optional[str] = None,
|
||||
prompt_text: Optional[str] = None,
|
||||
emotion: Optional[str] = None,
|
||||
top_k: Optional[int] = 5,
|
||||
top_p: Optional[float] = 1,
|
||||
temperature: Optional[float] = 1,
|
||||
text_split_method: Optional[str] = "cut5",
|
||||
batch_size: Optional[int] = 1,
|
||||
batch_threshold: Optional[float] = 0.75,
|
||||
split_bucket: Optional[bool] = True,
|
||||
speed_factor: Optional[float] = 1.0,
|
||||
fragment_interval: Optional[float] = 0.3,
|
||||
seed: Optional[int] = -1,
|
||||
media_type: Optional[str] = "wav",
|
||||
streaming_mode: Optional[bool] = False,
|
||||
parallel_infer: Optional[bool] = True,
|
||||
repetition_penalty: Optional[float] = 1.35,
|
||||
bert_base_path: Optional[str] = None,
|
||||
cnhuhbert_base_path: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
is_half: Optional[bool] = None,
|
||||
t2s_weights_path: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
vits_weights_path: Optional[str] = None,
|
||||
t2s_model_path: Optional[str] = None,
|
||||
vits_model_path: Optional[str] = None,
|
||||
t2s_model_device: Optional[str] = None,
|
||||
vits_model_device: Optional[str] = None
|
||||
):
|
||||
req = {
|
||||
"text": text,
|
||||
"role": role,
|
||||
"text_lang": text_lang.lower(),
|
||||
"ref_audio_path": ref_audio_path,
|
||||
"aux_ref_audio_paths": aux_ref_audio_paths,
|
||||
"prompt_lang": prompt_lang.lower() if prompt_lang else None,
|
||||
"prompt_text": prompt_text,
|
||||
"emotion": emotion,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"text_split_method": text_split_method,
|
||||
"batch_size": batch_size,
|
||||
"batch_threshold": batch_threshold,
|
||||
"split_bucket": split_bucket,
|
||||
"speed_factor": speed_factor,
|
||||
"fragment_interval": fragment_interval,
|
||||
"seed": seed,
|
||||
"media_type": media_type,
|
||||
"streaming_mode": streaming_mode,
|
||||
"parallel_infer": parallel_infer,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"bert_base_path": bert_base_path,
|
||||
"cnhuhbert_base_path": cnhuhbert_base_path,
|
||||
"device": device,
|
||||
"is_half": is_half,
|
||||
"t2s_weights_path": t2s_weights_path,
|
||||
"version": version,
|
||||
"vits_weights_path": vits_weights_path,
|
||||
"t2s_model_path": t2s_model_path,
|
||||
"vits_model_path": vits_model_path,
|
||||
"t2s_model_device": t2s_model_device,
|
||||
"vits_model_device": vits_model_device
|
||||
}
|
||||
return await tts_handle(req, is_ttsrole=True)
|
||||
|
||||
@APP.post("/ttsrole")
|
||||
async def ttsrole_post_endpoint(request: TTSRole_Request):
|
||||
req = request.dict(exclude_unset=True)
|
||||
if "text_lang" in req:
|
||||
req["text_lang"] = req["text_lang"].lower()
|
||||
if "prompt_lang" in req:
|
||||
req["prompt_lang"] = req["prompt_lang"].lower()
|
||||
return await tts_handle(req, is_ttsrole=True)
|
||||
|
||||
@APP.get("/set_gpt_weights")
|
||||
async def set_gpt_weights(weights_path: str = None):
|
||||
try:
|
||||
if not weights_path:
|
||||
return JSONResponse(status_code=400, content={"status": "error", "message": "gpt weight path is required"})
|
||||
tts_pipeline.init_t2s_weights(weights_path)
|
||||
tts_config.t2s_weights_path = weights_path
|
||||
return JSONResponse(status_code=200, content={"status": "success", "message": "success"})
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"status": "error", "message": f"change gpt weight failed", "exception": str(e)})
|
||||
|
||||
@APP.get("/set_sovits_weights")
|
||||
async def set_sovits_weights(weights_path: str = None):
|
||||
try:
|
||||
if not weights_path:
|
||||
return JSONResponse(status_code=400, content={"status": "error", "message": "sovits weight path is required"})
|
||||
tts_pipeline.init_vits_weights(weights_path)
|
||||
tts_config.vits_weights_path = weights_path
|
||||
return JSONResponse(status_code=200, content={"status": "success", "message": "success"})
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"status": "error", "message": f"change sovits weight failed", "exception": str(e)})
|
||||
|
||||
@APP.get("/set_refer_audio")
|
||||
async def set_refer_audio(refer_audio_path: str = None):
|
||||
try:
|
||||
if not refer_audio_path:
|
||||
return JSONResponse(status_code=400, content={"status": "error", "message": "refer audio path is required"})
|
||||
tts_pipeline.set_ref_audio(refer_audio_path)
|
||||
return JSONResponse(status_code=200, content={"status": "success", "message": "success"})
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"status": "error", "message": f"set refer audio failed", "exception": str(e)})
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
if host == 'None': # 在调用时使用 -a None 参数,可以让api监听双栈
|
||||
host = None
|
||||
uvicorn.run(app=APP, host=host, port=port, workers=1)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
exit(0)
|
||||
1020
api_role_v3.py
Normal file
1020
api_role_v3.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user