diff --git a/src/chain.cc b/src/chain.cc index b27ce04c6..ca82a933b 100644 --- a/src/chain.cc +++ b/src/chain.cc @@ -40,33 +40,37 @@ namespace dd while(hit!=ad._data.end()) { std::string ad_key = (*hit).first; - if ((rhit = _replacements->find(ad_key))!=_replacements->end()) + auto rhit_range = _replacements->equal_range(ad_key); + int rcount = std::distance(rhit_range.first,rhit_range.second); + if (rcount > 0) { - std::string nested_chain = (*rhit).second.list_keys().at(0); - - // we erase the chainid, and add up the model object - adc._data.erase(ad_key); - - // recursive replacements for chains with > 2 models - bool recursive_changes = false; - visitor_nested vn(_replacements); - APIData nested_ad = (*rhit).second.getobj(nested_chain); - auto nhit = nested_ad._data.begin(); - while(nhit!=nested_ad._data.end()) + for (auto rhit=rhit_range.first; rhit!=rhit_range.second; ++rhit) { - mapbox::util::apply_visitor(vn,(*nhit).second); - if (!vn._vad.empty()) + std::string nested_chain = (*rhit).second.list_keys().at(0); + + // recursive replacements for chains with > 2 models + bool recursive_changes = false; + visitor_nested vn(_replacements); + APIData nested_ad = (*rhit).second.getobj(nested_chain); + auto nhit = nested_ad._data.begin(); + while(nhit!=nested_ad._data.end()) { - adc.add(nested_chain,vn._vad); - recursive_changes = true; + mapbox::util::apply_visitor(vn,(*nhit).second); + if (!vn._vad.empty()) + { + adc.add(nested_chain,vn._vad); + recursive_changes = true; + } + ++nhit; } - ++nhit; + + if (!recursive_changes) + adc.add(nested_chain, + (*rhit).second.getobj(nested_chain)); + } - - if (!recursive_changes) - adc.add(nested_chain, - (*rhit).second.getobj(nested_chain)); - + // we erase the chainid, and add up the model object + adc._data.erase(ad_key); _vad.push_back(adc); } else @@ -90,12 +94,13 @@ namespace dd // pre-compile models != first model std::vector uris; APIData first_model_out; - std::unordered_map other_models_out; + std::unordered_multimap other_models_out; std::unordered_map::const_iterator hit = _model_data.begin(); while(hit!=_model_data.end()) { - std::string model_name = (*hit).first; - if (model_name == _first_sname) + std::string model_id = (*hit).first; + std::string model_name = get_model_sname(model_id); + if (model_id == _first_id) { first_model_out = (*hit).second; std::vector predictions = first_model_out.getv("predictions"); diff --git a/src/chain.h b/src/chain.h index a119341d8..0b04e5c2c 100644 --- a/src/chain.h +++ b/src/chain.h @@ -22,7 +22,7 @@ #ifndef CHAIN_H #define CHAIN_H -#include "chain_actions.h" +#include "apidata.h" #include namespace dd @@ -37,29 +37,65 @@ namespace dd ChainData() {} ~ChainData() {} - void add_model_data(const std::string &sname, + void add_model_data(const std::string &id, const APIData &out) { - auto hit = _model_data.begin(); - if ((hit=_model_data.find(sname))!=_model_data.end()) + std::unordered_map::iterator hit; + if ((hit=_model_data.find(id))!=_model_data.end()) _model_data.erase(hit); - _model_data.insert(std::pair(sname,out)); + _model_data.insert(std::pair(id,out)); } - APIData get_model_data(const std::string &sname) const + APIData get_model_data(const std::string &id) const { std::unordered_map::const_iterator hit; - if ((hit = _model_data.find(sname))!=_model_data.end()) + if ((hit=_model_data.find(id))!=_model_data.end()) return (*hit).second; else return APIData(); } + + void add_action_data(const std::string &id, + const APIData &out) + { + std::unordered_map::iterator hit; + if ((hit=_action_data.find(id))!=_action_data.end()) + _action_data.erase(hit); + _action_data.insert(std::pair(id,out)); + } + + APIData get_action_data(const std::string &id) const + { + std::unordered_map::const_iterator hit; + if ((hit=_action_data.find(id))!=_action_data.end()) + return (*hit).second; + else + return APIData(); + } + + void add_model_sname(const std::string &id, + const std::string &sname) + { + std::unordered_map::iterator hit; + if ((hit=_id_sname.find(id))==_id_sname.end()) + _id_sname.insert(std::pair(id,sname)); + } + + std::string get_model_sname(const std::string &id) + { + std::unordered_map::const_iterator hit; + if ((hit=_id_sname.find(id))!=_id_sname.end()) + return (*hit).second; + else return std::string(); + } APIData nested_chain_output(); std::unordered_map _model_data; - std::vector _action_data; - std::string _first_sname; + std::unordered_map _action_data; + std::unordered_map _id_sname; + //std::string _first_sname; + std::string _first_id; }; /** @@ -68,7 +104,7 @@ namespace dd class visitor_nested { public: - visitor_nested(std::unordered_map *r) + visitor_nested(std::unordered_multimap *r) :_replacements(r) {} ~visitor_nested() {} @@ -87,7 +123,7 @@ namespace dd void operator()(const APIData &ad); void operator()(const std::vector &vad); - std::unordered_map *_replacements = nullptr; + std::unordered_multimap *_replacements = nullptr; std::vector _vad; }; diff --git a/src/chain_actions.cc b/src/chain_actions.cc index 7a151951d..5986d4402 100644 --- a/src/chain_actions.cc +++ b/src/chain_actions.cc @@ -28,7 +28,7 @@ namespace dd { void ImgsCropAction::apply(APIData &model_out, - std::vector &actions_data) + ChainData &cdata) { std::vector vad = model_out.getv("predictions"); std::vector imgs = model_out.getobj("input").get("imgs").get>(); @@ -114,14 +114,14 @@ namespace dd APIData action_out; action_out.add("data",cropped_imgs); action_out.add("cids",bbox_ids); - actions_data.push_back(action_out); + cdata.add_action_data(_action_id,action_out); // updated model data with chain ids model_out.add("predictions",cvad); } void ClassFilter::apply(APIData &model_out, - std::vector &actions_data) + ChainData &cdata) { if (!_params.has("classes")) { @@ -158,7 +158,8 @@ namespace dd } // empty action data - actions_data.push_back(APIData()); + cdata.add_action_data(_action_id,APIData()); + //actions_data.push_back(APIData()); // updated model data model_out.add("predictions",cvad); diff --git a/src/chain_actions.h b/src/chain_actions.h index fd7bc5bce..e22168f20 100644 --- a/src/chain_actions.h +++ b/src/chain_actions.h @@ -23,6 +23,7 @@ #define CHAIN_ACTIONS_H #include "apidata.h" +#include "chain.h" namespace dd { @@ -53,10 +54,12 @@ namespace dd { public: ChainAction(const APIData &adc, + const std::string &action_id, const std::string &action_type) - :_action_type(action_type) + :_action_id(action_id),_action_type(action_type) { - _params = adc.getobj("parameters"); + APIData action_adc = adc.getobj("action"); + _params = action_adc.getobj("parameters"); } ~ChainAction() {} @@ -69,8 +72,9 @@ namespace dd } void apply(APIData &model_out, - std::vector &actions_data); + ChainData &cdata); + std::string _action_id; std::string _action_type; APIData _params; bool _in_place = false; @@ -81,25 +85,27 @@ namespace dd { public: ImgsCropAction(const APIData &adc, + const std::string &action_id, const std::string &action_type) - :ChainAction(adc,action_type) {} + :ChainAction(adc,action_id,action_type) {} ~ImgsCropAction() {} void apply(APIData &model_out, - std::vector &actions_data); + ChainData &cdata); }; class ClassFilter : public ChainAction { public: ClassFilter(const APIData &adc, + const std::string &action_id, const std::string &action_type) - :ChainAction(adc,action_type) {_in_place = true;} + :ChainAction(adc,action_id,action_type) {_in_place = true;} ~ClassFilter() {} void apply(APIData &model_out, - std::vector &action_data); + ChainData &cdata); }; class ChainActionFactory @@ -111,17 +117,21 @@ namespace dd void apply_action(const std::string &action_type, APIData &model_out, - std::vector &action_out) + ChainData &cdata) { + std::string action_id; + if (_adc.has("id")) + action_id = _adc.get("id").get(); + else action_id = std::to_string(cdata._action_data.size()); if (action_type == "crop") { - ImgsCropAction act(_adc,action_type); - act.apply(model_out,action_out); + ImgsCropAction act(_adc,action_id,action_type); + act.apply(model_out,cdata); } else if (action_type == "filter") { - ClassFilter act(_adc,action_type); - act.apply(model_out,action_out); + ClassFilter act(_adc,action_id,action_type); + act.apply(model_out,cdata); } else { diff --git a/src/services.h b/src/services.h index f9dc825d2..59d5ebb60 100644 --- a/src/services.h +++ b/src/services.h @@ -32,6 +32,7 @@ #include "svminputfileconn.h" #include "outputconnectorstrategy.h" #include "chain.h" +#include "chain_actions.h" #ifdef USE_CAFFE #include "backends/caffe/caffelib.h" #endif @@ -616,6 +617,163 @@ namespace dd return pout._status; } + //TODO: parent_id + int chain_service(const std::string &cname, + const std::shared_ptr &chain_logger, + APIData &adc, + ChainData &cdata, + const std::string &pred_id, + std::vector &meta_uris, + std::vector &index_uris, + const int &prec_action_id, + const int chain_pos, + int &npredicts) + { + std::string sname = adc.get("service").get(); + chain_logger->info("[" + std::to_string(chain_pos) + "] / executing predict on service " + sname); + + // need to check that service exists + if (!service_exists(sname)) + { + spdlog::drop(cname); + throw ServiceNotFoundException("Service " + sname + " does not exist"); + } + + // parent_id, if any + std::string parent_id; + if (adc.has("parent_id")) + parent_id = adc.get("parent_id").get(); + + // if not first predict call in the chain, need to setup the input data! + if (chain_pos != 0) + { + // take data from the previous action + APIData act_data = cdata.get_action_data(!parent_id.empty() ? parent_id : std::to_string(prec_action_id)); + adc.add("data",act_data.get("data").get>()); // action output data must be string for now (more types to be supported / auto-detected) + adc.add("ids",act_data.get("cids").get>()); // chain ids of processed elements + adc.add("meta_uris",meta_uris); + adc.add("index_uris",index_uris); + } + else { + cdata._first_id = pred_id; + } + + APIData pred_out; + try + { + int pred_status = predict(adc,sname,pred_out,true); + } + catch(...) + { + spdlog::drop(cname); + throw; + } + + // check on results + std::vector vad = pred_out.getv("predictions"); + if (vad.empty()) + { + chain_logger->info("[" + std::to_string(chain_pos) + "] no predictions"); + //break; + return 1; + } + + int classes_size = 0; + int vals_size = 0; + std::vector nmeta_uris; + std::vector nindex_uris; + for (size_t j=0;j(vad.at(j).has("vals")); + if (chain_pos == 0) // first call's response contains uniformized top level URIs. + { + for (size_t k=0;k()); + if (vad.at(j).has("index_uri")) + nindex_uris.push_back(vad.at(j).get("index_uri").get()); + } + } + else // update meta uris to batch size at the current level of the chain + { + for (size_t k=0;kinfo("[" + std::to_string(chain_pos) + "] / no result from prediction"); + return 1; + } + ++npredicts; + + // store model output + cdata.add_model_data(pred_id,pred_out); + + return 0; + } + + int chain_action(const std::shared_ptr &chain_logger, + APIData &adc, + ChainData &cdata, + const int &chain_pos, + const std::string &prec_pred_id) + { + std::string action_type = adc.getobj("action").get("type").get(); + + APIData prev_data = cdata.get_model_data(prec_pred_id); + if (!prev_data.getv("predictions").size()) + { + // no prediction to work from + chain_logger->info("no prediction to act on"); + return 1; + } + + // call chain action factory + chain_logger->info("[" + std::to_string(chain_pos) + "] / executing action " + action_type); + ChainActionFactory caf(adc); + caf.apply_action(action_type, + prev_data, + cdata); + + // replace prev_data in cdata for prec_pred_id + cdata.add_model_data(prec_pred_id,prev_data); + + std::vector vad = prev_data.getv("predictions"); + if (vad.empty()) + { + // no prediction to work from + chain_logger->info("no prediction to act on after applying action " + action_type); + return 1; + } + + int classes_size = 0; + int vals_size = 0; + for (size_t i=0;i(vad.at(i).has("vals")); + } + + if (!classes_size && !vals_size) + { + chain_logger->info("[" + std::to_string(chain_pos) + "] / no result after applying action " + action_type); + return 1; + } + + return 0; + } + int chain(const APIData &ad, const std::string &cname, APIData &out) { #ifdef USE_DD_SYSLOG @@ -643,7 +801,7 @@ namespace dd std::vector meta_uris; std::vector index_uris; int npredicts = 0; - std::string prec_pred_sname; + std::string prec_pred_id; int prec_action_id = 0; int aid = 0; for (size_t i=0;i(); - - chain_logger->info("[" + std::to_string(i) + "] / executing predict on service " + pred_sname); - - // need to check that service exists - if (!service_exists(pred_sname)) - { - spdlog::drop(cname); - throw ServiceNotFoundException("Service " + pred_sname + " does not exist"); - } - - // if not first predict call in the chain, need to setup the input data! - if (i != 0) - { - // take data from the previous action - APIData act_data = cdata._action_data.at(prec_action_id); - - adc.add("data",act_data.get("data").get>()); // action output data must be string for now (more types to be supported / auto-detected) - adc.add("ids",act_data.get("cids").get>()); // chain ids of processed elements - adc.add("meta_uris",meta_uris); - adc.add("index_uris",index_uris); - } - else cdata._first_sname = pred_sname; - - APIData pred_out; - try - { - int pred_status = predict(adc,pred_sname,pred_out,true); - } - catch(...) - { - spdlog::drop(cname); - throw; - } - - // check on results - std::vector vad = pred_out.getv("predictions"); - if (vad.empty()) - { - chain_logger->info("[" + std::to_string(i) + "] no predictions"); - break; - } - - int classes_size = 0; - int vals_size = 0; - std::vector nmeta_uris; - std::vector nindex_uris; - for (size_t j=0;j(vad.at(j).has("vals")); - if (i == 0) // first call's response contains uniformized top level URIs. - { - for (size_t k=0;k()); - if (vad.at(j).has("index_uri")) - nindex_uris.push_back(vad.at(j).get("index_uri").get()); - } - } - else // update meta uris to batch size at the current level of the chain - { - for (size_t k=0;kinfo("[" + std::to_string(i) + "] / no result from prediction"); - break; - } - ++npredicts; - - // store model output - cdata.add_model_data(pred_sname,pred_out); - - prec_pred_sname = pred_sname; + std::string pred_id; + if (adc.has("id")) + pred_id = adc.get("id").get(); + else pred_id = std::to_string(i); + cdata.add_model_sname(pred_id,adc.get("service").get()); + chain_service(cname,chain_logger,adc,cdata, + pred_id,meta_uris,index_uris, + prec_action_id,i,npredicts); + prec_pred_id = pred_id; } else if (adc.has("action")) { - std::string action_type = adc.getobj("action").get("type").get(); - - APIData prev_data = cdata.get_model_data(prec_pred_sname); - if (!prev_data.getv("predictions").size()) - { - // no prediction to work from - chain_logger->info("no prediction to act on"); - break; - } - - // call chain action factory - chain_logger->info("[" + std::to_string(i) + "] / executing action " + action_type); - ChainActionFactory caf(adc.getobj("action")); - caf.apply_action(action_type, - prev_data, - cdata._action_data); - - // replace prev_data in cdata for prec_pred_sname - cdata.add_model_data(prec_pred_sname,prev_data); - - std::vector vad = prev_data.getv("predictions"); - if (vad.empty()) - { - // no prediction to work from - chain_logger->info("no prediction to act on after applying action " + action_type); - break; - } - - int classes_size = 0; - int vals_size = 0; - for (size_t i=0;i(vad.at(i).has("vals")); - } - - if (!classes_size && !vals_size) - { - chain_logger->info("[" + std::to_string(i) + "] / no result after applying action " + action_type); - break; - } - + chain_action(chain_logger,adc,cdata,i,prec_pred_id); prec_action_id = aid; ++aid; } @@ -791,7 +831,7 @@ namespace dd APIData nested_out; if (npredicts > 1) nested_out = cdata.nested_chain_output(); - else nested_out = cdata.get_model_data(cdata._first_sname); + else nested_out = cdata.get_model_data(cdata._first_id); out = nested_out; std::chrono::time_point tstop = std::chrono::system_clock::now();