# Run with `bumps --fit=dream --burn=1000 --samples=1e6 --init=eps --parallel --store=output fit_thin_film_bloch.py`

import pandas as pd
import numpy as np

from bumps.names import Parameter, FitProblem

from bumps.parameter import Parameter, varying

class BlochLawThinFilm():
    def __init__(self, Ms, D, Lz, a, name=""):
        self.Ms = Parameter(Ms, name=name+"Ms")
        self.D = Parameter(D, name=name+"D")
        self.Lz = Parameter(Lz, name=name+"Lz")
        self.a = Parameter(a, name=name+"a")

    def parameters(self):
        return dict(Ms=self.Ms, D=self.D, Lz=self.Lz, a=self.a)

    def __call__(self, x, y):
        hbar = 1.054571817e-34 # J s
        gamma = 1.760859e11 # rad s^-1 T^-1 
        kB = 1.3806503e-23   # J/K

        Ms = self.Ms.value
        D0 = self.D.value
        Lz = self.Lz.value
        a = self.a.value

        N=int(Lz/a + 1)
        thinfilm_sum = 0.0
        for m in range(0, N):
            thinfilm_sum += np.log(1 - np.exp(-hbar*gamma*y/(kB*x) - (D0/(kB*x))*(m*np.pi/Lz)**2   ))
        # Ms*1000 converts kA/m -> A/m
        low_T = ((hbar * gamma) / (4*np.pi*Lz)) * ((kB*x) / (1000*Ms*D0))

        # Ms conversion is not needed here because we are comparting to
        # data in kA/m
        return Ms * (1 + low_T*thinfilm_sum)

def plot(X, Y, theory, data, err):
    import pylab

    # pylab.subplot(121)
    pylab.errorbar(X, data, err, fmt='o')
    pylab.plot(X, theory, '-')
    pylab.xlabel("T (K)")
    pylab.ylabel("Ms (kA/m)")
    pylab.xlim(0, 1.1*np.max(X))
    # pylab.subplot(122)
    # pylab.plot(X, (data-theory)/(err+1))

class Curve2D(object):
    def __init__(self, model, X, Y, data, err):
        self.X, self.Y, self.data, self.err = X, Y, data, err
        self.model = model

    def numpoints(self):
        return np.prod(self.data.shape)

    def parameters(self):
        return self.model.parameters()

    def theory(self):
        return self.model(self.X, self.Y)

    def residuals(self):
        #if np.any(self.err ==0): print "zeros in err"
        return (self.theory() - self.data)/(self.err+(self.err==0.))

    def nllf(self):
        R = self.residuals()
        #if np.any(np.isnan(R)): print "NaN in residuals"
        return 0.5*np.sum(R**2)

    def __call__(self):
        return 2*self.nllf()/self.dof

    def plot(self, view='linear'):
        plot(self.X, self.Y, self.theory(), self.data, self.err)

    def save(self, basename):
        import json
        pars = [(p.name,p.value) for p in varying(self.parameters())]
        out = json.dumps(dict(theory=self.theory().tolist(),
                              data=self.data.tolist(),
                              err=self.err.tolist(),
                              X = self.X.tolist(),
                              Y = self.Y.tolist(),
                              pars = pars))
        open(basename+".json","w").write(out)

    def update(self):
            pass

def read_data():
	df = pd.read_csv('CoB_mag.tsv', sep=r'\s+', comment='#')
	X = df['T_K'].to_numpy()
	Y = df['Hc_T'].to_numpy()
	data = df['M_kA_per_m'].to_numpy()
	err = df['M_err_kA_per_m'].to_numpy()
	return X, Y, data, err

def build_problem():
	# Make an estimate of the atomic radius of the atoms.
	# We assume dense packed FCC 111 planes are close to the shortest
	# possible distance between planes of atoms, which should
	# give an indication of the number of wavevectors we can have.
	# 2*atomic_radius would be a SC lattice and represents a very loosley packed
	# system
	atomic_radius=0.152e-9
	lattice_constant=atomic_radius*2*np.sqrt(2)
	distance_111=lattice_constant/np.sqrt(3)

	M = Curve2D(BlochLawThinFilm(Ms=800.0,D=1e-40,Lz=0.8e-9, a=lattice_constant), *read_data())

	M.model.Lz.value = 0.8e-9
	M.model.a.value = lattice_constant
	# M.model.a.range(atomic_radius,2*lattice_constant)

	M.model.Ms.value = np.max(M.data)
	M.model.Ms.range(np.min(M.data), 1.5*np.max(M.data))

	M.model.D.value =1e-40
	M.model.D.range(5e-41,10e-40)

	return FitProblem(M)

problem = build_problem()