import numpy as np
import math 
from matplotlib import pyplot as plt

### a(x,v) is a function of the position and velocity,
### x is the array of the positions
### v is the array of the velocities
### dt is the timestep

def euler(a, x, v, dt):
    assert len(x) == len(v)
    for i in range(0, len(x)-1):
        v[i+1] = v[i] + a(x[i], v[i]) * dt
        x[i+1] = x[i] + v[i] * dt
    return ["Euler method","eul"]


def leapfrog(a, x, v, dt):
    assert len(x) == len(v)
    ### leapfrog works only for equations of the form \ddot x = f(x)
    a_ = lambda x: a(x, 0) 
    ### v2 stores the velocities at halfstep v2[i] = v_i+1/2
    v2 = np.zeros(len(v))
    v2[0] = v[0] + a_(x[0]) * dt/2.
    for i in range(0, len(x)-1):
        x[i+1]  = x[i] + v2[i] * dt
        v2[i+1] = v2[i] + a_(x[i+1]) * dt  
        v[i+1]  = (v2[i] + v2[i+1])/2. #only needed for the plots
    return ["Leapfrog method","lf"]


def runge_kutta_4(a, x, v, dt):
    assert len(x) == len(v)
    f = lambda u: np.array([u[1], a(u[0], u[1])])

    for i in range(0, len(x)-1):
        u = np.array([x[i], v[i]])
        k1 = dt * f(u)
        k2 = dt * f(u + k1/2.)
        k3 = dt * f(u + k2/2.)
        k4 = dt * f(u + k3)
        u = u + (k1 + 2.*(k2+k3) +k4)/6.
        x[i+1] = u[0]
        v[i+1] = u[1]
    return ["Runge-Kutta 4th order","rk4"]


def solve_simple_oscillator(solver):
    ### Parameters
    Tmax = 10.
    dt = .1
    x0 = 1.
    v0 = 1.
    w = 2. #alpha/sqrt(m)

    ### equation of motion
    eq_mot = lambda x, v : - w**2 * x
    ### energy divided by the mass
    energy = lambda x, v : .5 * v**2 + w**2/2. * x**2
    ### exact position
    x_ex = lambda x0, v0, t : x0 * np.cos(w*t) + v0/w * np.sin(w*t)
    ### exact velocity
    v_ex = lambda x0, v0, t : -w*x0 * np.sin(w*t) + v0 * np.cos(w*t)

    t = np.arange(0, Tmax, dt)
    x = np.zeros(len(t))
    v = np.zeros(len(t))

    x[0] = x0
    v[0] = v0

    ### solve eq of motion
    [title, suffix] = solver(eq_mot, x, v, dt)

    ### compute energy
    E = energy(x, v)
    E_ex = energy(x0, v0)

    ### make plots
    plt.figure()
    plt.title(title)

    ax = plt.subplot(211)

    plt.plot(t, x_ex(x0, v0, t), '-', c='g', label = "$x_{ex}$")
    plt.plot(t, x, '-', c='b', label = "$x_{num}$")
    plt.plot(t, v_ex(x0, v0 ,t), '-', c='orange', label = "$\dot{x}_{ex}$")
    plt.plot(t, v, '-', c='r', label = "$\\dot{x}_{num}$")

    ylim = ax.get_ylim()
    plt.ylim(ylim[0]-(ylim[1]-ylim[0])/5., ylim[1])
    plt.legend(loc = "lower left", ncol=4)
    plt.title("Position and Velocity")

    ax = plt.subplot(212)

    plt.plot([0, Tmax-dt], [E_ex, E_ex], '-', c='g', label = "$E_{ex}$")
    plt.plot(t, E, '-', c='b', label = "$E_{num}$")

    E_max = math.ceil(max(E))

    ylim = ax.get_ylim()
    plt.ylim(ylim[0]-E_ex/5., ylim[1]+E_ex/5.)
    plt.legend(loc = "upper left", ncol=2)
    plt.title("Energy / $m$")

    plt.suptitle(title)
    plt.savefig("simple_oscillator_"+suffix+".pdf")
    #plt.show()


def animate_moving_pendulum(x, l, dt, title, filename):
    from matplotlib import animation
    fig = plt.figure()
    x1 = np.array([x_[0] for x_ in x])
    y1 = np.zeros(len(x1))
    phi = np.array([x_[1] for x_ in x])
    
    x2 = x1 + l * np.sin(phi)
    y2 =    - l * np.cos(phi)
    
    xmin = min( min(x1), min(x2)) - 1
    xmax = max( max(x1), max(x2)) + 1
    
    plt.plot([xmin, xmax], [0,0], '-', c='k')
    line, = plt.plot([], [], 'o-', c='r', lw=2)
    plt.title(title)
    plt.ylim(-1.1*l, 1.1*l)
    plt.xlim(xmin, xmax)
    plt.gca().set_aspect('equal', adjustable='box')
    
    ### init function
    def init():
        line.set_data([],[])
        return line

    ### animation function
    def frame(i):
        line.set_data([x1[i], x2[i]],[y1[i], y2[i]])
        return line
    
    anim = animation.FuncAnimation(fig, frame, init_func=init, frames=len(x1), interval=dt*1000, blit=True)

    ### codecs has to be installed to be able to save the animation
    ### choose one of the two codecs
    anim.save(filename+'.mp4', extra_args=['-vcodec', 'libx264'])
    anim.save(filename+'.webm', extra_args=['-vcodec', 'libvpx'])


def solve_moving_pendulum(solver):
    ### Parameters
    Tmax = 10.
    dt = .05
    l = 2.
    g = 10.
    mu = 5.
    x0 = 0.
    v0 = 0.
    phi0 = 3.*math.pi/4.
    dot_phi0 = 0.

    # equation of motion
    eq_mot = lambda x, v : np.array([ np.sin(x[1])*(l*v[1]**2+g*np.cos(x[1]))/(mu-np.cos(x[1])**2), -np.sin(x[1])*(l*v[1]**2*np.cos(x[1])+ mu * g)/l/(mu -np.cos(x[1])**2) ])
    # energy divided by the mass m_2
    energy = lambda x, v : .5 * (mu * v[0]**2 + l**2*v[1]**2 + 2*l*v[0]*v[1]*np.cos(x[1])) + g*l*(1-np.cos(x[1]))

    t = np.arange(0, Tmax, dt)
    x = [np.array([0,0])] * len(t)
    v = [np.array([0,0])] * len(t)
    E = np. zeros(len(t))
    
    x[0] = np.array([x0, phi0])
    v[0] = np.array([v0, dot_phi0])
    
    ### solve eqs of motion
    [title, suffix] = solver(eq_mot, x, v, dt)

    ### compute energy
    for i in range(len(x)):
        E[i] = energy(x[i], v[i])
    E_ex = energy(x[0], v[0])

    ### make plots
    plt.figure()

    ax = plt.subplot(211)

    plt.plot(t, [x_[0] for x_ in x], '-', c='b', label = "$x$")
    plt.plot(t, [v_[0] for v_ in v], '-', c='r', label = "$\\dot{x}$")
    plt.plot(t, [x_[1] for x_ in x], '-', c='g', label = "$\\phi$")
    plt.plot(t, [v_[1] for v_ in v], '-', c='orange', label = "$\\dot{\\phi}$")

    ylim = ax.get_ylim()

    plt.ylim(ylim[0]-(ylim[1]-ylim[0])/5., ylim[1])
    plt.legend(loc = "lower left", ncol=4)
    plt.title("Positions and Velocities")

    ax = plt.subplot(212)

    plt.plot([0, Tmax-dt], [E_ex, E_ex], '-', c='g', label = "$E_{ex}$")
    plt.plot(t, E, '-', c='b', label = "$E_{num}$")

    E_max = math.ceil(max(E))

    ylim = ax.get_ylim()
    plt.ylim(ylim[0]-E_ex/5., ylim[1]+E_ex/5.)
    plt.legend(loc = "upper left", ncol=2)
    plt.title("Energy / $m_2$")

    plt.suptitle(title)
    plt.savefig("moving_pendulum_"+suffix+".pdf")
    
    ### animation requires codecs to save to file, remove next line if they are not installed
    animate_moving_pendulum(x, l, dt, title, "moving_pendulum_"+suffix)
    
    #plt.show()


### main program
solve_simple_oscillator(euler)

solve_simple_oscillator(leapfrog)

solve_simple_oscillator(runge_kutta_4)

solve_moving_pendulum(euler)

solve_moving_pendulum(runge_kutta_4)
