2025-08-24 02:00:32 -04:00

1234 lines
52 KiB
Python

import math
from typing import Tuple
import torch
import torchaudio
from torch import Tensor
__all__ = [
"get_mel_banks",
"inverse_mel_scale",
"inverse_mel_scale_scalar",
"mel_scale",
"mel_scale_scalar",
"spectrogram",
"fbank",
"fbank_onnx"
"mfcc",
"vtln_warp_freq",
"vtln_warp_mel_freq",
]
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
# 1 milliseconds = 0.001 seconds
MILLISECONDS_TO_SECONDS = 0.001
# window types
HAMMING = "hamming"
HANNING = "hanning"
POVEY = "povey"
RECTANGULAR = "rectangular"
BLACKMAN = "blackman"
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
def _get_epsilon(device, dtype):
return EPSILON.to(device=device, dtype=dtype)
def _next_power_of_2(x: int) -> int:
r"""Returns the smallest power of 2 that is greater than x"""
return 1 if x == 0 else 2 ** (x - 1).bit_length()
def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
representing how the window is shifted along the waveform. Each row is a frame.
Args:
waveform (Tensor): Tensor of size ``num_samples``
window_size (int): Frame length
window_shift (int): Frame shift
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends.
Returns:
Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
"""
assert waveform.dim() == 1
num_samples = waveform.size(0)
strides = (window_shift * waveform.stride(0), waveform.stride(0))
if snip_edges:
if num_samples < window_size:
return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
else:
m = 1 + (num_samples - window_size) // window_shift
else:
reversed_waveform = torch.flip(waveform, [0])
m = (num_samples + (window_shift // 2)) // window_shift
pad = window_size // 2 - window_shift // 2
pad_right = reversed_waveform
if pad > 0:
# torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
# but we want [2, 1, 0, 0, 1, 2]
pad_left = reversed_waveform[-pad:]
waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
else:
# pad is negative so we want to trim the waveform at the front
waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
sizes = (m, window_size)
return waveform.as_strided(sizes, strides)
def _feature_window_function(
window_type: str,
window_size: int,
blackman_coeff: float,
device: torch.device,
dtype: int,
) -> Tensor:
r"""Returns a window function with the given type and size"""
if window_type == HANNING:
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
elif window_type == HAMMING:
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
elif window_type == POVEY:
# like hanning but goes to zero at edges
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
elif window_type == RECTANGULAR:
return torch.ones(window_size, device=device, dtype=dtype)
elif window_type == BLACKMAN:
a = 2 * math.pi / (window_size - 1)
window_function = torch.arange(window_size, device=device, dtype=dtype)
# can't use torch.blackman_window as they use different coefficients
return (
blackman_coeff
- 0.5 * torch.cos(a * window_function)
+ (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
).to(device=device, dtype=dtype)
else:
raise Exception("Invalid window type " + window_type)
def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
device, dtype = strided_input.device, strided_input.dtype
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
if energy_floor == 0.0:
return log_energy
return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
def _get_waveform_and_window_properties(
waveform: Tensor,
channel: int,
sample_frequency: float,
frame_shift: float,
frame_length: float,
round_to_power_of_two: bool,
preemphasis_coefficient: float,
) -> Tuple[Tensor, int, int, int]:
r"""Gets the waveform and window properties"""
channel = max(channel, 0)
assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
waveform = waveform[channel, :] # size (n)
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
window_size, len(waveform)
)
assert 0 < window_shift, "`window_shift` must be greater than 0"
assert padded_window_size % 2 == 0, (
"the padded `window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`"
)
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
return waveform, window_shift, window_size, padded_window_size
def _get_window(
waveform: Tensor,
padded_window_size: int,
window_size: int,
window_shift: int,
window_type: str,
blackman_coeff: float,
snip_edges: bool,
raw_energy: bool,
energy_floor: float,
dither: float,
remove_dc_offset: bool,
preemphasis_coefficient: float,
) -> Tuple[Tensor, Tensor]:
r"""Gets a window and its log energy
Returns:
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
"""
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
# size (m, window_size)
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
if dither != 0.0:
rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
strided_input = strided_input + rand_gauss * dither
if remove_dc_offset:
# Subtract each row/frame by its mean
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
strided_input = strided_input - row_means
if raw_energy:
# Compute the log energy of each row/frame before applying preemphasis and
# window function
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
if preemphasis_coefficient != 0.0:
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
0
) # size (m, window_size + 1)
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
# Apply window_function to each row/frame
window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
0
) # size (1, window_size)
strided_input = strided_input * window_function # size (m, window_size)
# Pad columns with zero until we reach size (m, padded_window_size)
if padded_window_size != window_size:
padding_right = padded_window_size - window_size
strided_input = torch.nn.functional.pad(
strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
).squeeze(0)
# Compute energy after window function (not the raw one)
if not raw_energy:
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
return strided_input, signal_log_energy
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
# it returns size (m, n)
if subtract_mean:
col_means = torch.mean(tensor, dim=0).unsqueeze(0)
tensor = tensor - col_means
return tensor
def spectrogram(
waveform: Tensor,
blackman_coeff: float = 0.42,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
min_duration: float = 0.0,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
window_type: str = POVEY,
) -> Tensor:
r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
compute-spectrogram-feats.
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
(Default: ``'povey'``)
Returns:
Tensor: A spectrogram identical to what Kaldi would output. The shape is
(m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
"""
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return torch.empty(0)
strided_input, signal_log_energy = _get_window(
waveform,
padded_window_size,
window_size,
window_shift,
window_type,
blackman_coeff,
snip_edges,
raw_energy,
energy_floor,
dither,
remove_dc_offset,
preemphasis_coefficient,
)
# size (m, padded_window_size // 2 + 1, 2)
fft = torch.fft.rfft(strided_input)
# Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
power_spectrum[:, 0] = signal_log_energy
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
return power_spectrum
def inverse_mel_scale_scalar(mel_freq: float) -> float:
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
def mel_scale_scalar(freq: float) -> float:
return 1127.0 * math.log(1.0 + freq / 700.0)
def mel_scale(freq: Tensor) -> Tensor:
return 1127.0 * (1.0 + freq / 700.0).log()
def vtln_warp_freq(
vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq: float,
high_freq: float,
vtln_warp_factor: float,
freq: Tensor,
) -> Tensor:
r"""This computes a VTLN warping function that is not the same as HTK's one,
but has similar inputs (this function has the advantage of never producing
empty bins).
This function computes a warp function F(freq), defined between low_freq
and high_freq inclusive, with the following properties:
F(low_freq) == low_freq
F(high_freq) == high_freq
The function is continuous and piecewise linear with two inflection
points.
The lower inflection point (measured in terms of the unwarped
frequency) is at frequency l, determined as described below.
The higher inflection point is at a frequency h, determined as
described below.
If l <= f <= h, then F(f) = f/vtln_warp_factor.
If the higher inflection point (measured in terms of the unwarped
frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
Since (by the last point) F(h) == h/vtln_warp_factor, then
max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
= vtln_high_cutoff * min(1, vtln_warp_factor).
If the lower inflection point (measured in terms of the unwarped
frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
= vtln_low_cutoff * max(1, vtln_warp_factor)
Args:
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
low_freq (float): Lower frequency cutoffs in mel computation
high_freq (float): Upper frequency cutoffs in mel computation
vtln_warp_factor (float): Vtln warp factor
freq (Tensor): given frequency in Hz
Returns:
Tensor: Freq after vtln warp
"""
assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
scale = 1.0 / vtln_warp_factor
Fl = scale * l # F(l)
Fh = scale * h # F(h)
assert l > low_freq and h < high_freq
# slope of left part of the 3-piece linear function
scale_left = (Fl - low_freq) / (l - low_freq)
# [slope of center part is just "scale"]
# slope of right part of the 3-piece linear function
scale_right = (high_freq - Fh) / (high_freq - h)
res = torch.empty_like(freq)
outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
before_l = torch.lt(freq, l) # freq < l
before_h = torch.lt(freq, h) # freq < h
after_h = torch.ge(freq, h) # freq >= h
# order of operations matter here (since there is overlapping frequency regions)
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
res[before_h] = scale * freq[before_h]
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
res[outside_low_high_freq] = freq[outside_low_high_freq]
return res
def vtln_warp_mel_freq(
vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq,
high_freq: float,
vtln_warp_factor: float,
mel_freq: Tensor,
) -> Tensor:
r"""
Args:
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
low_freq (float): Lower frequency cutoffs in mel computation
high_freq (float): Upper frequency cutoffs in mel computation
vtln_warp_factor (float): Vtln warp factor
mel_freq (Tensor): Given frequency in Mel
Returns:
Tensor: ``mel_freq`` after vtln warp
"""
return mel_scale(
vtln_warp_freq(
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
)
)
def get_mel_banks(
num_bins: int,
window_length_padded: int,
sample_freq: float,
low_freq: float,
high_freq: float,
vtln_low: float,
vtln_high: float,
vtln_warp_factor: float,
device=None,
dtype=None,
) -> Tuple[Tensor, Tensor]:
"""
Returns:
(Tensor, Tensor): The tuple consists of ``bins`` (which is
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
center frequencies of bins of size (``num_bins``)).
"""
assert num_bins > 3, "Must have at least 3 mel bins"
assert window_length_padded % 2 == 0
num_fft_bins = window_length_padded / 2
nyquist = 0.5 * sample_freq
if high_freq <= 0.0:
high_freq += nyquist
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), (
"Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
)
# fft-bin width [think of it as Nyquist-freq / half-window-length]
fft_bin_width = sample_freq / window_length_padded
mel_low_freq = mel_scale_scalar(low_freq)
mel_high_freq = mel_scale_scalar(high_freq)
# divide by num_bins+1 in next line because of end-effects where the bins
# spread out to the sides.
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
if vtln_high < 0.0:
vtln_high += nyquist
assert vtln_warp_factor == 1.0 or (
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
), "Bad values in options: vtln-low {} and vtln-high {}, versus low-freq {} and high-freq {}".format(
vtln_low, vtln_high, low_freq, high_freq
)
bin = torch.arange(num_bins).unsqueeze(1)
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
if vtln_warp_factor != 1.0:
left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
# center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
# size(1, num_fft_bins)
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
# size (num_bins, num_fft_bins)
up_slope = (mel - left_mel) / (center_mel - left_mel)
down_slope = (right_mel - mel) / (right_mel - center_mel)
if vtln_warp_factor == 1.0:
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
else:
# warping can move the order of left_mel, center_mel, right_mel anywhere
bins = torch.zeros_like(up_slope)
up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
bins[up_idx] = up_slope[up_idx]
bins[down_idx] = down_slope[down_idx]
return bins.to(device=device, dtype=dtype) # , center_freqs
cache = {}
def fbank(
waveform: Tensor,
blackman_coeff: float = 0.42,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
min_duration: float = 0.0,
num_mel_bins: int = 23,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
use_log_fbank: bool = True,
use_power: bool = True,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY,
) -> Tensor:
r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
compute-fbank-feats.
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
(Default: ``0.0``)
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
(need to change other parameters). (Default: ``False``)
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
negative, offset from high-mel-freq (Default: ``-500.0``)
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
(Default: ``'povey'``)
Returns:
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
where m is calculated in _get_strided
"""
device, dtype = waveform.device, waveform.dtype
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return torch.empty(0, device=device, dtype=dtype)
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
strided_input, signal_log_energy = _get_window(
waveform,
padded_window_size,
window_size,
window_shift,
window_type,
blackman_coeff,
snip_edges,
raw_energy,
energy_floor,
dither,
remove_dc_offset,
preemphasis_coefficient,
)
# size (m, padded_window_size // 2 + 1)
spectrum = torch.fft.rfft(strided_input).abs()
if use_power:
spectrum = spectrum.pow(2.0)
# size (num_mel_bins, padded_window_size // 2)
# print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
cache_key = "%s-%s-%s-%s-%s-%s-%s-%s-%s-%s" % (
num_mel_bins,
padded_window_size,
sample_frequency,
low_freq,
high_freq,
vtln_low,
vtln_high,
vtln_warp,
device,
dtype,
)
if cache_key not in cache:
mel_energies = get_mel_banks(
num_mel_bins,
padded_window_size,
sample_frequency,
low_freq,
high_freq,
vtln_low,
vtln_high,
vtln_warp,
device,
dtype,
)
cache[cache_key] = mel_energies
else:
mel_energies = cache[cache_key]
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
mel_energies = torch.mm(spectrum, mel_energies.T)
if use_log_fbank:
# avoid log of zero (which should be prevented anyway by dithering)
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
# if use_energy then add it as the last column for htk_compat == true else first column
if use_energy:
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
# returns size (m, num_mel_bins + 1)
if htk_compat:
mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
else:
mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
return mel_energies
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
# returns a dct matrix of size (num_mel_bins, num_ceps)
# size (num_mel_bins, num_mel_bins)
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
# this would be the first column in the dct_matrix for torchaudio as it expects a
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
# expects a left multiply e.g. dct_matrix * vector).
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
dct_matrix = dct_matrix[:, :num_ceps]
return dct_matrix
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
# returns size (num_ceps)
# Compute liftering coefficients (scaling on cepstral coeffs)
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
i = torch.arange(num_ceps)
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
def mfcc(
waveform: Tensor,
blackman_coeff: float = 0.42,
cepstral_lifter: float = 22.0,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
num_ceps: int = 13,
min_duration: float = 0.0,
num_mel_bins: int = 23,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY,
) -> Tensor:
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
compute-mfcc-feats.
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
(Default: ``0.0``)
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
features (need to change other parameters). (Default: ``False``)
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
negative, offset from high-mel-freq (Default: ``-500.0``)
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
(Default: ``"povey"``)
Returns:
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
where m is calculated in _get_strided
"""
assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
device, dtype = waveform.device, waveform.dtype
# The mel_energies should not be squared (use_power=True), not have mean subtracted
# (subtract_mean=False), and use log (use_log_fbank=True).
# size (m, num_mel_bins + use_energy)
feature = fbank(
waveform=waveform,
blackman_coeff=blackman_coeff,
channel=channel,
dither=dither,
energy_floor=energy_floor,
frame_length=frame_length,
frame_shift=frame_shift,
high_freq=high_freq,
htk_compat=htk_compat,
low_freq=low_freq,
min_duration=min_duration,
num_mel_bins=num_mel_bins,
preemphasis_coefficient=preemphasis_coefficient,
raw_energy=raw_energy,
remove_dc_offset=remove_dc_offset,
round_to_power_of_two=round_to_power_of_two,
sample_frequency=sample_frequency,
snip_edges=snip_edges,
subtract_mean=False,
use_energy=use_energy,
use_log_fbank=True,
use_power=True,
vtln_high=vtln_high,
vtln_low=vtln_low,
vtln_warp=vtln_warp,
window_type=window_type,
)
if use_energy:
# size (m)
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
# offset is 0 if htk_compat==True else 1
mel_offset = int(not htk_compat)
feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
# size (num_mel_bins, num_ceps)
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
# size (m, num_ceps)
feature = feature.matmul(dct_matrix)
if cepstral_lifter != 0.0:
# size (1, num_ceps)
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
feature *= lifter_coeffs.to(device=device, dtype=dtype)
# if use_energy then replace the last column for htk_compat == true else first column
if use_energy:
feature[:, 0] = signal_log_energy
if htk_compat:
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
feature = feature[:, 1:] # size (m, num_ceps - 1)
if not use_energy:
# scale on C0 (actually removing a scale we previously added that's
# part of one common definition of the cosine transform.)
energy *= math.sqrt(2)
feature = torch.cat((feature, energy), dim=1)
feature = _subtract_column_mean(feature, subtract_mean)
return feature
def _get_log_energy_onnx(strided_input: Tensor, epsilon: Tensor, energy_floor: float = 1.0) -> Tensor:
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
device, dtype = strided_input.device, strided_input.dtype
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
return torch.max(log_energy, torch.tensor(0.0, device=device, dtype=dtype))
def _get_waveform_and_window_properties_onnx(
waveform: Tensor,
) -> Tuple[Tensor, int, int, int]:
r"""ONNX-compatible version with hardcoded parameters from traced fbank call:
channel=-1, sample_frequency=16000, frame_shift=10.0, frame_length=25.0,
round_to_power_of_two=True, preemphasis_coefficient=0.97"""
# Hardcoded values from traced parameters
# channel=-1 -> 0 after max(channel, 0)
channel = 0
# Extract channel 0 from waveform
if waveform.dim() == 1:
# Mono waveform, use as-is
waveform_selected = waveform
else:
# Multi-channel, select first channel
waveform_selected = waveform[channel, :]
# Hardcoded calculations:
# window_shift = int(16000 * 10.0 * 0.001) = 160
# window_size = int(16000 * 25.0 * 0.001) = 400
# padded_window_size = _next_power_of_2(400) = 512
window_shift = 160
window_size = 400
padded_window_size = 512
return waveform_selected, window_shift, window_size, padded_window_size
def _get_window_onnx(
waveform: Tensor,
) -> Tuple[Tensor, Tensor]:
r"""ONNX-compatible version with hardcoded parameters from traced fbank call:
padded_window_size=512, window_size=400, window_shift=160, window_type='povey',
blackman_coeff=0.42, snip_edges=True, raw_energy=True, energy_floor=1.0,
dither=0, remove_dc_offset=True, preemphasis_coefficient=0.97
Returns:
(Tensor, Tensor): strided_input of size (m, 512) and signal_log_energy of size (m)
"""
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
# Hardcoded values from traced parameters
window_size = 400
window_shift = 160
padded_window_size = 512
snip_edges = True
# size (m, window_size)
strided_input = _get_strided_onnx(waveform, window_size, window_shift, snip_edges)
# dither=0, so skip dithering (lines 209-211 from original)
# remove_dc_offset=True, so execute this branch
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
strided_input = strided_input - row_means
# raw_energy=True, so execute this branch
signal_log_energy = _get_log_energy_onnx(strided_input, epsilon) # energy_floor=1.0
# preemphasis_coefficient=0.97 != 0.0, so execute this branch
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(0)
strided_input = strided_input - 0.97 * offset_strided_input[:, :-1]
# Apply povey window function to each row/frame
# povey window: torch.hann_window(window_size, periodic=False).pow(0.85)
window_function = torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85).unsqueeze(0)
strided_input = strided_input * window_function
# Pad columns from window_size=400 to padded_window_size=512
padding_right = padded_window_size - window_size # 112
strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0).squeeze(0)
# raw_energy=True, so skip the "not raw_energy" branch (lines 244-245)
return strided_input, signal_log_energy
def _get_strided_onnx(waveform: Tensor, window_size = 400, window_shift = 160, snip_edges = 512) -> Tensor:
seq_len = waveform.size(0)
# Calculate number of windows
num_windows = 1 + (seq_len - window_size) // window_shift
# Create indices for all windows at once
window_starts = torch.arange(0, num_windows * window_shift, window_shift, device=waveform.device)
window_indices = window_starts.unsqueeze(1) + torch.arange(window_size, device=waveform.device).unsqueeze(0)
# Extract windows using advanced indexing
windows = waveform[window_indices] # [num_windows, window_size]
return windows
def _subtract_column_mean_onnx(tensor: Tensor) -> Tensor:
"""ONNX-compatible version with hardcoded parameters from traced fbank call:
subtract_mean=False, so this function returns the input tensor unchanged.
Args:
tensor: Input tensor of size (m, n)
Returns:
Tensor: Same as input tensor (m, n) since subtract_mean=False
"""
# subtract_mean=False from traced parameters, so return tensor as-is
return tensor
def get_mel_banks_onnx(
device=None,
dtype=None,
) -> Tensor:
"""ONNX-compatible version with hardcoded parameters from traced fbank call:
num_bins=80, window_length_padded=512, sample_freq=16000, low_freq=20.0,
high_freq=0.0, vtln_low=100.0, vtln_high=-500.0, vtln_warp_factor=1.0
Returns:
Tensor: melbank of size (80, 256) (num_bins, num_fft_bins)
"""
# Hardcoded values from traced parameters
num_bins = 80
window_length_padded = 512
sample_freq = 16000.0
low_freq = 20.0
high_freq = 0.0 # Will be adjusted to nyquist
vtln_warp_factor = 1.0
# Calculate dynamic values to ensure accuracy
num_fft_bins = window_length_padded // 2 # 256 (integer division)
nyquist = 0.5 * sample_freq # 8000.0
# high_freq <= 0.0, so high_freq += nyquist
if high_freq <= 0.0:
high_freq += nyquist # 8000.0
# fft-bin width = sample_freq / window_length_padded = 16000 / 512 = 31.25
fft_bin_width = sample_freq / window_length_padded
# Calculate mel scale values dynamically
mel_low_freq = mel_scale_scalar(low_freq)
mel_high_freq = mel_scale_scalar(high_freq)
# mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
# vtln_warp_factor == 1.0, so no VTLN warping needed
# Create mel bin centers
bin_indices = torch.arange(num_bins, device=device, dtype=dtype).unsqueeze(1)
left_mel = mel_low_freq + bin_indices * mel_freq_delta
center_mel = mel_low_freq + (bin_indices + 1.0) * mel_freq_delta
right_mel = mel_low_freq + (bin_indices + 2.0) * mel_freq_delta
# No VTLN warping since vtln_warp_factor == 1.0
# Create frequency bins for FFT
fft_freqs = fft_bin_width * torch.arange(num_fft_bins, device=device, dtype=dtype)
mel = mel_scale(fft_freqs).unsqueeze(0) # size(1, num_fft_bins)
# Calculate triangular filter banks
up_slope = (mel - left_mel) / (center_mel - left_mel)
down_slope = (right_mel - mel) / (right_mel - center_mel)
# Since vtln_warp_factor == 1.0, use the simpler branch
bins = torch.max(torch.zeros(1, device=device, dtype=dtype), torch.min(up_slope, down_slope))
return bins
def fbank_onnx(
waveform: Tensor, num_mel_bins=80, sample_frequency=16000, dither=0
) -> Tensor:
r"""ONNX-compatible fbank function with hardcoded parameters from traced call:
num_mel_bins=80, sample_frequency=16000, dither=0
blackman_coeff: float = 0.42,
channel: int = -1,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
min_duration: float = 0.0,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
use_log_fbank: bool = True,
use_power: bool = True,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
Returns:
Tensor: A fbank identical to what Kaldi would output. The shape is (m, 80)
where m is calculated in _get_strided
"""
device, dtype = waveform.device, waveform.dtype
# Use ONNX-compatible version of _get_waveform_and_window_properties
waveform_selected, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties_onnx(waveform)
# min_duration=0.0, so skip the duration check (signal will never be too short)
# Use ONNX-compatible version of _get_window
strided_input, signal_log_energy = _get_window_onnx(waveform_selected)
# spectrum = torch.fft.rfft(strided_input).abs()
m, frame_size = strided_input.shape
# Process all frames at once using batch processing
# Reshape to (m, 1, frame_size) to treat each frame as a separate batch item
batched_frames = strided_input.unsqueeze(1) # Shape: (m, 1, 512)
# Create rectangular window for all frames at once
rectangular_window = torch.ones(512, device=strided_input.device, dtype=strided_input.dtype)
# Apply STFT to all frames simultaneously
# The batch dimension allows us to process all m frames in parallel
stft_result = torch.stft(
batched_frames.flatten(0, 1), # Shape: (m, 512) - flatten batch and channel dims
n_fft=512,
hop_length=512, # Process entire frame at once
window=rectangular_window,
center=False, # Don't add padding
return_complex=False
)
# stft_result shape: (m, 257, 1, 2) where last dim is [real, imag]
# Calculate magnitude: sqrt(real^2 + imag^2)
real_part = stft_result[..., 0] # Shape: (m, 257, 1)
imag_part = stft_result[..., 1] # Shape: (m, 257, 1)
spectrum = torch.sqrt(real_part.pow(2) + imag_part.pow(2)).squeeze(-1) # Shape: (m, 257)
# use_power=True, so execute this branch
spectrum = spectrum.pow(2.0)
# Get mel filterbanks using ONNX-compatible version
mel_energies = get_mel_banks_onnx(device, dtype)
# pad right column with zeros to match FFT output size (80, 256) -> (80, 257)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
# sum with mel filterbanks over the power spectrum, size (m, 80)
mel_energies = torch.mm(spectrum, mel_energies.T)
# use_log_fbank=True, so execute this branch
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
# use_energy=False, so skip the energy addition (lines 828-834)
# Use ONNX-compatible version of _subtract_column_mean
mel_energies = _subtract_column_mean_onnx(mel_energies)
return mel_energies
# Test to compare original fbank vs fbank_onnx
if __name__ == "__main__":
import torch
print("Testing fbank vs fbank_onnx with traced parameters...")
# Create test waveform
torch.manual_seed(42)
sample_rate = 16000
duration = 1.0 # 1 second
num_samples = int(sample_rate * duration)
# Create a test waveform (sine wave + noise)
t = torch.linspace(0, duration, num_samples)
frequency = 440.0 # A4 note
waveform = torch.sin(2 * torch.pi * frequency * t) + 0.1 * torch.randn(num_samples)
# Test with both mono and stereo inputs
mono_waveform = waveform.unsqueeze(0) # Shape: (1, num_samples)
print(f"Test waveform shape: {mono_waveform.shape}")
# Test parameters from trace: num_mel_bins=80, sample_frequency=16000, dither=0
try:
print("\n=== DEBUGGING: Step-by-step comparison ===")
# Step 1: Check waveform processing
orig_waveform, orig_window_shift, orig_window_size, orig_padded_window_size = _get_waveform_and_window_properties(
mono_waveform, -1, 16000.0, 10.0, 25.0, True, 0.97
)
onnx_waveform, onnx_window_shift, onnx_window_size, onnx_padded_window_size = _get_waveform_and_window_properties_onnx(mono_waveform)
print(f"Original waveform shape: {orig_waveform.shape}")
print(f"ONNX waveform shape: {onnx_waveform.shape}")
print(f"Waveform difference: {torch.max(torch.abs(orig_waveform - onnx_waveform)).item():.2e}")
print(f"Window params - orig: shift={orig_window_shift}, size={orig_window_size}, padded={orig_padded_window_size}")
print(f"Window params - onnx: shift={onnx_window_shift}, size={onnx_window_size}, padded={onnx_padded_window_size}")
# Step 2: Check windowing
orig_strided, orig_energy = _get_window(
orig_waveform, orig_padded_window_size, orig_window_size, orig_window_shift,
'povey', 0.42, True, True, 1.0, 0, True, 0.97
)
onnx_strided, onnx_energy = _get_window_onnx(onnx_waveform)
print(f"\nOriginal strided shape: {orig_strided.shape}")
print(f"ONNX strided shape: {onnx_strided.shape}")
print(f"Strided difference: {torch.max(torch.abs(orig_strided - onnx_strided)).item():.2e}")
print(f"Energy difference: {torch.max(torch.abs(orig_energy - onnx_energy)).item():.2e}")
# Step 3: Check mel banks
orig_mel_banks = get_mel_banks(80, 512, 16000.0, 20.0, 0.0, 100.0, -500.0, 1.0, mono_waveform.device, mono_waveform.dtype)
onnx_mel_banks = get_mel_banks_onnx(mono_waveform.device, mono_waveform.dtype)
print(f"\nOriginal mel banks shape: {orig_mel_banks.shape}")
print(f"ONNX mel banks shape: {onnx_mel_banks.shape}")
print(f"Mel banks difference: {torch.max(torch.abs(orig_mel_banks - onnx_mel_banks)).item():.2e}")
# Step 4: Full comparison
print("\n=== FULL COMPARISON ===")
# Original fbank
original_result = fbank(
mono_waveform,
num_mel_bins=80,
sample_frequency=16000,
dither=0
)
# ONNX-compatible fbank
onnx_result = fbank_onnx(mono_waveform)
print(f"Original fbank output shape: {original_result.shape}")
print(f"ONNX fbank output shape: {onnx_result.shape}")
# Check if shapes match
if original_result.shape == onnx_result.shape:
print("✅ Output shapes match")
else:
print("❌ Output shapes don't match")
print(f" Original: {original_result.shape}")
print(f" ONNX: {onnx_result.shape}")
# Check numerical differences
if original_result.shape == onnx_result.shape:
diff = torch.abs(original_result - onnx_result)
max_diff = torch.max(diff).item()
mean_diff = torch.mean(diff).item()
relative_diff = torch.mean(diff / (torch.abs(original_result) + 1e-8)).item()
print(f"Max absolute difference: {max_diff:.2e}")
print(f"Mean absolute difference: {mean_diff:.2e}")
print(f"Mean relative difference: {relative_diff:.2e}")
# Find where the max difference occurs
max_idx = torch.argmax(diff)
max_coords = torch.unravel_index(max_idx, diff.shape)
print(f"Max difference at coordinates: {max_coords}")
print(f" Original value: {original_result[max_coords].item():.6f}")
print(f" ONNX value: {onnx_result[max_coords].item():.6f}")
# Check if results are numerically close
tolerance = 1e-5
if max_diff < tolerance:
print(f"✅ Results are numerically identical (within {tolerance})")
else:
print(f"❌ Results {max_diff} differ by more than {tolerance}")
# Additional statistics
print(f"Original result range: [{torch.min(original_result).item():.3f}, {torch.max(original_result).item():.3f}]")
print(f"ONNX result range: [{torch.min(onnx_result).item():.3f}, {torch.max(onnx_result).item():.3f}]")
except Exception as e:
print(f"❌ Error during testing: {e}")
import traceback
traceback.print_exc()