Skip to content

Commit

Permalink
Modularized data download function
Browse files Browse the repository at this point in the history
Signed-off-by: noopur <[email protected]>
  • Loading branch information
noopurintel committed Jan 20, 2025
1 parent 00fcccb commit 458a3ea
Showing 1 changed file with 36 additions and 19 deletions.
55 changes: 36 additions & 19 deletions tests/end_to_end/utils/federation_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,20 +689,17 @@ def setup_collaborator_data(collaborators, model_name, local_bind_path):
model_name (str): Model name
local_bind_path (str): Local bind path
"""
# Check if data already exists, if yes, skip the download part
# This is mainly helpful in case of re-runs
for index, collaborator in enumerate(collaborators, start=1):
data_path = os.path.join(constants.COL_WORKSPACE_PATH.format(local_bind_path, collaborator.name), "data")
if os.path.exists(data_path):
folders = [f for f in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, f))]
# For example, in case of xgb_higgs, data is present in folders like
# data/1, data/2, etc. under respective collaborator workspaces.
if folders and folders[0] != str(index):
raise ex.DataSetupException(f"Data is present but not for {collaborator.name}.")

# Download the data for the model in the local bind path (for e.g. /home/user/results/xgb_higgs)
# and then copy to the respective collaborator workspaces
log.info("Downloading the data for the model. This will take some time to complete based on the data size ..")
if not pre_existing_data(collaborators, model_name, local_bind_path):
download_data(collaborators, model_name, local_bind_path)

log.info("Data setup is complete for all the collaborators")


def download_data(collaborators, model_name, local_bind_path):
"""
Download the data for the model and copy to the respective collaborator workspaces
"""
log.info(f"Copying {constants.DATA_SETUP_FILE} from one of the collaborator workspaces to the local bind path..")
try:
shutil.copyfile(
src=os.path.join(collaborators[0].workspace_path, "src", constants.DATA_SETUP_FILE),
Expand All @@ -711,16 +708,16 @@ def setup_collaborator_data(collaborators, model_name, local_bind_path):
except Exception as e:
raise ex.DataSetupException(f"Failed to copy data setup file: {e}")

log.info("Downloading the data for the model. This will take some time to complete based on the data size ..")
error_msg = f"Failed to download data for {model_name}"
try:
return_code, output, error = run_command(
return_code, _, error = run_command(
f"python -v {constants.DATA_SETUP_FILE} {len(collaborators)}",
workspace_path=local_bind_path,
error_msg=error_msg,
return_error=True,
)
log.info("Data download completed successfully. Modifying the data.yaml file..")
if error:
if return_code !=0 or error:
raise ex.DataSetupException(f"{error_msg}: {error}")

except Exception:
Expand All @@ -729,8 +726,8 @@ def setup_collaborator_data(collaborators, model_name, local_bind_path):
try:
# Move the data to the respective workspace based on the index
for index, collaborator in enumerate(collaborators, start=1):
src_folder = os.path.join(local_bind_path, 'data', str(index))
dst_folder = os.path.join(collaborator.workspace_path, 'data', str(index))
src_folder = os.path.join(local_bind_path, "data", str(index))
dst_folder = os.path.join(collaborator.workspace_path, "data", str(index))
if os.path.exists(src_folder):
shutil.copytree(src_folder, dst_folder, dirs_exist_ok=True)
log.info(f"Copied data from {src_folder} to {dst_folder}")
Expand All @@ -747,6 +744,26 @@ def setup_collaborator_data(collaborators, model_name, local_bind_path):
return True


def pre_existing_data(collaborators):
"""
Check if data already exists for the model
Args:
collaborators (list): List of collaborator objects
Returns:
bool: True if data already exists, else False
"""
# Check if data already exists, if yes, skip the download part
# This is mainly helpful in case of re-runs
for index, collaborator in enumerate(collaborators, start=1):
dst_folder = os.path.join(collaborator.workspace_path, "data", str(index))
if os.path.exists(dst_folder):
log.info(f"Destination folder {dst_folder} already exists. Using the existing data..")
continue
else:
return False
return True


def extract_memory_usage(log_file):
"""
Extracts memory usage data from a log file.
Expand Down

0 comments on commit 458a3ea

Please sign in to comment.