Skip to content

Commit

Permalink
Moved estimation to a separate package
Browse files Browse the repository at this point in the history
  • Loading branch information
jmafoster1 committed Aug 6, 2024
1 parent c6f9d31 commit a13d83b
Show file tree
Hide file tree
Showing 34 changed files with 1,348 additions and 1,200 deletions.
75 changes: 75 additions & 0 deletions causal_testing/estimation/cubic_spline_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""This module contains the CubicSplineRegressionEstimator class, for estimating continuous outcomes with changes in behaviour"""

import logging
from abc import ABC, abstractmethod
from typing import Any
from math import ceil

import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from patsy import dmatrix # pylint: disable = no-name-in-module
from patsy import ModelDesc
from statsmodels.regression.linear_model import RegressionResultsWrapper
from statsmodels.tools.sm_exceptions import PerfectSeparationError
from lifelines import CoxPHFitter

from causal_testing.specification.variable import Variable
from causal_testing.specification.capabilities import TreatmentSequence, Capability
from causal_testing.estimation.estimator import Estimator
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator

logger = logging.getLogger(__name__)


class CubicSplineRegressionEstimator(LinearRegressionEstimator):
"""A Cubic Spline Regression Estimator is a parametric estimator which restricts the variables in the data to a
combination of parameters and basis functions of the variables.
"""

def __init__(
# pylint: disable=too-many-arguments
self,
treatment: str,
treatment_value: float,
control_value: float,
adjustment_set: set,
outcome: str,
basis: int,
df: pd.DataFrame = None,
effect_modifiers: dict[Variable:Any] = None,
formula: str = None,
alpha: float = 0.05,
expected_relationship=None,
):
super().__init__(
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, formula, alpha
)

self.expected_relationship = expected_relationship

if effect_modifiers is None:
effect_modifiers = []

if formula is None:
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
self.formula = f"{outcome} ~ cr({'+'.join(terms)}, df={basis})"

def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
model = self._run_linear_regression()

x = {"Intercept": 1, self.treatment: self.treatment_value}
if adjustment_config is not None:
for k, v in adjustment_config.items():
x[k] = v
if self.effect_modifiers is not None:
for k, v in self.effect_modifiers.items():
x[k] = v

treatment = model.predict(x).iloc[0]

x[self.treatment] = self.control_value
control = model.predict(x).iloc[0]

return pd.Series(treatment - control)
87 changes: 87 additions & 0 deletions causal_testing/estimation/estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""This module contains the Estimator abstract class"""

import logging
from abc import ABC, abstractmethod
from typing import Any
from math import ceil

import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from patsy import dmatrix # pylint: disable = no-name-in-module
from patsy import ModelDesc
from statsmodels.regression.linear_model import RegressionResultsWrapper
from statsmodels.tools.sm_exceptions import PerfectSeparationError
from lifelines import CoxPHFitter

from causal_testing.specification.variable import Variable
from causal_testing.specification.capabilities import TreatmentSequence, Capability

logger = logging.getLogger(__name__)


class Estimator(ABC):
# pylint: disable=too-many-instance-attributes
"""An estimator contains all of the information necessary to compute a causal estimate for the effect of changing
a set of treatment variables to a set of values.
All estimators must implement the following two methods:
1) add_modelling_assumptions: The validity of a model-assisted causal inference result depends on whether
the modelling assumptions imposed by a model actually hold. Therefore, for each model, is important to state
the modelling assumption upon which the validity of the results depend. To achieve this, the estimator object
maintains a list of modelling assumptions (as strings). If a user wishes to implement their own estimator, they
must implement this method and add all assumptions to the list of modelling assumptions.
2) estimate_ate: All estimators must be capable of returning the average treatment effect as a minimum. That is, the
average effect of the intervention (changing treatment from control to treated value) on the outcome of interest
adjusted for all confounders.
"""

def __init__(
# pylint: disable=too-many-arguments
self,
treatment: str,
treatment_value: float,
control_value: float,
adjustment_set: set,
outcome: str,
df: pd.DataFrame = None,
effect_modifiers: dict[str:Any] = None,
alpha: float = 0.05,
query: str = "",
):
self.treatment = treatment
self.treatment_value = treatment_value
self.control_value = control_value
self.adjustment_set = adjustment_set
self.outcome = outcome
self.alpha = alpha
self.df = df.query(query) if query else df

if effect_modifiers is None:
self.effect_modifiers = {}
elif isinstance(effect_modifiers, dict):
self.effect_modifiers = effect_modifiers
else:
raise ValueError(f"Unsupported type for effect_modifiers {effect_modifiers}. Expected iterable")
self.modelling_assumptions = []
if query:
self.modelling_assumptions.append(query)
self.add_modelling_assumptions()
logger.debug("Effect Modifiers: %s", self.effect_modifiers)

@abstractmethod
def add_modelling_assumptions(self):
"""
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
must hold if the resulting causal inference is to be considered valid.
"""

def compute_confidence_intervals(self) -> list[float, float]:
"""
Estimate the 95% Wald confidence intervals for the effect of changing the treatment from control values to
treatment values on the outcome.
:return: 95% Wald confidence intervals.
"""
File renamed without changes.
Loading

0 comments on commit a13d83b

Please sign in to comment.