[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest.hxx
1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 
37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
39 
40 #include <iostream>
41 #include <algorithm>
42 #include <map>
43 #include <set>
44 #include <list>
45 #include <numeric>
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
49 #include "matrix.hxx"
50 #include "random.hxx"
51 #include "functorexpression.hxx"
52 #include "random_forest/rf_common.hxx"
53 #include "random_forest/rf_nodeproxy.hxx"
54 #include "random_forest/rf_split.hxx"
55 #include "random_forest/rf_decisionTree.hxx"
56 #include "random_forest/rf_visitors.hxx"
57 #include "random_forest/rf_region.hxx"
58 #include "sampling.hxx"
59 #include "random_forest/rf_preprocessing.hxx"
60 #include "random_forest/rf_online_prediction_set.hxx"
61 #include "random_forest/rf_earlystopping.hxx"
62 #include "random_forest/rf_ridge_split.hxx"
63 namespace vigra
64 {
65 
66 /** \addtogroup MachineLearning Machine Learning
67 
68  This module provides classification algorithms that map
69  features to labels or label probabilities.
70  Look at the RandomForest class first for a overview of most of the
71  functionality provided as well as use cases.
72 **/
73 //@{
74 
75 namespace detail
76 {
77 
78 
79 
80 /* \brief sampling option factory function
81  */
82 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
83 {
84  SamplerOptions return_opt;
85  return_opt.withReplacement(RF_opt.sample_with_replacement_);
86  return_opt.stratified(RF_opt.stratification_method_ == RF_EQUAL);
87  return return_opt;
88 }
89 }//namespace detail
90 
91 /** Random Forest class
92  *
93  * \tparam <PrprocessorTag = ClassificationTag> Class used to preprocess
94  * the input while learning and predicting. Currently Available:
95  * ClassificationTag and RegressionTag. It is recommended to use
96  * Splitfunctor::Preprocessor_t while using custom splitfunctors
97  * as they may need the data to be in a different format.
98  * \sa Preprocessor
99  *
100  * simple usage for classification (regression is not yet supported):
101  * look at RandomForest::learn() as well as RandomForestOptions() for additional
102  * options.
103  *
104  * \code
105  * using namespace vigra;
106  * using namespace rf;
107  * typedef xxx feature_t; \\ replace xxx with whichever type
108  * typedef yyy label_t; \\ likewise
109  *
110  * // allocate the training data
111  * MultiArrayView<2, feature_t> f = get_training_features();
112  * MultiArrayView<2, label_t> l = get_training_labels();
113  *
114  * RandomForest<> rf;
115  *
116  * // construct visitor to calculate out-of-bag error
117  * visitors::OOB_Error oob_v;
118  *
119  * // perform training
120  * rf.learn(f, l, visitors::create_visitor(oob_v));
121  *
122  * std::cout << "the out-of-bag error is: " << oob_v.oob_breiman << "\n";
123  *
124  * // get features for new data to be used for prediction
125  * MultiArrayView<2, feature_t> pf = get_features();
126  *
127  * // allocate space for the response (pf.shape(0) is the number of samples)
128  * MultiArrayView<2, label_t> prediction(pf.shape(0), 1);
129  * MultiArrayView<2, double> prob(pf.shape(0), rf.class_count());
130  *
131  * // perform prediction on new data
132  * rf.predict_labels(pf, prediction);
133  * rf.predict_probabilities(pf, prob);
134  *
135  * \endcode
136  *
137  * Additional information such as Variable Importance measures are accessed
138  * via Visitors defined in rf::visitors.
139  * Have a look at rf::split for other splitting methods.
140  *
141 */
142 template <class LabelType = double , class PreprocessorTag = ClassificationTag >
144 {
145 
146  public:
147  //public typedefs
149  typedef detail::DecisionTree DecisionTree_t;
151  typedef GiniSplit Default_Split_t;
155  StackEntry_t;
156  typedef LabelType LabelT;
157  protected:
158 
159  /** optimisation for predictLabels
160  * */
162 
163  public:
164 
165  //problem independent data.
166  Options_t options_;
167  //problem dependent data members - is only set if
168  //a copy constructor, some sort of import
169  //function or the learn function is called
171  ProblemSpec_t ext_param_;
172  /*mutable ArrayVector<int> tree_indices_;*/
173  rf::visitors::OnlineLearnVisitor online_visitor_;
174 
175 
176  void reset()
177  {
178  ext_param_.clear();
179  trees_.clear();
180  }
181 
182  public:
183 
184  /** \name Constructors
185  * Note: No copy Constructor specified as no pointers are manipulated
186  * in this class
187  */
188  /*\{*/
189  /**\brief default constructor
190  *
191  * \param options general options to the Random Forest. Must be of Type
192  * Options_t
193  * \param ext_param problem specific values that can be supplied
194  * additionally. (class weights , labels etc)
195  * \sa RandomForestOptions, ProblemSpec
196  *
197  */
200  :
201  options_(options),
202  ext_param_(ext_param)/*,
203  tree_indices_(options.tree_count_,0)*/
204  {
205  /*for(int ii = 0 ; ii < int(tree_indices_.size()); ++ii)
206  tree_indices_[ii] = ii;*/
207  }
208 
209  /**\brief Create RF from external source
210  * \param treeCount Number of trees to add.
211  * \param topology_begin
212  * Iterator to a Container where the topology_ data
213  * of the trees are stored.
214  * Iterator should support at least treeCount forward
215  * iterations. (i.e. topology_end - topology_begin >= treeCount
216  * \param parameter_begin
217  * iterator to a Container where the parameters_ data
218  * of the trees are stored. Iterator should support at
219  * least treeCount forward iterations.
220  * \param problem_spec
221  * Extrinsic parameters that specify the problem e.g.
222  * ClassCount, featureCount etc.
223  * \param options (optional) specify options used to train the original
224  * Random forest. This parameter is not used anywhere
225  * during prediction and thus is optional.
226  *
227  */
228  /* TODO: This constructor may be replaced by a Constructor using
229  * NodeProxy iterators to encapsulate the underlying data type.
230  */
231  template<class TopologyIterator, class ParameterIterator>
232  RandomForest(int treeCount,
233  TopologyIterator topology_begin,
234  ParameterIterator parameter_begin,
235  ProblemSpec_t const & problem_spec,
236  Options_t const & options = Options_t())
237  :
238  trees_(treeCount, DecisionTree_t(problem_spec)),
239  ext_param_(problem_spec),
240  options_(options)
241  {
242  for(unsigned int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
243  {
244  trees_[k].topology_ = *topology_begin;
245  trees_[k].parameters_ = *parameter_begin;
246  }
247  }
248 
249  /*\}*/
250 
251 
252  /** \name Data Access
253  * data access interface - usage of member variables is deprecated
254  */
255 
256  /*\{*/
257 
258 
259  /**\brief return external parameters for viewing
260  * \return ProblemSpec_t
261  */
262  ProblemSpec_t const & ext_param() const
263  {
264  vigra_precondition(ext_param_.used() == true,
265  "RandomForest::ext_param(): "
266  "Random forest has not been trained yet.");
267  return ext_param_;
268  }
269 
270  /**\brief set external parameters
271  *
272  * \param in external parameters to be set
273  *
274  * set external parameters explicitly.
275  * If Random Forest has not been trained the preprocessor will
276  * either ignore filling values set this way or will throw an exception
277  * if values specified manually do not match the value calculated
278  & during the preparation step.
279  */
280  void set_ext_param(ProblemSpec_t const & in)
281  {
282  vigra_precondition(ext_param_.used() == false,
283  "RandomForest::set_ext_param():"
284  "Random forest has been trained! Call reset()"
285  "before specifying new extrinsic parameters.");
286  }
287 
288  /**\brief access random forest options
289  *
290  * \return random forest options
291  */
293  {
294  return options;
295  }
296 
297 
298  /**\brief access const random forest options
299  *
300  * \return const Option_t
301  */
302  Options_t const & options() const
303  {
304  return options_;
305  }
306 
307  /**\brief access const trees
308  */
309  DecisionTree_t const & tree(int index) const
310  {
311  return trees_[index];
312  }
313 
314  /**\brief access trees
315  */
316  DecisionTree_t & tree(int index)
317  {
318  return trees_[index];
319  }
320 
321  /*\}*/
322 
323  /**\brief return number of features used while
324  * training.
325  */
326  int feature_count() const
327  {
328  return ext_param_.column_count_;
329  }
330 
331 
332  /**\brief return number of features used while
333  * training.
334  *
335  * deprecated. Use feature_count() instead.
336  */
337  int column_count() const
338  {
339  return ext_param_.column_count_;
340  }
341 
342  /**\brief return number of classes used while
343  * training.
344  */
345  int class_count() const
346  {
347  return ext_param_.class_count_;
348  }
349 
350  /**\brief return number of trees
351  */
352  int tree_count() const
353  {
354  return options_.tree_count_;
355  }
356 
357 
358 
359  template<class U,class C1,
360  class U2, class C2,
361  class Split_t,
362  class Stop_t,
363  class Visitor_t,
364  class Random_t>
365  void onlineLearn( MultiArrayView<2,U,C1> const & features,
366  MultiArrayView<2,U2,C2> const & response,
367  int new_start_index,
368  Visitor_t visitor_,
369  Split_t split_,
370  Stop_t stop_,
371  Random_t & random,
372  bool adjust_thresholds=false);
373 
374  template <class U, class C1, class U2,class C2>
375  void onlineLearn( MultiArrayView<2, U, C1> const & features,
376  MultiArrayView<2, U2,C2> const & labels,int new_start_index,bool adjust_thresholds=false)
377  {
379  onlineLearn(features,
380  labels,
381  new_start_index,
382  rf_default(),
383  rf_default(),
384  rf_default(),
385  rnd,
386  adjust_thresholds);
387  }
388 
389  template<class U,class C1,
390  class U2, class C2,
391  class Split_t,
392  class Stop_t,
393  class Visitor_t,
394  class Random_t>
395  void reLearnTree(MultiArrayView<2,U,C1> const & features,
396  MultiArrayView<2,U2,C2> const & response,
397  int treeId,
398  Visitor_t visitor_,
399  Split_t split_,
400  Stop_t stop_,
401  Random_t & random);
402 
403  template<class U, class C1, class U2, class C2>
404  void reLearnTree(MultiArrayView<2, U, C1> const & features,
405  MultiArrayView<2, U2, C2> const & labels,
406  int treeId)
407  {
408  RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
409  reLearnTree(features,
410  labels,
411  treeId,
412  rf_default(),
413  rf_default(),
414  rf_default(),
415  rnd);
416  }
417 
418 
419  /**\name Learning
420  * Following functions differ in the degree of customization
421  * allowed
422  */
423  /*\{*/
424  /**\brief learn on data with custom config and random number generator
425  *
426  * \param features a N x M matrix containing N samples with M
427  * features
428  * \param response a N x D matrix containing the corresponding
429  * response. Current split functors assume D to
430  * be 1 and ignore any additional columns.
431  * This is not enforced to allow future support
432  * for uncertain labels, label independent strata etc.
433  * The Preprocessor specified during construction
434  * should be able to handle features and labels
435  * features and the labels.
436  * see also: SplitFunctor, Preprocessing
437  *
438  * \param visitor visitor which is to be applied after each split,
439  * tree and at the end. Use rf_default for using
440  * default value. (No Visitors)
441  * see also: rf::visitors
442  * \param split split functor to be used to calculate each split
443  * use rf_default() for using default value. (GiniSplit)
444  * see also: rf::split
445  * \param stop
446  * predicate to be used to calculate each split
447  * use rf_default() for using default value. (EarlyStoppStd)
448  * \param random RandomNumberGenerator to be used. Use
449  * rf_default() to use default value.(RandomMT19337)
450  *
451  *
452  */
453  template <class U, class C1,
454  class U2,class C2,
455  class Split_t,
456  class Stop_t,
457  class Visitor_t,
458  class Random_t>
459  void learn( MultiArrayView<2, U, C1> const & features,
460  MultiArrayView<2, U2,C2> const & response,
461  Visitor_t visitor,
462  Split_t split,
463  Stop_t stop,
464  Random_t const & random);
465 
466  template <class U, class C1,
467  class U2,class C2,
468  class Split_t,
469  class Stop_t,
470  class Visitor_t>
471  void learn( MultiArrayView<2, U, C1> const & features,
472  MultiArrayView<2, U2,C2> const & response,
473  Visitor_t visitor,
474  Split_t split,
475  Stop_t stop)
476 
477  {
478  RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
479  learn( features,
480  response,
481  visitor,
482  split,
483  stop,
484  rnd);
485  }
486 
487  template <class U, class C1, class U2,class C2, class Visitor_t>
488  void learn( MultiArrayView<2, U, C1> const & features,
489  MultiArrayView<2, U2,C2> const & labels,
490  Visitor_t visitor)
491  {
492  learn( features,
493  labels,
494  visitor,
495  rf_default(),
496  rf_default());
497  }
498 
499  template <class U, class C1, class U2,class C2,
500  class Visitor_t, class Split_t>
501  void learn( MultiArrayView<2, U, C1> const & features,
502  MultiArrayView<2, U2,C2> const & labels,
503  Visitor_t visitor,
504  Split_t split)
505  {
506  learn( features,
507  labels,
508  visitor,
509  split,
510  rf_default());
511  }
512 
513  /**\brief learn on data with default configuration
514  *
515  * \param features a N x M matrix containing N samples with M
516  * features
517  * \param labels a N x D matrix containing the corresponding
518  * N labels. Current split functors assume D to
519  * be 1 and ignore any additional columns.
520  * this is not enforced to allow future support
521  * for uncertain labels.
522  *
523  * learning is done with:
524  *
525  * \sa rf::split, EarlyStoppStd
526  *
527  * - Randomly seeded random number generator
528  * - default gini split functor as described by Breiman
529  * - default The standard early stopping criterion
530  */
531  template <class U, class C1, class U2,class C2>
532  void learn( MultiArrayView<2, U, C1> const & features,
533  MultiArrayView<2, U2,C2> const & labels)
534  {
535  learn( features,
536  labels,
537  rf_default(),
538  rf_default(),
539  rf_default());
540  }
541  /*\}*/
542 
543 
544 
545  /**\name prediction
546  */
547  /*\{*/
548  /** \brief predict a label given a feature.
549  *
550  * \param features: a 1 by featureCount matrix containing
551  * data point to be predicted (this only works in
552  * classification setting)
553  * \param stop: early stopping criterion
554  * \return double value representing class. You can use the
555  * predictLabels() function together with the
556  * rf.external_parameter().class_type_ attribute
557  * to get back the same type used during learning.
558  */
559  template <class U, class C, class Stop>
560  LabelType predictLabel(MultiArrayView<2, U, C>const & features, Stop & stop) const;
561 
562  template <class U, class C>
563  LabelType predictLabel(MultiArrayView<2, U, C>const & features)
564  {
565  return predictLabel(features, rf_default());
566  }
567  /** \brief predict a label with features and class priors
568  *
569  * \param features: same as above.
570  * \param prior: iterator to prior weighting of classes
571  * \return sam as above.
572  */
573  template <class U, class C>
574  LabelType predictLabel(MultiArrayView<2, U, C> const & features,
575  ArrayVectorView<double> prior) const;
576 
577  /** \brief predict multiple labels with given features
578  *
579  * \param features: a n by featureCount matrix containing
580  * data point to be predicted (this only works in
581  * classification setting)
582  * \param labels: a n by 1 matrix passed by reference to store
583  * output.
584  */
585  template <class U, class C1, class T, class C2>
587  MultiArrayView<2, T, C2> & labels) const
588  {
589  vigra_precondition(features.shape(0) == labels.shape(0),
590  "RandomForest::predictLabels(): Label array has wrong size.");
591  for(int k=0; k<features.shape(0); ++k)
592  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
593  }
594 
595  template <class U, class C1, class T, class C2, class Stop>
596  void predictLabels(MultiArrayView<2, U, C1>const & features,
597  MultiArrayView<2, T, C2> & labels,
598  Stop & stop) const
599  {
600  vigra_precondition(features.shape(0) == labels.shape(0),
601  "RandomForest::predictLabels(): Label array has wrong size.");
602  for(int k=0; k<features.shape(0); ++k)
603  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), stop));
604  }
605  /** \brief predict the class probabilities for multiple labels
606  *
607  * \param features same as above
608  * \param prob a n x class_count_ matrix. passed by reference to
609  * save class probabilities
610  * \param stop earlystopping criterion
611  * \sa EarlyStopping
612  */
613  template <class U, class C1, class T, class C2, class Stop>
614  void predictProbabilities(MultiArrayView<2, U, C1>const & features,
615  MultiArrayView<2, T, C2> & prob,
616  Stop & stop) const;
617  template <class T1,class T2, class C>
618  void predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
619  MultiArrayView<2, T2, C> & prob);
620 
621  /** \brief predict the class probabilities for multiple labels
622  *
623  * \param features same as above
624  * \param prob a n x class_count_ matrix. passed by reference to
625  * save class probabilities
626  */
627  template <class U, class C1, class T, class C2>
629  MultiArrayView<2, T, C2> & prob) const
630  {
631  predictProbabilities(features, prob, rf_default());
632  }
633 
634  template <class U, class C1, class T, class C2>
635  void predictRaw(MultiArrayView<2, U, C1>const & features,
636  MultiArrayView<2, T, C2> & prob) const;
637 
638 
639  /*\}*/
640 
641 };
642 
643 
644 template <class LabelType, class PreprocessorTag>
645 template<class U,class C1,
646  class U2, class C2,
647  class Split_t,
648  class Stop_t,
649  class Visitor_t,
650  class Random_t>
651 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1> const & features,
652  MultiArrayView<2,U2,C2> const & response,
653  int new_start_index,
654  Visitor_t visitor_,
655  Split_t split_,
656  Stop_t stop_,
657  Random_t & random,
658  bool adjust_thresholds)
659 {
660  online_visitor_.activate();
661  online_visitor_.adjust_thresholds=adjust_thresholds;
662 
663  using namespace rf;
664  //typedefs
665  typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
666  typedef UniformIntRandomFunctor<Random_t>
667  RandFunctor_t;
668  // default values and initialization
669  // Value Chooser chooses second argument as value if first argument
670  // is of type RF_DEFAULT. (thanks to template magic - don't care about
671  // it - just smile and wave.
672 
673  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
674  Default_Stop_t default_stop(options_);
675  typename RF_CHOOSER(Stop_t)::type stop
676  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
677  Default_Split_t default_split;
678  typename RF_CHOOSER(Split_t)::type split
679  = RF_CHOOSER(Split_t)::choose(split_, default_split);
680  rf::visitors::StopVisiting stopvisiting;
681  typedef rf::visitors::detail::VisitorNode
682  <rf::visitors::OnlineLearnVisitor,
683  typename RF_CHOOSER(Visitor_t)::type>
684  IntermedVis;
685  IntermedVis
686  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
687  #undef RF_CHOOSER
688 
689  // Preprocess the data to get something the split functor can work
690  // with. Also fill the ext_param structure by preprocessing
691  // option parameters that could only be completely evaluated
692  // when the training data is known.
693  ext_param_.class_count_=0;
694  Preprocessor_t preprocessor( features, response,
695  options_, ext_param_);
696 
697  // Make stl compatible random functor.
698  RandFunctor_t randint ( random);
699 
700  // Give the Split functor information about the data.
701  split.set_external_parameters(ext_param_);
702  stop.set_external_parameters(ext_param_);
703 
704 
705  //Create poisson samples
706  PoissonSampler<RandomTT800> poisson_sampler(1.0,vigra::Int32(new_start_index),vigra::Int32(ext_param().row_count_));
707 
708  //TODO: visitors for online learning
709  //visitor.visit_at_beginning(*this, preprocessor);
710 
711  // THE MAIN EFFING RF LOOP - YEAY DUDE!
712  for(int ii = 0; ii < (int)trees_.size(); ++ii)
713  {
714  online_visitor_.tree_id=ii;
715  poisson_sampler.sample();
716  std::map<int,int> leaf_parents;
717  leaf_parents.clear();
718  //Get all the leaf nodes for that sample
719  for(int s=0;s<poisson_sampler.numOfSamples();++s)
720  {
721  int sample=poisson_sampler[s];
722  online_visitor_.current_label=preprocessor.response()(sample,0);
723  online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
724  int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_);
725 
726 
727  //Add to the list for that leaf
728  online_visitor_.add_to_index_list(ii,leaf,sample);
729  //TODO: Class count?
730  //Store parent
731  if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
732  {
733  leaf_parents[leaf]=online_visitor_.last_node_id;
734  }
735  }
736 
737 
738  std::map<int,int>::iterator leaf_iterator;
739  for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
740  {
741  int leaf=leaf_iterator->first;
742  int parent=leaf_iterator->second;
743  int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
744  ArrayVector<Int32> indeces;
745  indeces.clear();
746  indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
747  StackEntry_t stack_entry(indeces.begin(),
748  indeces.end(),
749  ext_param_.class_count_);
750 
751 
752  if(parent!=-1)
753  {
754  if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
755  {
756  stack_entry.leftParent=parent;
757  }
758  else
759  {
760  vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,"last_node_id seems to be wrong");
761  stack_entry.rightParent=parent;
762  }
763  }
764  //trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,leaf);
765  trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
766  //Now, the last one moved onto leaf
767  online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
768  //Now it should be classified correctly!
769  }
770 
771  /*visitor
772  .visit_after_tree( *this,
773  preprocessor,
774  poisson_sampler,
775  stack_entry,
776  ii);*/
777  }
778 
779  //visitor.visit_at_end(*this, preprocessor);
780  online_visitor_.deactivate();
781 }
782 
783 template<class LabelType, class PreprocessorTag>
784 template<class U,class C1,
785  class U2, class C2,
786  class Split_t,
787  class Stop_t,
788  class Visitor_t,
789  class Random_t>
791  MultiArrayView<2,U2,C2> const & response,
792  int treeId,
793  Visitor_t visitor_,
794  Split_t split_,
795  Stop_t stop_,
796  Random_t & random)
797 {
798  using namespace rf;
799 
800 
802  RandFunctor_t;
803 
804  // See rf_preprocessing.hxx for more info on this
805  ext_param_.class_count_=0;
807 
808  // default values and initialization
809  // Value Chooser chooses second argument as value if first argument
810  // is of type RF_DEFAULT. (thanks to template magic - don't care about
811  // it - just smile and wave.
812 
813  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
814  Default_Stop_t default_stop(options_);
815  typename RF_CHOOSER(Stop_t)::type stop
816  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
817  Default_Split_t default_split;
818  typename RF_CHOOSER(Split_t)::type split
819  = RF_CHOOSER(Split_t)::choose(split_, default_split);
820  rf::visitors::StopVisiting stopvisiting;
823  typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
824  IntermedVis
825  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
826  #undef RF_CHOOSER
827  vigra_precondition(options_.prepare_online_learning_,"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
828  online_visitor_.activate();
829 
830  // Make stl compatible random functor.
831  RandFunctor_t randint ( random);
832 
833  // Preprocess the data to get something the split functor can work
834  // with. Also fill the ext_param structure by preprocessing
835  // option parameters that could only be completely evaluated
836  // when the training data is known.
837  Preprocessor_t preprocessor( features, response,
838  options_, ext_param_);
839 
840  // Give the Split functor information about the data.
841  split.set_external_parameters(ext_param_);
842  stop.set_external_parameters(ext_param_);
843 
844  /**\todo replace this crappy class out. It uses function pointers.
845  * and is making code slower according to me.
846  * Comment from Nathan: This is copied from Rahul, so me=Rahul
847  */
848  Sampler<Random_t > sampler(preprocessor.strata().begin(),
849  preprocessor.strata().end(),
850  detail::make_sampler_opt(options_)
851  .sampleSize(ext_param().actual_msample_),
852  &random);
853  //initialize First region/node/stack entry
854  sampler
855  .sample();
856 
858  first_stack_entry( sampler.sampledIndices().begin(),
859  sampler.sampledIndices().end(),
860  ext_param_.class_count_);
861  first_stack_entry
862  .set_oob_range( sampler.oobIndices().begin(),
863  sampler.oobIndices().end());
864  online_visitor_.reset_tree(treeId);
865  online_visitor_.tree_id=treeId;
866  trees_[treeId].reset();
867  trees_[treeId]
868  .learn( preprocessor.features(),
869  preprocessor.response(),
870  first_stack_entry,
871  split,
872  stop,
873  visitor,
874  randint);
875  visitor
876  .visit_after_tree( *this,
877  preprocessor,
878  sampler,
879  first_stack_entry,
880  treeId);
881 
882  online_visitor_.deactivate();
883 }
884 
885 template <class LabelType, class PreprocessorTag>
886 template <class U, class C1,
887  class U2,class C2,
888  class Split_t,
889  class Stop_t,
890  class Visitor_t,
891  class Random_t>
894  MultiArrayView<2, U2,C2> const & response,
895  Visitor_t visitor_,
896  Split_t split_,
897  Stop_t stop_,
898  Random_t const & random)
899 {
900  using namespace rf;
901  //this->reset();
902  //typedefs
904  RandFunctor_t;
905 
906  // See rf_preprocessing.hxx for more info on this
908 
909  vigra_precondition(features.shape(0) == response.shape(0),
910  "RandomForest::learn(): shape mismatch between features and response.");
911 
912  // default values and initialization
913  // Value Chooser chooses second argument as value if first argument
914  // is of type RF_DEFAULT. (thanks to template magic - don't care about
915  // it - just smile and wave.
916 
917  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
918  Default_Stop_t default_stop(options_);
919  typename RF_CHOOSER(Stop_t)::type stop
920  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
921  Default_Split_t default_split;
922  typename RF_CHOOSER(Split_t)::type split
923  = RF_CHOOSER(Split_t)::choose(split_, default_split);
924  rf::visitors::StopVisiting stopvisiting;
927  typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
928  IntermedVis
929  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
930  #undef RF_CHOOSER
931  if(options_.prepare_online_learning_)
932  online_visitor_.activate();
933  else
934  online_visitor_.deactivate();
935 
936 
937  // Make stl compatible random functor.
938  RandFunctor_t randint ( random);
939 
940 
941  // Preprocess the data to get something the split functor can work
942  // with. Also fill the ext_param structure by preprocessing
943  // option parameters that could only be completely evaluated
944  // when the training data is known.
945  Preprocessor_t preprocessor( features, response,
946  options_, ext_param_);
947 
948  // Give the Split functor information about the data.
949  split.set_external_parameters(ext_param_);
950  stop.set_external_parameters(ext_param_);
951 
952 
953  //initialize trees.
954  trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
955 
956  Sampler<Random_t > sampler(preprocessor.strata().begin(),
957  preprocessor.strata().end(),
958  detail::make_sampler_opt(options_)
959  .sampleSize(ext_param().actual_msample_),
960  &random);
961 
962  visitor.visit_at_beginning(*this, preprocessor);
963  // THE MAIN EFFING RF LOOP - YEAY DUDE!
964 
965  for(int ii = 0; ii < (int)trees_.size(); ++ii)
966  {
967  //initialize First region/node/stack entry
968  sampler
969  .sample();
971  first_stack_entry( sampler.sampledIndices().begin(),
972  sampler.sampledIndices().end(),
973  ext_param_.class_count_);
974  first_stack_entry
975  .set_oob_range( sampler.oobIndices().begin(),
976  sampler.oobIndices().end());
977  trees_[ii]
978  .learn( preprocessor.features(),
979  preprocessor.response(),
980  first_stack_entry,
981  split,
982  stop,
983  visitor,
984  randint);
985  visitor
986  .visit_after_tree( *this,
987  preprocessor,
988  sampler,
989  first_stack_entry,
990  ii);
991  }
992 
993  visitor.visit_at_end(*this, preprocessor);
994  // Only for online learning?
995  online_visitor_.deactivate();
996 }
997 
998 
999 
1000 
1001 template <class LabelType, class Tag>
1002 template <class U, class C, class Stop>
1004  ::predictLabel(MultiArrayView<2, U, C> const & features, Stop & stop) const
1005 {
1006  vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1007  "RandomForestn::predictLabel():"
1008  " Too few columns in feature matrix.");
1009  vigra_precondition(rowCount(features) == 1,
1010  "RandomForestn::predictLabel():"
1011  " Feature matrix must have a singlerow.");
1012  typedef MultiArrayShape<2>::type Shp;
1013  garbage_prediction_.reshape(Shp(1, ext_param_.class_count_), 0.0);
1014  LabelType d;
1015  predictProbabilities(features, garbage_prediction_, stop);
1016  ext_param_.to_classlabel(argMax(garbage_prediction_), d);
1017  return d;
1018 }
1019 
1020 
1021 //Same thing as above with priors for each label !!!
1022 template <class LabelType, class PreprocessorTag>
1023 template <class U, class C>
1026  ArrayVectorView<double> priors) const
1027 {
1028  using namespace functor;
1029  vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1030  "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1031  vigra_precondition(rowCount(features) == 1,
1032  "RandomForestn::predictLabel():"
1033  " Feature matrix must have a single row.");
1034  Matrix<double> prob(1,ext_param_.class_count_);
1035  predictProbabilities(features, prob);
1036  std::transform( prob.begin(), prob.end(),
1037  priors.begin(), prob.begin(),
1038  Arg1()*Arg2());
1039  LabelType d;
1040  ext_param_.to_classlabel(argMax(prob), d);
1041  return d;
1042 }
1043 
1044 template<class LabelType,class PreprocessorTag>
1045 template <class T1,class T2, class C>
1047  ::predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
1048  MultiArrayView<2, T2, C> & prob)
1049 {
1050  //Features are n xp
1051  //prob is n x NumOfLabel probability for each feature in each class
1052 
1053  vigra_precondition(rowCount(predictionSet.features) == rowCount(prob),
1054  "RandomFroest::predictProbabilities():"
1055  " Feature matrix and probability matrix size mismatch.");
1056  // num of features must be bigger than num of features in Random forest training
1057  // but why bigger?
1058  vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_,
1059  "RandomForestn::predictProbabilities():"
1060  " Too few columns in feature matrix.");
1061  vigra_precondition( columnCount(prob)
1062  == (MultiArrayIndex)ext_param_.class_count_,
1063  "RandomForestn::predictProbabilities():"
1064  " Probability matrix must have as many columns as there are classes.");
1065  prob.init(0.0);
1066  //store total weights
1067  std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1068  //Go through all trees
1069  int set_id=-1;
1070  for(int k=0; k<options_.tree_count_; ++k)
1071  {
1072  set_id=(set_id+1) % predictionSet.indices[0].size();
1073  typedef std::set<SampleRange<T1> > my_set;
1074  typedef typename my_set::iterator set_it;
1075  //typedef std::set<std::pair<int,SampleRange<T1> > >::iterator set_it;
1076  //Build a stack with all the ranges we have
1077  std::vector<std::pair<int,set_it> > stack;
1078  stack.clear();
1079  for(set_it i=predictionSet.ranges[set_id].begin();
1080  i!=predictionSet.ranges[set_id].end();++i)
1081  stack.push_back(std::pair<int,set_it>(2,i));
1082  //get weights predicted by single tree
1083  int num_decisions=0;
1084  while(!stack.empty())
1085  {
1086  set_it range=stack.back().second;
1087  int index=stack.back().first;
1088  stack.pop_back();
1089  ++num_decisions;
1090 
1091  if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1092  {
1093  ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_,
1094  trees_[k].parameters_,
1095  index).prob_begin();
1096  for(int i=range->start;i!=range->end;++i)
1097  {
1098  //update votecount.
1099  for(int l=0; l<ext_param_.class_count_; ++l)
1100  {
1101  prob(predictionSet.indices[set_id][i], l) += (T2)weights[l];
1102  //every weight in totalWeight.
1103  totalWeights[predictionSet.indices[set_id][i]] += (T1)weights[l];
1104  }
1105  }
1106  }
1107 
1108  else
1109  {
1110  if(trees_[k].topology_[index]!=i_ThresholdNode)
1111  {
1112  throw std::runtime_error("predicting with online prediction sets is only supported for RFs with threshold nodes");
1113  }
1114  Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1115  if(range->min_boundaries[node.column()]>=node.threshold())
1116  {
1117  //Everything goes to right child
1118  stack.push_back(std::pair<int,set_it>(node.child(1),range));
1119  continue;
1120  }
1121  if(range->max_boundaries[node.column()]<node.threshold())
1122  {
1123  //Everything goes to the left child
1124  stack.push_back(std::pair<int,set_it>(node.child(0),range));
1125  continue;
1126  }
1127  //We have to split at this node
1128  SampleRange<T1> new_range=*range;
1129  new_range.min_boundaries[node.column()]=FLT_MAX;
1130  range->max_boundaries[node.column()]=-FLT_MAX;
1131  new_range.start=new_range.end=range->end;
1132  int i=range->start;
1133  while(i!=range->end)
1134  {
1135  //Decide for range->indices[i]
1136  if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1137  {
1138  new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1139  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1140  --range->end;
1141  --new_range.start;
1142  std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1143 
1144  }
1145  else
1146  {
1147  range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1148  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1149  ++i;
1150  }
1151  }
1152  //The old one ...
1153  if(range->start==range->end)
1154  {
1155  predictionSet.ranges[set_id].erase(range);
1156  }
1157  else
1158  {
1159  stack.push_back(std::pair<int,set_it>(node.child(0),range));
1160  }
1161  //And the new one ...
1162  if(new_range.start!=new_range.end)
1163  {
1164  std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1165  stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1166  }
1167  }
1168  }
1169  predictionSet.cumulativePredTime[k]=num_decisions;
1170  }
1171  for(unsigned int i=0;i<totalWeights.size();++i)
1172  {
1173  double test=0.0;
1174  //Normalise votes in each row by total VoteCount (totalWeight
1175  for(int l=0; l<ext_param_.class_count_; ++l)
1176  {
1177  test+=prob(i,l);
1178  prob(i, l) /= totalWeights[i];
1179  }
1180  assert(test==totalWeights[i]);
1181  assert(totalWeights[i]>0.0);
1182  }
1183 }
1184 
1185 template <class LabelType, class PreprocessorTag>
1186 template <class U, class C1, class T, class C2, class Stop_t>
1188  ::predictProbabilities(MultiArrayView<2, U, C1>const & features,
1189  MultiArrayView<2, T, C2> & prob,
1190  Stop_t & stop_) const
1191 {
1192  //Features are n xp
1193  //prob is n x NumOfLabel probability for each feature in each class
1194 
1195  vigra_precondition(rowCount(features) == rowCount(prob),
1196  "RandomForestn::predictProbabilities():"
1197  " Feature matrix and probability matrix size mismatch.");
1198 
1199  // num of features must be bigger than num of features in Random forest training
1200  // but why bigger?
1201  vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1202  "RandomForestn::predictProbabilities():"
1203  " Too few columns in feature matrix.");
1204  vigra_precondition( columnCount(prob)
1205  == (MultiArrayIndex)ext_param_.class_count_,
1206  "RandomForestn::predictProbabilities():"
1207  " Probability matrix must have as many columns as there are classes.");
1208 
1209  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1210  Default_Stop_t default_stop(options_);
1211  typename RF_CHOOSER(Stop_t)::type & stop
1212  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1213  #undef RF_CHOOSER
1214  stop.set_external_parameters(ext_param_, tree_count());
1215  prob.init(NumericTraits<T>::zero());
1216  /* This code was originally there for testing early stopping
1217  * - we wanted the order of the trees to be randomized
1218  if(tree_indices_.size() != 0)
1219  {
1220  std::random_shuffle(tree_indices_.begin(),
1221  tree_indices_.end());
1222  }
1223  */
1224  //Classify for each row.
1225  for(int row=0; row < rowCount(features); ++row)
1226  {
1227  ArrayVector<double>::const_iterator weights;
1228 
1229  //totalWeight == totalVoteCount!
1230  double totalWeight = 0.0;
1231 
1232  //Let each tree classify...
1233  for(int k=0; k<options_.tree_count_; ++k)
1234  {
1235  //get weights predicted by single tree
1236  weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row));
1237 
1238  //update votecount.
1239  int weighted = options_.predict_weighted_;
1240  for(int l=0; l<ext_param_.class_count_; ++l)
1241  {
1242  double cur_w = weights[l] * (weighted * (*(weights-1))
1243  + (1-weighted));
1244  prob(row, l) += (T)cur_w;
1245  //every weight in totalWeight.
1246  totalWeight += cur_w;
1247  }
1248  if(stop.after_prediction(weights,
1249  k,
1250  rowVector(prob, row),
1251  totalWeight))
1252  {
1253  break;
1254  }
1255  }
1256 
1257  //Normalise votes in each row by total VoteCount (totalWeight
1258  for(int l=0; l< ext_param_.class_count_; ++l)
1259  {
1260  prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1261  }
1262  }
1263 
1264 }
1265 
1266 template <class LabelType, class PreprocessorTag>
1267 template <class U, class C1, class T, class C2>
1268 void RandomForest<LabelType, PreprocessorTag>
1269  ::predictRaw(MultiArrayView<2, U, C1>const & features,
1270  MultiArrayView<2, T, C2> & prob) const
1271 {
1272  //Features are n xp
1273  //prob is n x NumOfLabel probability for each feature in each class
1274 
1275  vigra_precondition(rowCount(features) == rowCount(prob),
1276  "RandomForestn::predictProbabilities():"
1277  " Feature matrix and probability matrix size mismatch.");
1278 
1279  // num of features must be bigger than num of features in Random forest training
1280  // but why bigger?
1281  vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1282  "RandomForestn::predictProbabilities():"
1283  " Too few columns in feature matrix.");
1284  vigra_precondition( columnCount(prob)
1285  == (MultiArrayIndex)ext_param_.class_count_,
1286  "RandomForestn::predictProbabilities():"
1287  " Probability matrix must have as many columns as there are classes.");
1288 
1289  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1290  prob.init(NumericTraits<T>::zero());
1291  /* This code was originally there for testing early stopping
1292  * - we wanted the order of the trees to be randomized
1293  if(tree_indices_.size() != 0)
1294  {
1295  std::random_shuffle(tree_indices_.begin(),
1296  tree_indices_.end());
1297  }
1298  */
1299  //Classify for each row.
1300  for(int row=0; row < rowCount(features); ++row)
1301  {
1302  ArrayVector<double>::const_iterator weights;
1303 
1304  //totalWeight == totalVoteCount!
1305  double totalWeight = 0.0;
1306 
1307  //Let each tree classify...
1308  for(int k=0; k<options_.tree_count_; ++k)
1309  {
1310  //get weights predicted by single tree
1311  weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row));
1312 
1313  //update votecount.
1314  int weighted = options_.predict_weighted_;
1315  for(int l=0; l<ext_param_.class_count_; ++l)
1316  {
1317  double cur_w = weights[l] * (weighted * (*(weights-1))
1318  + (1-weighted));
1319  prob(row, l) += (T)cur_w;
1320  //every weight in totalWeight.
1321  totalWeight += cur_w;
1322  }
1323  }
1324  }
1325  prob/= options_.tree_count_;
1326 
1327 }
1328 
1329 //@}
1330 
1331 } // namespace vigra
1332 
1333 #include "random_forest/rf_algorithm.hxx"
1334 #endif // VIGRA_RANDOM_FOREST_HXX
Definition: rf_region.hxx:57
void set_ext_param(ProblemSpec_t const &in)
set external parameters
Definition: random_forest.hxx:280
int class_count() const
return number of classes used while training.
Definition: random_forest.hxx:345
Definition: rf_nodeproxy.hxx:626
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition: rf_common.hxx:131
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:669
Definition: rf_preprocessing.hxx:62
int feature_count() const
return number of features used while training.
Definition: random_forest.hxx:326
int column_count() const
return number of features used while training.
Definition: random_forest.hxx:337
Create random samples from a sequence of indices.
Definition: sampling.hxx:233
const difference_type & shape() const
Definition: multi_array.hxx:1602
Definition: rf_split.hxx:993
const_iterator begin() const
Definition: array_vector.hxx:223
problem specification class for the random forest.
Definition: rf_common.hxx:535
RandomForest(Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t())
default constructor
Definition: random_forest.hxx:198
void sample()
Definition: sampling.hxx:468
std::ptrdiff_t MultiArrayIndex
Definition: multi_iterator.hxx:348
MultiArray< 2, double > garbage_prediction_
Definition: random_forest.hxx:161
Standard early stopping criterion.
Definition: rf_common.hxx:882
Definition: random.hxx:648
ProblemSpec_t const & ext_param() const
return external parameters for viewing
Definition: random_forest.hxx:262
DecisionTree_t & tree(int index)
access trees
Definition: random_forest.hxx:316
DecisionTree_t const & tree(int index) const
access const trees
Definition: random_forest.hxx:309
Options_t & set_options()
access random forest options
Definition: random_forest.hxx:292
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition: random_forest.hxx:893
void reLearnTree(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random)
Definition: random_forest.hxx:790
Definition: random_forest.hxx:143
Options_t const & options() const
access const random forest options
Definition: random_forest.hxx:302
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const
predict the class probabilities for multiple labels
Definition: rf_visitors.hxx:244
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Definition: algorithm.hxx:96
Definition: rf_visitors.hxx:573
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:939
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const
predict multiple labels with given features
Definition: random_forest.hxx:586
SamplerOptions & stratified(bool in=true)
Draw equally many samples from each "stratum". A stratum is a group of like entities, e.g. pixels belonging to the same object class. This is useful to create balanced samples when the class probabilities are very unbalanced (e.g. when there are many background and few foreground pixels). Stratified sampling thus avoids that a trained classifier is biased towards the majority class.
Definition: sampling.hxx:144
SamplerOptions & withReplacement(bool in=true)
Sample from training population with replacement.
Definition: sampling.hxx:86
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const
predict the class probabilities for multiple labels
Definition: random_forest.hxx:628
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:695
int tree_count() const
return number of trees
Definition: random_forest.hxx:352
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:682
RandomForest(int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t())
Create RF from external source.
Definition: random_forest.hxx:232
Definition: random.hxx:335
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:593
Options object for the random forest.
Definition: rf_common.hxx:170
MultiArrayView & init(const U &init)
Definition: multi_array.hxx:1214
LabelType predictLabel(MultiArrayView< 2, U, C >const &features, Stop &stop) const
predict a label given a feature.
Definition: random_forest.hxx:1004
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels)
learn on data with default configuration
Definition: random_forest.hxx:532
Definition: rf_visitors.hxx:224

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.9.0 (Sun Aug 10 2014)