Ticket #11036: trac_11036_improve_solve_mod_perf.patch

File trac_11036_improve_solve_mod_perf.patch, 9.4 KB (added by dsm, 9 years ago)
  • sage/symbolic/relation.py

    # HG changeset patch
    # User D. S. McNeil <dsm054@gmail.com>
    # Date 1301110343 -28800
    # Node ID 733adae06130875381160304a9845d4ff6a4df47
    # Parent  361a4ad7d52c69b64ae2e658ffd0820af0d87e93
    Trac 11036: improve solve_mod performance
    
    diff -r 361a4ad7d52c -r 733adae06130 sage/symbolic/relation.py
    a b  
    710710        sage: solve_mod([x^2 == 1, 4*x  == 11], 15)
    711711        [(14,)]
    712712   
     713
     714    Note that the solution elements belong to the relevant modular ring::
     715
     716        sage: sols = solve_mod([x^2 == 1, 4*x  == 11], 15)
     717        sage: sols[0][0]
     718        14
     719        sage: parent(sols[0][0])
     720        Ring of integers modulo 15
     721
     722   
    713723    Fermat's equation modulo 3 with exponent 5::
    714724   
    715725        sage: var('x,y,z')
     
    717727        sage: solve_mod([x^5 + y^5 == z^5], 3)
    718728        [(0, 0, 0), (0, 1, 1), (0, 2, 2), (1, 0, 1), (1, 1, 2), (1, 2, 0), (2, 0, 2), (2, 1, 0), (2, 2, 1)]
    719729
     730
    720731    We can solve with respect to a bigger modulus if it consists only of small prime factors::
    721732
    722733        sage: [d] = solve_mod([5*x + y == 3, 2*x - 3*y == 9], 3*5*7*11*19*23*29, solution_dict = True)
     
    725736        sage: d[y]
    726737        8610183
    727738
    728     We solve an simple equation modulo 2::
     739    For cases where there are relatively few solutions and the prime
     740    factors are small, this can be efficient even if the modulus itself
     741    is large::
     742
     743        sage: sorted(solve_mod([x^2 == 41], 10^20))
     744        [(4538602480526452429,), (11445932736758703821,), (38554067263241296179,),
     745        (45461397519473547571,), (54538602480526452429,), (61445932736758703821,),
     746        (88554067263241296179,), (95461397519473547571,)]
     747
     748    We solve a simple equation modulo 2::
    729749   
    730750        sage: x,y = var('x,y')
    731751        sage: solve_mod([x == y], 2)
    732752        [(0, 0), (1, 1)]
    733753
     754
    734755    .. warning::
    735756
    736        The current implementation splits the modulus into prime powers,
    737        then naively enumerates all possible solutions and finally combines
    738        the solution using the Chinese Remainder Theorem.
    739        The interface is good, but the algorithm is horrible if the modulus
    740        has some larger prime factors! Sage *does* have the ability to do
    741        something much faster in certain cases at least by using Groebner
    742        basis, linear algebra techniques, etc. But for a lot of toy problems
    743        this function as is might be useful. At least it establishes an interface.
     757       The current implementation splits the modulus into prime
     758       powers, then naively enumerates all possible solutions and
     759       finally combines the solution using the Chinese Remainder
     760       Theorem.  The interface is good, but the algorithm is horrible
     761       if the modulus has some larger prime factors! Sage *does* have
     762       the ability to do something much faster in certain cases at
     763       least by using Groebner basis, linear algebra techniques,
     764       etc. But for a lot of toy problems this function as is might be
     765       useful. At least it establishes an interface.
     766
     767    TESTS:
     768
     769    Make sure that we short-circuit in at least some cases::
     770
     771        sage: solve_mod([2*x==1], 2*next_prime(10^50))
     772        []
     773
     774    Try multi-equation cases::
     775
     776        sage: x, y, z = var("x y z")
     777        sage: solve_mod([2*x^2 + x*y, -x*y+2*y^2+x-2*y, -2*x^2+2*x*y-y^2-x-y], 12)
     778        [(0, 0), (4, 4), (0, 3), (4, 7)]
     779        sage: eqs = [-y^2+z^2, -x^2+y^2-3*z^2-z-1, -y*z-z^2-x-y+2, -x^2-12*z^2-y+z]
     780        sage: solve_mod(eqs, 11)
     781        [(8, 5, 6)]
     782
     783    Confirm that modulus 1 now behaves as it should::
     784   
     785        sage: x, y = var("x y")
     786        sage: solve_mod([x==1], 1)
     787        [(0,)]
     788        sage: solve_mod([2*x^2+x*y, -x*y+2*y^2+x-2*y, -2*x^2+2*x*y-y^2-x-y], 1)
     789        [(0, 0)]
     790   
     791
    744792    """
    745793    from sage.rings.all import Integer, Integers, PolynomialRing, factor, crt_basis
    746794    from sage.misc.all import cartesian_product_iterator
     
    753801    if modulus < 1:
    754802         raise ValueError, "the modulus must be a positive integer"
    755803    vars = list(set(sum([list(e.variables()) for e in eqns], [])))
    756     vars.sort(cmp = lambda x,y: cmp(repr(x), repr(y)))
    757     n = len(vars)
     804    vars.sort(key=repr)
    758805
    759     factors = [p**i for p,i in factor(modulus)]
    760     crt_basis = vector(Integers(modulus), crt_basis(factors))
    761     solutions = [solve_mod_enumerate(eqns, p) for p in factors]
     806    if modulus == 1: # degenerate case
     807        ans = [tuple(Integers(modulus)(0) for v in vars)]
     808    else:
     809        factors = [p**i for p,i in factor(modulus)]
     810        crt_basis = vector(Integers(modulus), crt_basis(factors))
     811        solutions = []
     812        for p in factors:
     813            s = solve_mod_enumerate(eqns, p)
     814            if not s: return []
     815            solutions.append(s)
    762816
    763     ans = []
    764     for solution in cartesian_product_iterator(solutions):
    765         solution_mat = matrix(Integers(modulus), solution)
    766         ans.append(tuple(c.dot_product(crt_basis) for c in solution_mat.columns()))
     817        ans = []
     818        for solution in cartesian_product_iterator(solutions):
     819            solution_mat = matrix(Integers(modulus), solution)
     820            ans.append(tuple(c.dot_product(crt_basis) for c in solution_mat.columns()))
    767821
    768822    if solution_dict == True:
    769823        sol_dict = [dict(zip(vars, solution)) for solution in ans]
     
    805859        sage: solve_mod([x^5 + y^5 == z^5], 3)
    806860        [(0, 0, 0), (0, 1, 1), (0, 2, 2), (1, 0, 1), (1, 1, 2), (1, 2, 0), (2, 0, 2), (2, 1, 0), (2, 2, 1)]
    807861
    808     We solve an simple equation modulo 2::
     862    We solve a simple equation modulo 2::
    809863   
    810864        sage: x,y = var('x,y')
    811865        sage: solve_mod([x == y], 2)
     
    814868       
    815869    .. warning::
    816870   
    817        Currently this naively enumerates all possible solutions.  The
    818        interface is good, but the algorithm is horrible if the modulus
    819        is at all large! Sage *does* have the ability to do something
    820        much faster in certain cases at least by using the Chinese
    821        Remainder Theorem, Groebner basis, linear algebra techniques,
    822        etc. But for a lot of toy problems this function as is might be
    823        useful. At the very least, it establishes an interface.
    824     """
     871       Currently this constructs possible solutions by building up
     872       from the smallest prime factor of the modulus.  The interface
     873       is good, but the algorithm is horrible if the modulus isn't the
     874       product of many small primes! Sage *does* have the ability to
     875       do something much faster in certain cases at least by using the
     876       Chinese Remainder Theorem, Groebner basis, linear algebra
     877       techniques, etc. But for a lot of toy problems this function as
     878       is might be useful. At the very least, it establishes an
     879       interface.
     880
     881    TESTS:
     882   
     883    Confirm we can reproduce the first few terms of OEIS A187719::
     884   
     885        sage: from sage.symbolic.relation import solve_mod_enumerate
     886        sage: [sorted(solve_mod_enumerate(x^2==41, 10^i))[0][0] for i in [1..13]]
     887        [1, 21, 71, 1179, 2429, 47571, 1296179, 8703821, 26452429, 526452429,
     888        13241296179, 19473547571, 2263241296179]
     889
     890    Confirm that modulus 1 now behaves as it should::
     891   
     892        sage: x, y = var("x y")
     893        sage: solve_mod_enumerate([2*x^2 + x*y, -x*y + 2*y^2 + x - 2*y, -2*x^2 + 2*x*y - y^2 - x - y], 1)
     894        [(0, 0)]
     895 
     896     """
    825897    from sage.rings.all import Integer, Integers, PolynomialRing
     898    from sage.rings.all import factor
    826899    from sage.symbolic.expression import is_Expression
    827900    from sage.misc.all import cartesian_product_iterator
    828    
     901    from sage.modules.all import vector
     902
    829903    if not isinstance(eqns, (list, tuple)):
    830904        eqns = [eqns]
     905    eqns = [eq if is_Expression(eq) else (eq.lhs()-eq.rhs()) for eq in eqns]
    831906    modulus = Integer(modulus)
    832907    if modulus < 1:
    833908         raise ValueError, "the modulus must be a positive integer"
    834909    vars = list(set(sum([list(e.variables()) for e in eqns], [])))
    835     vars.sort(cmp = lambda x,y: cmp(repr(x), repr(y)))
    836     n = len(vars)
    837     R = Integers(modulus)
    838     S = PolynomialRing(R, len(vars), vars)
    839     eqns_mod = [S(eq) if is_Expression(eq) else
    840                 S(eq.lhs() - eq.rhs()) for eq in eqns]
     910    vars.sort(key=repr)
     911
     912    if modulus == 1: # degenerate case
     913        ans = [tuple(Integers(modulus)(0) for v in vars)]
     914        return ans
     915
     916    factors = factor(modulus)
     917    mrunning = 1
    841918    ans = []
    842     for t in cartesian_product_iterator([R]*len(vars)):
    843         is_soln = True
    844         for e in eqns_mod:
    845             if e(t) != 0:
    846                 is_soln = False
    847                 break
    848         if is_soln:
    849             ans.append(t)
    850 
     919    for fi, (p, m) in enumerate(factors):
     920        for mi in xrange(m):
     921            mrunning *= p
     922            R = Integers(mrunning)
     923            S = PolynomialRing(R, len(vars), vars)
     924            eqns_mod = [S(eq) for eq in eqns]
     925            if fi == 0 and mi == 0:
     926                possibles = cartesian_product_iterator([xrange(len(R)) for _ in xrange(len(vars))])
     927            else:
     928                shifts = cartesian_product_iterator([xrange(p) for _ in xrange(len(vars))])
     929                pairs = cartesian_product_iterator([shifts, ans])
     930                possibles = (tuple(vector(t)+vector(shift)*(mrunning//p)) for shift, t in pairs)
     931            ans = list(t for t in possibles if all(e(t) == 0 for e in eqns_mod))
     932            if not ans: return ans
    851933    return ans
    852934
    853935def solve_ineq_univar(ineq):