remove context manager and fix path

This commit is contained in:
Sucial 2025-02-19 17:56:51 +08:00
parent c65aa507c2
commit 8714087fde
2 changed files with 5 additions and 63 deletions

View File

@ -1,19 +1,8 @@
from functools import wraps
from packaging import version
from collections import namedtuple
import os
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
# constants
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# helpers
def exists(val):
return val is not None
@ -21,21 +10,6 @@ def exists(val):
def default(v, d):
return v if exists(v) else d
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
# main class
class Attend(nn.Module):
def __init__(
self,
@ -51,48 +25,16 @@ class Attend(nn.Module):
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# determine efficient attention configs for cuda and cpu
self.cpu_config = FlashAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
if device_version >= version.parse('8.0'):
if os.name == 'nt':
print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(False, True, True)
else:
print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(True, False, False)
else:
print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(False, True, True)
def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
# _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
if exists(self.scale):
default_scale = q.shape[-1] ** -0.5
q = q * (self.scale / default_scale)
# Check if there is a compatible device for flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.
)
return out
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
return F.scaled_dot_product_attention(q, k, v,dropout_p = self.dropout if self.training else 0.)
def forward(self, q, k, v):
"""
@ -103,7 +45,7 @@ class Attend(nn.Module):
d - feature dimension
"""
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
# q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
scale = default(self.scale, q.shape[-1] ** -0.5)

View File

@ -250,7 +250,7 @@ class Roformer_Loader:
sf.write(path, data, sr)
else:
sf.write(path, data, sr)
os.system("ffmpeg -i '{}' -vn '{}' -q:a 2 -y".format(path, path[:-3] + format))
os.system("ffmpeg -i \"{}\" -vn \"{}\" -q:a 2 -y".format(path, path[:-3] + format))
try: os.remove(path)
except: pass