[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_common.hxx
1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 
37 #ifndef VIGRA_RF_COMMON_HXX
38 #define VIGRA_RF_COMMON_HXX
39 
40 namespace vigra
41 {
42 
43 
44 struct ClassificationTag
45 {};
46 
47 struct RegressionTag
48 {};
49 
50 namespace detail
51 {
52  class RF_DEFAULT;
53 }
54 inline detail::RF_DEFAULT& rf_default();
55 namespace detail
56 {
57 
58 /* \brief singleton default tag class -
59  *
60  * use the rf_default() factory function to use the tag.
61  * \sa RandomForest<>::learn();
62  */
63 class RF_DEFAULT
64 {
65  private:
66  RF_DEFAULT()
67  {}
68  public:
69  friend RF_DEFAULT& ::vigra::rf_default();
70 
71  /** ok workaround for automatic choice of the decisiontree
72  * stackentry.
73  */
74 };
75 
76 /* \brief chooses between default type and type supplied
77  *
78  * This is an internal class and you shouldn't really care about it.
79  * Just pass on used in RandomForest.learn()
80  * Usage:
81  *\code
82  * // example: use container type supplied by user or ArrayVector if
83  * // rf_default() was specified as argument;
84  * template<class Container_t>
85  * void do_some_foo(Container_t in)
86  * {
87  * typedef ArrayVector<int> Default_Container_t;
88  * Default_Container_t default_value;
89  * Value_Chooser<Container_t, Default_Container_t>
90  * choose(in, default_value);
91  *
92  * // if the user didn't care and the in was of type
93  * // RF_DEFAULT then default_value is used.
94  * do_some_more_foo(choose.value());
95  * }
96  * Value_Chooser choose_val<Type, Default_Type>
97  *\endcode
98  */
99 template<class T, class C>
100 class Value_Chooser
101 {
102 public:
103  typedef T type;
104  static T & choose(T & t, C &)
105  {
106  return t;
107  }
108 };
109 
110 template<class C>
111 class Value_Chooser<detail::RF_DEFAULT, C>
112 {
113 public:
114  typedef C type;
115 
116  static C & choose(detail::RF_DEFAULT &, C & c)
117  {
118  return c;
119  }
120 };
121 
122 
123 
124 
125 } //namespace detail
126 
127 
128 /**\brief factory function to return a RF_DEFAULT tag
129  * \sa RandomForest<>::learn()
130  */
131 detail::RF_DEFAULT& rf_default()
132 {
133  static detail::RF_DEFAULT result;
134  return result;
135 }
136 
137 /** tags used with the RandomForestOptions class
138  * \sa RF_Traits::Option_t
139  */
140 enum RF_OptionTag { RF_EQUAL,
141  RF_PROPORTIONAL,
142  RF_EXTERNAL,
143  RF_NONE,
144  RF_FUNCTION,
145  RF_LOG,
146  RF_SQRT,
147  RF_CONST,
148  RF_ALL};
149 
150 
151 /** \addtogroup MachineLearning
152 **/
153 //@{
154 
155 /**\brief Options object for the random forest
156  *
157  * usage:
158  * RandomForestOptions a = RandomForestOptions()
159  * .param1(value1)
160  * .param2(value2)
161  * ...
162  *
163  * This class only contains options/parameters that are not problem
164  * dependent. The ProblemSpec class contains methods to set class weights
165  * if necessary.
166  *
167  * Note that the return value of all methods is *this which makes
168  * concatenating of options as above possible.
169  */
171 {
172  public:
173  /**\name sampling options*/
174  /*\{*/
175  // look at the member access functions for documentation
176  double training_set_proportion_;
177  int training_set_size_;
178  int (*training_set_func_)(int);
180  training_set_calc_switch_;
181 
182  bool sample_with_replacement_;
184  stratification_method_;
185 
186 
187  /**\name general random forest options
188  *
189  * these usually will be used by most split functors and
190  * stopping predicates
191  */
192  /*\{*/
193  RF_OptionTag mtry_switch_;
194  int mtry_;
195  int (*mtry_func_)(int) ;
196 
197  bool predict_weighted_;
198  int tree_count_;
199  int min_split_node_size_;
200  bool prepare_online_learning_;
201  /*\}*/
202 
204  typedef std::map<std::string, double_array> map_type;
205 
206  int serialized_size() const
207  {
208  return 12;
209  }
210 
211 
212  bool operator==(RandomForestOptions & rhs) const
213  {
214  bool result = true;
215  #define COMPARE(field) result = result && (this->field == rhs.field);
216  COMPARE(training_set_proportion_);
217  COMPARE(training_set_size_);
218  COMPARE(training_set_calc_switch_);
219  COMPARE(sample_with_replacement_);
220  COMPARE(stratification_method_);
221  COMPARE(mtry_switch_);
222  COMPARE(mtry_);
223  COMPARE(tree_count_);
224  COMPARE(min_split_node_size_);
225  COMPARE(predict_weighted_);
226  #undef COMPARE
227 
228  return result;
229  }
230  bool operator!=(RandomForestOptions & rhs_) const
231  {
232  return !(*this == rhs_);
233  }
234  template<class Iter>
235  void unserialize(Iter const & begin, Iter const & end)
236  {
237  Iter iter = begin;
238  vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
239  "RandomForestOptions::unserialize():"
240  "wrong number of parameters");
241  #define PULL(item_, type_) item_ = type_(*iter); ++iter;
242  PULL(training_set_proportion_, double);
243  PULL(training_set_size_, int);
244  ++iter; //PULL(training_set_func_, double);
245  PULL(training_set_calc_switch_, (RF_OptionTag)int);
246  PULL(sample_with_replacement_, 0 != );
247  PULL(stratification_method_, (RF_OptionTag)int);
248  PULL(mtry_switch_, (RF_OptionTag)int);
249  PULL(mtry_, int);
250  ++iter; //PULL(mtry_func_, double);
251  PULL(tree_count_, int);
252  PULL(min_split_node_size_, int);
253  PULL(predict_weighted_, 0 !=);
254  #undef PULL
255  }
256  template<class Iter>
257  void serialize(Iter const & begin, Iter const & end) const
258  {
259  Iter iter = begin;
260  vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
261  "RandomForestOptions::serialize():"
262  "wrong number of parameters");
263  #define PUSH(item_) *iter = double(item_); ++iter;
264  PUSH(training_set_proportion_);
265  PUSH(training_set_size_);
266  if(training_set_func_ != 0)
267  {
268  PUSH(1);
269  }
270  else
271  {
272  PUSH(0);
273  }
274  PUSH(training_set_calc_switch_);
275  PUSH(sample_with_replacement_);
276  PUSH(stratification_method_);
277  PUSH(mtry_switch_);
278  PUSH(mtry_);
279  if(mtry_func_ != 0)
280  {
281  PUSH(1);
282  }
283  else
284  {
285  PUSH(0);
286  }
287  PUSH(tree_count_);
288  PUSH(min_split_node_size_);
289  PUSH(predict_weighted_);
290  #undef PUSH
291  }
292 
293  void make_from_map(map_type & in) // -> const: .operator[] -> .find
294  {
295  typedef MultiArrayShape<2>::type Shp;
296  #define PULL(item_, type_) item_ = type_(in[#item_][0]);
297  #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0);
298  PULL(training_set_proportion_,double);
299  PULL(training_set_size_, int);
300  PULL(mtry_, int);
301  PULL(tree_count_, int);
302  PULL(min_split_node_size_, int);
303  PULLBOOL(sample_with_replacement_, bool);
304  PULLBOOL(prepare_online_learning_, bool);
305  PULLBOOL(predict_weighted_, bool);
306 
307  PULL(training_set_calc_switch_, (RF_OptionTag)(int));
308 
309  PULL(stratification_method_, (RF_OptionTag)(int));
310  PULL(mtry_switch_, (RF_OptionTag)(int));
311 
312  /*don't pull*/
313  //PULL(mtry_func_!=0, int);
314  //PULL(training_set_func,int);
315  #undef PULL
316  #undef PULLBOOL
317  }
318  void make_map(map_type & in) const
319  {
320  typedef MultiArrayShape<2>::type Shp;
321  #define PUSH(item_, type_) in[#item_] = double_array(1, double(item_));
322  #define PUSHFUNC(item_, type_) in[#item_] = double_array(1, double(item_!=0));
323  PUSH(training_set_proportion_,double);
324  PUSH(training_set_size_, int);
325  PUSH(mtry_, int);
326  PUSH(tree_count_, int);
327  PUSH(min_split_node_size_, int);
328  PUSH(sample_with_replacement_, bool);
329  PUSH(prepare_online_learning_, bool);
330  PUSH(predict_weighted_, bool);
331 
332  PUSH(training_set_calc_switch_, RF_OptionTag);
333  PUSH(stratification_method_, RF_OptionTag);
334  PUSH(mtry_switch_, RF_OptionTag);
335 
336  PUSHFUNC(mtry_func_, int);
337  PUSHFUNC(training_set_func_,int);
338  #undef PUSH
339  #undef PUSHFUNC
340  }
341 
342 
343  /**\brief create a RandomForestOptions object with default initialisation.
344  *
345  * look at the other member functions for more information on default
346  * values
347  */
349  :
350  training_set_proportion_(1.0),
351  training_set_size_(0),
352  training_set_func_(0),
353  training_set_calc_switch_(RF_PROPORTIONAL),
354  sample_with_replacement_(true),
355  stratification_method_(RF_NONE),
356  mtry_switch_(RF_SQRT),
357  mtry_(0),
358  mtry_func_(0),
359  predict_weighted_(false),
360  tree_count_(256),
361  min_split_node_size_(1),
362  prepare_online_learning_(false)
363  {}
364 
365  /**\brief specify stratification strategy
366  *
367  * default: RF_NONE
368  * possible values: RF_EQUAL, RF_PROPORTIONAL,
369  * RF_EXTERNAL, RF_NONE
370  * RF_EQUAL: get equal amount of samples per class.
371  * RF_PROPORTIONAL: sample proportional to fraction of class samples
372  * in population
373  * RF_EXTERNAL: strata_weights_ field of the ProblemSpec_t object
374  * has been set externally. (defunct)
375  */
377  {
378  vigra_precondition(in == RF_EQUAL ||
379  in == RF_PROPORTIONAL ||
380  in == RF_EXTERNAL ||
381  in == RF_NONE,
382  "RandomForestOptions::use_stratification()"
383  "input must be RF_EQUAL, RF_PROPORTIONAL,"
384  "RF_EXTERNAL or RF_NONE");
385  stratification_method_ = in;
386  return *this;
387  }
388 
389  RandomForestOptions & prepare_online_learning(bool in)
390  {
391  prepare_online_learning_=in;
392  return *this;
393  }
394 
395  /**\brief sample from training population with or without replacement?
396  *
397  * <br> Default: true
398  */
400  {
401  sample_with_replacement_ = in;
402  return *this;
403  }
404 
405  /**\brief specify the fraction of the total number of samples
406  * used per tree for learning.
407  *
408  * This value should be in [0.0 1.0] if sampling without
409  * replacement has been specified.
410  *
411  * <br> default : 1.0
412  */
414  {
415  training_set_proportion_ = in;
416  training_set_calc_switch_ = RF_PROPORTIONAL;
417  return *this;
418  }
419 
420  /**\brief directly specify the number of samples per tree
421  */
423  {
424  training_set_size_ = in;
425  training_set_calc_switch_ = RF_CONST;
426  return *this;
427  }
428 
429  /**\brief use external function to calculate the number of samples each
430  * tree should be learnt with.
431  *
432  * \param in function pointer that takes the number of rows in the
433  * learning data and outputs the number samples per tree.
434  */
436  {
437  training_set_func_ = in;
438  training_set_calc_switch_ = RF_FUNCTION;
439  return *this;
440  }
441 
442  /**\brief weight each tree with number of samples in that node
443  */
445  {
446  predict_weighted_ = true;
447  return *this;
448  }
449 
450  /**\brief use built in mapping to calculate mtry
451  *
452  * Use one of the built in mappings to calculate mtry from the number
453  * of columns in the input feature data.
454  * \param in possible values: RF_LOG, RF_SQRT or RF_ALL
455  * <br> default: RF_SQRT.
456  */
458  {
459  vigra_precondition(in == RF_LOG ||
460  in == RF_SQRT||
461  in == RF_ALL,
462  "RandomForestOptions()::features_per_node():"
463  "input must be of type RF_LOG or RF_SQRT");
464  mtry_switch_ = in;
465  return *this;
466  }
467 
468  /**\brief Set mtry to a constant value
469  *
470  * mtry is the number of columns/variates/variables randomly chosen
471  * to select the best split from.
472  *
473  */
475  {
476  mtry_ = in;
477  mtry_switch_ = RF_CONST;
478  return *this;
479  }
480 
481  /**\brief use a external function to calculate mtry
482  *
483  * \param in function pointer that takes int (number of columns
484  * of the and outputs int (mtry)
485  */
487  {
488  mtry_func_ = in;
489  mtry_switch_ = RF_FUNCTION;
490  return *this;
491  }
492 
493  /** How many trees to create?
494  *
495  * <br> Default: 255.
496  */
498  {
499  tree_count_ = in;
500  return *this;
501  }
502 
503  /**\brief Number of examples required for a node to be split.
504  *
505  * When the number of examples in a node is below this number,
506  * the node is not split even if class separation is not yet perfect.
507  * Instead, the node returns the proportion of each class
508  * (among the remaining examples) during the prediction phase.
509  * <br> Default: 1 (complete growing)
510  */
512  {
513  min_split_node_size_ = in;
514  return *this;
515  }
516 };
517 
518 
519 /** \brief problem types
520  */
521 enum Problem_t{REGRESSION, CLASSIFICATION, CHECKLATER};
522 
523 
524 /** \brief problem specification class for the random forest.
525  *
526  * This class contains all the problem specific parameters the random
527  * forest needs for learning. Specification of an instance of this class
528  * is optional as all necessary fields will be computed prior to learning
529  * if not specified.
530  *
531  * if needed usage is similar to that of RandomForestOptions
532  */
533 
534 template<class LabelType = double>
536 {
537 
538 
539 public:
540 
541  /** \brief problem class
542  */
543 
544  typedef LabelType Label_t;
545  ArrayVector<Label_t> classes;
547  typedef std::map<std::string, double_array> map_type;
548 
549  int column_count_; // number of features
550  int class_count_; // number of classes
551  int row_count_; // number of samples
552 
553  int actual_mtry_; // mtry used in training
554  int actual_msample_; // number if in-bag samples per tree
555 
556  Problem_t problem_type_; // classification or regression
557 
558  int used_; // this ProblemSpec is valid
559  ArrayVector<double> class_weights_; // if classes have different importance
560  int is_weighted_; // class_weights_ are used
561  double precision_; // termination criterion for regression loss
562  int response_size_;
563 
564  template<class T>
565  void to_classlabel(int index, T & out) const
566  {
567  out = T(classes[index]);
568  }
569  template<class T>
570  int to_classIndex(T index) const
571  {
572  return std::find(classes.begin(), classes.end(), index) - classes.begin();
573  }
574 
575  #define EQUALS(field) field(rhs.field)
576  ProblemSpec(ProblemSpec const & rhs)
577  :
578  EQUALS(column_count_),
579  EQUALS(class_count_),
580  EQUALS(row_count_),
581  EQUALS(actual_mtry_),
582  EQUALS(actual_msample_),
583  EQUALS(problem_type_),
584  EQUALS(used_),
585  EQUALS(class_weights_),
586  EQUALS(is_weighted_),
587  EQUALS(precision_),
588  EQUALS(response_size_)
589  {
590  std::back_insert_iterator<ArrayVector<Label_t> >
591  iter(classes);
592  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
593  }
594  #undef EQUALS
595  #define EQUALS(field) field(rhs.field)
596  template<class T>
597  ProblemSpec(ProblemSpec<T> const & rhs)
598  :
599  EQUALS(column_count_),
600  EQUALS(class_count_),
601  EQUALS(row_count_),
602  EQUALS(actual_mtry_),
603  EQUALS(actual_msample_),
604  EQUALS(problem_type_),
605  EQUALS(used_),
606  EQUALS(class_weights_),
607  EQUALS(is_weighted_),
608  EQUALS(precision_),
609  EQUALS(response_size_)
610  {
611  std::back_insert_iterator<ArrayVector<Label_t> >
612  iter(classes);
613  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
614  }
615  #undef EQUALS
616 
617  #define EQUALS(field) (this->field = rhs.field);
618  ProblemSpec & operator=(ProblemSpec const & rhs)
619  {
620  EQUALS(column_count_);
621  EQUALS(class_count_);
622  EQUALS(row_count_);
623  EQUALS(actual_mtry_);
624  EQUALS(actual_msample_);
625  EQUALS(problem_type_);
626  EQUALS(used_);
627  EQUALS(is_weighted_);
628  EQUALS(precision_);
629  EQUALS(response_size_)
630  class_weights_.clear();
631  std::back_insert_iterator<ArrayVector<double> >
632  iter2(class_weights_);
633  std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
634  classes.clear();
635  std::back_insert_iterator<ArrayVector<Label_t> >
636  iter(classes);
637  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
638  return *this;
639  }
640 
641  template<class T>
642  ProblemSpec<Label_t> & operator=(ProblemSpec<T> const & rhs)
643  {
644  EQUALS(column_count_);
645  EQUALS(class_count_);
646  EQUALS(row_count_);
647  EQUALS(actual_mtry_);
648  EQUALS(actual_msample_);
649  EQUALS(problem_type_);
650  EQUALS(used_);
651  EQUALS(is_weighted_);
652  EQUALS(precision_);
653  EQUALS(response_size_)
654  class_weights_.clear();
655  std::back_insert_iterator<ArrayVector<double> >
656  iter2(class_weights_);
657  std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
658  classes.clear();
659  std::back_insert_iterator<ArrayVector<Label_t> >
660  iter(classes);
661  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
662  return *this;
663  }
664  #undef EQUALS
665 
666  template<class T>
667  bool operator==(ProblemSpec<T> const & rhs)
668  {
669  bool result = true;
670  #define COMPARE(field) result = result && (this->field == rhs.field);
671  COMPARE(column_count_);
672  COMPARE(class_count_);
673  COMPARE(row_count_);
674  COMPARE(actual_mtry_);
675  COMPARE(actual_msample_);
676  COMPARE(problem_type_);
677  COMPARE(is_weighted_);
678  COMPARE(precision_);
679  COMPARE(used_);
680  COMPARE(class_weights_);
681  COMPARE(classes);
682  COMPARE(response_size_)
683  #undef COMPARE
684  return result;
685  }
686 
687  bool operator!=(ProblemSpec & rhs)
688  {
689  return !(*this == rhs);
690  }
691 
692 
693  size_t serialized_size() const
694  {
695  return 9 + class_count_ *int(is_weighted_+1);
696  }
697 
698 
699  template<class Iter>
700  void unserialize(Iter const & begin, Iter const & end)
701  {
702  Iter iter = begin;
703  vigra_precondition(end - begin >= 9,
704  "ProblemSpec::unserialize():"
705  "wrong number of parameters");
706  #define PULL(item_, type_) item_ = type_(*iter); ++iter;
707  PULL(column_count_,int);
708  PULL(class_count_, int);
709 
710  vigra_precondition(end - begin >= 9 + class_count_,
711  "ProblemSpec::unserialize(): 1");
712  PULL(row_count_, int);
713  PULL(actual_mtry_,int);
714  PULL(actual_msample_, int);
715  PULL(problem_type_, Problem_t);
716  PULL(is_weighted_, int);
717  PULL(used_, int);
718  PULL(precision_, double);
719  PULL(response_size_, int);
720  if(is_weighted_)
721  {
722  vigra_precondition(end - begin == 9 + 2*class_count_,
723  "ProblemSpec::unserialize(): 2");
724  class_weights_.insert(class_weights_.end(),
725  iter,
726  iter + class_count_);
727  iter += class_count_;
728  }
729  classes.insert(classes.end(), iter, end);
730  #undef PULL
731  }
732 
733 
734  template<class Iter>
735  void serialize(Iter const & begin, Iter const & end) const
736  {
737  Iter iter = begin;
738  vigra_precondition(end - begin == serialized_size(),
739  "RandomForestOptions::serialize():"
740  "wrong number of parameters");
741  #define PUSH(item_) *iter = double(item_); ++iter;
742  PUSH(column_count_);
743  PUSH(class_count_)
744  PUSH(row_count_);
745  PUSH(actual_mtry_);
746  PUSH(actual_msample_);
747  PUSH(problem_type_);
748  PUSH(is_weighted_);
749  PUSH(used_);
750  PUSH(precision_);
751  PUSH(response_size_);
752  if(is_weighted_)
753  {
754  std::copy(class_weights_.begin(),
755  class_weights_.end(),
756  iter);
757  iter += class_count_;
758  }
759  std::copy(classes.begin(),
760  classes.end(),
761  iter);
762  #undef PUSH
763  }
764 
765  void make_from_map(map_type & in) // -> const: .operator[] -> .find
766  {
767  typedef MultiArrayShape<2>::type Shp;
768  #define PULL(item_, type_) item_ = type_(in[#item_][0]);
769  PULL(column_count_,int);
770  PULL(class_count_, int);
771  PULL(row_count_, int);
772  PULL(actual_mtry_,int);
773  PULL(actual_msample_, int);
774  PULL(problem_type_, (Problem_t)int);
775  PULL(is_weighted_, int);
776  PULL(used_, int);
777  PULL(precision_, double);
778  PULL(response_size_, int);
779  class_weights_ = in["class_weights_"];
780  #undef PUSH
781  }
782  void make_map(map_type & in) const
783  {
784  typedef MultiArrayShape<2>::type Shp;
785  #define PUSH(item_) in[#item_] = double_array(1, double(item_));
786  PUSH(column_count_);
787  PUSH(class_count_)
788  PUSH(row_count_);
789  PUSH(actual_mtry_);
790  PUSH(actual_msample_);
791  PUSH(problem_type_);
792  PUSH(is_weighted_);
793  PUSH(used_);
794  PUSH(precision_);
795  PUSH(response_size_);
796  in["class_weights_"] = class_weights_;
797  #undef PUSH
798  }
799 
800  /**\brief set default values (-> values not set)
801  */
803  : column_count_(0),
804  class_count_(0),
805  row_count_(0),
806  actual_mtry_(0),
807  actual_msample_(0),
808  problem_type_(CHECKLATER),
809  used_(false),
810  is_weighted_(false),
811  precision_(0.0),
812  response_size_(1)
813  {}
814 
815 
816  ProblemSpec & column_count(int in)
817  {
818  column_count_ = in;
819  return *this;
820  }
821 
822  /**\brief supply with class labels -
823  *
824  * the preprocessor will not calculate the labels needed in this case.
825  */
826  template<class C_Iter>
827  ProblemSpec & classes_(C_Iter begin, C_Iter end)
828  {
829  int size = end-begin;
830  for(int k=0; k<size; ++k, ++begin)
831  classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
832  class_count_ = size;
833  return *this;
834  }
835 
836  /** \brief supply with class weights -
837  *
838  * this is the only case where you would really have to
839  * create a ProblemSpec object.
840  */
841  template<class W_Iter>
842  ProblemSpec & class_weights(W_Iter begin, W_Iter end)
843  {
844  class_weights_.insert(class_weights_.end(), begin, end);
845  is_weighted_ = true;
846  return *this;
847  }
848 
849 
850 
851  void clear()
852  {
853  used_ = false;
854  classes.clear();
855  class_weights_.clear();
856  column_count_ = 0 ;
857  class_count_ = 0;
858  actual_mtry_ = 0;
859  actual_msample_ = 0;
860  problem_type_ = CHECKLATER;
861  is_weighted_ = false;
862  precision_ = 0.0;
863  response_size_ = 0;
864 
865  }
866 
867  bool used() const
868  {
869  return used_ != 0;
870  }
871 };
872 
873 
874 //@}
875 
876 
877 
878 /**\brief Standard early stopping criterion
879  *
880  * Stop if region.size() < min_split_node_size_;
881  */
883 {
884  public:
885  int min_split_node_size_;
886 
887  template<class Opt>
888  EarlyStoppStd(Opt opt)
889  : min_split_node_size_(opt.min_split_node_size_)
890  {}
891 
892  template<class T>
893  void set_external_parameters(ProblemSpec<T>const &, int /* tree_count */ = 0, bool /* is_weighted_ */ = false)
894  {}
895 
896  template<class Region>
897  bool operator()(Region& region)
898  {
899  return region.size() < min_split_node_size_;
900  }
901 
902  template<class WeightIter, class T, class C>
903  bool after_prediction(WeightIter, int /* k */, MultiArrayView<2, T, C> /* prob */, double /* totalCt */)
904  {
905  return false;
906  }
907 };
908 
909 
910 } // namespace vigra
911 
912 #endif //VIGRA_RF_COMMON_HXX
RandomForestOptions & features_per_node(RF_OptionTag in)
use built in mapping to calculate mtry
Definition: rf_common.hxx:457
RandomForestOptions & tree_count(int in)
Definition: rf_common.hxx:497
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition: rf_common.hxx:131
RandomForestOptions & samples_per_tree(double in)
specify the fraction of the total number of samples used per tree for learning.
Definition: rf_common.hxx:413
RandomForestOptions & features_per_node(int(*in)(int))
use a external function to calculate mtry
Definition: rf_common.hxx:486
const_iterator begin() const
Definition: array_vector.hxx:223
RandomForestOptions & samples_per_tree(int in)
directly specify the number of samples per tree
Definition: rf_common.hxx:422
RandomForestOptions & samples_per_tree(int(*in)(int))
use external function to calculate the number of samples each tree should be learnt with...
Definition: rf_common.hxx:435
problem specification class for the random forest.
Definition: rf_common.hxx:535
LabelType Label_t
problem class
Definition: rf_common.hxx:544
RandomForestOptions & min_split_node_size(int in)
Number of examples required for a node to be split.
Definition: rf_common.hxx:511
RandomForestOptions & features_per_node(int in)
Set mtry to a constant value.
Definition: rf_common.hxx:474
Standard early stopping criterion.
Definition: rf_common.hxx:882
ProblemSpec & classes_(C_Iter begin, C_Iter end)
supply with class labels -
Definition: rf_common.hxx:827
RandomForestOptions()
create a RandomForestOptions object with default initialisation.
Definition: rf_common.hxx:348
ProblemSpec & class_weights(W_Iter begin, W_Iter end)
supply with class weights -
Definition: rf_common.hxx:842
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:939
RandomForestOptions & sample_with_replacement(bool in)
sample from training population with or without replacement?
Definition: rf_common.hxx:399
RandomForestOptions & predict_weighted()
weight each tree with number of samples in that node
Definition: rf_common.hxx:444
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:593
Options object for the random forest.
Definition: rf_common.hxx:170
RandomForestOptions & use_stratification(RF_OptionTag in)
specify stratification strategy
Definition: rf_common.hxx:376
const_iterator end() const
Definition: array_vector.hxx:237
ProblemSpec()
set default values (-> values not set)
Definition: rf_common.hxx:802
RF_OptionTag
Definition: rf_common.hxx:140
Problem_t
problem types
Definition: rf_common.hxx:521

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.9.0 (Sun Aug 10 2014)