17#ifndef B2PASTIX_SOLVER_H_ 
   18#define B2PASTIX_SOLVER_H_ 
   19#include "b2sparse_solver.H" 
   26namespace b2000 { 
namespace b2linalg {
 
   29class PASTIX_LDLt_seq_sparse_direct_solver : 
public LDLt_sparse_solver<T> {
 
   31    PASTIX_LDLt_seq_sparse_direct_solver() {
 
   36          size_t s, 
size_t nnz, 
const size_t* colind, 
const size_t* rowind, 
const T* value,
 
   37          const int connectivity, 
const Dictionary& dictionary) {}
 
   39    void update_value() {}
 
   42          size_t s, 
size_t nrhs, 
const T* b, 
size_t ldb, T* x, 
size_t ldx,
 
   43          char left_or_right = 
' ') {
 
   49class PASTIX_LDLt_seq_extension_sparse_direct_solver : 
public LDLt_extension_sparse_solver<T> {
 
   51    PASTIX_LDLt_seq_extension_sparse_direct_solver() {
 
   56          size_t s, 
size_t nnz, 
const size_t* colind, 
const size_t* rowind, 
const T* value,
 
   57          size_t s_ext, 
const int connectivity, 
const Dictionary& dictionary) {}
 
   59    void update_value() {}
 
   62          size_t s, 
size_t nrhs, 
const T* b, 
size_t ldb, T* x, 
size_t ldx, 
const T* ma = 0,
 
   63          const T* mb = 0, 
const T* mc = 0, 
char left_or_right = 
' ') {
 
  107class PASTIX_LDLt_seq_sparse_direct_solver<double> : 
public LDLt_sparse_solver<double> {
 
  110    PASTIX_LDLt_seq_sparse_direct_solver() {
 
  112        pastixInitParam(iparm, dparm);
 
  114        iparm[IPARM_FACTORIZATION] = PastixFactLDLT;
 
  117        iparm[IPARM_THREAD_NBR] = 1;
 
  120        iparm[IPARM_SCHEDULER] = PastixSchedSequential;
 
  124        dparm[DPARM_EPSILON_MAGN_CTRL] = -1e-14;
 
  126        pastixInit(&pastix_data, MPI_COMM_WORLD, iparm, dparm);
 
  130    ~PASTIX_LDLt_seq_sparse_direct_solver() { pastixFinalize(&pastix_data); }
 
  133          size_t s, 
size_t nnz, 
const size_t* colind, 
const size_t* rowind, 
const double* value,
 
  134          const int connectivity, 
const Dictionary& dictionary) {
 
  135        if (s == 0) { 
return; }
 
  136        logging::Logger logger =
 
  140        colptr_loc.assign(colind, colind + s + 1);
 
  141        row.assign(rowind, rowind + nnz);
 
  144        spm = std::make_unique<spmatrix_t>();
 
  147        spm->values = 
const_cast<double*
>(value);
 
  153        spm->fmttype = SpmCSC;
 
  156        spm->mtxtype = SpmSymmetric;
 
  157        spm->colptr = colptr_loc.data();
 
  158        spm->rowptr = row.data();
 
  162        spm->flttype = SpmDouble;
 
  165        spmUpdateComputedFields(spm.get());
 
  167#ifdef PASTIX_DEBUG_OUTPUT 
  174        rc = spmCheckAndCorrect(spm.get(), &spm2);
 
  182        FILE* myMatrix = fopen(
"spm.out", 
"a");
 
  183        spmSave(spm.get(), myMatrix);
 
  191        pastix_task_analyze(pastix_data, spm.get());
 
  197        pastix_task_numfact(pastix_data, spm.get());
 
  199        updated_value = 
false;
 
  202    void update_value() { updated_value = 
true; }
 
  205          size_t s, 
size_t nrhs, 
const double* b, 
size_t ldb, 
double* x, 
size_t ldx,
 
  206          char left_or_right = 
' ') {
 
  207        if (s == 0) { 
return; }
 
  208        logging::Logger logger =
 
  212            pastix_task_numfact(pastix_data, spm.get());
 
  213            updated_value = 
false;
 
  216#ifdef PASTIX_DEBUG_OUTPUT 
  219        FILE* myMatrix = fopen(
"rhs.out", 
"a");
 
  220        spmPrintRHS(spm.get(), nrhs, x, ldx, myMatrix);
 
  226            if (b != x) { std::copy(b, b + s, x); }
 
  227            pastix_task_solve(pastix_data, nrhs, x, ldx);
 
  230            std::vector<double> x_value(ldx * nrhs);
 
  232            for (
size_t i = 0; i != nrhs; ++i) {
 
  233                std::copy(b + i * ldb, b + i * ldb + s, x_value.begin() + i * ldx);
 
  235            pastix_task_solve(pastix_data, nrhs, x_value.data(), ldx);
 
  238            for (
size_t i = 0; i != nrhs; ++i) {
 
  239                std::copy(x_value.begin() + i * ldx, x_value.begin() + i * ldx + s, x + i * ldx);
 
  246    pastix_data_t* pastix_data{
nullptr};
 
  247    bool updated_value{
false};
 
  249    std::vector<int> colptr_loc;
 
  250    std::vector<int> row;
 
  252    std::unique_ptr<spmatrix_t> spm;
 
  253    spm_int_t iparm[IPARM_SIZE];
 
  254    double dparm[DPARM_SIZE];
 
  258class PASTIX_LDLt_seq_extension_sparse_direct_solver<double>
 
  259    : 
public LDLt_extension_sparse_solver<double>,
 
  260      public PASTIX_LDLt_seq_sparse_direct_solver<double> {
 
  262    PASTIX_LDLt_seq_extension_sparse_direct_solver()
 
  263        : PASTIX_LDLt_seq_sparse_direct_solver<double>(), div(0) {}
 
  266          size_t size_, 
size_t nnz_, 
const size_t* colind_, 
const size_t* rowind_,
 
  267          const double* value_, 
size_t size_ext_, 
const int connectivity,
 
  268          const Dictionary& dictionary) {
 
  271        PASTIX_LDLt_seq_sparse_direct_solver<double>::init(
 
  272              size_, nnz_, colind_, rowind_, value_, connectivity, dictionary);
 
  273        m_ab.resize(size_ * 2);
 
  276    void update_value() { PASTIX_LDLt_seq_sparse_direct_solver<double>::update_value(); }
 
  279          size_t s, 
size_t nrhs, 
const double* b, 
size_t ldb, 
double* x, 
size_t ldx,
 
  280          const double* ma_ = 0, 
const double* mb_ = 0, 
const double* mc_ = 0,
 
  281          char left_or_right = 
' ') {
 
  283            std::copy(ma_, ma_ + s - 1, &m_ab[0]);
 
  284            std::copy(mb_, mb_ + s - 1, &m_ab[s - 1]);
 
  285            PASTIX_LDLt_seq_sparse_direct_solver<double>::resolve(
 
  286                  s - 1, 2, &m_ab[0], s - 1, &m_ab[0], s - 1);
 
  287            div = 1 / (*mc_ - blas::dot(s - 1, ma_, 1, &m_ab[s - 1], 1));
 
  290        for (
size_t i = 0; i != nrhs; ++i) {
 
  291            const double x2 = x[ldx * i + s - 1] =
 
  292                  div * (b[ldb * i + s - 1] - blas::dot(s - 1, b + ldb * i, 1, &m_ab[s - 1], 1));
 
  293            PASTIX_LDLt_seq_sparse_direct_solver<double>::resolve(
 
  294                  s - 1, 1, b + ldb * i, ldb, x + ldx * i, ldx);
 
  295            blas::axpy(s - 1, -x2, &m_ab[0], 1, x + ldx * i, 1);
 
  300    std::vector<double> m_ab;
 
#define THROW
Definition b2exception.H:198
 
Logger & get_logger(const std::string &logger_name="")
Definition b2logging.H:829
 
Contains the base classes for implementing Finite Elements.
Definition b2boundary_condition.H:32
 
GenericException< UnimplementedError_name > UnimplementedError
Definition b2exception.H:314