#!/usr/bin/env python3

import sys
from rkf45 import r8_rkf45   # file rkf45.py needed: it provides RKF integration
import numpy as np           # to install this lib in deb-type linux, run as root: apt-get install python3-numpy


# This graphics function below is entirely optional.
# If you are unwilling to install matplotlib, just run with option --np
# or you can even delete it, or comment it out:
def makeplots(tlist,xcmlist,vcmlist,title):

    from matplotlib import pyplot as plt
    from matplotlib.animation import FuncAnimation
# to install this lib, run as root: apt-get install python-matplotlib

############################    subfigure 1
    plt.subplot( 2, 1, 1 )
    plt.plot( tlist, xcmlist, 'b-o')
    plt.xlabel( '$t$' )
    plt.ylabel( '$x_{cm}$' )
    plt.title(title)
#    plt.legend( ( '$x_{cm}$' ), loc='upper left' )
    plt.legend( ( "xcm" ), loc='upper left' )

###########################    subfigure 2 
    plt.subplot( 2, 1, 2 )
    plt.plot( tlist, vcmlist, 'b-o')
    plt.xlabel( '$t$' )
    plt.ylabel( '$v_{cm}$' )
    plt.legend( ( '$v_{cm}$' ), loc='lower right' )

#    plt.draw()  # draw plot and continue execution
    plt.show()  # draw plot, stop & wait
#    plt.close()


# This graphics function below is entirely optional.
# If you are unwilling to install matplotlib, just run with option --np
# or you can even delete it, or comment it out:
def oo_makeplots(tlist,xcmlist,vcmlist,title):
#                                   object-oriented version  TO BE DONE!

    from matplotlib import pyplot as plt
    from matplotlib.animation import FuncAnimation
# to install this lib, run as root: apt-get install python-matplotlib

############################    subfigure 1
    fig,axes = plt.subplot( 2, 1, 1 )
    plt.plot( tlist, xcmlist, 'b-o')
    plt.xlabel( '$t$' )
    plt.ylabel( '$x_{cm}$' )
    plt.title(title)
#    plt.legend( ( '$x_{cm}$' ), loc='upper left' )
    plt.legend( ( "xcm" ), loc='upper left' )

###########################    subfigure 2 
    plt.subplot( 2, 1, 2 )
    plt.plot( tlist, vcmlist, 'b-o')
    plt.xlabel( '$t$' )
    plt.ylabel( '$v_{cm}$' )
    plt.legend( ( '$v_{cm}$' ), loc='lower right' )

#    plt.draw()  # draw plot and continue execution
    plt.show()  # draw plot, stop & wait
#    plt.close()



def derivs(t, y):
    """the left side of the Newton equation (transformed to first order) for the FK model
    t is time; 
    y is a vector whose first half is coordinates, second half is velocities
    this functions contains all the physics of the model"""
    neq=len(y)
    nhalf=neq//2
    deriv=[0.]*neq  # initialize the new array
#    for i in range(nhalf):  # the second half of y is velocities
#        deriv[i]=y[i+nhalf] # this enables the mapping of Newton to 1st order
    deriv[:nhalf]=y[nhalf:] # the second half of y is velocities
        
    for i in range(1,nhalf-1): # omit first & last element-> end atoms
        deriv[i+nhalf]=force-force_amplitude*np.sin(twopi*y[i])+spring_const*(y[i+1]+y[i-1]-2*y[i])-gamma*deriv[i]
#        print ("quii",i,y[i+1],y[i-1],y[i])

# end particles with boundary conditions: BC==1 -> PBC; BC==0 -> OBC
    i=0  # -> leftmost particle
    deriv[i+nhalf]=f_at_end+force-force_amplitude*np.sin(twopi*y[i])\
            +spring_const*(y[i+1]-y[i]-spacing+BC*(y[nhalf-1]-y[i]-lengthM1))\
            -gamma*deriv[i]
#    print ("quii",i,y[i+1],y[nhalf-1]-lengthM1,y[i])
    i=nhalf-1  # -> rightmost particle
    deriv[i+nhalf]=force-force_amplitude*np.sin(twopi*y[i])\
            +spring_const*(BC*(y[0]-y[i]+lengthM1)+y[i-1]-y[i]+spacing)\
            -gamma*deriv[i]
#    print ("quii",i,y[0]+lengthM1,y[i-1],y[i])
    
    return deriv

def vcmcompute(tlist,xcmlist):
    nstep=len(tlist)
    vcmaveGood=(xcmlist[nstep-1]-xcmlist[nstep//2])/(tlist[nstep-1]-tlist[nstep//2])
# compute the cm average speed as  Delta x / Delta t  : safe & reliable!
#    vcmaveBad=np.average(vcmlist[nstep/2:])
    return vcmaveGood


def fk_one_simulation(BCstr,npart,timefin,time_step,spacing,spring_const,\
                     gamma,pot_amplitude,forc,forc_at_end):
    """integration of the standard driven Frenkel-Kontorova model"""
    global force, f_at_end
    force=forc
    f_at_end=forc_at_end
    label=BCstr+'__n_'+str(npart)+'__a_'+str(spacing)+'__k_'+str(spring_const)+'__gam_'+str(gamma)+'__U0_'+str(pot_amplitude)+'__F_'+str(force)+'__FaE_'+str(f_at_end)
    print("# starting a FK calculation for")
    print("#	"+label)
    filen="output_"+label+".dat"
#   end of setting up stuff
    
#   putting equally spaced particles as initial condition:
    x0=(spacing*np.asarray(range(npart))).tolist()
    v0=[0.]*npart	# all 0 speeds at start
    y=x0+v0
    if debug:
        print("initial condition:",x0,v0,y)
    neq=len(y)
    yp=derivs(0.,y)

    if debug:
        print("#starting point:",y)
        print("#starting deriv:",yp)

    relerr=1.e-6
    abserr=1.e-10
    flag=1

    tlist=[]
    xcmlist=[]
    vcmlist=[]
    if args.writeotputf:
        outputf = open(filen, 'w')
        o=""
        print(0, o.join((" "+str(i)) for i in y), file=outputf)
    nstep=int(round(1.*timefin/time_step))
    for it in range(nstep):
        ti=it*time_step
        tf=(it+1)*time_step
        tlist.append(tf)
# in a future version of scipy one may rather use:      scipy.integrate.RK45
        y, yp, t, flag = r8_rkf45( derivs, neq, y, yp, ti, tf, relerr, abserr, flag )
        if flag!=2:
            print("Warning! flag =",flag,".... trying to keep on going")
            flag=2
        if args.writeotputf:
            o=""
            print(tf, o.join((" "+str(i)) for i in y), file=outputf)
        xcm=0
        vcm=0
        for i in range(npart):
            xcm+=y[i]
        xcm/=npart
        for i in range(npart,2*npart):
            vcm+=y[i]
        vcm/=npart
        print("integrated from",ti,"to",tf,"xcm=",xcm,"vcm=",vcm)
        xcmlist.append(xcm)
        vcmlist.append(vcm)

    nback=10
    if nstep<100:
        nback=5
    elif nstep<10:
        nback=3
    elif nstep<5:
        nback=2
    endvcm=(xcmlist[nstep-1]-xcmlist[nstep-1-nback])/(tlist[nstep-1]-tlist[nstep-1-nback])

    print("#end of the integration, final vCM=",endvcm)
    if args.writeotputf:
        outputf.close()
    return endvcm,tlist,xcmlist,vcmlist

def bisection(f,otherarguments,a,b,tol):
    """bisect to find 0 of function f in interval [a,b], with tolerance tol
f(x,otherarguments) expects 2 variables: a float x, plus
           a second variable collecting whatever other arguments f may need"""
    fa=f(a,otherarguments)
    fb=f(b,otherarguments)
    if(fb*fa>0):
        print("bisection error: function has same sign at",a,":",fa,"and at",b,":",fb)
        exit()
    if fa==0.:
        return a
    if fb==0.:
        return b
    
    c = (a+b)/2.0
    while True:
        fc=f(c,otherarguments)
        if fc == 0:     # got there!
            break
        elif fa*fc > 0: # same sign as f @ a: replace & get closer to solution
            a = c
            fa=fc
        else:           # same sign as f @ b: replace & get closer to solution
            b = c
            fb=fc
        c = (a+b)/2.0   # bisect
        if b-a < tol:
            break
    return c

def zerofunc(x,otherarguments):
    '''the function whose zero is targeted by the bisection'''
    [BCstr,npart,timefin,time_step,spacing,spring_const,gamma,pot_amplitude,f_at_end,thresholdspeed]=otherarguments
    force=x
    endvcm,tlist,xcmlist,vcmlist = \
        fk_one_simulation(BCstr,npart,timefin,time_step,spacing,\
                         spring_const,gamma,pot_amplitude,force,f_at_end)
    return endvcm-thresholdspeed

def main_fk(args):
    '''driver for the FK calculations. Can do one of:
- a single FK run;
- a sequence of FK runs with an incremental sequence of driving forces;
- a bisection search of the depinning force + a single FK run @ the final threshold force.
'''
    commandname=args.prog
    global debug
    debug = args.debug
    global BC,gamma,spring_const,force_amplitude,lengthM1,spacing
    global pi,twopi
    pi=np.pi
    twopi=2*pi

    spacing=args.spacing
    BC=args.BC
    force=args.force
    fin=args.fin
    ffin=args.ffin
    deltaf=args.deltaf
    f_at_end=args.f_at_end
    gamma=args.gamma
    npart=args.npart
    spring_const=args.spring_const
    timefin=args.timefin
    time_step=args.time_step
    pot_amplitude=args.pot_amplitude

    force_amplitude=pot_amplitude*pi # derivative of Vext(x)=-U0/2 cos(2 pi x)
    lengthM1=spacing*(npart-1)  # used in PBC
    if BC==1:
        BCstr="PBC"
    else:
        BCstr="OBC"
    print("# dimensionless ratio k/ (2 U0 (pi/a_substrate)^2) =",\
               spring_const/(2*pot_amplitude*pi*pi))

    if args.finddepforce:
#		the initial bracket for bisection:
        force_min=0
        force_max=force
        otherarguments=[BCstr,npart,timefin,time_step,spacing,spring_const,gamma,pot_amplitude,f_at_end,args.thresholdspeed]
        force_dep=bisection(zerofunc,otherarguments,\
                             force_min,force_max,args.tol_force)
        print("# depinning force =",force_dep)
        force=force_dep
    if deltaf==0: # execute a single calculation:
        endvcm,tlist,xcmlist,vcmlist = fk_one_simulation(BCstr,npart,timefin,\
          time_step,spacing,spring_const,gamma,pot_amplitude,force,f_at_end)
        print("#F= ",force," <vcm>=",vcmcompute(tlist,xcmlist))
    else:	# a sequence of calculations for changing force:
        for force in np.linspace(fin,ffin,int(round((ffin-fin)/deltaf+1))):
            endvcm,tlist,xcmlist,vcmlist = fk_one_simulation(BCstr,npart,timefin,\
              time_step,spacing,spring_const,gamma,pot_amplitude,force,f_at_end)
            print("F= ",force," <vcm>=",vcmcompute(tlist,xcmlist),"\n")

    sys.stdout.flush()

# This graphics part in the next 3 lines here is entirely optional.
# If you are unwilling to install matplotlib just run with the --np option
    if args.doplot:
        title='FK in '+BCstr+', $N='+str(npart)+'$, $a='+str(spacing)+'$, $k = '+str(spring_const)+'$, $\gamma='+str(gamma)+'$, $F='+str(force)+'$'
        makeplots(tlist,xcmlist,vcmlist,title)



# the following function is only exectuted when the code is run as a script, and its purposes is to parse
# the command line and to generate a meaningful parsed args list to the actual function doing the job:
if __name__ == "__main__":
    import argparse

    desc="""simulate the FK model using accurate RKF integration.
OUTPUT: an array of time + positions of all particles
"""

    epil="""\t\t\tv. 1.3 by Nicola Manini, 03/05/2020"""

    ##  Argument Parser definition
    parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter
                                      , description=desc, epilog=epil)

    parser.add_argument( '-d', '--debug', action='store_true',
                         dest='debug',
                         help='activate debug mode -- WARNING! output is affected/spoiled!' )

    parser.add_argument( '-a',
                         dest='spacing', type=float, default=1.0,
                         help='lattice spacing of particles (units of sine period)')
    parser.add_argument( '-b',
                         dest='BC', type=int, default=0,
                         help='boundary condition: 0=OBC; 1=PBC')
    parser.add_argument( '-f',
                         dest='force', type=float, default=0.0,
                         help='external driving force')
    parser.add_argument( '--Fin',
                         dest='fin', type=float, default=0.0,
                         help='initial driving force')
    parser.add_argument( '--Ffin',
                         dest='ffin', type=float, default=0.0,
                         help='final driving force')
    parser.add_argument( '--Fde',
                         dest='deltaf', type=float, default=0.0,
                         help='driving force increment')
    parser.add_argument( '--f-at-end force',
                         dest='f_at_end', type=float, default=0.0,
                         help='external force applied only to the leftmost particle')
    parser.add_argument( '--find-depinning-force',
                         dest='finddepforce', action='store_true',
                         help='execute a bisection calculation of the depinning force')
    parser.add_argument( '--tol_force',
                         dest='tol_force', type=float, default=1.e-3,
                         help='tolerance in the determination of the depinning force')
    parser.add_argument( '--thresholdspeed',
                         dest='thresholdspeed', type=float, default=1.e-6,
                         help='small speed threshold for computing the depinning force: below this speed the chain is considered statically pinned')
    parser.add_argument( '-g',
                         dest='gamma', type=float, default=0.1,
                         help='viscous damping coefficient')
    parser.add_argument( '-n',
                         dest='npart', type=int, default=10,
                         help='the number of particles')
    parser.add_argument( '--nof',
                         dest='writeotputf', action='store_false',
                         help='generate no output file(s)')
    parser.add_argument( '--np',
                         dest='doplot', action='store_false',
                         help='generate no plot at the end')
    parser.add_argument( '-k',
                         dest='spring_const', type=float, default=1.0,
                         help='the value of the spring constant')
    parser.add_argument( '-t',
                         dest='timefin', type=float, default=20.0,
                         help='the total simulation time')
    parser.add_argument( '--dt',
                         dest='time_step', type=float, default=2.0,
                         help='the time interval, used only for writing')
    parser.add_argument( '-U',
                         dest='pot_amplitude', type=float, default=1.0,
                         help='peak-valley potential amplitude U0 in -(U0/2)*cos(2 pi x)')

    
    ## End arg parser definition
#    args = argparse.parser()
    args=parser.parse_args(sys.argv[1:])
    d = vars(args)	# adding prog in args, for unknown reasons it's not there...
    d['prog']=parser.prog
    
#   here the actual function doing the job is called:
    stored=main_fk(args)
# and here some optional print out could be added:
