/*
    -- MAGMA (version 2.6.2) --
       Univ. of Tennessee, Knoxville
       Univ. of California, Berkeley
       Univ. of Colorado, Denver
       @date March 2022

       @author Hartwig Anzt

       @generated from sparse_hip/src/zcustomilu.cpp, normal z -> s, Mon Mar 21 16:53:02 2022
*/
#include "magmasparse_internal.h"

#define REAL

/* For hipSPARSE, they use a separate real type than for hipBLAS */
#ifdef MAGMA_HAVE_HIP
  #define float float
#endif

// todo: make it spacific
#if CUDA_VERSION >= 11000 || defined(MAGMA_HAVE_HIP)
#define cusparseCreateSolveAnalysisInfo(info) {;}
#else
#define cusparseCreateSolveAnalysisInfo(info)                                                   \
    CHECK_CUSPARSE( cusparseCreateSolveAnalysisInfo( info ))
#endif

// todo: info is passed; buf has to be passed
#if CUDA_VERSION >= 11000 || defined(MAGMA_HAVE_HIP)
#define cusparseScsrsv_analysis(handle, trans, m, nnz, descr, val, row, col, info)              \
    {                                                                                           \
        csrsv2Info_t linfo = 0;                                                                 \
        int bufsize;                                                                            \
        void *buf;                                                                              \
        hipsparseCreateCsrsv2Info(&linfo);                                                       \
        hipsparseScsrsv2_bufferSize(handle, trans, m, nnz, descr, (float*)val, row, col,                 \
                                   linfo, &bufsize);                                            \
        if (bufsize > 0)                                                                        \
           magma_malloc(&buf, bufsize);                                                         \
        hipsparseScsrsv2_analysis(handle, trans, m, nnz, descr, (float*)val, row, col, linfo,            \
                                 HIPSPARSE_SOLVE_POLICY_USE_LEVEL, buf);                         \
        if (bufsize > 0)                                                                        \
           magma_free(buf);                                                                     \
    }
#endif

/**
    Purpose
    -------

    Reads in an Incomplete LU preconditioner.

    Arguments
    ---------

    @param[in]
    A           magma_s_matrix
                input matrix A
                
    @param[in]
    b           magma_s_matrix
                input RHS b

    @param[in,out]
    precond     magma_s_preconditioner*
                preconditioner parameters
                
    @param[in]
    queue       magma_queue_t
                Queue to execute in.

    @ingroup magmasparse_sgepr
    ********************************************************************/
extern "C"
magma_int_t
magma_scustomilusetup(
    magma_s_matrix A,
    magma_s_matrix b,
    magma_s_preconditioner *precond,
    magma_queue_t queue )
{
    magma_int_t info = 0;

    hipsparseHandle_t cusparseHandle=NULL;
    hipsparseMatDescr_t descrL=NULL;
    hipsparseMatDescr_t descrU=NULL;
    
    magma_s_matrix hA={Magma_CSR};
    char preconditionermatrix[255];
    
    // first L
    snprintf( preconditionermatrix, sizeof(preconditionermatrix),
                "/Users/hanzt0114cl306/work/matrices/matrices/ILUT_L.mtx" );
    
    CHECK( magma_s_csr_mtx( &hA, preconditionermatrix , queue) );
    CHECK( magma_smtransfer( hA, &precond->L, Magma_CPU, Magma_DEV , queue ));
    // extract the diagonal of L into precond->d
    CHECK( magma_sjacobisetup_diagscal( precond->L, &precond->d, queue ));
    CHECK( magma_svinit( &precond->work1, Magma_DEV, hA.num_rows, 1, MAGMA_S_ZERO, queue ));

    magma_smfree( &hA, queue );
    
    // now U
    snprintf( preconditionermatrix, sizeof(preconditionermatrix),
                "/Users/hanzt0114cl306/work/matrices/matrices/ILUT_U.mtx" );

    CHECK( magma_s_csr_mtx( &hA, preconditionermatrix , queue) );
    CHECK( magma_smtransfer( hA, &precond->U, Magma_CPU, Magma_DEV , queue ));
    // extract the diagonal of U into precond->d2
    CHECK( magma_sjacobisetup_diagscal( precond->U, &precond->d2, queue ));
    CHECK( magma_svinit( &precond->work2, Magma_DEV, hA.num_rows, 1, MAGMA_S_ZERO, queue ));


    // CUSPARSE context //
    CHECK_CUSPARSE( hipsparseCreate( &cusparseHandle ));
    CHECK_CUSPARSE( hipsparseCreateMatDescr( &descrL ));
    CHECK_CUSPARSE( hipsparseSetMatType( descrL, HIPSPARSE_MATRIX_TYPE_TRIANGULAR ));
    CHECK_CUSPARSE( hipsparseSetMatDiagType( descrL, HIPSPARSE_DIAG_TYPE_UNIT ));
    CHECK_CUSPARSE( hipsparseSetMatIndexBase( descrL, HIPSPARSE_INDEX_BASE_ZERO ));
    CHECK_CUSPARSE( hipsparseSetMatFillMode( descrL, HIPSPARSE_FILL_MODE_LOWER ));
    cusparseCreateSolveAnalysisInfo( &precond->cuinfoL );
    cusparseScsrsv_analysis( cusparseHandle,
                             HIPSPARSE_OPERATION_NON_TRANSPOSE, precond->L.num_rows,
                             precond->L.nnz, descrL,
                             (float*)precond->L.val, precond->L.row, precond->L.col, 
                             precond->cuinfoL );
    
    
    CHECK_CUSPARSE( hipsparseCreateMatDescr( &descrU ));
    CHECK_CUSPARSE( hipsparseSetMatType( descrU, HIPSPARSE_MATRIX_TYPE_TRIANGULAR ));
    CHECK_CUSPARSE( hipsparseSetMatDiagType( descrU, HIPSPARSE_DIAG_TYPE_NON_UNIT ));
    CHECK_CUSPARSE( hipsparseSetMatIndexBase( descrU, HIPSPARSE_INDEX_BASE_ZERO ));
    CHECK_CUSPARSE( hipsparseSetMatFillMode( descrU, HIPSPARSE_FILL_MODE_UPPER ));
    cusparseCreateSolveAnalysisInfo( &precond->cuinfoU );
    cusparseScsrsv_analysis( cusparseHandle,
                             HIPSPARSE_OPERATION_NON_TRANSPOSE, precond->U.num_rows,
                             precond->U.nnz, descrU,
                             (float*)precond->U.val, precond->U.row, precond->U.col, 
                             precond->cuinfoU );

    
    cleanup:
        
    hipsparseDestroy( cusparseHandle );
    hipsparseDestroyMatDescr( descrL );
    hipsparseDestroyMatDescr( descrU );
    cusparseHandle=NULL;
    descrL=NULL;
    descrU=NULL;    
    magma_smfree( &hA, queue );
    
    return info;
}
    
