mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
411 lines
14 KiB
Python
411 lines
14 KiB
Python
import sys
|
|
import io
|
|
import os
|
|
import re
|
|
import json
|
|
import tarfile
|
|
from functools import partial
|
|
|
|
import webdataset as wds
|
|
from webdataset import ResampledShards, DataPipeline, tarfile_to_samples
|
|
from webdataset.filters import pipelinefilter
|
|
from webdataset.tariterators import url_opener, group_by_keys
|
|
from webdataset.handlers import reraise_exception
|
|
from webdataset.gopen import gopen_schemes, gopen
|
|
|
|
|
|
def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress
|
|
"""Return node and worker info for PyTorch and some distributed environments."""
|
|
rank = 0
|
|
world_size = 1
|
|
worker = 0
|
|
num_workers = 1
|
|
try:
|
|
import torch.distributed
|
|
|
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
group = group or torch.distributed.group.WORLD
|
|
rank = torch.distributed.get_rank(group=group)
|
|
world_size = torch.distributed.get_world_size(group=group)
|
|
except ModuleNotFoundError:
|
|
pass
|
|
try:
|
|
import torch.utils.data
|
|
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
if worker_info is not None:
|
|
worker = worker_info.id
|
|
num_workers = worker_info.num_workers
|
|
except ModuleNotFoundError:
|
|
pass
|
|
|
|
return rank, world_size, worker, num_workers
|
|
|
|
|
|
def pytorch_worker_seed(group=None):
|
|
"""Compute a distinct, deterministic RNG seed for each worker and node."""
|
|
rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
|
|
return rank * 1000 + worker
|
|
|
|
|
|
def worker_seed_sat(group=None, seed=0):
|
|
return pytorch_worker_seed(group=group) + seed * 23
|
|
|
|
|
|
class ConfiguredResampledShards(ResampledShards):
|
|
def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True):
|
|
from sat.helpers import print_rank0
|
|
|
|
try:
|
|
from megatron.core.parallel_state import get_data_parallel_group
|
|
|
|
group = get_data_parallel_group()
|
|
print_rank0("Using megatron data parallel group.")
|
|
except:
|
|
from sat.mpu import get_data_parallel_group
|
|
|
|
try:
|
|
group = get_data_parallel_group()
|
|
print_rank0("Using sat data parallel group.")
|
|
except AssertionError:
|
|
group = None
|
|
print_rank0("No data parallel group is specified!")
|
|
worker_seed_sat_this = partial(worker_seed_sat, group=group, seed=seed)
|
|
super().__init__(urls, nshards, worker_seed_sat_this, deterministic)
|
|
|
|
|
|
class SimpleDistributedWebDataset(DataPipeline):
|
|
def __init__(self, path, process_fn, seed, *, shuffle_buffer=1000):
|
|
# set shuffle_buffer = 1 to disable it, model-parallel will be different due to shuffle
|
|
try:
|
|
from sat.mpu import get_model_parallel_world_size
|
|
|
|
if get_model_parallel_world_size() > 1:
|
|
shuffle_buffer = 1
|
|
except Exception:
|
|
pass
|
|
super().__init__(
|
|
ConfiguredResampledShards(path, seed), # Lots of shards are recommended, or not evenly
|
|
tarfile_to_samples(),
|
|
wds.shuffle(shuffle_buffer),
|
|
process_fn,
|
|
)
|
|
|
|
|
|
def tar_file_iterator_with_meta(
|
|
fileobj,
|
|
meta_names,
|
|
skip_meta=r"__[^/]*__($|/)",
|
|
suffix=None,
|
|
handler=reraise_exception,
|
|
meta_stream=None,
|
|
):
|
|
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
|
|
|
|
:param fileobj: byte stream suitable for tarfile
|
|
:param meta_names: key of different items in meta file
|
|
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
|
|
|
|
"""
|
|
stream = tarfile.open(fileobj=fileobj, mode="r|*")
|
|
data_dir, filename = fileobj.name.rsplit("/", 1)
|
|
meta_data = {} # {id: {meta_name: meta_value, meta_name2: meta_value2, ...}}
|
|
|
|
if meta_stream is None:
|
|
meta_file_name = filename.split(".")[0] + ".meta.jsonl"
|
|
meta_path = os.path.join(data_dir, meta_file_name)
|
|
if os.path.exists(meta_path):
|
|
meta_stream = open(meta_path, "r")
|
|
else:
|
|
meta_file_name = meta_stream.name
|
|
|
|
if meta_stream is not None:
|
|
for lineno, line in enumerate(meta_stream):
|
|
meta_list = []
|
|
try:
|
|
meta_list.append(json.loads(line))
|
|
except Exception as exn:
|
|
from sat.helpers import print_rank0
|
|
|
|
print_rank0(
|
|
f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}",
|
|
level="DEBUG",
|
|
)
|
|
continue
|
|
for item in meta_list:
|
|
if item["key"] not in meta_data:
|
|
meta_data[item["key"]] = {}
|
|
for meta_name in meta_names:
|
|
if meta_name in item:
|
|
meta_data[item["key"]][meta_name] = item[meta_name]
|
|
meta_stream.close()
|
|
|
|
try:
|
|
for tarinfo in stream:
|
|
fname = tarinfo.name
|
|
try:
|
|
if not tarinfo.isreg():
|
|
continue
|
|
if fname is None:
|
|
continue
|
|
if "/" not in fname and fname.startswith("__") and fname.endswith("__"):
|
|
# skipping metadata for now
|
|
continue
|
|
if skip_meta is not None and re.match(skip_meta, fname):
|
|
continue
|
|
if fname.endswith(".txt") and suffix is not None:
|
|
data = (stream.extractfile(tarinfo).read().decode() + suffix).encode()
|
|
else:
|
|
data = stream.extractfile(tarinfo).read()
|
|
result = dict(fname=fname, data=data)
|
|
yield result
|
|
|
|
if fname.endswith(".id"):
|
|
fid = fname.split(".")[0]
|
|
if "-$#%@&" in fid:
|
|
sfid = fid.split("-$#%@&")[0]
|
|
else:
|
|
sfid = fid
|
|
meta_data_fid = meta_data.get(sfid, {})
|
|
for meta_name in meta_names:
|
|
meta_fname = fid + "." + meta_name
|
|
meta = meta_data_fid.get(meta_name, None)
|
|
yield dict(fname=meta_fname, data=meta)
|
|
stream.members = []
|
|
except Exception as exn:
|
|
if hasattr(exn, "args") and len(exn.args) > 0:
|
|
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
|
|
if handler(exn):
|
|
continue
|
|
else:
|
|
break
|
|
except Exception as exn:
|
|
print(exn)
|
|
del stream
|
|
|
|
|
|
def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception):
|
|
"""Expand a stream of open tar files into a stream of tar file contents.
|
|
|
|
This returns an iterator over (filename, file_contents).
|
|
"""
|
|
for source in data:
|
|
url = source["url"]
|
|
try:
|
|
assert isinstance(source, dict)
|
|
assert "stream" in source
|
|
for sample in tar_file_iterator_with_meta(
|
|
source["stream"], meta_names, meta_stream=source["meta_stream"]
|
|
):
|
|
assert isinstance(sample, dict) and "data" in sample and "fname" in sample
|
|
sample["__url__"] = url
|
|
yield sample
|
|
except Exception as exn:
|
|
exn.args = exn.args + (source.get("stream"), source.get("url"))
|
|
if handler(exn):
|
|
continue
|
|
else:
|
|
break
|
|
|
|
|
|
def url_opener(
|
|
data,
|
|
handler,
|
|
**kw,
|
|
):
|
|
"""Open URLs and yield a stream of url+stream pairs.
|
|
|
|
Args:
|
|
data: iterator over dict(url=...)
|
|
handler: exception handler.
|
|
kw: keyword arguments for gopen.gopen.
|
|
|
|
Yields:
|
|
a stream of url+stream pairs.
|
|
"""
|
|
for sample in data:
|
|
assert isinstance(sample, dict), sample
|
|
assert "url" in sample
|
|
url = sample["url"]
|
|
try:
|
|
stream = gopen(url, **kw)
|
|
if hasattr(stream, "meta_stream"):
|
|
meta_stream = stream.meta_stream
|
|
del stream.meta_stream
|
|
else:
|
|
meta_stream = None
|
|
sample.update(stream=stream, meta_stream=meta_stream)
|
|
yield sample
|
|
except Exception as exn:
|
|
exn.args = exn.args + (url,)
|
|
if handler(exn):
|
|
continue
|
|
else:
|
|
break
|
|
|
|
|
|
def tarfile_samples_with_meta(src, meta_names, handler=reraise_exception):
|
|
streams = url_opener(src, handler=handler)
|
|
files = tar_file_expander_with_meta(streams, meta_names, handler)
|
|
samples = group_by_keys(files, handler=handler)
|
|
return samples
|
|
|
|
|
|
class MetaDistributedWebDataset(DataPipeline):
|
|
"""WebDataset with meta information files
|
|
Extra Format:
|
|
in webdataset (tar), for each sample there is a '.id';
|
|
for each tar file, there is a '.meta.jsonl' file with the same name;
|
|
The '.meta.jsonl' file contains lines of json objects, each with a 'key' field to match '.id'.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
path,
|
|
process_fn,
|
|
seed,
|
|
*,
|
|
meta_names=[],
|
|
nshards=sys.maxsize,
|
|
shuffle_buffer=1000,
|
|
include_dirs=None,
|
|
):
|
|
# os.environ['WDS_SHOW_SEED'] = '1'
|
|
import torch
|
|
|
|
if torch.distributed.get_rank() == 0:
|
|
if include_dirs is not None: # /webdatasets/A,/webdatasets/C
|
|
other_paths = []
|
|
include_dirs = include_dirs.split(",")
|
|
for include_dir in include_dirs:
|
|
if "*" in include_dir:
|
|
include_dir, n = include_dir.split("*")
|
|
n = int(n)
|
|
else:
|
|
n = 1
|
|
for cur_dir, dirs, files in os.walk(include_dir):
|
|
for f in files:
|
|
if f.endswith("tar") and os.path.getsize(os.path.join(cur_dir, f)) > 0:
|
|
# other_paths.append(os.path.join(cur_dir,f))
|
|
other_paths.extend([os.path.join(cur_dir, f)] * n)
|
|
# print(f'Adding dataset paths {",".join(other_paths)}')
|
|
from braceexpand import braceexpand
|
|
|
|
if len(path) > 0: # not ""
|
|
path = list(braceexpand(path)) + other_paths
|
|
else:
|
|
path = other_paths
|
|
path = [path]
|
|
else:
|
|
path = [
|
|
None,
|
|
]
|
|
torch.distributed.broadcast_object_list(path, src=0)
|
|
path = path[0]
|
|
|
|
tarfile_samples = partial(tarfile_samples_with_meta, meta_names=meta_names)
|
|
tarfile_to_samples = pipelinefilter(tarfile_samples)
|
|
|
|
# if model parallel, shuffle_buffer should be 1 to disable shuffling
|
|
try:
|
|
from sat.mpu import get_model_parallel_world_size
|
|
|
|
if get_model_parallel_world_size() > 1:
|
|
shuffle_buffer = 1
|
|
except Exception:
|
|
pass
|
|
|
|
super().__init__(
|
|
ConfiguredResampledShards(path, seed, nshards=nshards),
|
|
tarfile_to_samples(),
|
|
wds.shuffle(shuffle_buffer),
|
|
process_fn,
|
|
)
|
|
|
|
|
|
# rclone support
|
|
from webdataset.gopen import Pipe
|
|
|
|
|
|
def gopen_rclone(url, mode="rb", bufsize=1024 * 1024 * 32):
|
|
"""Open a URL with `curl`.
|
|
|
|
:param url: rclone url, e.g. data:bucket1/foo.tar. data should be configured.
|
|
:param mode: file mode
|
|
:param bufsize: buffer size
|
|
"""
|
|
url = url.replace("rclone://", "")
|
|
if mode[0] == "r":
|
|
cmd = f"rclone cat '{url}'"
|
|
return Pipe(
|
|
cmd,
|
|
mode=mode,
|
|
shell=True,
|
|
bufsize=bufsize,
|
|
ignore_status=[141, 23],
|
|
) # skipcq: BAN-B604
|
|
elif mode[0] == "w":
|
|
cmd = f"rclone cp - '{url}'"
|
|
return Pipe(
|
|
cmd,
|
|
mode=mode,
|
|
shell=True,
|
|
bufsize=bufsize,
|
|
ignore_status=[141, 26],
|
|
) # skipcq: BAN-B604
|
|
else:
|
|
raise ValueError(f"{mode}: unknown mode")
|
|
|
|
|
|
def gopen_boto3(url, mode="rb", bufsize=8192 * 2):
|
|
"""Open a URL with boto3 API.
|
|
|
|
:param url: boto3 url, e.g. boto3://bucket1/foo.tar. data should be configured.
|
|
:param mode: file mode
|
|
:param bufsize: buffer size
|
|
"""
|
|
import boto3
|
|
|
|
# boto3.set_stream_logger('botocore', level='DEBUG')
|
|
if url.startswith("boto3://"):
|
|
url = url.replace("boto3://", "")
|
|
need_meta = False
|
|
else:
|
|
url = url.replace("metaboto3://", "")
|
|
need_meta = True
|
|
endpoint_url = os.environ.get("S3_ENDPOINT_URL", None)
|
|
access_key = os.environ.get("S3_ACCESS_KEY_ID", None)
|
|
secret_key = os.environ.get("S3_SECRET_ACCESS_KEY", None)
|
|
|
|
if mode[0] == "r":
|
|
s3_client = boto3.client(
|
|
"s3",
|
|
endpoint_url=endpoint_url,
|
|
aws_access_key_id=access_key,
|
|
aws_secret_access_key=secret_key,
|
|
)
|
|
bucket, key = url.split("/", 1)
|
|
|
|
if need_meta:
|
|
# download a meta json
|
|
meta_file_key = key.split(".")[0] + ".meta.jsonl"
|
|
meta_stream = io.BytesIO()
|
|
s3_client.download_fileobj(bucket, meta_file_key, meta_stream)
|
|
meta_stream.seek(0)
|
|
meta_stream.name = meta_file_key
|
|
else:
|
|
meta_stream = None
|
|
|
|
# data tar stream
|
|
response = s3_client.get_object(Bucket=bucket, Key=key) # Range optional
|
|
response["Body"].name = key # actually not used
|
|
response["Body"].meta_stream = meta_stream
|
|
return response["Body"]
|
|
else:
|
|
raise ValueError(f"{mode}: unknown mode")
|
|
|
|
|
|
gopen_schemes["rclone"] = gopen_rclone
|
|
gopen_schemes["boto3"] = gopen_boto3
|
|
gopen_schemes["metaboto3"] = gopen_boto3
|