37 #ifndef VIGRA_REGRESSION_HXX
38 #define VIGRA_REGRESSION_HXX
41 #include "linear_solve.hxx"
42 #include "singular_value_decomposition.hxx"
43 #include "numerictraits.hxx"
44 #include "functorexpression.hxx"
78 template <
class T,
class C1,
class C2,
class C3>
82 std::string method =
"QR")
120 template <
class T,
class C1,
class C2,
class C3,
class C4>
128 const unsigned int rows =
rowCount(A);
131 vigra_precondition(rows >= cols,
132 "weightedLeastSquares(): Input matrix A must be rectangular with rowCount >= columnCount.");
133 vigra_precondition(
rowCount(b) == rows,
134 "weightedLeastSquares(): Shape mismatch between matrices A and b.");
136 "weightedLeastSquares(): Weight matrix has wrong shape.");
138 "weightedLeastSquares(): Result matrix x has wrong shape.");
142 for(
unsigned int k=0; k<rows; ++k)
144 vigra_precondition(weights(k,0) >= 0,
145 "weightedLeastSquares(): Weights must be positive.");
147 for(
unsigned int l=0; l<cols; ++l)
148 wa(k,l) = w * A(k,l);
149 for(
unsigned int l=0; l<rhsCount; ++l)
150 wb(k,l) = w * b(k,l);
180 template <
class T,
class C1,
class C2,
class C3>
187 const unsigned int rows =
rowCount(A);
190 vigra_precondition(rows >= cols,
191 "ridgeRegression(): Input matrix A must be rectangular with rowCount >= columnCount.");
192 vigra_precondition(
rowCount(b) == rows,
193 "ridgeRegression(): Shape mismatch between matrices A and b.");
195 "ridgeRegression(): Result matrix x has wrong shape.");
196 vigra_precondition(lambda >= 0.0,
197 "ridgeRegression(): lambda >= 0.0 required.");
199 unsigned int m = rows;
200 unsigned int n = cols;
205 if(rank < n && lambda == 0.0)
209 for(
unsigned int k=0; k<cols; ++k)
210 for(
unsigned int l=0; l<rhsCount; ++l)
211 t(k,l) *= s(k,0) / (
sq(s(k,0)) + lambda);
253 template <
class T,
class C1,
class C2,
class C3,
class C4>
261 const unsigned int rows =
rowCount(A);
264 vigra_precondition(rows >= cols,
265 "weightedRidgeRegression(): Input matrix A must be rectangular with rowCount >= columnCount.");
266 vigra_precondition(
rowCount(b) == rows,
267 "weightedRidgeRegression(): Shape mismatch between matrices A and b.");
269 "weightedRidgeRegression(): Weight matrix has wrong shape.");
271 "weightedRidgeRegression(): Result matrix x has wrong shape.");
272 vigra_precondition(lambda >= 0.0,
273 "weightedRidgeRegression(): lambda >= 0.0 required.");
277 for(
unsigned int k=0; k<rows; ++k)
279 vigra_precondition(weights(k,0) >= 0,
280 "weightedRidgeRegression(): Weights must be positive.");
282 for(
unsigned int l=0; l<cols; ++l)
283 wa(k,l) = w * A(k,l);
284 for(
unsigned int l=0; l<rhsCount; ++l)
285 wb(k,l) = w * b(k,l);
307 template <
class T,
class C1,
class C2,
class C3,
class Array>
314 const unsigned int rows =
rowCount(A);
316 const unsigned int lambdaCount = lambda.size();
317 vigra_precondition(rows >= cols,
318 "ridgeRegressionSeries(): Input matrix A must be rectangular with rowCount >= columnCount.");
320 "ridgeRegressionSeries(): Shape mismatch between matrices A and b.");
322 "ridgeRegressionSeries(): Result matrix x has wrong shape.");
324 unsigned int m = rows;
325 unsigned int n = cols;
333 for(
unsigned int i=0; i<lambdaCount; ++i)
335 vigra_precondition(lambda[i] >= 0.0,
336 "ridgeRegressionSeries(): lambda >= 0.0 required.");
337 if(lambda[i] == 0.0 && rank < rows)
339 for(
unsigned int k=0; k<cols; ++k)
340 xt(k,0) = xl(k,0) * s(k,0) / (
sq(s(k,0)) + lambda[i]);
354 enum Mode { LARS, LASSO, NNLASSO };
359 : max_solution_count(0),
360 unconstrained_dimension_count(0),
362 least_squares_solutions(true)
374 max_solution_count = (int)n;
390 else if(mode ==
"lasso")
392 else if(mode ==
"nnlasso")
395 vigra_fail(
"LeastAngleRegressionOptions.setMode(): Invalid mode.");
441 least_squares_solutions = select;
445 int max_solution_count, unconstrained_dimension_count;
447 bool least_squares_solutions;
452 template <
class T,
class C1,
class C2>
460 Matrix<T> R, qtb, lars_solution, lars_prediction, next_lsq_solution, next_lsq_prediction, searchVector;
466 A(Ai), b(bi), R(A), qtb(b),
467 lars_solution(A.shape(1), 1), lars_prediction(A.shape(0), 1),
468 next_lsq_solution(A.shape(1), 1), next_lsq_prediction(A.shape(0), 1), searchVector(A.shape(0), 1),
469 columnPermutation(A.shape(1))
471 for(
unsigned int k=0; k<columnPermutation.size(); ++k)
472 columnPermutation[k] = k;
476 LarsData(LarsData
const & d,
int asetSize)
477 : activeSetSize(asetSize),
478 A(d.R.subarray(Shape(0,0), Shape(d.A.shape(0), activeSetSize))), b(d.qtb), R(A), qtb(b),
479 lars_solution(d.lars_solution.subarray(Shape(0,0), Shape(activeSetSize, 1))), lars_prediction(d.lars_prediction),
480 next_lsq_solution(d.next_lsq_solution.subarray(Shape(0,0), Shape(activeSetSize, 1))),
481 next_lsq_prediction(d.next_lsq_prediction), searchVector(d.searchVector),
482 columnPermutation(A.shape(1))
484 for(
unsigned int k=0; k<columnPermutation.size(); ++k)
485 columnPermutation[k] = k;
489 template <
class T,
class C1,
class C2,
class Array1,
class Array2,
class Array3>
491 leastAngleRegressionMainLoop(LarsData<T, C1, C2> & d,
493 Array2 * lars_solutions, Array3 * lsq_solutions,
494 LeastAngleRegressionOptions
const & options)
496 using namespace vigra::functor;
498 typedef typename MultiArrayShape<2>::type Shape;
499 typedef typename Matrix<T>::view_type Subarray;
500 typedef ArrayVector<MultiArrayIndex> Permutation;
501 typedef typename Permutation::view_type ColumnSet;
503 vigra_precondition(d.activeSetSize > 0,
504 "leastAngleRegressionMainLoop() must not be called with empty active set.");
506 bool enforce_positive = (options.mode == LeastAngleRegressionOptions::NNLASSO);
507 bool lasso_modification = (options.mode != LeastAngleRegressionOptions::LARS);
514 if(maxSolutionCount == 0)
515 maxSolutionCount = lasso_modification
519 bool needToRemoveColumn =
false;
522 while(currentSolutionCount < maxSolutionCount)
525 ColumnSet inactiveSet = d.columnPermutation.subarray((
unsigned int)d.activeSetSize, (
unsigned int)cols);
528 Matrix<T> cLARS =
transpose(d.A) * (d.b - d.lars_prediction),
529 cLSQ =
transpose(d.A) * (d.b - d.next_lsq_prediction);
537 : argMax(
abs(cLARS));
538 T C =
abs(cLARS(cmaxIndex, 0));
540 Matrix<T> ac(cols - d.activeSetSize, 1);
543 T rho = cLSQ(inactiveSet[k], 0),
544 cc = C -
sign(rho)*cLARS(inactiveSet[k], 0);
549 ac(k,0) = cc / (cc + rho);
550 else if(enforce_positive)
553 ac(k,0) = cc / (cc - rho);
558 if(enforce_positive && needToRemoveColumn)
559 ac(columnToBeRemoved-d.activeSetSize,0) = 1.0;
564 columnToBeAdded =
argMin(ac);
567 T
gamma = (d.activeSetSize == maxRank)
569 : ac(columnToBeAdded, 0);
572 if(columnToBeAdded >= 0)
573 columnToBeAdded += d.activeSetSize;
576 needToRemoveColumn =
false;
577 if(lasso_modification)
580 Matrix<T> s(Shape(d.activeSetSize, 1), NumericTraits<T>::max());
583 if(( enforce_positive && d.next_lsq_solution(k,0) < 0.0) ||
584 (!enforce_positive &&
sign(d.lars_solution(k,0))*
sign(d.next_lsq_solution(k,0)) == -1.0))
585 s(k,0) = d.lars_solution(k,0) / (d.lars_solution(k,0) - d.next_lsq_solution(k,0));
588 columnToBeRemoved =
argMinIf(s, Arg1() <= Param(gamma));
589 if(columnToBeRemoved >= 0)
591 needToRemoveColumn =
true;
592 gamma = s(columnToBeRemoved, 0);
597 d.lars_prediction = gamma * d.next_lsq_prediction + (1.0 -
gamma) * d.lars_prediction;
598 d.lars_solution = gamma * d.next_lsq_solution + (1.0 - gamma) * d.lars_solution;
599 if(needToRemoveColumn)
600 d.lars_solution(columnToBeRemoved, 0) = 0.0;
603 ++currentSolutionCount;
604 activeSets.push_back(
typename Array1::value_type(d.columnPermutation.begin(), d.columnPermutation.begin()+d.activeSetSize));
606 if(lsq_solutions != 0)
610 ArrayVector<Matrix<T> > nnresults;
611 ArrayVector<ArrayVector<MultiArrayIndex> > nnactiveSets;
612 LarsData<T, C1, C2> nnd(d, d.activeSetSize);
614 leastAngleRegressionMainLoop(nnd, nnactiveSets, &nnresults, (Array3*)0,
615 LeastAngleRegressionOptions().leastSquaresSolutions(
false).nnlasso());
617 typename Array2::value_type nnlsq_solution(Shape(d.activeSetSize, 1));
618 for(
unsigned int k=0; k<nnactiveSets.back().size(); ++k)
620 nnlsq_solution(nnactiveSets.back()[k],0) = nnresults.back()[k];
623 lsq_solutions->push_back(
typename Array3::value_type());
624 lsq_solutions->back() = nnlsq_solution;
629 lsq_solutions->push_back(
typename Array3::value_type());
630 lsq_solutions->back() = d.next_lsq_solution.subarray(Shape(0,0), Shape(d.activeSetSize, 1));
633 if(lars_solutions != 0)
636 lars_solutions->push_back(
typename Array2::value_type());
637 lars_solutions->back() = d.lars_solution.subarray(Shape(0,0), Shape(d.activeSetSize, 1));
644 if(needToRemoveColumn)
647 if(columnToBeRemoved != d.activeSetSize)
651 detail::upperTriangularSwapColumns(columnToBeRemoved, d.activeSetSize, d.R, d.qtb, d.columnPermutation);
654 std::swap(d.lars_solution(columnToBeRemoved, 0), d.lars_solution(d.activeSetSize,0));
655 std::swap(d.next_lsq_solution(columnToBeRemoved, 0), d.next_lsq_solution(d.activeSetSize,0));
656 columnToBeRemoved = d.activeSetSize;
658 d.lars_solution(d.activeSetSize,0) = 0.0;
659 d.next_lsq_solution(d.activeSetSize,0) = 0.0;
663 vigra_invariant(columnToBeAdded >= 0,
664 "leastAngleRegression(): internal error (columnToBeAdded < 0)");
666 if(d.activeSetSize != columnToBeAdded)
668 std::swap(d.columnPermutation[d.activeSetSize], d.columnPermutation[columnToBeAdded]);
670 columnToBeAdded = d.activeSetSize;
674 d.next_lsq_solution(d.activeSetSize,0) = 0.0;
675 d.lars_solution(d.activeSetSize,0) = 0.0;
678 detail::qrColumnHouseholderStep(d.activeSetSize, d.R, d.qtb);
683 Subarray Ractive = d.R.subarray(Shape(0,0), Shape(d.activeSetSize, d.activeSetSize));
684 Subarray qtbactive = d.qtb.subarray(Shape(0,0), Shape(d.activeSetSize, 1));
685 Subarray next_lsq_solution_view = d.next_lsq_solution.subarray(Shape(0,0), Shape(d.activeSetSize, 1));
689 d.next_lsq_prediction.init(0.0);
691 d.next_lsq_prediction += next_lsq_solution_view(k,0)*
columnVector(d.A, d.columnPermutation[k]);
694 return (
unsigned int)currentSolutionCount;
697 template <
class T,
class C1,
class C2,
class Array1,
class Array2>
699 leastAngleRegressionImpl(MultiArrayView<2, T, C1>
const & A, MultiArrayView<2, T, C2>
const &b,
700 Array1 & activeSets, Array2 * lasso_solutions, Array2 * lsq_solutions,
701 LeastAngleRegressionOptions
const & options)
703 using namespace vigra::functor;
708 "leastAngleRegression(): Shape mismatch between matrices A and b.");
710 bool enforce_positive = (options.mode == LeastAngleRegressionOptions::NNLASSO);
712 detail::LarsData<T, C1, C2> d(A, b);
719 if(initialColumn == -1)
723 std::swap(d.columnPermutation[0], d.columnPermutation[initialColumn]);
725 detail::qrColumnHouseholderStep(0, d.R, d.qtb);
726 d.next_lsq_solution(0,0) = d.qtb(0,0) / d.R(0,0);
727 d.next_lsq_prediction = d.next_lsq_solution(0,0) *
columnVector(A, d.columnPermutation[0]);
728 d.searchVector = d.next_lsq_solution(0,0) *
columnVector(A, d.columnPermutation[0]);
730 return leastAngleRegressionMainLoop(d, activeSets, lasso_solutions, lsq_solutions, options);
880 template <
class T,
class C1,
class C2,
class Array1,
class Array2>
883 Array1 & activeSets, Array2 & solutions,
884 LeastAngleRegressionOptions
const & options = LeastAngleRegressionOptions())
886 if(options.least_squares_solutions)
887 return detail::leastAngleRegressionImpl(A, b, activeSets, (Array2*)0, &solutions, options);
889 return detail::leastAngleRegressionImpl(A, b, activeSets, &solutions, (Array2*)0, options);
892 template <
class T,
class C1,
class C2,
class Array1,
class Array2>
895 Array1 & activeSets, Array2 & lasso_solutions, Array2 & lsq_solutions,
896 LeastAngleRegressionOptions
const & options = LeastAngleRegressionOptions())
898 return detail::leastAngleRegressionImpl(A, b, activeSets, &lasso_solutions, &lsq_solutions, options);
919 template <
class T,
class C1,
class C2,
class C3>
925 "nonnegativeLeastSquares(): Matrix shape mismatch.");
927 "nonnegativeLeastSquares(): RHS and solution must be vectors (i.e. columnCount == 1).");
934 x.
init(NumericTraits<T>::zero());
935 if(activeSets.
size() > 0)
936 for(
unsigned int k=0; k<activeSets.
back().size(); ++k)
937 x(activeSets.
back()[k],0) = results.
back()[k];
952 using linalg::LeastAngleRegressionOptions;
956 #endif // VIGRA_REGRESSION_HXX
LeastAngleRegressionOptions & maxSolutionCount(unsigned int n)
Definition: regression.hxx:372
bool ridgeRegression(MultiArrayView< 2, T, C1 > const &A, MultiArrayView< 2, T, C2 > const &b, MultiArrayView< 2, T, C3 > &x, double lambda)
Definition: regression.hxx:182
bool weightedLeastSquares(MultiArrayView< 2, T, C1 > const &A, MultiArrayView< 2, T, C2 > const &b, MultiArrayView< 2, T, C3 > const &weights, MultiArrayView< 2, T, C4 > &x, std::string method="QR")
Definition: regression.hxx:122
reference back()
Definition: array_vector.hxx:293
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:725
LeastAngleRegressionOptions & nnlasso()
Definition: regression.hxx:426
std::string tolower(std::string s)
Definition: utilities.hxx:93
Pass options to leastAngleRegression().
Definition: regression.hxx:351
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:669
void transpose(const MultiArrayView< 2, T, C1 > &v, MultiArrayView< 2, T, C2 > &r)
Definition: matrix.hxx:963
int argMin(MultiArrayView< 2, T, C > const &a)
Definition: matrix.hxx:1932
const difference_type & shape() const
Definition: multi_array.hxx:1602
linalg::TemporaryMatrix< T > sign(MultiArrayView< 2, T, C > const &v)
LeastAngleRegressionOptions & leastSquaresSolutions(bool select=true)
Definition: regression.hxx:439
int argMinIf(MultiArrayView< 2, T, C > const &a, UnaryFunctor condition)
Definition: matrix.hxx:2000
LeastAngleRegressionOptions & lars()
Definition: regression.hxx:404
int argMaxIf(MultiArrayView< 2, T, C > const &a, UnaryFunctor condition)
Definition: matrix.hxx:2035
std::ptrdiff_t MultiArrayIndex
Definition: multi_iterator.hxx:348
void nonnegativeLeastSquares(MultiArrayView< 2, T, C1 > const &A, MultiArrayView< 2, T, C2 > const &b, MultiArrayView< 2, T, C3 > &x)
Definition: regression.hxx:921
bool ridgeRegressionSeries(MultiArrayView< 2, T, C1 > const &A, MultiArrayView< 2, T, C2 > const &b, MultiArrayView< 2, T, C3 > &x, Array const &lambda)
Definition: regression.hxx:309
bool leastSquares(MultiArrayView< 2, T, C1 > const &A, MultiArrayView< 2, T, C2 > const &b, MultiArrayView< 2, T, C3 > &x, std::string method="QR")
Definition: regression.hxx:80
LeastAngleRegressionOptions()
Definition: regression.hxx:358
linalg::TemporaryMatrix< T > sq(MultiArrayView< 2, T, C > const &v)
bool weightedRidgeRegression(MultiArrayView< 2, T, C1 > const &A, MultiArrayView< 2, T, C2 > const &b, MultiArrayView< 2, T, C3 > const &weights, MultiArrayView< 2, T, C4 > &x, double lambda)
Definition: regression.hxx:255
LeastAngleRegressionOptions & setMode(std::string mode)
Definition: regression.hxx:385
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:939
int argMax(MultiArrayView< 2, T, C > const &a)
Definition: matrix.hxx:1965
double gamma(double x)
Definition: mathutil.hxx:1500
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:682
LeastAngleRegressionOptions & lasso()
Definition: regression.hxx:415
MultiArrayView & init(const U &init)
Definition: multi_array.hxx:1214
bool linearSolveUpperTriangular(const MultiArrayView< 2, T, C1 > &r, const MultiArrayView< 2, T, C2 > &b, MultiArrayView< 2, T, C3 > x)
Definition: linear_solve.hxx:1014
unsigned int singularValueDecomposition(MultiArrayView< 2, T, C1 > const &A, MultiArrayView< 2, T, C2 > &U, MultiArrayView< 2, T, C3 > &S, MultiArrayView< 2, T, C4 > &V)
Definition: singular_value_decomposition.hxx:75
size_type size() const
Definition: array_vector.hxx:330
linalg::TemporaryMatrix< T > abs(MultiArrayView< 2, T, C > const &v)
unsigned int leastAngleRegression(...)
SquareRootTraits< FixedPoint< IntBits, FracBits > >::SquareRootResult sqrt(FixedPoint< IntBits, FracBits > v)
square root.
Definition: fixedpoint.hxx:616
bool linearSolve(const MultiArrayView< 2, T, C1 > &A, const MultiArrayView< 2, T, C2 > &b, MultiArrayView< 2, T, C3 > &res, std::string method="QR")
Definition: linear_solve.hxx:1173