35 #define VIGRA_RF_ALGORITHM_HXX
38 #include "splices.hxx"
58 template<
class OrigMultiArray,
61 void choose(OrigMultiArray
const & in,
70 for(Iter iter = b; iter != e; ++iter, ++ii)
100 template<
class Feature_t,
class Response_t>
102 Response_t
const & response)
125 typedef std::vector<int> FeatureList_t;
126 typedef std::vector<double> ErrorList_t;
127 typedef FeatureList_t::iterator Pivot_t;
153 template<
class FeatureT,
156 class ErrorRateCallBack>
157 bool init(FeatureT
const & all_features,
158 ResponseT
const & response,
161 ErrorRateCallBack errorcallback)
163 bool ret_ = init(all_features, response, errorcallback);
166 vigra_precondition(std::distance(b, e) == (std::ptrdiff_t)
selected.size(),
167 "Number of features in ranking != number of features matrix");
172 template<
class FeatureT,
175 bool init(FeatureT
const & all_features,
176 ResponseT
const & response,
181 return init(all_features, response, b, e, ecallback);
185 template<
class FeatureT,
187 bool init(FeatureT
const & all_features,
188 ResponseT
const & response)
190 return init(all_features, response, RFErrorCallback());
202 template<
class FeatureT,
204 class ErrorRateCallBack>
205 bool init(FeatureT
const & all_features,
206 ResponseT
const & response,
207 ErrorRateCallBack errorcallback)
215 selected.resize(all_features.shape(1), 0);
216 for(
unsigned int ii = 0; ii <
selected.size(); ++ii)
218 errors.resize(all_features.shape(1), -1);
219 errors.back() = errorcallback(all_features, response);
223 std::map<typename ResponseT::value_type, int> res_map;
224 std::vector<int> cts;
226 for(
int ii = 0; ii < response.shape(0); ++ii)
228 if(res_map.find(response(ii, 0)) == res_map.end())
230 res_map[response(ii, 0)] = counter;
234 cts[res_map[response(ii,0)]] +=1;
236 no_features = double(*(std::max_element(cts.begin(),
238 /
double(response.shape(0));
293 template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
295 ResponseT
const & response,
297 ErrorRateCallBack errorcallback)
299 VariableSelectionResult::FeatureList_t & selected = result.
selected;
300 VariableSelectionResult::ErrorList_t & errors = result.
errors;
301 VariableSelectionResult::Pivot_t & pivot = result.pivot;
302 int featureCount = features.shape(1);
304 if(!result.init(features, response, errorcallback))
308 vigra_precondition((
int)selected.size() == featureCount,
309 "forward_selection(): Number of features in Feature "
310 "matrix and number of features in previously used "
311 "result struct mismatch!");
315 int not_selected_size = std::distance(pivot, selected.end());
316 while(not_selected_size > 1)
318 std::vector<double> current_errors;
319 VariableSelectionResult::Pivot_t next = pivot;
320 for(
int ii = 0; ii < not_selected_size; ++ii, ++next)
322 std::swap(*pivot, *next);
324 detail::choose( features,
328 double error = errorcallback(cur_feats, response);
329 current_errors.push_back(error);
330 std::swap(*pivot, *next);
332 int pos = std::distance(current_errors.begin(),
333 std::min_element(current_errors.begin(),
334 current_errors.end()));
336 std::advance(next, pos);
337 std::swap(*pivot, *next);
338 errors[std::distance(selected.begin(), pivot)] = current_errors[pos];
340 std::copy(current_errors.begin(), current_errors.end(), std::ostream_iterator<double>(std::cerr,
", "));
341 std::cerr <<
"Choosing " << *pivot <<
" at error of " << current_errors[pos] << std::endl;
344 not_selected_size = std::distance(pivot, selected.end());
347 template<
class FeatureT,
class ResponseT>
349 ResponseT
const & response,
350 VariableSelectionResult & result)
395 template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
397 ResponseT
const & response,
399 ErrorRateCallBack errorcallback)
401 int featureCount = features.shape(1);
402 VariableSelectionResult::FeatureList_t & selected = result.
selected;
403 VariableSelectionResult::ErrorList_t & errors = result.
errors;
404 VariableSelectionResult::Pivot_t & pivot = result.pivot;
407 if(!result.init(features, response, errorcallback))
411 vigra_precondition((
int)selected.size() == featureCount,
412 "backward_elimination(): Number of features in Feature "
413 "matrix and number of features in previously used "
414 "result struct mismatch!");
416 pivot = selected.end() - 1;
418 int selected_size = std::distance(selected.begin(), pivot);
419 while(selected_size > 1)
421 VariableSelectionResult::Pivot_t next = selected.begin();
422 std::vector<double> current_errors;
423 for(
int ii = 0; ii < selected_size; ++ii, ++next)
425 std::swap(*pivot, *next);
427 detail::choose( features,
431 double error = errorcallback(cur_feats, response);
432 current_errors.push_back(error);
433 std::swap(*pivot, *next);
435 int pos = std::distance(current_errors.begin(),
436 std::min_element(current_errors.begin(),
437 current_errors.end()));
438 next = selected.begin();
439 std::advance(next, pos);
440 std::swap(*pivot, *next);
442 errors[std::distance(selected.begin(), pivot)-1] = current_errors[pos];
443 selected_size = std::distance(selected.begin(), pivot);
445 std::copy(current_errors.begin(), current_errors.end(), std::ostream_iterator<double>(std::cerr,
", "));
446 std::cerr <<
"Eliminating " << *pivot <<
" at error of " << current_errors[pos] << std::endl;
452 template<
class FeatureT,
class ResponseT>
454 ResponseT
const & response,
455 VariableSelectionResult & result)
492 template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
494 ResponseT
const & response,
496 ErrorRateCallBack errorcallback)
498 VariableSelectionResult::FeatureList_t & selected = result.
selected;
499 VariableSelectionResult::ErrorList_t & errors = result.
errors;
500 VariableSelectionResult::Pivot_t & iter = result.pivot;
501 int featureCount = features.shape(1);
503 if(!result.init(features, response, errorcallback))
507 vigra_precondition((
int)selected.size() == featureCount,
508 "forward_selection(): Number of features in Feature "
509 "matrix and number of features in previously used "
510 "result struct mismatch!");
514 for(; iter != selected.end(); ++iter)
518 detail::choose( features,
522 double error = errorcallback(cur_feats, response);
523 errors[std::distance(selected.begin(), iter)] = error;
525 std::copy(selected.begin(), iter+1, std::ostream_iterator<int>(std::cerr,
", "));
526 std::cerr <<
"Choosing " << *(iter+1) <<
" at error of " << error << std::endl;
532 template<
class FeatureT,
class ResponseT>
534 ResponseT
const & response,
535 VariableSelectionResult & result)
542 enum ClusterLeafTypes{c_Leaf = 95, c_Node = 99};
557 ClusterNode():NodeBase(){}
558 ClusterNode(
int nCol,
559 BT::T_Container_type & topology,
560 BT::P_Container_type & split_param)
561 : BT(nCol + 5, 5,topology, split_param)
571 ClusterNode( BT::T_Container_type
const & topology,
572 BT::P_Container_type
const & split_param,
574 :
NodeBase(5 , 5,topology, split_param, n)
580 ClusterNode( BT & node_)
585 BT::parameter_size_ += 0;
591 void set_index(
int in)
617 HC_Entry(
int p,
int l,
int a,
bool in)
618 : parent(p), level(l), addr(a), infm(in)
647 double dist_func(
double a,
double b)
649 return std::min(a, b);
655 template<
class Functor>
659 std::vector<int> stack;
660 stack.push_back(begin_addr);
661 while(!stack.empty())
663 ClusterNode node(topology_, parameters_, stack.
back());
667 if(node.columns_size() != 1)
669 stack.push_back(node.child(0));
670 stack.push_back(node.child(1));
678 template<
class Functor>
682 std::queue<HC_Entry> queue;
687 queue.push(
HC_Entry(parent,level,begin_addr, infm));
688 while(!queue.empty())
690 level = queue.front().level;
691 parent = queue.front().parent;
692 addr = queue.front().addr;
693 infm = queue.front().infm;
694 ClusterNode node(topology_, parameters_, queue.
front().addr);
698 parnt = ClusterNode(topology_, parameters_, parent);
701 bool istrue = tester(node, level, parnt, infm);
702 if(node.columns_size() != 1)
704 queue.push(
HC_Entry(addr, level +1,node.child(0),istrue));
705 queue.push(
HC_Entry(addr, level +1,node.child(1),istrue));
712 void save(std::string file, std::string prefix)
717 Shp(topology_.
size(),1),
721 Shp(parameters_.
size(), 1),
722 parameters_.
data()));
732 template<
class T,
class C>
736 std::vector<std::pair<int, int> > addr;
737 typedef std::pair<int, int> Entry;
739 for(
int ii = 0; ii < distance.
shape(0); ++ii)
741 addr.push_back(std::make_pair(topology_.
size(), ii));
742 ClusterNode leaf(1, topology_, parameters_);
743 leaf.set_index(index);
745 leaf.columns_begin()[0] = ii;
748 while(addr.size() != 1)
753 double min_dist = dist((addr.begin()+ii_min)->second,
754 (addr.begin()+jj_min)->second);
755 for(
unsigned int ii = 0; ii < addr.size(); ++ii)
757 for(
unsigned int jj = ii+1; jj < addr.size(); ++jj)
759 if( dist((addr.begin()+ii_min)->second,
760 (addr.begin()+jj_min)->second)
761 > dist((addr.begin()+ii)->second,
762 (addr.begin()+jj)->second))
764 min_dist = dist((addr.begin()+ii)->second,
765 (addr.begin()+jj)->second);
777 ClusterNode firstChild(topology_,
779 (addr.begin() +ii_min)->first);
780 ClusterNode secondChild(topology_,
782 (addr.begin() +jj_min)->first);
783 col_size = firstChild.columns_size() + secondChild.columns_size();
785 int cur_addr = topology_.
size();
786 begin_addr = cur_addr;
788 ClusterNode parent(col_size,
791 ClusterNode firstChild(topology_,
793 (addr.begin() +ii_min)->first);
794 ClusterNode secondChild(topology_,
796 (addr.begin() +jj_min)->first);
797 parent.parameters_begin()[0] = min_dist;
798 parent.set_index(index);
800 std::merge(firstChild.columns_begin(), firstChild.columns_end(),
801 secondChild.columns_begin(),secondChild.columns_end(),
802 parent.columns_begin());
806 if(*parent.columns_begin() == *firstChild.columns_begin())
808 parent.child(0) = (addr.begin()+ii_min)->first;
809 parent.child(1) = (addr.begin()+jj_min)->first;
810 (addr.begin()+ii_min)->first = cur_addr;
812 to_desc = (addr.begin()+jj_min)->second;
813 addr.erase(addr.begin()+jj_min);
817 parent.child(1) = (addr.begin()+ii_min)->first;
818 parent.child(0) = (addr.begin()+jj_min)->first;
819 (addr.begin()+jj_min)->first = cur_addr;
821 to_desc = (addr.begin()+ii_min)->second;
822 addr.erase(addr.begin()+ii_min);
826 for(
int jj = 0 ; jj < (int)addr.size(); ++jj)
830 double bla = dist_func(
831 dist(to_desc, (addr.begin()+jj)->second),
832 dist((addr.begin()+ii_keep)->second,
833 (addr.begin()+jj)->second));
835 dist((addr.begin()+ii_keep)->second,
836 (addr.begin()+jj)->second) = bla;
837 dist((addr.begin()+jj)->second,
838 (addr.begin()+ii_keep)->second) = bla;
859 bool operator()(Node& node)
872 template<
class Iter,
class DT>
877 Matrix<double> tmp_mem_;
880 Matrix<double> feats_;
887 template<
class Feat_T,
class Label_T>
890 Feat_T
const & feats,
891 Label_T
const & labls,
896 :tmp_mem_(_spl(a, b).size(), feats.shape(1)),
899 feats_(_spl(a,b).size(), feats.shape(1)),
900 labels_(_spl(a,b).size(),1),
906 copy_splice(_spl(a,b),
907 _spl(feats.shape(1)),
910 copy_splice(_spl(a,b),
911 _spl(labls.shape(1)),
917 bool operator()(Node& node)
921 int class_count = perm_imp.
shape(1) - 1;
923 for(
int kk = 0; kk < nPerm; ++kk)
926 for(
int ii = 0; ii <
rowCount(feats_); ++ii)
929 for(
int jj = 0; jj < node.columns_size(); ++jj)
931 if(node.columns_begin()[jj] != feats_.shape(1))
932 tmp_mem_(ii, node.columns_begin()[jj])
933 = tmp_mem_(index, node.columns_begin()[jj]);
937 for(
int ii = 0; ii <
rowCount(tmp_mem_); ++ii)
944 ++perm_imp(index,labels_(ii, 0));
946 ++perm_imp(index, class_count);
950 double node_status = perm_imp(index, class_count);
951 node_status /= nPerm;
952 node_status -= orig_imp(0, class_count);
954 node_status /= oob_size;
955 node.status() += node_status;
976 void save(std::string file, std::string prefix)
984 bool operator()(Node& node)
986 for(
int ii = 0; ii < node.columns_size(); ++ii)
987 variables(index, ii) = node.columns_begin()[ii];
1001 bool operator()(Nde & cur,
int level, Nde parent,
bool infm)
1004 cur.status() = std::min(parent.status(), cur.status());
1031 std::ofstream graphviz;
1036 std::string
const gz)
1037 :features_(features), labels_(labels),
1038 graphviz(gz.c_str(), std::ios::out)
1040 graphviz <<
"digraph G\n{\n node [shape=\"record\"]";
1044 graphviz <<
"\n}\n";
1049 bool operator()(Nde & cur,
int level, Nde parent,
bool infm)
1051 graphviz <<
"node" << cur.index() <<
" [style=\"filled\"][label = \" #Feats: "<< cur.columns_size() <<
"\\n";
1052 graphviz <<
" status: " << cur.status() <<
"\\n";
1053 for(
int kk = 0; kk < cur.columns_size(); ++kk)
1055 graphviz << cur.columns_begin()[kk] <<
" ";
1059 graphviz <<
"\"] [color = \"" <<cur.status() <<
" 1.000 1.000\"];\n";
1061 graphviz <<
"\"node" << parent.index() <<
"\" -> \"node" << cur.index() <<
"\";\n";
1081 int repetition_count_;
1087 void save(std::string filename, std::string prefix)
1089 std::string prefix1 =
"cluster_importance_" + prefix;
1093 prefix1 =
"vars_" + prefix;
1101 : repetition_count_(rep_cnt), clustering(clst)
1107 template<
class RF,
class PR>
1110 Int32 const class_count = rf.ext_param_.class_count_;
1111 Int32 const column_count = rf.ext_param_.column_count_+1;
1132 template<
class RF,
class PR,
class SM,
class ST>
1136 Int32 column_count = rf.ext_param_.column_count_ +1;
1137 Int32 class_count = rf.ext_param_.class_count_;
1141 typename PR::Feature_t & features
1142 =
const_cast<typename PR::Feature_t &
>(pr.features());
1149 if(rf.ext_param_.actual_msample_ < pr.features().shape(0)- 10000)
1153 for(
int ii = 0; ii < pr.features().shape(0); ++ii)
1154 indices.push_back(ii);
1155 std::random_shuffle(indices.begin(), indices.end());
1156 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1158 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 3000)
1160 oob_indices.push_back(indices[ii]);
1161 ++cts[pr.response()(indices[ii], 0)];
1167 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1168 if(!sm.is_used()[ii])
1169 oob_indices.push_back(ii);
1179 oob_right(Shp_t(1, class_count + 1));
1182 for(iter = oob_indices.
begin();
1183 iter != oob_indices.
end();
1187 .predictLabel(
rowVector(features, *iter))
1188 == pr.response()(*iter, 0))
1191 ++oob_right[pr.response()(*iter,0)];
1193 ++oob_right[class_count];
1198 perm_oob_right (Shp_t(2* column_count-1, class_count + 1));
1201 pc(oob_indices.
begin(), oob_indices.
end(),
1210 perm_oob_right /= repetition_count_;
1211 for(
int ii = 0; ii <
rowCount(perm_oob_right); ++ii)
1212 rowVector(perm_oob_right, ii) -= oob_right;
1214 perm_oob_right *= -1;
1215 perm_oob_right /= oob_indices.
size();
1224 template<
class RF,
class PR,
class SM,
class ST>
1232 template<
class RF,
class PR>
1272 template<
class FeatureT,
class ResponseT>
1274 ResponseT
const & response,
1281 if(features.shape(0) > 40000)
1288 RF.
learn(features, response,
1317 template<
class FeatureT,
class ResponseT>
1319 ResponseT
const & response,
1320 HClustering & linkage)
1327 template<
class Array1,
class Vector1>
1328 void get_ranking(Array1
const & in, Vector1 & out)
1330 std::map<double, int> mymap;
1331 for(
int ii = 0; ii < in.size(); ++ii)
1333 for(std::map<double, int>::reverse_iterator iter = mymap.rbegin(); iter!= mymap.rend(); ++iter)
1335 out.push_back(iter->second);
UInt32 uniformInt() const
Definition: random.hxx:448
double no_features
Definition: rf_algorithm.hxx:151
void visit_at_end(RF &rf, PR &pr)
Definition: rf_algorithm.hxx:1233
RandomForestOptions & tree_count(int in)
Definition: rf_common.hxx:497
reference back()
Definition: array_vector.hxx:293
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:725
MultiArray< 2, double > cluster_stdev_
Definition: rf_algorithm.hxx:1080
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:669
Topology_type column_data() const
Definition: rf_nodeproxy.hxx:159
MultiArray< 2, double > cluster_importance_
Definition: rf_algorithm.hxx:1077
RandomForestOptions & samples_per_tree(double in)
specify the fraction of the total number of samples used per tree for learning.
Definition: rf_common.hxx:413
const difference_type & shape() const
Definition: multi_array.hxx:1602
void visit_at_beginning(RF const &rf, PR const &pr)
Definition: rf_algorithm.hxx:1108
Definition: rf_algorithm.hxx:1068
const_iterator begin() const
Definition: array_vector.hxx:223
NodeBase()
Definition: rf_nodeproxy.hxx:237
Definition: rf_algorithm.hxx:848
NormalizeStatus(double m)
Definition: rf_algorithm.hxx:855
Definition: rf_algorithm.hxx:997
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2756
Definition: rf_visitors.hxx:852
void forward_selection(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition: rf_algorithm.hxx:294
Definition: rf_visitors.hxx:1484
void backward_elimination(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition: rf_algorithm.hxx:396
detail::VisitorNode< A > create_visitor(A &a)
Definition: rf_visitors.hxx:334
Definition: rf_algorithm.hxx:611
Definition: rf_algorithm.hxx:873
Definition: rf_algorithm.hxx:83
difference_type_1 size() const
Definition: multi_array.hxx:1595
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition: random_forest.hxx:893
Definition: random_forest.hxx:143
reference front()
Definition: array_vector.hxx:279
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
bool init(FeatureT const &all_features, ResponseT const &response, ErrorRateCallBack errorcallback)
Definition: rf_algorithm.hxx:205
void breadth_first_traversal(Functor &tester)
Definition: rf_algorithm.hxx:679
Definition: rf_algorithm.hxx:638
void cluster_permutation_importance(FeatureT const &features, ResponseT const &response, HClustering &linkage, MultiArray< 2, double > &distance)
Definition: rf_algorithm.hxx:1273
Definition: rf_algorithm.hxx:964
Parameter_type parameters_begin() const
Definition: rf_nodeproxy.hxx:207
Definition: metaprogramming.hxx:117
ErrorList_t errors
Definition: rf_algorithm.hxx:146
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:939
void writeHDF5(...)
Store array data in an HDF5 file.
Definition: rf_visitors.hxx:1449
INT & typeID()
Definition: rf_nodeproxy.hxx:136
MultiArray< 2, double > distance
Definition: rf_visitors.hxx:1513
void cluster(MultiArrayView< 2, T, C > distance)
Definition: rf_algorithm.hxx:733
Definition: rf_algorithm.hxx:1025
MultiArray< 2, int > variables
Definition: rf_algorithm.hxx:1074
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:695
Definition: rf_visitors.hxx:101
void rank_selection(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition: rf_algorithm.hxx:493
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_algorithm.hxx:1225
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_algorithm.hxx:1133
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:682
Definition: random.hxx:335
double oob_breiman
Definition: rf_visitors.hxx:863
Options object for the random forest.
Definition: rf_common.hxx:170
RandomForestOptions & use_stratification(RF_OptionTag in)
specify stratification strategy
Definition: rf_common.hxx:376
const_iterator end() const
Definition: array_vector.hxx:237
const_pointer data() const
Definition: array_vector.hxx:209
void iterate(Functor &tester)
Definition: rf_algorithm.hxx:656
FeatureList_t selected
Definition: rf_algorithm.hxx:133
size_type size() const
Definition: array_vector.hxx:330
MultiArrayView< 2, int > variables
Definition: rf_algorithm.hxx:970
double operator()(Feature_t const &features, Response_t const &response)
Definition: rf_algorithm.hxx:101
RFErrorCallback(RandomForestOptions opt=RandomForestOptions())
Definition: rf_algorithm.hxx:93
Definition: rf_algorithm.hxx:116