diff --git a/docs/how-to-guides/client-callback-function.md b/docs/how-to-guides/client-callback-function.md new file mode 100644 index 000000000..e7bfca830 --- /dev/null +++ b/docs/how-to-guides/client-callback-function.md @@ -0,0 +1,42 @@ +--- +layout: default +title: How to use callback function in feathr client +parent: Feathr How-to Guides +--- + +## What is a callback function + +A callback function is a function that is sent to another function as an argument. It can be used to extend the function as per the user needs. + +## How to use callback functions + +We can pass a callback function when initializing the feathr client. + +```python +client = FeathrClient(config_path, callback) +``` + +The below functions accept an optional parameters named **params**. params is a dictionary where user can pass the arguments for the callback function. + +- get_online_features +- multi_get_online_features +- get_offline_features +- monitor_features +- materialize_features + +An example on how to use it: + +```python +# inside notebook +client = FeathrClient(config_path, callback) +params = {"param1":"value1", "param2":"value2"} +client.get_offline_features(observation_settings,feature_query,output_path, params) + +# users can define their own callback function +async def callback(params): + import httpx + async with httpx.AsyncClient() as requestHandler: + response = await requestHandler.post('https://some-endpoint', json = params) + return response + +``` diff --git a/feathr_project/feathr/client.py b/feathr_project/feathr/client.py index 9575a3151..2a8c4cd86 100644 --- a/feathr_project/feathr/client.py +++ b/feathr_project/feathr/client.py @@ -2,6 +2,7 @@ import logging import os import tempfile +import asyncio from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Optional, Union @@ -82,12 +83,13 @@ class FeathrClient(object): local_workspace_dir (str, optional): set where is the local work space dir. If not set, Feathr will create a temporary folder to store local workspace related files. credential (optional): credential to access cloud resources, most likely to be the returned result of DefaultAzureCredential(). If not set, Feathr will initialize DefaultAzureCredential() inside the __init__ function to get credentials. project_registry_tag (Dict[str, str]): adding tags for project in Feathr registry. This might be useful if you want to tag your project as deprecated, or allow certain customizations on project leve. Default is empty + callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread. This is optional. Raises: RuntimeError: Fail to create the client since necessary environment variables are not set for Redis client creation. """ - def __init__(self, config_path:str = "./feathr_config.yaml", local_workspace_dir: str = None, credential=None, project_registry_tag: Dict[str, str]=None): + def __init__(self, config_path:str = "./feathr_config.yaml", local_workspace_dir: str = None, credential=None, project_registry_tag: Dict[str, str]=None, callback:callable = None): self.logger = logging.getLogger(__name__) # Redis key separator self._KEY_SEPARATOR = ':' @@ -182,6 +184,7 @@ def __init__(self, config_path:str = "./feathr_config.yaml", local_workspace_dir 'feature_registry', 'purview', 'purview_name') # initialize the registry no matter whether we set purview name or not, given some of the methods are used there. self.registry = _FeatureRegistry(self.project_name, self.azure_purview_name, self.registry_delimiter, project_registry_tag, config_path = config_path, credential=self.credential) + self.callback = callback def _check_required_environment_variables_exist(self): """Checks if the required environment variables(form feathr_config.yaml) is set. @@ -264,13 +267,15 @@ def _get_registry_client(self): """ return self.registry._get_registry_client() - def get_online_features(self, feature_table, key, feature_names): - """Fetches feature value for a certain key from a online feature table. + def get_online_features(self, feature_table, key, feature_names, params: dict = None): + """Fetches feature value for a certain key from a online feature table. There is an optional callback function + and the params to extend this function's capability.For eg. cosumer of the features. Args: feature_table: the name of the feature table. key: the key of the entity feature_names: list of feature names to fetch + params: a dictionary of parameters for the callback function Return: A list of feature values for this entity. It's ordered by the requested feature names. @@ -283,15 +288,20 @@ def get_online_features(self, feature_table, key, feature_names): """ redis_key = self._construct_redis_key(feature_table, key) res = self.redis_clint.hmget(redis_key, *feature_names) - return self._decode_proto(res) + feature_values = self._decode_proto(res) + if (self.callback is not None) and (params is not None): + event_loop = asyncio.get_event_loop() + event_loop.create_task(self.callback(params)) + return feature_values - def multi_get_online_features(self, feature_table, keys, feature_names): + def multi_get_online_features(self, feature_table, keys, feature_names, params: dict = None): """Fetches feature value for a list of keys from a online feature table. This is the batch version of the get API. Args: feature_table: the name of the feature table. keys: list of keys for the entities feature_names: list of feature names to fetch + params: a dictionary of parameters for the callback function Return: A list of feature values for the requested entities. It's ordered by the requested feature names. For @@ -312,6 +322,10 @@ def multi_get_online_features(self, feature_table, keys, feature_names): for feature_list in pipeline_result: decoded_pipeline_result.append(self._decode_proto(feature_list)) + if (self.callback is not None) and (params is not None): + event_loop = asyncio.get_event_loop() + event_loop.create_task(self.callback(params)) + return dict(zip(keys, decoded_pipeline_result)) def _decode_proto(self, feature_list): @@ -412,15 +426,18 @@ def get_offline_features(self, output_path: str, execution_configurations: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, udf_files = None, - verbose: bool = False + verbose: bool = False, + params: dict = None ): """ - Get offline features for the observation dataset + Get offline features for the observation dataset. There is an optional callback function and the params + to extend this function's capability.For eg. cosumer of the features. Args: observation_settings: settings of the observation data, e.g. timestamp columns, input path, etc. feature_query: features that are requested to add onto the observation data output_path: output path of job, i.e. the observation data with features attached. - execution_configurations: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations. + execution_configuratons: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations. + params: a dictionary of parameters for the callback function """ feature_queries = feature_query if isinstance(feature_query, List) else [feature_query] feature_names = [] @@ -457,7 +474,11 @@ def get_offline_features(self, FeaturePrinter.pretty_print_feature_query(feature_query) write_to_file(content=config, full_file_name=config_file_path) - return self._get_offline_features_with_config(config_file_path, execution_configurations, udf_files=udf_files) + job_info = self._get_offline_features_with_config(config_file_path, execution_configurations, udf_files=udf_files) + if (self.callback is not None) and (params is not None): + event_loop = asyncio.get_event_loop() + event_loop.create_task(self.callback(params)) + return job_info def _get_offline_features_with_config(self, feature_join_conf_path='feature_join_conf/feature_join.conf', execution_configurations: Dict[str,str] = {}, udf_files=[]): """Joins the features to your offline observation dataset based on the join config. @@ -534,21 +555,28 @@ def wait_job_to_finish(self, timeout_sec: int = 300): else: raise RuntimeError('Spark job failed.') - def monitor_features(self, settings: MonitoringSettings, execution_configurations: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False): - """Create a offline job to generate statistics to monitor feature data + def monitor_features(self, settings: MonitoringSettings, execution_configuratons: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False, params: dict = None): + """Create a offline job to generate statistics to monitor feature data. There is an optional + callback function and the params to extend this function's capability.For eg. cosumer of the features. Args: settings: Feature monitoring settings - execution_configurations: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations. + execution_configuratons: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations. + params: a dictionary of parameters for the callback function. """ - self.materialize_features(settings, execution_configurations, verbose) + self.materialize_features(settings, execution_configuratons, verbose) + if (self.callback is not None) and (params is not None): + event_loop = asyncio.get_event_loop() + event_loop.create_task(self.callback(params)) - def materialize_features(self, settings: MaterializationSettings, execution_configurations: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False): - """Materialize feature data + def materialize_features(self, settings: MaterializationSettings, execution_configurations: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False, params: dict = None): + """Materialize feature data. There is an optional callback function and the params + to extend this function's capability.For eg. cosumer of the feature store. Args: settings: Feature materialization settings - execution_configurations: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations. + execution_configuratons: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations. + params: a dictionary of parameters for the callback function """ # produce materialization config for end in settings.get_backfill_cutoff_time(): @@ -575,6 +603,10 @@ def materialize_features(self, settings: MaterializationSettings, execution_conf # Pretty print feature_names of materialized features if verbose and settings: FeaturePrinter.pretty_print_materialize_features(settings) + + if (self.callback is not None) and (params is not None): + event_loop = asyncio.get_event_loop() + event_loop.create_task(self.callback(params)) def _materialize_features_with_config(self, feature_gen_conf_path: str = 'feature_gen_conf/feature_gen.conf',execution_configurations: Dict[str,str] = {}, udf_files=[]): """Materializes feature data based on the feature generation config. The feature diff --git a/feathr_project/test/test_client_callback.py b/feathr_project/test/test_client_callback.py new file mode 100644 index 000000000..544c4c20b --- /dev/null +++ b/feathr_project/test/test_client_callback.py @@ -0,0 +1,201 @@ +import os +import asyncio +import unittest.mock as mock +import time +from subprocess import call +from datetime import datetime, timedelta + +from pathlib import Path +from feathr import ValueType +from feathr import FeatureQuery +from feathr import ObservationSettings +from feathr import TypedKey +from test_fixture import get_online_test_table_name +from feathr.definition._materialization_utils import _to_materialization_config +from feathr import (BackfillTime, MaterializationSettings) +from feathr import (BackfillTime, MaterializationSettings, FeatureQuery, + ObservationSettings, SparkExecutionConfiguration) +from feathr import RedisSink, HdfsSink +from feathr import (BOOLEAN, FLOAT, INPUT_CONTEXT, INT32, STRING, + DerivedFeature, Feature, FeatureAnchor, HdfsSource, + TypedKey, ValueType, WindowAggTransformation) +from feathr import FeathrClient + +params = {"wait" : 0.1} +async def sample_callback(params): + print(params) + await asyncio.sleep(0.1) + +callback = mock.MagicMock(return_value=sample_callback(params)) + + +def basic_test_setup_with_callback(config_path: str, callback: callable): + + now = datetime.now() + # set workspace folder by time; make sure we don't have write conflict if there are many CI tests running + os.environ['SPARK_CONFIG__DATABRICKS__WORK_DIR'] = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), '_', str(now.microsecond)]) + os.environ['SPARK_CONFIG__AZURE_SYNAPSE__WORKSPACE_DIR'] = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/feathr_github_ci','_', str(now.minute), '_', str(now.second) ,'_', str(now.microsecond)]) + + client = FeathrClient(config_path=config_path, callback=callback) + batch_source = HdfsSource(name="nycTaxiBatchSource", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss") + + f_trip_distance = Feature(name="f_trip_distance", + feature_type=FLOAT, transform="trip_distance") + f_trip_time_duration = Feature(name="f_trip_time_duration", + feature_type=INT32, + transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60") + + features = [ + f_trip_distance, + f_trip_time_duration, + Feature(name="f_is_long_trip_distance", + feature_type=BOOLEAN, + transform="cast_float(trip_distance)>30"), + Feature(name="f_day_of_week", + feature_type=INT32, + transform="dayofweek(lpep_dropoff_datetime)"), + ] + + + request_anchor = FeatureAnchor(name="request_features", + source=INPUT_CONTEXT, + features=features) + + f_trip_time_distance = DerivedFeature(name="f_trip_time_distance", + feature_type=FLOAT, + input_features=[ + f_trip_distance, f_trip_time_duration], + transform="f_trip_distance * f_trip_time_duration") + + f_trip_time_rounded = DerivedFeature(name="f_trip_time_rounded", + feature_type=INT32, + input_features=[f_trip_time_duration], + transform="f_trip_time_duration % 10") + + location_id = TypedKey(key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id") + agg_features = [Feature(name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", + agg_func="AVG", + window="90d")), + Feature(name="f_location_max_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", + agg_func="MAX", + window="90d")) + ] + + agg_anchor = FeatureAnchor(name="aggregationFeatures", + source=batch_source, + features=agg_features) + + client.build_features(anchor_list=[agg_anchor, request_anchor], derived_feature_list=[ + f_trip_time_distance, f_trip_time_rounded]) + + return client + + + +def test_client_callback_offline_feature(): + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" + client = basic_test_setup_with_callback(os.path.join(test_workspace_dir, "feathr_config.yaml"),callback) + + location_id = TypedKey(key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id") + feature_query = FeatureQuery(feature_list=["f_location_avg_fare"], key=location_id) + + settings = ObservationSettings( + observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss") + + now = datetime.now() + output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), ".avro"]) + + res = client.get_offline_features(observation_settings=settings, + feature_query=feature_query, + output_path=output_path, + params=params) + callback.assert_called_with(params) + + +def test_client_callback_materialization(): + online_test_table = get_online_test_table_name("nycTaxiCITable") + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" + + client = basic_test_setup_with_callback(os.path.join(test_workspace_dir, "feathr_config.yaml"),callback) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + redisSink = RedisSink(table_name=online_test_table) + settings = MaterializationSettings("nycTaxiTable", + sinks=[redisSink], + feature_names=[ + "f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time) + client.materialize_features(settings, params=params) + callback.assert_called_with(params) + +def test_client_callback_monitor_features(): + online_test_table = get_online_test_table_name("nycTaxiCITable") + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" + + client = basic_test_setup_with_callback(os.path.join(test_workspace_dir, "feathr_config.yaml"),callback) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + redisSink = RedisSink(table_name=online_test_table) + settings = MaterializationSettings("nycTaxiTable", + sinks=[redisSink], + feature_names=[ + "f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time) + client.monitor_features(settings, params=params) + callback.assert_called_with(params) + +def test_client_callback_get_online_features(): + online_test_table = get_online_test_table_name("nycTaxiCITable") + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" + + client = basic_test_setup_with_callback(os.path.join(test_workspace_dir, "feathr_config.yaml"),callback) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + redisSink = RedisSink(table_name=online_test_table) + settings = MaterializationSettings("nycTaxiTable", + sinks=[redisSink], + feature_names=[ + "f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time) + client.materialize_features(settings) + callback.assert_called_with(params) + client.wait_job_to_finish(timeout_sec=900) + # wait for a few secs for the data to come in redis + time.sleep(5) + client.get_online_features('nycTaxiDemoFeature', '265', ['f_location_avg_fare', 'f_location_max_fare'], params=params) + callback.assert_called_with(params) + + +def test_client_callback_multi_get_online_features(): + online_test_table = get_online_test_table_name("nycTaxiCITable") + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" + + client = basic_test_setup_with_callback(os.path.join(test_workspace_dir, "feathr_config.yaml"),callback) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + redisSink = RedisSink(table_name=online_test_table) + settings = MaterializationSettings("nycTaxiTable", + sinks=[redisSink], + feature_names=[ + "f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time) + client.materialize_features(settings) + callback.assert_called_with(params) + client.wait_job_to_finish(timeout_sec=900) + # wait for a few secs for the data to come in redis + time.sleep(5) + client.multi_get_online_features('nycTaxiDemoFeature', ["239", "265"], ['f_location_avg_fare', 'f_location_max_fare'], params=params) + callback.assert_called_with(params) \ No newline at end of file