-
Notifications
You must be signed in to change notification settings - Fork 915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How can I implement a YOLO model using the Flower framework? #4361
Comments
Hey @wkqco33 , I saw first your question here https://discuss.flower.ai/t/how-can-i-implement-a-yolo-model-using-the-flower-framework/368/2. Let us know if you made your Yolo+Flower experiments work. |
It's server code class SaveModelStrategy(FedAvg):
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
aggregated_parameters, aggregated_metrics = super().aggregate_fit(
server_round, results, failures
)
return aggregated_parameters, aggregated_metrics
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]
return {"accuracy": sum(accuracies) / sum(examples)} # type: ignore
def server_fn():
model = load_model()
parameters = ndarrays_to_parameters(get_weights(model))
log(INFO, "Initial parameters")
strategy = SaveModelStrategy(
fraction_fit=0.5,
fraction_evaluate=1.0,
min_available_clients=2,
initial_parameters=parameters,
evaluate_metrics_aggregation_fn=weighted_average,
)
num_rounds = 3
config = ServerConfig(num_rounds=num_rounds)
return strategy, config
if __name__ == "__main__":
strategy, config = server_fn()
fl.server.start_server(
server_address="0.0.0.0:8080",
strategy=strategy,
config=config,
) and it's client code class RobotClient(NumPyClient):
def __init__(self, data, epochs):
self.model: YOLO = load_model()
self.data = data
self.epochs = epochs
def fit(self, parameters: NDArrays, config):
set_weights(self.model, parameters)
results: DetMetrics | None = self.model.train(
data=self.data,
epochs=self.epochs,
)
if results is not None:
log(INFO, f"Results: {results.box.map}")
return get_weights(self.model), 10, {}
def evaluate(self, parameters: NDArrays, config):
set_weights(self.model, parameters)
matrics: DetMetrics = self.model.val()
accuracy = matrics.box.map
loss = matrics.fitness
return loss, 10, {"accuracy": accuracy}
def client_fn():
data = "coco8.yaml"
epochs = 1
return RobotClient(data, epochs).to_client()
if __name__ == "__main__":
server_ip = os.getenv("SERVER_IP", "localhost")
server_port = os.getenv("SERVER_PORT", "8080")
log(INFO, f"Server IP: {server_ip}:{server_port}")
fl.client.start_client(
server_address=f"{server_ip}:{server_port}",
client=client_fn(),
) I changed the get_weights and set_weights functions with your advice. def load_model():
model = YOLO("yolo11n.pt")
return model
def get_weights(model: YOLO) -> NDArrays:
return [val.cpu().numpy() for _, val in model.state_dict().items()]
def set_weights(model, parameters):
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=False) But I can't see the result because of these error messages.
I really appreciate @jafermarq you and your support!!!! |
What is your question?
How to Pass Weights as Parameters in Flower?
I’m trying to use the Flower framework to train a YOLO model in a federated learning setting. I’m having trouble figuring out how to properly pass the model weights as parameters between the server and clients.
Here’s what I’ve tried so far:
However, I’m encountering errors during training, and I suspect it’s related to how the weights are being handled.
Could someone provide guidance or examples on how to correctly pass YOLO model weights as parameters in Flower? Any help would be greatly appreciated!
The text was updated successfully, but these errors were encountered: