Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
porteratzo committed Dec 26, 2023
1 parent fb63061 commit ac17821
Show file tree
Hide file tree
Showing 15 changed files with 735 additions and 719 deletions.
83 changes: 69 additions & 14 deletions openfl-tutorials/experimental/Federeated_Pytorch_LLM_Horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
import openfl.native as fx
import sys

sys.path.append("openfl/openfl-workspace/torch_llm")
from src.pt_model import LLMTaskRunner
from src.ptglue_inmemory import GlueMrpcFederatedDataLoader
import openfl.interface.workspace as workspace
import os
import subprocess
Expand Down Expand Up @@ -58,29 +55,87 @@ def propogate_workspace():
])
if result.returncode != 0:
raise RuntimeError(result.stderr)

def propogate_dataset(data_loader):
remote_hosts = [
i.split(":")[0] for i in HOSTS.split(",") if i.split(":")[0] != LOCAL_HOST
]
for rem_host in remote_hosts:
result = subprocess.run(
[
"scp",
"-r",
os.getcwd() + f"/temp_dataset_{data_loader.data_path}_train",
rem_host
+ ":"
+ os.getcwd()
+ f"/temp_dataset_{data_loader.data_path}_train",
],
capture_output=True,
)
if result.returncode != 0:
raise RuntimeError(result.stderr)
result = subprocess.run(
[
"scp",
"-r",
os.getcwd() + f"/temp_dataset_{data_loader.data_path}_valid",
rem_host
+ ":"
+ os.getcwd()
+ f"/temp_dataset_{data_loader.data_path}_valid",
],
capture_output=True,
)
if result.returncode != 0:
raise RuntimeError(result.stderr)

def get_args():
"""
Get command-line arguments for a script.
Parameters:
- data_path (str): Path to the data.
- model_path (str): Path to the model.
Returns:
- args (Namespace): A namespace containing the parsed arguments.
"""
import argparse

parser = argparse.ArgumentParser(description="Your script description here.")
parser.add_argument(
"--dont_propogate_to_nodes", action='store_true', help="Path to the data.", required=False,
)
args = parser.parse_args()
return args

def main():
args = get_args()
print(WORKSPACE_PREFIX)
log_level = "INFO"
log_file = None
workspace.create(WORKSPACE_PREFIX, "torch_llm")
os.chdir(WORKSPACE_PREFIX)
#workspace.create(WORKSPACE_PREFIX, "torch_llm_horovod")
#os.chdir(WORKSPACE_PREFIX)
sys.path.append(WORKSPACE_PREFIX)
propogate_workspace()
fx.setup_logging(level=log_level, log_file=log_file)
num_collaborators = 1

collaborator_models = [
LLMTaskRunner(

from src.pt_model import LLMTaskRunner
from src.ptglue_inmemory import GlueMrpcFederatedDataLoader

collaborator_models = [LLMTaskRunner(
data_loader=GlueMrpcFederatedDataLoader(
data_slice, 32, collaborator_count=num_collaborators
1, 32, collaborator_count=num_collaborators
)
)
for data_slice in range(num_collaborators)
]
)]
collaborators = {
"one": collaborator_models[0],
}
if not args.dont_propogate_to_nodes:
propogate_workspace()
propogate_dataset(collaborator_models[0].data_loader)

#fx.setup_logging(level=log_level, log_file=log_file)

# Collaborator one's data
for i, model in enumerate(collaborator_models):
Expand Down
Loading

0 comments on commit ac17821

Please sign in to comment.