Skip to content

Commit

Permalink
Associate node stat agg with request
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Mar 4, 2025
1 parent 4cf10b8 commit c8c8edc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

package org.opensearch.knn.plugin.transport;

import lombok.Getter;
import lombok.Setter;
import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.knn.plugin.stats.KNNNodeStatAggregation;
import org.opensearch.knn.plugin.stats.StatNames;

import java.io.IOException;
Expand All @@ -26,6 +29,12 @@ public class KNNStatsRequest extends BaseNodesRequest<KNNStatsRequest> {
public static final String ALL_STATS_KEY = "_all";
private final Set<String> validStats;
private final Set<String> statsToBeRetrieved;
/**
* Node stat aggregation associated with the request. Not serialized between nodes. Can be null.
*/
@Getter
@Setter
private KNNNodeStatAggregation aggregation;

/**
* Empty constructor needed for KNNStatsTransportAction
Expand All @@ -34,6 +43,7 @@ public KNNStatsRequest() {
super((String[]) null);
validStats = StatNames.getNames();
statsToBeRetrieved = new HashSet<>();
aggregation = null;
}

/**
Expand All @@ -46,6 +56,7 @@ public KNNStatsRequest(StreamInput in) throws IOException {
super(in);
validStats = in.readSet(StreamInput::readString);
statsToBeRetrieved = in.readSet(StreamInput::readString);
aggregation = null;
}

/**
Expand All @@ -57,6 +68,7 @@ public KNNStatsRequest(String... nodeIds) {
super(nodeIds);
validStats = StatNames.getNames();
statsToBeRetrieved = new HashSet<>();
aggregation = null;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ public class KNNStatsTransportAction extends TransportNodesAction<

private final KNNStats knnStats;
private final Client client;
private KNNNodeStatAggregation knnNodeStatAggregation;

/**
* Constructor
Expand Down Expand Up @@ -71,7 +70,6 @@ public KNNStatsTransportAction(
);
this.knnStats = knnStats;
this.client = client;
this.knnNodeStatAggregation = null;
}

@Override
Expand All @@ -98,7 +96,7 @@ protected void doExecute(Task task, KNNStatsRequest request, ActionListener<KNNS
// Add the stats makes sure that we dont recurse infinitely.
dependentStats.forEach(knnStatsRequest::addStat);
client.execute(KNNStatsAction.INSTANCE, knnStatsRequest, ActionListener.wrap(knnStatsResponse -> {
knnNodeStatAggregation = new KNNNodeStatAggregation(knnStatsResponse.getNodes());
request.setAggregation(new KNNNodeStatAggregation(knnStatsResponse.getNodes()));
contextListener.onResponse(null);
}, contextListener::onFailure));
} else {
Expand All @@ -118,7 +116,7 @@ protected KNNStatsResponse newResponse(

for (String statName : knnStats.getClusterStats().keySet()) {
if (statsToBeRetrieved.contains(statName)) {
clusterStats.put(statName, knnStats.getStats().get(statName).getValue(knnNodeStatAggregation));
clusterStats.put(statName, knnStats.getStats().get(statName).getValue(request.getAggregation()));
}
}

Expand Down

0 comments on commit c8c8edc

Please sign in to comment.