Skip to content

Commit

Permalink
Merge pull request #629 from jolibrain/chain_branches
Browse files Browse the repository at this point in the history
Support for tree-structured chains
  • Loading branch information
beniz authored Sep 2, 2019
2 parents 19989f7 + e6b3e93 commit 8ae6462
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 182 deletions.
55 changes: 30 additions & 25 deletions src/chain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,33 +40,37 @@ namespace dd
while(hit!=ad._data.end())
{
std::string ad_key = (*hit).first;
if ((rhit = _replacements->find(ad_key))!=_replacements->end())
auto rhit_range = _replacements->equal_range(ad_key);
int rcount = std::distance(rhit_range.first,rhit_range.second);
if (rcount > 0)
{
std::string nested_chain = (*rhit).second.list_keys().at(0);

// we erase the chainid, and add up the model object
adc._data.erase(ad_key);

// recursive replacements for chains with > 2 models
bool recursive_changes = false;
visitor_nested vn(_replacements);
APIData nested_ad = (*rhit).second.getobj(nested_chain);
auto nhit = nested_ad._data.begin();
while(nhit!=nested_ad._data.end())
for (auto rhit=rhit_range.first; rhit!=rhit_range.second; ++rhit)
{
mapbox::util::apply_visitor(vn,(*nhit).second);
if (!vn._vad.empty())
std::string nested_chain = (*rhit).second.list_keys().at(0);

// recursive replacements for chains with > 2 models
bool recursive_changes = false;
visitor_nested vn(_replacements);
APIData nested_ad = (*rhit).second.getobj(nested_chain);
auto nhit = nested_ad._data.begin();
while(nhit!=nested_ad._data.end())
{
adc.add(nested_chain,vn._vad);
recursive_changes = true;
mapbox::util::apply_visitor(vn,(*nhit).second);
if (!vn._vad.empty())
{
adc.add(nested_chain,vn._vad);
recursive_changes = true;
}
++nhit;
}
++nhit;

if (!recursive_changes)
adc.add(nested_chain,
(*rhit).second.getobj(nested_chain));

}

if (!recursive_changes)
adc.add(nested_chain,
(*rhit).second.getobj(nested_chain));

// we erase the chainid, and add up the model object
adc._data.erase(ad_key);
_vad.push_back(adc);
}
else
Expand All @@ -90,12 +94,13 @@ namespace dd
// pre-compile models != first model
std::vector<std::string> uris;
APIData first_model_out;
std::unordered_map<std::string,APIData> other_models_out;
std::unordered_multimap<std::string,APIData> other_models_out;
std::unordered_map<std::string,APIData>::const_iterator hit = _model_data.begin();
while(hit!=_model_data.end())
{
std::string model_name = (*hit).first;
if (model_name == _first_sname)
std::string model_id = (*hit).first;
std::string model_name = get_model_sname(model_id);
if (model_id == _first_id)
{
first_model_out = (*hit).second;
std::vector<APIData> predictions = first_model_out.getv("predictions");
Expand Down
58 changes: 47 additions & 11 deletions src/chain.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#ifndef CHAIN_H
#define CHAIN_H

#include "chain_actions.h"
#include "apidata.h"
#include <iostream>

namespace dd
Expand All @@ -37,29 +37,65 @@ namespace dd
ChainData() {}
~ChainData() {}

void add_model_data(const std::string &sname,
void add_model_data(const std::string &id,
const APIData &out)
{
auto hit = _model_data.begin();
if ((hit=_model_data.find(sname))!=_model_data.end())
std::unordered_map<std::string,APIData>::iterator hit;
if ((hit=_model_data.find(id))!=_model_data.end())
_model_data.erase(hit);
_model_data.insert(std::pair<std::string,APIData>(sname,out));
_model_data.insert(std::pair<std::string,APIData>(id,out));
}

APIData get_model_data(const std::string &sname) const
APIData get_model_data(const std::string &id) const
{
std::unordered_map<std::string,APIData>::const_iterator hit;
if ((hit = _model_data.find(sname))!=_model_data.end())
if ((hit=_model_data.find(id))!=_model_data.end())
return (*hit).second;
else
return APIData();
}

void add_action_data(const std::string &id,
const APIData &out)
{
std::unordered_map<std::string,APIData>::iterator hit;
if ((hit=_action_data.find(id))!=_action_data.end())
_action_data.erase(hit);
_action_data.insert(std::pair<std::string,APIData>(id,out));
}

APIData get_action_data(const std::string &id) const
{
std::unordered_map<std::string,APIData>::const_iterator hit;
if ((hit=_action_data.find(id))!=_action_data.end())
return (*hit).second;
else
return APIData();
}

void add_model_sname(const std::string &id,
const std::string &sname)
{
std::unordered_map<std::string,std::string>::iterator hit;
if ((hit=_id_sname.find(id))==_id_sname.end())
_id_sname.insert(std::pair<std::string,std::string>(id,sname));
}

std::string get_model_sname(const std::string &id)
{
std::unordered_map<std::string,std::string>::const_iterator hit;
if ((hit=_id_sname.find(id))!=_id_sname.end())
return (*hit).second;
else return std::string();
}

APIData nested_chain_output();

std::unordered_map<std::string,APIData> _model_data;
std::vector<APIData> _action_data;
std::string _first_sname;
std::unordered_map<std::string,APIData> _action_data;
std::unordered_map<std::string,std::string> _id_sname;
//std::string _first_sname;
std::string _first_id;
};

/**
Expand All @@ -68,7 +104,7 @@ namespace dd
class visitor_nested
{
public:
visitor_nested(std::unordered_map<std::string,APIData> *r)
visitor_nested(std::unordered_multimap<std::string,APIData> *r)
:_replacements(r) {}
~visitor_nested() {}

Expand All @@ -87,7 +123,7 @@ namespace dd
void operator()(const APIData &ad);
void operator()(const std::vector<APIData> &vad);

std::unordered_map<std::string,APIData> *_replacements = nullptr;
std::unordered_multimap<std::string,APIData> *_replacements = nullptr;
std::vector<APIData> _vad;
};

Expand Down
9 changes: 5 additions & 4 deletions src/chain_actions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace dd
{

void ImgsCropAction::apply(APIData &model_out,
std::vector<APIData> &actions_data)
ChainData &cdata)
{
std::vector<APIData> vad = model_out.getv("predictions");
std::vector<cv::Mat> imgs = model_out.getobj("input").get("imgs").get<std::vector<cv::Mat>>();
Expand Down Expand Up @@ -114,14 +114,14 @@ namespace dd
APIData action_out;
action_out.add("data",cropped_imgs);
action_out.add("cids",bbox_ids);
actions_data.push_back(action_out);
cdata.add_action_data(_action_id,action_out);

// updated model data with chain ids
model_out.add("predictions",cvad);
}

void ClassFilter::apply(APIData &model_out,
std::vector<APIData> &actions_data)
ChainData &cdata)
{
if (!_params.has("classes"))
{
Expand Down Expand Up @@ -158,7 +158,8 @@ namespace dd
}

// empty action data
actions_data.push_back(APIData());
cdata.add_action_data(_action_id,APIData());
//actions_data.push_back(APIData());

// updated model data
model_out.add("predictions",cvad);
Expand Down
34 changes: 22 additions & 12 deletions src/chain_actions.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#define CHAIN_ACTIONS_H

#include "apidata.h"
#include "chain.h"

namespace dd
{
Expand Down Expand Up @@ -53,10 +54,12 @@ namespace dd
{
public:
ChainAction(const APIData &adc,
const std::string &action_id,
const std::string &action_type)
:_action_type(action_type)
:_action_id(action_id),_action_type(action_type)
{
_params = adc.getobj("parameters");
APIData action_adc = adc.getobj("action");
_params = action_adc.getobj("parameters");
}

~ChainAction() {}
Expand All @@ -69,8 +72,9 @@ namespace dd
}

void apply(APIData &model_out,
std::vector<APIData> &actions_data);
ChainData &cdata);

std::string _action_id;
std::string _action_type;
APIData _params;
bool _in_place = false;
Expand All @@ -81,25 +85,27 @@ namespace dd
{
public:
ImgsCropAction(const APIData &adc,
const std::string &action_id,
const std::string &action_type)
:ChainAction(adc,action_type) {}
:ChainAction(adc,action_id,action_type) {}

~ImgsCropAction() {}

void apply(APIData &model_out,
std::vector<APIData> &actions_data);
ChainData &cdata);
};

class ClassFilter : public ChainAction
{
public:
ClassFilter(const APIData &adc,
const std::string &action_id,
const std::string &action_type)
:ChainAction(adc,action_type) {_in_place = true;}
:ChainAction(adc,action_id,action_type) {_in_place = true;}
~ClassFilter() {}

void apply(APIData &model_out,
std::vector<APIData> &action_data);
ChainData &cdata);
};

class ChainActionFactory
Expand All @@ -111,17 +117,21 @@ namespace dd

void apply_action(const std::string &action_type,
APIData &model_out,
std::vector<APIData> &action_out)
ChainData &cdata)
{
std::string action_id;
if (_adc.has("id"))
action_id = _adc.get("id").get<std::string>();
else action_id = std::to_string(cdata._action_data.size());
if (action_type == "crop")
{
ImgsCropAction act(_adc,action_type);
act.apply(model_out,action_out);
ImgsCropAction act(_adc,action_id,action_type);
act.apply(model_out,cdata);
}
else if (action_type == "filter")
{
ClassFilter act(_adc,action_type);
act.apply(model_out,action_out);
ClassFilter act(_adc,action_id,action_type);
act.apply(model_out,cdata);
}
else
{
Expand Down
Loading

0 comments on commit 8ae6462

Please sign in to comment.