#include <math.h>
#include "rcomplex.hpp"

//Complex methods


Real Complex::Magnitude()
{
	return sqrt(re*re + im*im);
}

//Friend functions for Complex

Complex operator*(Real f, Complex &v) //Scalar prod
{
	return Complex(v.re*f, v.im*f);
}

Complex operator+(Complex &u, Complex &v)
{
	return Complex(u.re + v.re, u.im + v.im);
}

Complex operator-(Complex u, Complex v)
{
	return Complex(u.re - v.re, u.im - v.im);
}

Complex operator+(Real f, Complex &v)
{
	return Complex(f + v.re, v.im);
}

Complex operator-(Real f, Complex &v)
{
	return Complex(f - v.re, -v.im);
}

Complex Sqrt(Complex &u)
{
	Real modulus, arg;

	if (u.re == 0.0 && u.im == 0.0) //atan2 gives error for 0,0
		return u;
	else
	{
		modulus = sqrt(u.Magnitude());
		arg = atan2(u.im,u.re);
		arg /= 2.0;
		return Complex(modulus * cos(arg), modulus * sin(arg));
	}
}

void Cuberoot(Complex &u, Complex& root0, Complex& root1, Complex& root2)
{
	Real modulus, arg;

	if (u.re == 0.0 && u.im == 0.0) //atan2 gives error for 0,0
		root0 = root1 = root2 = u;
	else
	{
		modulus = pow(u.Magnitude(), 1.0 / 3.0);
		arg = atan2(u.im,u.re);
		arg /= 3.0;
		root0 = Complex(modulus * cos(arg), modulus * sin(arg));
		arg += 2.0 * PI / 3.0;
		root1 = Complex(modulus * cos(arg), modulus * sin(arg));
		arg += 2.0 * PI / 3.0;
		root2 = Complex(modulus * cos(arg), modulus * sin(arg));
	}
}

Complex operator*(Complex &u, Complex &v)
{
	return Complex(u.re*v.re - u.im*v.im, u.re*v.im + u.im*v.re);
}

Complex operator/(Complex &v, Real f)
{
	return Complex(v.re/f, v.im/f);
}

Complex operator/(Complex &u, Complex &v)
{
	return (u * ~(v))/(v.re*v.re + v.im*v.im);
}

Complex operator/(Real f, Complex &v)
{
	return (Complex(f,0.0)/v);
}

ostream& operator<< (ostream& s, Complex &u)
{
	s << u.re << " " << u.im;
	return s; //Stream operators return the stream, for chaining.
}

istream& operator>> (istream& s, Complex &u)
{
	s >> u.re >> u.im;
	return s;
}



/* Morris Kline, Mathematical Thought from Ancient to Modern Times,
p. 269 gives Vieta's method of solving a cubic equation, which we 
implement as CubicSolve.*/

void CubicSolve(Complex b, Complex c, Complex d, Complex *image)
{
/* solve the equation x3 + bx2 + cx + d = 0 by setting x = y - b/3,
to get the equation y3 + py + q = 0.  Then set y = z - p/3*z, 
substitue again and get the values for z.*/
	Complex bover3, pover3, qover2, SqrtR;
	Complex cand[6];
	int index1 = 0, index2;

	pover3 = (3.0*c - b*b)/9.0;
	qover2 = (2.0*b*b*b - 9.0*b*c + 27.0*d)/54.0;
	SqrtR = Sqrt( pover3*pover3*pover3 + qover2*qover2 );
	Cuberoot(-qover2 + SqrtR, image[0], cand[1], cand[2]);
	Cuberoot(-qover2 - SqrtR, cand[3], cand[4], cand[5]);
	for (int i=1; i<6; i++)
	{
		if (!index1 && !(cand[i] == image[0]))
			index1 = i;
		if (index1 && (!(cand[i] == image[0])) &&
			(!(cand[i] == cand[index1])) )
		{
			index2 = i;
			break;
		}
	}
	bover3 = b/3.0;
	/* Now image[0], cand[index1], cand[index2] hold the correct z
	values.  The correct x values are given by x = z - p/3*z - b/3, or
	z - pover3/z - bover3*/
	image[0] = image[0] - (pover3/image[0]) - bover3;
	image[1] = cand[index1] - (pover3/cand[index1]) - bover3;
	image[2] = cand[index2] - (pover3/cand[index2]) - bover3;;
}

void CubicSolveOne(Complex b, Complex c, Complex d, Complex &answer)
{
/* Use this if you only need ONE root, as in the solution to the
quartic.*/
	Complex pover3, qover2, SqrtR;
	Complex root1, root2;

	pover3 = (3.0*c - b*b)/9.0;
	qover2 = (2.0*b*b*b - 9.0*b*c + 27.0*d)/54.0;
	SqrtR = Sqrt( pover3*pover3*pover3 + qover2*qover2 );
	Cuberoot(-qover2 + SqrtR, answer, root1, root2);
	answer = answer - (pover3/answer) - b/3.0;
}

/* In the case where we know b is zero, we get this form of
CubicSolve*/

void CubicSolveLeadZero(Complex b, Complex c, Complex d, Complex *image)
{
/* solve the equation x3 + bx2 + cx + d = 0 by setting x = y - b/3,
to get the equation y3 + py + q = 0.  Then set y = z - p/3*z, 
substitue again and get the values for z.*/
	Complex pover3, qover2, SqrtR;
	Complex cand[6];
	int index1 = 0, index2;

	pover3 = (3.0*c)/9.0;
	qover2 = (- 9.0*b*c + 27.0*d)/54.0;
	SqrtR = Sqrt( pover3*pover3*pover3 + qover2*qover2 );
	Cuberoot(-qover2 + SqrtR, image[0], cand[1], cand[2]);
	Cuberoot(-qover2 - SqrtR, cand[3], cand[4], cand[5]);
	for (int i=1; i<6; i++)
	{
		if (!index1 && !(cand[i] == image[0]))
			index1 = i;
		if (index1 && (!(cand[i] == image[0])) &&
			(!(cand[i] == cand[index1])) )
		{
			index2 = i;
			break;
		}
	}
	/* Now image[0], cand[index1], cand[index2] hold the correct z
	values.  The correct x values are given by x = z - p/3*z - b/3, or
	z - pover3/z - bover3*/
	image[0] = image[0] - (pover3/image[0]);
	image[1] = cand[index1] - (pover3/cand[index1]);
	image[2] = cand[index2] - (pover3/cand[index2]);
}


/* The Chemical Rubber Company Handbook gives a method for solving a
quartic equation x4 + ax3 + bx2 + cx + d = 0 */

#define LEADZERO //Means you know the a param is always 0.

#ifndef LEADZERO
void QuarticSolve(Complex a, Complex b, Complex c, Complex d, Complex *image)
{
/* We start by finding a root to the resolvent cubic equation
y3 - by2 + (ac-4d)y - a2d + 4bd - c2 = 0*/

	Complex y, R, DE1, DE2, D, E, imagebase;

	CubicSolveOne(-b, a*c-4*d, -a*a*d + 4*b*d - c*c, y);
	R = Sqrt(a*a/4.0 - b + y);
	if (R != Complex(0.0, 0.0))
	{
		DE1 = 3.0*a*a/4.0 - R*R - 2.0*b;
		DE2 = (4.0*a*b - 8.0*c - a*a*a)/4.0*R;
		D = Sqrt(DE1 + DE2);
		E = Sqrt(DE1 - DE2);
	}
	else
	{
		DE1 = 3.0*a*a/4.0 - 2.0*b;
		DE2 = 2.0*Sqrt(y*y - 4.0*d);
		D = Sqrt(DE1 + DE2);
		E = Sqrt(DE1 - DE2);
	}
	imagebase = -a/4.0 + R/2.0;
	image[0] = imagebase + D/2.0;
	image[1] = imagebase - D/2.0;
	imagebase = -a/4.0 - R/2.0;
	image[2] = imagebase + E/2.0;
	image[3] = imagebase - E/2.0;
}

#else //LEADZERO

void QuarticSolve(Complex a, Complex b, Complex c, Complex d, Complex *image)
{
/* We start by finding a root to the resolvent cubic equation
y3 - by2 + (ac-4d)y - a2d + 4bd - c2 = 0*/

	Complex y, R, DE1, DE2, D, E, imagebase;

	CubicSolveOne(-b, -4*d, 4*b*d - c*c, y);
	R = Sqrt(- b + y);
	if (R != Complex(0.0, 0.0))
	{
		DE1 = - R*R - 2.0*b;
		DE2 = (- 2.0*c)/R;
		D = Sqrt(DE1 + DE2);
		E = Sqrt(DE1 - DE2);
	}
	else
	{
		DE1 = - 2.0*b;
		DE2 = 2.0*Sqrt(y*y - 4.0*d);
		D = Sqrt(DE1 + DE2);
		E = Sqrt(DE1 - DE2);
	}
	imagebase = R/2.0;
	image[0] = imagebase + D/2.0;
	image[1] = imagebase - D/2.0;
	imagebase = - R/2.0;
	image[2] = imagebase + E/2.0;
	image[3] = imagebase - E/2.0;
}
#endif