Package smile.regression
Class NeuralNetwork.Trainer
- java.lang.Object
-
- smile.regression.RegressionTrainer<double[]>
-
- smile.regression.NeuralNetwork.Trainer
-
- Enclosing class:
- NeuralNetwork
public static class NeuralNetwork.Trainer extends RegressionTrainer<double[]>
Trainer for neural networks.
-
-
Constructor Summary
Constructors Constructor Description Trainer(int... numUnits)
Constructor.Trainer(NeuralNetwork.ActivationFunction activation, int... numUnits)
Constructor.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NeuralNetwork.Trainer
setLearningRate(double eta)
Sets the learning rate.NeuralNetwork.Trainer
setMomentum(double alpha)
Sets the momentum factor.NeuralNetwork.Trainer
setNumEpochs(int epochs)
Sets the number of epochs of stochastic learning.NeuralNetwork.Trainer
setWeightDecay(double lambda)
Sets the weight decay factor.NeuralNetwork
train(double[][] x, double[] y)
Learns a regression model with given training data.-
Methods inherited from class smile.regression.RegressionTrainer
setAttributes
-
-
-
-
Constructor Detail
-
Trainer
public Trainer(int... numUnits)
Constructor. The default activation function is the logistic sigmoid function.- Parameters:
numUnits
- the number of units in each layer.
-
Trainer
public Trainer(NeuralNetwork.ActivationFunction activation, int... numUnits)
Constructor.- Parameters:
activation
- the activation function of output layer.numUnits
- the number of units in each layer.
-
-
Method Detail
-
setLearningRate
public NeuralNetwork.Trainer setLearningRate(double eta)
Sets the learning rate.- Parameters:
eta
- the learning rate.
-
setMomentum
public NeuralNetwork.Trainer setMomentum(double alpha)
Sets the momentum factor.- Parameters:
alpha
- the momentum factor.
-
setWeightDecay
public NeuralNetwork.Trainer setWeightDecay(double lambda)
Sets the weight decay factor. After each weight update, every weight is simply ''decayed'' or shrunk according w = w * (1 - eta * lambda).- Parameters:
lambda
- the weight decay for regularization.
-
setNumEpochs
public NeuralNetwork.Trainer setNumEpochs(int epochs)
Sets the number of epochs of stochastic learning.- Parameters:
epochs
- the number of epochs of stochastic learning.
-
train
public NeuralNetwork train(double[][] x, double[] y)
Description copied from class:RegressionTrainer
Learns a regression model with given training data.- Specified by:
train
in classRegressionTrainer<double[]>
- Parameters:
x
- the training instances.y
- the training response values.- Returns:
- a trained regression model.
-
-