Skip to content

Commit

Permalink
move neuralnet datahandlers to machine_learning/datahandler directory…
Browse files Browse the repository at this point in the history
…. and let DataHandler be able to treat multiple type data.
  • Loading branch information
MasahiroOgawa committed Nov 26, 2016
1 parent faee8f1 commit 271f25c
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 93 deletions.
2 changes: 1 addition & 1 deletion ext/mnist/src/mnist_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ MNIST_dataset<Container, Sub<Pixel>, Label> read_dataset(const std::string mnist
std::string trainlabel_fname = mnist_dir + "/train-labels-idx1-ubyte";
std::string testlabel_fname = mnist_dir + "/t10k-labels-idx1-ubyte";

return read_dataset_direct<Container, Sub<Pixel>>(trainimg_fname, testimg_fname, trainlabel_fname, testlabel_fname, training_limit, test_limit);
return read_dataset_direct<Container, Sub<Pixel>, Label>(trainimg_fname, testimg_fname, trainlabel_fname, testlabel_fname, training_limit, test_limit);
}

} //end of namespace mnist
Expand Down
26 changes: 17 additions & 9 deletions machine_learning/data_handler/src/datahandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,23 @@ using namespace std;

namespace mo {

//-----------------------------------------
std::unique_ptr<DataHandler<>> create_handler(const DataType data_type){
switch (data_type) {
case DataType::mnist:
return std::unique_ptr<DataHandler<>>{new MnistDataHandler()};
break;
default:
throw std::logic_error("no such data type.");
}
std::unique_ptr<DataHandler<std::vector<unsigned char>, unsigned char>> create_imgdata_handler(const DataType data_type){
switch (data_type) {
case DataType::mnist:
return unique_ptr<DataHandler<vector<unsigned char>, unsigned char>>{ new MnistDataHandler<vector<unsigned char>, unsigned char>() };
default:
throw std::logic_error("no such data type. @DataHandler::create_imgdata_handler()");
}
}

//------------------------------
std::unique_ptr<DataHandler<std::vector<double>, double>> create_vecdata_handler(const DataType data_type){
switch (data_type) {
case DataType::mnist:
return unique_ptr<DataHandler<vector<double>, double>>{ new MnistDataHandler<vector<double>, double>() };
default:
throw std::logic_error("no such data type. @DataHandler::create_imgdata_handler()");
}
}

} // namespace mo
27 changes: 19 additions & 8 deletions machine_learning/data_handler/src/datahandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ enum class DataType{
mnist
};

template<typename Datum = std::vector<unsigned char>, typename Label = unsigned char>
template<typename Datum, typename Label>
struct Dataset{
std::vector<Datum> train_data;
std::vector<Datum> test_data;
Expand All @@ -23,35 +23,46 @@ struct Dataset{
/////////////////////////////
/// \brief DataHandler
/////////////////////////////
template<typename Datum = std::vector<unsigned char>, typename Label = unsigned char>
template<typename Datum , typename Label>
class DataHandler
{
public:
DataHandler(){}
virtual void read(const std::string& datadir) = 0; // pure virtual function; must be implemented in derived classes.
virtual char show(const Datum& datum, const std::string& winname) = 0;
virtual void show_traindata() = 0;
virtual const Matrix& X()const = 0;
virtual const Matrix& B()const = 0;

virtual const Matrix& train_datamat()const = 0; // for Neuralnet input
virtual const Matrix& train_labelmat()const = 0;
virtual ~DataHandler(){} // Because DataHandler has virtual func, need virtual destructor to call derived class's destructor.
std::vector<Datum>& train_data(){return dataset_.train_data;}
std::vector<Datum>& test_data(){return dataset_.test_data;}
std::vector<Label>& train_labels(){return dataset_.train_labels;}
std::vector<Label>& test_labels(){return dataset_.test_labels;}

virtual char show(const Datum& datum, const std::string& winname) = 0;
virtual void show_traindata() = 0;

protected:
Dataset<Datum, Label> dataset_; //currently support only the same type MNIST_dataset
};


//////////////////////////////
/// \brief create_handler
/// \brief create_imgdata_handler
/// \param data_type
/// \return
/// we cannot define tempalete version.
/// because in that case we have to define whole inplementation in .h, and which needs mnisthandler.h,
/// so it cause cross include.
//////////////////////////////
std::unique_ptr<DataHandler<>> create_handler(const DataType data_type);
std::unique_ptr<DataHandler<std::vector<unsigned char>, unsigned char>> create_imgdata_handler(const DataType data_type);


//////////////////////////////
/// \brief create_vecdata_handler
/// \param data_type
/// \return
//////////////////////////////
std::unique_ptr<DataHandler<std::vector<double>, double>> create_vecdata_handler(const DataType data_type);

} // namespace mo

Expand Down
14 changes: 7 additions & 7 deletions machine_learning/data_handler/src/irisdatahandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ namespace mo {
IrisDataHandler::IrisDataHandler()
{

cls_map_["Iris-setosa"] = (cv::Mat_<double>(3,1) << 1, 0, 0);
cls_map_["Iris-versicolor"] = (cv::Mat_<double>(3,1) << 0, 1, 0);
cls_map_["Iris-virginica"] = (cv::Mat_<double>(3,1) << 0, 0, 1);
cls_map_["Iris-setosa"] = (Matrix(3,1) << 1, 0, 0);
cls_map_["Iris-versicolor"] = (Matrix(3,1) << 0, 1, 0);
cls_map_["Iris-virginica"] = (Matrix(3,1) << 0, 0, 1);
}


Expand All @@ -29,20 +29,20 @@ try{

string line;
string cell;
cv::Mat Xt;
Matrix Xt;
unsigned datadim_ = 4;
while(getline(fi, line)){ //read lines
stringstream lstream(line);
if(line.empty()) break; //in case reading empty line.
cv::Mat x;
Matrix x;
for(unsigned i=0;i<datadim_;++i){ //read feature cells
getline(lstream, cell, ',');
x.push_back(stod(cell));
}
Xt.push_back(cv::Mat(x.t()));
Xt.push_back(Matrix(x.t()));

getline(lstream, cell, ','); //read instruction signal
cv::Mat b_vec = cls_map_.find(cell)->second;
Matrix b_vec = cls_map_.find(cell)->second;
if(B_.empty()) B_ = b_vec;
else cv::hconcat(B_, b_vec, B_);
}
Expand Down
68 changes: 10 additions & 58 deletions machine_learning/data_handler/src/mnistdatahandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,67 +5,19 @@ using namespace std;

namespace mo {

//---------------------------------------
void MnistDataHandler::read(const std::string& datadir){
std::cout << "[INFO] try to read mnist data from " << datadir << " @" << __func__ << "\n";

mnist::MNIST_dataset<std::vector, std::vector<unsigned char>, unsigned char> mnist_dataset = mnist::read_dataset(datadir);

this->dataset_.train_data = std::move(mnist_dataset.training_images); // we need this-> to make it dependent name(template arguments).
this->dataset_.test_data = std::move(mnist_dataset.test_images);
this->dataset_.train_labels = std::move(mnist_dataset.training_labels);
this->dataset_.test_labels = std::move(mnist_dataset.test_labels);

//print out info
size_t num_traindata = this->dataset_.train_data.size();
size_t num_testdata = this->dataset_.test_data.size();
size_t num_trainlbl = this->dataset_.train_labels.size();
size_t num_testlbl = this->dataset_.test_labels.size();
std::cout << "[INFO] read " << "#train data=" << num_traindata << ", #train labels=" << num_trainlbl
<< ", #test data=" << num_testdata << ", #test labels=" << num_testlbl << "\n";
if(num_traindata != num_trainlbl){
std::string errmsg = "#train_data != #train_labels @";
errmsg += __func__;
throw(std::logic_error(errmsg));
}
if(this->dataset_.test_data.size() != this->dataset_.test_labels.size()){
std::string errmsg = "#test_data != #test_labels @";
errmsg += __func__;
throw(std::logic_error(errmsg));
}
}
////---------------------------------------
//template<>
//const Matrix& MnistDataHandler<std::vector<double>, double>::train_datamat()const{
// return move(Matrix());
//}

//---------------------------------------
void MnistDataHandler::show_traindata(){
std::cout << "q: stop display.\n";

const std::vector<std::vector<unsigned char>>& train_data = this->dataset_.train_data;
for(auto tr_img : train_data){
char ch = this->show(tr_img,"train image");
if(ch=='q'){
destroy_window("train image");
break;
}
}

return;
}


//---------------------------------------
char MnistDataHandler::show(const std::vector<unsigned char>& datum, const std::string& winname){
Image_gray img(28, 28, const_cast<unsigned char*>(datum.data()));
return mo::show(winname, img, 0);
}

//---------------------------------------
const Matrix& MnistDataHandler::X()const{
return Matrix();
}

//---------------------------------------
const Matrix& MnistDataHandler::B()const{
return Matrix();
}
////---------------------------------------
//template<>
//const Matrix& MnistDataHandler<std::vector<double>, double>::train_labelmat()const{
// return move(Matrix());
//}

} // namespace mo
96 changes: 91 additions & 5 deletions machine_learning/data_handler/src/mnistdatahandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,110 @@

namespace mo{

class MnistDataHandler : public DataHandler<std::vector<unsigned char>, unsigned char>
template<typename Datum , typename Label>
class MnistDataHandler : public DataHandler<Datum, Label>
{
public:
MnistDataHandler(){}

//////////////////
/// \brief read
/// \param datadir
//////////////////
void read(const std::string& datadir);

//////////////////
/// \brief show
/// \param datum
/// \return pushed key
/// \return pushed key on the window
//////////////////
char show(const std::vector<unsigned char>& datum, const std::string& winname="mnist image");
char show(const Datum& datum, const std::string& winname="mnist image");

void show_traindata();
const Matrix& X()const;
const Matrix& B()const;
const Matrix& train_datamat()const;
const Matrix& train_labelmat()const;
};


//---------------------------------------
template<typename Datum , typename Label>
void MnistDataHandler<Datum,Label>::read(const std::string& datadir){
std::cout << "[INFO] try to read mnist data from " << datadir << " @" << __func__ << "\n";

mnist::MNIST_dataset<std::vector, Datum, Label> mnist_dataset = mnist::read_dataset<std::vector, std::vector, typename Datum::value_type, Label>(datadir);

this->dataset_.train_data = std::move(mnist_dataset.training_images); // we need this-> to make it dependent name(template arguments).
this->dataset_.test_data = std::move(mnist_dataset.test_images);
this->dataset_.train_labels = std::move(mnist_dataset.training_labels);
this->dataset_.test_labels = std::move(mnist_dataset.test_labels);

//print out info
size_t num_traindata = this->dataset_.train_data.size();
size_t num_testdata = this->dataset_.test_data.size();
size_t num_trainlbl = this->dataset_.train_labels.size();
size_t num_testlbl = this->dataset_.test_labels.size();
std::cout << "[INFO] read " << "#train data=" << num_traindata << ", #train labels=" << num_trainlbl
<< ", #test data=" << num_testdata << ", #test labels=" << num_testlbl << "\n";
if(num_traindata != num_trainlbl){
std::string errmsg = "#train_data != #train_labels @";
errmsg += __func__;
throw(std::logic_error(errmsg));
}
if(this->dataset_.test_data.size() != this->dataset_.test_labels.size()){
std::string errmsg = "#test_data != #test_labels @";
errmsg += __func__;
throw(std::logic_error(errmsg));
}
}

//---------------------------------------
template<typename Datum , typename Label>
char MnistDataHandler<Datum, Label>::show(const Datum& datum, const std::string& winname){
std::cout << "datum" << std::endl;
return '0';
}

//---------------------------------------
// template specialization for image data.
// to avoid being called primary, this must be placed in header file.
// in this case, it needs inline specifier. then it has the same address in every translation unit.
// otherwise it cause compile error of multiple definition.
template<>
inline char MnistDataHandler<std::vector<unsigned char>, unsigned char>::show(const std::vector<unsigned char>& datum, const std::string& winname){
Image_gray img(28, 28, const_cast<unsigned char*>(datum.data()));
return mo::show(winname, img, 0);
}

//---------------------------------------
template<typename Datum , typename Label>
void MnistDataHandler<Datum, Label>::show_traindata(){
std::cout << "q: stop display.\n";

const auto& train_data = this->train_data();
for(auto tr_img : train_data){
char ch = this->show(tr_img,"train image");
if(ch=='q'){
destroy_window("train image");
break;
}
}

return;
}

//---------------------------------------
template<typename Datum , typename Label>
const Matrix& MnistDataHandler<Datum, Label>::train_datamat()const{
return Matrix();
}

//---------------------------------------
template<typename Datum , typename Label>
const Matrix& MnistDataHandler<Datum, Label>::train_labelmat()const{
return Matrix();
}


} // namespace mo

#endif // MNISTDATAHANDLER_H
2 changes: 1 addition & 1 deletion machine_learning/data_handler/src/test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using namespace std;
using namespace mo;

int main()try{
unique_ptr<DataHandler<>> pdh{create_handler(DataType::mnist)};
unique_ptr<DataHandler<vector<unsigned char>, unsigned char>> pdh{create_imgdata_handler(DataType::mnist)};
pdh->read("~/git/my/mopf/data/mnist");
pdh->show_traindata();

Expand Down
4 changes: 2 additions & 2 deletions machine_learning/knearest_neighbor/src/knearest_neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class KNearestNeighbor

private:
int k_;
std::unique_ptr<DataHandler<>> pdh_;
std::unique_ptr<DataHandler<std::vector<unsigned char>, unsigned char>> pdh_;
bool show_result_;
DistanceType disttp_;

Expand All @@ -31,7 +31,7 @@ class KNearestNeighbor
template<typename Datum, typename Label>
KNearestNeighbor<Datum, Label>::KNearestNeighbor(const int k, const std::string& datadir, const DataType dt
, const bool show_result, const DistanceType disttp):
k_ {k}, show_result_ {show_result}, pdh_{create_handler(dt)}, disttp_{disttp}
k_ {k}, show_result_ {show_result}, pdh_{create_imgdata_handler(dt)}, disttp_{disttp}
{
pdh_->read(datadir);
if(show_result_) pdh_->show_traindata();
Expand Down
5 changes: 3 additions & 2 deletions machine_learning/neuralnet/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "../../data_handler/src/datahandler.h"
#include "../../data_handler/src/daviddatahandler.h"
#include "../../data_handler/src/irisdatahandler.h"
#include "../../data_handler/src/mnistdatahandler.h"
#include "neuralnethandler.h"
#include <memory> //for unique_ptr
using namespace std;
Expand All @@ -17,12 +18,12 @@ try{
read_param(prm_file,prm);

//read data
unique_ptr<DataHandler<>> pdh{create_handler(static_cast<DataType>(prm.data.data_type))};
unique_ptr<DataHandler<vector<double>, double>> pdh{create_vecdata_handler(static_cast<DataType>(prm.data.data_type))};
pdh->read(prm.data.datafname);

//learn
NeuralnetHandler nns(prm.nn, prm.vis);
nns.learn(pdh->X(), pdh->B());
nns.learn(pdh->train_datamat(), pdh->train_labelmat());

return 0;
}catch(runtime_error& e){
Expand Down

0 comments on commit 271f25c

Please sign in to comment.