Ticket #9663: stirling2.patch

File stirling2.patch, 11.5 KB (added by fredrik.johansson, 9 years ago)

fast implementation of stirling_number2 -- updated patch

  • module_list.py

    # HG changeset patch
    # User Fredrik Johansson <fredrik.johansson@gmail.com>
    # Date 1281047071 -7200
    # Node ID 1d52c99c455a162c525d98ba4157a93d35e057cf
    # Parent  7d20ef3a988c8054f41b4179735eb765d9a4980c
    #9663 fast Cython implementation of stirling_number2
    
    diff -r 7d20ef3a988c -r 1d52c99c455a module_list.py
    a b  
    203203    Extension('sage.combinat.permutation_cython',
    204204              sources=['sage/combinat/permutation_cython.pyx']),
    205205
     206    Extension('sage.combinat.combinat_cython',
     207              sources=['sage/combinat/combinat_cython.pyx']),
     208
    206209    ################################
    207210    ##
    208211    ## sage.crypto
  • sage/combinat/combinat.py

    diff -r 7d20ef3a988c -r 1d52c99c455a sage/combinat/combinat.py
    a b  
    1616
    1717- Florent Hivert (2009-02): combinatorial class cleanup
    1818
     19- Fredrik Johansson (2010-07): fast implementation of ``stirling_number2``
     20
    1921This module implements some combinatorial functions, as listed
    2022below. For a more detailed description, see the relevant
    2123docstrings.
     
    233235from sage.structure.parent import Parent
    234236from sage.misc.lazy_attribute import lazy_attribute
    235237from sage.misc.misc import deprecation
     238from combinat_cython import _stirling_number2
    236239######### combinatorial sequences
    237240
    238241def bell_number(n):
     
    555558    """
    556559    return ZZ(gap.eval("Stirling1(%s,%s)"%(ZZ(n),ZZ(k))))
    557560
    558 def stirling_number2(n,k):
     561def stirling_number2(n, k, algorithm=None):
    559562    """
    560563    Returns the n-th Stirling number `S_2(n,k)` of the second
    561564    kind (the number of ways to partition a set of n elements into k
    562565    pairwise disjoint nonempty subsets). (The n-th Bell number is the
    563     sum of the `S_2(n,k)`'s, `k=0,...,n`.) Wraps GAP's
    564     Stirling2.
    565    
    566     EXAMPLES: Stirling numbers satisfy
    567     `S_2(n,k) = S_2(n-1,k-1) + kS_2(n-1,k)`::
    568    
    569         sage: 5*stirling_number2(9,5) + stirling_number2(9,4)
    570         42525
    571         sage: stirling_number2(10,5)
    572         42525
    573    
    574     ::
    575    
    576         sage: n = stirling_number2(20,11); n
     566    sum of the `S_2(n,k)`'s, `k=0,...,n`.)
     567
     568    INPUT:
     569
     570       *  ``n`` - nonnegative machine-size integer
     571       *  ``k`` - nonnegative machine-size integer
     572       * ``algorithm``:
     573
     574         * None (default) - use native implementation
     575         * ``"maxima"`` - use Maxima's stirling2 function
     576         * ``"gap"`` - use GAP's Stirling2 function
     577
     578    EXAMPLES:
     579
     580    Print a table of the first several Stirling numbers of the second kind::
     581
     582        sage: for n in range(10):
     583        ...       for k in range(10):
     584        ...           print str(stirling_number2(n,k)).rjust(k and 6),
     585        ...       print
     586        ...
     587        1      0      0      0      0      0      0      0      0      0
     588        0      1      0      0      0      0      0      0      0      0
     589        0      1      1      0      0      0      0      0      0      0
     590        0      1      3      1      0      0      0      0      0      0
     591        0      1      7      6      1      0      0      0      0      0
     592        0      1     15     25     10      1      0      0      0      0
     593        0      1     31     90     65     15      1      0      0      0
     594        0      1     63    301    350    140     21      1      0      0
     595        0      1    127    966   1701   1050    266     28      1      0
     596        0      1    255   3025   7770   6951   2646    462     36      1
     597
     598    Stirling numbers satisfy `S_2(n,k) = S_2(n-1,k-1) + kS_2(n-1,k)`::
     599
     600         sage: 5*stirling_number2(9,5) + stirling_number2(9,4)
     601         42525
     602         sage: stirling_number2(10,5)
     603         42525
     604
     605    TESTS::
     606
     607        sage: stirling_number2(500,501)
     608        0
     609        sage: stirling_number2(500,500)
     610        1
     611        sage: stirling_number2(500,499)
     612        124750
     613        sage: stirling_number2(500,498)
     614        7739801875
     615        sage: stirling_number2(500,497)
     616        318420320812125
     617        sage: stirling_number2(500,0)
     618        0
     619        sage: stirling_number2(500,1)
     620        1
     621        sage: stirling_number2(500,2)
     622        1636695303948070935006594848413799576108321023021532394741645684048066898202337277441635046162952078575443342063780035504608628272942696526664263794687
     623        sage: stirling_number2(500,3)
     624        6060048632644989473730877846590553186337230837666937173391005972096766698597315914033083073801260849147094943827552228825899880265145822824770663507076289563105426204030498939974727520682393424986701281896187487826395121635163301632473646
     625        sage: stirling_number2(500,30)
     626        13707767141249454929449108424328432845001327479099713037876832759323918134840537229737624018908470350134593241314462032607787062188356702932169472820344473069479621239187226765307960899083230982112046605340713218483809366970996051181537181362810003701997334445181840924364501502386001705718466534614548056445414149016614254231944272872440803657763210998284198037504154374028831561296154209804833852506425742041757849726214683321363035774104866182331315066421119788248419742922490386531970053376982090046434022248364782970506521655684518998083846899028416459701847828711541840099891244700173707021989771147674432503879702222276268661726508226951587152781439224383339847027542755222936463527771486827849728880
     627        sage: stirling_number2(500,31)
     628        5832088795102666690960147007601603328246123996896731854823915012140005028360632199516298102446004084519955789799364757997824296415814582277055514048635928623579397278336292312275467402957402880590492241647229295113001728653772550743446401631832152281610081188041624848850056657889275564834450136561842528589000245319433225808712628826136700651842562516991245851618481622296716433577650218003181535097954294609857923077238362717189185577756446945178490324413383417876364657995818830270448350765700419876347023578011403646501685001538551891100379932684279287699677429566813471166558163301352211170677774072447414719380996777162087158124939742564291760392354506347716119002497998082844612434332155632097581510486912
     629        sage: n = stirling_number2(20,11)
     630        sage: n
    577631        1900842429486
    578632        sage: type(n)
    579633        <type 'sage.rings.integer.Integer'>
    580     """
    581     return ZZ(gap.eval("Stirling2(%s,%s)"%(ZZ(n),ZZ(k))))
     634        sage: n = stirling_number2(20,11,algorithm='gap')
     635        sage: n
     636        1900842429486
     637        sage: type(n)
     638        <type 'sage.rings.integer.Integer'>
     639        sage: n = stirling_number2(20,11,algorithm='maxima')
     640        sage: n
     641        1900842429486
     642        sage: type(n)
     643        <type 'sage.rings.integer.Integer'>
    582644
     645     """
     646    if algorithm is None:
     647        return _stirling_number2(n, k)
     648    elif algorithm == 'gap':
     649        return ZZ(gap.eval("Stirling2(%s,%s)"%(ZZ(n),ZZ(k))))
     650    elif algorithm == 'maxima':
     651        return ZZ(maxima.eval("stirling2(%s,%s)"%(ZZ(n),ZZ(k))))
     652    else:
     653        raise ValueError("unknown algorithm: %s" % algorithm)
    583654
    584655class CombinatorialObject(SageObject):
    585656    def __init__(self, l):
  • new file sage/combinat/combinat_cython.pxd

    diff -r 7d20ef3a988c -r 1d52c99c455a sage/combinat/combinat_cython.pxd
    - +  
     1from sage.libs.gmp.all cimport mpz_t
     2
     3cdef mpz_stirling_s2(mpz_t s, unsigned long n, unsigned long k)
  • new file sage/combinat/combinat_cython.pyx

    diff -r 7d20ef3a988c -r 1d52c99c455a sage/combinat/combinat_cython.pyx
    - +  
     1"""
     2Fast computation of combinatorial functions (Cython + mpz).
     3
     4Currently implemented:
     5- Stirling numbers of the second kind
     6
     7AUTHORS:
     8- Fredrik Johansson (2010-10): Stirling numbers of second kind
     9
     10"""
     11
     12include "../ext/stdsage.pxi"
     13
     14from stdlib cimport malloc, free
     15
     16from sage.libs.gmp.all cimport *
     17from sage.rings.integer cimport Integer
     18
     19cdef void mpz_addmul_alt(mpz_t s, mpz_t t, mpz_t u, unsigned long parity):
     20    """
     21    Set s = s + t*u * (-1)^parity
     22    """
     23    if parity & 1:
     24        mpz_submul(s, t, u)
     25    else:
     26        mpz_addmul(s, t, u)
     27
     28
     29cdef mpz_stirling_s2(mpz_t s, unsigned long n, unsigned long k):
     30    """
     31    Set s = S(n,k) where S(n,k) denotes a Stirling number of the
     32    second kind.
     33
     34    Algorithm: S(n,k) = (sum_{j=0}^k (-1)^(k-j) C(k,j) j^n) / k!
     35
     36    TODO: compute S(n,k) efficiently for large n when n-k is small
     37    (e.g. when k > 20 and n-k < 20)
     38    """
     39    cdef mpz_t t, u
     40    cdef mpz_t *bc
     41    cdef unsigned long j, max_bc
     42    # Some important special cases
     43    if k+1 >= n:
     44        # Upper triangle of n\k table
     45        if k > n:
     46            mpz_set_ui(s, 0)
     47        elif n == k:
     48            mpz_set_ui(s, 1)
     49        elif k+1 == n:
     50            # S(n,n-1) = C(n,2)
     51            mpz_set_ui(s, n)
     52            mpz_mul_ui(s, s, n-1)
     53            mpz_tdiv_q_2exp(s, s, 1)
     54    elif k <= 2:
     55        # Leftmost three columns of n\k table
     56        if k == 0:
     57            mpz_set_ui(s, 0)
     58        elif k == 1:
     59            mpz_set_ui(s, 1)
     60        elif k == 2:
     61            # 2^(n-1)-1
     62            mpz_set_ui(s, 1)
     63            mpz_mul_2exp(s, s, n-1)
     64            mpz_sub_ui(s, s, 1)
     65    # Direct sequential evaluation of the sum
     66    elif n < 200:
     67        mpz_init(t)
     68        mpz_init(u)
     69        mpz_set_ui(t, 1)
     70        mpz_set_ui(s, 0)
     71        for j in range(1, k//2+1):
     72            mpz_mul_ui(t, t, k+1-j)
     73            mpz_tdiv_q_ui(t, t, j)
     74            mpz_set_ui(u, j)
     75            mpz_pow_ui(u, u, n)
     76            mpz_addmul_alt(s, t, u, k+j)
     77            if 2*j != k:
     78                # Use the fact that C(k,j) = C(k,k-j)
     79                mpz_set_ui(u, k-j)
     80                mpz_pow_ui(u, u, n)
     81                mpz_addmul_alt(s, t, u, j)
     82        # Last term not included because loop starts from 1
     83        mpz_set_ui(u, k)
     84        mpz_pow_ui(u, u, n)
     85        mpz_add(s, s, u)
     86        mpz_fac_ui(t, k)
     87        mpz_tdiv_q(s, s, t)
     88        mpz_clear(t)
     89        mpz_clear(u)
     90    # Only compute odd powers, saving about half of the time for large n.
     91    # We need to precompute binomial coefficients since they will be accessed
     92    # out of order, adding overhead that makes this slower for small n.
     93    else:
     94        mpz_init(t)
     95        mpz_init(u)
     96        max_bc = (k+1)//2
     97        bc = <mpz_t*> malloc((max_bc+1) * sizeof(mpz_t))
     98        mpz_init_set_ui(bc[0], 1)
     99        for j in range(1, max_bc+1):
     100            mpz_init_set(bc[j], bc[j-1])
     101            mpz_mul_ui(bc[j], bc[j], k+1-j)
     102            mpz_tdiv_q_ui(bc[j], bc[j], j)
     103        mpz_set_ui(s, 0)
     104        for j in range(1, k+1, 2):
     105            mpz_set_ui(u, j)
     106            mpz_pow_ui(u, u, n)
     107            # Process each 2^p * j, where j is odd
     108            while 1:
     109                if j > max_bc:
     110                    mpz_addmul_alt(s, bc[k-j], u, k+j)
     111                else:
     112                    mpz_addmul_alt(s, bc[j], u, k+j)
     113                j *= 2
     114                if j > k:
     115                    break
     116                mpz_mul_2exp(u, u, n)
     117        for j in range(max_bc+1):   # careful: 0 ... max_bc
     118            mpz_clear(bc[j])
     119        free(bc)
     120        mpz_fac_ui(t, k)
     121        mpz_tdiv_q(s, s, t)
     122        mpz_clear(t)
     123        mpz_clear(u)
     124
     125def _stirling_number2(n, k):
     126    """
     127    Python wrapper of mpz_stirling_s2.
     128
     129        sage: from sage.combinat.combinat_cython import _stirling_number2
     130        sage: _stirling_number2(3, 2)
     131        3
     132
     133    This is wrapped again by stirling_number2 in combinat.py.
     134    """
     135    cdef Integer s
     136    s = PY_NEW(Integer)
     137    mpz_stirling_s2(s.value, n, k)
     138    return s