From 9433df83a773d3ccfe0a481ae36e5e3a6e60fd50 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Tue, 1 Jul 2008 16:20:06 +0000 Subject: [PATCH] * resurected Flagged::_expression used to optimize m+=(a*b).lazy() (equivalent to the GEMM blas routine) * added a GEMM benchmark --- Eigen/src/Core/Flagged.h | 2 + bench/benchBlasGemm.cpp | 232 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 234 insertions(+) create mode 100644 bench/benchBlasGemm.cpp diff --git a/Eigen/src/Core/Flagged.h b/Eigen/src/Core/Flagged.h index db8e2738e..2b26e3016 100644 --- a/Eigen/src/Core/Flagged.h +++ b/Eigen/src/Core/Flagged.h @@ -113,6 +113,8 @@ template clas m_matrix.const_cast_derived().template writePacket(index, x); } + const ExpressionType& _expression() const { return m_matrix; } + protected: ExpressionTypeNested m_matrix; }; diff --git a/bench/benchBlasGemm.cpp b/bench/benchBlasGemm.cpp new file mode 100644 index 000000000..d22af89da --- /dev/null +++ b/bench/benchBlasGemm.cpp @@ -0,0 +1,232 @@ + + +// #define EIGEN_DEFAULT_TO_ROW_MAJOR +#define _FLOAT + +#include +#include +#include "BenchTimer.h" + +// include the BLAS headers +#include +#include + +#ifdef _FLOAT +typedef float Scalar; +#define CBLAS_GEMM cblas_sgemm +#else +typedef double Scalar; +#define CBLAS_GEMM cblas_dgemm +#endif + + +typedef Eigen::Matrix MyMatrix; +void bench_eigengemm(MyMatrix& mc, const MyMatrix& ma, const MyMatrix& mb, int nbloops); +void bench_eigengemm_normal(MyMatrix& mc, const MyMatrix& ma, const MyMatrix& mb, int nbloops); +void check_product(int M, int N, int K); +void check_product(void); + +int main(int argc, char *argv[]) +{ + { + int aux; + asm( + "stmxcsr %[aux] \n\t" + "orl $32832, %[aux] \n\t" + "ldmxcsr %[aux] \n\t" + : : [aux] "m" (aux)); + } + + int nbtries=1, nbloops=1, M, N, K; + + if (argc==2) + { + if (std::string(argv[1])=="check") + check_product(); + else + M = N = K = atoi(argv[1]); + } + else if ((argc==3) && (std::string(argv[1])=="auto")) + { + M = N = K = atoi(argv[2]); + nbloops = 1000000000/(M*M*M); + if (nbloops<1) + nbloops = 1; + nbtries = 6; + } + else if (argc==4) + { + M = N = K = atoi(argv[1]); + nbloops = atoi(argv[2]); + nbtries = atoi(argv[3]); + } + else if (argc==6) + { + M = atoi(argv[1]); + N = atoi(argv[2]); + K = atoi(argv[3]); + nbloops = atoi(argv[4]); + nbtries = atoi(argv[5]); + } + else + { + std::cout << "Usage: " << argv[0] << " size nbloops nbtries\n"; + std::cout << "Usage: " << argv[0] << " M N K nbloops nbtries\n"; + exit(1); + } + + double nbmad = double(M) * double(N) * double(K) * double(nbloops); + + if (!(std::string(argv[1])=="auto")) + std::cout << M << " x " << N << " x " << K << "\n"; + + Scalar alpha, beta; + MyMatrix ma(M,K), mb(K,N), mc(M,N); + ma = MyMatrix::random(M,K); + mb = MyMatrix::random(K,N); + mc = MyMatrix::random(M,N); + + Eigen::BenchTimer timer; + + // we simply compute c += a*b, so: + alpha = 1; + beta = 1; + + // bench cblas + // ROWS_A, COLS_B, COLS_A, 1.0, A, COLS_A, B, COLS_B, 0.0, C, COLS_B); + if (!(std::string(argv[1])=="auto")) + { + timer.reset(); + for (uint k=0 ; k& >*//*(mc).operator+=( (ma * mb).lazy() );*/ + +// Flagged, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>( +// Product(ma, mb))); + #endif +} + +void bench_eigengemm_normal(MyMatrix& mc, const MyMatrix& ma, const MyMatrix& mb, int nbloops) +{ + for (uint j=0 ; j(ma,mb).lazy(); +} + +#define MYVERIFY(A,M) if (!(A)) { \ + std::cout << "FAIL: " << M << "\n"; \ + } + +void check_product(int M, int N, int K) +{ + MyMatrix ma(M,K), mb(K,N), mc(M,N), maT(K,M), mbT(N,K), meigen(M,N), mref(M,N); + ma = MyMatrix::random(M,K); + mb = MyMatrix::random(K,N); + maT = ma.transpose(); + mbT = mb.transpose(); + mc = MyMatrix::random(M,N); + + MyMatrix::Scalar eps = 1e-4; + + meigen = mref = mc; + CBLAS_GEMM(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, K, 1, ma.data(), M, mb.data(), K, 1, mref.data(), M); + meigen += ma * mb; + MYVERIFY(meigen.isApprox(mref, eps),". * ."); + +// meigen = mref = mc; +// CBLAS_GEMM(CblasColMajor, CblasTrans, CblasNoTrans, M, N, K, 1, maT.data(), K, mb.data(), K, 1, mref.data(), M); +// meigen += maT.transpose() * mb; +// MYVERIFY(meigen.isApprox(mref, eps),"T * ."); +// +// meigen = mref = mc; +// CBLAS_GEMM(CblasColMajor, CblasTrans, CblasTrans, M, N, K, 1, maT.data(), K, mbT.data(), N, 1, mref.data(), M); +// meigen += (maT.transpose()) * (mbT.transpose()); +// MYVERIFY(meigen.isApprox(mref, eps),"T * T"); +// +// meigen = mref = mc; +// CBLAS_GEMM(CblasColMajor, CblasNoTrans, CblasTrans, M, N, K, 1, ma.data(), M, mbT.data(), N, 1, mref.data(), M); +// meigen += ma * mbT.transpose(); +// MYVERIFY(meigen.isApprox(mref, eps),". * T"); +} + +void check_product(void) +{ + int M, N, K; + for (uint i=0; i<1000; ++i) + { + M = ei_random(1,64); + N = ei_random(1,768); + K = ei_random(1,768); + M = (0 + M) * 1; + std::cout << M << " x " << N << " x " << K << "\n"; + check_product(M, N, K); + } +} +