Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callback function as parameter in the feature usage functions #396

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/how-to-guides/client-callback-function.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
---
layout: default
title: How to use callback function in feathr client
parent: Feathr How-to Guides
---

## What is a callback function

A callback function is a function that is sent to another function as an argument. It can be used to extend the function as per the user needs.

## How to use callback functions

Currently the below functions in feathr client support passing a callback as an argument:

- get_online_features
- multi_get_online_features
- get_offline_features
- monitor_features
- materialize_features

These functions accept two optional parameters named **callback** and **params**.
callback is of type function and params is a dictionary where user can pass the arguments for the callback function.

An example on how to use it:

```python
# inside notebook
client = FeathrClient(config_path)
client.get_offline_features(observation_settings,feature_query,output_path, callback, params)

# users can define their own callback function and params
params = {"param1":"value1", "param2":"value2"}

async def callback(params):
import httpx
async with httpx.AsyncClient() as requestHandler:
response = await requestHandler.post('https://some-endpoint', json = params)
return response

```
58 changes: 47 additions & 11 deletions feathr_project/feathr/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import tempfile
import asyncio
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Union
Expand Down Expand Up @@ -264,13 +265,16 @@ def _get_registry_client(self):
"""
return self.registry._get_registry_client()

def get_online_features(self, feature_table, key, feature_names):
"""Fetches feature value for a certain key from a online feature table.
def get_online_features(self, feature_table, key, feature_names, callback: callable = None, params: dict = None):
"""Fetches feature value for a certain key from a online feature table. There is an optional callback function
and the params to extend this function's capability.For eg. cosumer of the features.

Args:
feature_table: the name of the feature table.
key: the key of the entity
feature_names: list of feature names to fetch
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
params: a dictionary of parameters for the callback function

Return:
A list of feature values for this entity. It's ordered by the requested feature names.
Expand All @@ -283,15 +287,21 @@ def get_online_features(self, feature_table, key, feature_names):
"""
redis_key = self._construct_redis_key(feature_table, key)
res = self.redis_clint.hmget(redis_key, *feature_names)
return self._decode_proto(res)
feature_values = self._decode_proto(res)
if (callback is not None) and (params is not None):
event_loop = asyncio.get_event_loop()
event_loop.create_task(callback(params))
return feature_values

def multi_get_online_features(self, feature_table, keys, feature_names):
def multi_get_online_features(self, feature_table, keys, feature_names, callback: callable = None, params: dict = None):
"""Fetches feature value for a list of keys from a online feature table. This is the batch version of the get API.

Args:
feature_table: the name of the feature table.
keys: list of keys for the entities
feature_names: list of feature names to fetch
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
params: a dictionary of parameters for the callback function

Return:
A list of feature values for the requested entities. It's ordered by the requested feature names. For
Expand All @@ -312,6 +322,10 @@ def multi_get_online_features(self, feature_table, keys, feature_names):
for feature_list in pipeline_result:
decoded_pipeline_result.append(self._decode_proto(feature_list))

if (callback is not None) and (params is not None):
event_loop = asyncio.get_event_loop()
event_loop.create_task(callback(params))

return dict(zip(keys, decoded_pipeline_result))

def _decode_proto(self, feature_list):
Expand Down Expand Up @@ -412,15 +426,20 @@ def get_offline_features(self,
output_path: str,
execution_configuratons: Union[SparkExecutionConfiguration ,Dict[str,str]] = {},
udf_files = None,
verbose: bool = False
verbose: bool = False,
callback: callable = None,
params: dict = None
):
"""
Get offline features for the observation dataset
Get offline features for the observation dataset. There is an optional callback function and the params
to extend this function's capability.For eg. cosumer of the features.
Args:
observation_settings: settings of the observation data, e.g. timestamp columns, input path, etc.
feature_query: features that are requested to add onto the observation data
output_path: output path of job, i.e. the observation data with features attached.
execution_configuratons: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations.
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
params: a dictionary of parameters for the callback function
"""
feature_queries = feature_query if isinstance(feature_query, List) else [feature_query]
feature_names = []
Expand Down Expand Up @@ -457,7 +476,11 @@ def get_offline_features(self,
FeaturePrinter.pretty_print_feature_query(feature_query)

write_to_file(content=config, full_file_name=config_file_path)
return self._get_offline_features_with_config(config_file_path, execution_configuratons, udf_files=udf_files)
job_info = self._get_offline_features_with_config(config_file_path, execution_configuratons, udf_files=udf_files)
if (callback is not None) and (params is not None):
event_loop = asyncio.get_event_loop()
event_loop.create_task(callback(params))
return job_info

def _get_offline_features_with_config(self, feature_join_conf_path='feature_join_conf/feature_join.conf', execution_configuratons: Dict[str,str] = {}, udf_files=[]):
"""Joins the features to your offline observation dataset based on the join config.
Expand Down Expand Up @@ -534,21 +557,30 @@ def wait_job_to_finish(self, timeout_sec: int = 300):
else:
raise RuntimeError('Spark job failed.')

def monitor_features(self, settings: MonitoringSettings, execution_configuratons: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False):
"""Create a offline job to generate statistics to monitor feature data
def monitor_features(self, settings: MonitoringSettings, execution_configuratons: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False, callback: callable = None, params: dict = None):
"""Create a offline job to generate statistics to monitor feature data. There is an optional
callback function and the params to extend this function's capability.For eg. cosumer of the features.

Args:
settings: Feature monitoring settings
execution_configuratons: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations.
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
params: a dictionary of parameters for the callback function.
"""
self.materialize_features(settings, execution_configuratons, verbose)
if (callback is not None) and (params is not None):
event_loop = asyncio.get_event_loop()
event_loop.create_task(callback(params))

def materialize_features(self, settings: MaterializationSettings, execution_configuratons: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False):
"""Materialize feature data
def materialize_features(self, settings: MaterializationSettings, execution_configuratons: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False, callback: callable = None, params: dict = None):
"""Materialize feature data. There is an optional callback function and the params
to extend this function's capability.For eg. cosumer of the feature store.

Args:
settings: Feature materialization settings
execution_configuratons: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations.
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
params: a dictionary of parameters for the callback function
"""
# produce materialization config
for end in settings.get_backfill_cutoff_time():
Expand All @@ -575,6 +607,10 @@ def materialize_features(self, settings: MaterializationSettings, execution_conf
# Pretty print feature_names of materialized features
if verbose and settings:
FeaturePrinter.pretty_print_materialize_features(settings)

if (callback is not None) and (params is not None):
event_loop = asyncio.get_event_loop()
event_loop.create_task(callback(params))

def _materialize_features_with_config(self, feature_gen_conf_path: str = 'feature_gen_conf/feature_gen.conf',execution_configuratons: Dict[str,str] = {}, udf_files=[]):
"""Materializes feature data based on the feature generation config. The feature
Expand Down
124 changes: 124 additions & 0 deletions feathr_project/test/test_client_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
import asyncio
import unittest.mock as mock
import time
from subprocess import call
from datetime import datetime, timedelta

from pathlib import Path
from feathr import ValueType
from feathr import FeatureQuery
from feathr import ObservationSettings
from feathr import TypedKey
from test_fixture import basic_test_setup
from test_fixture import get_online_test_table_name
from feathr.definition._materialization_utils import _to_materialization_config
from feathr import (BackfillTime, MaterializationSettings)
from feathr import (BackfillTime, MaterializationSettings, FeatureQuery,
ObservationSettings, SparkExecutionConfiguration)
from feathr import RedisSink, HdfsSink


params = {"wait" : 0.1}
async def sample_callback(params):
print(params)
await asyncio.sleep(0.1)

callback = mock.MagicMock(return_value=sample_callback(params))

def test_client_callback_offline_feature():
test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace"
client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml"))

location_id = TypedKey(key_column="DOLocationID",
key_column_type=ValueType.INT32,
description="location id in NYC",
full_name="nyc_taxi.location_id")
feature_query = FeatureQuery(feature_list=["f_location_avg_fare"], key=location_id)

settings = ObservationSettings(
observation_path="wasbs://[email protected]/sample_data/green_tripdata_2020-04.csv",
event_timestamp_column="lpep_dropoff_datetime",
timestamp_format="yyyy-MM-dd HH:mm:ss")

now = datetime.now()
output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), ".avro"])

res = client.get_offline_features(observation_settings=settings,
feature_query=feature_query,
output_path=output_path,
callback=callback,
params=params)
callback.assert_called_with(params)


def test_client_callback_materialization():
online_test_table = get_online_test_table_name("nycTaxiCITable")
test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace"

client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml"))
backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1))
redisSink = RedisSink(table_name=online_test_table)
settings = MaterializationSettings("nycTaxiTable",
sinks=[redisSink],
feature_names=[
"f_location_avg_fare", "f_location_max_fare"],
backfill_time=backfill_time)
client.materialize_features(settings, callback=callback, params=params)
callback.assert_called_with(params)

def test_client_callback_monitor_features():
online_test_table = get_online_test_table_name("nycTaxiCITable")
test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace"

client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml"))
backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1))
redisSink = RedisSink(table_name=online_test_table)
settings = MaterializationSettings("nycTaxiTable",
sinks=[redisSink],
feature_names=[
"f_location_avg_fare", "f_location_max_fare"],
backfill_time=backfill_time)
client.monitor_features(settings, callback=callback, params=params)
callback.assert_called_with(params)

def test_client_callback_get_online_features():
online_test_table = get_online_test_table_name("nycTaxiCITable")
test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace"

client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml"))
backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1))
redisSink = RedisSink(table_name=online_test_table)
settings = MaterializationSettings("nycTaxiTable",
sinks=[redisSink],
feature_names=[
"f_location_avg_fare", "f_location_max_fare"],
backfill_time=backfill_time)
client.materialize_features(settings)
callback.assert_called_with(params)
client.wait_job_to_finish(timeout_sec=900)
# wait for a few secs for the data to come in redis
time.sleep(5)
client.get_online_features('nycTaxiDemoFeature', '265', ['f_location_avg_fare', 'f_location_max_fare'], callback=callback, params=params)
callback.assert_called_with(params)


def test_client_callback_multi_get_online_features():
online_test_table = get_online_test_table_name("nycTaxiCITable")
test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace"

client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml"))
backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1))
redisSink = RedisSink(table_name=online_test_table)
settings = MaterializationSettings("nycTaxiTable",
sinks=[redisSink],
feature_names=[
"f_location_avg_fare", "f_location_max_fare"],
backfill_time=backfill_time)
client.materialize_features(settings)
callback.assert_called_with(params)
client.wait_job_to_finish(timeout_sec=900)
# wait for a few secs for the data to come in redis
time.sleep(5)
client.multi_get_online_features('nycTaxiDemoFeature', ["239", "265"], ['f_location_avg_fare', 'f_location_max_fare'], callback=callback, params=params)
callback.assert_called_with(params)