!!****m* ABINIT/m_vtowfk
!! NAME
!!  m_vtowfk
!!
!! FUNCTION
!!
!!
!! COPYRIGHT
!!  Copyright (C) 1998-2025 ABINIT group (DCA, XG, GMR, MT)
!!  This file is distributed under the terms of the
!!  GNU General Public License, see ~abinit/COPYING
!!  or http://www.gnu.org/copyleft/gpl.txt .
!!
!! SOURCE

#if defined HAVE_CONFIG_H
#include "config.h"
#endif

#include "abi_common.h"

! nvtx related macro definition
#include "nvtx_macros.h"

module m_vtowfk

  use, intrinsic :: iso_fortran_env, only: int32, int64, real32, real64

 use defs_basis
 use m_abicore
 use m_errors
 use m_xmpi
 use m_efield
 use m_linalg_interfaces
 use m_cgtools
 use m_dtset
 use m_dtfil
 use m_xomp
 use m_xg
 use m_xg_nonlop

 use defs_abitypes, only : MPI_type
 use m_time,        only : timab, cwtime, cwtime_report, sec2str
 use m_fstrings,    only : sjoin, itoa, ftoa
 use m_hamiltonian, only : gs_hamiltonian_type
 use m_getghc,      only : getghc_nucdip
 use m_paw_dmft,    only : paw_dmft_type
 use m_pawcprj,     only : pawcprj_type, pawcprj_alloc, pawcprj_free, pawcprj_put,pawcprj_copy
 use m_paw_dmft,    only : paw_dmft_type
 use m_gwls_hamiltonian, only : build_H
 use m_fftcore,     only : fftcore_set_mixprec, fftcore_mixprec
 use m_cgwf,        only : cgwf
 use m_cgwf_cprj,   only : cgwf_cprj,mksubovl,cprj_update,cprj_update_oneband
 use m_lobpcgwf_old,only : lobpcgwf
 use m_lobpcgwf,    only : lobpcgwf2
 use m_chebfiwf,    only : chebfiwf2
 use m_chebfiwf_cprj,only : chebfiwf2_cprj
 use m_lobpcgwf_cprj,only : lobpcgwf2_cprj
 use m_spacepar,    only : meanvalue_g
 use m_chebfi,      only : chebfi
 use m_rmm_diis,    only : rmm_diis
 use m_nonlop,      only : nonlop !, nonlop_counter
 use m_prep_kgb,    only : prep_nonlop, prep_fourwf
 use m_cgprj,       only : cprj_rotate,xg_cprj_copy,XG_TO_CPRJ
 use m_fft,         only : fourwf
 use m_cgtk,        only : cgtk_fixphase
 use m_common,      only : get_gemm_nonlop_ompgpu_blocksize
 use m_gemm_nonlop_projectors, only : gemm_nonlop_block_size, gemm_nonlop_is_distributed

#if defined HAVE_YAKL
 use gator_mod
#endif

#if defined(HAVE_GPU_MARKERS)
 use m_nvtx_data
#endif

 implicit none

 private
!!***

 public :: vtowfk
!!***

contains
!!***

!!****f* ABINIT/vtowfk
!! NAME
!! vtowfk
!!
!! FUNCTION
!! This routine compute the partial density at a given k-point,
!! for a given spin-polarization, from a fixed Hamiltonian
!! but might also simply compute eigenvectors and eigenvalues at this k point
!!
!! INPUTS
!!  cgq = array that holds the WF of the nearest neighbours of
!!        the current k-point (electric field, MPI //)
!!  cpus= cpu time limit in seconds
!!  dtefield <type(efield_type)> = variables related to Berry phase
!!      calculations (see initberry.f)
!!  dtfil <type(datafiles_type)>=variables related to files
!!  dtset <type(dataset_type)>=all input variables for this dataset
!!  fixed_occ=true if electronic occupations are fixed (occopt<3)
!!  gs_hamk <type(gs_hamiltonian_type)>=all data for the Hamiltonian at k
!!  ibg=shift to be applied on the location of data in the array cprj
!!  icg=shift to be applied on the location of data in the array cg
!!  ikpt=number of the k-point
!!  iscf=(<= 0  =>non-SCF), >0 => SCF
!!  isppol= 1 for unpolarized, 2 for spin-polarized
!!  kg_k(3,npw_k)=reduced planewave coordinates.
!!  kinpw(npw_k)=(modified) kinetic energy for each plane wave (Hartree)
!!  mcg=second dimension of the cg array
!!  mcgq=second dimension of the cgq array (electric field, MPI //)
!!  mcprj=size of projected wave-functions array (cprj) =nspinor*mband*mkmem*nsppol
!!  mkgq = second dimension of pwnsfacq
!!  mpi_enreg=information about MPI parallelization
!!  mpw=maximum dimensioned size of npw
!!  natom=number of atoms in cell.
!!  nband_k=number of bands at this k point for that spin polarization
!!  nkpt=number of k points.
!!  istep=index of the number of steps in the routine scfcv
!!  nnsclo_now=number of non-self-consistent loops for the current vtrial
!!             (often 1 for SCF calculation, =nstep for non-SCF calculations)
!!  npw_k=number of plane waves at this k point
!!  npwarr(nkpt)=number of planewaves in basis at this k point
!!  occ_k(nband_k)=occupation number for each band (usually 2) for each k.
!!  optforces=option for the computation of forces
!!  prtvol=control print volume and debugging output
!!  pwind(pwind_alloc,2,3)= array used to compute
!!           the overlap matrix smat between k-points (see initberry.f)
!!  pwind_alloc= first dimension of pwind
!!  pwnsfac(2,pwind_alloc)= phase factors for non-symmorphic translations
!!                          (see initberry.f)
!!  pwnsfacq(2,mkgq)= phase factors for the nearest neighbours of the
!!                    current k-point (electric field, MPI //)
!!  usebanfft=flag for band-fft parallelism
!!  paw_dmft  <type(paw_dmft_type)>= paw+dmft related data
!!  wtk=weight assigned to the k point.
!!  zshift(nband_k)=energy shifts for the squared shifted hamiltonian algorithm
!!
!! OUTPUT
!!  dphase_k(3)=change in Zak phase for the current k-point
!!  eig_k(nband_k)=array for holding eigenvalues (hartree)
!!  ek_k(nband_k)=contribution from each band to kinetic energy, at this k-point
!!  ek_k_nd(2,nband_k,nband_k*use_dmft)=contribution to kinetic energy,
!!     including non-diagonal terms, at this k-point (usefull if use_dmft)
!!  end_k(nband_k)=contribution from each band to nuclear dipole energy, at this k-point
!!  resid_k(nband_k)=residuals for each band over all k points, BEFORE the band rotation.
!!   In input: previous residuals.
!!  ==== if optforces>0 ====
!!    grnl_k(3*natom,nband_k)=nonlocal gradients, at this k-point
!!  ==== if gs_hamk%usepaw==0 ====
!!    enlx_k(nband_k)=contribution from each band to
!!                    nonlocal pseudopotential + Fock-type part of total energy, at this k-point
!!  ==== if (gs_hamk%usepaw==1) ====
!!    cprj(natom,mcprj*usecprj)= wave functions projected with non-local projectors:
!!                               cprj(n,k,i)=<p_i|Cnk> where p_i is a non-local projector.
!!
!! SIDE EFFECTS
!!  cg(2,mcg)=updated wavefunctions
!!  rhoaug(n4,n5,n6,nvloc)= density in electrons/bohr**3, on the augmented fft grid.
!!                    (cumulative, so input as well as output). Update only
!!                    for occopt<3 (fixed occupation numbers)
!!  rmm_diis_status: Status of the RMM-DIIS eigensolver. See m_rmm_diis.
!!
!! NOTES
!!  The cprj are distributed over band and spinors processors.
!!  One processor doesn't know all the cprj.
!!  Only the mod((iband-1)/mpi_enreg%bandpp,mpi_enreg%nproc_band) projectors
!!  are stored on each proc.
!!
!! SOURCE

subroutine vtowfk(cg,cgq,cprj,cpus,dphase_k,dtefield,dtfil,dtset,&
& eig_k,ek_k,ek_k_nd,end_k,enlx_k,fixed_occ,grnl_k,gs_hamk,&
& ibg,icg,ikpt,iscf,isppol,kg_k,kinpw,mband_cprj,mcg,mcgq,mcprj,mkgq,mpi_enreg,&
& mpw,natom,nband_k,nbdbuf,nkpt,istep,nnsclo_now,npw_k,npwarr,occ_k,optforces,prtvol,&
& pwind,pwind_alloc,pwnsfac,pwnsfacq,resid_k,rhoaug,paw_dmft,wtk,xg_nonlop,zshift,rmm_diis_status)

!Arguments ------------------------------------
 integer, intent(in) :: ibg,icg,ikpt,iscf,isppol,mband_cprj,mcg,mcgq,mcprj,mkgq,mpw
 integer, intent(in) :: natom,nband_k,nbdbuf,nkpt,nnsclo_now,npw_k,optforces
 integer, intent(in) :: prtvol,pwind_alloc,istep
 logical,intent(in) :: fixed_occ
 real(dp), intent(in) :: cpus,wtk
 type(datafiles_type), intent(in) :: dtfil
 type(efield_type), intent(inout) :: dtefield
 type(dataset_type), intent(in) :: dtset
 type(gs_hamiltonian_type), intent(inout) :: gs_hamk
 type(MPI_type), intent(inout) :: mpi_enreg
 type(paw_dmft_type), intent(in)  :: paw_dmft
 integer, intent(in) :: kg_k(3,npw_k)
 integer, intent(in) :: npwarr(nkpt),pwind(pwind_alloc,2,3)
 integer, intent(inout) :: rmm_diis_status(2)
 real(dp), intent(in) :: cgq(2,mcgq),kinpw(npw_k),occ_k(nband_k)
 real(dp), intent(in) :: pwnsfac(2,pwind_alloc),pwnsfacq(2,mkgq)
 real(dp), intent(in) :: zshift(nband_k)
 real(dp), target, intent(out) :: eig_k(nband_k)
 real(dP), intent(out) ::ek_k(nband_k),dphase_k(3),ek_k_nd(2,nband_k,nband_k*paw_dmft%use_dmft)
 real(dp), intent(out) :: end_k(nband_k),enlx_k(nband_k)
 real(dp), intent(out),target :: grnl_k(3*natom,nband_k*optforces)
 real(dp), intent(inout) :: resid_k(nband_k)
 real(dp), intent(inout),target :: cg(2,mcg)
 real(dp), intent(inout) :: rhoaug(gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,gs_hamk%nvloc)
 type(pawcprj_type),intent(inout),target :: cprj(natom,mcprj*gs_hamk%usecprj)
 type(xg_nonlop_t),intent(in) :: xg_nonlop

!Local variables-------------------------------
 logical :: has_fock,xg_diago,update_cprj
 logical :: do_subdiago,do_ortho,rotate_subvnlx,use_rmm_diis,is_distrib_tmp
 integer,parameter :: level=112,tim_fourwf=2,tim_nonlop_prep=11,enough=3,tim_getcprj=5
 integer,save :: nskip=0
!     Flag use_subovl: 1 if "subovl" array is computed (see below)
!     subovl should be Identity (in that case we should use use_subovl=0)
!     But this is true only if conjugate gradient algo. converges
 integer :: use_subovl=0
 integer :: use_subvnlx=0
 integer :: use_totvnlx=0
 integer :: bandpp_cprj,blocksize,choice,cpopt,iband,iband1
 integer :: iblock,iblocksize,ibs,idir,ierr,igs,igsc,ii,inonsc
 integer :: iorder_cprj,ipw,ispinor,ispinor_index,istwf_k,iwavef,me_g0,mgsc,my_nspinor,n1,n2,n3 !kk
 integer :: nband_k_cprj,ncols_cprj,nblockbd,ncpgr,ndat,niter,nkpt_max,nnlout,ortalgo,ndat_fft
 integer :: paw_opt,quit,signs,space,spaceComm,tim_nonlop,wfoptalg,wfopta10
 integer :: gpu_option_tmp,nblk_gemm_nonlop,blksize_gemm_nonlop_tmp
 logical :: nspinor1TreatedByThisProc,nspinor2TreatedByThisProc
 real(dp) :: ar,ar_im,eshift,occblock,norm
 real(dp) :: max_resid,weight,cpu,wall,gflops
 character(len=50) :: iter_name
 character(len=500) :: msg
 real(dp) :: dummy(2,1),nonlop_dum(1,1),tsec(2)
 real(dp),allocatable :: cwavef1(:,:),cwavef_x(:,:),cwavef_y(:,:),cwavefb(:,:,:)

#if defined HAVE_GPU && defined HAVE_YAKL
 real(real64), ABI_CONTIGUOUS pointer :: cwavef(:,:)  => null()
#else
 real(dp),allocatable,target :: cwavef(:,:)
#endif

 real(dp),allocatable :: eig_save(:),enlout(:),evec(:,:),gsc(:,:),ghc_vectornd(:,:)
 real(dp),allocatable :: subham(:),subovl(:),subvnlx(:),totvnlx(:,:)


#if defined HAVE_GPU && defined HAVE_YAKL
 real(real64), ABI_CONTIGUOUS pointer :: wfraug(:,:,:,:)
#else
 real(dp),allocatable :: wfraug(:,:,:,:)
#endif

 real(dp),pointer :: cg_k(:,:),cg_k_block(:,:),grnl_k_block(:,:),eig_k_block(:)
 real(dp),pointer :: cwavef_iband(:,:)
 type(pawcprj_type),pointer :: cwaveprj(:,:)
 type(pawcprj_type),pointer :: cprj_cwavef_bands(:,:),cprj_cwavef(:,:)

 real(dp), allocatable :: weight_t(:) ! only allocated and used with GPU fourwf
 type(xgBlock_t) :: xgx0,xgeigen,xgforces
 type(xg_t) :: cprj_xgx0,cprj_work


! **********************************************************************

 DBG_ENTER("COLL")

 call timab(28,1,tsec) ! Keep track of total time spent in "vtowfk"

!Structured debugging if prtvol==-level
 if(prtvol==-level)then
   write(msg,'(80a,a,a)') ('=',ii=1,80),ch10,'vtowfk: enter'
   call wrtout(std_out,msg,'PERS')
 end if

!=========================================================================
!============= INITIALIZATIONS AND ALLOCATIONS ===========================
!=========================================================================

 nkpt_max=50; if(xmpi_paral==1)nkpt_max=-1

 wfoptalg=mod(dtset%wfoptalg,100); wfopta10=mod(wfoptalg,10)
 xg_diago = dtset%wfoptalg == 114 .or. dtset%wfoptalg == 111
 istwf_k=gs_hamk%istwf_k
 has_fock=(associated(gs_hamk%fockcommon))
 quit=0
 igsc=0

!Parallelization over spinors management
 my_nspinor=max(1,dtset%nspinor/mpi_enreg%nproc_spinor)
 if (mpi_enreg%paral_spinor==0) then
   ispinor_index=1
   nspinor1TreatedByThisProc=.true.
   nspinor2TreatedByThisProc=(dtset%nspinor==2)
 else
   ispinor_index=mpi_enreg%me_spinor+1
   nspinor1TreatedByThisProc=(mpi_enreg%me_spinor==0)
   nspinor2TreatedByThisProc=(mpi_enreg%me_spinor==1)
 end if

!Parallelism over FFT and/or bands: define sizes and tabs
 !if (mpi_enreg%paral_kgb==1) then
 nblockbd=nband_k/(mpi_enreg%nproc_band*mpi_enreg%bandpp)
 !else
 !  nblockbd=nband_k/mpi_enreg%nproc_fft
 !  if (nband_k/=nblockbd*mpi_enreg%nproc_fft) nblockbd=nblockbd+1
 !end if
 blocksize=nband_k/nblockbd

!Save eshift
 if(wfoptalg==3)then
   eshift=zshift(1)
   ABI_MALLOC(eig_save,(nband_k))
   eig_save(:)=eshift
 end if

 n1=gs_hamk%ngfft(1); n2=gs_hamk%ngfft(2); n3=gs_hamk%ngfft(3)

 ! Decide whether RMM-DIIS eigensolver should be activated.
 ! rmm_diis > 0 --> Activate it after (3 + rmm_diis) iterations with wfoptalg algorithm.
 ! rmm_diis < 0 --> Start with RMM-DIIS directly (risky)
 use_rmm_diis = .False.
 if (dtset%rmm_diis /= 0 .and. iscf > 0) then
   use_rmm_diis = istep > 3 + dtset%rmm_diis
   !if (use_rmm_diis) call wrtout(std_out, " Activating RMM-DIIS eigensolver in SCF mode.")
 end if
 !nonlop_counter = 0

 mgsc=0
 igsc=0
 if ((.not. xg_diago .and. dtset%cprj_in_memory==0) .or. dtset%rmm_diis /= 0) then
   mgsc=nband_k*npw_k*my_nspinor*gs_hamk%usepaw
   ABI_MALLOC_OR_DIE(gsc,(2,mgsc), ierr)
   gsc=zero
 else
   ABI_MALLOC(gsc,(0,0))
 end if

 if(wfopta10 /= 1 .and. .not. xg_diago) then
   !chebfi already does this stuff inside
   ABI_MALLOC(evec,(2*nband_k,nband_k))
   ABI_MALLOC(subham,(nband_k*(nband_k+1)))

   ABI_MALLOC(subvnlx,(0))
   ABI_MALLOC(totvnlx,(0,0))
   if (wfopta10==4) then
!    Later, will have to generalize to Fock case, like when wfopta10/=4
     if (gs_hamk%usepaw==0) then
       ABI_FREE(totvnlx)
       if (istwf_k==1) then
         ABI_MALLOC(totvnlx,(2*nband_k,nband_k))
       else if (istwf_k==2) then
         ABI_MALLOC(totvnlx,(nband_k,nband_k))
       end if
       use_totvnlx=1
     endif
   else
     if (gs_hamk%usepaw==0 .or. has_fock) then
       ABI_FREE(subvnlx)
       ABI_MALLOC(subvnlx,(nband_k*(nband_k+1)))
       use_subvnlx=1
     end if
   end if

   if (use_subovl==1) then
     ABI_MALLOC(subovl,(nband_k*(nband_k+1)))
   else
     ABI_MALLOC(subovl,(0))
   end if
 end if

 ! Carry out UP TO dtset%nline (or dtset%mdeg_filter) steps, or until resid for every band is < dtset%tolwfr
 if (prtvol/=5 .and. (prtvol>2 .or. ikpt <= nkpt_max)) then
   write(msg,'(a,i5,2x,a,3f9.5,2x,a)')' non-scf iterations; kpt # ',ikpt,', k= (',gs_hamk%kpt_k,'), band residuals:'
   call wrtout(std_out,msg,'PERS')
 end if

 if (dtset%cprj_in_memory==2) then
   if (ikpt==1) then
     write(msg,'(a,i3)') ' In vtowfk : use of cprj in memory with cprj_update_lvl=',dtset%cprj_update_lvl
     call wrtout(std_out,msg,'COLL')
   end if
   cprj_cwavef_bands => cprj(:,1+ibg:nband_k/mpi_enreg%nproc_band*my_nspinor+ibg)
 end if

!Electric field: initialize dphase_k
 dphase_k(:) = zero

!=========================================================================
!==================== NON-SELF-CONSISTENT LOOP ===========================
!=========================================================================

!nnsclo_now=number of non-self-consistent loops for the current vtrial
!(often 1 for SCF calculation, =nstep for non-SCF calculations)
 call timab(39,1,tsec) ! "vtowfk (loop)"

 cg_k => cg(:,1+icg:npw_k*my_nspinor*nband_k+icg)

 do inonsc=1,nnsclo_now
   ABI_NVTX_START_RANGE(NVTX_VTOWFK_EXTRA1)
   if (iscf < 0 .and. (inonsc <= enough .or. mod(inonsc, 10) == 0)) call cwtime(cpu, wall, gflops, "start")

   if (dtset%rmm_diis /= 0 .and. iscf < 0) then
     use_rmm_diis = inonsc > 3 + dtset%rmm_diis
     !if (use_rmm_diis) call wrtout(std_out, " Activating RMM-DIIS eigensolver in NSCF mode.")
   end if

   ! This initialisation is needed for the MPI-parallelisation (gathering using sum)
   if(wfopta10 /= 1 .and. .not. xg_diago) then
     subham(:)=zero
     if (gs_hamk%usepaw==0) then
       if (wfopta10==4) then
         totvnlx(:,:)=zero
       else
         subvnlx(:)=zero
       end if
     end if
     if (use_subovl==1)subovl(:)=zero
   end if

   !resid_k(:)=zero

   !call cg_kfilter(npw_k, my_nspinor, nband_k, kinpw, cg(:, icg+1))

!  Filter the WFs when modified kinetic energy is too large (see routine mkkin.f)
!  !$OMP PARALLEL DO COLLAPSE(2) PRIVATE(igs,iwavef)
   do iband=1,nband_k
     iwavef=(iband-1)*npw_k*my_nspinor+icg
     cwavef_iband => cg(:,1+iwavef:npw_k*my_nspinor+iwavef)
     update_cprj=.False.
     do ispinor=1,my_nspinor
       igs=(ispinor-1)*npw_k
       do ipw=1+igs,npw_k+igs
         if(kinpw(ipw-igs)>huge(zero)*1.d-11)then
           norm=cwavef_iband(1,ipw)**2+cwavef_iband(2,ipw)**2
           if (norm>tol15*tol15) update_cprj=.True.
           cwavef_iband(:,ipw)=zero
         end if
       end do
     end do
     if (dtset%cprj_in_memory==2.and.update_cprj) then
       cprj_cwavef => cprj_cwavef_bands(:,my_nspinor*(iband-1)+1:my_nspinor*iband)
       call cprj_update_oneband(cwavef_iband,cprj_cwavef,gs_hamk,mpi_enreg,tim_getcprj)
     end if
   end do
   ABI_NVTX_END_RANGE()


   ! JLJ 17/10/2014: If it is a GWLS calculation, construct the hamiltonian
   ! as in a usual GS calc., but skip any minimisation procedure.
   ! This would be equivalent to nstep=0, if the latter did work.
   if(dtset%optdriver/=RUNL_GWLS) then

     if(wfopta10==4.or.wfopta10==1) then

       if (dtset%gpu_option==ABI_GPU_KOKKOS) then
         ! Kokkos GPU branch is not OpenMP thread-safe, setting OpenMP num threads to 1
         call xomp_set_num_threads(1)
       end if

!    =========================================================================
!    ============ MINIMIZATION OF BANDS: LOBPCG ==============================
!    =========================================================================
       if (wfopta10==4) then

         if (use_rmm_diis) then
           call rmm_diis(istep, ikpt, isppol, cg_k, dtset, eig_k, occ_k, enlx_k, gs_hamk, kinpw, gsc, &
                         mpi_enreg, nband_k, npw_k, my_nspinor, resid_k, rmm_diis_status)
         else

           if ( .not. xg_diago ) then

             ABI_NVTX_START_RANGE(NVTX_LOBPCG1)
             call lobpcgwf(cg,dtset,gs_hamk,gsc,icg,igsc,kinpw,mcg,mgsc,mpi_enreg,&
&             nband_k,nblockbd,npw_k,prtvol,resid_k,subham,totvnlx,use_totvnlx)
             ! In case of FFT parallelism, exchange subspace arrays
             spaceComm=mpi_enreg%comm_bandspinorfft
             call xmpi_sum(subham,spaceComm,ierr)
             if (gs_hamk%usepaw==0) then
               if (wfopta10==4) then
                 call xmpi_sum(totvnlx,spaceComm,ierr)
               else
                 call xmpi_sum(subvnlx,spaceComm,ierr)
               end if
             end if
             if (use_subovl==1) call xmpi_sum(subovl,spaceComm,ierr)
             ABI_NVTX_END_RANGE()

           else

             ABI_NVTX_START_RANGE(NVTX_LOBPCG2)
             if (dtset%cprj_in_memory==1) then
               call lobpcgwf2_cprj(cg_k,dtset,eig_k,occ_k,enlx_k,gs_hamk,isppol,ikpt,inonsc,istep,&
                 kinpw,mpi_enreg,nband_k,npw_k,my_nspinor,prtvol,resid_k,nbdbuf,xg_nonlop)
             else
               call lobpcgwf2(cg_k,dtset,eig_k,occ_k,enlx_k,gs_hamk,isppol,ikpt,inonsc,istep,kinpw,mpi_enreg,&
&               nband_k,npw_k,my_nspinor,prtvol,resid_k,nbdbuf)
             end if
             ABI_NVTX_END_RANGE()

           end if

         end if

!    =========================================================================
!    ============ MINIMIZATION OF BANDS: CHEBYSHEV FILTERING =================
!    =========================================================================
       else if (wfopta10 == 1) then
         if ( .not. xg_diago) then
           ABI_NVTX_START_RANGE(NVTX_CHEBFI1)
           call chebfi(cg_k,dtset,eig_k,enlx_k,gs_hamk,gsc,kinpw,&
&           mpi_enreg,nband_k,npw_k,my_nspinor,prtvol,resid_k)
           ABI_NVTX_END_RANGE()
         else if (dtset%cprj_in_memory==1) then
           call chebfiwf2_cprj(cg_k,dtset,eig_k,occ_k,enlx_k,gs_hamk,&
             mpi_enreg,nband_k,npw_k,my_nspinor,prtvol,resid_k,xg_nonlop)
         else
           ABI_NVTX_START_RANGE(NVTX_CHEBFI2)
           call chebfiwf2(cg_k,dtset,eig_k,occ_k,enlx_k,gs_hamk,&
&           mpi_enreg,nband_k,npw_k,my_nspinor,prtvol,resid_k)
           ABI_NVTX_END_RANGE()
         end if
       end if

!      =========================================================================
!      ======== MINIMIZATION OF BANDS: CONJUGATE GRADIENT (Teter et al.) =======
!      =========================================================================
     else
       ! use_subvnlx=0; if (gs_hamk%usepaw==0 .or. associated(gs_hamk%fockcommon)) use_subvnlx=1
       ! use_subvnlx=0; if (gs_hamk%usepaw==0) use_subvnlx=1

       if (.not. use_rmm_diis) then

         if (isppol==1.and.ikpt==1.and.inonsc==1.and.istep==1) then
           if (dtset%tolwfr_diago/=zero) then
             write(msg, '(a,es16.6)' ) ' cgwf: tolwfr_diago=',dtset%tolwfr_diago
             call wrtout(std_out,msg,'COLL')
           end if
         end if

         if (dtset%cprj_in_memory==2) then
           call cgwf_cprj(cg,cprj_cwavef_bands,dtset%cprj_update_lvl,eig_k,&
&             gs_hamk,icg,mcg,mpi_enreg,nband_k,dtset%nline,&
&             dtset%ortalg,prtvol,quit,resid_k,subham,dtset%tolrde,dtset%tolwfr_diago,wfoptalg)
         else
           call cgwf(dtset%berryopt,cg,cgq,dtset%chkexit,cpus,dphase_k,dtefield,dtfil%filnam_ds(1),&
&           gsc,gs_hamk,icg,igsc,ikpt,inonsc,isppol,dtset%mband,mcg,mcgq,mgsc,mkgq,&
&           mpi_enreg,mpw,nband_k,dtset%nbdblock,nkpt,dtset%nline,npw_k,npwarr,my_nspinor,&
&           dtset%nsppol,dtset%ortalg,prtvol,pwind,pwind_alloc,pwnsfac,pwnsfacq,quit,resid_k,&
&           subham,subovl,subvnlx,dtset%tolrde,dtset%tolwfr_diago,use_subovl,use_subvnlx,wfoptalg,zshift)
         end if
       else
         call rmm_diis(istep, ikpt, isppol, cg(:,icg+1:), dtset, eig_k, occ_k, enlx_k, gs_hamk, kinpw, gsc, &
                       mpi_enreg, nband_k, npw_k, my_nspinor, resid_k, rmm_diis_status)
       end if

       if (dtset%gpu_option==ABI_GPU_KOKKOS) then
         ! Kokkos GPU branch is not OpenMp thread-safe, restoring OpenMP threads num
         call xomp_set_num_threads(dtset%gpu_kokkos_nthrd)
       end if

     end if

   end if

!  =========================================================================
!  ===================== FIND LARGEST RESIDUAL =============================
!  =========================================================================

!  Find largest resid over bands at this k point
!  Note that this operation is done BEFORE rotation of bands:
!  it would be time-consuming to recompute the residuals after.
   if (nbdbuf >= 0) then
     max_resid = maxval(resid_k(1:max(1,nband_k-nbdbuf)))
   else if (nbdbuf==-101) then
     max_resid = maxval(occ_k(1:nband_k)*resid_k(1:nband_k))
   else
     ABI_ERROR(sjoin('Bad value of nbdbuf:', itoa(nbdbuf)))
   end if

!  Print residuals
   if(prtvol/=5.and.(prtvol>2 .or. ikpt<=nkpt_max))then
     do ii=0,(nband_k-1)/8
       write(msg,'(a,8es10.2)')' res:',(resid_k(iband),iband=1+ii*8,min(nband_k,8+ii*8))
       call wrtout(std_out,msg,'PERS')
     end do
   end if

!  =========================================================================
!  ========== DIAGONALIZATION OF HAMILTONIAN IN WFs SUBSPACE ===============
!  =========================================================================
   do_subdiago = .not. wfopta10 == 1 .and. .not. xg_diago
   if (use_rmm_diis) do_subdiago = .False.  ! subdiago is already performed before RMM-DIIS.

   ABI_NVTX_START_RANGE(NVTX_SUB_SPC_DIAGO)
   if (do_subdiago) then
     if (prtvol > 1) call wrtout(std_out, " Performing subspace diagonalization.")
     call timab(585,1,tsec) !"vtowfk(subdiago)"
     if (dtset%cprj_in_memory==2) then
       call subdiago_low_memory(cg,eig_k,evec,icg,istwf_k,&
&       mcg,nband_k,npw_k,my_nspinor,dtset%paral_kgb,subham)
       call timab(585,2,tsec)
       call timab(578,1,tsec)
       call cprj_rotate(cprj_cwavef_bands,evec,gs_hamk%dimcprj,natom,nband_k,gs_hamk%nspinor)
       call timab(578,2,tsec)
     else
       call subdiago(cg, eig_k, evec, gsc, icg, igsc, istwf_k, &
       mcg, mgsc, nband_k, npw_k, my_nspinor, dtset%paral_kgb, &
       subham, subovl, use_subovl, gs_hamk%usepaw, mpi_enreg%me_g0)
       call timab(585,2,tsec)
     end if
   end if
   ABI_NVTX_END_RANGE()

   !  Print energies
   if(prtvol/=5.and.(prtvol>2 .or. ikpt<=nkpt_max))then
     do ii=0,(nband_k-1)/8
       write(msg, '(a,8es10.2)' )' ene:',(eig_k(iband),iband=1+ii*8,min(nband_k,8+ii*8))
       call wrtout(std_out,msg,'PERS')
     end do
   end if

!  THIS CHANGE OF SHIFT DOES NOT WORK WELL
!  Update zshift in the case of wfoptalg==3
!  if(wfoptalg==3 .and. inonsc/=1)then
!  do iband=1,nband_k
!  if(eig_k(iband)<eshift .and. eig_save(iband)<eshift)then
!  zshift(iband)=max(eig_k(iband),eig_save(iband))
!  end if
!  if(eig_k(iband)>eshift .and. eig_save(iband)>eshift)then
!  zshift(iband)=min(eig_k(iband),eig_save(iband))
!  end if
!  end do
!  eig_save(:)=eig_k(:)
!  end if

!  =========================================================================
!  =============== ORTHOGONALIZATION OF WFs (if needed) ====================
!  =========================================================================

!  Re-orthonormalize the wavefunctions at this k point.
!  this step is redundant but is performed to combat rounding error in wavefunction orthogonality.
!  This step is performed inside rmm_diis if RMM-DIIS is activated.

   call timab(583,1,tsec) ! "vtowfk(pw_orthon)"
   ortalgo = mpi_enreg%paral_kgb
   ! The orthogonalization is completely disabled with ortalg<=-10.
   ! This option is usefull for testing only and is not documented.
   do_ortho = (wfoptalg/=14 .and. wfoptalg /= 1 .and. wfoptalg /= 11 .and. dtset%ortalg>-10) .or. dtset%ortalg > 0
   if (xg_diago) do_ortho = .false.
   if (use_rmm_diis) do_ortho = .False.

   if (do_ortho) then

     ABI_NVTX_START_RANGE(NVTX_ORTHO_WF)

     if (prtvol > 0) call wrtout(std_out, " Calling pw_orthon to orthonormalize bands.")
     if (dtset%cprj_in_memory==2) then
       ABI_FREE(subovl)
       ABI_MALLOC(subovl,(nband_k*(nband_k+1)))
       call mksubovl(cg,cprj_cwavef_bands,gs_hamk,icg,nband_k,subovl,mpi_enreg)
       call pw_orthon_cprj(icg,mcg,npw_k*my_nspinor,my_nspinor,nband_k,ortalgo,subovl,cg,cprj=cprj_cwavef_bands)
     else
       call pw_orthon(icg,igsc,istwf_k,mcg,mgsc,npw_k*my_nspinor,nband_k,ortalgo,gsc,gs_hamk%usepaw,cg,&
&        mpi_enreg%me_g0,mpi_enreg%comm_bandspinorfft)
     end if

     ABI_NVTX_END_RANGE()
   end if
   call timab(583,2,tsec)

   ABI_NVTX_START_RANGE(NVTX_VTOWFK_EXTRA2)

   ! DEBUG seq==par comment next block
   ! Fix phases of all bands
   if (xmpi_paral/=1 .or. mpi_enreg%paral_kgb/=1) then
     !call wrtout(std_out, "Calling cgtk_fixphase")
     if ( (.not.xg_diago) .and. dtset%cprj_in_memory==0 ) then
       call cgtk_fixphase(cg,gsc,icg,igsc,istwf_k,mcg,mgsc,mpi_enreg,nband_k,npw_k*my_nspinor,gs_hamk%usepaw)
     else if ( xg_diago ) then
       ! GSC is local to vtowfk and is completely useless since everything
       ! is calculated in my lobpcg, we don't care about the phase of gsc !
       call cgtk_fixphase(cg,gsc,icg,igsc,istwf_k,mcg,mgsc,mpi_enreg,nband_k,npw_k*my_nspinor,0)
     else ! dtset%cprj_in_memory/=0 .and. .not.xg_diago
       call cgtk_fixphase(cg,gsc,icg,igsc,istwf_k,mcg,mgsc,mpi_enreg,nband_k,npw_k*my_nspinor,0,&
         & cprj=cprj_cwavef_bands,nspinor=dtset%nspinor)
     end if
   end if

   if (iscf < 0) then
     if (max_resid > dtset%tolwfr .and. max_resid < tol7) then
       if (fftcore_mixprec == 1) call wrtout(std_out, " Approaching NSCF convergence. Activating FFT in double-precision")
       ii = fftcore_set_mixprec(0)
     end if

     ! Print residual and wall-time required by NSCF iteration.
     if (inonsc <= enough .or. mod(inonsc, 20) == 0) then
       call cwtime(cpu, wall, gflops, "stop")
       if (ikpt == 1 .or. mod(ikpt, 100) == 0) then
         if (inonsc == 1) call wrtout(std_out, sjoin(" k-point: [", itoa(ikpt), "/", itoa(nkpt), "], spin:", itoa(isppol)))
         call wrtout(std_out, sjoin("   Max resid =", ftoa(max_resid, fmt="es13.5"), &
           " (exclude nbdbuf bands). One NSCF iteration cpu-time:", &
           sec2str(cpu), ", wall-time:", sec2str(wall)), do_flush=.True.)
         if (inonsc == enough) call wrtout(std_out, "   Printing residuals every mod(20) iterations...")
       end if
     end if
   end if
   ABI_NVTX_END_RANGE()

   ! Exit loop over inonsc if converged
   if (max_resid < dtset%tolwfr) then
     if (iscf < 0 .and. (ikpt == 1 .or. mod(ikpt, 100) == 0)) then
       call wrtout(std_out, sjoin("   NSCF loop completed after", itoa(inonsc), "iterations"))
     end if
     exit
   end if
 end do ! inonsc (NON SELF-CONSISTENT LOOP)

 if (dtset%cprj_in_memory==2) then
   update_cprj=dtset%cprj_update_lvl<=3.and.dtset%cprj_update_lvl/=2
   if (update_cprj) call cprj_update(cg,cprj_cwavef_bands,gs_hamk,icg,nband_k,mpi_enreg,tim_getcprj)
 end if

 call timab(39,2,tsec)
 call timab(30,1,tsec) ! "vtowfk  (afterloop)"

 !if (dtset%prtvol > 0)
 !call wrtout(std_out, sjoin(" Number of Vnl|Psi> applications:", itoa(nonlop_counter)))

!###################################################################

!Compute kinetic energy and non-local energy for each band, and in the SCF
!case, contribution to forces, and eventually accumulate rhoaug

 ndat=1;if (mpi_enreg%paral_kgb==1) ndat=mpi_enreg%bandpp
 if(iscf>0 .and. fixed_occ)  then
   ndat_fft=ndat; if(mpi_enreg%paral_kgb==0) ndat_fft=blocksize
   if(dtset%gpu_option==ABI_GPU_KOKKOS) then
#if defined HAVE_GPU && defined HAVE_YAKL
     ABI_MALLOC_MANAGED(wfraug,(/2,gs_hamk%n4,gs_hamk%n5,gs_hamk%n6*ndat_fft/))
#endif
   else
     ABI_MALLOC(wfraug,(2,gs_hamk%n4,gs_hamk%n5,gs_hamk%n6*ndat_fft))
   end if
 end if

!"nonlop" routine input parameters
 nnlout=3*natom*optforces
 signs=1;idir=0
 if (gs_hamk%usepaw==0) then
   choice=1+optforces
   paw_opt=0;cpopt=-1;tim_nonlop=2
 else
   choice=2*optforces
   paw_opt=2;cpopt=0;tim_nonlop=10-8*optforces
   if (dtset%cprj_in_memory==2) cpopt=2 ! cprj are in memory (but not the derivatives)
   if (dtset%usefock==1) then
!     if (dtset%optforces/= 0) then
     if (optforces/= 0) then
       choice=2;cpopt=1; nnlout=3*natom
     end if
   end if
 end if

 ABI_MALLOC(enlout,(nnlout*blocksize))

 ! Allocation of memory space for one block of waveforms containing blocksize waveforms
 if(dtset%gpu_option==ABI_GPU_KOKKOS) then
#if defined HAVE_GPU && defined HAVE_YAKL
   ABI_MALLOC_MANAGED(cwavef, (/2,npw_k*my_nspinor*blocksize/))
#endif
 else
   ABI_MALLOC(cwavef, (2,npw_k*my_nspinor*blocksize))
 end if

 if (dtset%cprj_in_memory/=2) then
   if (gs_hamk%usepaw==1.and.(iscf>0.or.gs_hamk%usecprj==1)) then
     iorder_cprj=0
     nband_k_cprj=nband_k*(mband_cprj/dtset%mband)
     bandpp_cprj=mpi_enreg%bandpp
     ABI_MALLOC(cwaveprj,(natom,my_nspinor*bandpp_cprj))
     ncpgr=0;if (cpopt==1) ncpgr=cprj(1,1)%ncpgr
     call pawcprj_alloc(cwaveprj,ncpgr,gs_hamk%dimcprj)
   else
     ABI_MALLOC(cwaveprj,(0,0))
   end if
 end if

!The code below is more efficient if paral_kgb==1 (less MPI communications)
!however OMP is not compatible with paral_kgb since we should define
!which threads performs the call to MPI_ALL_REDUCE.
!This problem can be easily solved by removing MPI_enreg from meanvalue_g so that
!the MPI call is done only once outside the OMP parallel region.

 !call cwtime(cpu, wall, gflops, "start")

 if (dtset%cprj_in_memory==1) then
   ncols_cprj = blocksize*my_nspinor/mpi_enreg%nproc_band
   call xg_init(cprj_xgx0,xg_nonlop%space_cprj,xg_nonlop%cprjdim,ncols_cprj,comm=xg_nonlop%comm_band)
   call xg_init(cprj_work,xg_nonlop%space_cprj,xg_nonlop%cprjdim,ncols_cprj,comm=xg_nonlop%comm_band)
 end if

 ! In case of GEMM nonlop distribution + force computation,
 ! recompute distribution as projectors arrays are bigger in this case
 gpu_option_tmp=gs_hamk%gpu_option
 if(optforces==1 .and. gs_hamk%gpu_option==ABI_GPU_OPENMP) then
   blksize_gemm_nonlop_tmp = gemm_nonlop_block_size; is_distrib_tmp = gemm_nonlop_is_distributed
   gemm_nonlop_block_size = dtset%gpu_nl_splitsize
   call get_gemm_nonlop_ompgpu_blocksize(ikpt,gs_hamk,mpi_enreg%bandpp,nband_k,&
   &                        dtset%nspinor,mpi_enreg%paral_kgb,mpi_enreg%nproc_band,&
   &                        optforces,0,-1,gs_hamk%gpu_option,(dtset%gpu_nl_distrib/=0),&
   &                        gemm_nonlop_block_size,nblk_gemm_nonlop,warn_on_fail=.true.)
   gemm_nonlop_is_distributed = (dtset%gpu_nl_distrib/=0 .and. nblk_gemm_nonlop > 0)
   if(nblk_gemm_nonlop==-1) then
     gs_hamk%gpu_option=ABI_GPU_DISABLED
     ABI_WARNING("GPU has been disabled for forces computation during SCF step due to memory constraints.")
   end if
 end if

!Loop over bands or blocks of bands. Note that in sequential mode iblock=iband, nblockbd=nband_k and blocksize=1
 do iblock=1,nblockbd
   occblock=maxval(occ_k(1+(iblock-1)*blocksize:iblock*blocksize))
   cwavef(:,:)=cg(:,1+(iblock-1)*npw_k*my_nspinor*blocksize+icg:iblock*npw_k*my_nspinor*blocksize+icg)

!  Compute kinetic energy of each band
   do iblocksize=1,blocksize
     iband=(iblock-1)*blocksize+iblocksize

     call meanvalue_g(ar,kinpw,0,istwf_k,mpi_enreg,npw_k,my_nspinor,&
&     cg(:,1+(iband-1)*npw_k*my_nspinor+icg:iband*npw_k*my_nspinor+icg),&
&     cg(:,1+(iband-1)*npw_k*my_nspinor+icg:iband*npw_k*my_nspinor+icg),0,&
&     gpu_thread_limit=dtset%gpu_thread_limit)

     ek_k(iband)=ar
     if(ANY(ABS(dtset%nucdipmom)>tol8)) then
       ABI_MALLOC(ghc_vectornd,(2,npw_k*my_nspinor))
       call getghc_nucdip(cg(:,1+(iband-1)*npw_k*my_nspinor+icg:iband*npw_k*my_nspinor+icg),&
         & ghc_vectornd,gs_hamk%gbound_k,gs_hamk%istwf_k,kg_k,gs_hamk%kpt_k,&
         & gs_hamk%mgfft,mpi_enreg,ndat,gs_hamk%ngfft,npw_k,gs_hamk%nvloc,&
         & gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,my_nspinor,gs_hamk%vectornd,gs_hamk%gpu_option)
       end_k(iband)=DOT_PRODUCT(cg(1,1+(iband-1)*npw_k*my_nspinor+icg:iband*npw_k*my_nspinor+icg),&
         &                      ghc_vectornd(1,1:npw_k*my_nspinor))+&
         &          DOT_PRODUCT(cg(2,1+(iband-1)*npw_k*my_nspinor+icg:iband*npw_k*my_nspinor+icg),&
         &                      ghc_vectornd(2,1:npw_k*my_nspinor))
       ABI_FREE(ghc_vectornd)
     end if
     if(paw_dmft%use_dmft==1) then
       do iband1=1,nband_k
         call meanvalue_g(ar,kinpw,0,istwf_k,mpi_enreg,npw_k,my_nspinor,&
&         cg(:,1+(iband -1)*npw_k*my_nspinor+icg:iband *npw_k*my_nspinor+icg),&
&         cg(:,1+(iband1-1)*npw_k*my_nspinor+icg:iband1*npw_k*my_nspinor+icg),&
&         paw_dmft%use_dmft,ar_im=ar_im,gpu_thread_limit=dtset%gpu_thread_limit)
         ek_k_nd(1,iband,iband1)=ar
         ek_k_nd(2,iband,iband1)=ar_im
       end do
     end if
   end do

   if (iscf>0) then

     ABI_NVTX_START_RANGE(NVTX_VTOWFK_FOURWF)
     ! In case of fixed occupation numbers, accumulates the partial density
     if (fixed_occ .and. mpi_enreg%paral_kgb/=1) then

       ! treat all bands at once on GPU
       if (dtset%gpu_option /= ABI_GPU_DISABLED) then

         ABI_MALLOC(weight_t,(blocksize))

         ! compute weights
         do iblocksize=1,blocksize
           iband=(iblock-1)*blocksize+iblocksize
           weight_t(iblocksize) = occ_k(iband) * wtk / gs_hamk%ucvol
           if (abs(occ_k(iband)) < tol8) weight_t(iblocksize) = zero
         end do

         if(dtset%nspinor==1) then
           call fourwf(1,rhoaug(:,:,:,1),cwavef(:,:),dummy,wfraug,&
           &    gs_hamk%gbound_k,gs_hamk%gbound_k,istwf_k,kg_k,kg_k,&
           &    gs_hamk%mgfft,mpi_enreg,blocksize,gs_hamk%ngfft,&
           &    npw_k,1,gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,1,tim_fourwf,weight,weight,&
           &    weight_array_r=weight_t,weight_array_i=weight_t,&
           &    gpu_option=dtset%gpu_option)
         else if(dtset%nspinor==2) then
           ABI_MALLOC(cwavefb,(2,npw_k*blocksize,2))
           ibs=(iblock-1)*npw_k*my_nspinor*blocksize+icg
           do iband=1,blocksize
             cwavefb(:,(iband-1)*npw_k+1:iband*npw_k,1)=cg(:,1+(2*iband-2)*npw_k+ibs:(iband*2-1)*npw_k+ibs)
             cwavefb(:,(iband-1)*npw_k+1:iband*npw_k,2)=cg(:,1+(2*iband-1)*npw_k+ibs:iband*2*npw_k+ibs)
           end do

           call fourwf(1,rhoaug(:,:,:,1),cwavefb(:,:,1),dummy,wfraug,&
           &    gs_hamk%gbound_k,gs_hamk%gbound_k,istwf_k,kg_k,kg_k,&
           &    gs_hamk%mgfft,mpi_enreg,blocksize,gs_hamk%ngfft,&
           &    npw_k,1,gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,1,tim_fourwf,weight,weight,&
           &    weight_array_r=weight_t,weight_array_i=weight_t,&
           &    gpu_option=dtset%gpu_option)
           if(dtset%nspden==1) then
             call fourwf(1,rhoaug(:,:,:,1),cwavefb(:,:,2),dummy,wfraug,&
             &    gs_hamk%gbound_k,gs_hamk%gbound_k,istwf_k,kg_k,kg_k,&
             &    gs_hamk%mgfft,mpi_enreg,blocksize,gs_hamk%ngfft,&
             &    npw_k,1,gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,1,tim_fourwf,weight,weight,&
             &    weight_array_r=weight_t,weight_array_i=weight_t,&
             &    gpu_option=dtset%gpu_option)
           else if (dtset%nspden==4) then
             ! Build the four components of rho. We use only norm quantities and, so fourwf.
             ! $\sum_{n} f_n \Psi^{* \alpha}_n \Psi^{\alpha}_n =\rho^{\alpha \alpha}$
             ! $\sum_{n} f_n (\Psi^{1}+\Psi^{2})^*_n (\Psi^{1}+\Psi^{2})_n=rho+m_x$
             ! $\sum_{n} f_n (\Psi^{1}-i \Psi^{2})^*_n (\Psi^{1}-i \Psi^{2})_n=rho+m_y$
             ABI_MALLOC(cwavef_x,(2,npw_k*blocksize))
             ABI_MALLOC(cwavef_y,(2,npw_k*blocksize))

             !$(\Psi^{1}+\Psi^{2})$
             cwavef_x(:,:)=cwavefb(:,1:npw_k*blocksize,1)+cwavefb(:,1:npw_k*blocksize,2)
             !$(\Psi^{1}-i \Psi^{2})$
             cwavef_y(1,:)=cwavefb(1,1:npw_k*blocksize,1)+cwavefb(2,1:npw_k*blocksize,2)
             cwavef_y(2,:)=cwavefb(2,1:npw_k*blocksize,1)-cwavefb(1,1:npw_k*blocksize,2)

             ! z component
             call fourwf(1,rhoaug(:,:,:,4),cwavefb(:,:,2),dummy,wfraug,&
             &    gs_hamk%gbound_k,gs_hamk%gbound_k,istwf_k,kg_k,kg_k,&
             &    gs_hamk%mgfft,mpi_enreg,blocksize,gs_hamk%ngfft,&
             &    npw_k,1,gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,1,tim_fourwf,weight,weight,&
             &    weight_array_r=weight_t,weight_array_i=weight_t,&
             &    gpu_option=dtset%gpu_option)
             ! x component
             call fourwf(1,rhoaug(:,:,:,2),cwavef_x(:,:),dummy,wfraug,&
             &    gs_hamk%gbound_k,gs_hamk%gbound_k,istwf_k,kg_k,kg_k,&
             &    gs_hamk%mgfft,mpi_enreg,blocksize,gs_hamk%ngfft,&
             &    npw_k,1,gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,1,tim_fourwf,weight,weight,&
             &    weight_array_r=weight_t,weight_array_i=weight_t,&
             &    gpu_option=dtset%gpu_option)
             ! y component
             call fourwf(1,rhoaug(:,:,:,3),cwavef_y(:,:),dummy,wfraug,&
             &    gs_hamk%gbound_k,gs_hamk%gbound_k,istwf_k,kg_k,kg_k,&
             &    gs_hamk%mgfft,mpi_enreg,blocksize,gs_hamk%ngfft,&
             &    npw_k,1,gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,1,tim_fourwf,weight,weight,&
             &    weight_array_r=weight_t,weight_array_i=weight_t,&
             &    gpu_option=dtset%gpu_option)
             ABI_FREE(cwavef_x)
             ABI_FREE(cwavef_y)
           end if

           ABI_FREE(cwavefb)
         end if

         ABI_FREE(weight_t)

       else

         do iblocksize=1,blocksize
           iband=(iblock-1)*blocksize+iblocksize
           cwavef_iband => cwavef(:,1+(iblocksize-1)*npw_k*my_nspinor:iblocksize*npw_k*my_nspinor)

           if (abs(occ_k(iband))>=tol8) then
             weight = occ_k(iband) * wtk / gs_hamk%ucvol

             ! Accumulate charge density in real space in array rhoaug

             ! The same section of code is also found in mkrho.F90 : should be rationalized !
             call fourwf(1,rhoaug(:,:,:,1),cwavef_iband,dummy,wfraug,gs_hamk%gbound_k,gs_hamk%gbound_k,&
               &           istwf_k,gs_hamk%kg_k,gs_hamk%kg_k,gs_hamk%mgfft,mpi_enreg,1,gs_hamk%ngfft,npw_k,1,&
               &           gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,1,&
               &           tim_fourwf,weight,weight,gpu_option=dtset%gpu_option)

             if(dtset%nspinor==2)then
               ABI_MALLOC(cwavef1,(2,npw_k))
               cwavef1(:,:)=cwavef_iband(:,1+npw_k:2*npw_k) ! EB FR spin dn part and used for m_z component (cwavef_z)

               if(dtset%nspden==1) then

                 call fourwf(1,rhoaug(:,:,:,1),cwavef1,dummy,wfraug,&
                   &               gs_hamk%gbound_k,gs_hamk%gbound_k,&
                   &               istwf_k,gs_hamk%kg_k,gs_hamk%kg_k,gs_hamk%mgfft,mpi_enreg,1,gs_hamk%ngfft,npw_k,1,&
                   &               gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,1,&
                   &               tim_fourwf,weight,weight,gpu_option=dtset%gpu_option)

               else if(dtset%nspden==4) then
                 ! Build the four components of rho. We use only norm quantities and, so fourwf.
                 ! $\sum_{n} f_n \Psi^{* \alpha}_n \Psi^{\alpha}_n =\rho^{\alpha \alpha}$
                 ! $\sum_{n} f_n (\Psi^{1}+\Psi^{2})^*_n (\Psi^{1}+\Psi^{2})_n=rho+m_x$
                 ! $\sum_{n} f_n (\Psi^{1}-i \Psi^{2})^*_n (\Psi^{1}-i \Psi^{2})_n=rho+m_y$
                 ABI_MALLOC(cwavef_x,(2,npw_k))
                 ABI_MALLOC(cwavef_y,(2,npw_k))
                 !$(\Psi^{1}+\Psi^{2})$
                 cwavef_x(:,:)=cwavef_iband(:,1:npw_k)+cwavef1(:,1:npw_k)
                 !$(\Psi^{1}-i \Psi^{2})$
                 cwavef_y(1,:)=cwavef_iband(1,1:npw_k)+cwavef1(2,1:npw_k)
                 cwavef_y(2,:)=cwavef_iband(2,1:npw_k)-cwavef1(1,1:npw_k)
                 ! z component
                 call fourwf(1,rhoaug(:,:,:,4),cwavef1,dummy,wfraug,gs_hamk%gbound_k,gs_hamk%gbound_k,&
                   &             istwf_k,gs_hamk%kg_k,gs_hamk%kg_k,gs_hamk%mgfft,mpi_enreg,1,gs_hamk%ngfft,npw_k,1,&
                   &               gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,1,&
                   &               tim_fourwf,weight,weight,gpu_option=dtset%gpu_option)
                 ! x component
                 call fourwf(1,rhoaug(:,:,:,2),cwavef_x,dummy,wfraug,gs_hamk%gbound_k,gs_hamk%gbound_k,&
                   &               istwf_k,gs_hamk%kg_k,gs_hamk%kg_k,gs_hamk%mgfft,mpi_enreg,1,gs_hamk%ngfft,npw_k,1,&
                   &               gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,1,&
                   &               tim_fourwf,weight,weight,gpu_option=dtset%gpu_option)
                 ! y component
                 call fourwf(1,rhoaug(:,:,:,3),cwavef_y,dummy,wfraug,gs_hamk%gbound_k,gs_hamk%gbound_k,&
                   &               istwf_k,gs_hamk%kg_k,gs_hamk%kg_k,gs_hamk%mgfft,mpi_enreg,1,gs_hamk%ngfft,npw_k,1,&
                   &               gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,1,&
                   &               tim_fourwf,weight,weight,gpu_option=dtset%gpu_option)

                 ABI_FREE(cwavef_x)
                 ABI_FREE(cwavef_y)

               end if ! dtset%nspden/=4
               ABI_FREE(cwavef1)
             end if
           else
             nskip=nskip+1
           end if
         end do  ! Loop inside a block of bands

       end if ! dtset%gpu_option

!      In case of fixed occupation numbers,in bandFFT mode accumulates the partial density
     else if (fixed_occ .and. mpi_enreg%paral_kgb==1) then

       if (dtset%nspinor==1) then
         call timab(537,1,tsec) ! "prep_fourwf%vtow"
         call prep_fourwf(rhoaug(:,:,:,1),blocksize,cwavef,wfraug,iblock,istwf_k,&
&         gs_hamk%mgfft,mpi_enreg,nband_k,ndat,gs_hamk%ngfft,npw_k,&
&         gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,occ_k,&
&         1,gs_hamk%ucvol,wtk,gpu_option=dtset%gpu_option)
         call timab(537,2,tsec)
       else if (dtset%nspinor==2) then
         ABI_MALLOC(cwavefb,(2,npw_k*blocksize,2))
         ibs=(iblock-1)*npw_k*my_nspinor*blocksize+icg
!        --- No parallelization over spinors ---
         if (mpi_enreg%paral_spinor==0) then
           do iband=1,blocksize
             cwavefb(:,(iband-1)*npw_k+1:iband*npw_k,1)=cg(:,1+(2*iband-2)*npw_k+ibs:(iband*2-1)*npw_k+ibs)
             cwavefb(:,(iband-1)*npw_k+1:iband*npw_k,2)=cg(:,1+(2*iband-1)*npw_k+ibs:iband*2*npw_k+ibs)
           end do
         else
!          --- Parallelization over spinors ---
!          (split the work between 2 procs)
           cwavefb(:,:,3-ispinor_index)=zero
           do iband=1,blocksize
             cwavefb(:,(iband-1)*npw_k+1:iband*npw_k,ispinor_index) = cg(:,1+(iband-1)*npw_k+ibs:iband*npw_k+ibs)
           end do
           call xmpi_sum(cwavefb,mpi_enreg%comm_spinor,ierr)
         end if

         call timab(537,1,tsec) !"prep_fourwf%vtow"
         if (nspinor1TreatedByThisProc) then
           call prep_fourwf(rhoaug(:,:,:,1),blocksize,cwavefb(:,:,1),wfraug,iblock,&
&           istwf_k,gs_hamk%mgfft,mpi_enreg,nband_k,ndat,gs_hamk%ngfft,npw_k,&
&           gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,occ_k,1,gs_hamk%ucvol,wtk,&
&           gpu_option=dtset%gpu_option)
         end if
         if(dtset%nspden==1) then
           if (nspinor2TreatedByThisProc) then
             call prep_fourwf(rhoaug(:,:,:,1),blocksize,cwavefb(:,:,2),wfraug,&
&             iblock,istwf_k,gs_hamk%mgfft,mpi_enreg,nband_k,ndat,&
&             gs_hamk%ngfft,npw_k,gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,occ_k,1,&
&             gs_hamk%ucvol,wtk,gpu_option=dtset%gpu_option)
           end if
         else if(dtset%nspden==4) then
           ABI_MALLOC(cwavef_x,(2,npw_k*blocksize))
           ABI_MALLOC(cwavef_y,(2,npw_k*blocksize))
           cwavef_x(:,:)=cwavefb(:,1:npw_k*blocksize,1)+cwavefb(:,:,2)
           cwavef_y(1,:)=cwavefb(1,1:npw_k*blocksize,1)+cwavefb(2,:,2)
           cwavef_y(2,:)=cwavefb(2,:,1)-cwavefb(1,:,2)
           if (nspinor1TreatedByThisProc) then
             call prep_fourwf(rhoaug(:,:,:,4),blocksize,cwavefb(:,:,2),wfraug,&
&             iblock,istwf_k,gs_hamk%mgfft,mpi_enreg,nband_k,ndat,gs_hamk%ngfft,&
&             npw_k,gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,occ_k,1,gs_hamk%ucvol,wtk,&
&             gpu_option=dtset%gpu_option)
           end if
           if (nspinor2TreatedByThisProc) then
             call prep_fourwf(rhoaug(:,:,:,2),blocksize,cwavef_x,wfraug,&
&             iblock,istwf_k,gs_hamk%mgfft,mpi_enreg,nband_k,ndat,gs_hamk%ngfft,&
&             npw_k,gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,occ_k,1,gs_hamk%ucvol,wtk,&
&             gpu_option=dtset%gpu_option)
             call prep_fourwf(rhoaug(:,:,:,3),blocksize,cwavef_y,wfraug,&
&             iblock,istwf_k,gs_hamk%mgfft,mpi_enreg,nband_k,ndat,gs_hamk%ngfft,&
&             npw_k,gs_hamk%n4,gs_hamk%n5,gs_hamk%n6,occ_k,1,gs_hamk%ucvol,wtk,&
&             gpu_option=dtset%gpu_option)
           end if
           ABI_FREE(cwavef_x)
           ABI_FREE(cwavef_y)
         end if
         call timab(537,2,tsec)
         ABI_FREE(cwavefb)
       end if
     end if
     ABI_NVTX_END_RANGE()
   end if ! End of SCF calculation

!    Call to nonlocal operator:
!    - Compute nonlocal forces from most recent wfs
!    - PAW: compute projections of WF onto NL projectors (cprj)
   ABI_NVTX_START_RANGE(NVTX_VTOWFK_NONLOP)
   eig_k_block => eig_k(1+(iblock-1)*blocksize:iblock*blocksize)
   cg_k_block => cg_k(:,1+(iblock-1)*blocksize*my_nspinor*npw_k:iblock*blocksize*my_nspinor*npw_k)
   if (dtset%cprj_in_memory==2) then
     if (optforces>0) then
       call timab(554,1,tsec)  ! "vtowfk:rhoij"
!      Treat all wavefunctions in case of PAW
       cwaveprj => cprj(:,1+(iblock-1)*my_nspinor*blocksize+ibg:iblock*my_nspinor*blocksize+ibg)
       call nonlop(choice,cpopt,cwaveprj,enlout,gs_hamk,idir,eig_k_block,&
&       mpi_enreg,blocksize,nnlout,&
&       paw_opt,signs,nonlop_dum,tim_nonlop,cwavef,cwavef)
!      Acccumulate forces
       iband=(iblock-1)*blocksize
       do iblocksize=1,blocksize
         ii=0
         if (nnlout>3*natom) ii=6
         iband=iband+1;ibs=ii+nnlout*(iblocksize-1)
         grnl_k(1:nnlout,iband)=enlout(ibs+1:ibs+nnlout)
       end do
       call timab(554,2,tsec)  ! "vtowfk:rhoij"
     end if ! PAW or forces
   else
     if(iscf>0.or.gs_hamk%usecprj==1)then
       if (gs_hamk%usepaw==1.or.optforces/=0) then
!        Treat all wavefunctions in case of varying occupation numbers or PAW
!        Only treat occupied bands in case of fixed occupation numbers and NCPP
         if(fixed_occ.and.abs(occblock)<=tol8.and.gs_hamk%usepaw==0) then
           if (optforces>0) grnl_k(:,(iblock-1)*blocksize+1:iblock*blocksize)=zero
         else
           if (dtset%cprj_in_memory/=1) then
             if(gs_hamk%usepaw==1) then
               call timab(554,1,tsec)  ! "vtowfk:rhoij"
             end if
             if(cpopt==1) then
               iband=1+(iblock-1)*bandpp_cprj
               call pawcprj_copy(cprj(:,1+(iblock-1)*my_nspinor*blocksize+ibg:iblock*my_nspinor*blocksize+ibg),cwaveprj)
             end if
             if (mpi_enreg%paral_kgb==1) then
               call timab(572,1,tsec) ! 'prep_nonlop%vtowfk'
               call prep_nonlop(choice,cpopt,cwaveprj,enlout,gs_hamk,idir, &
  &             eig_k_block,blocksize,&
  &             mpi_enreg,nnlout,paw_opt,signs,nonlop_dum,tim_nonlop_prep,cwavef,cwavef,already_transposed=.false.)
               call timab(572,2,tsec)
             else
               call nonlop(choice,cpopt,cwaveprj,enlout,gs_hamk,idir,eig_k_block,&
  &             mpi_enreg,blocksize,nnlout,&
  &             paw_opt,signs,nonlop_dum,tim_nonlop,cwavef,cwavef)
             end if
             if(gs_hamk%usepaw==1) then
               call timab(554,2,tsec)
             end if
  !          Acccumulate forces
             if (optforces>0) then
               iband=(iblock-1)*blocksize
               do iblocksize=1,blocksize
                 ii=0
                 if (nnlout>3*natom) ii=6
                 iband=iband+1;ibs=ii+nnlout*(iblocksize-1)
                 grnl_k(1:nnlout,iband)=enlout(ibs+1:ibs+nnlout)
               end do
             end if
  !          Store cprj (<Pnl|Psi>)
             if (gs_hamk%usepaw==1.and.gs_hamk%usecprj==1) then
               iband=1+(iblock-1)*bandpp_cprj
               call pawcprj_put(gs_hamk%atindx,cwaveprj,cprj,natom,iband,ibg,ikpt,iorder_cprj,isppol,&
  &             mband_cprj,dtset%mkmem,natom,bandpp_cprj,nband_k_cprj,gs_hamk%dimcprj,my_nspinor,&
  &             dtset%nsppol,dtfil%unpaw,mpicomm=mpi_enreg%comm_kpt,proc_distrb=mpi_enreg%proc_distrb)
             end if

           else ! cprj_in_memory==1

             call timab(222,1,tsec) ! 'nonlop%vtowfk'

             if ( gs_hamk%istwf_k > 1 ) then ! Real only
               space = SPACE_CR
             else ! complex
               space = SPACE_C
             end if
             me_g0 = -1
             if (space==SPACE_CR) then
               me_g0 = 0
               if (gs_hamk%istwf_k == 2) then
                 if (mpi_enreg%me_g0 == 1) me_g0 = 1
               end if
             end if
             call xgBlock_map(xgx0,cg_k_block,space,npw_k*my_nspinor,blocksize,comm=mpi_enreg%comm_band,me_g0=me_g0,&
         &   gpu_option=dtset%gpu_option)
             call xgBlock_map_1d(xgeigen,eig_k_block,SPACE_R,blocksize)

             if (optforces/=0.or.gs_hamk%usepaw==1) then
               call xg_nonlop_getcprj(xg_nonlop,xgx0,cprj_xgx0%self,cprj_work%self)
             end if

             if (optforces/=0) then
               grnl_k_block => grnl_k(:,1+(iblock-1)*blocksize:iblock*blocksize)
               call xgBlock_map(xgforces,grnl_k_block,SPACE_R,3*natom,blocksize)
               call xg_nonlop_forces_stress(xg_nonlop,xgx0,cprj_xgx0%self,cprj_work%self,xgeigen,forces=xgforces)
             end if

             call timab(222,2,tsec) ! 'nonlop%vtowfk'

             if (gs_hamk%usepaw==1) then
               cprj_cwavef_bands => cprj(:,1+ibg+(iblock-1)*ncols_cprj:iblock*ncols_cprj+ibg)
               call xg_cprj_copy(cprj_cwavef_bands,cprj_xgx0%self,xg_nonlop,XG_TO_CPRJ)
             end if

           end if
         end if
       end if ! PAW or forces
     end if ! iscf>0 or iscf=-3
   end if
   ABI_NVTX_END_RANGE()
 end do !  End of loop on blocks

 ! restore safe value related to GEMM nonlop slicing and GPU in case of forces compute
 if(optforces==1 .and. gpu_option_tmp==ABI_GPU_OPENMP) then
   gs_hamk%gpu_option = gpu_option_tmp
   gemm_nonlop_block_size = blksize_gemm_nonlop_tmp
   gemm_nonlop_is_distributed = is_distrib_tmp
 end if

 if (dtset%cprj_in_memory==1) then

   call xg_free(cprj_xgx0)
   call xg_free(cprj_work)

 end if

 !call cwtime_report(" Block loop", cpu, wall, gflops)

 if(dtset%gpu_option==ABI_GPU_KOKKOS) then
#if defined HAVE_GPU && defined HAVE_YAKL
   ABI_FREE_MANAGED(cwavef)
#endif
 else
   ABI_FREE(cwavef)
 end if

 ABI_FREE(enlout)

 if (dtset%cprj_in_memory/=2) then
   if (gs_hamk%usepaw==1.and.(iscf>0.or.gs_hamk%usecprj==1)) then
     call pawcprj_free(cwaveprj)
   end if
   ABI_FREE(cwaveprj)
 else
   nullify(cwaveprj)
 end if

 if (fixed_occ.and.iscf>0) then
   if(dtset%gpu_option==ABI_GPU_KOKKOS) then
#if defined HAVE_GPU && defined HAVE_YAKL
     ABI_FREE_MANAGED(wfraug)
#endif
   else
     ABI_FREE(wfraug)
   end if
 end if

!Write the number of one-way 3D ffts skipped until now (in case of fixed occupation numbers
 if(iscf>0 .and. fixed_occ .and. (prtvol>2 .or. ikpt<=nkpt_max) )then
   write(msg,'(a,i0)')' vtowfk: number of one-way 3D ffts skipped in vtowfk until now =',nskip
   call wrtout(std_out,msg,'PERS')
 end if

 ! Norm-conserving or FockACE: Compute nonlocal+FockACE part of total energy: rotate subvnlx elements
 ! Note the two calls. For (old) lobpcgwf we have a (nband_k, nband_k) matrix, whereas cgwf
 ! returns results in packed form.
 ! CHEBYSHEV, NEW LOBPCG and RMM-DIIS do not need this
 !
 rotate_subvnlx = gs_hamk%usepaw == 0 .and. wfopta10 /= 1 .and. .not. xg_diago
 if (use_rmm_diis) rotate_subvnlx = .False.

 if (rotate_subvnlx) then
   call timab(586,1,tsec)   ! 'vtowfk(nonlocalpart)'
   if (wfopta10==4) then
     call cg_hrotate_and_get_diag(istwf_k, nband_k, totvnlx, evec, enlx_k)
   else
     call cg_hprotate_and_get_diag(nband_k, subvnlx, evec, enlx_k)
   end if
   call timab(586,2,tsec)
 end if

!###################################################################

 if (iscf<=0 .and. max_resid > dtset%tolwfr) then
   write(msg,'(2(a,i0),a,es13.5)')&
    "Wavefunctions not converged for ikpt: ", ikpt, ", nnsclo: ",nnsclo_now,', max resid: ',max_resid
   ABI_WARNING(msg)
 end if

!Print out eigenvalues (hartree)
 if (mod(dtset%wfoptalg,10)==1) then
   niter=dtset%mdeg_filter
   iter_name='as the degree of the polynomial filter'
 else
   niter=dtset%nline
   iter_name='CG line minimizations'
 end if
 if (prtvol/=5.and.(prtvol>2 .or. ikpt<=nkpt_max)) then
   write(msg, '(5x,a,i5,2x,a,a,a,i4,a,i4,2a)' ) &
    'eigenvalues (hartree) for',nband_k,'bands',ch10,&
    '              after ',inonsc,' non-SCF iterations with ',niter,' ',trim(iter_name)
   call wrtout(std_out,msg,'PERS')
   do ii=0,(nband_k-1)/6
     write(msg, '(1p,6e12.4)' ) (eig_k(iband),iband=1+6*ii,min(6+6*ii,nband_k))
     call wrtout(std_out,msg,'PERS')
   end do
 else if(ikpt==nkpt_max+1)then
   call wrtout(std_out,' vtowfk : prtvol=0 or 1, do not print more k-points.','PERS')
 end if

!Print out decomposition of eigenvalues in the non-selfconsistent case or if prtvol>=10
 if( (iscf<0 .and. (prtvol>2 .or. ikpt<=nkpt_max)) .or. prtvol>=10)then
   write(msg, '(5x,a,i5,2x,a,a,a,i4,a,i4,2a)' ) &
&   ' mean kinetic energy (hartree) for ',nband_k,' bands',ch10,&
&   '              after ',inonsc,' non-SCF iterations with ',niter,' ',trim(iter_name)
   call wrtout(std_out,msg,'PERS')

   do ii=0,(nband_k-1)/6
     write(msg, '(1p,6e12.4)' ) (ek_k(iband),iband=1+6*ii,min(6+6*ii,nband_k))
     call wrtout(std_out,msg,'PERS')
   end do

   if (gs_hamk%usepaw==0) then
     write(msg, '(5x,a,i5,2x,a,a,a,i4,a,i4,2a)' ) &
&     ' mean NL+Fock-type energy (hartree) for ',nband_k,' bands',ch10,&
&     '              after ',inonsc,' non-SCF iterations with ',niter,' ',trim(iter_name)
     call wrtout(std_out,msg,'PERS')

     do ii=0,(nband_k-1)/6
       write(msg,'(1p,6e12.4)') (enlx_k(iband),iband=1+6*ii,min(6+6*ii,nband_k))
       call wrtout(std_out,msg,'PERS')
     end do
   end if
 end if

 ! Hamiltonian constructor for gwls_sternheimer
 if (dtset%optdriver==RUNL_GWLS) call build_H(dtset,mpi_enreg,cpopt,cg,gs_hamk,kg_k,kinpw)

 if (dtset%cprj_in_memory==2) nullify(cprj_cwavef_bands)

 if(wfopta10 /= 1 .and. .not. xg_diago) then
   ABI_FREE(evec)
   ABI_FREE(subham)
   ABI_FREE(totvnlx)
   ABI_FREE(subvnlx)
   ABI_FREE(subovl)
 end if

 ABI_SFREE(gsc)

 if(wfoptalg==3) then
   ABI_FREE(eig_save)
 end if

 if (prtvol==-level) then
   ! Structured debugging: if prtvol=-level, stop here.
   write(msg,'(a,a,a,i0,a)')' vtowfk : exit ',ch10,'  prtvol=-',level,', debugging mode => stop '
   ABI_ERROR(msg)
 end if

 call timab(30,2,tsec)
 call timab(28,2,tsec)

 DBG_EXIT("COLL")

end subroutine vtowfk
!!***

end module m_vtowfk
!!***
