'''
THIS IS A SAGEMATH PROTOTYPE IMPLEMENTATION OF THE SMOOTH CASE ALGORITHM FOR 
COMPUTING POINTS PER CONNECTED COMPONENTS DEFINED BY A REAL INEQUATION.

ALTHOUGH FUNCTIONAL FOR MANY EXAMPLES, THERE MAY BE BUGS. 
PLEASE FEEL FREE TO REPORT THEM.

EXECUTE WITH ONE ARGUMENT:
    - A string whose name is the file (only tested in .txt format)
      in which the polynomial is.
      !!! VARIABLE NAMES MUST ONLY BE LETTERS !!!
      e.g. execute as: sage smooth.sage "polynomial.txt"
REQUIRES:
    - V(f) is smooth (program checks for this and returns an error if not)
RETURNS:
    - Finite list of rational points, with at least one point per connected
    component of {x in R : f(x) =/= 0}

TO USE:
    - Requires msolve (latest version advised) to be installed 
    - If necessary, change the msolve_path variable (in the beginning, right 
      below the imports) to the path in which your msolve is installed
    - Other variables that can be modified by hand are (at the end):
        - verbose (value between 0 and 2), to indicate how verbose msolve is.
        - threads, to indicate the number of threads to be used by msolve.
        - precision (default 128), to indicate how precise the real root
          intervals are (128 should be sufficient, unless you have a
          hypersurface that is very close to being singular).
        - changevar (Boolean, default True), to indicate whether to apply a
          random change of variables matrix, or the identity (if set to 
          False). Mostly useful for testing purposes, but can save computation
          time if set to False if your polynomial is known to be sufficiently
          generic (typically, randomly generated).
'''

import os
import sys
import time
import uuid
import itertools

msolve_path = "msolve"

def ToMSolve(F, finput="/tmp/in.ms"):
    """Convert a system of sage polynomials into a msolve input file.

    Inputs :
    F (list of polynomials): system of polynomial to solve
    finput (string): name of the msolve input file.

    """
    A = F[0].parent()
    assert all(A1 == A for A1 in map(parent,F)),\
            "The polynomials in the system must belong to the same polynomial ring."

    variables, char = A.variable_names(), A.characteristic()
    s = (", ".join(variables) + " \n"
            + str(char) + "\n")

    if isinstance(A,sage.rings.polynomial.polynomial_ring.PolynomialRing_generic):
        B = A
    else:
        B = A.change_ring(order = 'degrevlex')
    F2 = [ str(B(f)).replace(" ", "") for f in F ]
    if "0" in F2:
        F2.remove("0")
    s += ",\n".join(F2) + "\n"

    fd = open(finput, 'w')
    fd.write(s)
    fd.close()

def FormatOutputMSolve(foutput, hasinterval):
    """Convert a msolve output file into a rational parametrization 

    Inputs :
    foutput (string): name of the msolve output file
    hasinterval (boolean): specifies if the rational parametrization has solution intervals added

    Output :
        A rational parametrization of the zero-dimensional ideal describing
    the solutions. Note : p[i] and c[i] stand for the (i+1)-th coordinate.

    """
    f = open(foutput,'r')
    s = f.read()
    s = s.replace("\n","").replace(":","")
    R = sage_eval(s)
    A.<t> = QQ[]
    # dimension
    dim = R[0]
    if dim > 0:
        return None, None, A(-1), None, None, None, None

    # parametrization
    try:
        nvars       = R[1][1]
    except IndexError:
        return None, None, A(-3), None, None, None, None
    qdim        = R[1][2]
    varstr      = R[1][3]
    linearform  = R[1][4]
    elim        = R[1][5][1][0]
    den         = R[1][5][1][1]
    polys       = R[1][5][1][2]
    # solutions
    if hasinterval == True:
        intervals   = R[2][1]

    #  nvars, degquot, deg = L[1], L[2], L[5][0]
    #  varstr      =   L[3]
    #  linearform  =   L[4]

    if len(elim) > 0:
        pelim = A(elim[1])
    else:
        return None, None, A(-2), None, None, None, None

    pden, p, c = A(1), [], []
    if qdim > 0:
        pden = A(den[1])
        for l in polys:
            p.append(A(l[0][1]))
            c.append( l[1] )

    S   =   []
    if hasinterval == True:
        if len(intervals) > 0:
            for sol in intervals:
                s = []
                for i in range(nvars):
                    s.append((sol[i][0]+sol[i][1])/2)
                S.append(s)
    return [varstr, linearform, pelim, pden, p, c, S]

def FormatOutputMSolveGrobner(foutput, R):
    """Convert a msolve grobner basis output file into a list of polynomials in parent ring R

    Inputs :
        foutput (string): name of the msolve output file
        R (PolynomialRing): desired parent ring of the basis

    Output :
        list of polynomials in R
    
    !!! Make sure that your previously defined R contains exactly the right variables present in basis output file !!!

    """
    with open(foutput, 'r') as o:
        sols = []
        for l2 in o.readlines()[7:]:
            l = ''
            for c in l2:
                if c not in ['[',']','\n',',',':']:
                    l += c
            if l == '' or l[0] == "#":
                continue
            sols.append(R(l))
    return sols

def FormatOutputMSolveIntervals(foutput):
    "Converts a msolve isolation intervals only file into list"
    f = open(foutput,'r')
    s = f.read()
    s = s.replace("\n","").replace(":","")
    R = sage_eval(s)
    return R


def IsSmooth(poly, threads, verbose):
    '''
    Function that checks for the smoothness of V(f).
    Input: 
        - poly (polynomial), 
        - threads (number to be used in computation), 
        - verbose (msolve parameter for explicit description of computations, 
            ranging from 0 (no desc) to 2 (full desc)).
    Output: 
        - Boolean answer to the question "Is V(poly) smooth?".
    '''
    R = poly.parent()
    variables = list(R.gens())
    input_list = [derivative(poly, variables[i]) for i in range(len(variables))]
    input_list.append(poly)
    filename = str(uuid.uuid4())
    ToMSolve(input_list, f"{filename}_in.ms")
    os.system(f"{msolve_path} -g1 -v{verbose} -t{threads} -f {filename}_in.ms -o {filename}_out.ms")
    gb = FormatOutputMSolveGrobner(f"{filename}_out.ms",R)
    os.system(f"rm {filename}_in.ms")
    os.system(f"rm {filename}_out.ms")
    if gb == [1]:
        return True
    else:
        return False

def RandomMatrix(n, changevar=True):
    '''
    Function that generates a random (n x n) change of variables matrix, such
    that all partial inverses B_k exist, for 0 <= k <= n (aka such that it
    satisfies hypothesis (A2)).
    Input:
        - n (number of variables)
        - changevar (boolean; uses A = Identity if set to False)
    Output:
        - List of matrices: [A, B_0 (= A^{-1}), B_1, ..., B_n]
    '''
    while True:
        try:
            if changevar == True:
                A = matrix([[ZZ.random_element(1,100,"uniform") for j in range(n)] for i in range(n)])
            else:
                A = matrix.identity(n)
            list_of_matrices = [A] + [A[list(range(k,n)), list(range(k,n))].inverse() for k in range(n)]
            return list_of_matrices
        except ZeroDivisionError:
            print("Random Matrix choice failed to satisfy (A2) - trying again...")
            pass

def DerivativeOrder(poly):
    '''
    Function that re-labels the variables of the input polynomial such that its
    partial derivatives have increasing degree.
    Input:
        - poly (polynomial)
    Output:
        - Re-labeled poly
    '''
    R = poly.parent()
    variables = list(R.gens())
    n = len(poly.variables())
    der_deg = {variables[i] : derivative(poly,variables[i]).degree() for i in range(n)}
    new_var = sorted(der_deg, key=lambda k: der_deg[k])
    perm = Word(variables).standard_permutation() / Word(new_var).standard_permutation()
    inv_perm = Word(new_var).standard_permutation() / Word(variables).standard_permutation()
    sigma = (SymmetricGroup(range(n)))([perm[i]-1 for i in range(n)])
    inv_sigma = (SymmetricGroup(range(n)))([inv_perm[i]-1 for i in range(n)])

    return(inv_sigma, R(SR(poly(*sigma(R.gens())))))
#------------
def augment_ring_with_variables(ring, P, x):
    # input:
    #    . P: polynomial in r variables (not tested with r==1, may not work)
    #    . x: variable or tuple/list of variables
    #    . ring: polynomial ring
    #
    # constraint:
    #    P.parent() is exactly ring.remove_var(*x)
    #
    # output:
    #    same polynomial but in ring
    from sage.rings.polynomial.polydict import PolyDict, ETuple

    # if x is just a variable, or is a list, make it a tuple
    if isinstance(x, list):
        x = tuple(x)
    if not isinstance(x, tuple):
        x = (x,)

    # safety check
    if ring.remove_var(*x) != P.parent():
        raise ValueError("P.parent() must be ring.remove_var(*x)")

    # get positions of variables to add, and index maps for those to keep
    # (exploits technicality: for ETuples, et[i] yields 0 if i is out of bounds;
    # ring.ngens() is beyond the length of P monomials)
    ng = ring.ngens()
    map_idx = [ng for i in range(ring.ngens())]
    curr_idx = 0
    y = ring.gens()
    for i in range(ng):
        yi = y[i]
        if yi not in x:
            map_idx[i] = curr_idx
            curr_idx += 1

    # retrieve dictionary {monomial : coeff}
    Pdic = P.monomial_coefficients()

    # insert variables
    Qdic = {ETuple([et[map_idx[i]] for i in range(ring.ngens())]) : coeff
            for et, coeff in Pdic.items()}
    Q = ring(Qdic)

    return Q

def remove_absent_variable(P, x):
    # input:
    #    . P: polynomial in r variables (not tested with r==1, may not work)
    #    . x: variable or tuple/list of variables
    # constraint:
    #    each variable in x is in P.parent() and does not appear in P
    #    (i.e. any monomial involving x is with coefficient zero)
    #
    # output:
    #    same polynomial but seen in P.parent() with variables x removed
    from sage.rings.polynomial.polydict import PolyDict, ETuple

    # if x is just a variable, or is a list, make it a tuple
    if isinstance(x, list):
        x = tuple(x)
    if not isinstance(x, tuple):
        x = (x,)

    # parent ring and ring with variables removed
    pring = P.parent()
    ring = pring.remove_var(*x)

    # get positions of variables to keep
    var_idx = []
    y = pring.gens()
    for i in range(pring.ngens()):
        yi = y[i]
        if yi not in x:
            var_idx.append(i)

    # retrieve dictionary {monomial : coeff}
    Pdic = P.monomial_coefficients()

    # remove variables
    Qdic = {ETuple([et[i] for i in var_idx]) : coeff for et, coeff in Pdic.items()}
    Q = ring(Qdic)

    return Q
#------------
def CriticalPoints(f,threads,verbose,precision,k, n, list_of_matrices, sigma, variables, der_list):
    '''
    Function computing the critical points of the projection on X_k-axis.
    Input:
        - f (polynomial),
        - threads (number to be used in computation), 
        - verbose (msolve parameter for explicit description of computations, 
            ranging from 0 (no desc) to 2 (full desc)),
        - change,
        - k (integer satisfying 0 <= k < n),
        - n (number of variables),
        - list_of_matrices (change of variables matrix A and its B_k's,
            formatted as in the output of RandomMatrix),
        - sigma (list of n-1 integers),
        - variables (list of x0, ..., x{n-1}),
        - der_list (list of partial derivatives of f).
    Output:
        - substitution expressions for the first k-1 variables,
        - Isolation intervals with rational endpoints for the last (n-k)
            coordinates of A applied to each critical point of the projection of
            V^A on the X_k-axis, where the first coordinates have been
            instantiated to sigma[0], ..., sigma[k-1]. This follows the msolve
            output format for the -P0 flag.
    '''
    # Computing the actual values to substitue into x_0, ..., x_{k-1}
    if k == 0:
        substitution = []
    else:
        inv_left = list_of_matrices[1][list(range(k)),list(range(k))].inverse()
        right = list_of_matrices[1][list(range(k)),list(range(k,n))]
        sigma_k = matrix(sigma[:k]).transpose()
        vars_k = matrix(variables[k:]).transpose()
        substitution = (inv_left * (sigma_k - (right * vars_k))).coefficients()
    
    # Computing the system equivalent to f^A, df^A/dX_{k+1}, ..., df^A/dX_n.
    if k == n-1:
        der_system = []
    else:
        left = matrix(der_list[k+1:])
        top_right = (list_of_matrices[0])[list(range(k+1)), list(range(k+1,n))]
        inv_right = list_of_matrices[k+2]
        der_system = (left + (matrix(der_list[:k+1]) * top_right * inv_right)).coefficients()
    input_system = [f] + der_system

    # Subsituting the former in the latter, with the right parent ring.
    input_system = [p.subs({variables[i] : substitution[i] for i in range(k)}) for p in input_system]
    if k == n-1:
        Rk = PolynomialRing(QQ, variables[k:], n-k)
        input_system = [Rk(p) for p in input_system]
    else:
        input_system = [remove_absent_variable(p,variables[:k]) for p in input_system]


    # Calling msolve to solve the above system.
    filename = str(uuid.uuid4())
    ToMSolve(input_system, f"{filename}_in.ms")
    os.system(f"{msolve_path} -v{verbose} -t{threads} -p{precision} -P0 -f {filename}_in.ms -o {filename}_out.ms")
    sol = FormatOutputMSolveIntervals(f"{filename}_out.ms")
    os.system(f"rm {filename}_in.ms")
    os.system(f"rm {filename}_out.ms")

    # Computing the polynomial equivalent to df^A/dX_k
    # NOT USED BY CURRENT VERSION OF CODE
    if k == 0:
        gk = (1/list_of_matrices[1][0][0])*der_list[k]
    else:
        left = der_list[k]
        top_right = (list_of_matrices[0])[list(range(k)), list(range(k,n))]
        inv_right = list_of_matrices[k+1]
        gk = (left + (matrix(der_list[:k]) * top_right * inv_right)[0][0])*(1/inv_right[0][0])
    
    return substitution, sol
    
def matrix_box(n, point, matrix, substitution):
        '''
        Function computing the isolation box of a real point after
        transformation by a matrix.
        Input:
           - n (size of the matrix / number of coordinates)
           - point (list of isolation intervals for n-k+1 coordinates, msolve
             format)
           - matrix (n x n)
           - substitution (expressions for the first k-1 variables in terms of
             the other ones)
        Output:
             - list of isolation intervals for the last n-k+1 coordinates of
               matrix*point, msolve format
        '''
        first_coords_point = [rough_eval(point, item) for item in substitution]
        extended_point = first_coords_point + point
        vertices = [list(item) for item in itertools.product(*extended_point)]
        changed_vertices = [matrix*vector(vertex) for vertex in vertices]
        new_box = [[min(item[i] for item in changed_vertices),max(item[i] for item in changed_vertices)] for i in range(n)]
        for coord in new_box[len(substitution):]:
            if sign(coord[0]) != sign(coord[1]):
                raise ValueError("Msolve not precise enough to determine coordinate sign. Consider increasing the precision.")
        return new_box[len(substitution):]
#------------
def rough_eval(point,poly):
    '''
    Function computing the isolation interval that a polynomial takes on a box
    approximating a point (although MPFI technically does it already, it is not
    precise enough for us)
    Input:
        - point (list of isolation intervals, msolve format)
        - poly (polynomial)
    Output:
            - minimal and maximal value that f can take on the box
    '''

    varss = poly.parent().gens()
    varsss = poly.variables()
    if len(varsss) == 0:
        return QQ(poly), QQ(poly)
    if len(varss) != len(varsss):
        print("Careful, some variables might be missing.")
        print(poly)
        print(poly.parent())
        print(poly.variables())
    
    rg = range(len(varss))
    sign_list = [sign(coord[0]) for coord in point]
    modified_point = []
    for i in range(len(point)):
        if sign_list[i] == -1:
            modified_point.append([point[i][1], point[i][0]])
        else:
            modified_point.append([point[i][0], point[i][1]])
    min_out_poly = 0
    max_out_poly = 0
    if len(varss) == 1:
        poly = PolynomialRing(QQ, 1, varss)(poly)
        varss = PolynomialRing(QQ, 1, varss).gens()
    for coeff,monom in poly:
        signn = coeff*monom.subs({varss[i] : sign_list[i] for i in rg})
        if sign(signn) == -1:
            min_out_poly += coeff*monom.subs({varss[i] : modified_point[i][1] for i in rg})
            max_out_poly += coeff*monom.subs({varss[i] : modified_point[i][0] for i in rg})
        else:
            min_out_poly += coeff*monom.subs({varss[i] : modified_point[i][0] for i in rg})
            max_out_poly += coeff*monom.subs({varss[i] : modified_point[i][1] for i in rg})
    return min_out_poly, max_out_poly

def is_zero_inside(interval):
    '''
    Check if 0 is inside an interval
    Input:
       - Interval (type list, of form [a,b])
    Output:
       - Boolean
    '''
    if interval[0] == 0 or interval[1] == 0 or sign(interval[0]) != sign(interval[1]):
        return True
    return False

def do_boxes_intersect(point_list):
    '''
    Check whether any point approximation box in a list intersect each other
    Input:
       - List of approximations of points in msolve format
    Output:
       - Boolean
    '''
    if point_list == []:
        return False
    n = len(point_list[0])
    all_vertices = [[list(item) for item in itertools.product(*point)] for point in point_list]
    for i in range(len(point_list)):
        vertices = all_vertices[i]
        for vertex in vertices:
            for j in range(i+1, len(point_list)):
                counter = 0
                for k in range(n):
                    if point_list[j][k][0] <= vertex[k] <= point_list[j][k][1]:
                        counter += 1
                if counter == n:
                    return True
    return False
#------------
def TransverseIntersection(poly, vars, point, low_prec):
    '''
    Function computing points 'to the left' and 'to the right' of the critical
    point, by means of the transverse line and real root approximation

    Input:
       - poly (polynomial)
       - vars (variables of the parent of the polynomial)
       - point (approximation of point, in msolve format)
       - low_prec (point, with coordinates at a lower precision)
    Output:
       - Two points (in msolve format), each being to the 'left' and the 'right'
         of the critical point on the transverse line, and sufficiently close to
         be in the right connected component.
    '''
    UnivarRing.<ttttt> = QQ[]
    approx = [RealIntervalField(prec=5*precision)(item).simplest_rational(False, False) for item in point]
    list_to_sub = [UnivarRing(ttttt + approx[0])]
    if len(approx) != 1:
        list_to_sub += approx[1:]
    transverse_poly = UnivarRing(poly.subs({vars[i] : list_to_sub[i] for i in range(len(vars))}))
    
    filename = str(uuid.uuid4())
    ToMSolve([transverse_poly], f"{filename}_in.ms")
    os.system(f"{msolve_path} -v{verbose} -t{threads} -p{2*precision} -I1 -f {filename}_in.ms -o {filename}_out.ms")
    inter = FormatOutputMSolveIntervals(f"{filename}_out.ms")
    os.system(f"rm {filename}_in.ms")
    os.system(f"rm {filename}_out.ms")
    inter_lambda_values = inter[1][1]

    if inter_lambda_values == []:
        raise ValueError("Coordinates not precise enough to compute a good intersection line. Consider increasing the precision")

    allowed_lambda_interval = [low_prec[0] - approx[0], low_prec[1] - approx[0]]
    endpoints = list(itertools.chain.from_iterable(list(itertools.chain.from_iterable(inter_lambda_values))))
    sorted_endpoints = sorted(endpoints, key=lambda x: (abs(x), x))

    if (allowed_lambda_interval[0] <= sorted_endpoints[0] <= allowed_lambda_interval[1]) and (allowed_lambda_interval[0] <= sorted_endpoints[1] <= allowed_lambda_interval[1]):
        if len(sorted_endpoints) == 2:
            lambd = ceil(abs(sorted_endpoints[1]))+1
        else:
            lambd = RealIntervalField(prec=5*precision)(abs(sorted_endpoints[1]), abs(sorted_endpoints[2])).simplest_rational(True,True)
    else:
        raise ValueError("Coordinates not precise enough to compute a good intersection line. Consider increasing the precision")
    
    if len(sorted_endpoints) != 2 and allowed_lambda_interval[0] <= sorted_endpoints[2] <= allowed_lambda_interval[1]:
        raise ValueError("Isolation box not precise enough to gurantee a single intersection point of the transverse line inside it. Consider increasing the precision")
    
    right_pt = [approx[0] + lambd]
    left_pt = [approx[0] - lambd]
    if len(approx) != 1:
        right_pt += approx[1:]
        left_pt += approx[1:]
    return left_pt, right_pt


def SmoothPointPerConnectedComponent(poly,threads,verbose,precision,changevar=True):
    '''
    Main function computing the points per connected components.
    Input:
        - poly (polynomial),
        - threads (number to be used in computation), 
        - verbose (msolve parameter for explicit description of computations, 
            ranging from 0 (no desc) to 2 (full desc)),
        - changevar (boolean; does not apply any change of variable if set
            to False).
    Output:
        - Finite list of rational points.
    '''

    # Setting up correct parent rings and variables
    R = poly.parent()
    variables = list(R.gens())
    n = len(variables)

    # Re-naming variables to have the increasing partial derivatives degree
    inv_permutation, poly = DerivativeOrder(poly)
    f = R(poly)

    # Generating the change of variable matrix A, with its partial inverses
    list_of_matrices = RandomMatrix(n,changevar)

    # Generating the specialisation point sigma
    if changevar == True:
        sigma = [ZZ.random_element(1,100,"uniform") for i in range(n-1)]
    else:
        sigma = [1 for i in range(n-1)]

    # Pre-computing partial derivatives of f to avoid unnecessary computations
    der_list = [derivative(f, variables[i]) for i in range(n)]

    # Initialising final solutions list and number of solutions
    Sols = []
    Sols_num = 0

    # Computing f^A
    fA = f.subs({variables[i] : (list_of_matrices[0]*vector(variables))[i] for i in range(n)})

    # Main for loop
    for k in range(n):
        start = time.perf_counter()
        print(f"k = {k}\n")

        # Obtaining approximations to critical points
        substitution, crit = CriticalPoints(f, threads, verbose,2*precision,k, n, list_of_matrices, sigma, variables, der_list)
        
        temp = time.perf_counter()

        Solsk = []
        #print(crit)

        # In case we have infinitely many of them
        if crit[0] > 0:
            print("Error - Infinitely many critical points. Picking another change of variables matrix...")
            return SmoothPointPerConnectedComponent(poly, threads, verbose, precision, changevar)
        
        # In case we have finitely many of them, and at least one
        if crit[0] != -1 and len(crit) < 3 and crit[1][1] != []:

            # Substituting the first variables & fixing parent ring issues
            f_sub = f.subs({variables[i] : substitution[i] for i in range(k)})
            fA_sub = fA.subs({variables[i] : sigma[i] for i in range(k)})

            if k == n-1:
                f_sub = f_sub.univariate_polynomial()
                fA_sub = fA_sub.univariate_polynomial()
                substitution = [item.univariate_polynomial() for item in substitution]
            else:
                f_sub = remove_absent_variable(f_sub, variables[:k])
                fA_sub = remove_absent_variable(fA_sub, variables[:k])
                substitution = [remove_absent_variable(item, variables[:k]) for item in substitution]
            variabless = f_sub.parent().gens()

            # Looping over each computed point to obtain A^-1 * point
            if changevar == True:
            
                A_inv_list = []

                for point in crit[1][1]:

                    # Checking whether the sign of each coordinate is known
                    for coord in point:
                        if sign(coord[0]) != sign(coord[1]):
                            raise ValueError("Msolve not precise enough to determine coordinate sign. Consider increasing the precision.")
                    
                    # Computing approximation box of original (no A) critical point
                    A_inv_point = matrix_box(n,point,list_of_matrices[1],substitution)

                    # Checking whether the sign of each coordinate is still known
                    for coord in A_inv_point:
                        if sign(coord[0]) != sign(coord[1]):
                            raise ValueError("Msolve not precise enough to determine coordinate sign. Consider increasing the precision.")

                    A_inv_list.append(A_inv_point)
            
                     # Checking whether the new approximation boxes intersect
                    if do_boxes_intersect(A_inv_list):
                        raise ValueError("Msolve not precise enough to isolate critical points. Consider increasing the precision.")
            
            else:
                A_inv_list = crit[1][1]

            
            dfAdxk = derivative(fA_sub, variabless[0])
            for point in A_inv_list:
                # Checking whether we do not have an exact point (if we actually
                # do, we can skip verification steps)
                if [item[0] for item in point] != [item[1] for item in point]:

                    # Checking if df^A/dx_k = 0 in the msolve approximation box
                    dfA_interval = rough_eval(point, dfAdxk)
                    if is_zero_inside(dfA_interval):
                        raise ValueError("Msolve not precise enough to guarantee non-zero derivative. Consider increasing the precision.")

                    # Computing rougher approximation of that point in the x_k coordinate only
                    coord_low_prec = [floor(2^(precision)*point[0][0])/2^(precision), ceil(2^(precision)*point[0][1])/2^(precision)]

                    # Checking whether f vanishes on the low and high x_k
                    low_f = fA_sub.subs({variabless[0] : coord_low_prec[0]})
                    high_f = fA_sub.subs({variabless[0] : coord_low_prec[1]})
                    if k == n-1:
                        low_inter = [QQ(low_f), QQ(low_f)]
                        high_inter = [QQ(high_f), QQ(high_f)]
                    elif k == n-2:
                        low_f = low_f.univariate_polynomial()
                        high_f = high_f.univariate_polynomial()
                        low_inter = rough_eval(point[1:], low_f)
                        high_inter = rough_eval(point[1:], high_f)
                    else:
                        low_f = remove_absent_variable(low_f, variabless[0])
                        high_f = remove_absent_variable(high_f, variabless[0])
                        low_inter = rough_eval(point[1:], low_f)
                        high_inter = rough_eval(point[1:], high_f)

                    if is_zero_inside(low_inter) or is_zero_inside(high_inter):
                        raise ValueError("Approximation is not sufficiently precise. Consider increasing the precision.")
                else:
                    coord_low_prec = [point[0][0], point[0][1]]
                # Computing the points lying on the transverse line in the
                # corresponding connected components
                left, right = TransverseIntersection(fA_sub, variabless, point, coord_low_prec)
                
                left = list_of_matrices[0]*vector(sigma[:k] + left)
                right = list_of_matrices[0]*vector(sigma[:k] + right)

                # Sanity check for the whole procedure
                if left == 0 or right == 0 or sign(f.subs({variables[i]: left[i] for i in range(n)})) == sign(f.subs({variables[i]: right[i] for i in range(n)})):
                    raise ValueError("Sanity check failed, something went wrong.")

                Solsk.append(inv_permutation(list(left)))
                Solsk.append(inv_permutation(list(right)))

            print(f"{len(Solsk)} solutions were added in the k = {k} case\n")
            Sols_num += len(Solsk)
        
        # In case there are no critical points
        else:
            print(f"There are no critical points in the k = {k} case\n")  
        Sols += [Solsk]
        end = time.perf_counter()
        print(f"The k = {k} case took {end-start} seconds to compute")
        print(f"(Overall msolve time: {temp-start} seconds)")
        print(f"(Overall sagemath time: {end-temp} seconds)\n")

    Sols += [list(list_of_matrices[0]*vector(sigma + [0]))]
    Sols_num += 1
    return Sols, Sols_num

def main(poly, threads, verbose, precision, changevar=True):
    '''
    Function computing the points per connected component.
    Input:
        - poly (polynomial),
        - threads (number to be used in computation), 
        - verbose (msolve parameter for explicit description of computations, 
            ranging from 0 (no desc) to 2 (full desc)),
        - changevar (boolean; does not apply any change of variable, and always
            picks [1,...,1] as a specialisation point, if set to False).
    Output:
        - Finite list of rational points.
    '''
    variables = list(poly.variables())
    n = len(variables)
    R = PolynomialRing(QQ, variables, n)
    f = R(poly)
    if IsSmooth(f, threads, verbose):
        #print("smooth")
        return SmoothPointPerConnectedComponent(f,threads,verbose,precision,changevar)
    else:
        raise ValueError("Input polynomial does not define a smooth hypersurface.")

verbose = 0
threads = 7
precision = 128
changevar = True
sys.set_int_max_str_digits(0)
f = open(sys.argv[1],'r')
poly = SR(f.read())

start2 = time.perf_counter()
aaaa = main(poly, threads, verbose, precision, changevar)
end2 = time.perf_counter()

if aaaa[0] == 0:
    print("An error occured")

print(f"This took {end2-start2} seconds to compute in total.")
print(f"A total of {aaaa[1]} points were computed.")
#print(aaaa[0])
