Skip to content

Commit

Permalink
restore data_analyzer.py to master
Browse files Browse the repository at this point in the history
  • Loading branch information
bm-synth authored Feb 9, 2025
1 parent a59906d commit 5f67a49
Showing 1 changed file with 34 additions and 198 deletions.
232 changes: 34 additions & 198 deletions deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,26 +386,10 @@ def merge_map_results(self, dataset, metric_names, metric_types, save_path, num_
index_to_metric_builder.merge_file_(chunk_im_fname)
close_mmap_dataset_builder(index_to_sample_builder, index_to_sample_fname)
close_mmap_dataset_builder(index_to_metric_builder, index_to_metric_fname)
num_sample_per_value = {}
index_to_sample = MMapIndexedDataset(index_to_sample_fname, skip_warmup=True)
index_to_metric = MMapIndexedDataset(index_to_metric_fname, skip_warmup=True)
index_to_sample_merged_fname = f"{metric_save_path}/{metric_name}_index_to_sample_percentile_merged"
index_to_sample_merged_builder = create_mmap_dataset_builder(index_to_sample_merged_fname,
sample_idx_dtype)
for v_idx in range(len(index_to_sample)):
if v_idx > 0:
assert index_to_metric[v_idx] > index_to_metric[v_idx - 1]
num_sample_per_value[index_to_metric[v_idx][0]] = len(index_to_sample[v_idx])
assert sum(num_sample_per_value.values()) == total_num_samples
merge_step = max(1, len(index_to_sample) // 100)
for v_idx in range(0, len(index_to_sample), merge_step):
merged_samples = np.copy(
np.concatenate(index_to_sample[v_idx:min(len(index_to_sample), (v_idx + merge_step))],
axis=None))
index_to_sample_merged_builder.add_item(
torch.tensor(merged_samples.astype(np.int64), dtype=torch.long))
logger.info(f"Finished merging index_to_sample {v_idx} to {v_idx+merge_step}.")
close_mmap_dataset_builder(index_to_sample_merged_builder, index_to_sample_merged_fname)

num_sample_per_value = DataAnalyzer.output_index_to_sample_percentile(
index_to_sample_fname, index_to_metric_fname, metric_name, metric_save_path, total_num_samples,
sample_idx_dtype)
self.get_metric_value_percentiles(metric_name, num_sample_per_value, total_num_samples)
elif metric_type == 'accumulate_value_over_samples':
metric_save_path = f"{save_path}/{metric_name}/"
Expand All @@ -427,6 +411,29 @@ def merge_map_results(self, dataset, metric_names, metric_types, save_path, num_
metric_value_builder.add_item(torch.tensor(metric_value.astype(np.int64), dtype=torch.long))
close_mmap_dataset_builder(metric_value_builder, metric_value_fname)

@staticmethod
def output_index_to_sample_percentile(index_to_sample_fname, index_to_metric_fname, metric_name, metric_save_path,
total_num_samples, sample_idx_dtype):
""" read index_to_metric and index_to_sample files and write distribution to index_to_sample_percentage_merged """
num_sample_per_value = {}
index_to_sample = MMapIndexedDataset(index_to_sample_fname, skip_warmup=True)
index_to_metric = MMapIndexedDataset(index_to_metric_fname, skip_warmup=True)
index_to_sample_merged_fname = f"{metric_save_path}/{metric_name}_index_to_sample_percentile_merged"
index_to_sample_merged_builder = create_mmap_dataset_builder(index_to_sample_merged_fname, sample_idx_dtype)
for v_idx in range(len(index_to_sample)):
if v_idx > 0:
assert index_to_metric[v_idx] > index_to_metric[v_idx - 1]
num_sample_per_value[index_to_metric[v_idx][0]] = len(index_to_sample[v_idx])
assert sum(list(num_sample_per_value.values())) == total_num_samples
merge_step = max(1, len(index_to_sample) // 100)
for v_idx in range(0, len(index_to_sample), merge_step):
merged_samples = np.copy(
np.concatenate(index_to_sample[v_idx:min(len(index_to_sample), (v_idx + merge_step))], axis=None))
index_to_sample_merged_builder.add_item(torch.tensor(merged_samples.astype(np.int64), dtype=torch.long))
logger.info(f"Finished merging index_to_sample {v_idx} to {v_idx+merge_step}.")
close_mmap_dataset_builder(index_to_sample_merged_builder, index_to_sample_merged_fname)
return num_sample_per_value

def run_reduce(self):
if self.custom_reduce is None:
self.merge_map_results(self.dataset, self.metric_names, self.metric_types, self.save_path,
Expand Down Expand Up @@ -633,7 +640,7 @@ def run_map_reduce(self):
metric_to_samples_dict[value.item()] = []
metric_to_samples_dict[value.item()].append(sample.item())

# index_to_metric and index_to_sample serialize a dictionary from metric to samples
# index_to_metric and index_to_sample serialize a dicitonary from metric to samples
# index_to_metric stores a key per row, index_to_sample stores the values per row
values = [torch.tensor([x]) for x in metric_to_samples_dict.keys()]
samples = [torch.tensor(metric_to_samples_dict[x]) for x in metric_to_samples_dict.keys()]
Expand Down Expand Up @@ -807,140 +814,7 @@ def sample_sort(tensor, comm_group, num_workers, n_samples=100):
return recv


class SerialDataAnalyzer(object):

def __init__(
self,
dataset,
batch_size=1,
num_threads=4,
metric_names=[],
metric_functions=[],
metric_types=[],
save_path="./",
collate_fn=None,
sample_indices=None,
) -> None:
self.dataset = dataset
self.batch_size = batch_size
self.num_threads = num_threads
self.metric_names = metric_names
self.metric_functions = metric_functions
self.metric_types = metric_types
self.save_path = save_path
self.collate_fn = collate_fn
self.sample_indices = sample_indices
logger.info(f"Serial data analyzer initialized.")

def run_map_reduce(self):

dataloader = DataLoader(dataset=self.dataset,
num_workers=self.num_threads,
collate_fn=self.collate_fn,
pin_memory=False)

# set initial results list
metric_results = []
for metric_type in self.metric_types:
assert metric_type in ['single_value_per_sample', 'accumulate_value_over_samples'], \
f"metric_type {metric_type} not implemented."
if metric_type == 'single_value_per_sample':
metric_results.append({'values_list': [], 'metric_to_samples_dict': {}})
else:
metric_results.append(None)

# update results list
processed_samples = 0
for data in dataloader:
for m_idx in range(len(self.metric_names)):
metric_type, metric_function = self.metric_types[m_idx], self.metric_functions[m_idx]
metric_values = metric_function(data)
assert torch.is_tensor(metric_values) or isinstance(metric_values, np.ndarray), \
"metric_function must return a tensor or array"
if isinstance(metric_values, np.ndarray):
metric_values = torch.from_numpy(metric_values)
assert metric_values.dtype in valid_dtypes, \
f"metric_function result dtype {metric_values.dtype} not supported. Supported dtypes {valid_dtypes}"

if metric_type == 'single_value_per_sample':
metric_to_samples_dict = metric_results[m_idx]['metric_to_samples_dict']
values = metric_results[m_idx]['values_list']
for row in range(metric_values.size()[0]):
value = metric_values[row].item()
sample_idx = processed_samples + row # sample idx following dataset iteration order
if isinstance(data, dict) and 'index' in data: # Megatron use case
sample_idx = data['index'][row][0].item()
elif self.sample_indices is not None: # user defined shuffling of indices
sample_idx = self.sample_indices[sample_idx]
if value not in metric_to_samples_dict:
metric_to_samples_dict[value] = []
metric_to_samples_dict[value].append(sample_idx)
values.append(value)
elif metric_type == 'accumulate_value_over_samples':
if metric_results[m_idx] is None:
metric_results[m_idx] = metric_values
else:
metric_results[m_idx].add_(metric_values)
processed_samples += len(data)

# convert lists to arrays to same memory
# compute dtype for sample ids
total_num_samples = len(self.dataset)
sample_idx_dtype = find_fit_int_dtype(0, total_num_samples - 1)
logger.info(f"Total number of data samples: {total_num_samples}.")
logger.info(f"Will use {sample_idx_dtype} to store the sample indexes.")

for m_idx in range(len(self.metric_names)):
metric_values, metric_name, metric_type = \
metric_results[m_idx], self.metric_names[m_idx], self.metric_types[m_idx]
metric_save_path = f"{self.save_path}/{metric_name}/"
os.makedirs(metric_save_path, exist_ok=True)

if metric_type == 'single_value_per_sample':

# Compute metric value dtypes based on range
values = metric_results[m_idx]['values_list']
metric_value_dtype = find_fit_int_dtype(min(values), max(values))

# sample_to_metric maps sample ids to metric values, as a list of metric values
sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric"
values = [torch.tensor([x], device='cpu') for x in values]
self.file_write(values, sample_to_metric_fname, metric_value_dtype)

# Compute sample dtypes based on range
metric_to_samples_dict = metric_values['metric_to_samples_dict']
values, samples = metric_to_samples_dict.keys(), metric_to_samples_dict.values()
sample_value_dtype = find_fit_int_dtype(min([min(x) for x in samples]), max([max(x) for x in samples]))

# index_to_metric and index_to_sample serialize a dicitonary from metric to samples
# index_to_metric stores a key per row, index_to_sample stores the values per row
values = [torch.tensor([x]) for x in metric_to_samples_dict.keys()]
samples = [torch.tensor(x) for x in metric_to_samples_dict.values()]
index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric" #dict keys
index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample" #dict values
self.file_write(values, index_to_metric_fname, metric_value_dtype)
self.file_write(samples, index_to_sample_fname, sample_value_dtype)

DataAnalyzer.output_index_to_sample_percentile(index_to_sample_fname, index_to_metric_fname,
metric_name, metric_save_path, total_num_samples,
sample_idx_dtype)

elif metric_type == 'accumulate_value_over_samples':
metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value"
metric_value_dtype = find_fit_int_dtype(metric_values.min(), metric_values.max())
self.file_write([metric_values], metric_value_fname, metric_value_dtype)

def file_write(self, tensor_list, fname, numpy_dtype):
""" write a list of tensors to a file """

# prepares output folder and file
os.makedirs(os.path.dirname(fname), exist_ok=True)
builder = create_mmap_dataset_builder(fname, numpy_dtype)
builder.add_items(tensor_list)
close_mmap_dataset_builder(builder, fname) # close file


def test_compare_data_analyzers(dataset, num_threads=16):
def test_compare_both_data_analyzers(dataset):
""" given a dataset, compare file and memory based data analyser"""

id = lambda t: t.to(torch.int64) # identity
Expand All @@ -956,41 +830,7 @@ def test_compare_data_analyzers(dataset, num_threads=16):
metric_types=['single_value_per_sample', 'accumulate_value_over_samples'],
num_threads=num_threads,
)
worker_id = int(os.environ['RANK'])
num_workers = int(os.environ['WORLD_SIZE'])

# run Serial Data Analyzer (with on single CPU-memory storage of map-reduce)
start_time = time.time()
if worker_id == 0:
sda = SerialDataAnalyzer(save_path="./output_sda", num_threads=num_threads, **kwargs)
sda.run_map_reduce()
print("SerialDataAnalyzer runtime: %s seconds " % (time.time() - start_time))

# run Distributed Data Analyzer (with distributed CUDA-memory storage of map-reduce)
start_time = time.time()
dda = DistributedDataAnalyzer(
save_path="./output_dda",
device=f"cuda:{int(os.environ['LOCAL_RANK'])}",
**kwargs | dict(worker_id=worker_id, num_workers=num_workers),
)
dda.run_map_reduce()
if worker_id == 0:
print("DistributedDataAnalyzer runtime: %s seconds " % (time.time() - start_time))

# run regular Data Analyzer (with shared disk storage of map-reduce)
start_time = time.time()
da = DataAnalyzer(
num_threads=num_threads,
num_threads_reduce=num_threads,
metric_dtypes=[torch.int64, torch.int64],
save_path="./output_da",
**kwargs | dict(worker_id=worker_id, num_workers=num_workers),
)
da.run_map_reduce()
if worker_id == 0:
print("DataAnalyzer runtime: %s seconds " % (time.time() - start_time))

# check that all output files match
dda = DistributedDataAnalyzer(
save_path="./output_dist",
device=f"cuda:{int(os.environ['LOCAL_RANK'])}",
Expand Down Expand Up @@ -1018,16 +858,12 @@ def test_compare_data_analyzers(dataset, num_threads=16):
"mod/mod_sample_to_metric.bin", "mod/mod_sample_to_metric.idx"
]

if worker_id == 0:
if dda.worker_id == 0:
for path in output_paths:
with open(os.path.join(da.save_path, path), 'rb') as f1, \
open(os.path.join(dda.save_path, path), 'rb') as f2, \
open(os.path.join(sda.save_path, path), 'rb') as f3:
f1c, f2c, f3c = f1.read(), f2.read(), f3.read()
if f1c != f2c:
print(f"DataAnalyzer and DistributedDataAnalyzer {path} are not identical.")
if f2c != f3c:
print(f"DistributedDataAnalyzer and SerialDataAnalyzer {path} are not identical.")
open(os.path.join(dda.save_path, path), 'rb') as f2:
if f1.read() != f2.read():
print(f"files {path} are not identical.")


if __name__ == "__main__":
Expand All @@ -1041,4 +877,4 @@ def __init__(self, size=10_000_000):
__len__ = lambda self: self.size
__getitem__ = lambda self, idx: self.values[idx]

test_compare_data_analyzers(TestDataset())
test_compare_both_data_analyzers(TestDataset())

0 comments on commit 5f67a49

Please sign in to comment.