From f975b9bd3eb0a862efef290a63a3d1d20a03c099 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Fri, 30 Oct 2009 08:51:33 -0400 Subject: [PATCH] SVD::solve() : port to new API and improvements --- Eigen/src/LU/FullPivLU.h | 2 +- Eigen/src/SVD/SVD.h | 124 +++++++++++++++++++++++++++------------ test/svd.cpp | 4 +- 3 files changed, 89 insertions(+), 41 deletions(-) diff --git a/Eigen/src/LU/FullPivLU.h b/Eigen/src/LU/FullPivLU.h index a28a536b6..067b59549 100644 --- a/Eigen/src/LU/FullPivLU.h +++ b/Eigen/src/LU/FullPivLU.h @@ -200,7 +200,7 @@ template class FullPivLU return ei_fullpivlu_image_impl(*this, originalMatrix.derived()); } - /** This method returns a solution x to the equation Ax=b, where A is the matrix of which + /** \return a solution x to the equation Ax=b, where A is the matrix of which * *this is the LU decomposition. * * \param b the right-hand-side of the equation to solve. Can be a vector or a matrix, diff --git a/Eigen/src/SVD/SVD.h b/Eigen/src/SVD/SVD.h index da01cf396..807e7058c 100644 --- a/Eigen/src/SVD/SVD.h +++ b/Eigen/src/SVD/SVD.h @@ -25,6 +25,8 @@ #ifndef EIGEN_SVD_H #define EIGEN_SVD_H +template struct ei_svd_solve_impl; + /** \ingroup SVD_Module * \nonstableyet * @@ -40,24 +42,24 @@ */ template class SVD { - private: + public: typedef typename MatrixType::Scalar Scalar; typedef typename NumTraits::Real RealScalar; enum { + RowsAtCompileTime = MatrixType::RowsAtCompileTime, + ColsAtCompileTime = MatrixType::ColsAtCompileTime, PacketSize = ei_packet_traits::size, AlignmentMask = int(PacketSize)-1, - MinSize = EIGEN_ENUM_MIN(MatrixType::RowsAtCompileTime, MatrixType::ColsAtCompileTime) + MinSize = EIGEN_ENUM_MIN(RowsAtCompileTime, ColsAtCompileTime) }; - typedef Matrix ColVector; - typedef Matrix RowVector; + typedef Matrix ColVector; + typedef Matrix RowVector; - typedef Matrix MatrixUType; - typedef Matrix MatrixVType; - typedef Matrix SingularValuesType; - - public: + typedef Matrix MatrixUType; + typedef Matrix MatrixVType; + typedef Matrix SingularValuesType; /** * \brief Default Constructor. @@ -76,8 +78,24 @@ template class SVD compute(matrix); } - template - bool solve(const MatrixBase &b, ResultType* result) const; + /** \returns a solution of \f$ A x = b \f$ using the current SVD decomposition of A. + * + * \param b the right-hand-side of the equation to solve. + * + * \note_about_checking_solutions + * + * \note_about_arbitrary_choice_of_solution + * \note_about_using_kernel_to_study_multiple_solutions + * + * \sa MatrixBase::svd(), + */ + template + inline const ei_svd_solve_impl + solve(const MatrixBase& b) const + { + ei_assert(m_isInitialized && "SVD is not initialized."); + return ei_svd_solve_impl(*this, b.derived()); + } const MatrixUType& matrixU() const { @@ -108,6 +126,18 @@ template class SVD template void computeScalingRotation(ScalingType *positive, RotationType *unitary) const; + inline int rows() const + { + ei_assert(m_isInitialized && "SVD is not initialized."); + return m_rows; + } + + inline int cols() const + { + ei_assert(m_isInitialized && "SVD is not initialized."); + return m_cols; + } + protected: // Computes (a^2 + b^2)^(1/2) without destructive underflow or overflow. inline static Scalar pythag(Scalar a, Scalar b) @@ -133,6 +163,7 @@ template class SVD /** \internal */ SingularValuesType m_sigma; bool m_isInitialized; + int m_rows, m_cols; }; /** Computes / recomputes the SVD decomposition A = U S V^* of \a matrix @@ -144,8 +175,8 @@ template class SVD template SVD& SVD::compute(const MatrixType& matrix) { - const int m = matrix.rows(); - const int n = matrix.cols(); + const int m = m_rows = matrix.rows(); + const int n = m_cols = matrix.cols(); m_matU.resize(m, m); m_matU.setZero(); @@ -397,40 +428,57 @@ SVD& SVD::compute(const MatrixType& matrix) return *this; } -/** \returns the solution of \f$ A x = b \f$ using the current SVD decomposition of A. - * The parts of the solution corresponding to zero singular values are ignored. - * - * \sa MatrixBase::svd(), LU::solve(), LLT::solve() - */ -template -template -bool SVD::solve(const MatrixBase &b, ResultType* result) const +template +struct ei_traits > { - ei_assert(m_isInitialized && "SVD is not initialized."); + typedef Matrix ReturnMatrixType; +}; - const int rows = m_matU.rows(); - ei_assert(b.rows() == rows); +template +struct ei_svd_solve_impl : public ReturnByValue > +{ + typedef typename ei_cleantype::type RhsNested; + typedef SVD SVDType; + typedef typename MatrixType::RealScalar RealScalar; + typedef typename MatrixType::Scalar Scalar; + const SVDType& m_svd; + const typename Rhs::Nested m_rhs; - result->resize(m_matV.rows(), b.cols()); + ei_svd_solve_impl(const SVDType& svd, const Rhs& rhs) + : m_svd(svd), m_rhs(rhs) + {} - Scalar maxVal = m_sigma.cwise().abs().maxCoeff(); - for (int j=0; j void evalTo(Dest& dst) const { - Matrix aux = m_matU.transpose() * b.col(j); + ei_assert(m_rhs.rows() == m_svd.rows()); - for (int i = 0; i aux = m_svd.matrixU().adjoint() * m_rhs.col(j); - result->col(j) = m_matV * aux; + for (int i = 0; i void svd(const MatrixType& m) a += a * a.adjoint() + a1 * a1.adjoint(); } SVD svd(a); - svd.solve(b, &x); + x = svd.solve(b); VERIFY_IS_APPROX(a * x,b); } @@ -87,7 +87,7 @@ template void svd_verify_assert() MatrixType tmp; SVD svd; - VERIFY_RAISES_ASSERT(svd.solve(tmp, &tmp)) + VERIFY_RAISES_ASSERT(svd.solve(tmp)) VERIFY_RAISES_ASSERT(svd.matrixU()) VERIFY_RAISES_ASSERT(svd.singularValues()) VERIFY_RAISES_ASSERT(svd.matrixV())