Skip to content

Commit

Permalink
modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
tundeakins committed Nov 28, 2022
1 parent 69c5f4a commit df1917b
Show file tree
Hide file tree
Showing 11 changed files with 649 additions and 148 deletions.
3 changes: 2 additions & 1 deletion CONAN3/RVmodel_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_RVmod(params,tt,RVmes,RVerr,bis,fwhm,contra,nfilt,baseLSQ,inmcmc,nddf,no
# print(ome2)
else:
ome2 = 0
print('ome2 000')
# print('ome2 000')
params[5] = np.sqrt(ecc)*np.sin(ome2)
# print('here')

Expand Down Expand Up @@ -120,6 +120,7 @@ def get_RVmod(params,tt,RVmes,RVerr,bis,fwhm,contra,nfilt,baseLSQ,inmcmc,nddf,no
phases = ((tt-params[0])/params[4]) - np.round( ((tt-params[0])/params[4]))
of=open(outfile,'w')
of.write("%10s %10s %10s %10s %10s %10s %10s %10s\n" %("# time","RV","error","full_mod","base","Rvmodel","det_RV", "phase"))
print(f"gamma_kms = {params[gammaind]}")
for k in range(len(tt)):
of.write('%10.6f %10.6f %10.6f %10.6f %10.6f %10.6f %10.6f %10.6f\n' % (tt[k], RVmes[k], RVerr[k], mod_RVbl[k],bfuncRV[k], +\
mod_RV[k]-params[gammaind],RVmes[k]-bfuncRV[k]-params[gammaind],phases[k]))
Expand Down
278 changes: 246 additions & 32 deletions CONAN3/_classes.py

Large diffs are not rendered by default.

81 changes: 62 additions & 19 deletions CONAN3/fit_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from multiprocessing import Pool
import pickle
import emcee
import time

from occultquad import *
from occultnl import *
Expand Down Expand Up @@ -41,17 +42,25 @@ def fit_data(lc, rv=None, mcmc=None, statistic = "median",
"""
function to fit the data using the light-curve object lc, rv_object rv and mcmc setup object mcmc.
Parameters
----------
statistic : str;
statistic to run on posteriors to obtain model parameters and create model output file ".._out_full.dat".
must be one of ["median", "max", "bestfit"], default is "median".
"max" and "median" calculate the maximum and median of each parameter posterior respectively while "bestfit" \
is the parameter combination that gives the maximum joint posterior probability.
Returns:
--------
result: object containing labeled mcmc chains
result : object containing labeled mcmc chains
Object that contains methods to plot the chains, corner, and histogram of parameters.
e.g result.plot_chains(), result.plot_burnin_chains(), result.plot_corner, result.plot_posterior("T_0")
**kwargs: other parameters sent to emcee.EnsembleSampler.run_mcmc() function
"""
print('CONAN3 launched!!!\n')
#begin loading data from the 3 objects and calling the methods
assert statistic in ["median", "max"], 'statistic can only be either median or max'
assert statistic in ["median", "max", "bestfit"], 'statistic can only be either median, max or bestfit'

#============lc_data=========================
#from load_lightcurves()
Expand All @@ -71,7 +80,7 @@ def fit_data(lc, rv=None, mcmc=None, statistic = "median",

nfilt = len(filnames)
ngroup = len(grnames)

useSpline = lc._spline

#============GP Setup=============================
#from load_lightcurves.add_GP()
Expand Down Expand Up @@ -205,6 +214,7 @@ def fit_data(lc, rv=None, mcmc=None, statistic = "median",
gamprilo = [] if rv is None else rv._gamprilo
gamprihi = [] if rv is None else rv._gamprihi
sinPs = [] if rv is None else rv._sinPs
rv_fpath = [] if rv is None else rv._fpath
if rv is None: #remove k as a free parameter
lc._config_par["K"].to_fit = "n"
lc._config_par["K"].step_size = 0
Expand Down Expand Up @@ -519,7 +529,8 @@ def fit_data(lc, rv=None, mcmc=None, statistic = "median",
GPrvWNstep = np.array([0.1])

#============================= SETUP ARRAYS =======================================
print('Setting up photometry arrays ...')
print('Setting up photometry arrays ...')
if useSpline.use: print('Setting up Spline fitting ...')
tarr=np.array([]) # initializing array with all timestamps
farr=np.array([]) # initializing array with all flux values
earr=np.array([]) # initializing array with all error values
Expand Down Expand Up @@ -609,7 +620,7 @@ def fit_data(lc, rv=None, mcmc=None, statistic = "median",
pnames=np.concatenate((pnames,[RVnames[i]+'_gamma']), axis=0)

if (jit_apply=='y'):
print('does jitter work?')
# print('does jitter work?')
# print(nothing)
params=np.concatenate((params,[0.01]), axis=0)
stepsize=np.concatenate((stepsize,[0.001]), axis=0)
Expand All @@ -618,7 +629,16 @@ def fit_data(lc, rv=None, mcmc=None, statistic = "median",
prior=np.concatenate((prior,[0.]), axis=0)
priorlow=np.concatenate((priorlow,[0.]), axis=0)
priorup=np.concatenate((priorup,[0.]), axis=0)
pnames=np.concatenate((pnames,[RVnames[i]+'_jitter']), axis=0)
pnames=np.concatenate((pnames,[RVnames[i]+'_jitter']), axis=0)
else:
params=np.concatenate((params,[0.]), axis=0)
stepsize=np.concatenate((stepsize,[0.]), axis=0)
pmin=np.concatenate((pmin,[0.]), axis=0)
pmax=np.concatenate((pmax,[0]), axis=0)
prior=np.concatenate((prior,[0.]), axis=0)
priorlow=np.concatenate((priorlow,[0.]), axis=0)
priorup=np.concatenate((priorup,[0.]), axis=0)
pnames=np.concatenate((pnames,[RVnames[i]+'_jitter']), axis=0)

nbc_tot = np.copy(0) # total number of baseline coefficients let to vary (leastsq OR jumping)

Expand Down Expand Up @@ -950,10 +970,10 @@ def fit_data(lc, rv=None, mcmc=None, statistic = "median",
pargps.append(pargp)
# time.sleep(3000)

print('Setting up RV arrays ...')
if rv is not None: print('Setting up RV arrays ...')

for i in range(nRV):
t, rv, err, bis, fwhm, contrast = np.loadtxt(fpath+RVnames[i], usecols=(0,1,2,3,4,5), unpack = True) # reading in the data
t, rv, err, bis, fwhm, contrast = np.loadtxt(rv_fpath+RVnames[i], usecols=(0,1,2,3,4,5), unpack = True) # reading in the data

tarr = np.concatenate((tarr,t), axis=0)
farr = np.concatenate((farr,rv), axis=0) # ! add the RVs to the "flux" array !
Expand Down Expand Up @@ -1137,11 +1157,15 @@ def fit_data(lc, rv=None, mcmc=None, statistic = "median",
nocc,rprs0,erprs0,grprs,egrprs,grnames,groups,ngroup,ewarr, inmcmc, paraCNM, baseLSQ, bvars, bvarsRV,
cont,names,RVnames,earr,divwhite,dwCNMarr,dwCNMind,params,useGPphot,useGPrv,GPobjects,GPparams,GPindex,
pindices,jumping,pnames,LCjump,priors[jumping],priorwids[jumping],lim_low[jumping],lim_up[jumping],pargps,
jumping_noGP,GPphotWN,jit_apply,jumping_GP,GPstepsizes,GPcombined]
jumping_noGP,GPphotWN,jit_apply,jumping_GP,GPstepsizes,GPcombined,useSpline]

mval, merr,dump1,dump2 = logprob_multi(initial[jumping],*indparams,verbose=True,debug=debug)
debug_t1 = time.time()
mval, merr,dump1,dump2 = logprob_multi(initial[jumping],*indparams,make_out_file=True,verbose=True,debug=debug)
if debug: print(f'finished logprob_multi, took {(time.time() - debug_t1)} secs')
if not os.path.exists("init"): os.mkdir("init") #folder to put initial plots
debug_t2 = time.time()
mcmc_plots(mval,tarr,farr,earr,xarr,yarr,warr,aarr,sarr,barr,carr,lind, nphot, nRV, indlist, filters, names, RVnames, 'init/init_',initial,T0_in[0],per_in[0])
if debug: print(f'finished mcmc_plots, took {(time.time() - debug_t2)} secs')


########################### MCMC run ###########################################
Expand All @@ -1152,7 +1176,7 @@ def fit_data(lc, rv=None, mcmc=None, statistic = "median",
nocc,rprs0,erprs0,grprs,egrprs,grnames,groups,ngroup,ewarr, inmcmc, paraCNM, baseLSQ, bvars, bvarsRV,
cont,names,RVnames,earr,divwhite,dwCNMarr,dwCNMind,params,useGPphot,useGPrv,GPobjects,GPparams,GPindex,
pindices,jumping,pnames,LCjump,priors[jumping],priorwids[jumping],lim_low[jumping],lim_up[jumping],pargps,
jumping_noGP,GPphotWN,jit_apply,jumping_GP,GPstepsizes,GPcombined]
jumping_noGP,GPphotWN,jit_apply,jumping_GP,GPstepsizes,GPcombined,useSpline]

print('No of dimensions: ', ndim)
print('No of chains: ', nchains)
Expand Down Expand Up @@ -1194,12 +1218,17 @@ def fit_data(lc, rv=None, mcmc=None, statistic = "median",
fig.savefig("burnin_chains.png", bbox_inches="tight")
print("Burn-in chains plot saved as: burnin_chains.png")
except:
print(f"burn-in chains not plotted (number of parameters ({ndim}) exceeds 20. use result.plot_burnin_chains()")
print(f"full burn-in chains not plotted (number of parameters ({ndim}) exceeds 20. use result.plot_burnin_chains()")
print(f"saving burn-in chain plot for the first 20 parameters")
pl_pars = list(burn_result._par_names)[:20]
fig = burn_result.plot_burnin_chains(pl_pars)
fig.savefig("burnin_chains.png", bbox_inches="tight")

matplotlib.use(__default_backend__)
sampler.reset()

print("Running production...")
pos, prob, state = sampler.run_mcmc(pos, ppchain, progress=True)
print("\nRunning production...")
pos, prob, state = sampler.run_mcmc(pos, ppchain,skip_initial_state_check=True, progress=True )
bp = pos[np.argmax(prob)]

posterior = sampler.flatchain
Expand Down Expand Up @@ -1259,14 +1288,20 @@ def fit_data(lc, rv=None, mcmc=None, statistic = "median",
baseLSQ, bvars, bvarsRV, cont,names,RVnames,earr,divwhite,dwCNMarr,dwCNMind,params,
useGPphot,useGPrv,GPobjects,GPparams,GPindex,pindices,jumping,pnames,LCjump,
priors[jumping],priorwids[jumping],lim_low[jumping],lim_up[jumping],pargps,
jumping_noGP,GPphotWN,jumping_GP,jit_apply,GPstepsizes,GPcombined]

jumping_noGP,GPphotWN,jumping_GP,jit_apply,GPstepsizes,GPcombined,useSpline]

#AKIN: save config parameters indparams and summary_stats and as a hidden files.
#can be used to run logprob_multi() to generate out_full.dat files for median posterior, max posterior and best fit values
pickle.dump(indparams, open(".par_config.pkl","wb"))
stat_vals = dict(med = medp[jumping], max=maxp[jumping], bf = bpfull[jumping])
pickle.dump(stat_vals, open(".stat_vals.pkl","wb"))

#median
mval, merr,T0_post,p_post = logprob_multi(medp[jumping],*indparams,verbose=True)
mval, merr,T0_post,p_post = logprob_multi(medp[jumping],*indparams,make_out_file=(statistic=="median"), verbose=True,)
mcmc_plots(mval,tarr,farr,earr,xarr,yarr,warr,aarr,sarr,barr,carr,lind, nphot, nRV, indlist, filters, names, RVnames, 'med_',medp,T0_post,p_post)

#max_posterior
mval2, merr2, T0_post, p_post = logprob_multi(maxp[jumping],*indparams)
mval2, merr2, T0_post, p_post = logprob_multi(maxp[jumping],*indparams,make_out_file=(statistic=="max"),verbose=False)
mcmc_plots(mval2,tarr,farr,earr,xarr,yarr,warr,aarr,sarr,barr,carr,lind, nphot, nRV, indlist, filters, names, RVnames, 'max_',maxp, T0_post,p_post)


Expand Down Expand Up @@ -1323,13 +1358,21 @@ def fit_data(lc, rv=None, mcmc=None, statistic = "median",
fig = result.plot_chains()
fig.savefig("chains.png", bbox_inches="tight")
except:
pass
print(f"\nfull chains not plotted (number of parameters ({ndim}) exceeds 20. use result.plot_burnin_chains()")
print(f"saving chain plot for the first 20 parameters")
pl_pars = list(result._par_names)[:20]
fig = result.plot_chains(pl_pars)
fig.savefig("chains.png", bbox_inches="tight")

try:
fig = result.plot_corner()
fig.savefig("corner.png", bbox_inches="tight")
except:
print(f"\ncorner not plotted (number of parameters ({ndim}) exceeds 14. use result.plot_corner(force_plot=True)")
print("saving corner plot for the first 14 parameters")
pl_pars = list(result._par_names)[:14]
fig = result.plot_corner(pl_pars,force_plot=True)
fig.savefig("corner.png", bbox_inches="tight")

matplotlib.use(__default_backend__)

Expand Down
Loading

0 comments on commit df1917b

Please sign in to comment.