Package smile.regression
Class NeuralNetwork
- java.lang.Object
-
- smile.regression.NeuralNetwork
-
- All Implemented Interfaces:
java.io.Serializable
,OnlineRegression<double[]>
,Regression<double[]>
public class NeuralNetwork extends java.lang.Object implements OnlineRegression<double[]>
Multilayer perceptron neural network for regression. An MLP consists of several layers of nodes, interconnected through weighted acyclic arcs from each preceding layer to the following, without lateral or feedback connections. Each node calculates a transformed weighted linear combination of its inputs (output activations from the preceding layer), with one of the weights acting as a trainable bias connected to a constant input. The transformation, called activation function, is a bounded non-decreasing (non-linear) function, such as the sigmoid functions (ranges from 0 to 1). Another popular activation function is hyperbolic tangent which is actually equivalent to the sigmoid function in shape but ranges from -1 to 1.- Author:
- Sam Erickson
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
NeuralNetwork.ActivationFunction
static class
NeuralNetwork.Trainer
Trainer for neural networks.
-
Constructor Summary
Constructors Constructor Description NeuralNetwork(int... numUnits)
Constructor.NeuralNetwork(NeuralNetwork.ActivationFunction activation, double alpha, double lambda, int... numUnits)
Constructor.NeuralNetwork(NeuralNetwork.ActivationFunction activation, int... numUnits)
Constructor.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NeuralNetwork
clone()
double
getLearningRate()
Returns the learning rate.double
getMomentum()
Returns the momentum factor.double[][]
getWeight(int layer)
Returns the weights of a layer.double
getWeightDecay()
Returns the weight decay factor.void
learn(double[][] x, double[] y)
Trains the neural network with the given dataset for one epoch by stochastic gradient descent.void
learn(double[] x, double y)
Online update the regression model with a new training instance.double
learn(double[] x, double y, double weight)
Update the neural network with given instance and associated target value.double
predict(double[] x)
Predicts the dependent variable of an instance.void
setLearningRate(double eta)
Sets the learning rate.void
setMomentum(double alpha)
Sets the momentum factor.void
setWeightDecay(double lambda)
Sets the weight decay factor.-
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
-
Methods inherited from interface smile.regression.Regression
predict
-
-
-
-
Constructor Detail
-
NeuralNetwork
public NeuralNetwork(int... numUnits)
Constructor. The default activation function is the logistic sigmoid function.- Parameters:
numUnits
- the number of units in each layer.
-
NeuralNetwork
public NeuralNetwork(NeuralNetwork.ActivationFunction activation, int... numUnits)
Constructor.- Parameters:
activation
- the activation function of output layer.numUnits
- the number of units in each layer.
-
NeuralNetwork
public NeuralNetwork(NeuralNetwork.ActivationFunction activation, double alpha, double lambda, int... numUnits)
Constructor.- Parameters:
activation
- the activation function of output layer.numUnits
- the number of units in each layer.
-
-
Method Detail
-
clone
public NeuralNetwork clone()
- Overrides:
clone
in classjava.lang.Object
-
setLearningRate
public void setLearningRate(double eta)
Sets the learning rate.- Parameters:
eta
- the learning rate.
-
getLearningRate
public double getLearningRate()
Returns the learning rate.
-
setMomentum
public void setMomentum(double alpha)
Sets the momentum factor.- Parameters:
alpha
- the momentum factor.
-
getMomentum
public double getMomentum()
Returns the momentum factor.
-
setWeightDecay
public void setWeightDecay(double lambda)
Sets the weight decay factor. After each weight update, every weight is simply ''decayed'' or shrunk according w = w * (1 - eta * lambda).- Parameters:
lambda
- the weight decay for regularization.
-
getWeightDecay
public double getWeightDecay()
Returns the weight decay factor.
-
getWeight
public double[][] getWeight(int layer)
Returns the weights of a layer.- Parameters:
layer
- the layer of netural network, 0 for input layer.
-
predict
public double predict(double[] x)
Description copied from interface:Regression
Predicts the dependent variable of an instance.- Specified by:
predict
in interfaceRegression<double[]>
- Parameters:
x
- the instance.- Returns:
- the predicted value of dependent variable.
-
learn
public double learn(double[] x, double y, double weight)
Update the neural network with given instance and associated target value. Note that this method is NOT multi-thread safe.- Parameters:
x
- the training instance.y
- the target value.weight
- a positive weight value associated with the training instance.- Returns:
- the weighted training error before back-propagation.
-
learn
public void learn(double[] x, double y)
Description copied from interface:OnlineRegression
Online update the regression model with a new training instance. In general, this method may be NOT multi-thread safe.- Specified by:
learn
in interfaceOnlineRegression<double[]>
- Parameters:
x
- training instance.y
- response variable.
-
learn
public void learn(double[][] x, double[] y)
Trains the neural network with the given dataset for one epoch by stochastic gradient descent.- Parameters:
x
- training instances.y
- training labels in [0, k), where k is the number of classes.
-
-