import math from typing import Tuple import torch import torchaudio from torch import Tensor __all__ = [ "get_mel_banks", "inverse_mel_scale", "inverse_mel_scale_scalar", "mel_scale", "mel_scale_scalar", "spectrogram", "fbank", "fbank_onnx" "mfcc", "vtln_warp_freq", "vtln_warp_mel_freq", ] # numeric_limits::epsilon() 1.1920928955078125e-07 EPSILON = torch.tensor(torch.finfo(torch.float).eps) # 1 milliseconds = 0.001 seconds MILLISECONDS_TO_SECONDS = 0.001 # window types HAMMING = "hamming" HANNING = "hanning" POVEY = "povey" RECTANGULAR = "rectangular" BLACKMAN = "blackman" WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN] def _get_epsilon(device, dtype): return EPSILON.to(device=device, dtype=dtype) def _next_power_of_2(x: int) -> int: r"""Returns the smallest power of 2 that is greater than x""" return 1 if x == 0 else 2 ** (x - 1).bit_length() def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor: r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``) representing how the window is shifted along the waveform. Each row is a frame. Args: waveform (Tensor): Tensor of size ``num_samples`` window_size (int): Frame length window_shift (int): Frame shift snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame_length. If False, the number of frames depends only on the frame_shift, and we reflect the data at the ends. Returns: Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame """ assert waveform.dim() == 1 num_samples = waveform.size(0) strides = (window_shift * waveform.stride(0), waveform.stride(0)) if snip_edges: if num_samples < window_size: return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device) else: m = 1 + (num_samples - window_size) // window_shift else: reversed_waveform = torch.flip(waveform, [0]) m = (num_samples + (window_shift // 2)) // window_shift pad = window_size // 2 - window_shift // 2 pad_right = reversed_waveform if pad > 0: # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect' # but we want [2, 1, 0, 0, 1, 2] pad_left = reversed_waveform[-pad:] waveform = torch.cat((pad_left, waveform, pad_right), dim=0) else: # pad is negative so we want to trim the waveform at the front waveform = torch.cat((waveform[-pad:], pad_right), dim=0) sizes = (m, window_size) return waveform.as_strided(sizes, strides) def _feature_window_function( window_type: str, window_size: int, blackman_coeff: float, device: torch.device, dtype: int, ) -> Tensor: r"""Returns a window function with the given type and size""" if window_type == HANNING: return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype) elif window_type == HAMMING: return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype) elif window_type == POVEY: # like hanning but goes to zero at edges return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85) elif window_type == RECTANGULAR: return torch.ones(window_size, device=device, dtype=dtype) elif window_type == BLACKMAN: a = 2 * math.pi / (window_size - 1) window_function = torch.arange(window_size, device=device, dtype=dtype) # can't use torch.blackman_window as they use different coefficients return ( blackman_coeff - 0.5 * torch.cos(a * window_function) + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function) ).to(device=device, dtype=dtype) else: raise Exception("Invalid window type " + window_type) def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor: r"""Returns the log energy of size (m) for a strided_input (m,*)""" device, dtype = strided_input.device, strided_input.dtype log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m) if energy_floor == 0.0: return log_energy return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype)) def _get_waveform_and_window_properties( waveform: Tensor, channel: int, sample_frequency: float, frame_shift: float, frame_length: float, round_to_power_of_two: bool, preemphasis_coefficient: float, ) -> Tuple[Tensor, int, int, int]: r"""Gets the waveform and window properties""" channel = max(channel, 0) assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0)) waveform = waveform[channel, :] # size (n) window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS) window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS) padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format( window_size, len(waveform) ) assert 0 < window_shift, "`window_shift` must be greater than 0" assert padded_window_size % 2 == 0, ( "the padded `window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`" ) assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]" assert sample_frequency > 0, "`sample_frequency` must be greater than zero" return waveform, window_shift, window_size, padded_window_size def _get_window( waveform: Tensor, padded_window_size: int, window_size: int, window_shift: int, window_type: str, blackman_coeff: float, snip_edges: bool, raw_energy: bool, energy_floor: float, dither: float, remove_dc_offset: bool, preemphasis_coefficient: float, ) -> Tuple[Tensor, Tensor]: r"""Gets a window and its log energy Returns: (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m) """ device, dtype = waveform.device, waveform.dtype epsilon = _get_epsilon(device, dtype) # size (m, window_size) strided_input = _get_strided(waveform, window_size, window_shift, snip_edges) if dither != 0.0: rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype) strided_input = strided_input + rand_gauss * dither if remove_dc_offset: # Subtract each row/frame by its mean row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1) strided_input = strided_input - row_means if raw_energy: # Compute the log energy of each row/frame before applying preemphasis and # window function signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) if preemphasis_coefficient != 0.0: # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze( 0 ) # size (m, window_size + 1) strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1] # Apply window_function to each row/frame window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze( 0 ) # size (1, window_size) strided_input = strided_input * window_function # size (m, window_size) # Pad columns with zero until we reach size (m, padded_window_size) if padded_window_size != window_size: padding_right = padded_window_size - window_size strided_input = torch.nn.functional.pad( strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0 ).squeeze(0) # Compute energy after window function (not the raw one) if not raw_energy: signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) return strided_input, signal_log_energy def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor: # subtracts the column mean of the tensor size (m, n) if subtract_mean=True # it returns size (m, n) if subtract_mean: col_means = torch.mean(tensor, dim=0).unsqueeze(0) tensor = tensor - col_means return tensor def spectrogram( waveform: Tensor, blackman_coeff: float = 0.42, channel: int = -1, dither: float = 0.0, energy_floor: float = 1.0, frame_length: float = 25.0, frame_shift: float = 10.0, min_duration: float = 0.0, preemphasis_coefficient: float = 0.97, raw_energy: bool = True, remove_dc_offset: bool = True, round_to_power_of_two: bool = True, sample_frequency: float = 16000.0, snip_edges: bool = True, subtract_mean: bool = False, window_type: str = POVEY, ) -> Tensor: r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's compute-spectrogram-feats. Args: waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: this floor is applied to the zeroth component, representing the total signal energy. The floor on the individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input to FFT. (Default: ``True``) sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if specified there) (Default: ``16000.0``) snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame_length. If False, the number of frames depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do it this way. (Default: ``False``) window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') (Default: ``'povey'``) Returns: Tensor: A spectrogram identical to what Kaldi would output. The shape is (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided """ device, dtype = waveform.device, waveform.dtype epsilon = _get_epsilon(device, dtype) waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient ) if len(waveform) < min_duration * sample_frequency: # signal is too short return torch.empty(0) strided_input, signal_log_energy = _get_window( waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff, snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient, ) # size (m, padded_window_size // 2 + 1, 2) fft = torch.fft.rfft(strided_input) # Convert the FFT into a power spectrum power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1) power_spectrum[:, 0] = signal_log_energy power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) return power_spectrum def inverse_mel_scale_scalar(mel_freq: float) -> float: return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0) def inverse_mel_scale(mel_freq: Tensor) -> Tensor: return 700.0 * ((mel_freq / 1127.0).exp() - 1.0) def mel_scale_scalar(freq: float) -> float: return 1127.0 * math.log(1.0 + freq / 700.0) def mel_scale(freq: Tensor) -> Tensor: return 1127.0 * (1.0 + freq / 700.0).log() def vtln_warp_freq( vtln_low_cutoff: float, vtln_high_cutoff: float, low_freq: float, high_freq: float, vtln_warp_factor: float, freq: Tensor, ) -> Tensor: r"""This computes a VTLN warping function that is not the same as HTK's one, but has similar inputs (this function has the advantage of never producing empty bins). This function computes a warp function F(freq), defined between low_freq and high_freq inclusive, with the following properties: F(low_freq) == low_freq F(high_freq) == high_freq The function is continuous and piecewise linear with two inflection points. The lower inflection point (measured in terms of the unwarped frequency) is at frequency l, determined as described below. The higher inflection point is at a frequency h, determined as described below. If l <= f <= h, then F(f) = f/vtln_warp_factor. If the higher inflection point (measured in terms of the unwarped frequency) is at h, then max(h, F(h)) == vtln_high_cutoff. Since (by the last point) F(h) == h/vtln_warp_factor, then max(h, h/vtln_warp_factor) == vtln_high_cutoff, so h = vtln_high_cutoff / max(1, 1/vtln_warp_factor). = vtln_high_cutoff * min(1, vtln_warp_factor). If the lower inflection point (measured in terms of the unwarped frequency) is at l, then min(l, F(l)) == vtln_low_cutoff This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor) = vtln_low_cutoff * max(1, vtln_warp_factor) Args: vtln_low_cutoff (float): Lower frequency cutoffs for VTLN vtln_high_cutoff (float): Upper frequency cutoffs for VTLN low_freq (float): Lower frequency cutoffs in mel computation high_freq (float): Upper frequency cutoffs in mel computation vtln_warp_factor (float): Vtln warp factor freq (Tensor): given frequency in Hz Returns: Tensor: Freq after vtln warp """ assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq" assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]" l = vtln_low_cutoff * max(1.0, vtln_warp_factor) h = vtln_high_cutoff * min(1.0, vtln_warp_factor) scale = 1.0 / vtln_warp_factor Fl = scale * l # F(l) Fh = scale * h # F(h) assert l > low_freq and h < high_freq # slope of left part of the 3-piece linear function scale_left = (Fl - low_freq) / (l - low_freq) # [slope of center part is just "scale"] # slope of right part of the 3-piece linear function scale_right = (high_freq - Fh) / (high_freq - h) res = torch.empty_like(freq) outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq before_l = torch.lt(freq, l) # freq < l before_h = torch.lt(freq, h) # freq < h after_h = torch.ge(freq, h) # freq >= h # order of operations matter here (since there is overlapping frequency regions) res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq) res[before_h] = scale * freq[before_h] res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq) res[outside_low_high_freq] = freq[outside_low_high_freq] return res def vtln_warp_mel_freq( vtln_low_cutoff: float, vtln_high_cutoff: float, low_freq, high_freq: float, vtln_warp_factor: float, mel_freq: Tensor, ) -> Tensor: r""" Args: vtln_low_cutoff (float): Lower frequency cutoffs for VTLN vtln_high_cutoff (float): Upper frequency cutoffs for VTLN low_freq (float): Lower frequency cutoffs in mel computation high_freq (float): Upper frequency cutoffs in mel computation vtln_warp_factor (float): Vtln warp factor mel_freq (Tensor): Given frequency in Mel Returns: Tensor: ``mel_freq`` after vtln warp """ return mel_scale( vtln_warp_freq( vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq) ) ) def get_mel_banks( num_bins: int, window_length_padded: int, sample_freq: float, low_freq: float, high_freq: float, vtln_low: float, vtln_high: float, vtln_warp_factor: float, device=None, dtype=None, ) -> Tuple[Tensor, Tensor]: """ Returns: (Tensor, Tensor): The tuple consists of ``bins`` (which is melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is center frequencies of bins of size (``num_bins``)). """ assert num_bins > 3, "Must have at least 3 mel bins" assert window_length_padded % 2 == 0 num_fft_bins = window_length_padded / 2 nyquist = 0.5 * sample_freq if high_freq <= 0.0: high_freq += nyquist assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), ( "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist) ) # fft-bin width [think of it as Nyquist-freq / half-window-length] fft_bin_width = sample_freq / window_length_padded mel_low_freq = mel_scale_scalar(low_freq) mel_high_freq = mel_scale_scalar(high_freq) # divide by num_bins+1 in next line because of end-effects where the bins # spread out to the sides. mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) if vtln_high < 0.0: vtln_high += nyquist assert vtln_warp_factor == 1.0 or ( (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high) ), "Bad values in options: vtln-low {} and vtln-high {}, versus low-freq {} and high-freq {}".format( vtln_low, vtln_high, low_freq, high_freq ) bin = torch.arange(num_bins).unsqueeze(1) left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1) center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1) right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1) if vtln_warp_factor != 1.0: left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel) center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel) right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel) # center_freqs = inverse_mel_scale(center_mel) # size (num_bins) # size(1, num_fft_bins) mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0) # size (num_bins, num_fft_bins) up_slope = (mel - left_mel) / (center_mel - left_mel) down_slope = (right_mel - mel) / (right_mel - center_mel) if vtln_warp_factor == 1.0: # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope)) else: # warping can move the order of left_mel, center_mel, right_mel anywhere bins = torch.zeros_like(up_slope) up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel bins[up_idx] = up_slope[up_idx] bins[down_idx] = down_slope[down_idx] return bins.to(device=device, dtype=dtype) # , center_freqs cache = {} def fbank( waveform: Tensor, blackman_coeff: float = 0.42, channel: int = -1, dither: float = 0.0, energy_floor: float = 1.0, frame_length: float = 25.0, frame_shift: float = 10.0, high_freq: float = 0.0, htk_compat: bool = False, low_freq: float = 20.0, min_duration: float = 0.0, num_mel_bins: int = 23, preemphasis_coefficient: float = 0.97, raw_energy: bool = True, remove_dc_offset: bool = True, round_to_power_of_two: bool = True, sample_frequency: float = 16000.0, snip_edges: bool = True, subtract_mean: bool = False, use_energy: bool = False, use_log_fbank: bool = True, use_power: bool = True, vtln_high: float = -500.0, vtln_low: float = 100.0, vtln_warp: float = 1.0, window_type: str = POVEY, ) -> Tensor: r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's compute-fbank-feats. Args: waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: this floor is applied to the zeroth component, representing the total signal energy. The floor on the individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (Default: ``0.0``) htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features (need to change other parameters). (Default: ``False``) low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input to FFT. (Default: ``True``) sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if specified there) (Default: ``16000.0``) snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame_length. If False, the number of frames depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do it this way. (Default: ``False``) use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``) use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``) vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if negative, offset from high-mel-freq (Default: ``-500.0``) vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') (Default: ``'povey'``) Returns: Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``) where m is calculated in _get_strided """ device, dtype = waveform.device, waveform.dtype waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient ) if len(waveform) < min_duration * sample_frequency: # signal is too short return torch.empty(0, device=device, dtype=dtype) # strided_input, size (m, padded_window_size) and signal_log_energy, size (m) strided_input, signal_log_energy = _get_window( waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff, snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient, ) # size (m, padded_window_size // 2 + 1) spectrum = torch.fft.rfft(strided_input).abs() if use_power: spectrum = spectrum.pow(2.0) # size (num_mel_bins, padded_window_size // 2) # print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp) cache_key = "%s-%s-%s-%s-%s-%s-%s-%s-%s-%s" % ( num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp, device, dtype, ) if cache_key not in cache: mel_energies = get_mel_banks( num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp, device, dtype, ) cache[cache_key] = mel_energies else: mel_energies = cache[cache_key] # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1) mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0) # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) mel_energies = torch.mm(spectrum, mel_energies.T) if use_log_fbank: # avoid log of zero (which should be prevented anyway by dithering) mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() # if use_energy then add it as the last column for htk_compat == true else first column if use_energy: signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1) # returns size (m, num_mel_bins + 1) if htk_compat: mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1) else: mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1) mel_energies = _subtract_column_mean(mel_energies, subtract_mean) return mel_energies def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor: # returns a dct matrix of size (num_mel_bins, num_ceps) # size (num_mel_bins, num_mel_bins) dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho") # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins) # this would be the first column in the dct_matrix for torchaudio as it expects a # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi # expects a left multiply e.g. dct_matrix * vector). dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins)) dct_matrix = dct_matrix[:, :num_ceps] return dct_matrix def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor: # returns size (num_ceps) # Compute liftering coefficients (scaling on cepstral coeffs) # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected. i = torch.arange(num_ceps) return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter) def mfcc( waveform: Tensor, blackman_coeff: float = 0.42, cepstral_lifter: float = 22.0, channel: int = -1, dither: float = 0.0, energy_floor: float = 1.0, frame_length: float = 25.0, frame_shift: float = 10.0, high_freq: float = 0.0, htk_compat: bool = False, low_freq: float = 20.0, num_ceps: int = 13, min_duration: float = 0.0, num_mel_bins: int = 23, preemphasis_coefficient: float = 0.97, raw_energy: bool = True, remove_dc_offset: bool = True, round_to_power_of_two: bool = True, sample_frequency: float = 16000.0, snip_edges: bool = True, subtract_mean: bool = False, use_energy: bool = False, vtln_high: float = -500.0, vtln_low: float = 100.0, vtln_warp: float = 1.0, window_type: str = POVEY, ) -> Tensor: r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's compute-mfcc-feats. Args: waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``) channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: this floor is applied to the zeroth component, representing the total signal energy. The floor on the individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (Default: ``0.0``) htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features (need to change other parameters). (Default: ``False``) low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``) min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input to FFT. (Default: ``True``) sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if specified there) (Default: ``16000.0``) snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame_length. If False, the number of frames depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do it this way. (Default: ``False``) use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if negative, offset from high-mel-freq (Default: ``-500.0``) vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') (Default: ``"povey"``) Returns: Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``) where m is calculated in _get_strided """ assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins) device, dtype = waveform.device, waveform.dtype # The mel_energies should not be squared (use_power=True), not have mean subtracted # (subtract_mean=False), and use log (use_log_fbank=True). # size (m, num_mel_bins + use_energy) feature = fbank( waveform=waveform, blackman_coeff=blackman_coeff, channel=channel, dither=dither, energy_floor=energy_floor, frame_length=frame_length, frame_shift=frame_shift, high_freq=high_freq, htk_compat=htk_compat, low_freq=low_freq, min_duration=min_duration, num_mel_bins=num_mel_bins, preemphasis_coefficient=preemphasis_coefficient, raw_energy=raw_energy, remove_dc_offset=remove_dc_offset, round_to_power_of_two=round_to_power_of_two, sample_frequency=sample_frequency, snip_edges=snip_edges, subtract_mean=False, use_energy=use_energy, use_log_fbank=True, use_power=True, vtln_high=vtln_high, vtln_low=vtln_low, vtln_warp=vtln_warp, window_type=window_type, ) if use_energy: # size (m) signal_log_energy = feature[:, num_mel_bins if htk_compat else 0] # offset is 0 if htk_compat==True else 1 mel_offset = int(not htk_compat) feature = feature[:, mel_offset : (num_mel_bins + mel_offset)] # size (num_mel_bins, num_ceps) dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device) # size (m, num_ceps) feature = feature.matmul(dct_matrix) if cepstral_lifter != 0.0: # size (1, num_ceps) lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0) feature *= lifter_coeffs.to(device=device, dtype=dtype) # if use_energy then replace the last column for htk_compat == true else first column if use_energy: feature[:, 0] = signal_log_energy if htk_compat: energy = feature[:, 0].unsqueeze(1) # size (m, 1) feature = feature[:, 1:] # size (m, num_ceps - 1) if not use_energy: # scale on C0 (actually removing a scale we previously added that's # part of one common definition of the cosine transform.) energy *= math.sqrt(2) feature = torch.cat((feature, energy), dim=1) feature = _subtract_column_mean(feature, subtract_mean) return feature def _get_log_energy_onnx(strided_input: Tensor, epsilon: Tensor, energy_floor: float = 1.0) -> Tensor: r"""Returns the log energy of size (m) for a strided_input (m,*)""" device, dtype = strided_input.device, strided_input.dtype log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m) return torch.max(log_energy, torch.tensor(0.0, device=device, dtype=dtype)) def _get_waveform_and_window_properties_onnx( waveform: Tensor, ) -> Tuple[Tensor, int, int, int]: r"""ONNX-compatible version with hardcoded parameters from traced fbank call: channel=-1, sample_frequency=16000, frame_shift=10.0, frame_length=25.0, round_to_power_of_two=True, preemphasis_coefficient=0.97""" # Hardcoded values from traced parameters # channel=-1 -> 0 after max(channel, 0) channel = 0 # Extract channel 0 from waveform if waveform.dim() == 1: # Mono waveform, use as-is waveform_selected = waveform else: # Multi-channel, select first channel waveform_selected = waveform[channel, :] # Hardcoded calculations: # window_shift = int(16000 * 10.0 * 0.001) = 160 # window_size = int(16000 * 25.0 * 0.001) = 400 # padded_window_size = _next_power_of_2(400) = 512 window_shift = 160 window_size = 400 padded_window_size = 512 return waveform_selected, window_shift, window_size, padded_window_size def _get_window_onnx( waveform: Tensor, ) -> Tuple[Tensor, Tensor]: r"""ONNX-compatible version with hardcoded parameters from traced fbank call: padded_window_size=512, window_size=400, window_shift=160, window_type='povey', blackman_coeff=0.42, snip_edges=True, raw_energy=True, energy_floor=1.0, dither=0, remove_dc_offset=True, preemphasis_coefficient=0.97 Returns: (Tensor, Tensor): strided_input of size (m, 512) and signal_log_energy of size (m) """ device, dtype = waveform.device, waveform.dtype epsilon = _get_epsilon(device, dtype) # Hardcoded values from traced parameters window_size = 400 window_shift = 160 padded_window_size = 512 snip_edges = True # size (m, window_size) strided_input = _get_strided_onnx(waveform, window_size, window_shift, snip_edges) # dither=0, so skip dithering (lines 209-211 from original) # remove_dc_offset=True, so execute this branch row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1) strided_input = strided_input - row_means # raw_energy=True, so execute this branch signal_log_energy = _get_log_energy_onnx(strided_input, epsilon) # energy_floor=1.0 # preemphasis_coefficient=0.97 != 0.0, so execute this branch offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(0) strided_input = strided_input - 0.97 * offset_strided_input[:, :-1] # Apply povey window function to each row/frame # povey window: torch.hann_window(window_size, periodic=False).pow(0.85) window_function = torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85).unsqueeze(0) strided_input = strided_input * window_function # Pad columns from window_size=400 to padded_window_size=512 padding_right = padded_window_size - window_size # 112 strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0).squeeze(0) # raw_energy=True, so skip the "not raw_energy" branch (lines 244-245) return strided_input, signal_log_energy def _get_strided_onnx(waveform: Tensor, window_size = 400, window_shift = 160, snip_edges = 512) -> Tensor: seq_len = waveform.size(0) # Calculate number of windows num_windows = 1 + (seq_len - window_size) // window_shift # Create indices for all windows at once window_starts = torch.arange(0, num_windows * window_shift, window_shift, device=waveform.device) window_indices = window_starts.unsqueeze(1) + torch.arange(window_size, device=waveform.device).unsqueeze(0) # Extract windows using advanced indexing windows = waveform[window_indices] # [num_windows, window_size] return windows def _subtract_column_mean_onnx(tensor: Tensor) -> Tensor: """ONNX-compatible version with hardcoded parameters from traced fbank call: subtract_mean=False, so this function returns the input tensor unchanged. Args: tensor: Input tensor of size (m, n) Returns: Tensor: Same as input tensor (m, n) since subtract_mean=False """ # subtract_mean=False from traced parameters, so return tensor as-is return tensor def get_mel_banks_onnx( device=None, dtype=None, ) -> Tensor: """ONNX-compatible version with hardcoded parameters from traced fbank call: num_bins=80, window_length_padded=512, sample_freq=16000, low_freq=20.0, high_freq=0.0, vtln_low=100.0, vtln_high=-500.0, vtln_warp_factor=1.0 Returns: Tensor: melbank of size (80, 256) (num_bins, num_fft_bins) """ # Hardcoded values from traced parameters num_bins = 80 window_length_padded = 512 sample_freq = 16000.0 low_freq = 20.0 high_freq = 0.0 # Will be adjusted to nyquist vtln_warp_factor = 1.0 # Calculate dynamic values to ensure accuracy num_fft_bins = window_length_padded // 2 # 256 (integer division) nyquist = 0.5 * sample_freq # 8000.0 # high_freq <= 0.0, so high_freq += nyquist if high_freq <= 0.0: high_freq += nyquist # 8000.0 # fft-bin width = sample_freq / window_length_padded = 16000 / 512 = 31.25 fft_bin_width = sample_freq / window_length_padded # Calculate mel scale values dynamically mel_low_freq = mel_scale_scalar(low_freq) mel_high_freq = mel_scale_scalar(high_freq) # mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) # vtln_warp_factor == 1.0, so no VTLN warping needed # Create mel bin centers bin_indices = torch.arange(num_bins, device=device, dtype=dtype).unsqueeze(1) left_mel = mel_low_freq + bin_indices * mel_freq_delta center_mel = mel_low_freq + (bin_indices + 1.0) * mel_freq_delta right_mel = mel_low_freq + (bin_indices + 2.0) * mel_freq_delta # No VTLN warping since vtln_warp_factor == 1.0 # Create frequency bins for FFT fft_freqs = fft_bin_width * torch.arange(num_fft_bins, device=device, dtype=dtype) mel = mel_scale(fft_freqs).unsqueeze(0) # size(1, num_fft_bins) # Calculate triangular filter banks up_slope = (mel - left_mel) / (center_mel - left_mel) down_slope = (right_mel - mel) / (right_mel - center_mel) # Since vtln_warp_factor == 1.0, use the simpler branch bins = torch.max(torch.zeros(1, device=device, dtype=dtype), torch.min(up_slope, down_slope)) return bins def fbank_onnx( waveform: Tensor, num_mel_bins=80, sample_frequency=16000, dither=0 ) -> Tensor: r"""ONNX-compatible fbank function with hardcoded parameters from traced call: num_mel_bins=80, sample_frequency=16000, dither=0 All other parameters use their traced default values. Args: waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) Returns: Tensor: A fbank identical to what Kaldi would output. The shape is (m, 80) where m is calculated in _get_strided """ device, dtype = waveform.device, waveform.dtype # Use ONNX-compatible version of _get_waveform_and_window_properties waveform_selected, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties_onnx(waveform) # min_duration=0.0, so skip the duration check (signal will never be too short) # Use ONNX-compatible version of _get_window strided_input, signal_log_energy = _get_window_onnx(waveform_selected) # spectrum = torch.fft.rfft(strided_input).abs() m, frame_size = strided_input.shape # Process all frames at once using batch processing # Reshape to (m, 1, frame_size) to treat each frame as a separate batch item batched_frames = strided_input.unsqueeze(1) # Shape: (m, 1, 512) # Create rectangular window for all frames at once rectangular_window = torch.ones(512, device=strided_input.device, dtype=strided_input.dtype) # Apply STFT to all frames simultaneously # The batch dimension allows us to process all m frames in parallel stft_result = torch.stft( batched_frames.flatten(0, 1), # Shape: (m, 512) - flatten batch and channel dims n_fft=512, hop_length=512, # Process entire frame at once window=rectangular_window, center=False, # Don't add padding return_complex=False ) # stft_result shape: (m, 257, 1, 2) where last dim is [real, imag] # Calculate magnitude: sqrt(real^2 + imag^2) real_part = stft_result[..., 0] # Shape: (m, 257, 1) imag_part = stft_result[..., 1] # Shape: (m, 257, 1) spectrum = torch.sqrt(real_part.pow(2) + imag_part.pow(2)).squeeze(-1) # Shape: (m, 257) # use_power=True, so execute this branch spectrum = spectrum.pow(2.0) # Get mel filterbanks using ONNX-compatible version mel_energies = get_mel_banks_onnx(device, dtype) # pad right column with zeros to match FFT output size (80, 256) -> (80, 257) mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0) # sum with mel filterbanks over the power spectrum, size (m, 80) mel_energies = torch.mm(spectrum, mel_energies.T) # use_log_fbank=True, so execute this branch mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() # use_energy=False, so skip the energy addition (lines 828-834) # Use ONNX-compatible version of _subtract_column_mean mel_energies = _subtract_column_mean_onnx(mel_energies) return mel_energies # Test to compare original fbank vs fbank_onnx if __name__ == "__main__": import torch print("Testing fbank vs fbank_onnx with traced parameters...") # Create test waveform torch.manual_seed(42) sample_rate = 16000 duration = 1.0 # 1 second num_samples = int(sample_rate * duration) # Create a test waveform (sine wave + noise) t = torch.linspace(0, duration, num_samples) frequency = 440.0 # A4 note waveform = torch.sin(2 * torch.pi * frequency * t) + 0.1 * torch.randn(num_samples) # Test with both mono and stereo inputs mono_waveform = waveform.unsqueeze(0) # Shape: (1, num_samples) print(f"Test waveform shape: {mono_waveform.shape}") # Test parameters from trace: num_mel_bins=80, sample_frequency=16000, dither=0 try: print("\n=== DEBUGGING: Step-by-step comparison ===") # Step 1: Check waveform processing orig_waveform, orig_window_shift, orig_window_size, orig_padded_window_size = _get_waveform_and_window_properties( mono_waveform, -1, 16000.0, 10.0, 25.0, True, 0.97 ) onnx_waveform, onnx_window_shift, onnx_window_size, onnx_padded_window_size = _get_waveform_and_window_properties_onnx(mono_waveform) print(f"Original waveform shape: {orig_waveform.shape}") print(f"ONNX waveform shape: {onnx_waveform.shape}") print(f"Waveform difference: {torch.max(torch.abs(orig_waveform - onnx_waveform)).item():.2e}") print(f"Window params - orig: shift={orig_window_shift}, size={orig_window_size}, padded={orig_padded_window_size}") print(f"Window params - onnx: shift={onnx_window_shift}, size={onnx_window_size}, padded={onnx_padded_window_size}") # Step 2: Check windowing orig_strided, orig_energy = _get_window( orig_waveform, orig_padded_window_size, orig_window_size, orig_window_shift, 'povey', 0.42, True, True, 1.0, 0, True, 0.97 ) onnx_strided, onnx_energy = _get_window_onnx(onnx_waveform) print(f"\nOriginal strided shape: {orig_strided.shape}") print(f"ONNX strided shape: {onnx_strided.shape}") print(f"Strided difference: {torch.max(torch.abs(orig_strided - onnx_strided)).item():.2e}") print(f"Energy difference: {torch.max(torch.abs(orig_energy - onnx_energy)).item():.2e}") # Step 3: Check mel banks orig_mel_banks = get_mel_banks(80, 512, 16000.0, 20.0, 0.0, 100.0, -500.0, 1.0, mono_waveform.device, mono_waveform.dtype) onnx_mel_banks = get_mel_banks_onnx(mono_waveform.device, mono_waveform.dtype) print(f"\nOriginal mel banks shape: {orig_mel_banks.shape}") print(f"ONNX mel banks shape: {onnx_mel_banks.shape}") print(f"Mel banks difference: {torch.max(torch.abs(orig_mel_banks - onnx_mel_banks)).item():.2e}") # Step 4: Full comparison print("\n=== FULL COMPARISON ===") # Original fbank original_result = fbank( mono_waveform, num_mel_bins=80, sample_frequency=16000, dither=0 ) # ONNX-compatible fbank onnx_result = fbank_onnx(mono_waveform) print(f"Original fbank output shape: {original_result.shape}") print(f"ONNX fbank output shape: {onnx_result.shape}") # Check if shapes match if original_result.shape == onnx_result.shape: print("✅ Output shapes match") else: print("❌ Output shapes don't match") print(f" Original: {original_result.shape}") print(f" ONNX: {onnx_result.shape}") # Check numerical differences if original_result.shape == onnx_result.shape: diff = torch.abs(original_result - onnx_result) max_diff = torch.max(diff).item() mean_diff = torch.mean(diff).item() relative_diff = torch.mean(diff / (torch.abs(original_result) + 1e-8)).item() print(f"Max absolute difference: {max_diff:.2e}") print(f"Mean absolute difference: {mean_diff:.2e}") print(f"Mean relative difference: {relative_diff:.2e}") # Find where the max difference occurs max_idx = torch.argmax(diff) max_coords = torch.unravel_index(max_idx, diff.shape) print(f"Max difference at coordinates: {max_coords}") print(f" Original value: {original_result[max_coords].item():.6f}") print(f" ONNX value: {onnx_result[max_coords].item():.6f}") # Check if results are numerically close tolerance = 1e-5 if max_diff < tolerance: print(f"✅ Results are numerically identical (within {tolerance})") else: print(f"❌ Results {max_diff} differ by more than {tolerance}") # Additional statistics print(f"Original result range: [{torch.min(original_result).item():.3f}, {torch.max(original_result).item():.3f}]") print(f"ONNX result range: [{torch.min(onnx_result).item():.3f}, {torch.max(onnx_result).item():.3f}]") except Exception as e: print(f"❌ Error during testing: {e}") import traceback traceback.print_exc()