remove context manager and fix path

This commit is contained in:
Sucial 2025-02-19 17:56:51 +08:00
parent c65aa507c2
commit 8714087fde
2 changed files with 5 additions and 63 deletions

View File

@ -1,19 +1,8 @@
from functools import wraps
from packaging import version from packaging import version
from collections import namedtuple
import os
import torch import torch
from torch import nn, einsum from torch import nn, einsum
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, reduce
# constants
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# helpers
def exists(val): def exists(val):
return val is not None return val is not None
@ -21,21 +10,6 @@ def exists(val):
def default(v, d): def default(v, d):
return v if exists(v) else d return v if exists(v) else d
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
# main class
class Attend(nn.Module): class Attend(nn.Module):
def __init__( def __init__(
self, self,
@ -51,48 +25,16 @@ class Attend(nn.Module):
self.flash = flash self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# determine efficient attention configs for cuda and cpu
self.cpu_config = FlashAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
if device_version >= version.parse('8.0'):
if os.name == 'nt':
print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(False, True, True)
else:
print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(True, False, False)
else:
print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(False, True, True)
def flash_attn(self, q, k, v): def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device # _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
if exists(self.scale): if exists(self.scale):
default_scale = q.shape[-1] ** -0.5 default_scale = q.shape[-1] ** -0.5
q = q * (self.scale / default_scale) q = q * (self.scale / default_scale)
# Check if there is a compatible device for flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
with torch.backends.cuda.sdp_kernel(**config._asdict()): return F.scaled_dot_product_attention(q, k, v,dropout_p = self.dropout if self.training else 0.)
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.
)
return out
def forward(self, q, k, v): def forward(self, q, k, v):
""" """
@ -103,7 +45,7 @@ class Attend(nn.Module):
d - feature dimension d - feature dimension
""" """
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device # q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
scale = default(self.scale, q.shape[-1] ** -0.5) scale = default(self.scale, q.shape[-1] ** -0.5)

View File

@ -250,7 +250,7 @@ class Roformer_Loader:
sf.write(path, data, sr) sf.write(path, data, sr)
else: else:
sf.write(path, data, sr) sf.write(path, data, sr)
os.system("ffmpeg -i '{}' -vn '{}' -q:a 2 -y".format(path, path[:-3] + format)) os.system("ffmpeg -i \"{}\" -vn \"{}\" -q:a 2 -y".format(path, path[:-3] + format))
try: os.remove(path) try: os.remove(path)
except: pass except: pass