Actual source code: baijsolvtrann.c

  1: #include <../src/mat/impls/baij/seq/baij.h>
  2: #include <petsc/private/kernels/blockinvert.h>

  4: /* ----------------------------------------------------------- */
  5: PetscErrorCode MatSolveTranspose_SeqBAIJ_N_inplace(Mat A,Vec bb,Vec xx)
  6: {
  7:   Mat_SeqBAIJ       *a   =(Mat_SeqBAIJ*)A->data;
  8:   IS                iscol=a->col,isrow=a->row;
  9:   PetscErrorCode    ierr;
 10:   const PetscInt    *r,*c,*rout,*cout,*ai=a->i,*aj=a->j,*vi;
 11:   PetscInt          i,nz,j;
 12:   const PetscInt    n  =a->mbs,bs=A->rmap->bs,bs2=a->bs2;
 13:   const MatScalar   *aa=a->a,*v;
 14:   PetscScalar       *x,*t,*ls;
 15:   const PetscScalar *b;

 18:   VecGetArrayRead(bb,&b);
 19:   VecGetArray(xx,&x);
 20:   t    = a->solve_work;

 22:   ISGetIndices(isrow,&rout); r = rout;
 23:   ISGetIndices(iscol,&cout); c = cout;

 25:   /* copy the b into temp work space according to permutation */
 26:   for (i=0; i<n; i++) {
 27:     for (j=0; j<bs; j++) {
 28:       t[i*bs+j] = b[c[i]*bs+j];
 29:     }
 30:   }

 32:   /* forward solve the upper triangular transpose */
 33:   ls = a->solve_work + A->cmap->n;
 34:   for (i=0; i<n; i++) {
 35:     PetscArraycpy(ls,t+i*bs,bs);
 36:     PetscKernel_w_gets_transA_times_v(bs,ls,aa+bs2*a->diag[i],t+i*bs);
 37:     v  = aa + bs2*(a->diag[i] + 1);
 38:     vi = aj + a->diag[i] + 1;
 39:     nz = ai[i+1] - a->diag[i] - 1;
 40:     while (nz--) {
 41:       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(*vi++),v,t+i*bs);
 42:       v += bs2;
 43:     }
 44:   }

 46:   /* backward solve the lower triangular transpose */
 47:   for (i=n-1; i>=0; i--) {
 48:     v  = aa + bs2*ai[i];
 49:     vi = aj + ai[i];
 50:     nz = a->diag[i] - ai[i];
 51:     while (nz--) {
 52:       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(*vi++),v,t+i*bs);
 53:       v += bs2;
 54:     }
 55:   }

 57:   /* copy t into x according to permutation */
 58:   for (i=0; i<n; i++) {
 59:     for (j=0; j<bs; j++) {
 60:       x[bs*r[i]+j]   = t[bs*i+j];
 61:     }
 62:   }

 64:   ISRestoreIndices(isrow,&rout);
 65:   ISRestoreIndices(iscol,&cout);
 66:   VecRestoreArrayRead(bb,&b);
 67:   VecRestoreArray(xx,&x);
 68:   PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n);
 69:   return(0);
 70: }

 72: PetscErrorCode MatSolveTranspose_SeqBAIJ_N(Mat A,Vec bb,Vec xx)
 73: {
 74:   Mat_SeqBAIJ       *a   =(Mat_SeqBAIJ*)A->data;
 75:   IS                iscol=a->col,isrow=a->row;
 76:   PetscErrorCode    ierr;
 77:   const PetscInt    *r,*c,*rout,*cout;
 78:   const PetscInt    n=a->mbs,*ai=a->i,*aj=a->j,*vi,*diag=a->diag;
 79:   PetscInt          i,j,nz;
 80:   const PetscInt    bs =A->rmap->bs,bs2=a->bs2;
 81:   const MatScalar   *aa=a->a,*v;
 82:   PetscScalar       *x,*t,*ls;
 83:   const PetscScalar *b;

 86:   VecGetArrayRead(bb,&b);
 87:   VecGetArray(xx,&x);
 88:   t    = a->solve_work;

 90:   ISGetIndices(isrow,&rout); r = rout;
 91:   ISGetIndices(iscol,&cout); c = cout;

 93:   /* copy the b into temp work space according to permutation */
 94:   for (i=0; i<n; i++) {
 95:     for (j=0; j<bs; j++) {
 96:       t[i*bs+j] = b[c[i]*bs+j];
 97:     }
 98:   }

100:   /* forward solve the upper triangular transpose */
101:   ls = a->solve_work + A->cmap->n;
102:   for (i=0; i<n; i++) {
103:     PetscArraycpy(ls,t+i*bs,bs);
104:     PetscKernel_w_gets_transA_times_v(bs,ls,aa+bs2*diag[i],t+i*bs);
105:     v  = aa + bs2*(diag[i] - 1);
106:     vi = aj + diag[i] - 1;
107:     nz = diag[i] - diag[i+1] - 1;
108:     for (j=0; j>-nz; j--) {
109:       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(vi[j]),v,t+i*bs);
110:       v -= bs2;
111:     }
112:   }

114:   /* backward solve the lower triangular transpose */
115:   for (i=n-1; i>=0; i--) {
116:     v  = aa + bs2*ai[i];
117:     vi = aj + ai[i];
118:     nz = ai[i+1] - ai[i];
119:     for (j=0; j<nz; j++) {
120:       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(vi[j]),v,t+i*bs);
121:       v += bs2;
122:     }
123:   }

125:   /* copy t into x according to permutation */
126:   for (i=0; i<n; i++) {
127:     for (j=0; j<bs; j++) {
128:       x[bs*r[i]+j]   = t[bs*i+j];
129:     }
130:   }

132:   ISRestoreIndices(isrow,&rout);
133:   ISRestoreIndices(iscol,&cout);
134:   VecRestoreArrayRead(bb,&b);
135:   VecRestoreArray(xx,&x);
136:   PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n);
137:   return(0);
138: }