#!/usr/bin/env python3

import sys
import os
import os.path
import re
import math
import time
import numpy as np # to install this lib in deb-type linux, run as root: apt-get install python3-numpy
import xyz_utils
import xyz_statistics

# to test this code you could use e.g.:
# for the 3D version:
# echo | awk 'END{n=1000;L=99; print n; print ""; for(i=1;i<=n;i++){print "C",L*rand(),L*rand(),L*rand()}}' > /tmp/pippo.xyz
# for the 2D version:
# echo | awk 'END{n=1000;L=99; print n; print ""; for(i=1;i<=n;i++){print "C",L*rand(),L*rand(),0}}' > /tmp/pippo.xyz
# and then run the test, e.g. 
# xyz_gofr.py /tmp/pippo.xyz
# or even better inside gnuplot:
#gnuplot> p[:60]'<xyz_gofr.py pippo.xyz '


def minus(x, y):
    return x-y

def fracpart(x):
    return x-math.floor(x+0.5)

def distance_pbc(vect1,vect2,U,Uinv):
    '''a distance that takes PBC into account'''
    dif=list(map(minus, vect1, vect2)) #joining vector
    trasf=np.dot(U,dif)  #go to normalized cell-units
    unit=list(map(fracpart,trasf)) #each component in -0.5 to 0.5 interval
    dif=np.dot(Uinv,unit) #minimal vector
    return np.linalg.norm(dif)

def distance_standard(vect1,vect2):
    '''standard distance, no PBC'''
    dif=list(map(minus, vect1, vect2)) #joining vector
    return np.linalg.norm(dif)

def timewarn(old_time,delaytime,nframe,atom=None):
    '''ring a bell to stderr when more than delaytime has elapsed'''
    new_time = time.time()
    if new_time-old_time >= delaytime:
        if atom==None:
            print (commandname,"now preprocessing frame #",nframe,file=sys.stderr)
        else:
            print (commandname,"now processing frame #",nframe,"atom #",atom, file=sys.stderr)
        old_time=new_time
    return old_time


def xyz_gofr(args):
    """compute g(r) doing statistics of the distances in a xyz file,
       work in 2D or 3D alike, and assume PBC or OBC alike"""
    global a1,a2,a3
    global debug

    commandname=args.prog
    debug=args.debug
    old_time = time.time()
    delaytime=10  # seconds before the code tells stderr what it is doing
    pi=math.pi
# default values:
    dimensions=3 # by default 3D g(r)
    nintervals=100 # some default number of divisions, essentially unused
    nframe=0

    a1=args.a1
    a2=args.a2
    a3=args.a3
    dr=args.dr

    area=abs(a1[0]*a2[1]-a1[1]*a2[0]) # area of parallelegram
    volume=abs(np.linalg.det([a1,a2,a3])) # vol. of parallelepiped

    if volume < 1.e-15:
        print(commandname,"WARNING: volume=",volume, file=sys.stderr)
    
    ndist=0
    stored=[]

# preliminary calculations:    
    if args.PBC:  # default is PBC:
        print(commandname,": PBC calculation in cell with primitive vectors:\n ",a1,";",a2,";",a3, file=sys.stderr)
# matrix for cell mapping to [0-1]x[0-1]x[0-1]
        Uinv=np.array([a1,a2,a3])
        Uinv=Uinv.transpose()
        U=np.linalg.pinv(Uinv)

#        print("quii",Uinv, U)
        
# shortest and longest cell vector:
        side=[np.linalg.norm(a1),np.linalg.norm(a2),np.linalg.norm(a3)]
        dmax=max(side)
        dmax=dmax*1.1  # to prevent overflow we go slightly outward
        dmin=min(side)
        if dmin<=0:     # 2D case:
            dmin=min(side[:1])
        if dmin<=0:
            print(commandname,"ERROR: side=",side,"dmin=",dmin)
        dmin=dmin*0.5 # 50% of the shortest side is the maximum trustable range

    else: # OBC case:
        cmin=np.full((3), sys.float_info.max)
        cmax=np.full((3),-sys.float_info.max)
        for filen in args.filenames: # initial survey of the xyz files
            if filen=="-":
                f=sys.stdin
            else:
                f = open(filen, 'r')

            while True:
                old_time=timewarn(old_time,delaytime,nframe)
                snap=xyz_utils.xyz_read_one_frame(f)
                if snap.comment==None:
                    break
                nframe+=1
                natom,lmin,lmax,lave,lav2,lstd=xyz_statistics.xyz_frame_statistics(snap)

#                print ("quiaaa",cmin,cmax)
                for i in range(3):
                    cmin[i]=min(cmin[i],lmin[i])
                    cmax[i]=max(cmax[i],lmax[i])
#            print ("quiuuu",cmin,cmax)

        minz=cmin[2]
        maxz=cmax[2]
        side=cmax-cmin  # side of the parallelepiped box enclosing the complete sample
#        side=[0.]*3
#        for i in range(3):
#            side[i]=cmax[i]-cmin[i]

# shortest and longest cell side:
        dmax=max(side)*math.sqrt(3.)*0.75
        dmin=min(side)
        if dmin==0:
            dimensions=2
            dmin=min(side[:1])
        if dmin==0:
            print("Error in",commandname,"wrong cell side:",side,dmin)
            sys.exit()
        dmin=dmin*0.25  # 25% of the shortest side is the farthest distance where one can safely measure g(r)
        bmin=np.zeros([3], dtype="float")
        bmax=np.zeros([3], dtype="float")
        bmin[2]=cmin[2]
        bmax[2]=cmax[2]
        for i in range(dimensions): # define the central parallelepiped
            bmin[i]=cmin[i]+dmin
            bmax[i]=cmax[i]-dmin

        area=(cmax[0]-cmin[0])*(cmax[1]-cmin[1]) # area of rectangle
        volume=area*(cmax[2]-cmin[2]) # vol. of parallelepiped

        if args.areafix>0: # overwrite with user-defined info, if available
            area=args.areafix
        if args.volumefix>0:
            volume=args.volumefix
#        print cmin,cmax
        print(commandname,": OBC calculation in overall cell of corners:\n ",cmin,";",cmax, file=sys.stderr)
        print(commandname,": inner cell of corners\n ",bmin,";",bmax, file=sys.stderr)
        print(commandname,": dimensions=",dimensions,"volume=",volume,"area=",area, file=sys.stderr)

# this part is OK for both PBC and OBC:
    nintervals=int(1.+1.*dmax/dr)
    if debug:
        print("debug dmax =",dmax,"dr =",dr,"nintervals=",nintervals)
# end of preliminary calculations

#              the actual calculation of the histogram:
    histogram=np.zeros([nintervals], dtype="int")
    nframe=0
    totatom=0

    minz=9e99
    maxz=-minz
    for filen in args.filenames:
        if filen=="-":
            f=sys.stdin
        else:
            f = open(filen, 'r')

        while True:
            snap=xyz_utils.xyz_read_one_frame(f)
            if snap.comment==None:
                break
            nframe+=1
            natom=len(snap.atoms)
            totatom+=natom

            if args.PBC:
                zlist=snap.coords[:,2]
                minz=min(minz,np.amin(zlist))
                maxz=max(maxz,np.amax(zlist))
                for i in range(natom):
                    old_time=timewarn(old_time,delaytime,nframe,i)
                    here=snap.coords[i]
                    for j in range(i-1):	# this loop makes the algorithm O(N^2). Workaround: evaluate S(q) and use Fourier transform
                        prev=snap.coords[j]
                        k=int(math.floor(distance_pbc(here,prev,U,Uinv)/dr))
#                        print >> sys.stderr, " QUII dist = ", distance_pbc(here,prev,U,Uinv),dr,k
                        histogram[k]+=1
                        ndist+=1

            else: # OBC case:

                for i in range(natom):
                    old_time=timewarn(old_time,delaytime,nframe,i)
                    here=snap.coords[i]
                    if bmin[0]<=here[0] and here[0]<=bmax[0] and bmin[1]<=here[1] and here[1]<=bmax[1] and bmin[2]<=here[2] and here[2]<=bmax[2]:
                        for j in range(i-1):
                            prev=snap.coords[j]
                            k=int(math.floor(distance_standard(here,prev)/dr))
                            histogram[k]+=1
                            ndist+=1

        if filen!="-":
            f.close()
    if minz==maxz:  # only needed for PBC, should not hurt to repeat it in OBC
        dimensions=2

# end of both loops on filenames and if PBC; final wrapup of the statistics:

    storer=np.zeros([nintervals], dtype="float")
    storegofr=np.zeros([nintervals], dtype="float")
    npart=float(totatom)/nframe
#    NB!  for constant npart, ndist coincides with  nframe*npart*(npart-1)/2 
    print(commandname,"- npart:", npart, "; nframe:", nframe, "; dr:", dr,"; #distances:", ndist, file=sys.stderr)

    if dimensions==2: # 2D case
        if args.PBC:
            print("              note:  minz==maxz, therefore using 2D formula for g(r)", file=sys.stderr)
        for k in range(nintervals):
            r=(k+0.5)*dr
            gofr=histogram[k]*area/(ndist*2*pi*r*dr)
            storer[k]=r
            storegofr[k]=gofr
    else:             # 3D case
        for k in range(nintervals):
            r=math.sqrt(k*k +k+1./3)*dr  # a volume-weighted average value of r in the k dr< r < (k+1)dr range
            if debug:
                print("debug r =",r,"dr =",dr,"ndist=",ndist,"volume =",volume)
            gofr=histogram[k]*volume/(ndist*4*pi*r*r*dr)
            storer[k]=r
            storegofr[k]=gofr

    header="#r	g(r)  ["+str(dimensions)+"D formula], obtained with "+str(ndist)+" samples. Reliable for r < "+str(dmin)
    return header,storer,storegofr



# the following function is only exectuted when the code is run as a script, and its purposes is to parse
# the command line and to generate a meaningful parsed args list to the actual function doing the job:
if __name__ == "__main__":
    import sys
    import argparse
    commandname=sys.argv[0]

    desc="""compute g(r) from a xyz file
OUTPUT: a self-explanatory g(r) histogram"""

    epil="""\t\t\tv. 2.1	by Nicola Manini, 02/05/2020"""

##  Argument Parser definition:
    parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter
                                    , description=desc, epilog=epil)

    parser.add_argument( 'filenames', nargs='*', default=['-'],
                         help='Files to be processed. If not given, stdin is used')

    parser.add_argument( '-d', '--debug', action='store_true',
                         dest='debug', 
                         help='activate debug mode -- WARNING! output is affected/spoiled!' )

    parser.add_argument( '--dr',
                         dest='dr', type=float, default=0.1,
                         help='the dr interval fixing the histogram fineness')

    parser.add_argument( '--a1', type=float, nargs=3,
                         dest='a1', default=[100.,0.,0.],
                         help='1st primitive vector, expecting 3 components' )

    parser.add_argument( '--a2', type=float, nargs=3,
                         dest='a2', default=[0.,100.,0.],
                         help='2nd primitive vector, expecting 3 components' )
    
    parser.add_argument( '--a3', type=float, nargs=3,
                         dest='a3', default=[0.,0.,100.],
                         help='3rd primitive vector, expecting 3 components' )
    
    parser.add_argument( '-o','--OBC', action='store_false',
                         dest='PBC',
                         help='''assume OBC rather than the default PBC:
                           in the OBC case by default the cell is
                           extracted from xyz file boundaries: the first
                           atoms are taken from inside the central
                           parallelepiped of half size ''')

    parser.add_argument( '-a',
                         dest='areafix', type=float, default=0.,
                         help='fix the area for the normalization of OBC g(r)')
    
    parser.add_argument( '-v',
                         dest='volumefix', type=float, default=0.,
                         help='fix the volume for the normalization of OBC g(r)')
    
##  End arg parser definition

    args=parser.parse_args(sys.argv[1:])
    d = vars(args)	# adding prog in args, for unknown reasons it's not there...
    d['prog']=parser.prog


#   here the actual function doing the job is called:
    header,storer,storegofr=xyz_gofr(args)

#   the results are printed out:
    print(header)
    for i in range(len(storer)):
        print(storer[i],storegofr[i])
