Skip to content

Commit

Permalink
Merge pull request #25 from bruin-data/feature/introduce-stripe-source
Browse files Browse the repository at this point in the history
Feature/introduce stripe source
  • Loading branch information
karakanb authored Aug 14, 2024
2 parents 3897f4c + 92f206c commit f813c47
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 0 deletions.
4 changes: 4 additions & 0 deletions ingestr/src/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
NotionSource,
ShopifySource,
SqlSource,
StripeAnalyticsSource,
)

SQL_SOURCE_SCHEMES = [
Expand Down Expand Up @@ -102,6 +103,9 @@ def get_source(self) -> SourceProtocol:
return ShopifySource()
elif self.source_scheme == "gorgias":
return GorgiasSource()
elif self.source_scheme == "stripe":
return StripeAnalyticsSource()

else:
raise ValueError(f"Unsupported source scheme: {self.source_scheme}")

Expand Down
55 changes: 55 additions & 0 deletions ingestr/src/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ingestr.src.notion import notion_databases
from ingestr.src.shopify import shopify_source
from ingestr.src.sql_database import sql_table
from ingestr.src.stripe_analytics import stripe_source
from ingestr.src.table_definition import table_string_to_dataclass


Expand Down Expand Up @@ -295,3 +296,57 @@ def dlt_source(self, uri: str, table: str, **kwargs):
range_names=[table_fields.dataset],
get_named_ranges=False,
)


class StripeAnalyticsSource:
def handles_incrementality(self) -> bool:
return True

def dlt_source(self, uri: str, table: str, **kwargs):
if kwargs.get("incremental_key"):
raise ValueError(
"Stripe takes care of incrementality on its own, you should not provide incremental_key"
)

api_key = None
source_field = urlparse(uri)
source_params = parse_qs(source_field.query)
api_key = source_params.get("api_key")

if not api_key:
raise ValueError("api_key in the URI is required to connect to Stripe")

endpoint = None
table = str.capitalize(table)

if table in [
"Subscription",
"Account",
"Coupon",
"Customer",
"Product",
"Price",
"BalanceTransaction",
"Invoice",
"Event",
]:
endpoint = table
else:
raise ValueError(
f"Resource '{table}' is not supported for stripe source yet, if you are interested in it please create a GitHub issue at https://github.com/bruin-data/ingestr"
)

date_args = {}
if kwargs.get("interval_start"):
date_args["start_date"] = kwargs.get("interval_start")

if kwargs.get("interval_end"):
date_args["end_date"] = kwargs.get("interval_end")

return stripe_source(
endpoints=[
endpoint,
],
stripe_secret_key=api_key[0],
**date_args,
).with_resources(endpoint)
99 changes: 99 additions & 0 deletions ingestr/src/stripe_analytics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""This source uses Stripe API and dlt to load data such as Customer, Subscription, Event etc. to the database and to calculate the MRR and churn rate."""

from typing import Any, Dict, Generator, Iterable, Optional, Tuple

import dlt
import stripe
from dlt.sources import DltResource
from pendulum import DateTime

from .helpers import pagination, transform_date
from .settings import ENDPOINTS, INCREMENTAL_ENDPOINTS


@dlt.source
def stripe_source(
endpoints: Tuple[str, ...] = ENDPOINTS,
stripe_secret_key: str = dlt.secrets.value,
start_date: Optional[DateTime] = None,
end_date: Optional[DateTime] = None,
) -> Iterable[DltResource]:
"""
Retrieves data from the Stripe API for the specified endpoints.
For all endpoints, Stripe API responses do not provide the key "updated",
so in most cases, we are forced to load the data in 'replace' mode.
This source is suitable for all types of endpoints, including 'Events', 'Invoice', etc.
but these endpoints can also be loaded in incremental mode (see source incremental_stripe_source).
Args:
endpoints (Tuple[str, ...]): A tuple of endpoint names to retrieve data from. Defaults to most popular Stripe API endpoints.
stripe_secret_key (str): The API access token for authentication. Defaults to the value in the `dlt.secrets` object.
start_date (Optional[DateTime]): An optional start date to limit the data retrieved. Format: datetime(YYYY, MM, DD). Defaults to None.
end_date (Optional[DateTime]): An optional end date to limit the data retrieved. Format: datetime(YYYY, MM, DD). Defaults to None.
Returns:
Iterable[DltResource]: Resources with data that was created during the period greater than or equal to 'start_date' and less than 'end_date'.
"""
stripe.api_key = stripe_secret_key
stripe.api_version = "2022-11-15"

def stripe_resource(
endpoint: str,
) -> Generator[Dict[Any, Any], Any, None]:
yield from pagination(endpoint, start_date, end_date)

for endpoint in endpoints:
yield dlt.resource(
stripe_resource,
name=endpoint,
write_disposition="replace",
)(endpoint)


@dlt.source
def incremental_stripe_source(
endpoints: Tuple[str, ...] = INCREMENTAL_ENDPOINTS,
stripe_secret_key: str = dlt.secrets.value,
initial_start_date: Optional[DateTime] = None,
end_date: Optional[DateTime] = None,
) -> Iterable[DltResource]:
"""
As Stripe API does not include the "updated" key in its responses,
we are only able to perform incremental downloads from endpoints where all objects are uneditable.
This source yields the resources with incremental loading based on "append" mode.
You will load only the newest data without duplicating and without downloading a huge amount of data each time.
Args:
endpoints (tuple): A tuple of endpoint names to retrieve data from. Defaults to Stripe API endpoints with uneditable data.
stripe_secret_key (str): The API access token for authentication. Defaults to the value in the `dlt.secrets` object.
initial_start_date (Optional[DateTime]): An optional parameter that specifies the initial value for dlt.sources.incremental.
If parameter is not None, then load only data that were created after initial_start_date on the first run.
Defaults to None. Format: datetime(YYYY, MM, DD).
end_date (Optional[DateTime]): An optional end date to limit the data retrieved.
Defaults to None. Format: datetime(YYYY, MM, DD).
Returns:
Iterable[DltResource]: Resources with only that data has not yet been loaded.
"""
stripe.api_key = stripe_secret_key
stripe.api_version = "2022-11-15"
start_date_unix = (
transform_date(initial_start_date) if initial_start_date is not None else -1
)

def incremental_resource(
endpoint: str,
created: Optional[Any] = dlt.sources.incremental(
"created", initial_value=start_date_unix
),
) -> Generator[Dict[Any, Any], Any, None]:
start_value = created.last_value
yield from pagination(endpoint, start_date=start_value, end_date=end_date)

for endpoint in endpoints:
yield dlt.resource(
incremental_resource,
name=endpoint,
write_disposition="append",
primary_key="id",
)(endpoint)
68 changes: 68 additions & 0 deletions ingestr/src/stripe_analytics/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Stripe analytics source helpers"""

from typing import Any, Dict, Iterable, Optional, Union

import stripe
from dlt.common import pendulum
from dlt.common.typing import TDataItem
from pendulum import DateTime


def pagination(
endpoint: str, start_date: Optional[Any] = None, end_date: Optional[Any] = None
) -> Iterable[TDataItem]:
"""
Retrieves data from an endpoint with pagination.
Args:
endpoint (str): The endpoint to retrieve data from.
start_date (Optional[Any]): An optional start date to limit the data retrieved. Defaults to None.
end_date (Optional[Any]): An optional end date to limit the data retrieved. Defaults to None.
Returns:
Iterable[TDataItem]: Data items retrieved from the endpoint.
"""
starting_after = None
while True:
response = stripe_get_data(
endpoint,
start_date=start_date,
end_date=end_date,
starting_after=starting_after,
)

if len(response["data"]) > 0:
starting_after = response["data"][-1]["id"]
yield response["data"]

if not response["has_more"]:
break


def transform_date(date: Union[str, DateTime, int]) -> int:
if isinstance(date, str):
date = pendulum.from_format(date, "%Y-%m-%dT%H:%M:%SZ")
if isinstance(date, DateTime):
# convert to unix timestamp
date = int(date.timestamp())
return date


def stripe_get_data(
resource: str,
start_date: Optional[Any] = None,
end_date: Optional[Any] = None,
**kwargs: Any,
) -> Dict[Any, Any]:
if start_date:
start_date = transform_date(start_date)
if end_date:
end_date = transform_date(end_date)

if resource == "Subscription":
kwargs.update({"status": "all"})

resource_dict = getattr(stripe, resource).list(
created={"gte": start_date, "lt": end_date}, limit=100, **kwargs
)
return dict(resource_dict)
14 changes: 14 additions & 0 deletions ingestr/src/stripe_analytics/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Stripe analytics source settings and constants"""

# the most popular endpoints
# Full list of the Stripe API endpoints you can find here: https://stripe.com/docs/api.
ENDPOINTS = (
"Subscription",
"Account",
"Coupon",
"Customer",
"Product",
"Price",
)
# possible incremental endpoints
INCREMENTAL_ENDPOINTS = ("Event", "Invoice", "BalanceTransaction")
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ exclude = [
'src/google_sheets/.*',
'src/shopify/.*',
'src/gorgias/.*',
'src/stripe_analytics/.*'
]

[[tool.mypy.overrides]]
Expand All @@ -77,6 +78,7 @@ module = [
"ingestr.src.google_sheets.*",
"ingestr.src.shopify.*",
"ingestr.src.gorgias.*",
"ingestr.src.stripe_analytics.*",
]
follow_imports = "skip"

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ SQLAlchemy==1.4.52
sqlalchemy2-stubs==0.0.2a38
tqdm==4.66.2
typer==0.12.3
stripe==10.7.0

0 comments on commit f813c47

Please sign in to comment.