-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCARLOS_TRAINING_CONTROL.jl
135 lines (108 loc) · 7.1 KB
/
CARLOS_TRAINING_CONTROL.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
using MAT
using ForwardDiff
using DiffBase
include("pro_anti.jl")
include("pro_anti_opto.jl")
function JJ_opto_plot(nPro, nAnti; opto_targets=[0.9 0.7], theta1=0.025, theta2=0.035, cbeta=0.003, verbose=false, pre_string="", zero_last_sigmas=0, seedrand=NaN, rule_and_delay_periods = [0.4], target_periods = [0.1], post_target_periods = [0.5], opto_periods = [-1 -1],opto_strength=1, nderivs=0, difforder=0,plot_conditions=false,model_details=false, model_params...) #set opto defaults!
if ~(size(opto_targets) == size(opto_periods)); error("opto parameters are bad"); end
nruns = length(rule_and_delay_periods)*length(target_periods)*length(post_target_periods)*size(opto_periods)[1]
nruns_each = length(rule_and_delay_periods)*length(target_periods)*length(post_target_periods)
cost1s = ForwardDiffZeros(size(opto_periods)[1], nruns_each, nderivs=nderivs, difforder=difforder)
cost2s = ForwardDiffZeros(size(opto_periods)[1], nruns_each, nderivs=nderivs, difforder=difforder)
hP = zeros(size(opto_periods)[1], nruns_each);
hA = zeros(size(opto_periods)[1], nruns_each);
dP = zeros(size(opto_periods)[1], nruns_each);
dA = zeros(size(opto_periods)[1], nruns_each);
hBP = zeros(size(opto_periods)[1], nruns_each);
hBA = zeros(size(opto_periods)[1], nruns_each);
if model_details
proVall = [];
antiVall = [];
opto_fraction = [];
pro_input = [];
anti_input = [];
end
n = totHitsP = totHitsA = totDiffsP = totDiffsA =nopto= 0
for kk=1:size(opto_periods)[1] # iterate over each opto inactivation period
nopto = 0;
# reset random number generator for each opto period, so it cant over fit noise samples
if ~isnan(seedrand); srand(seedrand); end
for i in rule_and_delay_periods
for j in target_periods
for k = post_target_periods
nopto += 1
# include this opto inactivation in the parameters to pass on
my_params = make_dict(["rule_and_delay_period", "target_period", "post_target_period","opto_period","opto_strength"], [i, j, k, opto_periods[kk,:], opto_strength], Dict(model_params))
# print("model params is " ); print(model_params); print("\n")
if typeof(plot_conditions)==Bool && ~plot_conditions
proVs, antiVs, proVall, antiVall, opto_fraction,pro_input,anti_input = run_ntrials_opto(nPro, nAnti; nderivs=nderivs, difforder=difforder, my_params...)
elseif typeof(plot_conditions)==Bool
proVs, antiVs, proVall, antiVall, opto_fraction,pro_input,anti_input = run_ntrials_opto(nPro, nAnti; plot_list=1:10, nderivs=nderivs, difforder=difforder, my_params...)
elseif plot_conditions[kk]
proVs, antiVs, proVall, antiVall, opto_fraction,pro_input,anti_input = run_ntrials_opto(nPro, nAnti; plot_list=1:10, nderivs=nderivs, difforder=difforder, my_params...)
else
proVs, antiVs, proVall, antiVall, opto_fraction,pro_input,anti_input = run_ntrials_opto(nPro, nAnti; nderivs=nderivs, difforder=difforder, my_params...)
end
hitsP = 0.5*(1 + tanh.((proVs[1,:]-proVs[4,:,])/theta1))
diffsP = tanh.((proVs[1,:,]-proVs[4,:])/theta2).^2
hitsA = 0.5*(1 + tanh.((antiVs[4,:]-antiVs[1,:,])/theta1))
diffsA = tanh.((antiVs[4,:,]-antiVs[1,:])/theta2).^2
# set up storage
hP[kk,nopto] = mean(hitsP);
hA[kk,nopto] = mean(hitsA);
dP[kk,nopto] = mean(diffsP);
dA[kk,nopto] = mean(diffsA);
hBP[kk,nopto] = sum(proVs[1,:] .>= proVs[4,:,])/nPro;
hBA[kk,nopto] = sum(proVs[4,:] .> proVs[1,:,])/nAnti;
if nPro>0 && nAnti>0
cost1s[kk,nopto] = (nPro*(mean(hitsP) - opto_targets[kk,1]).^2 + nAnti*(mean(hitsA) - opto_targets[kk,2]).^2)/(nPro+nAnti)
cost2s[kk,nopto] = -cbeta*(nPro*mean(diffsP) + nAnti*mean(diffsA))/(nPro+nAnti)
elseif nPro>0
cost1s[kk,nopto] = (mean(hitsP) - opto_targets[kk,1]).^2
cost2s[kk,nopto] = -cbeta*mean(diffsP)
else
cost1s[kk,nopto] = (mean(hitsA) - opto_targets[kk,2]).^2
cost2s[kk,nopto] = -cbeta*mean(diffsA)
end
end
end
end
end
cost1 = mean(cost1s)
cost2 = mean(cost2s)
if model_details
return cost1 + cost2, cost1s, cost2s, hP,hA,dP,dA,hBP,hBA, proVall, antiVall, opto_fraction, pro_input, anti_input
else
return cost1 + cost2, cost1s, cost2s, hP,hA,dP,dA,hBP,hBA
end
end
# find the best run from this farm
farmName = "LA";
farmnum=3;
# load farm and do a run
F = matread("goodfarms/farm_"*farmName*lpad(farmnum[1],4,0)*".mat")
model_params = symbol_key_ize(F["model_params"])
# define training noise function
train_func = (;params...) -> JJ_opto_plot(model_params[:nPro],model_params[:nAnti]; rule_and_delay_periods=F["rule_and_delay_periods"], theta1=model_params[:theta1], theta2=model_params[:theta2], post_target_periods=F["post_target_periods"], seedrand=F["sr"], cbeta=F["cb"], verbose=true,plot_conditions=[true, false, false,false,false], merge(make_dict(F["args"],F["pars"], merge(model_params, Dict(params))))...)
t_opto_scost, t_opto_scost1, t_opto_scost2, t_opto_hitsP,t_opto_hitsA, t_opto_diffsP, t_opto_diffsA, t_opto_bP, t_opto_bA = train_func(;:start_pro=>[-0.5,-0.5,-0.5,-0.5],:start_anti=>[-0.5,-0.5,-0.5,-0.5]);
#define test noise function
#test_func = (;params...) -> JJ_opto_plot(100,100; rule_and_delay_periods=F["rule_and_delay_periods"], theta1=model_params[:theta1], theta2=model_params[:theta2], post_target_periods=F["post_target_periods"], seedrand=F["test_sr"], cbeta=F["cb"], verbose=true,plot_conditions=[true,false,false,false,false],merge(make_dict(F["args"],F["pars"], merge(model_params, Dict(params))))...)
# run the model and make sure to include initial conditions
#opto_scost, opto_scost1, opto_scost2, opto_hitsP,opto_hitsA, opto_diffsP, opto_diffsA, opto_bP, opto_bA = test_func(;:start_pro=>[-0.5,-0.5,-0.5,-0.5],:start_anti=>[-0.5,-0.5,-0.5,-0.5] )
# look at results
epochs = ["control"; "full"; "rule"; "delay"; "target"]
[epochs round(t_opto_hitsP*100)/100 round(t_opto_hitsA*100)/100]
# Plot a hits% plot for this farm, compared to the targets
figure()
plot([1,2,3,4,5],t_opto_hitsP,color="red")
plot([1,2,3,4,5],t_opto_hitsA,color="blue")
plot([1,2,3,4,5],model_params[:opto_targets][:,1],color="red",linestyle="--")
plot([1,2,3,4,5],model_params[:opto_targets][:,2],color="blue",linestyle="--")
p_err = round(1.96*sqrt((1/1000)*t_opto_hitsP.*(1-t_opto_hitsP))*1000)/1000
a_err = round(1.96*sqrt((1/1000)*t_opto_hitsA.*(1-t_opto_hitsA))*1000)/1000
for i=1:5
plot([i,i], t_opto_hitsP[i]+[p_err[i], -p_err[i]],color="red")
plot([i,i], t_opto_hitsA[i]+[a_err[i], -a_err[i]],color="blue")
end
ylabel("hits %")
xlabel("condition")