37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
51 #include "functorexpression.hxx"
52 #include "random_forest/rf_common.hxx"
53 #include "random_forest/rf_nodeproxy.hxx"
54 #include "random_forest/rf_split.hxx"
55 #include "random_forest/rf_decisionTree.hxx"
56 #include "random_forest/rf_visitors.hxx"
57 #include "random_forest/rf_region.hxx"
58 #include "sampling.hxx"
59 #include "random_forest/rf_preprocessing.hxx"
60 #include "random_forest/rf_online_prediction_set.hxx"
61 #include "random_forest/rf_earlystopping.hxx"
62 #include "random_forest/rf_ridge_split.hxx"
82 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
84 SamplerOptions return_opt;
86 return_opt.
stratified(RF_opt.stratification_method_ == RF_EQUAL);
142 template <
class LabelType =
double ,
class PreprocessorTag = ClassificationTag >
149 typedef detail::DecisionTree DecisionTree_t;
156 typedef LabelType LabelT;
231 template<
class TopologyIterator,
class ParameterIterator>
233 TopologyIterator topology_begin,
234 ParameterIterator parameter_begin,
238 trees_(treeCount, DecisionTree_t(problem_spec)),
239 ext_param_(problem_spec),
242 for(
unsigned int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
244 trees_[k].topology_ = *topology_begin;
245 trees_[k].parameters_ = *parameter_begin;
264 vigra_precondition(ext_param_.used() ==
true,
265 "RandomForest::ext_param(): "
266 "Random forest has not been trained yet.");
282 vigra_precondition(ext_param_.used() ==
false,
283 "RandomForest::set_ext_param():"
284 "Random forest has been trained! Call reset()"
285 "before specifying new extrinsic parameters.");
309 DecisionTree_t
const &
tree(
int index)
const
311 return trees_[index];
316 DecisionTree_t &
tree(
int index)
318 return trees_[index];
328 return ext_param_.column_count_;
339 return ext_param_.column_count_;
347 return ext_param_.class_count_;
354 return options_.tree_count_;
359 template<
class U,
class C1,
372 bool adjust_thresholds=
false);
374 template <
class U,
class C1,
class U2,
class C2>
379 onlineLearn(features,
389 template<
class U,
class C1,
395 void reLearnTree(MultiArrayView<2,U,C1>
const & features,
396 MultiArrayView<2,U2,C2>
const & response,
403 template<
class U,
class C1,
class U2,
class C2>
404 void reLearnTree(MultiArrayView<2, U, C1>
const & features,
405 MultiArrayView<2, U2, C2>
const & labels,
408 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
453 template <
class U,
class C1,
459 void learn( MultiArrayView<2, U, C1>
const & features,
460 MultiArrayView<2, U2,C2>
const & response,
464 Random_t
const & random);
466 template <
class U,
class C1,
471 void learn( MultiArrayView<2, U, C1>
const & features,
472 MultiArrayView<2, U2,C2>
const & response,
478 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
487 template <
class U,
class C1,
class U2,
class C2,
class Visitor_t>
488 void learn( MultiArrayView<2, U, C1>
const & features,
489 MultiArrayView<2, U2,C2>
const & labels,
499 template <
class U,
class C1,
class U2,
class C2,
500 class Visitor_t,
class Split_t>
501 void learn( MultiArrayView<2, U, C1>
const & features,
502 MultiArrayView<2, U2,C2>
const & labels,
531 template <
class U,
class C1,
class U2,
class C2>
559 template <
class U,
class C,
class Stop>
562 template <
class U,
class C>
573 template <
class U,
class C>
574 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features,
575 ArrayVectorView<double> prior)
const;
585 template <
class U,
class C1,
class T,
class C2>
589 vigra_precondition(features.
shape(0) == labels.
shape(0),
590 "RandomForest::predictLabels(): Label array has wrong size.");
591 for(
int k=0; k<features.
shape(0); ++k)
595 template <
class U,
class C1,
class T,
class C2,
class Stop>
600 vigra_precondition(features.
shape(0) == labels.
shape(0),
601 "RandomForest::predictLabels(): Label array has wrong size.");
602 for(
int k=0; k<features.
shape(0); ++k)
613 template <
class U,
class C1,
class T,
class C2,
class Stop>
615 MultiArrayView<2, T, C2> & prob,
617 template <
class T1,
class T2,
class C>
619 MultiArrayView<2, T2, C> & prob);
627 template <
class U,
class C1,
class T,
class C2>
634 template <
class U,
class C1,
class T,
class C2>
644 template <
class LabelType,
class PreprocessorTag>
645 template<
class U,
class C1,
651 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1>
const & features,
652 MultiArrayView<2,U2,C2>
const & response,
658 bool adjust_thresholds)
660 online_visitor_.activate();
661 online_visitor_.adjust_thresholds=adjust_thresholds;
665 typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
666 typedef UniformIntRandomFunctor<Random_t>
673 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
674 Default_Stop_t default_stop(options_);
675 typename RF_CHOOSER(Stop_t)::type stop
676 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
677 Default_Split_t default_split;
678 typename RF_CHOOSER(Split_t)::type split
679 = RF_CHOOSER(Split_t)::choose(split_, default_split);
680 rf::visitors::StopVisiting stopvisiting;
681 typedef rf::visitors::detail::VisitorNode
682 <rf::visitors::OnlineLearnVisitor,
683 typename RF_CHOOSER(Visitor_t)::type>
686 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
693 ext_param_.class_count_=0;
694 Preprocessor_t preprocessor( features, response,
695 options_, ext_param_);
698 RandFunctor_t randint ( random);
701 split.set_external_parameters(ext_param_);
702 stop.set_external_parameters(ext_param_);
706 PoissonSampler<RandomTT800> poisson_sampler(1.0,
vigra::Int32(new_start_index),
vigra::Int32(ext_param().row_count_));
712 for(
int ii = 0; ii < (int)trees_.size(); ++ii)
714 online_visitor_.tree_id=ii;
715 poisson_sampler.sample();
716 std::map<int,int> leaf_parents;
717 leaf_parents.clear();
719 for(
int s=0;s<poisson_sampler.numOfSamples();++s)
721 int sample=poisson_sampler[s];
722 online_visitor_.current_label=preprocessor.response()(sample,0);
723 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
724 int leaf=trees_[ii].getToLeaf(
rowVector(features,sample),online_visitor_);
728 online_visitor_.add_to_index_list(ii,leaf,sample);
731 if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
733 leaf_parents[leaf]=online_visitor_.last_node_id;
738 std::map<int,int>::iterator leaf_iterator;
739 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
741 int leaf=leaf_iterator->first;
742 int parent=leaf_iterator->second;
743 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
744 ArrayVector<Int32> indeces;
746 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
747 StackEntry_t stack_entry(indeces.begin(),
749 ext_param_.class_count_);
754 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
756 stack_entry.leftParent=parent;
760 vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,
"last_node_id seems to be wrong");
761 stack_entry.rightParent=parent;
765 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
767 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
780 online_visitor_.deactivate();
783 template<
class LabelType,
class PreprocessorTag>
784 template<
class U,
class C1,
805 ext_param_.class_count_=0;
813 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
815 typename RF_CHOOSER(Stop_t)::type stop
816 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
818 typename RF_CHOOSER(Split_t)::type split
819 = RF_CHOOSER(Split_t)::choose(split_, default_split);
823 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
825 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
827 vigra_precondition(options_.prepare_online_learning_,
"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
828 online_visitor_.activate();
831 RandFunctor_t randint ( random);
837 Preprocessor_t preprocessor( features, response,
838 options_, ext_param_);
841 split.set_external_parameters(ext_param_);
842 stop.set_external_parameters(ext_param_);
849 preprocessor.strata().end(),
850 detail::make_sampler_opt(options_)
851 .sampleSize(ext_param().actual_msample_),
858 first_stack_entry( sampler.sampledIndices().begin(),
859 sampler.sampledIndices().end(),
860 ext_param_.class_count_);
862 .set_oob_range( sampler.oobIndices().begin(),
863 sampler.oobIndices().end());
864 online_visitor_.reset_tree(treeId);
865 online_visitor_.tree_id=treeId;
866 trees_[treeId].reset();
868 .learn( preprocessor.features(),
869 preprocessor.response(),
876 .visit_after_tree( *
this,
882 online_visitor_.deactivate();
885 template <
class LabelType,
class PreprocessorTag>
886 template <
class U,
class C1,
898 Random_t
const & random)
909 vigra_precondition(features.
shape(0) == response.
shape(0),
910 "RandomForest::learn(): shape mismatch between features and response.");
917 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
919 typename RF_CHOOSER(Stop_t)::type stop
920 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
922 typename RF_CHOOSER(Split_t)::type split
923 = RF_CHOOSER(Split_t)::choose(split_, default_split);
927 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
929 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
931 if(options_.prepare_online_learning_)
932 online_visitor_.activate();
934 online_visitor_.deactivate();
938 RandFunctor_t randint ( random);
945 Preprocessor_t preprocessor( features, response,
946 options_, ext_param_);
949 split.set_external_parameters(ext_param_);
950 stop.set_external_parameters(ext_param_);
954 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
957 preprocessor.strata().end(),
958 detail::make_sampler_opt(options_)
959 .sampleSize(ext_param().actual_msample_),
962 visitor.visit_at_beginning(*
this, preprocessor);
965 for(
int ii = 0; ii < (int)trees_.size(); ++ii)
971 first_stack_entry( sampler.sampledIndices().begin(),
972 sampler.sampledIndices().end(),
973 ext_param_.class_count_);
975 .set_oob_range( sampler.oobIndices().begin(),
976 sampler.oobIndices().end());
978 .learn( preprocessor.features(),
979 preprocessor.response(),
986 .visit_after_tree( *
this,
993 visitor.visit_at_end(*
this, preprocessor);
995 online_visitor_.deactivate();
1001 template <
class LabelType,
class Tag>
1002 template <
class U,
class C,
class Stop>
1006 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1007 "RandomForestn::predictLabel():"
1008 " Too few columns in feature matrix.");
1009 vigra_precondition(
rowCount(features) == 1,
1010 "RandomForestn::predictLabel():"
1011 " Feature matrix must have a singlerow.");
1013 garbage_prediction_.reshape(Shp(1, ext_param_.class_count_), 0.0);
1015 predictProbabilities(features, garbage_prediction_, stop);
1016 ext_param_.to_classlabel(
argMax(garbage_prediction_), d);
1022 template <
class LabelType,
class PreprocessorTag>
1023 template <
class U,
class C>
1028 using namespace functor;
1029 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1030 "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1031 vigra_precondition(
rowCount(features) == 1,
1032 "RandomForestn::predictLabel():"
1033 " Feature matrix must have a single row.");
1034 Matrix<double> prob(1,ext_param_.class_count_);
1035 predictProbabilities(features, prob);
1036 std::transform( prob.begin(), prob.end(),
1037 priors.
begin(), prob.begin(),
1040 ext_param_.to_classlabel(
argMax(prob), d);
1044 template<
class LabelType,
class PreprocessorTag>
1045 template <
class T1,
class T2,
class C>
1054 "RandomFroest::predictProbabilities():"
1055 " Feature matrix and probability matrix size mismatch.");
1058 vigra_precondition(
columnCount(predictionSet.features) >= ext_param_.column_count_,
1059 "RandomForestn::predictProbabilities():"
1060 " Too few columns in feature matrix.");
1063 "RandomForestn::predictProbabilities():"
1064 " Probability matrix must have as many columns as there are classes.");
1067 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1070 for(
int k=0; k<options_.tree_count_; ++k)
1072 set_id=(set_id+1) % predictionSet.indices[0].size();
1073 typedef std::set<SampleRange<T1> > my_set;
1074 typedef typename my_set::iterator set_it;
1077 std::vector<std::pair<int,set_it> > stack;
1079 for(set_it i=predictionSet.ranges[set_id].begin();
1080 i!=predictionSet.ranges[set_id].end();++i)
1081 stack.push_back(std::pair<int,set_it>(2,i));
1083 int num_decisions=0;
1084 while(!stack.empty())
1086 set_it range=stack.back().second;
1087 int index=stack.back().first;
1091 if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1094 trees_[k].parameters_,
1095 index).prob_begin();
1096 for(
int i=range->start;i!=range->end;++i)
1099 for(
int l=0; l<ext_param_.class_count_; ++l)
1101 prob(predictionSet.indices[set_id][i], l) += (T2)weights[l];
1103 totalWeights[predictionSet.indices[set_id][i]] += (T1)weights[l];
1110 if(trees_[k].topology_[index]!=i_ThresholdNode)
1112 throw std::runtime_error(
"predicting with online prediction sets is only supported for RFs with threshold nodes");
1114 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1115 if(range->min_boundaries[node.column()]>=node.threshold())
1118 stack.push_back(std::pair<int,set_it>(node.child(1),range));
1121 if(range->max_boundaries[node.column()]<node.threshold())
1124 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1128 SampleRange<T1> new_range=*range;
1129 new_range.min_boundaries[node.column()]=FLT_MAX;
1130 range->max_boundaries[node.column()]=-FLT_MAX;
1131 new_range.start=new_range.end=range->end;
1133 while(i!=range->end)
1136 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1138 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1139 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1142 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1147 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1148 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1153 if(range->start==range->end)
1155 predictionSet.ranges[set_id].erase(range);
1159 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1162 if(new_range.start!=new_range.end)
1164 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1165 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1169 predictionSet.cumulativePredTime[k]=num_decisions;
1171 for(
unsigned int i=0;i<totalWeights.size();++i)
1175 for(
int l=0; l<ext_param_.class_count_; ++l)
1178 prob(i, l) /= totalWeights[i];
1180 assert(test==totalWeights[i]);
1181 assert(totalWeights[i]>0.0);
1185 template <
class LabelType,
class PreprocessorTag>
1186 template <
class U,
class C1,
class T,
class C2,
class Stop_t>
1189 MultiArrayView<2, T, C2> & prob,
1190 Stop_t & stop_)
const
1196 "RandomForestn::predictProbabilities():"
1197 " Feature matrix and probability matrix size mismatch.");
1201 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1202 "RandomForestn::predictProbabilities():"
1203 " Too few columns in feature matrix.");
1206 "RandomForestn::predictProbabilities():"
1207 " Probability matrix must have as many columns as there are classes.");
1209 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1210 Default_Stop_t default_stop(options_);
1211 typename RF_CHOOSER(Stop_t)::type & stop
1212 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1214 stop.set_external_parameters(ext_param_, tree_count());
1215 prob.init(NumericTraits<T>::zero());
1225 for(
int row=0; row <
rowCount(features); ++row)
1227 ArrayVector<double>::const_iterator weights;
1230 double totalWeight = 0.0;
1233 for(
int k=0; k<options_.tree_count_; ++k)
1236 weights = trees_[k ].predict(
rowVector(features, row));
1239 int weighted = options_.predict_weighted_;
1240 for(
int l=0; l<ext_param_.class_count_; ++l)
1242 double cur_w = weights[l] * (weighted * (*(weights-1))
1244 prob(row, l) += (T)cur_w;
1246 totalWeight += cur_w;
1248 if(stop.after_prediction(weights,
1258 for(
int l=0; l< ext_param_.class_count_; ++l)
1260 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1266 template <
class LabelType,
class PreprocessorTag>
1267 template <
class U,
class C1,
class T,
class C2>
1268 void RandomForest<LabelType, PreprocessorTag>
1269 ::predictRaw(MultiArrayView<2, U, C1>
const & features,
1270 MultiArrayView<2, T, C2> & prob)
const
1276 "RandomForestn::predictProbabilities():"
1277 " Feature matrix and probability matrix size mismatch.");
1281 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1282 "RandomForestn::predictProbabilities():"
1283 " Too few columns in feature matrix.");
1286 "RandomForestn::predictProbabilities():"
1287 " Probability matrix must have as many columns as there are classes.");
1289 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1290 prob.init(NumericTraits<T>::zero());
1300 for(
int row=0; row <
rowCount(features); ++row)
1302 ArrayVector<double>::const_iterator weights;
1305 double totalWeight = 0.0;
1308 for(
int k=0; k<options_.tree_count_; ++k)
1311 weights = trees_[k ].predict(
rowVector(features, row));
1314 int weighted = options_.predict_weighted_;
1315 for(
int l=0; l<ext_param_.class_count_; ++l)
1317 double cur_w = weights[l] * (weighted * (*(weights-1))
1319 prob(row, l) += (T)cur_w;
1321 totalWeight += cur_w;
1325 prob/= options_.tree_count_;
1333 #include "random_forest/rf_algorithm.hxx"
1334 #endif // VIGRA_RANDOM_FOREST_HXX
Definition: rf_region.hxx:57
void set_ext_param(ProblemSpec_t const &in)
set external parameters
Definition: random_forest.hxx:280
int class_count() const
return number of classes used while training.
Definition: random_forest.hxx:345
Definition: rf_nodeproxy.hxx:626
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition: rf_common.hxx:131
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:669
Definition: rf_preprocessing.hxx:62
int feature_count() const
return number of features used while training.
Definition: random_forest.hxx:326
int column_count() const
return number of features used while training.
Definition: random_forest.hxx:337
Create random samples from a sequence of indices.
Definition: sampling.hxx:233
const difference_type & shape() const
Definition: multi_array.hxx:1602
Definition: rf_split.hxx:993
const_iterator begin() const
Definition: array_vector.hxx:223
problem specification class for the random forest.
Definition: rf_common.hxx:535
RandomForest(Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t())
default constructor
Definition: random_forest.hxx:198
void sample()
Definition: sampling.hxx:468
std::ptrdiff_t MultiArrayIndex
Definition: multi_iterator.hxx:348
MultiArray< 2, double > garbage_prediction_
Definition: random_forest.hxx:161
Standard early stopping criterion.
Definition: rf_common.hxx:882
ProblemSpec_t const & ext_param() const
return external parameters for viewing
Definition: random_forest.hxx:262
DecisionTree_t & tree(int index)
access trees
Definition: random_forest.hxx:316
DecisionTree_t const & tree(int index) const
access const trees
Definition: random_forest.hxx:309
Options_t & set_options()
access random forest options
Definition: random_forest.hxx:292
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition: random_forest.hxx:893
void reLearnTree(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random)
Definition: random_forest.hxx:790
Definition: random_forest.hxx:143
Options_t const & options() const
access const random forest options
Definition: random_forest.hxx:302
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const
predict the class probabilities for multiple labels
Definition: rf_visitors.hxx:244
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Definition: algorithm.hxx:96
Definition: rf_visitors.hxx:573
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:939
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const
predict multiple labels with given features
Definition: random_forest.hxx:586
SamplerOptions & stratified(bool in=true)
Draw equally many samples from each "stratum". A stratum is a group of like entities, e.g. pixels belonging to the same object class. This is useful to create balanced samples when the class probabilities are very unbalanced (e.g. when there are many background and few foreground pixels). Stratified sampling thus avoids that a trained classifier is biased towards the majority class.
Definition: sampling.hxx:144
SamplerOptions & withReplacement(bool in=true)
Sample from training population with replacement.
Definition: sampling.hxx:86
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const
predict the class probabilities for multiple labels
Definition: random_forest.hxx:628
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:695
int tree_count() const
return number of trees
Definition: random_forest.hxx:352
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:682
RandomForest(int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t())
Create RF from external source.
Definition: random_forest.hxx:232
Definition: random.hxx:335
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:593
Options object for the random forest.
Definition: rf_common.hxx:170
MultiArrayView & init(const U &init)
Definition: multi_array.hxx:1214
LabelType predictLabel(MultiArrayView< 2, U, C >const &features, Stop &stop) const
predict a label given a feature.
Definition: random_forest.hxx:1004
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels)
learn on data with default configuration
Definition: random_forest.hxx:532
Definition: rf_visitors.hxx:224