-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathmerge-data.py
24 lines (20 loc) · 1.57 KB
/
merge-data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import numpy as np
part1 = np.load('edm_data/cifar10/20m_part1.npz')
part2 = np.load('edm_data/cifar10/20m_part2.npz')
np.savez('edm_data/cifar10/20m.npz', image=np.concatenate([part1['image'], part2['image']]), label=np.concatenate([part1['label'], part2['label']]))
part1 = np.load('edm_data/cifar10/50m_part1.npz')
part2 = np.load('edm_data/cifar10/50m_part2.npz')
part3 = np.load('edm_data/cifar10/50m_part3.npz')
part4 = np.load('edm_data/cifar10/50m_part4.npz')
np.savez('edm_data/cifar10/50m.npz', image=np.concatenate([part1['image'], part2['image'], part3['image'], part4['image']]), label=np.concatenate([part1['label'], part2['label'], part3['label'], part4['label']]))
part1 = np.load('edm_data/cifar100/50m_part1.npz')
part2 = np.load('edm_data/cifar100/50m_part2.npz')
part3 = np.load('edm_data/cifar100/50m_part3.npz')
part4 = np.load('edm_data/cifar100/50m_part4.npz')
np.savez('edm_data/cifar100/50m.npz', image=np.concatenate([part1['image'], part2['image'], part3['image'], part4['image']]), label=np.concatenate([part1['label'], part2['label'], part3['label'], part4['label']]))
part1 = np.load('edm_data/svhn/svhn_50m_part1.npz')
part2 = np.load('edm_data/svhn/svhn_50m_part2.npz')
part3 = np.load('edm_data/svhn/svhn_50m_part3.npz')
part4 = np.load('edm_data/svhn/svhn_50m_part4.npz')
part5 = np.load('edm_data/svhn/svhn_50m_part5.npz')
np.savez('edm_data/svhn/50m.npz', image=np.concatenate([part1['image'], part2['image'], part3['image'], part4['image'], part5['image']]), label=np.concatenate([part1['label'], part2['label'], part3['label'], part4['label'], part5['label']]))