Skip to content

Commit

Permalink
Do using std::sqrt; sqrt( x ); instead of std::sqrt( x );, and si…
Browse files Browse the repository at this point in the history
…milar functions, when x is a user-defined type (scalar_t). Resolves #43.
  • Loading branch information
mgates3 committed Nov 7, 2023
1 parent 3c47832 commit 577b5b1
Show file tree
Hide file tree
Showing 34 changed files with 113 additions and 62 deletions.
5 changes: 3 additions & 2 deletions include/blas/gemv.hh
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ void gemv(
blas::scalar_type<TA, TX, TY> beta,
TY *y, int64_t incy )
{
typedef blas::scalar_type<TA, TX, TY> scalar_t;
using std::swap;
using scalar_t = blas::scalar_type<TA, TX, TY>;

#define A(i_, j_) A[ (i_) + (j_)*lda ]

Expand Down Expand Up @@ -118,7 +119,7 @@ void gemv(
bool doconj = false;
if (layout == Layout::RowMajor) {
// A => A^T; A^T => A; A^H => A & conj
std::swap( m, n );
swap( m, n );
if (trans == Op::NoTrans) {
trans = Op::Trans;
}
Expand Down
5 changes: 3 additions & 2 deletions include/blas/hemm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ void hemm(
scalar_type<TA, TB, TC> beta,
TC *C, int64_t ldc )
{
typedef blas::scalar_type<TA, TB, TC> scalar_t;
using std::swap;
using scalar_t = blas::scalar_type<TA, TB, TC>;

#define A(i_, j_) A[ (i_) + (j_)*lda ]
#define B(i_, j_) B[ (i_) + (j_)*ldb ]
Expand Down Expand Up @@ -123,7 +124,7 @@ void hemm(
uplo = Uplo::Upper;
else if (uplo == Uplo::Upper)
uplo = Uplo::Lower;
std::swap( m, n );
swap( m, n );
}

// check remaining arguments
Expand Down
5 changes: 3 additions & 2 deletions include/blas/nrm2.hh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ nrm2(
int64_t n,
T const * x, int64_t incx )
{
typedef real_type<T> real_t;
using std::sqrt;
using real_t = real_type<T>;

// check arguments
blas_error_if( n < 0 ); // standard BLAS returns, doesn't fail
Expand All @@ -58,7 +59,7 @@ nrm2(
ix += incx;
}
}
return std::sqrt( result );
return sqrt( result );
}

} // namespace blas
Expand Down
6 changes: 4 additions & 2 deletions include/blas/swap.hh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ void swap(
TX *x, int64_t incx,
TY *y, int64_t incy )
{
using std::swap;

// check arguments
blas_error_if( n < 0 ); // standard BLAS returns, doesn't fail
blas_error_if( incx == 0 ); // standard BLAS doesn't detect inc[xy] == 0
Expand All @@ -50,15 +52,15 @@ void swap(
if (incx == 1 && incy == 1) {
// unit stride
for (int64_t i = 0; i < n; ++i) {
std::swap( x[i], y[i] );
swap( x[i], y[i] );
}
}
else {
// non-unit stride
int64_t ix = (incx > 0 ? 0 : (-n + 1)*incx);
int64_t iy = (incy > 0 ? 0 : (-n + 1)*incy);
for (int64_t i = 0; i < n; ++i) {
std::swap( x[ix], y[iy] );
swap( x[ix], y[iy] );
ix += incx;
iy += incy;
}
Expand Down
5 changes: 3 additions & 2 deletions include/blas/symm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ void symm(
scalar_type<TA, TB, TC> beta,
TC *C, int64_t ldc )
{
typedef blas::scalar_type<TA, TB, TC> scalar_t;
using std::swap;
using scalar_t = blas::scalar_type<TA, TB>;

#define A(i_, j_) A[ (i_) + (j_)*lda ]
#define B(i_, j_) B[ (i_) + (j_)*ldb ]
Expand Down Expand Up @@ -116,7 +117,7 @@ void symm(
uplo = Uplo::Upper;
else if (uplo == Uplo::Upper)
uplo = Uplo::Lower;
std::swap( m, n );
swap( m, n );
}

// check remaining arguments
Expand Down
5 changes: 3 additions & 2 deletions include/blas/trmm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ void trmm(
TA const *A, int64_t lda,
TB *B, int64_t ldb )
{
typedef blas::scalar_type<TA, TB> scalar_t;
using std::swap;
using scalar_t = blas::scalar_type<TA, TB>;

#define A(i_, j_) A[ (i_) + (j_)*lda ]
#define B(i_, j_) B[ (i_) + (j_)*ldb ]
Expand Down Expand Up @@ -129,7 +130,7 @@ void trmm(
uplo = Uplo::Upper;
else if (uplo == Uplo::Upper)
uplo = Uplo::Lower;
std::swap( m, n );
swap( m, n );
}

// check remaining arguments
Expand Down
5 changes: 3 additions & 2 deletions include/blas/trsm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ void trsm(
TA const *A, int64_t lda,
TB *B, int64_t ldb )
{
typedef blas::scalar_type<TA, TB> scalar_t;
using std::swap;
using scalar_t = blas::scalar_type<TA, TB>;

#define A(i_, j_) A[ (i_) + (j_)*lda ]
#define B(i_, j_) B[ (i_) + (j_)*ldb ]
Expand Down Expand Up @@ -134,7 +135,7 @@ void trsm(
uplo = Uplo::Upper;
else if (uplo == Uplo::Upper)
uplo = Uplo::Lower;
std::swap( m, n );
swap( m, n );
}

// check remaining arguments
Expand Down
6 changes: 4 additions & 2 deletions include/blas/util.hh
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,15 @@ private:
template <typename T>
T abs1( T x )
{
return std::abs( x );
using std::abs;
return abs( x );
}

template <typename T>
T abs1( std::complex<T> x )
{
return std::abs( real(x) ) + std::abs( imag(x) );
using std::abs;
return abs( real( x ) ) + abs( imag( x ) );
}

// -----------------------------------------------------------------------------
Expand Down
4 changes: 3 additions & 1 deletion src/device_batch_trsm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ void trsm(
#ifndef BLAS_HAVE_DEVICE
throw blas::Error( "device BLAS not available", __func__ );
#else
using std::swap;

blas_error_if( layout != Layout::ColMajor && layout != Layout::RowMajor );
blas_error_if( batch_size < 0 );
blas_error_if( info.size() != 0
Expand Down Expand Up @@ -81,7 +83,7 @@ void trsm(
// swap lower <=> upper, left <=> right, m <=> n
uplo_ = ( uplo_ == blas::Uplo::Lower ? blas::Uplo::Upper : blas::Uplo::Lower );
side_ = ( side_ == blas::Side::Left ? blas::Side::Right : blas::Side::Left );
std::swap( m_, n_ );
swap( m_, n_ );
}

// trsm needs only 2 ptr arrays (A and B). Allocate usual
Expand Down
4 changes: 3 additions & 1 deletion src/device_hemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ void hemm(
#ifndef BLAS_HAVE_DEVICE
throw blas::Error( "device BLAS not available", __func__ );
#else
using std::swap;

// check arguments
blas_error_if( layout != Layout::ColMajor &&
layout != Layout::RowMajor );
Expand Down Expand Up @@ -72,7 +74,7 @@ void hemm(
// swap left <=> right, lower <=> upper, m <=> n
side = (side == Side::Left ? Side::Right : Side::Left);
uplo = (uplo == Uplo::Lower ? Uplo::Upper : Uplo::Lower);
std::swap( m_, n_ );
swap( m_, n_ );
}

blas::internal_set_device( queue.device() );
Expand Down
4 changes: 3 additions & 1 deletion src/device_symm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ void symm(
#ifndef BLAS_HAVE_DEVICE
throw blas::Error( "device BLAS not available", __func__ );
#else
using std::swap;

// check arguments
blas_error_if( layout != Layout::ColMajor &&
layout != Layout::RowMajor );
Expand Down Expand Up @@ -72,7 +74,7 @@ void symm(
// swap left <=> right, lower <=> upper, m <=> n
side = (side == Side::Left ? Side::Right : Side::Left);
uplo = (uplo == Uplo::Lower ? Uplo::Upper : Uplo::Lower);
std::swap( m_, n_ );
swap( m_, n_ );
}

blas::internal_set_device( queue.device() );
Expand Down
4 changes: 3 additions & 1 deletion src/device_trmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ void trmm(
#ifndef BLAS_HAVE_DEVICE
throw blas::Error( "device BLAS not available", __func__ );
#else
using std::swap;

// check arguments
blas_error_if( layout != Layout::ColMajor &&
layout != Layout::RowMajor );
Expand Down Expand Up @@ -70,7 +72,7 @@ void trmm(
// swap lower <=> upper, left <=> right, m <=> n
uplo = (uplo == Uplo::Lower ? Uplo::Upper : Uplo::Lower);
side = (side == Side::Left ? Side::Right : Side::Left);
std::swap( m_, n_ );
swap( m_, n_ );
}

blas::internal_set_device( queue.device() );
Expand Down
4 changes: 3 additions & 1 deletion src/device_trsm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ void trsm(
#ifndef BLAS_HAVE_DEVICE
throw blas::Error( "device BLAS not available", __func__ );
#else
using std::swap;

// check arguments
blas_error_if( layout != Layout::ColMajor &&
layout != Layout::RowMajor );
Expand Down Expand Up @@ -70,7 +72,7 @@ void trsm(
// swap lower <=> upper, left <=> right, m <=> n
uplo = (uplo == Uplo::Lower ? Uplo::Upper : Uplo::Lower);
side = (side == Side::Left ? Side::Right : Side::Left);
std::swap( m_, n_ );
swap( m_, n_ );
}

blas::internal_set_device( queue.device() );
Expand Down
4 changes: 3 additions & 1 deletion src/gemv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ void gemv(
scalar_t beta,
scalar_t* y, int64_t incy )
{
using std::swap;

// check arguments
blas_error_if( layout != Layout::ColMajor &&
layout != Layout::RowMajor );
Expand Down Expand Up @@ -158,7 +160,7 @@ void gemv(
}
}
// A => A^T; A^T => A; A^H => A + conj
std::swap( m_, n_ );
swap( m_, n_ );
trans2 = (trans == Op::NoTrans ? Op::Trans : Op::NoTrans);
}
char trans_ = op2char( trans2 );
Expand Down
4 changes: 3 additions & 1 deletion src/hemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ void hemm(
scalar_t beta,
scalar_t* C, int64_t ldc )
{
using std::swap;

// check arguments
blas_error_if( layout != Layout::ColMajor &&
layout != Layout::RowMajor );
Expand Down Expand Up @@ -115,7 +117,7 @@ void hemm(
// swap left <=> right, lower <=> upper, m <=> n
side = (side == Side::Left ? Side::Right : Side::Left);
uplo = (uplo == Uplo::Lower ? Uplo::Upper : Uplo::Lower);
std::swap( m_, n_ );
swap( m_, n_ );
}
char side_ = side2char( side );
char uplo_ = uplo2char( uplo );
Expand Down
4 changes: 3 additions & 1 deletion src/symm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ void symm(
scalar_t beta,
scalar_t* C, int64_t ldc )
{
using std::swap;

// check arguments
blas_error_if( layout != Layout::ColMajor &&
layout != Layout::RowMajor );
Expand Down Expand Up @@ -149,7 +151,7 @@ void symm(
// swap left <=> right, lower <=> upper, m <=> n
side = (side == Side::Left ? Side::Right : Side::Left);
uplo = (uplo == Uplo::Lower ? Uplo::Upper : Uplo::Lower);
std::swap( m_, n_ );
swap( m_, n_ );
}
char side_ = side2char( side );
char uplo_ = uplo2char( uplo );
Expand Down
4 changes: 3 additions & 1 deletion src/trmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ void trmm(
scalar_t const* A, int64_t lda,
scalar_t* B, int64_t ldb )
{
using std::swap;

// check arguments
blas_error_if( layout != Layout::ColMajor &&
layout != Layout::RowMajor );
Expand Down Expand Up @@ -143,7 +145,7 @@ void trmm(
// swap lower <=> upper, left <=> right, m <=> n
uplo = (uplo == Uplo::Lower ? Uplo::Upper : Uplo::Lower);
side = (side == Side::Left ? Side::Right : Side::Left);
std::swap( m_, n_ );
swap( m_, n_ );
}
char side_ = side2char( side );
char uplo_ = uplo2char( uplo );
Expand Down
4 changes: 3 additions & 1 deletion src/trsm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ void trsm(
scalar_t const* A, int64_t lda,
scalar_t* B, int64_t ldb )
{
using std::swap;

// check arguments
blas_error_if( layout != Layout::ColMajor &&
layout != Layout::RowMajor );
Expand Down Expand Up @@ -143,7 +145,7 @@ void trsm(
// swap lower <=> upper, left <=> right, m <=> n
uplo = (uplo == Uplo::Lower ? Uplo::Upper : Uplo::Lower);
side = (side == Side::Left ? Side::Right : Side::Left);
std::swap( m_, n_ );
swap( m_, n_ );
}
char side_ = side2char( side );
char uplo_ = uplo2char( uplo );
Expand Down
22 changes: 13 additions & 9 deletions test/check_gemm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ void check_gemm(
#define C(i_, j_) C[ (i_) + (j_)*ldc ]
#define Cref(i_, j_) Cref[ (i_) + (j_)*ldcref ]

typedef blas::real_type<T> real_t;
using std::sqrt;
using std::abs;
using real_t = blas::real_type<T>;

assert( m >= 0 );
assert( n >= 0 );
Expand All @@ -52,15 +54,15 @@ void check_gemm(
real_t work[1], Cout_norm;
Cout_norm = lapack_lange( "f", m, n, C, ldc, work );
error[0] = Cout_norm
/ (sqrt(real_t(k)+2)*std::abs(alpha)*Anorm*Bnorm
+ 2*std::abs(beta)*Cnorm);
/ (sqrt( real_t( k ) + 2 ) * abs( alpha ) * Anorm * Bnorm
+ 2 * abs( beta ) * Cnorm);
if (verbose) {
printf( "error: ||Cout||=%.2e / (sqrt(k=%lld + 2)"
" * |alpha|=%.2e * ||A||=%.2e * ||B||=%.2e"
" + 2 * |beta|=%.2e * ||C||=%.2e) = %.2e\n",
Cout_norm, llong( k ),
std::abs( alpha ), Anorm, Bnorm,
std::abs( beta ), Cnorm, error[0] );
abs( alpha ), Anorm, Bnorm,
abs( beta ), Cnorm, error[0] );
}

// complex needs extra factor; see Higham, 2002, sec. 3.6.
Expand Down Expand Up @@ -103,6 +105,8 @@ void check_herk(
#define C(i_, j_) C[ (i_) + (j_)*ldc ]
#define Cref(i_, j_) Cref[ (i_) + (j_)*ldcref ]

using std::sqrt;
using std::abs;
typedef blas::real_type<T> real_t;

assert( n >= 0 );
Expand Down Expand Up @@ -133,15 +137,15 @@ void check_herk(
real_t work[1], Cout_norm;
Cout_norm = lapack_lanhe( "f", uplo2str(uplo), n, C, ldc, work );
error[0] = Cout_norm
/ (sqrt(real_t(k)+2)*std::abs(alpha)*Anorm*Bnorm
+ 2*std::abs(beta)*Cnorm);
/ (sqrt( real_t( k ) + 2 ) * abs( alpha ) * Anorm * Bnorm
+ 2 * abs( beta ) * Cnorm);
if (verbose) {
printf( "error: ||Cout||=%.2e / (sqrt(k=%lld + 2)"
" * |alpha|=%.2e * ||A||=%.2e * ||B||=%.2e"
" + 2 * |beta|=%.2e * ||C||=%.2e) = %.2e\n",
Cout_norm, llong( k ),
std::abs( alpha ), Anorm, Bnorm,
std::abs( beta ), Cnorm, error[0] );
abs( alpha ), Anorm, Bnorm,
abs( beta ), Cnorm, error[0] );
}

// complex needs extra factor; see Higham, 2002, sec. 3.6.
Expand Down
Loading

0 comments on commit 577b5b1

Please sign in to comment.