revert "rfft fallback to cpu"

This commit is contained in:
Kakaru 2025-10-21 17:11:54 +08:00 committed by GitHub
parent 917f73c38c
commit 3a92c046f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,8 +5,6 @@ import torch
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
import musa_utils
__all__ = [ __all__ = [
"get_mel_banks", "get_mel_banks",
"inverse_mel_scale", "inverse_mel_scale",
@ -307,12 +305,7 @@ def spectrogram(
) )
# size (m, padded_window_size // 2 + 1, 2) # size (m, padded_window_size // 2 + 1, 2)
if musa_utils.is_available() : # 怎么还有算子不支持怎么还有算子不支持怎么还有算子不支持
ori_device = strided_input.device
strided_input = strided_input.cpu()
fft = torch.fft.rfft(strided_input) fft = torch.fft.rfft(strided_input)
if musa_utils.is_available() :
fft = fft.to(ori_device)
# Convert the FFT into a power spectrum # Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1) power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
@ -625,12 +618,7 @@ def fbank(
) )
# size (m, padded_window_size // 2 + 1) # size (m, padded_window_size // 2 + 1)
if musa_utils.is_available() : # 怎么还有算子不支持怎么还有算子不支持怎么还有算子不支持
ori_device = strided_input.device
strided_input = strided_input.cpu()
spectrum = torch.fft.rfft(strided_input).abs() spectrum = torch.fft.rfft(strided_input).abs()
if musa_utils.is_available() :
spectrum = spectrum.to(ori_device)
if use_power: if use_power:
spectrum = spectrum.pow(2.0) spectrum = spectrum.pow(2.0)