-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdict_methods.py
103 lines (92 loc) · 3.36 KB
/
dict_methods.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import copy
import numpy as np
from scipy import stats as sstats
from collections import defaultdict
import itertools
def filter(res, filter_dict=None):
'''
Filter results to only include entries containing the key, value pairs in select_dict
For example, select_dict = {'day', 0} will filter all data from res whose key 'day' == 0
:param res: flattened dict of results
:param filter_dict:
:return: a copy of res with filter applied
'''
if filter_dict is None:
return res
out = copy.copy(res)
list_of_ixs = []
for key, vals in filter_dict.items():
membership = np.isin(res[key], vals)
list_of_ixs.append(membership)
select_ixs = np.all(list_of_ixs, axis=0)
for key, value in res.items():
out[key] = value[select_ixs]
return out
def exclude(res, exclude_dict=None):
if exclude_dict is None:
return res
out = copy.copy(res)
list_of_ixs = []
for key, vals in exclude_dict.items():
membership = np.isin(res[key], vals)
list_of_ixs.append(membership)
exclude_ixs = np.all(list_of_ixs, axis=0)
select_ixs = np.invert(exclude_ixs)
for key, value in res.items():
out[key] = value[select_ixs]
return out
def filter_reduce(res, filter_keys, reduce_key):
#TODO: have not tested behavior
out = defaultdict(list)
if isinstance(filter_keys, str):
filter_keys = [filter_keys]
unique_combinations, ixs = retrieve_unique_entries(res, filter_keys)
for v in unique_combinations:
filter_dict = {filter_key: val for filter_key, val in zip(filter_keys, v)}
cur_res = filter(res, filter_dict)
temp_res = reduce_by_mean(cur_res, reduce_key)
chain_defaultdicts(out, temp_res)
for key, val in out.items():
out[key] = np.array(val)
return out
def reduce_by_mean(res, key):
#TODO: have not tested behavior
data = res[key]
mean = np.mean(data, axis=0)
std = np.std(data, axis=0)
sem = sstats.sem(data, axis=0)
out = defaultdict(list)
for k, v in res.items():
if k == key:
out[k] = mean
out[k + '_std'] = std
out[k + '_sem'] = sem
else:
if len(set(v)) == 1:
out[k] = v[0]
else:
list_elements = ','.join([str(x) for x in np.unique(v)])
print('Took the mean of non-unique list elements: ' + list_elements)
return out
def retrieve_unique_entries(res, loop_keys):
unique_entries_per_loopkey = []
for x in loop_keys:
a = res[x]
indexes = np.unique(a, return_index=True)[1]
unique_entries_per_loopkey.append([a[index] for index in sorted(indexes)])
unique_entry_combinations = list(itertools.product(*unique_entries_per_loopkey))
list_of_ind = []
for x in range(len(unique_entry_combinations)):
list_of_ixs = []
cur_combination = unique_entry_combinations[x]
for i, val in enumerate(cur_combination):
list_of_ixs.append(val == res[loop_keys[i]])
ind = np.all(list_of_ixs, axis=0)
ind_ = np.where(ind)[0]
list_of_ind.append(ind_)
return unique_entry_combinations, list_of_ind
def chain_defaultdicts(dictA, dictB):
for k in dictB.keys():
dictA[k] = list(itertools.chain(dictA[k], dictB[k]))
for key, val in dictA.items():
dictA[key] = np.array(val)