-
Notifications
You must be signed in to change notification settings - Fork 55
Aggregator
Aggregator is designed for aggregating some values from all executors and then each executor can get the aggregated values. Common use cases are finding the top k elements, counting the number of elements, calculating the sum of the gradients in a machine learning algorithm, etc.
First of all, include the headers. It's enough to include just lib/aggregator_factory.hpp
.
Then you can create aggregators using the class Aggregator
. The constructor needs two things, one is the initial value and the other one is the aggregation rule which is a lambda function, e.g. to create an aggregator for summation:
Aggregator<int> agg(0, [](int& a, const int& b){ a += b; });
The initial value here is zero. And the aggregation rule is to tell, with a new coming value b
, how the aggregator value a
should be updated. For summation, we just add b
into a
. By default, Aggregator
uses the default construtor for the initial value and uses +
or +=
as the aggregation rule. So here to create an aggregator for summation can be simplified:
Aggregator<int> agg;
Another example is a vector with double values of size 10:
Aggregator<vector<double> > grad(vector<double>(10),
[](vector<double>& a, const vector<double>& b) {
for (int i = 0; i < a.size(); ++i) a[i] += b[i]; // bitwise addition
}, [](vector<double>& v) {
v = std::move(vector<double>(10)); // v will become a vector containing 10 zeros
});
Here a third parameter, which is a lambda, is given to the aggregator construtor. This lambda is call Zero Value Lambda
which is to reset an aggregator value to zero
.
When a husky job is running, many executors may be busy aggregating the values to an aggregator. Here the values aggregated by one executor is aggregated to a local copy first and local copies will later be sent to a global center for a global aggregation. So a local copy needs an initial state which is given by the Zero Value Lambda
. One important property for zero value is that a + zero value
should always be equal to a
.
After creating aggregators, you can aggregate values into the aggregator using update
or update_any
. With update
you can aggregate values sharing the same type as the aggregator value using the aggregation rule given in the aggregator construtor:
agg.update(1); // add 1 to the aggregator value
With update_any
you can aggregate values of any types and use different aggregation rules, e.g. add a value to a specific index of a vector
int idx = ...;
double val = ...;
grad.update_any([&](vector<double>& d){
d[idx] += val;
});
After aggregating the values, the updates are just in fact kept inside the local copies. One way to perform a global aggregation is to use HuskyAggregatorFactory::sync()
and you must ensure all executors take part into the synchronization. The other way is to first get the aggregator channel from HuskyAggregatorFactory::get_channel()
and use it as an out channel of a list_execute
, then the global aggregation is performed by the list_execute
but it's HuskyAggregatorFactory::sync()
that will be called eventually.
AggregatorFactory::sync();
/*** or using aggregator channel ***/
auto& ac = AggregatorFactory::get_channel();
list_execute(obj_list, {}, {&ac}, [&](OBJ& obj) {
... // here we can give updates to some aggregators
});
And after the global aggregation, we can fetch the value using get_value()
. Here get_value()
returns the reference to the final aggregated value and this value is shared by all executors on the same machines. Modification of the value may influence other executors and may have thread safety issues.
int sum = agg.get_value();
vector<double>& sum_of_grad = grad.get_value();
Finally an easy example is given for summary:
// 0. header
#include "lib/aggregator_factory.hpp"
void job() {
// 1. create the aggregator
Aggregator<int> sum;
auto& ac = AggregatorFactory::get_channel();
// 3. ensure you perform a list_execute for a global aggregation
list_execute(obj_list, {}, {&ac}, [&sum](OBJ & obj) {
// 2. update the aggregator
sum.update(1);
});
// 4. get the aggregation value
LOG_I << sum.get_value();
}
Once created, aggregators will keep aggregating values. But sometimes we want to reset the aggregator value and start a new aggregation after each global aggregation, e.g. to count how many objects has sent messages out in the each list_execute
and stop the job if no object sends messages out. This can be done using to_reset_each_iter()
but note that once to_reset_each_iter()
or to_keep_aggregate()
is called, they should be called by all executors:
agg.to_reset_each_iter();
// or back to keep aggregating
agg.to_keep_aggregate();
Another thing is that, by default, those aggregators who get new updates will take part in a global aggregation. But sometimes we just want some of the aggregators to be globally aggregated. This can be done using inactivate()
. Inactive aggregators still accept updates but will not get involved in the global aggregation to synchronize the updates. Note that once activate()
or inactivate()
is called, they should be called by all executors:
agg.inactivate();
// or back to active
agg.activate();
One more thing is that inside the implementation of Aggregator
, due to some thread safety issues, sometimes it's necessary to make a new copy of the aggregator value and any modification of the new copy must not affect the origin copy. By default, Aggregator use operator =
to make a new copy but for some classes, operator =
may not really create an independent copy, like std::shared_ptr
. To solve this problem, we need to specialize a template function called lib::copy_aggisn
. Take std::shared_ptr
as an example:
namespace lib {
template <>
void copy_assign<shared_ptr<T>>(shared_ptr<T>& a, const shared_ptr<T>& b){
// instead of `a = b;`
a = make_shared<T>(b->some_member);
}
} // namespace lib
During the global aggregation, aggregator values need to be serialized and sent to the global center and then be deserialized. And serialization and deserialization need to be done again when the values are broadcasted back from the global center. So ensure (and it's suggested) that you've overloaded operator<<
and operator>>
to serialize the value into a BinStream
and deserialize the value out from a BinStream
. If you don't want to overload these two operators, you can instead put two lambdas (one for deserialization and the other for serialization) as the fourth and fifth arguments when creating the aggregator. For example:
const int K = 10;
Aggregator<set<int>> unique_topk(set<int>(), // 1. initial value
[](set<int>& a, const set<int>& b) { // 2. aggregation rule
for (int x : b) {
if (a.size() == K && *a.begin() < x) a.erase(a.begin());
if (a.size() < K) a.insert(x);
}
},
[](set<int>& s) { s.clear(); }, // 3. zero value
[](BinStream& b, set<int>& s) { // 4. deserialization
size_t n; b >> n; s.clear();
for (int x; n--;) { b >> x; s.insert(x); }
},
[](BinStream& b, const set<int>& s) { // 5. serialization
b << s.size();
for (int i : s) b << i;
});
- Aggregators can only be used when the
job
function is running. - The order to create each aggregator must be the same for each thread.
- Don't forget to give the zero value if it's different from the one from default constructor.
- Zero Value Lambda is used only when
update_any
is used. - Zero Value should be zero:
A + Zero Value == A
- An aggregator is reference counter based. They are copyable within one thread. Avoid sharing an aggregator instance among multi threads.
- Each aggregator has ONE copy for read and local copies for write in one machine. These copies will be destroyed when the reference counter is dropped to 0.
-
Creation: wait_for_first, ordered, call_by_all
Update: O(1), disordered
GetValue: O(1), disordered
Removal: O(1), disordered, call_by_all
Activate/Inactivate: O(1), disordered, call_by_all
ToResetEachIter/KeepAggregate: O(1), disordered, call_by_all
GlobalSynchronization: wait_for_last, call_by_all, global_synchronization, message: O(2*num_active_updated_agg*num_machine) - The global aggregation of aggregators will be assigned to executors evenly, in terms of the number of aggregators. So if the aggregation of an aggregator is heavy, it's suggested to split this aggregator into multiple aggregators.
- Aggregators may have different global centers for their global aggregation.
Q: Why the implementation of Aggregator is so complicated?
A: Because the implementation takes care of the storage, thread cooperation for local aggregation, network communication for global aggregation, etc. But as long as you feel happy to use it, I think it's worth anyway.
Q: I get unexpected results, exceptions or seg-fault when using Aggregator ... How to debug?
A: There are many potential reasons:
- Wrong initial value, like vector with the wrong size, objects with some nullptr members
- Wrong zero value, like vector with the wrong size, objects with some nullptr members
- Wrong serialization and deserialization of the aggregator value.
- A destroyed variable is used inside zero value lambda, i.e. the variable may has already been destroyed because it's out of its definition scope.
- Check if you need to override
copy_assign
function. - If items above don't give you any help, kindly open a new issue and give us the details.
Q: What's the relationship between husky and alaska?
A: Kindly see here.