forked from gyyang/multitask
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpaper.py
executable file
·121 lines (97 loc) · 4.55 KB
/
paper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Main file for generating results in the paper:
Clustering and compositionality of task representations
in a neural network trained to perform many cognitive tasks
Yang GR et al. 2017 BioRxiv
"""
from __future__ import absolute_import
import tools
from analysis import performance
from analysis import standard_analysis
from analysis import clustering
from analysis import variance
from analysis import taskset
from analysis import varyhp
from analysis import data_analysis
from analysis import contextdm_analysis
# Directories of the models and the sample model
# Change these to your directories
root_dir = './data/train_all'
model_dir = root_dir + '/0'
## Performance Analysis-----------------------------------------------------
# standard_analysis.schematic_plot(model_dir=model_dir)
# performance.plot_performanceprogress(model_dir)
# performance.psychometric_choice(model_dir) # Psychometric for dm
# performance.psychometric_choiceattend(model_dir, no_ylabel=True)
# performance.psychometric_choiceint(model_dir, no_ylabel=True)
#
# for rule in ['dm1', 'contextdm1', 'multidm']:
# performance.plot_choicefamily_varytime(model_dir, rule)
# performance.psychometric_delaychoice_varytime(model_dir, 'delaydm1')
## Clustering Analysis------------------------------------------------------
# CA = clustering.Analysis(model_dir, data_type='rule')
# CA.plot_example_unit()
# CA.plot_cluster_score()
# CA.plot_variance()
# CA.plot_2Dvisualization('PCA')
# CA.plot_2Dvisualization('MDS')
# CA.plot_2Dvisualization('tSNE')
# CA.plot_lesions()
# CA.plot_connectivity_byclusters()
# CA = clustering.Analysis(model_dir, data_type='epoch')
# CA.plot_variance()
# CA.plot_2Dvisualization('tSNE')
## Varying hyperparameter analysis------------------------------------------
# varyhp_root_dir = './data/varyhp'
# n_clusters, hp_list = varyhp.get_n_clusters(varyhp_root_dir)
# varyhp.plot_n_clusters(n_clusters, hp_list)
# varyhp.plot_n_cluster_hist(n_clusters, hp_list)
## FTV Analysis-------------------------------------------------------------
# variance.plot_hist_varprop_selection(root_dir)
# variance.plot_hist_varprop_selection('./data/tanhgru')
# TODO: set plot_control=True later
# variance.plot_hist_varprop_all(root_dir, plot_control=False)
## ContextDM analysis-------------------------------------------------------
# ua = contextdm_analysis.UnitAnalysis(model_dir)
# ua.plot_inout_connections()
# ua.plot_rec_connections()
# ua.plot_rule_connections()
# ua.prettyplot_hist_varprop()
#
# contextdm_analysis.plot_performance_choicetasks(model_dir, grouping='var')
# contextdm_analysis.plot_performance_2D_all(model_dir, 'contextdm1')
## Task Representation------------------------------------------------------
# tsa = taskset.TaskSetAnalysis(model_dir)
# tsa.compute_and_plot_taskspace(epochs=['stim1'], dim_reduction_type='PCA')
## Compositional Representation---------------------------------------------
# setups = [1, 2, 3]
# for setup in setups:
# taskset.plot_taskspace_group(root_dir, setup=setup,
# restore=True, representation='rate')
# taskset.plot_taskspace_group(root_dir, setup=setup,
# restore=True, representation='weight')
# taskset.plot_replacerule_performance_group(
# root_dir, setup=setup, restore=True)
# setups = [1, 2, 3]
# for setup in setups:
# taskset.plot_replacerule_performance_group(
# './data/tanhgru', setup=setup, restore=True, fig_name_addon='tanhgru')
## Continual Learning Analysis----------------------------------------------
# hp_target0 = {'c_intsyn': 0, 'ksi_intsyn': 0.01,
# 'activation': 'relu', 'max_steps': 4e5}
# hp_target1 = {'c_intsyn': 1, 'ksi_intsyn': 0.01,
# 'activation': 'relu', 'max_steps': 4e5}
# model_dirs0 = tools.find_all_models('data/seq/', hp_target0)
# model_dirs1 = tools.find_all_models('data/seq/', hp_target1)
# model_dirs0 = tools.select_by_perf(model_dirs0, perf_min=0.8)
# model_dirs1 = tools.select_by_perf(model_dirs1, perf_min=0.8)
# performance.plot_performanceprogress_cont((model_dirs0[0], model_dirs1[2]))
# performance.plot_finalperformance_cont(model_dirs0, model_dirs1)
# data_analysis.plot_fracvar_hist_byhp(hp_vary='c_intsyn', mode='all_var')
# data_analysis.plot_fracvar_hist_byhp(hp_vary='p_weight_train', mode='all_var')
## Data analysis------------------------------------------------------------
# Note that these wouldn't work without the data file
# data_analysis.plot_all('mante_single_ar')
# data_analysis.plot_all('mante_single_fe')
# data_analysis.plot_all('mante_ar')
# data_analysis.plot_all('mante_fe')