Skip to content

Commit

Permalink
fix model for more than 1 decorrelation vector
Browse files Browse the repository at this point in the history
  • Loading branch information
j-faria committed Feb 2, 2024
1 parent edd0666 commit c27c064
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions src/kima/pykima/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,6 +1354,14 @@ def _select_posterior_samples(self, Np=None, mask=None):
return self.posterior_sample[mask & mask_Np].copy()

def log_prior(self, sample):
""" Calculate the log prior for a given sample
Args:
sample (array): sample for which to calculate the log prior
To evaluate at all posterior samples, consider using
np.apply_along_axis(self.log_prior, 1, self.posterior_sample)
"""
logp = []
for p, v in zip(self.parameter_priors, sample):
if p is None:
Expand Down Expand Up @@ -1536,7 +1544,7 @@ def print_sample(self, p, star_mass=1.0, show_a=False, show_m=False,
else:
print(' ', p[self.indices['jitter']])

if self.indicator_correlations:
if hasattr(self, 'indicator_correlations') and self.indicator_correlations:
print('indicator correlations:')
c = p[self.indices['betas']]
print(f' {c}')
Expand Down Expand Up @@ -1892,14 +1900,14 @@ def eval_model(self, sample, t = None,
else:
v += np.polyval(trend_par, t - self.tmiddle)

# TODO: fix this for more than one indicator
if self.indicator_correlations and include_indicator_correlations:
c = sample[self.indices['betas']]
# TODO: check if _extra_data is always read correctly
if hasattr(self, 'indicator_correlations') and self.indicator_correlations and include_indicator_correlations:
betas = sample[self.indices['betas']].copy()
interp_u = np.zeros_like(t)
for i, ai in enumerate(self.activity_indicators):
for i, (c, ai) in enumerate(zip(betas, self.activity_indicators)):
if ai != '':
interp_u += np.interp(t, self.data.t, self._extra_data[:, 3 + i])
v += c * interp_u
interp_u += c * np.interp(t, self.data.t, self._extra_data[:, 3 + i])
v += interp_u

return v

Expand Down Expand Up @@ -2492,6 +2500,9 @@ def eta4(self):
plot5 = display.plot_gp_corner
plot_gp_corner = display.plot_gp_corner

#
corner_planet_parameters = display.corner_planet_parameters


def get_sorted_planet_samples(self, full=True):
# all posterior samples for the planet parameters
Expand Down

0 comments on commit c27c064

Please sign in to comment.