Defer set-to-zero in triangular = product so that no aliasing issue occur in the common:
A.triangularView() = B*A.sefladjointView()*B.adjoint()
case that used to work in 3.2.
(grafted from 655ba783f8
)
			
			
This commit is contained in:
		
							parent
							
								
									582c96691b
								
							
						
					
					
						commit
						0eaff8fdf2
					
				| @ -543,7 +543,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularViewImpl<_Mat | ||||
| 
 | ||||
|     template<typename ProductType> | ||||
|     EIGEN_DEVICE_FUNC | ||||
|     EIGEN_STRONG_INLINE TriangularViewType& _assignProduct(const ProductType& prod, const Scalar& alpha); | ||||
|     EIGEN_STRONG_INLINE TriangularViewType& _assignProduct(const ProductType& prod, const Scalar& alpha, bool beta); | ||||
| }; | ||||
| 
 | ||||
| /***************************************************************************
 | ||||
| @ -950,8 +950,7 @@ struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::assign_ | ||||
|     if((dst.rows()!=dstRows) || (dst.cols()!=dstCols)) | ||||
|       dst.resize(dstRows, dstCols); | ||||
| 
 | ||||
|     dst.setZero(); | ||||
|     dst._assignProduct(src, 1); | ||||
|     dst._assignProduct(src, 1, 0); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| @ -962,7 +961,7 @@ struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::add_ass | ||||
|   typedef Product<Lhs,Rhs,DefaultProduct> SrcXprType; | ||||
|   static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<Scalar,typename SrcXprType::Scalar> &) | ||||
|   { | ||||
|     dst._assignProduct(src, 1); | ||||
|     dst._assignProduct(src, 1, 1); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| @ -973,7 +972,7 @@ struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::sub_ass | ||||
|   typedef Product<Lhs,Rhs,DefaultProduct> SrcXprType; | ||||
|   static void run(DstXprType &dst, const SrcXprType &src, const internal::sub_assign_op<Scalar,typename SrcXprType::Scalar> &) | ||||
|   { | ||||
|     dst._assignProduct(src, -1); | ||||
|     dst._assignProduct(src, -1, 1); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
|  | ||||
| @ -199,7 +199,7 @@ struct general_product_to_triangular_selector; | ||||
| template<typename MatrixType, typename ProductType, int UpLo> | ||||
| struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,true> | ||||
| { | ||||
|   static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha) | ||||
|   static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta) | ||||
|   { | ||||
|     typedef typename MatrixType::Scalar Scalar; | ||||
|      | ||||
| @ -217,6 +217,9 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,true> | ||||
| 
 | ||||
|     Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); | ||||
| 
 | ||||
|     if(!beta) | ||||
|       mat.template triangularView<UpLo>().setZero(); | ||||
| 
 | ||||
|     enum { | ||||
|       StorageOrder = (internal::traits<MatrixType>::Flags&RowMajorBit) ? RowMajor : ColMajor, | ||||
|       UseLhsDirectly = _ActualLhs::InnerStrideAtCompileTime==1, | ||||
| @ -244,7 +247,7 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,true> | ||||
| template<typename MatrixType, typename ProductType, int UpLo> | ||||
| struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false> | ||||
| { | ||||
|   static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha) | ||||
|   static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta) | ||||
|   { | ||||
|     typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs; | ||||
|     typedef internal::blas_traits<Lhs> LhsBlasTraits; | ||||
| @ -260,6 +263,9 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false> | ||||
| 
 | ||||
|     typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); | ||||
| 
 | ||||
|     if(!beta) | ||||
|       mat.template triangularView<UpLo>().setZero(); | ||||
| 
 | ||||
|     enum { | ||||
|       IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0, | ||||
|       LhsIsRowMajor = _ActualLhs::Flags&RowMajorBit ? 1 : 0, | ||||
| @ -286,11 +292,11 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false> | ||||
| 
 | ||||
| template<typename MatrixType, unsigned int UpLo> | ||||
| template<typename ProductType> | ||||
| TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::_assignProduct(const ProductType& prod, const Scalar& alpha) | ||||
| TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::_assignProduct(const ProductType& prod, const Scalar& alpha, bool beta) | ||||
| { | ||||
|   eigen_assert(derived().nestedExpression().rows() == prod.rows() && derived().cols() == prod.cols()); | ||||
|    | ||||
|   general_product_to_triangular_selector<MatrixType, ProductType, UpLo, internal::traits<ProductType>::InnerSize==1>::run(derived().nestedExpression().const_cast_derived(), prod, alpha); | ||||
|   general_product_to_triangular_selector<MatrixType, ProductType, UpLo, internal::traits<ProductType>::InnerSize==1>::run(derived().nestedExpression().const_cast_derived(), prod, alpha, beta); | ||||
|    | ||||
|   return derived(); | ||||
| } | ||||
|  | ||||
| @ -62,6 +62,19 @@ template<typename Scalar> void mmtr(int size) | ||||
|   CHECK_MMTR(matc, Upper, -= (s*sqc).template triangularView<Upper>()*sqc); | ||||
|   CHECK_MMTR(matc, Lower, = (s*sqr).template triangularView<Lower>()*sqc); | ||||
|   CHECK_MMTR(matc, Upper, += (s*sqc).template triangularView<Lower>()*sqc); | ||||
| 
 | ||||
|   // check aliasing
 | ||||
|   ref2 = ref1 = matc; | ||||
|   ref1 = sqc.adjoint() * matc * sqc; | ||||
|   ref2.template triangularView<Upper>() = ref1.template triangularView<Upper>(); | ||||
|   matc.template triangularView<Upper>() = sqc.adjoint() * matc * sqc; | ||||
|   VERIFY_IS_APPROX(matc, ref2); | ||||
| 
 | ||||
|   ref2 = ref1 = matc; | ||||
|   ref1 = sqc * matc * sqc.adjoint(); | ||||
|   ref2.template triangularView<Lower>() = ref1.template triangularView<Lower>(); | ||||
|   matc.template triangularView<Lower>() = sqc * matc * sqc.adjoint(); | ||||
|   VERIFY_IS_APPROX(matc, ref2); | ||||
| } | ||||
| 
 | ||||
| void test_product_mmtr() | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 Gael Guennebaud
						Gael Guennebaud