import numpy as np
from scipy import linalg
from cosmoTransitions.generic_potential import generic_potential
from tqdm import tqdm


class singletSM(generic_potential):
    def init(self, a2 = 0.51, b3 = 20.0, b4 = 0.8 , m1 = 10.0 , cTheta = 0.01):
        """

        :rtype: object
        """
        self.Ndim = 2

        # AX = B
        # X = [lmda, mu2, b2, b3]
        # A = [[ 3*v0**2,   -1,     0,       0 ],
        #	 [ v0**2,       -1, 0,   0 ],
        #	 [ 0, 0,     x0,       x0**2 ],
        #	 [ 0,       0,  1, 2*x0 ]]
        # B = [ Mhh - 0.5*a1*x0 - 0.5*a2*x0**2, -0.5*a1*x0 - 0.5*a2*x0**2, -b4*x0**3 - 0.25*a1*v0**2 - 0.5*a2*x0*v0**2, Mss - 3*b4*x0**2 - 0.5*a2*v0**2 ]
        # assert linalg.det(A) != 0
        # self.a1, self.a2, self.b4 = a1*1.0, a2*1.0, b4*1.0
        # self.v0, self.x0 = v0*1.0, x0*1.0

        self.v0 = 246.221
        self.m2 = 125
        self.x0 = 0
        self.a2, self.b3, self.b4, self.m1, self.cTheta = a2 * 1.0, b3 * 1.0, b4 * 1.0, m1 * 1.0, cTheta * 1.0
        self.sTheta = np.sqrt(1 - np.square(self.cTheta))
        self.lmda = (np.square(self.m1) * np.square(self.cTheta) + np.square(self.m2) * np.square(self.sTheta)) / (2 * np.square(self.v0))
        self.a1 = (np.square(self.m2) - np.square(self.m1)) * 2 * self.sTheta * self.cTheta / self.v0
        self.mu2 = self.lmda * np.square(self.v0)
        self.b1 = -0.25 * self.a1 * np.square(self.v0)
        self.b2 = -0.5 * self.a2 * np.square(self.v0) + np.square(self.m2) * np.square(self.cTheta) + np.square(self.m1) * np.square(self.sTheta)
        # self.lmda, self.mu2, self.b2, self.b3 = linalg.solve(A,B)

        # Do checks to ensure that the potential is bounded
        assert self.b4 > 0
        assert self.lmda > 0
        # assert self.a2 > -2 * np.sqrt(self.lmda * self.b4)
        # Check that the minimum is actually a minimum
        # assert Mhh > 0
        # assert Mss > 0
        # assert Mhh*Mss - v0**2*( 0.5*self.a1 + self.a2*x0 )**2 > 0

        # Set the yukawa couplings and the gauge couplings.
        self.g = 2 * 80.385 / 246.0  # mW = g*v/2
        sqrt_ggp = 2 * 91.1876 / 246.0  # mZ = sqrt(g*g+gp*gp) * v/2
        self.gp = np.sqrt(sqrt_ggp * sqrt_ggp - self.g * self.g)
        self.yt = np.sqrt(2) * 173.03 / 246.0  # mt = yt*v/sqrt(2)

        self.Tmax = 3000.  # 10000 was causing problems.
        self.forbidPhaseCrit = lambda X: True if (X[0] < -5.0) else False  # don't let h go too negative

    #	def getPhases(self,startHigh=True,**tracingArgs):
    # We'll get screwy results using the oringal method if there's a tree-level barrier,
    # since it uses T0 to set the scale.
    # Instead, just use mu and b2 as the scale.
    # T0_old = self.T0 if self.T0 != None else self.findT0() # used as the characteristic T scale
    # self.T0 = max(abs(self.mu2), abs(self.b2), self.T0**2)**0.5
    # generic_potential.getPhases(self, startHigh, **tracingArgs)
    # self.T0 = T0_old # set T0 to its old value

    def V0(self, X):
        X = np.array(X)
        h2, s = X[..., 0] ** 2, X[..., 1]
        y = - self.mu2 * h2 + self.lmda * h2 * h2
        y += self.b1 * s + 0.5 * self.b2 * s ** 2 + (1. / 3) * self.b3 * s ** 3 + 0.25 * self.b4 * s ** 4
        y += 0.5 * h2 * s * (self.a1 + self.a2 * s)
        return y

    def scalarMassSqs(self, h, s):
        h2 = h * h
        # M is the mass matrix. h is higgs direction, s is singlet, g is goldstone
        Mgg = -self.mu2 + self.lmda * h2 + 0.5 * s * (self.a1 + self.a2 * s)
        Mhh = Mgg + 2 * self.lmda * h2
        Mss = self.b2 + 2 * self.b3 * s + 3 * self.b4 * s * s + 0.5 * h2 * self.a2
        Mhs = 0.5 * h * (self.a1 + 2 * self.a2 * s)
        return Mgg, Mhh, Mss, Mhs

    def Vtot(self, X, T, include_radiation=True):
        X = np.array(X)
        h, s = X[..., 0], X[..., 1]
        y = self.V0(X)
        # Add in thermal mass contributions
        # For bosons, this is just T^2*m^2 / 24
        # For fermions, it's T^2*m^2 / 48
        Mgg, Mhh, Mss, Mhs = self.scalarMassSqs(h, s)
        mW2 = 0.25 * self.g * self.g * h * h
        mZ2 = mW2 + 0.25 * self.gp * self.gp * h * h
        y += T * T * (Mhh + Mss + 3 * Mgg + 3 * mW2 + 3 * mZ2) / 24.
        mt2 = 0.5 * self.yt * self.yt * h * h
        y += T * T * 12 * mt2 / 48.
        return y

    def V1T_from_X(self, X, T, include_radiation=True):
        # This is called by dgradV_dT. It's supposed to return just the
        # temperature-dependent part of the potential, but it's fine if it
        # returns the whole thing.
        # Since we're overriding Vtot, we need to override this too.
        return self.Vtot(X, T, include_radiation)

    def approxZeroTMin(self):
        # might want to add in more minima here if we can figure them out analytically.
        m = [np.array([self.v0, self.x0])]
        if (self.mu2 < 0 and self.b2 > 0):
            m.append(m[0] * 1e-4)  # add the origin as a minimum
        # (adding something very close to the origin rather than the exact origin so that
        # the tracing algorithm doesn't get stuck there)
        return m

def main():
    print('In main')
    a2l = 0.5; b3l = 0.8; b4l = 0.99; m1l = 20.0; cThetal = 0.01
    xSM = singletSM()
    '''
    for a2l in tqdm(np.logspace(-4.0,0.0,num=100)):
        for b3l in 246 * np.logspace(-4.0,0.0,num=100):
            for b4l in np.logspace(-5.0,0.0,num=100):
            '''
    xSM.init(a2l, b3l, b4l, m1l, cThetal)
    T = xSM.calcTcTrans()
    print(T)
    for key, val in T[0].items():
        if key == 'Tcrit':
            Tc = val
            print('Tcrit :', Tc)
        elif key == 'action':
            S3 = val
            print('action :', S3)
        elif key == 'trantype':
            Tr_order = val
            print('trantype :', Tr_order)
    if (Tr_order == 1 | (S3 / Tc == 117 - 4 * np.log10(Tc / 100.0))):
        print("OK, first order pha transition")
        print('parameter', a2l, b3l, b4l)






if __name__ == '__main__':
    print('This is the start of the program!')
    main()









