Class RandomForest

  • All Implemented Interfaces:
    java.io.Serializable, Regression<double[]>

    public class RandomForest
    extends java.lang.Object
    implements Regression<double[]>
    Random forest for regression. Random forest is an ensemble method that consists of many regression trees and outputs the average of individual trees. The method combines bagging idea and the random selection of features.

    Each tree is constructed using the following algorithm:

    1. If the number of cases in the training set is N, randomly sample N cases with replacement from the original data. This sample will be the training set for growing the tree.
    2. If there are M input variables, a number m << M is specified such that at each node, m variables are selected at random out of the M and the best split on these m is used to split the node. The value of m is held constant during the forest growing.
    3. Each tree is grown to the largest extent possible. There is no pruning.
    The advantages of random forest are:
    • For many data sets, it produces a highly accurate model.
    • It runs efficiently on large data sets.
    • It can handle thousands of input variables without variable deletion.
    • It gives estimates of what variables are important in the classification.
    • It generates an internal unbiased estimate of the generalization error as the forest building progresses.
    • It has an effective method for estimating missing data and maintains accuracy when a large proportion of the data are missing.
    The disadvantages are
    • Random forests are prone to over-fitting for some datasets. This is even more pronounced in noisy classification/regression tasks.
    • For data including categorical variables with different number of levels, random forests are biased in favor of those attributes with more levels. Therefore, the variable importance scores from random forest are not reliable for this type of data.
    Author:
    Haifeng Li
    See Also:
    Serialized Form
    • Nested Class Summary

      Nested Classes 
      Modifier and Type Class Description
      static class  RandomForest.Trainer
      Trainer for random forest.
    • Constructor Summary

      Constructors 
      Constructor Description
      RandomForest​(double[][] x, double[] y, int ntrees)
      Constructor.
      RandomForest​(double[][] x, double[] y, int ntrees, int maxNodes, int nodeSize, int mtry)
      Constructor.
      RandomForest​(Attribute[] attributes, double[][] x, double[] y, int ntrees)
      Constructor.
      RandomForest​(Attribute[] attributes, double[][] x, double[] y, int ntrees, int maxNodes)
      Constructor.
      RandomForest​(Attribute[] attributes, double[][] x, double[] y, int ntrees, int maxNodes, int nodeSize)
      Constructor.
      RandomForest​(Attribute[] attributes, double[][] x, double[] y, int ntrees, int maxNodes, int nodeSize, int mtry)
      Constructor.
      RandomForest​(Attribute[] attributes, double[][] x, double[] y, int ntrees, int maxNodes, int nodeSize, int mtry, double subsample)
      Constructor.
      RandomForest​(Attribute[] attributes, double[][] x, double[] y, int ntrees, int maxNodes, int nodeSize, int mtry, double subsample, double[] monotonicRegression)
      Constructor.
      RandomForest​(AttributeDataset data, int ntrees)
      Constructor.
      RandomForest​(AttributeDataset data, int ntrees, int maxNodes)
      Constructor.
      RandomForest​(AttributeDataset data, int ntrees, int maxNodes, int nodeSize)
      Constructor.
      RandomForest​(AttributeDataset data, int ntrees, int maxNodes, int nodeSize, int mtry, double subsample, double[] monotonicRegression)
      Constructor.
    • Method Summary

      All Methods Instance Methods Concrete Methods 
      Modifier and Type Method Description
      double error()
      Returns the out-of-bag estimation of RMSE.
      RegressionTree[] getTrees()
      Returns the regression trees.
      double[] importance()
      Returns the variable importance.
      RandomForest merge​(RandomForest other)
      Merges together two random forests and returns a new forest consisting of trees from both input forests.
      double predict​(double[] x)
      Predicts the dependent variable of an instance.
      int size()
      Returns the number of trees in the model.
      double[] test​(double[][] x, double[] y)
      Test the model on a validation dataset.
      double[][] test​(double[][] x, double[] y, RegressionMeasure[] measures)
      Test the model on a validation dataset.
      void trim​(int ntrees)
      Trims the tree model set to a smaller size in case of over-fitting.
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
    • Constructor Detail

      • RandomForest

        public RandomForest​(double[][] x,
                            double[] y,
                            int ntrees)
        Constructor. Learns a random forest for regression.
        Parameters:
        x - the training instances.
        y - the response variable.
        ntrees - the number of trees.
      • RandomForest

        public RandomForest​(double[][] x,
                            double[] y,
                            int ntrees,
                            int maxNodes,
                            int nodeSize,
                            int mtry)
        Constructor. Learns a random forest for regression.
        Parameters:
        x - the training instances.
        y - the response variable.
        ntrees - the number of trees.
        mtry - the number of input variables to be used to determine the decision at a node of the tree. p/3 seems to give generally good performance, where p is the number of variables.
        nodeSize - the number of instances in a node below which the tree will not split, setting nodeSize = 5 generally gives good results.
        maxNodes - the maximum number of leaf nodes in the tree.
      • RandomForest

        public RandomForest​(Attribute[] attributes,
                            double[][] x,
                            double[] y,
                            int ntrees)
        Constructor. Learns a random forest for regression.
        Parameters:
        attributes - the attribute properties.
        x - the training instances.
        y - the response variable.
        ntrees - the number of trees.
      • RandomForest

        public RandomForest​(AttributeDataset data,
                            int ntrees)
        Constructor. Learns a random forest for regression.
        Parameters:
        data - the dataset
        ntrees - the number of trees.
      • RandomForest

        public RandomForest​(Attribute[] attributes,
                            double[][] x,
                            double[] y,
                            int ntrees,
                            int maxNodes)
        Constructor. Learns a random forest for regression.
        Parameters:
        attributes - the attribute properties.
        x - the training instances.
        y - the response variable.
        ntrees - the number of trees.
        maxNodes - the maximum number of leaf nodes in the tree.
      • RandomForest

        public RandomForest​(AttributeDataset data,
                            int ntrees,
                            int maxNodes)
        Constructor. Learns a random forest for regression.
        Parameters:
        data - the dataset
        ntrees - the number of trees.
        maxNodes - the maximum number of leaf nodes in the tree.
      • RandomForest

        public RandomForest​(Attribute[] attributes,
                            double[][] x,
                            double[] y,
                            int ntrees,
                            int maxNodes,
                            int nodeSize)
        Constructor. Learns a random forest for regression.
        Parameters:
        attributes - the attribute properties.
        x - the training instances.
        y - the response variable.
        ntrees - the number of trees.
        nodeSize - the number of instances in a node below which the tree will not split, setting nodeSize = 5 generally gives good results.
        maxNodes - the maximum number of leaf nodes in the tree.
      • RandomForest

        public RandomForest​(AttributeDataset data,
                            int ntrees,
                            int maxNodes,
                            int nodeSize)
        Constructor. Learns a random forest for regression.
        Parameters:
        data - the dataset
        ntrees - the number of trees.
        maxNodes - the maximum number of leaf nodes in the tree.
        nodeSize - the number of instances in a node below which the tree will not split, setting nodeSize = 5 generally gives good results.
      • RandomForest

        public RandomForest​(Attribute[] attributes,
                            double[][] x,
                            double[] y,
                            int ntrees,
                            int maxNodes,
                            int nodeSize,
                            int mtry)
        Constructor. Learns a random forest for regression.
        Parameters:
        attributes - the attribute properties.
        x - the training instances.
        y - the response variable.
        ntrees - the number of trees.
        mtry - the number of input variables to be used to determine the decision at a node of the tree. p/3 seems to give generally good performance, where dim is the number of variables.
        nodeSize - the number of instances in a node below which the tree will not split, setting nodeSize = 5 generally gives good results.
        maxNodes - the maximum number of leaf nodes in the tree.
      • RandomForest

        public RandomForest​(Attribute[] attributes,
                            double[][] x,
                            double[] y,
                            int ntrees,
                            int maxNodes,
                            int nodeSize,
                            int mtry,
                            double subsample)
        Constructor. Learns a random forest for regression.
        Parameters:
        attributes - the attribute properties.
        x - the training instances.
        y - the response variable.
        ntrees - the number of trees.
        mtry - the number of input variables to be used to determine the decision at a node of the tree. p/3 seems to give generally good performance, where dim is the number of variables.
        nodeSize - the number of instances in a node below which the tree will not split, setting nodeSize = 5 generally gives good results.
        maxNodes - the maximum number of leaf nodes in the tree.
        subsample - the sampling rate for training tree. 1.0 means sampling with replacement. < 1.0 means sampling without replacement.
      • RandomForest

        public RandomForest​(AttributeDataset data,
                            int ntrees,
                            int maxNodes,
                            int nodeSize,
                            int mtry,
                            double subsample,
                            double[] monotonicRegression)
        Constructor. Learns a random forest for regression.
        Parameters:
        data - the dataset
        ntrees - the number of trees.
        mtry - the number of input variables to be used to determine the decision at a node of the tree. p/3 seems to give generally good performance, where dim is the number of variables.
        nodeSize - the number of instances in a node below which the tree will not split, setting nodeSize = 5 generally gives good results.
        maxNodes - the maximum number of leaf nodes in the tree.
        subsample - the sampling rate for training tree. 1.0 means sampling with replacement. < 1.0 means sampling without replacement.
      • RandomForest

        public RandomForest​(Attribute[] attributes,
                            double[][] x,
                            double[] y,
                            int ntrees,
                            int maxNodes,
                            int nodeSize,
                            int mtry,
                            double subsample,
                            double[] monotonicRegression)
        Constructor. Learns a random forest for regression.
        Parameters:
        attributes - the attribute properties.
        x - the training instances.
        y - the response variable.
        ntrees - the number of trees.
        mtry - the number of input variables to be used to determine the decision at a node of the tree. p/3 seems to give generally good performance, where dim is the number of variables.
        nodeSize - the number of instances in a node below which the tree will not split, setting nodeSize = 5 generally gives good results.
        maxNodes - the maximum number of leaf nodes in the tree.
        subsample - the sampling rate for training tree. 1.0 means sampling with replacement. < 1.0 means sampling without replacement.
    • Method Detail

      • merge

        public RandomForest merge​(RandomForest other)
        Merges together two random forests and returns a new forest consisting of trees from both input forests.
      • error

        public double error()
        Returns the out-of-bag estimation of RMSE. The OOB estimate is quite accurate given that enough trees have been grown. Otherwise the OOB estimate can bias upward.
        Returns:
        the out-of-bag estimation of RMSE
      • importance

        public double[] importance()
        Returns the variable importance. Every time a split of a node is made on variable the impurity criterion for the two descendent nodes is less than the parent node. Adding up the decreases for each individual variable over all trees in the forest gives a fast measure of variable importance that is often very consistent with the permutation importance measure.
        Returns:
        the variable importance
      • size

        public int size()
        Returns the number of trees in the model.
        Returns:
        the number of trees in the model
      • trim

        public void trim​(int ntrees)
        Trims the tree model set to a smaller size in case of over-fitting. Or if extra decision trees in the model don't improve the performance, we may remove them to reduce the model size and also improve the speed of prediction.
        Parameters:
        ntrees - the new (smaller) size of tree model set.
      • predict

        public double predict​(double[] x)
        Description copied from interface: Regression
        Predicts the dependent variable of an instance.
        Specified by:
        predict in interface Regression<double[]>
        Parameters:
        x - the instance.
        Returns:
        the predicted value of dependent variable.
      • test

        public double[] test​(double[][] x,
                             double[] y)
        Test the model on a validation dataset.
        Parameters:
        x - the test data set.
        y - the test data response values.
        Returns:
        RMSEs with first 1, 2, ..., regression trees.
      • test

        public double[][] test​(double[][] x,
                               double[] y,
                               RegressionMeasure[] measures)
        Test the model on a validation dataset.
        Parameters:
        x - the test data set.
        y - the test data output values.
        measures - the performance measures of regression.
        Returns:
        performance measures with first 1, 2, ..., regression trees.
      • getTrees

        public RegressionTree[] getTrees()
        Returns the regression trees.