-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Moved estimation to a separate package
- Loading branch information
1 parent
c6f9d31
commit a13d83b
Showing
34 changed files
with
1,348 additions
and
1,200 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""This module contains the CubicSplineRegressionEstimator class, for estimating continuous outcomes with changes in behaviour""" | ||
|
||
import logging | ||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
from math import ceil | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import statsmodels.api as sm | ||
import statsmodels.formula.api as smf | ||
from patsy import dmatrix # pylint: disable = no-name-in-module | ||
from patsy import ModelDesc | ||
from statsmodels.regression.linear_model import RegressionResultsWrapper | ||
from statsmodels.tools.sm_exceptions import PerfectSeparationError | ||
from lifelines import CoxPHFitter | ||
|
||
from causal_testing.specification.variable import Variable | ||
from causal_testing.specification.capabilities import TreatmentSequence, Capability | ||
from causal_testing.estimation.estimator import Estimator | ||
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class CubicSplineRegressionEstimator(LinearRegressionEstimator): | ||
"""A Cubic Spline Regression Estimator is a parametric estimator which restricts the variables in the data to a | ||
combination of parameters and basis functions of the variables. | ||
""" | ||
|
||
def __init__( | ||
# pylint: disable=too-many-arguments | ||
self, | ||
treatment: str, | ||
treatment_value: float, | ||
control_value: float, | ||
adjustment_set: set, | ||
outcome: str, | ||
basis: int, | ||
df: pd.DataFrame = None, | ||
effect_modifiers: dict[Variable:Any] = None, | ||
formula: str = None, | ||
alpha: float = 0.05, | ||
expected_relationship=None, | ||
): | ||
super().__init__( | ||
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, formula, alpha | ||
) | ||
|
||
self.expected_relationship = expected_relationship | ||
|
||
if effect_modifiers is None: | ||
effect_modifiers = [] | ||
|
||
if formula is None: | ||
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers)) | ||
self.formula = f"{outcome} ~ cr({'+'.join(terms)}, df={basis})" | ||
|
||
def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series: | ||
model = self._run_linear_regression() | ||
|
||
x = {"Intercept": 1, self.treatment: self.treatment_value} | ||
if adjustment_config is not None: | ||
for k, v in adjustment_config.items(): | ||
x[k] = v | ||
if self.effect_modifiers is not None: | ||
for k, v in self.effect_modifiers.items(): | ||
x[k] = v | ||
|
||
treatment = model.predict(x).iloc[0] | ||
|
||
x[self.treatment] = self.control_value | ||
control = model.predict(x).iloc[0] | ||
|
||
return pd.Series(treatment - control) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""This module contains the Estimator abstract class""" | ||
|
||
import logging | ||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
from math import ceil | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import statsmodels.api as sm | ||
import statsmodels.formula.api as smf | ||
from patsy import dmatrix # pylint: disable = no-name-in-module | ||
from patsy import ModelDesc | ||
from statsmodels.regression.linear_model import RegressionResultsWrapper | ||
from statsmodels.tools.sm_exceptions import PerfectSeparationError | ||
from lifelines import CoxPHFitter | ||
|
||
from causal_testing.specification.variable import Variable | ||
from causal_testing.specification.capabilities import TreatmentSequence, Capability | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Estimator(ABC): | ||
# pylint: disable=too-many-instance-attributes | ||
"""An estimator contains all of the information necessary to compute a causal estimate for the effect of changing | ||
a set of treatment variables to a set of values. | ||
All estimators must implement the following two methods: | ||
1) add_modelling_assumptions: The validity of a model-assisted causal inference result depends on whether | ||
the modelling assumptions imposed by a model actually hold. Therefore, for each model, is important to state | ||
the modelling assumption upon which the validity of the results depend. To achieve this, the estimator object | ||
maintains a list of modelling assumptions (as strings). If a user wishes to implement their own estimator, they | ||
must implement this method and add all assumptions to the list of modelling assumptions. | ||
2) estimate_ate: All estimators must be capable of returning the average treatment effect as a minimum. That is, the | ||
average effect of the intervention (changing treatment from control to treated value) on the outcome of interest | ||
adjusted for all confounders. | ||
""" | ||
|
||
def __init__( | ||
# pylint: disable=too-many-arguments | ||
self, | ||
treatment: str, | ||
treatment_value: float, | ||
control_value: float, | ||
adjustment_set: set, | ||
outcome: str, | ||
df: pd.DataFrame = None, | ||
effect_modifiers: dict[str:Any] = None, | ||
alpha: float = 0.05, | ||
query: str = "", | ||
): | ||
self.treatment = treatment | ||
self.treatment_value = treatment_value | ||
self.control_value = control_value | ||
self.adjustment_set = adjustment_set | ||
self.outcome = outcome | ||
self.alpha = alpha | ||
self.df = df.query(query) if query else df | ||
|
||
if effect_modifiers is None: | ||
self.effect_modifiers = {} | ||
elif isinstance(effect_modifiers, dict): | ||
self.effect_modifiers = effect_modifiers | ||
else: | ||
raise ValueError(f"Unsupported type for effect_modifiers {effect_modifiers}. Expected iterable") | ||
self.modelling_assumptions = [] | ||
if query: | ||
self.modelling_assumptions.append(query) | ||
self.add_modelling_assumptions() | ||
logger.debug("Effect Modifiers: %s", self.effect_modifiers) | ||
|
||
@abstractmethod | ||
def add_modelling_assumptions(self): | ||
""" | ||
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that | ||
must hold if the resulting causal inference is to be considered valid. | ||
""" | ||
|
||
def compute_confidence_intervals(self) -> list[float, float]: | ||
""" | ||
Estimate the 95% Wald confidence intervals for the effect of changing the treatment from control values to | ||
treatment values on the outcome. | ||
:return: 95% Wald confidence intervals. | ||
""" |
File renamed without changes.
Oops, something went wrong.