Skip to content

Commit

Permalink
moved callback to constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
shivamsanju committed Jun 28, 2022
1 parent d7bc984 commit aeb3e45
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 42 deletions.
18 changes: 10 additions & 8 deletions docs/how-to-guides/client-callback-function.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,29 @@ A callback function is a function that is sent to another function as an argumen

## How to use callback functions

Currently the below functions in feathr client support passing a callback as an argument:
We can pass a callback function when initializing the feathr client.

```python
client = FeathrClient(config_path, callback)
```

The below functions accept an optional parameters named **params**. params is a dictionary where user can pass the arguments for the callback function.

- get_online_features
- multi_get_online_features
- get_offline_features
- monitor_features
- materialize_features

These functions accept two optional parameters named **callback** and **params**.
callback is of type function and params is a dictionary where user can pass the arguments for the callback function.

An example on how to use it:

```python
# inside notebook
client = FeathrClient(config_path)
client.get_offline_features(observation_settings,feature_query,output_path, callback, params)

# users can define their own callback function and params
client = FeathrClient(config_path, callback)
params = {"param1":"value1", "param2":"value2"}
client.get_offline_features(observation_settings,feature_query,output_path, params)

# users can define their own callback function
async def callback(params):
import httpx
async with httpx.AsyncClient() as requestHandler:
Expand Down
40 changes: 18 additions & 22 deletions feathr_project/feathr/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,13 @@ class FeathrClient(object):
local_workspace_dir (str, optional): set where is the local work space dir. If not set, Feathr will create a temporary folder to store local workspace related files.
credential (optional): credential to access cloud resources, most likely to be the returned result of DefaultAzureCredential(). If not set, Feathr will initialize DefaultAzureCredential() inside the __init__ function to get credentials.
project_registry_tag (Dict[str, str]): adding tags for project in Feathr registry. This might be useful if you want to tag your project as deprecated, or allow certain customizations on project leve. Default is empty
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread. This is optional.
Raises:
RuntimeError: Fail to create the client since necessary environment variables are not set for Redis
client creation.
"""
def __init__(self, config_path:str = "./feathr_config.yaml", local_workspace_dir: str = None, credential=None, project_registry_tag: Dict[str, str]=None):
def __init__(self, config_path:str = "./feathr_config.yaml", local_workspace_dir: str = None, credential=None, project_registry_tag: Dict[str, str]=None, callback:callable = None):
self.logger = logging.getLogger(__name__)
# Redis key separator
self._KEY_SEPARATOR = ':'
Expand Down Expand Up @@ -183,6 +184,7 @@ def __init__(self, config_path:str = "./feathr_config.yaml", local_workspace_dir
'feature_registry', 'purview', 'purview_name')
# initialize the registry no matter whether we set purview name or not, given some of the methods are used there.
self.registry = _FeatureRegistry(self.project_name, self.azure_purview_name, self.registry_delimiter, project_registry_tag, config_path = config_path, credential=self.credential)
self.callback = callback

def _check_required_environment_variables_exist(self):
"""Checks if the required environment variables(form feathr_config.yaml) is set.
Expand Down Expand Up @@ -265,15 +267,14 @@ def _get_registry_client(self):
"""
return self.registry._get_registry_client()

def get_online_features(self, feature_table, key, feature_names, callback: callable = None, params: dict = None):
def get_online_features(self, feature_table, key, feature_names, params: dict = None):
"""Fetches feature value for a certain key from a online feature table. There is an optional callback function
and the params to extend this function's capability.For eg. cosumer of the features.
Args:
feature_table: the name of the feature table.
key: the key of the entity
feature_names: list of feature names to fetch
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
params: a dictionary of parameters for the callback function
Return:
Expand All @@ -288,19 +289,18 @@ def get_online_features(self, feature_table, key, feature_names, callback: calla
redis_key = self._construct_redis_key(feature_table, key)
res = self.redis_clint.hmget(redis_key, *feature_names)
feature_values = self._decode_proto(res)
if (callback is not None) and (params is not None):
if (self.callback is not None) and (params is not None):
event_loop = asyncio.get_event_loop()
event_loop.create_task(callback(params))
event_loop.create_task(self.callback(params))
return feature_values

def multi_get_online_features(self, feature_table, keys, feature_names, callback: callable = None, params: dict = None):
def multi_get_online_features(self, feature_table, keys, feature_names, params: dict = None):
"""Fetches feature value for a list of keys from a online feature table. This is the batch version of the get API.
Args:
feature_table: the name of the feature table.
keys: list of keys for the entities
feature_names: list of feature names to fetch
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
params: a dictionary of parameters for the callback function
Return:
Expand All @@ -322,9 +322,9 @@ def multi_get_online_features(self, feature_table, keys, feature_names, callback
for feature_list in pipeline_result:
decoded_pipeline_result.append(self._decode_proto(feature_list))

if (callback is not None) and (params is not None):
if (self.callback is not None) and (params is not None):
event_loop = asyncio.get_event_loop()
event_loop.create_task(callback(params))
event_loop.create_task(self.callback(params))

return dict(zip(keys, decoded_pipeline_result))

Expand Down Expand Up @@ -427,7 +427,6 @@ def get_offline_features(self,
execution_configurations: Union[SparkExecutionConfiguration ,Dict[str,str]] = {},
udf_files = None,
verbose: bool = False,
callback: callable = None,
params: dict = None
):
"""
Expand All @@ -438,7 +437,6 @@ def get_offline_features(self,
feature_query: features that are requested to add onto the observation data
output_path: output path of job, i.e. the observation data with features attached.
execution_configuratons: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations.
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
params: a dictionary of parameters for the callback function
"""
feature_queries = feature_query if isinstance(feature_query, List) else [feature_query]
Expand Down Expand Up @@ -476,10 +474,10 @@ def get_offline_features(self,
FeaturePrinter.pretty_print_feature_query(feature_query)

write_to_file(content=config, full_file_name=config_file_path)
job_info = self._get_offline_features_with_config(config_file_path, execution_configuratons, udf_files=udf_files)
if (callback is not None) and (params is not None):
job_info = self._get_offline_features_with_config(config_file_path, execution_configurations, udf_files=udf_files)
if (self.callback is not None) and (params is not None):
event_loop = asyncio.get_event_loop()
event_loop.create_task(callback(params))
event_loop.create_task(self.callback(params))
return job_info

def _get_offline_features_with_config(self, feature_join_conf_path='feature_join_conf/feature_join.conf', execution_configurations: Dict[str,str] = {}, udf_files=[]):
Expand Down Expand Up @@ -557,29 +555,27 @@ def wait_job_to_finish(self, timeout_sec: int = 300):
else:
raise RuntimeError('Spark job failed.')

def monitor_features(self, settings: MonitoringSettings, execution_configuratons: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False, callback: callable = None, params: dict = None):
def monitor_features(self, settings: MonitoringSettings, execution_configuratons: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False, params: dict = None):
"""Create a offline job to generate statistics to monitor feature data. There is an optional
callback function and the params to extend this function's capability.For eg. cosumer of the features.
Args:
settings: Feature monitoring settings
execution_configuratons: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations.
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
params: a dictionary of parameters for the callback function.
"""
self.materialize_features(settings, execution_configuratons, verbose)
if (callback is not None) and (params is not None):
if (self.callback is not None) and (params is not None):
event_loop = asyncio.get_event_loop()
event_loop.create_task(callback(params))
event_loop.create_task(self.callback(params))

def materialize_features(self, settings: MaterializationSettings, execution_configuratons: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False, callback: callable = None, params: dict = None):
def materialize_features(self, settings: MaterializationSettings, execution_configurations: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False, params: dict = None):
"""Materialize feature data. There is an optional callback function and the params
to extend this function's capability.For eg. cosumer of the feature store.
Args:
settings: Feature materialization settings
execution_configuratons: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations.
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
params: a dictionary of parameters for the callback function
"""
# produce materialization config
Expand Down Expand Up @@ -608,9 +604,9 @@ def materialize_features(self, settings: MaterializationSettings, execution_conf
if verbose and settings:
FeaturePrinter.pretty_print_materialize_features(settings)

if (callback is not None) and (params is not None):
if (self.callback is not None) and (params is not None):
event_loop = asyncio.get_event_loop()
event_loop.create_task(callback(params))
event_loop.create_task(self.callback(params))

def _materialize_features_with_config(self, feature_gen_conf_path: str = 'feature_gen_conf/feature_gen.conf',execution_configurations: Dict[str,str] = {}, udf_files=[]):
"""Materializes feature data based on the feature generation config. The feature
Expand Down
Loading

0 comments on commit aeb3e45

Please sign in to comment.