import numpy as np
import math 
from matplotlib import pyplot as plt

### Implement the following methods \ddot{x} = a(x, v), where v=\dot{x}
### a(x,v) is a function of the position and velocity,
### x is the array of the positions
### v is the array of the velocities
### dt is the timestep

def euler(a, x, v, dt):
    assert len(x) == len(v)
    ###
    ### IMPLEMENT EULER METHOD
    ###
    return ["Euler method","eul"]


def leapfrog(a, x, v, dt):
    assert len(x) == len(v)
    a_ = lambda x: a(x, 0) #leapfrog works only for equations of the form \ddot x = f(x)
    ###
    ### IMPLEMENT LEAPFROG METHOD
    ### Warning: To compute the energy at time t_i, you need x_i and v_i = (v_{i-1/2} + v_{i+1/2})/2 
    ###
    return ["Leapfrog method","lf"]


def runge_kutta_4(a, x, v, dt):
    assert len(x) == len(v)
    ###
    ### IMPLEMENT 4TH ORDER RUNGE-KUTTA_METHOD
    ###
    return ["Runge-Kutta 4th order","rk4"]


def solve_simple_oscillator(solver):
    ### Parameters
    Tmax = 10.
    dt = .1
    x0 = 1.
    v0 = 1.
    w = 2. #alpha/sqrt(m)

    # equation of motion
    eq_mot = lambda x, v : ### IMPLEMENT
    # energy divided by the mass
    energy = lambda x, v : ### IMPLEMENT
    # exact position
    x_ex = lambda x0, v0, t : ### IMPLEMENT
    # exact velocity
    v_ex = lambda x0, v0, t : ### IMPLEMENT

    t = np.arange(0, Tmax, dt)
    x = np.zeros(len(t))
    v = np.zeros(len(t))

    x[0] = x0
    v[0] = v0

    ### solve eq of motion
    [title, suffix] = solver(eq_mot, x, v, dt)

    ### compute energy
    E = energy(x, v)
    E_ex = energy(x0, v0)

    ### make plots
    plt.figure()
    plt.title(title)

    ax = plt.subplot(211)

    plt.plot(t, x_ex(x0, v0, t), '-', c='g', label = "$x_{ex}$")
    plt.plot(t, x, '-', c='b', label = "$x_{num}$")
    plt.plot(t, v_ex(x0, v0 ,t), '-', c='orange', label = "$\dot{x}_{ex}$")
    plt.plot(t, v, '-', c='r', label = "$\\dot{x}_{num}$")

    ylim = ax.get_ylim()
    plt.ylim(ylim[0]-(ylim[1]-ylim[0])/5., ylim[1])
    plt.legend(loc = "lower left", ncol=4)
    plt.title("Position and Velocity")

    ax = plt.subplot(212)

    plt.plot([0, Tmax-dt], [E_ex, E_ex], '-', c='g', label = "$E_{ex}$")
    plt.plot(t, E, '-', c='b', label = "$E_{num}$")

    E_max = math.ceil(max(E))

    ylim = ax.get_ylim()
    plt.ylim(ylim[0]-E_ex/5., ylim[1]+E_ex/5.)
    plt.legend(loc = "upper left", ncol=2)
    plt.title("Energy / $m$")

    plt.suptitle(title)
    plt.savefig("simple_oscillator_"+suffix+".pdf")
    #plt.show()


def solve_moving_pendulum(solver):
    ### Parameters
    Tmax = 10.
    dt = .05
    l = 2.
    g = 10.
    mu = 5.
    x0 = 0.
    v0 = 0.
    phi0 = 3.*math.pi/4.
    dot_phi0 = 0.

    # equation of motion
    eq_mot = lambda x, v : ### IMPLEMENT
    # energy divided by the mass m_2
    energy = lambda x, v : ### IMPLEMENT

    t = np.arange(0, Tmax, dt)
    x = [np.array([0,0])] * len(t)
    v = [np.array([0,0])] * len(t)
    E = np. zeros(len(t))
    
    x[0] = np.array([x0, phi0])
    v[0] = np.array([v0, dot_phi0])
    
    ### solve eqs of motion
    [title, suffix] = solver(eq_mot, x, v, dt)

    ### compute energy
    for i in range(len(x)):
        E[i] = energy(x[i], v[i])
    E_ex = energy(x[0], v[0])

    ### make plots
    plt.figure()

    ax = plt.subplot(211)

    plt.plot(t, [x_[0] for x_ in x], '-', c='b', label = "$x$")
    plt.plot(t, [v_[0] for v_ in v], '-', c='r', label = "$\\dot{x}$")
    plt.plot(t, [x_[1] for x_ in x], '-', c='g', label = "$\\phi$")
    plt.plot(t, [v_[1] for v_ in v], '-', c='orange', label = "$\\dot{\\phi}$")

    ylim = ax.get_ylim()

    plt.ylim(ylim[0]-(ylim[1]-ylim[0])/5., ylim[1])
    plt.legend(loc = "lower left", ncol=4)
    plt.title("Positions and Velocities")

    ax = plt.subplot(212)

    plt.plot([0, Tmax-dt], [E_ex, E_ex], '-', c='g', label = "$E_{ex}$")
    plt.plot(t, E, '-', c='b', label = "$E_{num}$")

    E_max = math.ceil(max(E))

    ylim = ax.get_ylim()
    plt.ylim(ylim[0]-E_ex/5., ylim[1]+E_ex/5.)
    plt.legend(loc = "upper left", ncol=2)
    plt.title("Energy / $m_2$")

    plt.suptitle(title)
    plt.savefig("moving_pendulum_"+suffix+".pdf")
    #plt.show()


### main program
solve_simple_oscillator(euler)

solve_simple_oscillator(leapfrog)

solve_simple_oscillator(runge_kutta_4)

solve_moving_pendulum(euler)

solve_moving_pendulum(runge_kutta_4)
