[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_visitors.hxx
1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 #ifndef RF_VISITORS_HXX
36 #define RF_VISITORS_HXX
37 
38 #ifdef HasHDF5
39 # include "vigra/hdf5impex.hxx"
40 #endif // HasHDF5
41 #include <vigra/windows.h>
42 #include <iostream>
43 #include <iomanip>
44 
45 #include <vigra/multi_pointoperators.hxx>
46 #include <vigra/timing.hxx>
47 
48 namespace vigra
49 {
50 namespace rf
51 {
52 /** \addtogroup MachineLearning Machine Learning
53 **/
54 //@{
55 
56 /**
57  This namespace contains all classes and methods related to extracting information during
58  learning of the random forest. All Visitors share the same interface defined in
59  visitors::VisitorBase. The member methods are invoked at certain points of the main code in
60  the order they were supplied.
61 
62  For the Random Forest the Visitor concept is implemented as a statically linked list
63  (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The
64  VisitorNode object calls the Next Visitor after one of its visit() methods have terminated.
65 
66  To simplify usage create_visitor() factory methods are supplied.
67  Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method.
68  It is possible to supply more than one visitor. They will then be invoked in serial order.
69 
70  The calculated information are stored as public data members of the class. - see documentation
71  of the individual visitors
72 
73  While creating a new visitor the new class should therefore publicly inherit from this class
74  (i.e.: see visitors::OOB_Error).
75 
76  \code
77 
78  typedef xxx feature_t \\ replace xxx with whichever type
79  typedef yyy label_t \\ meme chose.
80  MultiArrayView<2, feature_t> f = get_some_features();
81  MultiArrayView<2, label_t> l = get_some_labels();
82  RandomForest<> rf()
83 
84  //calculate OOB Error
85  visitors::OOB_Error oob_v;
86  //calculate Variable Importance
87  visitors::VariableImportanceVisitor varimp_v;
88 
89  double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v);
90  //the data can be found in the attributes of oob_v and varimp_v now
91 
92  \endcode
93 */
94 namespace visitors
95 {
96 
97 
98 /** Base Class from which all Visitors derive. Can be used as a template to create new
99  * Visitors.
100  */
102 {
103  public:
104  bool active_;
105  bool is_active()
106  {
107  return active_;
108  }
109 
110  bool has_value()
111  {
112  return false;
113  }
114 
115  VisitorBase()
116  : active_(true)
117  {}
118 
119  void deactivate()
120  {
121  active_ = false;
122  }
123  void activate()
124  {
125  active_ = true;
126  }
127 
128  /** do something after the the Split has decided how to process the Region
129  * (Stack entry)
130  *
131  * \param tree reference to the tree that is currently being learned
132  * \param split reference to the split object
133  * \param parent current stack entry which was used to decide the split
134  * \param leftChild left stack entry that will be pushed
135  * \param rightChild
136  * right stack entry that will be pushed.
137  * \param features features matrix
138  * \param labels label matrix
139  * \sa RF_Traits::StackEntry_t
140  */
141  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
142  void visit_after_split( Tree & tree,
143  Split & split,
144  Region & parent,
145  Region & leftChild,
146  Region & rightChild,
147  Feature_t & features,
148  Label_t & labels)
149  {}
150 
151  /** do something after each tree has been learned
152  *
153  * \param rf reference to the random forest object that called this
154  * visitor
155  * \param pr reference to the preprocessor that processed the input
156  * \param sm reference to the sampler object
157  * \param st reference to the first stack entry
158  * \param index index of current tree
159  */
160  template<class RF, class PR, class SM, class ST>
161  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
162  {}
163 
164  /** do something after all trees have been learned
165  *
166  * \param rf reference to the random forest object that called this
167  * visitor
168  * \param pr reference to the preprocessor that processed the input
169  */
170  template<class RF, class PR>
171  void visit_at_end(RF const & rf, PR const & pr)
172  {}
173 
174  /** do something before learning starts
175  *
176  * \param rf reference to the random forest object that called this
177  * visitor
178  * \param pr reference to the Processor class used.
179  */
180  template<class RF, class PR>
181  void visit_at_beginning(RF const & rf, PR const & pr)
182  {}
183  /** do some thing while traversing tree after it has been learned
184  * (external nodes)
185  *
186  * \param tr reference to the tree object that called this visitor
187  * \param index index in the topology_ array we currently are at
188  * \param node_t type of node we have (will be e_.... - )
189  * \param features feature matrix
190  * \sa NodeTags;
191  *
192  * you can create the node by using a switch on node_tag and using the
193  * corresponding Node objects. Or - if you do not care about the type
194  * use the NodeBase class.
195  */
196  template<class TR, class IntT, class TopT,class Feat>
197  void visit_external_node(TR & tr, IntT index, TopT node_t,Feat & features)
198  {}
199 
200  /** do something when visiting a internal node after it has been learned
201  *
202  * \sa visit_external_node
203  */
204  template<class TR, class IntT, class TopT,class Feat>
205  void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
206  {}
207 
208  /** return a double value. The value of the first
209  * visitor encountered that has a return value is returned with the
210  * RandomForest::learn() method - or -1.0 if no return value visitor
211  * existed. This functionality basically only exists so that the
212  * OOB - visitor can return the oob error rate like in the old version
213  * of the random forest.
214  */
215  double return_val()
216  {
217  return -1.0;
218  }
219 };
220 
221 
222 /** Last Visitor that should be called to stop the recursion.
223  */
225 {
226  public:
227  bool has_value()
228  {
229  return true;
230  }
231  double return_val()
232  {
233  return -1.0;
234  }
235 };
236 namespace detail
237 {
238 /** Container elements of the statically linked Visitor list.
239  *
240  * use the create_visitor() factory functions to create visitors up to size 10;
241  *
242  */
243 template <class Visitor, class Next = StopVisiting>
245 {
246  public:
247 
248  StopVisiting stop_;
249  Next next_;
250  Visitor & visitor_;
251  VisitorNode(Visitor & visitor, Next & next)
252  :
253  next_(next), visitor_(visitor)
254  {}
255 
256  VisitorNode(Visitor & visitor)
257  :
258  next_(stop_), visitor_(visitor)
259  {}
260 
261  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
262  void visit_after_split( Tree & tree,
263  Split & split,
264  Region & parent,
265  Region & leftChild,
266  Region & rightChild,
267  Feature_t & features,
268  Label_t & labels)
269  {
270  if(visitor_.is_active())
271  visitor_.visit_after_split(tree, split,
272  parent, leftChild, rightChild,
273  features, labels);
274  next_.visit_after_split(tree, split, parent, leftChild, rightChild,
275  features, labels);
276  }
277 
278  template<class RF, class PR, class SM, class ST>
279  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
280  {
281  if(visitor_.is_active())
282  visitor_.visit_after_tree(rf, pr, sm, st, index);
283  next_.visit_after_tree(rf, pr, sm, st, index);
284  }
285 
286  template<class RF, class PR>
287  void visit_at_beginning(RF & rf, PR & pr)
288  {
289  if(visitor_.is_active())
290  visitor_.visit_at_beginning(rf, pr);
291  next_.visit_at_beginning(rf, pr);
292  }
293  template<class RF, class PR>
294  void visit_at_end(RF & rf, PR & pr)
295  {
296  if(visitor_.is_active())
297  visitor_.visit_at_end(rf, pr);
298  next_.visit_at_end(rf, pr);
299  }
300 
301  template<class TR, class IntT, class TopT,class Feat>
302  void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
303  {
304  if(visitor_.is_active())
305  visitor_.visit_external_node(tr, index, node_t,features);
306  next_.visit_external_node(tr, index, node_t,features);
307  }
308  template<class TR, class IntT, class TopT,class Feat>
309  void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
310  {
311  if(visitor_.is_active())
312  visitor_.visit_internal_node(tr, index, node_t,features);
313  next_.visit_internal_node(tr, index, node_t,features);
314  }
315 
316  double return_val()
317  {
318  if(visitor_.is_active() && visitor_.has_value())
319  return visitor_.return_val();
320  return next_.return_val();
321  }
322 };
323 
324 } //namespace detail
325 
326 //////////////////////////////////////////////////////////////////////////////
327 // Visitor Factory function up to 10 visitors //
328 //////////////////////////////////////////////////////////////////////////////
329 
330 /** factory method to to be used with RandomForest::learn()
331  */
332 template<class A>
335 {
336  typedef detail::VisitorNode<A> _0_t;
337  _0_t _0(a);
338  return _0;
339 }
340 
341 
342 /** factory method to to be used with RandomForest::learn()
343  */
344 template<class A, class B>
345 detail::VisitorNode<A, detail::VisitorNode<B> >
346 create_visitor(A & a, B & b)
347 {
348  typedef detail::VisitorNode<B> _1_t;
349  _1_t _1(b);
350  typedef detail::VisitorNode<A, _1_t> _0_t;
351  _0_t _0(a, _1);
352  return _0;
353 }
354 
355 
356 /** factory method to to be used with RandomForest::learn()
357  */
358 template<class A, class B, class C>
359 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
360 create_visitor(A & a, B & b, C & c)
361 {
362  typedef detail::VisitorNode<C> _2_t;
363  _2_t _2(c);
364  typedef detail::VisitorNode<B, _2_t> _1_t;
365  _1_t _1(b, _2);
366  typedef detail::VisitorNode<A, _1_t> _0_t;
367  _0_t _0(a, _1);
368  return _0;
369 }
370 
371 
372 /** factory method to to be used with RandomForest::learn()
373  */
374 template<class A, class B, class C, class D>
375 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
376  detail::VisitorNode<D> > > >
377 create_visitor(A & a, B & b, C & c, D & d)
378 {
379  typedef detail::VisitorNode<D> _3_t;
380  _3_t _3(d);
381  typedef detail::VisitorNode<C, _3_t> _2_t;
382  _2_t _2(c, _3);
383  typedef detail::VisitorNode<B, _2_t> _1_t;
384  _1_t _1(b, _2);
385  typedef detail::VisitorNode<A, _1_t> _0_t;
386  _0_t _0(a, _1);
387  return _0;
388 }
389 
390 
391 /** factory method to to be used with RandomForest::learn()
392  */
393 template<class A, class B, class C, class D, class E>
394 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
395  detail::VisitorNode<D, detail::VisitorNode<E> > > > >
396 create_visitor(A & a, B & b, C & c,
397  D & d, E & e)
398 {
399  typedef detail::VisitorNode<E> _4_t;
400  _4_t _4(e);
401  typedef detail::VisitorNode<D, _4_t> _3_t;
402  _3_t _3(d, _4);
403  typedef detail::VisitorNode<C, _3_t> _2_t;
404  _2_t _2(c, _3);
405  typedef detail::VisitorNode<B, _2_t> _1_t;
406  _1_t _1(b, _2);
407  typedef detail::VisitorNode<A, _1_t> _0_t;
408  _0_t _0(a, _1);
409  return _0;
410 }
411 
412 
413 /** factory method to to be used with RandomForest::learn()
414  */
415 template<class A, class B, class C, class D, class E,
416  class F>
417 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
418  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
419 create_visitor(A & a, B & b, C & c,
420  D & d, E & e, F & f)
421 {
422  typedef detail::VisitorNode<F> _5_t;
423  _5_t _5(f);
424  typedef detail::VisitorNode<E, _5_t> _4_t;
425  _4_t _4(e, _5);
426  typedef detail::VisitorNode<D, _4_t> _3_t;
427  _3_t _3(d, _4);
428  typedef detail::VisitorNode<C, _3_t> _2_t;
429  _2_t _2(c, _3);
430  typedef detail::VisitorNode<B, _2_t> _1_t;
431  _1_t _1(b, _2);
432  typedef detail::VisitorNode<A, _1_t> _0_t;
433  _0_t _0(a, _1);
434  return _0;
435 }
436 
437 
438 /** factory method to to be used with RandomForest::learn()
439  */
440 template<class A, class B, class C, class D, class E,
441  class F, class G>
442 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
443  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
444  detail::VisitorNode<G> > > > > > >
445 create_visitor(A & a, B & b, C & c,
446  D & d, E & e, F & f, G & g)
447 {
448  typedef detail::VisitorNode<G> _6_t;
449  _6_t _6(g);
450  typedef detail::VisitorNode<F, _6_t> _5_t;
451  _5_t _5(f, _6);
452  typedef detail::VisitorNode<E, _5_t> _4_t;
453  _4_t _4(e, _5);
454  typedef detail::VisitorNode<D, _4_t> _3_t;
455  _3_t _3(d, _4);
456  typedef detail::VisitorNode<C, _3_t> _2_t;
457  _2_t _2(c, _3);
458  typedef detail::VisitorNode<B, _2_t> _1_t;
459  _1_t _1(b, _2);
460  typedef detail::VisitorNode<A, _1_t> _0_t;
461  _0_t _0(a, _1);
462  return _0;
463 }
464 
465 
466 /** factory method to to be used with RandomForest::learn()
467  */
468 template<class A, class B, class C, class D, class E,
469  class F, class G, class H>
470 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
471  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
472  detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
473 create_visitor(A & a, B & b, C & c,
474  D & d, E & e, F & f,
475  G & g, H & h)
476 {
477  typedef detail::VisitorNode<H> _7_t;
478  _7_t _7(h);
479  typedef detail::VisitorNode<G, _7_t> _6_t;
480  _6_t _6(g, _7);
481  typedef detail::VisitorNode<F, _6_t> _5_t;
482  _5_t _5(f, _6);
483  typedef detail::VisitorNode<E, _5_t> _4_t;
484  _4_t _4(e, _5);
485  typedef detail::VisitorNode<D, _4_t> _3_t;
486  _3_t _3(d, _4);
487  typedef detail::VisitorNode<C, _3_t> _2_t;
488  _2_t _2(c, _3);
489  typedef detail::VisitorNode<B, _2_t> _1_t;
490  _1_t _1(b, _2);
491  typedef detail::VisitorNode<A, _1_t> _0_t;
492  _0_t _0(a, _1);
493  return _0;
494 }
495 
496 
497 /** factory method to to be used with RandomForest::learn()
498  */
499 template<class A, class B, class C, class D, class E,
500  class F, class G, class H, class I>
501 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
502  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
503  detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
504 create_visitor(A & a, B & b, C & c,
505  D & d, E & e, F & f,
506  G & g, H & h, I & i)
507 {
508  typedef detail::VisitorNode<I> _8_t;
509  _8_t _8(i);
510  typedef detail::VisitorNode<H, _8_t> _7_t;
511  _7_t _7(h, _8);
512  typedef detail::VisitorNode<G, _7_t> _6_t;
513  _6_t _6(g, _7);
514  typedef detail::VisitorNode<F, _6_t> _5_t;
515  _5_t _5(f, _6);
516  typedef detail::VisitorNode<E, _5_t> _4_t;
517  _4_t _4(e, _5);
518  typedef detail::VisitorNode<D, _4_t> _3_t;
519  _3_t _3(d, _4);
520  typedef detail::VisitorNode<C, _3_t> _2_t;
521  _2_t _2(c, _3);
522  typedef detail::VisitorNode<B, _2_t> _1_t;
523  _1_t _1(b, _2);
524  typedef detail::VisitorNode<A, _1_t> _0_t;
525  _0_t _0(a, _1);
526  return _0;
527 }
528 
529 /** factory method to to be used with RandomForest::learn()
530  */
531 template<class A, class B, class C, class D, class E,
532  class F, class G, class H, class I, class J>
533 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
534  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
535  detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
536  detail::VisitorNode<J> > > > > > > > > >
537 create_visitor(A & a, B & b, C & c,
538  D & d, E & e, F & f,
539  G & g, H & h, I & i,
540  J & j)
541 {
542  typedef detail::VisitorNode<J> _9_t;
543  _9_t _9(j);
544  typedef detail::VisitorNode<I, _9_t> _8_t;
545  _8_t _8(i, _9);
546  typedef detail::VisitorNode<H, _8_t> _7_t;
547  _7_t _7(h, _8);
548  typedef detail::VisitorNode<G, _7_t> _6_t;
549  _6_t _6(g, _7);
550  typedef detail::VisitorNode<F, _6_t> _5_t;
551  _5_t _5(f, _6);
552  typedef detail::VisitorNode<E, _5_t> _4_t;
553  _4_t _4(e, _5);
554  typedef detail::VisitorNode<D, _4_t> _3_t;
555  _3_t _3(d, _4);
556  typedef detail::VisitorNode<C, _3_t> _2_t;
557  _2_t _2(c, _3);
558  typedef detail::VisitorNode<B, _2_t> _1_t;
559  _1_t _1(b, _2);
560  typedef detail::VisitorNode<A, _1_t> _0_t;
561  _0_t _0(a, _1);
562  return _0;
563 }
564 
565 //////////////////////////////////////////////////////////////////////////////
566 // Visitors of communal interest. //
567 //////////////////////////////////////////////////////////////////////////////
568 
569 
570 /** Visitor to gain information, later needed for online learning.
571  */
572 
574 {
575 public:
576  //Set if we adjust thresholds
577  bool adjust_thresholds;
578  //Current tree id
579  int tree_id;
580  //Last node id for finding parent
581  int last_node_id;
582  //Need to now the label for interior node visiting
583  vigra::Int32 current_label;
584  //marginal distribution for interior nodes
585  //
587  adjust_thresholds(false), tree_id(0), last_node_id(0), current_label(0)
588  {}
589  struct MarginalDistribution
590  {
591  ArrayVector<Int32> leftCounts;
592  Int32 leftTotalCounts;
593  ArrayVector<Int32> rightCounts;
594  Int32 rightTotalCounts;
595  double gap_left;
596  double gap_right;
597  };
599 
600  //All information for one tree
601  struct TreeOnlineInformation
602  {
603  std::vector<MarginalDistribution> mag_distributions;
604  std::vector<IndexList> index_lists;
605  //map for linear index of mag_distributions
606  std::map<int,int> interior_to_index;
607  //map for linear index of index_lists
608  std::map<int,int> exterior_to_index;
609  };
610 
611  //All trees
612  std::vector<TreeOnlineInformation> trees_online_information;
613 
614  /** Initialize, set the number of trees
615  */
616  template<class RF,class PR>
617  void visit_at_beginning(RF & rf,const PR & pr)
618  {
619  tree_id=0;
620  trees_online_information.resize(rf.options_.tree_count_);
621  }
622 
623  /** Reset a tree
624  */
625  void reset_tree(int tree_id)
626  {
627  trees_online_information[tree_id].mag_distributions.clear();
628  trees_online_information[tree_id].index_lists.clear();
629  trees_online_information[tree_id].interior_to_index.clear();
630  trees_online_information[tree_id].exterior_to_index.clear();
631  }
632 
633  /** simply increase the tree count
634  */
635  template<class RF, class PR, class SM, class ST>
636  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
637  {
638  tree_id++;
639  }
640 
641  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
642  void visit_after_split( Tree & tree,
643  Split & split,
644  Region & parent,
645  Region & leftChild,
646  Region & rightChild,
647  Feature_t & features,
648  Label_t & labels)
649  {
650  int linear_index;
651  int addr=tree.topology_.size();
652  if(split.createNode().typeID() == i_ThresholdNode)
653  {
654  if(adjust_thresholds)
655  {
656  //Store marginal distribution
657  linear_index=trees_online_information[tree_id].mag_distributions.size();
658  trees_online_information[tree_id].interior_to_index[addr]=linear_index;
659  trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
660 
661  trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
662  trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
663 
664  trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
665  trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
666  //Store the gap
667  double gap_left,gap_right;
668  int i;
669  gap_left=features(leftChild[0],split.bestSplitColumn());
670  for(i=1;i<leftChild.size();++i)
671  if(features(leftChild[i],split.bestSplitColumn())>gap_left)
672  gap_left=features(leftChild[i],split.bestSplitColumn());
673  gap_right=features(rightChild[0],split.bestSplitColumn());
674  for(i=1;i<rightChild.size();++i)
675  if(features(rightChild[i],split.bestSplitColumn())<gap_right)
676  gap_right=features(rightChild[i],split.bestSplitColumn());
677  trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
678  trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
679  }
680  }
681  else
682  {
683  //Store index list
684  linear_index=trees_online_information[tree_id].index_lists.size();
685  trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
686 
687  trees_online_information[tree_id].index_lists.push_back(IndexList());
688 
689  trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
690  std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
691  }
692  }
693  void add_to_index_list(int tree,int node,int index)
694  {
695  if(!this->active_)
696  return;
697  TreeOnlineInformation &ti=trees_online_information[tree];
698  ti.index_lists[ti.exterior_to_index[node]].push_back(index);
699  }
700  void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index)
701  {
702  if(!this->active_)
703  return;
704  trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
705  trees_online_information[src_tree].exterior_to_index.erase(src_index);
706  }
707  /** do something when visiting a internal node during getToLeaf
708  *
709  * remember as last node id, for finding the parent of the last external node
710  * also: adjust class counts and borders
711  */
712  template<class TR, class IntT, class TopT,class Feat>
713  void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
714  {
715  last_node_id=index;
716  if(adjust_thresholds)
717  {
718  vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes");
719  //Check if we are in the gap
720  double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
721  TreeOnlineInformation &ti=trees_online_information[tree_id];
722  MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
723  if(value>m.gap_left && value<m.gap_right)
724  {
725  //Check which site we want to go
726  if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
727  {
728  //We want to go left
729  m.gap_left=value;
730  }
731  else
732  {
733  //We want to go right
734  m.gap_right=value;
735  }
736  Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
737  }
738  //Adjust class counts
739  if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
740  {
741  ++m.rightTotalCounts;
742  ++m.rightCounts[current_label];
743  }
744  else
745  {
746  ++m.leftTotalCounts;
747  ++m.rightCounts[current_label];
748  }
749  }
750  }
751  /** do something when visiting a extern node during getToLeaf
752  *
753  * Store the new index!
754  */
755 };
756 
757 //////////////////////////////////////////////////////////////////////////////
758 // Out of Bag Error estimates //
759 //////////////////////////////////////////////////////////////////////////////
760 
761 
762 /** Visitor that calculates the oob error of each individual randomized
763  * decision tree.
764  *
765  * After training a tree, all those samples that are OOB for this particular tree
766  * are put down the tree and the error estimated.
767  * the per tree oob error is the average of the individual error estimates.
768  * (oobError = average error of one randomized tree)
769  * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error
770  * visitor)
771  */
773 {
774 public:
775  /** Average error of one randomized decision tree
776  */
777  double oobError;
778 
779  int totalOobCount;
780  ArrayVector<int> oobCount,oobErrorCount;
781 
783  : oobError(0.0),
784  totalOobCount(0)
785  {}
786 
787 
788  bool has_value()
789  {
790  return true;
791  }
792 
793 
794  /** does the basic calculation per tree*/
795  template<class RF, class PR, class SM, class ST>
796  void visit_after_tree( RF& rf, PR & pr, SM & sm, ST & st, int index)
797  {
798  //do the first time called.
799  if(int(oobCount.size()) != rf.ext_param_.row_count_)
800  {
801  oobCount.resize(rf.ext_param_.row_count_, 0);
802  oobErrorCount.resize(rf.ext_param_.row_count_, 0);
803  }
804  // go through the samples
805  for(int l = 0; l < rf.ext_param_.row_count_; ++l)
806  {
807  // if the lth sample is oob...
808  if(!sm.is_used()[l])
809  {
810  ++oobCount[l];
811  if( rf.tree(index)
812  .predictLabel(rowVector(pr.features(), l))
813  != pr.response()(l,0))
814  {
815  ++oobErrorCount[l];
816  }
817  }
818 
819  }
820  }
821 
822  /** Does the normalisation
823  */
824  template<class RF, class PR>
825  void visit_at_end(RF & rf, PR & pr)
826  {
827  // do some normalisation
828  for(int l=0; l < (int)rf.ext_param_.row_count_; ++l)
829  {
830  if(oobCount[l])
831  {
832  oobError += double(oobErrorCount[l]) / oobCount[l];
833  ++totalOobCount;
834  }
835  }
836  oobError/=totalOobCount;
837  }
838 
839 };
840 
841 /** Visitor that calculates the oob error of the ensemble
842  * This rate should be used to estimate the crossvalidation
843  * error rate.
844  * Here each sample is put down those trees, for which this sample
845  * is OOB i.e. if sample #1 is OOB for trees 1, 3 and 5 we calculate
846  * the output using the ensemble consisting only of trees 1 3 and 5.
847  *
848  * Using normal bagged sampling each sample is OOB for approx. 33% of trees
849  * The error rate obtained as such therefore corresponds to crossvalidation
850  * rate obtained using a ensemble containing 33% of the trees.
851  */
852 class OOB_Error : public VisitorBase
853 {
855  int class_count;
856  bool is_weighted;
857  MultiArray<2,double> tmp_prob;
858  public:
859 
860  MultiArray<2, double> prob_oob;
861  /** Ensemble oob error rate
862  */
863  double oob_breiman;
864 
865  MultiArray<2, double> oobCount;
866  ArrayVector< int> indices;
867  OOB_Error() : VisitorBase(), oob_breiman(0.0) {}
868 #ifdef HasHDF5
869  void save(std::string filen, std::string pathn)
870  {
871  if(*(pathn.end()-1) != '/')
872  pathn += "/";
873  const char* filename = filen.c_str();
874  MultiArray<2, double> temp(Shp(1,1), 0.0);
875  temp[0] = oob_breiman;
876  writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
877  }
878 #endif
879  // negative value if sample was ib, number indicates how often.
880  // value >=0 if sample was oob, 0 means fail 1, correct
881 
882  template<class RF, class PR>
883  void visit_at_beginning(RF & rf, PR & pr)
884  {
885  class_count = rf.class_count();
886  tmp_prob.reshape(Shp(1, class_count), 0);
887  prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
888  is_weighted = rf.options().predict_weighted_;
889  indices.resize(rf.ext_param().row_count_);
890  if(int(oobCount.size()) != rf.ext_param_.row_count_)
891  {
892  oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
893  }
894  for(int ii = 0; ii < rf.ext_param().row_count_; ++ii)
895  {
896  indices[ii] = ii;
897  }
898  }
899 
900  template<class RF, class PR, class SM, class ST>
901  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
902  {
903  // go through the samples
904  int total_oob =0;
905  // FIXME: magic number 10000: invoke special treatment when when msample << sample_count
906  // (i.e. the OOB sample ist very large)
907  // 40000: use at most 40000 OOB samples per class for OOB error estimate
908  if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
909  {
910  ArrayVector<int> oob_indices;
911  ArrayVector<int> cts(class_count, 0);
912  std::random_shuffle(indices.begin(), indices.end());
913  for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
914  {
915  if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
916  {
917  oob_indices.push_back(indices[ii]);
918  ++cts[pr.response()(indices[ii], 0)];
919  }
920  }
921  for(unsigned int ll = 0; ll < oob_indices.size(); ++ll)
922  {
923  // update number of trees in which current sample is oob
924  ++oobCount[oob_indices[ll]];
925 
926  // update number of oob samples in this tree.
927  ++total_oob;
928  // get the predicted votes ---> tmp_prob;
929  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll]));
930  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
931  rf.tree(index).parameters_,
932  pos);
933  tmp_prob.init(0);
934  for(int ii = 0; ii < class_count; ++ii)
935  {
936  tmp_prob[ii] = node.prob_begin()[ii];
937  }
938  if(is_weighted)
939  {
940  for(int ii = 0; ii < class_count; ++ii)
941  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
942  }
943  rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
944 
945  }
946  }else
947  {
948  for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
949  {
950  // if the lth sample is oob...
951  if(!sm.is_used()[ll])
952  {
953  // update number of trees in which current sample is oob
954  ++oobCount[ll];
955 
956  // update number of oob samples in this tree.
957  ++total_oob;
958  // get the predicted votes ---> tmp_prob;
959  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
960  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
961  rf.tree(index).parameters_,
962  pos);
963  tmp_prob.init(0);
964  for(int ii = 0; ii < class_count; ++ii)
965  {
966  tmp_prob[ii] = node.prob_begin()[ii];
967  }
968  if(is_weighted)
969  {
970  for(int ii = 0; ii < class_count; ++ii)
971  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
972  }
973  rowVector(prob_oob, ll) += tmp_prob;
974  }
975  }
976  }
977  // go through the ib samples;
978  }
979 
980  /** Normalise variable importance after the number of trees is known.
981  */
982  template<class RF, class PR>
983  void visit_at_end(RF & rf, PR & pr)
984  {
985  // ullis original metric and breiman style stuff
986  int totalOobCount =0;
987  int breimanstyle = 0;
988  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
989  {
990  if(oobCount[ll])
991  {
992  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
993  ++breimanstyle;
994  ++totalOobCount;
995  }
996  }
997  oob_breiman = double(breimanstyle)/totalOobCount;
998  }
999 };
1000 
1001 
1002 /** Visitor that calculates different OOB error statistics
1003  */
1005 {
1006  typedef MultiArrayShape<2>::type Shp;
1007  int class_count;
1008  bool is_weighted;
1009  MultiArray<2,double> tmp_prob;
1010  public:
1011 
1012  /** OOB Error rate of each individual tree
1013  */
1015  /** Mean of oob_per_tree
1016  */
1017  double oob_mean;
1018  /**Standard deviation of oob_per_tree
1019  */
1020  double oob_std;
1021 
1022  MultiArray<2, double> prob_oob;
1023  /** Ensemble OOB error
1024  *
1025  * \sa OOB_Error
1026  */
1027  double oob_breiman;
1028 
1029  MultiArray<2, double> oobCount;
1030  MultiArray<2, double> oobErrorCount;
1031  /** Per Tree OOB error calculated as in OOB_PerTreeError
1032  * (Ulli's version)
1033  */
1035 
1036  /**Column containing the development of the Ensemble
1037  * error rate with increasing number of trees
1038  */
1040  /** 4 dimensional array containing the development of confusion matrices
1041  * with number of trees - can be used to estimate ROC curves etc.
1042  *
1043  * oobroc_per_tree(ii,jj,kk,ll)
1044  * corresponds true label = ii
1045  * predicted label = jj
1046  * confusion matrix after ll trees
1047  *
1048  * explanation of third index:
1049  *
1050  * Two class case:
1051  * kk = 0 - (treeCount-1)
1052  * Threshold is on Probability for class 0 is kk/(treeCount-1);
1053  * More classes:
1054  * kk = 0. Threshold on probability set by argMax of the probability array.
1055  */
1057 
1059 
1060 #ifdef HasHDF5
1061  /** save to HDF5 file
1062  */
1063  void save(std::string filen, std::string pathn)
1064  {
1065  if(*(pathn.end()-1) != '/')
1066  pathn += "/";
1067  const char* filename = filen.c_str();
1068  MultiArray<2, double> temp(Shp(1,1), 0.0);
1069  writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree);
1070  writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree);
1071  writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree);
1072  temp[0] = oob_mean;
1073  writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp);
1074  temp[0] = oob_std;
1075  writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp);
1076  temp[0] = oob_breiman;
1077  writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
1078  temp[0] = oob_per_tree2;
1079  writeHDF5(filename, (pathn + "ulli_error").c_str(), temp);
1080  }
1081 #endif
1082  // negative value if sample was ib, number indicates how often.
1083  // value >=0 if sample was oob, 0 means fail 1, correct
1084 
1085  template<class RF, class PR>
1086  void visit_at_beginning(RF & rf, PR & pr)
1087  {
1088  class_count = rf.class_count();
1089  if(class_count == 2)
1090  oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count()));
1091  else
1092  oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count()));
1093  tmp_prob.reshape(Shp(1, class_count), 0);
1094  prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1095  is_weighted = rf.options().predict_weighted_;
1096  oob_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1097  breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1098  //do the first time called.
1099  if(int(oobCount.size()) != rf.ext_param_.row_count_)
1100  {
1101  oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1102  oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0);
1103  }
1104  }
1105 
1106  template<class RF, class PR, class SM, class ST>
1107  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1108  {
1109  // go through the samples
1110  int total_oob =0;
1111  int wrong_oob =0;
1112  for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1113  {
1114  // if the lth sample is oob...
1115  if(!sm.is_used()[ll])
1116  {
1117  // update number of trees in which current sample is oob
1118  ++oobCount[ll];
1119 
1120  // update number of oob samples in this tree.
1121  ++total_oob;
1122  // get the predicted votes ---> tmp_prob;
1123  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
1124  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
1125  rf.tree(index).parameters_,
1126  pos);
1127  tmp_prob.init(0);
1128  for(int ii = 0; ii < class_count; ++ii)
1129  {
1130  tmp_prob[ii] = node.prob_begin()[ii];
1131  }
1132  if(is_weighted)
1133  {
1134  for(int ii = 0; ii < class_count; ++ii)
1135  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1136  }
1137  rowVector(prob_oob, ll) += tmp_prob;
1138  int label = argMax(tmp_prob);
1139 
1140  if(label != pr.response()(ll, 0))
1141  {
1142  // update number of wrong oob samples in this tree.
1143  ++wrong_oob;
1144  // update number of trees in which current sample is wrong oob
1145  ++oobErrorCount[ll];
1146  }
1147  }
1148  }
1149  int breimanstyle = 0;
1150  int totalOobCount = 0;
1151  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1152  {
1153  if(oobCount[ll])
1154  {
1155  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1156  ++breimanstyle;
1157  ++totalOobCount;
1158  if(oobroc_per_tree.shape(2) == 1)
1159  {
1160  oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++;
1161  }
1162  }
1163  }
1164  if(oobroc_per_tree.shape(2) == 1)
1165  oobroc_per_tree.bindOuter(index)/=totalOobCount;
1166  if(oobroc_per_tree.shape(2) > 1)
1167  {
1168  MultiArrayView<3, double> current_roc
1169  = oobroc_per_tree.bindOuter(index);
1170  for(int gg = 0; gg < current_roc.shape(2); ++gg)
1171  {
1172  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1173  {
1174  if(oobCount[ll])
1175  {
1176  int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
1177  1 : 0;
1178  current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1179  }
1180  }
1181  current_roc.bindOuter(gg)/= totalOobCount;
1182  }
1183  }
1184  breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount);
1185  oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1186  // go through the ib samples;
1187  }
1188 
1189  /** Normalise variable importance after the number of trees is known.
1190  */
1191  template<class RF, class PR>
1192  void visit_at_end(RF & rf, PR & pr)
1193  {
1194  // ullis original metric and breiman style stuff
1195  oob_per_tree2 = 0;
1196  int totalOobCount =0;
1197  int breimanstyle = 0;
1198  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1199  {
1200  if(oobCount[ll])
1201  {
1202  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1203  ++breimanstyle;
1204  oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll];
1205  ++totalOobCount;
1206  }
1207  }
1208  oob_per_tree2 /= totalOobCount;
1209  oob_breiman = double(breimanstyle)/totalOobCount;
1210  // mean error of each tree
1211  MultiArrayView<2, double> mean(Shp(1,1), &oob_mean);
1212  MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std);
1213  rowStatistics(oob_per_tree, mean, stdDev);
1214  }
1215 };
1216 
1217 /** calculate variable importance while learning.
1218  */
1220 {
1221  public:
1222 
1223  /** This Array has the same entries as the R - random forest variable
1224  * importance.
1225  * Matrix is featureCount by (classCount +2)
1226  * variable_importance_(ii,jj) is the variable importance measure of
1227  * the ii-th variable according to:
1228  * jj = 0 - (classCount-1)
1229  * classwise permutation importance
1230  * jj = rowCount(variable_importance_) -2
1231  * permutation importance
1232  * jj = rowCount(variable_importance_) -1
1233  * gini decrease importance.
1234  *
1235  * permutation importance:
1236  * The difference between the fraction of OOB samples classified correctly
1237  * before and after permuting (randomizing) the ii-th column is calculated.
1238  * The ii-th column is permuted rep_cnt times.
1239  *
1240  * class wise permutation importance:
1241  * same as permutation importance. We only look at those OOB samples whose
1242  * response corresponds to class jj.
1243  *
1244  * gini decrease importance:
1245  * row ii corresponds to the sum of all gini decreases induced by variable ii
1246  * in each node of the random forest.
1247  */
1249  int repetition_count_;
1250  bool in_place_;
1251 
1252 #ifdef HasHDF5
1253  void save(std::string filename, std::string prefix)
1254  {
1255  prefix = "variable_importance_" + prefix;
1256  writeHDF5(filename.c_str(),
1257  prefix.c_str(),
1259  }
1260 #endif
1261 
1262  /** Constructor
1263  * \param rep_cnt (defautl: 10) how often should
1264  * the permutation take place. Set to 1 to make calculation faster (but
1265  * possibly more instable)
1266  */
1267  VariableImportanceVisitor(int rep_cnt = 10)
1268  : repetition_count_(rep_cnt)
1269 
1270  {}
1271 
1272  /** calculates impurity decrease based variable importance after every
1273  * split.
1274  */
1275  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1276  void visit_after_split( Tree & tree,
1277  Split & split,
1278  Region & parent,
1279  Region & leftChild,
1280  Region & rightChild,
1281  Feature_t & features,
1282  Label_t & labels)
1283  {
1284  //resize to right size when called the first time
1285 
1286  Int32 const class_count = tree.ext_param_.class_count_;
1287  Int32 const column_count = tree.ext_param_.column_count_;
1288  if(variable_importance_.size() == 0)
1289  {
1290 
1292  .reshape(MultiArrayShape<2>::type(column_count,
1293  class_count+2));
1294  }
1295 
1296  if(split.createNode().typeID() == i_ThresholdNode)
1297  {
1298  Node<i_ThresholdNode> node(split.createNode());
1299  variable_importance_(node.column(),class_count+1)
1300  += split.region_gini_ - split.minGini();
1301  }
1302  }
1303 
1304  /**compute permutation based var imp.
1305  * (Only an Array of size oob_sample_count x 1 is created.
1306  * - apposed to oob_sample_count x feature_count in the other method.
1307  *
1308  * \sa FieldProxy
1309  */
1310  template<class RF, class PR, class SM, class ST>
1311  void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & st, int index)
1312  {
1313  typedef MultiArrayShape<2>::type Shp_t;
1314  Int32 column_count = rf.ext_param_.column_count_;
1315  Int32 class_count = rf.ext_param_.class_count_;
1316 
1317  /* This solution saves memory uptake but not multithreading
1318  * compatible
1319  */
1320  // remove the const cast on the features (yep , I know what I am
1321  // doing here.) data is not destroyed.
1322  //typename PR::Feature_t & features
1323  // = const_cast<typename PR::Feature_t &>(pr.features());
1324 
1325  typedef typename PR::FeatureWithMemory_t FeatureArray;
1326  typedef typename FeatureArray::value_type FeatureValue;
1327 
1328  FeatureArray features = pr.features();
1329 
1330  //find the oob indices of current tree.
1331  ArrayVector<Int32> oob_indices;
1333  iter;
1334  for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1335  if(!sm.is_used()[ii])
1336  oob_indices.push_back(ii);
1337 
1338  //create space to back up a column
1339  ArrayVector<FeatureValue> backup_column;
1340 
1341  // Random foo
1342 #ifdef CLASSIFIER_TEST
1343  RandomMT19937 random(1);
1344 #else
1345  RandomMT19937 random(RandomSeed);
1346 #endif
1348  randint(random);
1349 
1350 
1351  //make some space for the results
1353  oob_right(Shp_t(1, class_count + 1));
1355  perm_oob_right (Shp_t(1, class_count + 1));
1356 
1357 
1358  // get the oob success rate with the original samples
1359  for(iter = oob_indices.begin();
1360  iter != oob_indices.end();
1361  ++iter)
1362  {
1363  if(rf.tree(index)
1364  .predictLabel(rowVector(features, *iter))
1365  == pr.response()(*iter, 0))
1366  {
1367  //per class
1368  ++oob_right[pr.response()(*iter,0)];
1369  //total
1370  ++oob_right[class_count];
1371  }
1372  }
1373  //get the oob rate after permuting the ii'th dimension.
1374  for(int ii = 0; ii < column_count; ++ii)
1375  {
1376  perm_oob_right.init(0.0);
1377  //make backup of original column
1378  backup_column.clear();
1379  for(iter = oob_indices.begin();
1380  iter != oob_indices.end();
1381  ++iter)
1382  {
1383  backup_column.push_back(features(*iter,ii));
1384  }
1385 
1386  //get the oob rate after permuting the ii'th dimension.
1387  for(int rr = 0; rr < repetition_count_; ++rr)
1388  {
1389  //permute dimension.
1390  int n = oob_indices.size();
1391  for(int jj = 1; jj < n; ++jj)
1392  std::swap(features(oob_indices[jj], ii),
1393  features(oob_indices[randint(jj+1)], ii));
1394 
1395  //get the oob success rate after permuting
1396  for(iter = oob_indices.begin();
1397  iter != oob_indices.end();
1398  ++iter)
1399  {
1400  if(rf.tree(index)
1401  .predictLabel(rowVector(features, *iter))
1402  == pr.response()(*iter, 0))
1403  {
1404  //per class
1405  ++perm_oob_right[pr.response()(*iter, 0)];
1406  //total
1407  ++perm_oob_right[class_count];
1408  }
1409  }
1410  }
1411 
1412 
1413  //normalise and add to the variable_importance array.
1414  perm_oob_right /= repetition_count_;
1415  perm_oob_right -=oob_right;
1416  perm_oob_right *= -1;
1417  perm_oob_right /= oob_indices.size();
1419  .subarray(Shp_t(ii,0),
1420  Shp_t(ii+1,class_count+1)) += perm_oob_right;
1421  //copy back permuted dimension
1422  for(int jj = 0; jj < int(oob_indices.size()); ++jj)
1423  features(oob_indices[jj], ii) = backup_column[jj];
1424  }
1425  }
1426 
1427  /** calculate permutation based impurity after every tree has been
1428  * learned default behaviour is that this happens out of place.
1429  * If you have very big data sets and want to avoid copying of data
1430  * set the in_place_ flag to true.
1431  */
1432  template<class RF, class PR, class SM, class ST>
1433  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1434  {
1435  after_tree_ip_impl(rf, pr, sm, st, index);
1436  }
1437 
1438  /** Normalise variable importance after the number of trees is known.
1439  */
1440  template<class RF, class PR>
1441  void visit_at_end(RF & rf, PR & pr)
1442  {
1443  variable_importance_ /= rf.trees_.size();
1444  }
1445 };
1446 
1447 /** Verbose output
1448  */
1450  public:
1452 
1453  template<class RF, class PR, class SM, class ST>
1454  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index){
1455  if(index != rf.options().tree_count_-1) {
1456  std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]"
1457  << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush;
1458  }
1459  else {
1460  std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl;
1461  }
1462  }
1463 
1464  template<class RF, class PR>
1465  void visit_at_end(RF const & rf, PR const & pr) {
1466  std::string a = TOCS;
1467  std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a << std::endl;
1468  }
1469 
1470  template<class RF, class PR>
1471  void visit_at_beginning(RF const & rf, PR const & pr) {
1472  TIC;
1473  std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl;
1474  }
1475 
1476  private:
1477  USETICTOC;
1478 };
1479 
1480 
1481 /** Computes Correlation/Similarity Matrix of features while learning
1482  * random forest.
1483  */
1485 {
1486  public:
1487  /** gini_missc(ii, jj) describes how well variable jj can describe a partition
1488  * created on variable ii(when variable ii was chosen)
1489  */
1491  MultiArray<2, int> tmp_labels;
1492  /** additional noise features.
1493  */
1495  MultiArray<2, double> noise_l;
1496  /** how well can a noise column describe a partition created on variable ii.
1497  */
1499  MultiArray<2, double> corr_l;
1500 
1501  /** Similarity Matrix
1502  *
1503  * (numberOfFeatures + 1) by (number Of Features + 1) Matrix
1504  * gini_missc
1505  * - row normalized by the number of times the column was chosen
1506  * - mean of corr_noise subtracted
1507  * - and symmetrised.
1508  *
1509  */
1511  /** Distance Matrix 1-similarity
1512  */
1514  ArrayVector<int> tmp_cc;
1515 
1516  /** How often was variable ii chosen
1517  */
1521  void save(std::string file, std::string prefix)
1522  {
1523  /*
1524  std::string tmp;
1525 #define VAR_WRITE(NAME) \
1526  tmp = #NAME;\
1527  tmp += "_";\
1528  tmp += prefix;\
1529  vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME);
1530  VAR_WRITE(gini_missc);
1531  VAR_WRITE(corr_noise);
1532  VAR_WRITE(distance);
1533  VAR_WRITE(similarity);
1534  vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data()));
1535 #undef VAR_WRITE
1536 */
1537  }
1538  template<class RF, class PR>
1539  void visit_at_beginning(RF const & rf, PR & pr)
1540  {
1541  typedef MultiArrayShape<2>::type Shp;
1542  int n = rf.ext_param_.column_count_;
1543  gini_missc.reshape(Shp(n +1,n+ 1));
1544  corr_noise.reshape(Shp(n + 1, 10));
1545  corr_l.reshape(Shp(n +1, 10));
1546 
1547  noise.reshape(Shp(pr.features().shape(0), 10));
1548  noise_l.reshape(Shp(pr.features().shape(0), 10));
1549  RandomMT19937 random(RandomSeed);
1550  for(int ii = 0; ii < noise.size(); ++ii)
1551  {
1552  noise[ii] = random.uniform53();
1553  noise_l[ii] = random.uniform53() > 0.5;
1554  }
1555  bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1556  tmp_labels.reshape(pr.response().shape());
1557  tmp_cc.resize(2);
1558  numChoices.resize(n+1);
1559  // look at all axes
1560  }
1561  template<class RF, class PR>
1562  void visit_at_end(RF const & rf, PR const & pr)
1563  {
1564  typedef MultiArrayShape<2>::type Shp;
1567  MultiArray<2, double> mean_noise(Shp(corr_noise.shape(0), 1));
1568  rowStatistics(corr_noise, mean_noise);
1569  mean_noise/= MultiArrayView<2, int>(mean_noise.shape(), numChoices.data());
1570  int rC = similarity.shape(0);
1571  for(int jj = 0; jj < rC-1; ++jj)
1572  {
1573  rowVector(similarity, jj) /= numChoices[jj];
1574  rowVector(similarity, jj) -= mean_noise(jj, 0);
1575  }
1576  for(int jj = 0; jj < rC; ++jj)
1577  {
1578  similarity(rC -1, jj) /= numChoices[jj];
1579  }
1580  rowVector(similarity, rC - 1) -= mean_noise(rC-1, 0);
1582  FindMinMax<double> minmax;
1583  inspectMultiArray(srcMultiArrayRange(similarity), minmax);
1584 
1585  for(int jj = 0; jj < rC; ++jj)
1586  similarity(jj, jj) = minmax.max;
1587 
1588  similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))
1589  += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose();
1590  similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;
1591  columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose();
1592  for(int jj = 0; jj < rC; ++jj)
1593  similarity(jj, jj) = 0;
1594 
1595  FindMinMax<double> minmax2;
1596  inspectMultiArray(srcMultiArrayRange(similarity), minmax2);
1597  for(int jj = 0; jj < rC; ++jj)
1598  similarity(jj, jj) = minmax2.max;
1599  distance.reshape(gini_missc.shape(), minmax2.max);
1600  distance -= similarity;
1601  }
1602 
1603  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1604  void visit_after_split( Tree & tree,
1605  Split & split,
1606  Region & parent,
1607  Region & leftChild,
1608  Region & rightChild,
1609  Feature_t & features,
1610  Label_t & labels)
1611  {
1612  if(split.createNode().typeID() == i_ThresholdNode)
1613  {
1614  double wgini;
1615  tmp_cc.init(0);
1616  for(int ii = 0; ii < parent.size(); ++ii)
1617  {
1618  tmp_labels[parent[ii]]
1619  = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1620  ++tmp_cc[tmp_labels[parent[ii]]];
1621  }
1622  double region_gini = bgfunc.loss_of_region(tmp_labels,
1623  parent.begin(),
1624  parent.end(),
1625  tmp_cc);
1626 
1627  int n = split.bestSplitColumn();
1628  ++numChoices[n];
1629  ++(*(numChoices.end()-1));
1630  //this functor does all the work
1631  for(int k = 0; k < features.shape(1); ++k)
1632  {
1633  bgfunc(columnVector(features, k),
1634  tmp_labels,
1635  parent.begin(), parent.end(),
1636  tmp_cc);
1637  wgini = (region_gini - bgfunc.min_gini_);
1638  gini_missc(n, k)
1639  += wgini;
1640  }
1641  for(int k = 0; k < 10; ++k)
1642  {
1643  bgfunc(columnVector(noise, k),
1644  tmp_labels,
1645  parent.begin(), parent.end(),
1646  tmp_cc);
1647  wgini = (region_gini - bgfunc.min_gini_);
1648  corr_noise(n, k)
1649  += wgini;
1650  }
1651 
1652  for(int k = 0; k < 10; ++k)
1653  {
1654  bgfunc(columnVector(noise_l, k),
1655  tmp_labels,
1656  parent.begin(), parent.end(),
1657  tmp_cc);
1658  wgini = (region_gini - bgfunc.min_gini_);
1659  corr_l(n, k)
1660  += wgini;
1661  }
1662  bgfunc(labels, tmp_labels, parent.begin(), parent.end(),tmp_cc);
1663  wgini = (region_gini - bgfunc.min_gini_);
1665  += wgini;
1666 
1667  region_gini = split.region_gini_;
1668 #if 1
1669  Node<i_ThresholdNode> node(split.createNode());
1671  node.column())
1672  +=split.region_gini_ - split.minGini();
1673 #endif
1674  for(int k = 0; k < 10; ++k)
1675  {
1676  split.bgfunc(columnVector(noise, k),
1677  labels,
1678  parent.begin(), parent.end(),
1679  parent.classCounts());
1681  k)
1682  += wgini;
1683  }
1684 #if 0
1685  for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1686  {
1687  wgini = region_gini - split.min_gini_[k];
1688 
1690  split.splitColumns[k])
1691  += wgini;
1692  }
1693 
1694  for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1695  {
1696  split.bgfunc(columnVector(features, split.splitColumns[k]),
1697  labels,
1698  parent.begin(), parent.end(),
1699  parent.classCounts());
1700  wgini = region_gini - split.bgfunc.min_gini_;
1702  split.splitColumns[k]) += wgini;
1703  }
1704 #endif
1705  // remember to partition the data according to the best.
1707  columnCount(gini_missc)-1)
1708  += region_gini;
1709  SortSamplesByDimensions<Feature_t>
1710  sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1711  std::partition(parent.begin(), parent.end(), sorter);
1712  }
1713  }
1714 };
1715 
1716 
1717 } // namespace visitors
1718 } // namespace rf
1719 } // namespace vigra
1720 
1721 //@}
1722 #endif // RF_VISITORS_HXX
#define TIC
Definition: timing.hxx:321
void visit_internal_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:205
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_visitors.hxx:1311
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:725
MultiArray< 2, double > breiman_per_tree
Definition: rf_visitors.hxx:1039
MultiArray< 2, double > gini_missc
Definition: rf_visitors.hxx:1490
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:669
const difference_type & shape() const
Definition: multi_array.hxx:1602
void visit_at_end(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:171
void visit_at_beginning(RF &rf, const PR &pr)
Definition: rf_visitors.hxx:617
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_visitors.hxx:161
const_iterator begin() const
Definition: array_vector.hxx:223
double oobError
Definition: rf_visitors.hxx:777
MultiArray< 2, double > similarity
Definition: rf_visitors.hxx:1510
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2756
Definition: rf_visitors.hxx:852
ArrayVector< int > numChoices
Definition: rf_visitors.hxx:1518
Definition: rf_visitors.hxx:1484
MultiArrayView subarray(const difference_type &p, const difference_type &q) const
Definition: multi_array.hxx:1490
Definition: rf_visitors.hxx:1219
MultiArrayView< N, T, StridedArrayTag > transpose() const
Definition: multi_array.hxx:1528
double oob_per_tree2
Definition: rf_visitors.hxx:1034
detail::VisitorNode< A > create_visitor(A &a)
Definition: rf_visitors.hxx:334
void reset_tree(int tree_id)
Definition: rf_visitors.hxx:625
Definition: random.hxx:648
difference_type_1 size() const
Definition: multi_array.hxx:1595
MultiArray< 4, double > oobroc_per_tree
Definition: rf_visitors.hxx:1056
double return_val()
Definition: rf_visitors.hxx:215
void visit_at_end(RF &rf, PR &pr)
Definition: rf_visitors.hxx:825
Definition: rf_visitors.hxx:244
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
MultiArray< 2, double > noise
Definition: rf_visitors.hxx:1494
void init(U const &initial)
Definition: array_vector.hxx:146
Definition: rf_split.hxx:831
MultiArray & init(const U &init)
Definition: multi_array.hxx:2746
Iterator argMax(Iterator first, Iterator last)
Definition: algorithm.hxx:96
Definition: rf_visitors.hxx:1004
Definition: rf_visitors.hxx:573
VariableImportanceVisitor(int rep_cnt=10)
Definition: rf_visitors.hxx:1267
MultiArray< 2, double > oob_per_tree
Definition: rf_visitors.hxx:1014
#define TOCS
Definition: timing.hxx:324
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:939
void writeHDF5(...)
Store array data in an HDF5 file.
Definition: rf_visitors.hxx:1449
MultiArray< 2, double > distance
Definition: rf_visitors.hxx:1513
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:695
double oob_std
Definition: rf_visitors.hxx:1020
Definition: rf_visitors.hxx:101
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:682
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude)
Definition: fftw3.hxx:1002
Definition: random.hxx:335
MultiArray< 2, double > variable_importance_
Definition: rf_visitors.hxx:1248
double oob_breiman
Definition: rf_visitors.hxx:863
const_iterator end() const
Definition: array_vector.hxx:237
const_pointer data() const
Definition: array_vector.hxx:209
size_type size() const
Definition: array_vector.hxx:330
void visit_at_beginning(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:181
void inspectMultiArray(...)
Call an analyzing functor at every element of a multi-dimensional array.
void visit_external_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:197
Definition: rf_visitors.hxx:772
void rowStatistics(...)
MultiArrayView< N-M, T, StrideTag > bindOuter(const TinyVector< Index, M > &d) const
Definition: multi_array.hxx:2101
Definition: rf_visitors.hxx:224
void visit_after_split(Tree &tree, Split &split, Region &parent, Region &leftChild, Region &rightChild, Feature_t &features, Label_t &labels)
Definition: rf_visitors.hxx:142
double oob_mean
Definition: rf_visitors.hxx:1017
MultiArray< 2, double > corr_noise
Definition: rf_visitors.hxx:1498

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.9.0 (Sun Aug 10 2014)