Ticket #3542: 3542-dharvey-bernoulli-numbers-multimodular.patch

File 3542-dharvey-bernoulli-numbers-multimodular.patch, 61.5 KB (added by ncalexan, 19 months ago)
  • sage/rings/arith.py

    # HG changeset patch
    # User dmharvey@math.harvard.edu
    # Date 1214353095 14400
    # Node ID 8b61a8e0a9f77189f7d9bec7366b6e3847c83a32
    # Parent  6a6766d05f3bc23e52c4ce7477132b9a7ce1607d
    bernmm
    * * *
    fixed overflow bug, added "r" prefix to comments, added algorithm="default" option to bernoulli(), per reviewer's comment on #3542
    
    diff -r dcd714fbd837 sage/rings/arith.py
    a b  
    145145 
    146146algebraic_dependency = algdep 
    147147 
    148 def bernoulli(n, algorithm='pari'): 
     148def bernoulli(n, algorithm='default', num_threads=1): 
    149149    r""" 
    150150    Return the n-th Bernoulli number, as a rational number. 
    151151 
    152152    INPUT: 
    153153        n -- an integer 
    154154        algorithm: 
    155             'pari' -- (default) use the PARI C library, which is 
    156                       by *far* the fastest. 
     155            'default' -- (default) use 'pari' for n <= 30000, and 
     156                         'bernmm' for n > 30000 (this is just a heuristic, 
     157                         and not guaranteed to be optimal on all hardware) 
     158            'pari' -- use the PARI C library 
    157159            'gap'  -- use GAP 
    158160            'gp'   -- use PARI/GP interpreter 
    159161            'magma' -- use MAGMA (optional) 
    160162            'python' -- use pure Python implementation 
     163            'bernmm' -- use bernmm package (a multimodular algorithm) 
     164        num_threads -- positive integer, number of threads to use 
     165                       (only used for bernmm algorithm) 
    161166 
    162167    EXAMPLES: 
    163168        sage: bernoulli(12) 
     
    176181        -691/2730 
    177182        sage: bernoulli(12, algorithm='python') 
    178183        -691/2730 
     184        sage: bernoulli(12, algorithm='bernmm') 
     185        -691/2730 
     186        sage: bernoulli(12, algorithm='bernmm', num_threads=4) 
     187        -691/2730 
    179188 
    180189    \note{If $n>50000$ then algorithm = 'gp' is used instead of 
    181190    algorithm = 'pari', since the C-library interface to PARI 
     
    185194    """ 
    186195    from sage.rings.all import Integer, Rational 
    187196    n = Integer(n) 
     197 
     198    if algorithm == 'default': 
     199        algorithm = 'pari' if n <= 30000 else 'bernmm' 
     200 
    188201    if n > 50000 and algorithm == 'pari': 
    189202        algorithm = 'gp' 
     203 
    190204    if algorithm == 'pari': 
    191205        x = pari(n).bernfrac()         # Use the PARI C library 
    192206        return Rational(x) 
     
    205219    elif algorithm == 'python': 
    206220        import sage.rings.bernoulli 
    207221        return sage.rings.bernoulli.bernoulli_python(n) 
     222    elif algorithm == 'bernmm': 
     223        import sage.rings.bernmm 
     224        return sage.rings.bernmm.bernmm_bern_rat(n, num_threads) 
    208225    else: 
    209226        raise ValueError, "invalid choice of algorithm" 
    210227 
  • (a) /dev/null vs. (b) b/sage/rings/bernmm.pyx

    diff -r dcd714fbd837 sage/rings/bernmm.pyx
    a b  
     1r""" 
     2Cython wrapper for bernmm library 
     3 
     4AUTHOR:  
     5    - David Harvey (2008-06): initial version 
     6""" 
     7 
     8#***************************************************************************** 
     9#       Copyright (C) 2008 William Stein <wstein@gmail.com> 
     10#                     2008 David Harvey <dmharvey@math.harvard.edu> 
     11# 
     12#  Distributed under the terms of the GNU General Public License (GPL) 
     13#                  http://www.gnu.org/licenses/ 
     14#***************************************************************************** 
     15 
     16include "../ext/cdefs.pxi" 
     17include "../ext/interrupt.pxi" 
     18 
     19 
     20cdef extern from "bernmm/bern_rat.h": 
     21    void bern_rat "bernmm::bern_rat" (mpq_t res, long k, int num_threads) 
     22 
     23cdef extern from "bernmm/bern_modp.h": 
     24    long bern_modp "bernmm::bern_modp" (long p, long k) 
     25 
     26 
     27 
     28from sage.rings.rational cimport Rational 
     29 
     30 
     31def bernmm_bern_rat(long k, int num_threads = 1): 
     32    r""" 
     33    Computes k-th Bernoulli number using a multimodular algorithm. 
     34    (Wrapper for bernmm library.) 
     35 
     36    INPUT: 
     37        k -- non-negative integer 
     38        num_threads -- integer >= 1, number of threads to use 
     39 
     40    COMPLEXITY: 
     41        Pretty much quadratic in $k$. See the paper ``A multimodular algorithm 
     42        for computing Bernoulli numbers'', David Harvey, 2008, for more details. 
     43 
     44    EXAMPLES: 
     45        sage: from sage.rings.bernmm import bernmm_bern_rat 
     46 
     47        sage: bernmm_bern_rat(0) 
     48        1 
     49        sage: bernmm_bern_rat(1) 
     50        -1/2 
     51        sage: bernmm_bern_rat(2) 
     52        1/6 
     53        sage: bernmm_bern_rat(3) 
     54        0 
     55        sage: bernmm_bern_rat(100) 
     56        -94598037819122125295227433069493721872702841533066936133385696204311395415197247711/33330 
     57        sage: bernmm_bern_rat(100, 3) 
     58        -94598037819122125295227433069493721872702841533066936133385696204311395415197247711/33330 
     59         
     60    TESTS: 
     61        sage: lst1 = [ bernoulli(2*k, algorithm='bernmm', num_threads=2) for k in [2932, 2957, 3443, 3962, 3973] ] 
     62        sage: lst2 = [ bernoulli(2*k, algorithm='pari') for k in [2932, 2957, 3443, 3962, 3973] ] 
     63        sage: lst1 == lst2 
     64        True 
     65        sage: [ Zmod(101)(t) for t in lst1 ] 
     66        [77, 72, 89, 98, 86] 
     67        sage: [ Zmod(101)(t) for t in lst2 ] 
     68        [77, 72, 89, 98, 86] 
     69    """ 
     70    cdef Rational x 
     71 
     72    if k < 0: 
     73        raise ValueError, "k must be non-negative" 
     74 
     75    x = Rational() 
     76    _sig_on 
     77    bern_rat(x.value, k, num_threads) 
     78    _sig_off 
     79 
     80    return x 
     81 
     82 
     83def bernmm_bern_modp(long p, long k): 
     84    r""" 
     85    Computes $B_k \mod p$, where $B_k$ is the k-th Bernoulli number. 
     86 
     87    If $B_k$ is not $p$-integral, returns -1. 
     88 
     89    INPUT: 
     90        p -- a prime 
     91        k -- non-negative integer 
     92 
     93    COMPLEXITY: 
     94        Pretty much linear in $p$. 
     95 
     96    EXAMPLES: 
     97        sage: from sage.rings.bernmm import bernmm_bern_modp 
     98 
     99        sage: bernoulli(0) % 5, bernmm_bern_modp(5, 0) 
     100        (1, 1) 
     101        sage: bernoulli(1) % 5, bernmm_bern_modp(5, 1) 
     102        (2, 2) 
     103        sage: bernoulli(2) % 5, bernmm_bern_modp(5, 2) 
     104        (1, 1) 
     105        sage: bernoulli(3) % 5, bernmm_bern_modp(5, 3) 
     106        (0, 0) 
     107        sage: bernoulli(4), bernmm_bern_modp(5, 4) 
     108        (-1/30, -1) 
     109        sage: bernoulli(18) % 5, bernmm_bern_modp(5, 18) 
     110        (4, 4) 
     111        sage: bernoulli(19) % 5, bernmm_bern_modp(5, 19) 
     112        (0, 0) 
     113 
     114        sage: p = 10000019; k = 1000 
     115        sage: bernoulli(k) % p 
     116        1972762 
     117        sage: bernmm_bern_modp(p, k) 
     118        1972762 
     119 
     120    """ 
     121    cdef long x 
     122 
     123    if k < 0: 
     124        raise ValueError, "k must be non-negative" 
     125 
     126    _sig_on 
     127    x = bern_modp(p, k) 
     128    _sig_off 
     129 
     130    return x 
     131     
     132 
     133# ============ end of file 
  • (a) /dev/null vs. (b) b/sage/rings/bernmm/README.txt

    diff -r dcd714fbd837 sage/rings/bernmm/README.txt
    a b  
     1This directory contains the bernmm library, by David Harvey. 
     2 
     3bernmm is maintained as a separate project, see my website dharvey.net. 
     4 
     5If you make any changes/improvements, please tell me about it! 
     6I might want to merge those changes into the upstream version. 
     7 
     82008/06/19 
     9------------------------------------- 
  • (a) /dev/null vs. (b) b/sage/rings/bernmm/bern_modp.cpp

    diff -r dcd714fbd837 sage/rings/bernmm/bern_modp.cpp
    a b  
     1/* 
     2   bern_modp.cpp:  computing isolated Bernoulli numbers modulo p 
     3    
     4   Copyright (C) 2008, David Harvey 
     5    
     6   This file is part of the bernmm package (version 1.0). 
     7 
     8   This program is free software: you can redistribute it and/or modify 
     9   it under the terms of the GNU General Public License as published by 
     10   the Free Software Foundation, either version 2 of the License, or 
     11   (at your option) any later version. 
     12 
     13   This program is distributed in the hope that it will be useful, 
     14   but WITHOUT ANY WARRANTY; without even the implied warranty of 
     15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
     16   GNU General Public License for more details. 
     17 
     18   You should have received a copy of the GNU General Public License 
     19   along with this program.  If not, see <http://www.gnu.org/licenses/>. 
     20*/ 
     21 
     22 
     23#include <limits.h> 
     24#include <signal.h> 
     25#include <gmp.h> 
     26#include <NTL/ZZ.h> 
     27#include "bern_modp_util.h" 
     28#include "bern_modp.h" 
     29 
     30 
     31NTL_CLIENT; 
     32 
     33 
     34using namespace std; 
     35 
     36 
     37namespace bernmm { 
     38 
     39 
     40/****************************************************************************** 
     41 
     42   Computing the main sum (general case) 
     43 
     44******************************************************************************/ 
     45 
     46/* 
     47   Returns (1 - g^k) B_k / 2k mod p. 
     48    
     49   PRECONDITIONS: 
     50      5 <= p < NTL_SP_BOUND, p prime 
     51      2 <= k <= p-3, k even 
     52      pinv = 1 / ((double) p) 
     53      g = a multiplicative generator of GF(p), in [0, p) 
     54*/ 
     55long bernsum_powg(long p, double pinv, long k, long g) 
     56{ 
     57   long half_gm1 = (g + ((g & 1) ? 0 : p) - 1) / 2;    // (g-1)/2 mod p 
     58   long g_to_jm1 = 1; 
     59   long g_to_km1 = PowerMod(g, k-1, p, pinv); 
     60   long g_to_km1_to_j = g_to_km1; 
     61   long sum = 0; 
     62   double g_pinv = ((double) g) / ((double) p); 
     63   mulmod_precon_t g_to_km1_pinv = PrepMulModPrecon(g_to_km1, p, pinv); 
     64    
     65   for (long j = 1; j <= (p-1)/2; j++) 
     66   { 
     67      // at this point, 
     68      //    g_to_jm1 holds g^(j-1) mod p 
     69      //    g_to_km1_to_j holds (g^(k-1))^j mod p 
     70    
     71      // update g_to_jm1 and compute q = (g*(g^(j-1) mod p) - (g^j mod p)) / p 
     72      long q; 
     73      g_to_jm1 = MulDivRem(q, g_to_jm1, g, p, g_pinv); 
     74       
     75      // compute h = -h_g(g^j) = q - (g-1)/2 
     76      long h = SubMod(q, half_gm1, p); 
     77       
     78      // add h_g(g^j) * (g^(k-1))^j to running total 
     79      sum = SubMod(sum, MulMod(h, g_to_km1_to_j, p, pinv), p); 
     80       
     81      // update g_to_km1_to_j 
     82      g_to_km1_to_j = MulModPrecon(g_to_km1_to_j, g_to_km1, p, g_to_km1_pinv); 
     83   } 
     84 
     85   return sum; 
     86} 
     87 
     88 
     89 
     90/****************************************************************************** 
     91 
     92   Computing the main sum (c = 1/2 case) 
     93 
     94******************************************************************************/ 
     95 
     96 
     97/* 
     98   The Expander class stores precomputed information for a fixed integer p, 
     99   that subsequently permits fast computation of the binary expansion of s/p 
     100   for any 0 < s < p. 
     101    
     102   The constructor takes p and max_words as input. Must have 1 <= max_words <= 
     103   MAX_INV. It computes an approximation to 1/p. 
     104    
     105   The function expand(word_t* res, long s, long n) computes n words of s/p. 
     106   Must have 0 < s < p and 1 <= n <= max_words. The output is written to res. 
     107   The first word of output is junk. The next n words are the digits of s/p, 
     108   from least to most significant. The buffer must be at least n+2 words long 
     109   (even though the first and last words are never used for output). 
     110 
     111   A "word" is a word_t, and contains WORD_BITS bits. On most systems this 
     112   will be an mp_limb_t. 
     113*/ 
     114 
     115#define MAX_INV 256 
     116 
     117#if (GMP_NAIL_BITS == 0) && (GMP_LIMB_BITS >= LONG_BIT) 
     118// fast mpn-based version 
     119 
     120typedef mp_limb_t word_t; 
     121#define WORD_BITS GMP_LIMB_BITS 
     122 
     123class Expander 
     124{ 
     125private: 
     126   // Approximation to 1/p. We store (max_words + 1) limbs. 
     127   mp_limb_t pinv[MAX_INV + 2]; 
     128   mp_limb_t p; 
     129   int max_words; 
     130 
     131public: 
     132   Expander(long p, int max_words) 
     133   { 
     134      assert(max_words >= 1); 
     135      assert(max_words <= MAX_INV); 
     136       
     137      this->max_words = max_words; 
     138      this->p = p; 
     139      mp_limb_t one = 1; 
     140      mpn_divrem_1(pinv, max_words + 1, &one, 1, p); 
     141   } 
     142 
     143   void expand(word_t* res, long s, int n) 
     144   { 
     145      assert(s > 0 && s < p); 
     146      assert(n >= 1); 
     147      assert(n <= max_words); 
     148    
     149      if (s == 1) 
     150      { 
     151         // already have 1/p; just copy it 
     152         for (int i = 1; i <= n; i++) 
     153            res[i] = pinv[max_words - n + i]; 
     154      } 
     155      else 
     156      { 
     157         mpn_mul_1(res, pinv + max_words - n, n + 1, (mp_limb_t) s); 
     158 
     159         // If the first output limb is really close to 0xFFFF..., then there's 
     160         // a possibility of overflow, so fall back on doing division directly. 
     161         // This should happen extremely rarely --- essentially never on a 
     162         // 64-bit system, and very occasionally on a 32-bit system. 
     163         if (res[0] > -((mp_limb_t) s)) 
     164         { 
     165            mp_limb_t ss = s; 
     166            mpn_divrem_1(res, n + 1, &ss, 1, p); 
     167         } 
     168      } 
     169   } 
     170}; 
     171 
     172 
     173#else 
     174// slow mpz-based version, since GMP is using nails, or mp_limb_t is 
     175// absurdly narrow 
     176 
     177typedef unsigned long word_t; 
     178#define WORD_BITS LONG_BIT 
     179 
     180class Expander 
     181{ 
     182private: 
     183   mp_limb_t p; 
     184   mpz_t temp; 
     185 
     186public: 
     187   Expander(long p, int max_words) 
     188   { 
     189      this->p = p; 
     190      mpz_init(temp); 
     191   } 
     192    
     193   ~Expander() 
     194   { 
     195      mpz_clear(temp); 
     196   } 
     197 
     198   void expand(word_t* res, long s, int n) 
     199   { 
     200      assert(s > 0 && s < p); 
     201      assert(n >= 1); 
     202       
     203      mpz_set_ui(temp, s); 
     204      mpz_mul_2exp(temp, temp, WORD_BITS * n); 
     205      mpz_fdiv_q_ui(temp, temp, p); 
     206      mpz_export(res + 1, NULL, -1, sizeof(word_t), 0, 0, temp); 
     207   } 
     208}; 
     209 
     210#endif 
     211 
     212 
     213 
     214/* 
     215   Returns (2^(-k) - 1) 2 B_k / k  mod p. 
     216    
     217   (Note: this is useless if 2^k = 1 mod p.) 
     218       
     219   PRECONDITIONS: 
     220      5 <= p < NTL_SP_BOUND, p prime 
     221      2 <= k <= p-3, k even 
     222      pinv = 1 / ((double) p) 
     223      g = a multiplicative generator of GF(p), in [0, p) 
     224      n = multiplicative order of 2 in GF(p) 
     225*/ 
     226 
     227#define TABLE_LG_SIZE 8 
     228#define TABLE_SIZE (((word_t) 1) << TABLE_LG_SIZE) 
     229#define TABLE_MASK (TABLE_SIZE - 1) 
     230#define NUM_TABLES (WORD_BITS / TABLE_LG_SIZE) 
     231 
     232#if WORD_BITS % TABLE_LG_SIZE != 0 
     233#error Number of bits in a long must be divisible by TABLE_LG_SIZE 
     234#endif 
     235 
     236long bernsum_pow2(long p, double pinv, long k, long g, long n) 
     237{ 
     238   // In the main summation loop we accumulate data into the _tables_ array; 
     239   // tables[y][z] contributes to the final answer with a weight of 
     240   //  
     241   // sum(-(-1)^z[t] * (2^(k-1))^(WORD_BITS - 1 - y * TABLE_LG_SIZE - t) : 
     242   //                                                  0 <= t < TABLE_LG_SIZE), 
     243   // 
     244   // where z[t] denotes the t-th binary digit of z (LSB is t = 0). 
     245   // The memory footprint for _tables_ is 4KB on a 32-bit machine, or 16KB 
     246   // on a 64-bit machine, so should fit easily into L1 cache. 
     247   long tables[NUM_TABLES][TABLE_SIZE]; 
     248   memset(tables, 0, sizeof(long) * NUM_TABLES * TABLE_SIZE); 
     249    
     250   long m = (p-1) / n; 
     251 
     252   // take advantage of symmetry (n' and m' from the paper) 
     253   if (n & 1) 
     254      m >>= 1; 
     255   else 
     256      n >>= 1; 
     257 
     258   // g^(k-1) 
     259   long g_to_km1 = PowerMod(g, k-1, p, pinv); 
     260   // 2^(k-1) 
     261   long two_to_km1 = PowerMod(2, k-1, p, pinv); 
     262   // B^(k-1), where B = 2^WORD_BITS 
     263   long B_to_km1 = PowerMod(two_to_km1, WORD_BITS, p, pinv); 
     264   // B^(MAX_INV) 
     265   long s_jump = PowerMod(2, MAX_INV * WORD_BITS, p, pinv); 
     266 
     267   // help speed up modmuls 
     268   mulmod_precon_t g_pinv = PrepMulModPrecon(g, p, pinv); 
     269   mulmod_precon_t g_to_km1_pinv = PrepMulModPrecon(g_to_km1, p, pinv); 
     270   mulmod_precon_t two_to_km1_pinv = PrepMulModPrecon(two_to_km1, p, pinv); 
     271   mulmod_precon_t B_to_km1_pinv = PrepMulModPrecon(B_to_km1, p, pinv); 
     272   mulmod_precon_t s_jump_pinv = PrepMulModPrecon(s_jump, p, pinv); 
     273 
     274   long g_to_km1_to_i = 1; 
     275   long g_to_i = 1; 
     276   long sum = 0; 
     277 
     278   // Precompute some of the binary expansion of 1/p; at most MAX_INV words, 
     279   // or possibly less if n is sufficiently small 
     280   Expander expander(p, (n >= MAX_INV * WORD_BITS) 
     281                                       ? MAX_INV : ((n - 1) / WORD_BITS + 1)); 
     282 
     283   // =========== phase 1: main summation loop 
     284    
     285   // loop over outer sum 
     286   for (long i = 0; i < m; i++) 
     287   { 
     288      // s keeps track of g^i*2^j mod p 
     289      long s = g_to_i; 
     290      // x keeps track of (g^i*2^j)^(k-1) mod p 
     291      long x = g_to_km1_to_i; 
     292       
     293      // loop over inner sum; break it up into chunks of length at most 
     294      // MAX_INV * WORD_BITS. If n is large, this allows us to do most of 
     295      // the work with mpn_mul_1 instead of mpn_divrem_1, and also improves 
     296      // memory locality. 
     297      for (long nn = n; nn > 0; nn -= MAX_INV * WORD_BITS) 
     298      { 
     299         word_t s_over_p[MAX_INV + 2]; 
     300         long bits, words; 
     301 
     302         if (nn >= MAX_INV * WORD_BITS) 
     303         { 
     304            // do one chunk of length exactly MAX_INV * WORD_BITS 
     305            bits = MAX_INV * WORD_BITS; 
     306            words = MAX_INV; 
     307         } 
     308         else 
     309         { 
     310            // last chunk of length less than MAX_INV * WORD_BITS 
     311            bits = nn; 
     312            words = (nn - 1) / WORD_BITS + 1; 
     313         } 
     314          
     315         // compute some bits of the binary expansion of s/p 
     316         expander.expand(s_over_p, s, words); 
     317         word_t* next = s_over_p + words; 
     318 
     319         // loop over whole words 
     320         for (; bits >= WORD_BITS; bits -= WORD_BITS, next--) 
     321         { 
     322            word_t y = *next; 
     323 
     324#if NUM_TABLES != 8 && NUM_TABLES != 4 
     325            // generic version 
     326            for (long h = 0; h < NUM_TABLES; h++) 
     327            { 
     328               long& target = tables[h][y & TABLE_MASK]; 
     329               target = SubMod(target, x, p); 
     330               y >>= TABLE_LG_SIZE; 
     331            } 
     332#else 
     333            // unrolled versions for 32-bit/64-bit machines 
     334            long& target0 = tables[0][y & TABLE_MASK]; 
     335            target0 = SubMod(target0, x, p); 
     336 
     337            long& target1 = tables[1][(y >> TABLE_LG_SIZE) & TABLE_MASK]; 
     338            target1 = SubMod(target1, x, p); 
     339 
     340            long& target2 = tables[2][(y >> (2*TABLE_LG_SIZE)) & TABLE_MASK]; 
     341            target2 = SubMod(target2, x, p); 
     342 
     343            long& target3 = tables[3][(y >> (3*TABLE_LG_SIZE)) & TABLE_MASK]; 
     344            target3 = SubMod(target3, x, p); 
     345#if NUM_TABLES == 8 
     346            long& target4 = tables[4][(y >> (4*TABLE_LG_SIZE)) & TABLE_MASK]; 
     347            target4 = SubMod(target4, x, p); 
     348 
     349            long& target5 = tables[5][(y >> (5*TABLE_LG_SIZE)) & TABLE_MASK]; 
     350            target5 = SubMod(target5, x, p); 
     351 
     352            long& target6 = tables[6][(y >> (6*TABLE_LG_SIZE)) & TABLE_MASK]; 
     353            target6 = SubMod(target6, x, p); 
     354 
     355            long& target7 = tables[7][(y >> (7*TABLE_LG_SIZE)) & TABLE_MASK]; 
     356            target7 = SubMod(target7, x, p); 
     357#endif 
     358#endif 
     359 
     360            x = MulModPrecon(x, B_to_km1, p, B_to_km1_pinv); 
     361         } 
     362          
     363         // loop over remaining bits in the last word 
     364         word_t y = *next; 
     365         for (; bits > 0; bits--) 
     366         { 
     367            if (y & (((word_t) 1) << (WORD_BITS - 1))) 
     368               sum = SubMod(sum, x, p); 
     369            else 
     370               sum = AddMod(sum, x, p); 
     371 
     372            x = MulModPrecon(x, two_to_km1, p, two_to_km1_pinv); 
     373            y <<= 1; 
     374         } 
     375          
     376         // update s 
     377         s = MulModPrecon(s, s_jump, p, s_jump_pinv); 
     378      } 
     379 
     380      // update g^i and (g^(k-1))^i 
     381      g_to_i = MulModPrecon(g_to_i, g, p, g_pinv); 
     382      g_to_km1_to_i = MulModPrecon(g_to_km1_to_i, g_to_km1, p, g_to_km1_pinv); 
     383   } 
     384 
     385   // =========== phase 2: consolidate table data 
     386 
     387   // compute weights[z] = sum((-1)^z[t] * (2^(k-1))^(TABLE_LG_SIZE - 1 - t) : 
     388   //                                                  0 <= t < TABLE_LG_SIZE). 
     389 
     390   long weights[TABLE_SIZE]; 
     391   weights[0] = 0; 
     392   for (long h = 0, x = 1; h < TABLE_LG_SIZE; 
     393        h++, x = MulModPrecon(x, two_to_km1, p, two_to_km1_pinv)) 
     394   { 
     395      for (long i = (1L << h) - 1; i >= 0; i--) 
     396      { 
     397         weights[2*i+1] = SubMod(weights[i], x, p); 
     398         weights[2*i]   = AddMod(weights[i], x, p); 
     399      } 
     400   } 
     401 
     402   // combine table data with weights 
     403 
     404   long x_jump = PowerMod(two_to_km1, TABLE_LG_SIZE, p, pinv); 
     405 
     406   for (long h = NUM_TABLES - 1, x = 1; h >= 0; h--) 
     407   { 
     408      mulmod_precon_t x_pinv = PrepMulModPrecon(x, p, pinv); 
     409 
     410      for (long i = 0; i < TABLE_SIZE; i++) 
     411      { 
     412         long y = MulMod(tables[h][i], weights[i], p, pinv); 
     413         y = MulModPrecon(y, x, p, x_pinv); 
     414         sum = SubMod(sum, y, p); 
     415      } 
     416 
     417      x = MulModPrecon(x_jump, x, p, x_pinv); 
     418   } 
     419 
     420   return sum; 
     421} 
     422 
     423 
     424/****************************************************************************** 
     425 
     426   Computing the main sum (c = 1/2 case, with REDC arithmetic) 
     427    
     428   Throughout this section F denotes 2^(LONG_BIT / 2). 
     429 
     430******************************************************************************/ 
     431 
     432 
     433/* 
     434   Returns x/F mod n. Output is in [0, 2n), i.e. *not* reduced completely 
     435   into [0, n). 
     436 
     437   PRECONDITIONS: 
     438      3 <= n < F, n odd 
     439      0 <= x < nF    (if n < F/2) 
     440      0 <= x < nF/2  (if n > F/2) 
     441      ninv2 = -1/n mod F 
     442*/ 
     443#define LOW_MASK ((1L << (LONG_BIT / 2)) - 1) 
     444static inline long RedcFast(long x, long n, long ninv2) 
     445{ 
     446   unsigned long y = (x * ninv2) & LOW_MASK; 
     447   unsigned long z = x + (n * y); 
     448   return z >> (LONG_BIT / 2); 
     449} 
     450 
     451 
     452/* 
     453   Same as RedcFast(), but reduces output into [0, n). 
     454*/ 
     455static inline long Redc(long x, long n, long ninv2) 
     456{ 
     457   long y = RedcFast(x, n, ninv2); 
     458   if (y >= n) 
     459      y -= n; 
     460   return y; 
     461} 
     462 
     463 
     464/* 
     465   Computes -1/n mod F, in [0, F). 
     466    
     467   PRECONDITIONS: 
     468      3 <= n < F, n odd 
     469*/ 
     470long PrepRedc(long n) 
     471{ 
     472   long ninv2 = -n;   // already correct mod 8 
     473    
     474   // newton's method for 2-adic inversion 
     475   for (long bits = 3; bits < LONG_BIT/2; bits *= 2) 
     476      ninv2 = 2*ninv2 + n * ninv2 * ninv2; 
     477       
     478   return ninv2 & LOW_MASK; 
     479} 
     480 
     481 
     482/* 
     483   Same as bernsum_pow2(), but uses REDC arithmetic, and various delayed 
     484   reduction strategies. 
     485 
     486   PRECONDITIONS: 
     487      Same as bernsum_pow2(), and in addition: 
     488      p < 2^(LONG_BIT/2 - 1) 
     489       
     490   (See bernsum_pow2() for code comments; we only add comments here where 
     491   something is different from bernsum_pow2()) 
     492*/ 
     493long bernsum_pow2_redc(long p, double pinv, long k, long g, long n) 
     494{ 
     495   long pinv2 = PrepRedc(p); 
     496   long F = (1L << (LONG_BIT/2)) % p; 
     497    
     498   long tables[NUM_TABLES][TABLE_SIZE]; 
     499   memset(tables, 0, sizeof(long) * NUM_TABLES * TABLE_SIZE); 
     500 
     501   long m = (p-1) / n; 
     502 
     503   if (n & 1) 
     504      m >>= 1; 
     505   else 
     506      n >>= 1; 
     507    
     508   long g_to_km1 = PowerMod(g, k-1, p, pinv); 
     509   long two_to_km1 = PowerMod(2, k-1, p, pinv); 
     510   long B_to_km1 = PowerMod(two_to_km1, WORD_BITS, p, pinv); 
     511   long s_jump = PowerMod(2, MAX_INV * WORD_BITS, p, pinv); 
     512    
     513   long g_redc = MulMod(g, F, p, pinv); 
     514   long g_to_km1_redc = MulMod(g_to_km1, F, p, pinv); 
     515   long two_to_km1_redc = MulMod(two_to_km1, F, p, pinv); 
     516   long B_to_km1_redc = MulMod(B_to_km1, F, p, pinv); 
     517   long s_jump_redc = MulMod(s_jump, F, p, pinv); 
     518 
     519   long g_to_km1_to_i = 1;    // always in [0, 2p) 
     520   long g_to_i = 1;           // always in [0, 2p) 
     521   long sum = 0; 
     522 
     523   Expander expander(p, (n >= MAX_INV * WORD_BITS) 
     524                                       ? MAX_INV : ((n - 1) / WORD_BITS + 1)); 
     525 
     526   // =========== phase 1: main summation loop 
     527    
     528   for (long i = 0; i < m; i++) 
     529   { 
     530      long s = g_to_i;           // always in [0, p) 
     531      if (s >= p) 
     532         s -= p; 
     533       
     534      long x = g_to_km1_to_i;    // always in [0, 2p) 
     535       
     536      for (long nn = n; nn > 0; nn -= MAX_INV * WORD_BITS) 
     537      { 
     538         word_t s_over_p[MAX_INV + 2]; 
     539         long bits, words; 
     540          
     541         if (nn >= MAX_INV * WORD_BITS) 
     542         { 
     543            bits = MAX_INV * WORD_BITS; 
     544            words = MAX_INV; 
     545         } 
     546         else 
     547         { 
     548            bits = nn; 
     549            words = (nn - 1) / WORD_BITS + 1; 
     550         } 
     551 
     552         expander.expand(s_over_p, s, words); 
     553         word_t* next = s_over_p + words; 
     554          
     555         for (; bits >= WORD_BITS; bits -= WORD_BITS, next--) 
     556         { 
     557            word_t y = *next; 
     558             
     559            // note: we add the values into tables *without* reduction mod p 
     560 
     561#if NUM_TABLES != 8 && NUM_TABLES != 4 
     562            // generic version 
     563            for (long h = 0; h < NUM_TABLES; h++) 
     564            { 
     565               tables[h][y & TABLE_MASK] += x; 
     566               y >>= TABLE_LG_SIZE; 
     567            } 
     568#else 
     569            // unrolled versions for 32-bit/64-bit machines 
     570            tables[0][ y                       & TABLE_MASK] += x; 
     571            tables[1][(y >>    TABLE_LG_SIZE ) & TABLE_MASK] += x; 
     572            tables[2][(y >> (2*TABLE_LG_SIZE)) & TABLE_MASK] += x; 
     573            tables[3][(y >> (3*TABLE_LG_SIZE)) & TABLE_MASK] += x; 
     574#if NUM_TABLES == 8 
     575            tables[4][(y >> (4*TABLE_LG_SIZE)) & TABLE_MASK] += x; 
     576            tables[5][(y >> (5*TABLE_LG_SIZE)) & TABLE_MASK] += x; 
     577            tables[6][(y >> (6*TABLE_LG_SIZE)) & TABLE_MASK] += x; 
     578            tables[7][(y >> (7*TABLE_LG_SIZE)) & TABLE_MASK] += x; 
     579#endif             
     580#endif 
     581 
     582            x = RedcFast(x * B_to_km1_redc, p, pinv2); 
     583         } 
     584          
     585         // bring x into [0, p) for next loop 
     586         if (x >= p) 
     587            x -= p; 
     588 
     589         word_t y = *next; 
     590         for (; bits > 0; bits--) 
     591         { 
     592            if (y & (((word_t) 1) << (WORD_BITS - 1))) 
     593               sum = SubMod(sum, x, p); 
     594            else 
     595               sum = AddMod(sum, x, p); 
     596 
     597            x = Redc(x * two_to_km1_redc, p, pinv2); 
     598            y <<= 1; 
     599         } 
     600          
     601         s = Redc(s * s_jump_redc, p, pinv2); 
     602      } 
     603 
     604      g_to_i = RedcFast(g_to_i * g_redc, p, pinv2); 
     605      g_to_km1_to_i = RedcFast(g_to_km1_to_i * g_to_km1_redc, p, pinv2); 
     606   } 
     607    
     608   // At this point, each table entry is at most p^2 (since x was always 
     609   // in [0, 2p), and the inner loop was called at most (p/2) / WORD_BITS 
     610   // times, and 2p * p/2 / WORD_BITS * TABLE_LG_SIZE <= p^2). 
     611    
     612   // =========== phase 2: consolidate table data 
     613    
     614   long weights[TABLE_SIZE]; 
     615   weights[0] = 0; 
     616   // we store the weights multiplied by a factor of 2^(3*LONG_BIT/2) to 
     617   // compensate for the three rounds of REDC reduction in the loop below 
     618   for (long h = 0, x = PowerMod(2, 3*LONG_BIT/2, p, pinv); 
     619        h < TABLE_LG_SIZE; h++, x = Redc(x * two_to_km1_redc, p, pinv2)) 
     620   { 
     621      for (long i = (1L << h) - 1; i >= 0; i--) 
     622      { 
     623         weights[2*i+1] = SubMod(weights[i], x, p); 
     624         weights[2*i]   = AddMod(weights[i], x, p); 
     625      } 
     626   } 
     627 
     628   long x_jump = PowerMod(two_to_km1, TABLE_LG_SIZE, p, pinv); 
     629   long x_jump_redc = MulMod(x_jump, F, p, pinv); 
     630 
     631   for (long h = NUM_TABLES - 1, x = 1; h >= 0; h--) 
     632   { 
     633      for (long i = 0; i < TABLE_SIZE; i++) 
     634      { 
     635         long y; 
     636         y = RedcFast(tables[h][i], p, pinv2); 
     637         y = RedcFast(y * weights[i], p, pinv2); 
     638         y = RedcFast(y * x, p, pinv2); 
     639         sum += y; 
     640      } 
     641 
     642      x = Redc(x * x_jump_redc, p, pinv2); 
     643   } 
     644 
     645   return sum % p; 
     646} 
     647 
     648 
     649 
     650/****************************************************************************** 
     651 
     652   Wrappers for bernsum_* 
     653 
     654******************************************************************************/ 
     655 
     656 
     657/* 
     658   Returns B_k/k mod p, in the range [0, p). 
     659    
     660   PRECONDITIONS: 
     661      5 <= p < NTL_SP_BOUND, p prime 
     662      2 <= k <= p-3, k even 
     663      pinv = 1 / ((double) p) 
     664    
     665   Algorithm: uses bernsum_powg() to compute the main sum. 
     666*/ 
     667long _bern_modp_powg(long p, double pinv, long k) 
     668{ 
     669   Factorisation F(p-1); 
     670   long g = primitive_root(p, pinv, F); 
     671 
     672   // compute main sum 
     673   long x = bernsum_powg(p, pinv, k, g); 
     674 
     675   // divide by (1 - g^k) and multiply by 2 
     676   long g_to_k = PowerMod(g, k, p, pinv); 
     677   long t = InvMod(p + 1 - g_to_k, p); 
     678   x = MulMod(x, t, p, pinv); 
     679   x = AddMod(x, x, p); 
     680    
     681   return x; 
     682} 
     683 
     684 
     685/* 
     686   Returns B_k/k mod p, in the range [0, p). 
     687    
     688   PRECONDITIONS: 
     689      5 <= p < NTL_SP_BOUND, p prime 
     690      2 <= k <= p-3, k even 
     691      pinv = 1 / ((double) p) 
     692      2^k != 1 mod p 
     693 
     694   Algorithm: uses bernsum_pow2() (or bernsum_pow2_redc() if p is small 
     695   enough) to compute the main sum. 
     696*/ 
     697long _bern_modp_pow2(long p, double pinv, long k) 
     698{ 
     699   Factorisation F(p-1); 
     700   long g = primitive_root(p, pinv, F); 
     701   long n = order(2, p, pinv, F); 
     702 
     703   // compute main sum 
     704   long x; 
     705   if (p < (1L << (LONG_BIT/2 - 1))) 
     706      x = bernsum_pow2_redc(p, pinv, k, g, n); 
     707   else 
     708      x = bernsum_pow2(p, pinv, k, g, n); 
     709 
     710   // divide by 2*(2^(-k) - 1) 
     711   long t = PowerMod(2, -k, p, pinv) - 1; 
     712   t = AddMod(t, t, p); 
     713   t = InvMod(t, p); 
     714   x = MulMod(x, t, p, pinv); 
     715    
     716   return x; 
     717} 
     718 
     719 
     720 
     721/* 
     722   Returns B_k/k mod p, in the range [0, p). 
     723    
     724   PRECONDITIONS: 
     725      5 <= p < NTL_SP_BOUND, p prime 
     726      2 <= k <= p-3, k even 
     727      pinv = 1 / ((double) p) 
     728*/ 
     729long _bern_modp(long p, double pinv, long k) 
     730{ 
     731   if (PowerMod(2, k, p, pinv) != 1) 
     732      // 2^k != 1 mod p, so we use the faster version 
     733      return _bern_modp_pow2(p, pinv, k); 
     734   else 
     735      // forced to use slower version 
     736      return _bern_modp_powg(p, pinv, k); 
     737} 
     738 
     739 
     740 
     741/****************************************************************************** 
     742 
     743   Main bern_modp() routine 
     744 
     745******************************************************************************/ 
     746 
     747long bern_modp(long p, long k) 
     748{ 
     749   assert(k >= 0); 
     750   assert(2 <= p && p < NTL_SP_BOUND); 
     751 
     752   // B_0 = 1 
     753   if (k == 0) 
     754      return 1; 
     755 
     756   // B_1 = -1/2 mod p 
     757   if (k == 1) 
     758   { 
     759      if (p == 2) 
     760         return -1; 
     761      return (p-1)/2; 
     762   } 
     763 
     764   // B_k = 0 for odd k >= 3 
     765   if (k & 1) 
     766      return 0; 
     767 
     768   // denominator of B_k is always divisible by 6 for k >= 2 
     769   if (p <= 3) 
     770      return -1; 
     771       
     772   // use Kummer's congruence (k = m mod p-1  =>  B_k/k = B_m/m mod p) 
     773   long m = k % (p-1); 
     774   if (m == 0) 
     775      return -1; 
     776 
     777   double pinv = 1 / ((double) p); 
     778   long x = _bern_modp(p, pinv, m);    // = B_m/m mod p 
     779   return MulMod(x, k, p, pinv); 
     780} 
     781 
     782 
     783};    // end namespace 
     784 
     785 
     786 
     787// end of file ================================================================ 
  • (a) /dev/null vs. (b) b/sage/rings/bernmm/bern_modp.h

    diff -r dcd714fbd837 sage/rings/bernmm/bern_modp.h
    a b  
     1/* 
     2   bern_modp.cpp:  computing isolated Bernoulli numbers modulo p 
     3    
     4   Copyright (C) 2008, David Harvey 
     5    
     6   This file is part of the bernmm package (version 1.0). 
     7 
     8   This program is free software: you can redistribute it and/or modify 
     9   it under the terms of the GNU General Public License as published by 
     10   the Free Software Foundation, either version 2 of the License, or 
     11   (at your option) any later version. 
     12 
     13   This program is distributed in the hope that it will be useful, 
     14   but WITHOUT ANY WARRANTY; without even the implied warranty of 
     15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
     16   GNU General Public License for more details. 
     17 
     18   You should have received a copy of the GNU General Public License 
     19   along with this program.  If not, see <http://www.gnu.org/licenses/>. 
     20*/ 
     21 
     22#ifndef BERNMM_BERN_MODP_H 
     23#define BERNMM_BERN_MODP_H 
     24 
     25 
     26namespace bernmm { 
     27 
     28 
     29/* 
     30   Returns B_k mod p, in [0, p), or -1 if B_k is not p-integral. 
     31 
     32   PRECONDITIONS: 
     33      2 <= p < NTL_SP_BOUND, p prime 
     34      k >= 0 
     35*/ 
     36long bern_modp(long p, long k); 
     37 
     38 
     39/* 
     40   Exported for testing. 
     41*/ 
     42long _bern_modp_powg(long p, double pinv, long k); 
     43long _bern_modp_pow2(long p, double pinv, long k); 
     44 
     45 
     46}; 
     47 
     48 
     49#endif 
     50 
     51// end of file ================================================================ 
  • (a) /dev/null vs. (b) b/sage/rings/bernmm/bern_modp_util.cpp

    diff -r dcd714fbd837 sage/rings/bernmm/bern_modp_util.cpp
    a b  
     1/* 
     2   bern_modp_util.cpp:  number-theoretic utility functions 
     3    
     4   Copyright (C) 2008, David Harvey 
     5    
     6   This file is part of the bernmm package (version 1.0). 
     7 
     8   This program is free software: you can redistribute it and/or modify 
     9   it under the terms of the GNU General Public License as published by 
     10   the Free Software Foundation, either version 2 of the License, or 
     11   (at your option) any later version. 
     12 
     13   This program is distributed in the hope that it will be useful, 
     14   but WITHOUT ANY WARRANTY; without even the implied warranty of 
     15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
     16   GNU General Public License for more details. 
     17 
     18   You should have received a copy of the GNU General Public License 
     19   along with this program.  If not, see <http://www.gnu.org/licenses/>. 
     20*/ 
     21 
     22 
     23#include <NTL/ZZ.h> 
     24#include "bern_modp_util.h" 
     25 
     26 
     27NTL_CLIENT; 
     28 
     29 
     30namespace bernmm { 
     31 
     32 
     33long PowerMod(long a, long ee, long n, double ninv) 
     34{ 
     35   long x, y; 
     36 
     37   unsigned long e; 
     38 
     39   if (ee < 0) 
     40      e = - ((unsigned long) ee); 
     41   else 
     42      e = ee; 
     43 
     44   x = 1; 
     45   y = a; 
     46   while (e) { 
     47      if (e & 1) x = MulMod(x, y, n, ninv); 
     48      y = MulMod(y, y, n, ninv); 
     49      e = e >> 1; 
     50   } 
     51 
     52   if (ee < 0) x = InvMod(x, n); 
     53 
     54   return x; 
     55} 
     56 
     57 
     58 
     59void Factorisation::helper(long k, long m) 
     60{ 
     61   if (m == 1) 
     62      return; 
     63 
     64   for (long i = k + 1; i * i <= m; i++) 
     65   { 
     66      if (m % i == 0) 
     67      { 
     68         // found a factor 
     69         factors.push_back(i); 
     70         // remove that factor entirely 
     71         for (m /= i; m % i == 0; m /= i); 
     72         // recurse 
     73         helper(i, m); 
     74         return; 
     75      } 
     76   } 
     77 
     78   // no more factors 
     79   factors.push_back(m); 
     80} 
     81 
     82 
     83Factorisation::Factorisation(long n) 
     84{ 
     85   this->n = n; 
     86   helper(1, n); 
     87} 
     88 
     89 
     90PrimeTable::PrimeTable(long bound) 
     91{ 
     92   long size = (bound - 1) / LONG_BIT + 1;   // = ceil(bound / LONG_BIT) 
     93   data.resize(size); 
     94    
     95   for (long i = 2; i * i < bound; i++) 
     96      if (is_prime(i)) 
     97         for (long j = 2*i; j < bound; j += i) 
     98            set(j); 
     99} 
     100 
     101 
     102long order(long x, long p, double pinv, const Factorisation& F) 
     103{ 
     104   // in the loop below, m is always some multiple of the order of x 
     105   long m = p - 1; 
     106 
     107   // try to remove factors from m until we can't remove any more 
     108   for (int i = 0; i < F.factors.size(); i++) 
     109   { 
     110      long q = F.factors[i]; 
     111 
     112      while (m % q == 0) 
     113      { 
     114         long mm = m / q; 
     115         if (PowerMod(x, mm, p, pinv) != 1) 
     116            break; 
     117         m = mm; 
     118      } 
     119   } 
     120 
     121   return m; 
     122} 
     123 
     124 
     125 
     126long primitive_root(long p, double pinv, const Factorisation& F) 
     127{ 
     128   if (p == 2) 
     129      return 1; 
     130 
     131   long g = 2; 
     132   for (; g < p; g++) 
     133      if (order(g, p, pinv, F) == p - 1) 
     134         return g; 
     135          
     136   // no generator exists!? 
     137   abort(); 
     138} 
     139 
     140 
     141 
     142};    // end namespace 
     143 
     144 
     145// end of file ================================================================ 
  • (a) /dev/null vs. (b) b/sage/rings/bernmm/bern_modp_util.h

    diff -r dcd714fbd837 sage/rings/bernmm/bern_modp_util.h
    a b  
     1/* 
     2   bern_modp_util.h:  number-theoretic utility functions 
     3    
     4   Copyright (C) 2008, David Harvey 
     5    
     6   This file is part of the bernmm package (version 1.0). 
     7 
     8   This program is free software: you can redistribute it and/or modify 
     9   it under the terms of the GNU General Public License as published by 
     10   the Free Software Foundation, either version 2 of the License, or 
     11   (at your option) any later version. 
     12 
     13   This program is distributed in the hope that it will be useful, 
     14   but WITHOUT ANY WARRANTY; without even the implied warranty of 
     15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
     16   GNU General Public License for more details. 
     17 
     18   You should have received a copy of the GNU General Public License 
     19   along with this program.  If not, see <http://www.gnu.org/licenses/>. 
     20*/ 
     21 
     22 
     23#ifndef BERNMM_BERN_MODP_UTIL_H 
     24#define BERNMM_BERN_MODP_UTIL_H 
     25 
     26 
     27#include <vector> 
     28#include <cassert> 
     29 
     30 
     31namespace bernmm { 
     32 
     33 
     34/* 
     35   Same as NTL's PowerMod, but also accepts an _ninv_ parameter, which is the 
     36   same as the ninv parameter for NTL's MulMod routines, i.e. should have 
     37   ninv = 1 / ((double) n). 
     38    
     39   (Implementation is adapted from ZZ.c in NTL 5.4.1.) 
     40*/ 
     41long PowerMod(long a, long ee, long n, double ninv); 
     42 
     43 
     44/* 
     45   Represents the factorisation of an integer n into distinct prime factors. 
     46    
     47   (Very naive implementation!) 
     48*/ 
     49class Factorisation 
     50{ 
     51protected: 
     52   /* 
     53      Finds distinct prime factors of m in the range k < p <= m. 
     54      Assumes that m does not have any prime factors p <= k. 
     55      Appends factors found to _factors_. 
     56   */ 
     57   void helper(long k, long m); 
     58 
     59public: 
     60   // the integer 
     61   long n; 
     62 
     63   // the distinct factors (in increasing order) 
     64   std::vector<long> factors; 
     65 
     66   // initialises with given integer 
     67   Factorisation(long n); 
     68}; 
     69 
     70 
     71 
     72class PrimeTable 
     73{ 
     74private: 
     75   std::vector<long> data;   // bit-vector; 0 means prime, 1 means composite 
     76    
     77   // read bit from index i 
     78   inline bool get(long i) const 
     79   { 
     80      return (data[i / LONG_BIT] >> (i % LONG_BIT)) & 1; 
     81   } 
     82 
     83   // set bit at index i 
     84   inline void set(long i) 
     85   { 
     86      data[i / LONG_BIT] |= (1L << (i % LONG_BIT)); 
     87   } 
     88 
     89 
     90public: 
     91   // initialise with primes up to given bound 
     92   PrimeTable(long bound); 
     93    
     94   // test whether n is prime by table lookup 
     95   inline bool is_prime(long n) const 
     96   { 
     97      return !get(n); 
     98   } 
     99    
     100   // returns smallest prime p that is larger than n 
     101   long next_prime(long n) const 
     102   { 
     103      for (n++; !is_prime(n); n++); 
     104      return n; 
     105   } 
     106}; 
     107 
     108 
     109 
     110/* 
     111   Returns 1 if n is prime. 
     112*/ 
     113int is_prime(long n); 
     114 
     115 
     116/* 
     117   Returns smallest prime larger than p. 
     118*/ 
     119long next_prime(long p); 
     120 
     121 
     122/* 
     123   Computes order of x mod p, given the factorisation F of p-1. 
     124*/ 
     125long order(long x, long p, double pinv, const Factorisation& F); 
     126 
     127 
     128/* 
     129   Finds the smallest primitive root mod p, given the factorisation F of p-1. 
     130*/ 
     131long primitive_root(long p, double pinv, const Factorisation& F); 
     132 
     133 
     134};    // end namespace 
     135 
     136 
     137#endif 
     138 
     139// end of file ================================================================ 
  • (a) /dev/null vs. (b) b/sage/rings/bernmm/bern_rat.cpp

    diff -r dcd714fbd837 sage/rings/bernmm/bern_rat.cpp
    a b  
     1/* 
     2   bern_rat.cpp:  multi-modular algorithm for computing Bernoulli numbers 
     3    
     4   Copyright (C) 2008, David Harvey 
     5    
     6   This file is part of the bernmm package (version 1.0). 
     7 
     8   This program is free software: you can redistribute it and/or modify 
     9   it under the terms of the GNU General Public License as published by 
     10   the Free Software Foundation, either version 2 of the License, or 
     11   (at your option) any later version. 
     12 
     13   This program is distributed in the hope that it will be useful, 
     14   but WITHOUT ANY WARRANTY; without even the implied warranty of 
     15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
     16   GNU General Public License for more details. 
     17 
     18   You should have received a copy of the GNU General Public License 
     19   along with this program.  If not, see <http://www.gnu.org/licenses/>. 
     20*/ 
     21 
     22#include <gmp.h> 
     23#include <NTL/ZZ.h> 
     24#include <cmath> 
     25#include <vector> 
     26#include <set> 
     27#include "bern_modp_util.h" 
     28#include "bern_modp.h" 
     29#include "bern_rat.h" 
     30 
     31#ifdef USE_THREADS 
     32#include <pthread.h> 
     33#endif 
     34 
     35 
     36using namespace std; 
     37using namespace NTL; 
     38 
     39 
     40namespace bernmm { 
     41 
     42 
     43/* 
     44   Computes the denominator of B_k using Clausen/von Staudt. 
     45*/ 
     46void bern_den(mpz_t res, long k, const PrimeTable& table) 
     47{ 
     48   mpz_set_ui(res, 1); 
     49 
     50   // loop through factors of k 
     51   for (long f = 1; f*f <= k; f++) 
     52   { 
     53      // if f divides k.... 
     54      if (k % f == 0) 
     55      { 
     56         // ... then both f + 1 and k/f + 1 are candidates for primes 
     57         // dividing the denominator of B_k 
     58         if (table.is_prime(f + 1)) 
     59            mpz_mul_ui(res, res, f + 1); 
     60 
     61         if (f*f != k) 
     62            if (table.is_prime(k/f + 1)) 
     63               mpz_mul_ui(res, res, k/f + 1); 
     64      } 
     65   } 
     66} 
     67 
     68 
     69// width of interval for each block 
     70#define BLOCK_SIZE 1000 
     71 
     72 
     73/* 
     74   Represents that B_k is congruent to _residue_ modulo _modulus_. 
     75*/ 
     76struct Item 
     77{ 
     78   mpz_t modulus; 
     79   mpz_t residue; 
     80    
     81   Item() 
     82   { 
     83      mpz_init(modulus); 
     84      mpz_init(residue); 
     85   } 
     86    
     87   ~Item() 
     88   { 
     89      mpz_clear(residue); 
     90      mpz_clear(modulus); 
     91   } 
     92}; 
     93 
     94 
     95/* 
     96   Items get sorted by modulus. 
     97*/ 
     98struct Item_cmp 
     99{ 
     100   bool operator()(const Item* x, const Item* y) 
     101   { 
     102      return mpz_cmp(x->modulus, y->modulus) < 0; 
     103   } 
     104}; 
     105 
     106 
     107/* 
     108   Returns new Item that combines information from op1 and op2 via CRT. 
     109*/ 
     110Item* CRT(Item* op1, Item* op2) 
     111{ 
     112   Item* res = new Item; 
     113 
     114   // let n1, n2 be the moduli, and r1, r2 be the residues 
     115    
     116   // res->modulus = t, where t = 0 mod n1, t = 1 mod n2 
     117   mpz_invert(res->modulus, op1->modulus, op2->modulus); 
     118   mpz_mul(res->modulus, res->modulus, op1->modulus); 
     119 
     120   // res->residue = r2 - r1 
     121   mpz_sub(res->residue, op2->residue, op1->residue); 
     122   // res->residue = t * (r2 - r1) 
     123   mpz_mul(res->residue, res->residue, res->modulus); 
     124   // res->residue = r1 + t * (r2 - r1) 
     125   mpz_add(res->residue, res->residue, op1->residue); 
     126   // res->modulus = n1 * n2 
     127   mpz_mul(res->modulus, op1->modulus, op2->modulus); 
     128   // res->residue = r1 mod n1, r2 = mod n2 
     129   mpz_mod(res->residue, res->residue, res->modulus); 
     130    
     131   return res; 
     132} 
     133 
     134 
     135struct State 
     136{ 
     137   long k; 
     138   long bound;   // only use primes less than this bound 
     139   const PrimeTable* table; 
     140    
     141   // index of block that should be processed next 
     142   long next; 
     143    
     144   std::set<Item*, Item_cmp> items; 
     145#ifdef USE_THREADS 
     146   pthread_mutex_t lock; 
     147#endif 
     148    
     149   State(long k, long bound, const PrimeTable& table) 
     150   { 
     151      this->k = k; 
     152      this->bound = bound; 
     153      this->next = 0; 
     154      this->table = &table; 
     155#ifdef USE_THREADS 
     156      pthread_mutex_init(&lock, NULL); 
     157#endif 
     158   } 
     159    
     160   ~State() 
     161   { 
     162#ifdef USE_THREADS 
     163      pthread_mutex_destroy(&lock); 
     164#endif 
     165   } 
     166}; 
     167 
     168 
     169void* worker(void* arg) 
     170{ 
     171   State& state = *((State*) arg); 
     172   long k = state.k; 
     173 
     174#ifdef USE_THREADS 
     175   pthread_mutex_lock(&state.lock); 
     176#endif 
     177 
     178   while (1) 
     179   { 
     180      if (state.next * BLOCK_SIZE < state.bound) 
     181      { 
     182         // need to generate more modular data 
     183          
     184         long next = state.next++; 
     185#ifdef USE_THREADS 
     186         pthread_mutex_unlock(&state.lock); 
     187#endif 
     188 
     189         Item* item = new Item; 
     190 
     191         mpz_set_ui(item->modulus, 1); 
     192         mpz_set_ui(item->residue, 0); 
     193          
     194         for (long p = max(5, state.table->next_prime(next * BLOCK_SIZE)); 
     195              p < state.bound && p < (next+1) * BLOCK_SIZE; 
     196              p = state.table->next_prime(p)) 
     197         { 
     198            if (k % (p-1) == 0) 
     199               continue; 
     200 
     201            // compute B_k mod p 
     202            long b = bern_modp(p, k); 
     203             
     204            // CRT into running total 
     205            long x = MulMod(SubMod(b, mpz_fdiv_ui(item->residue, p), p), 
     206                            InvMod(mpz_fdiv_ui(item->modulus, p), p), p); 
     207            mpz_addmul_ui(item->residue, item->modulus, x); 
     208            mpz_mul_ui(item->modulus, item->modulus, p); 
     209         } 
     210          
     211#ifdef USE_THREADS 
     212         pthread_mutex_lock(&state.lock); 
     213#endif 
     214         state.items.insert(item); 
     215      } 
     216      else 
     217      { 
     218         // all modular data has been generated 
     219 
     220         if (state.items.size() <= 1) 
     221         { 
     222            // no more CRTs for this thread to perform 
     223#ifdef USE_THREADS 
     224            pthread_mutex_unlock(&state.lock); 
     225#endif 
     226            return NULL; 
     227         } 
     228          
     229         // CRT two smallest items together 
     230         Item* item1 = *(state.items.begin()); 
     231         state.items.erase(state.items.begin()); 
     232         Item* item2 = *(state.items.begin()); 
     233         state.items.erase(state.items.begin()); 
     234#ifdef USE_THREADS 
     235         pthread_mutex_unlock(&state.lock); 
     236#endif 
     237          
     238         Item* item3 = CRT(item1, item2); 
     239         delete item1; 
     240         delete item2; 
     241 
     242#ifdef USE_THREADS 
     243         pthread_mutex_lock(&state.lock); 
     244#endif 
     245         state.items.insert(item3); 
     246      } 
     247   } 
     248} 
     249 
     250 
     251void bern_rat(mpq_t res, long k, int num_threads) 
     252{ 
     253   // special cases 
     254 
     255   if (k == 0) 
     256   { 
     257      // B_0 = 1 
     258      mpq_set_ui(res, 1, 1); 
     259      return; 
     260   } 
     261    
     262   if (k == 1) 
     263   { 
     264      // B_1 = -1/2 
     265      mpq_set_si(res, -1, 2); 
     266      return; 
     267   } 
     268    
     269   if (k == 2) 
     270   { 
     271      // B_2 = 1/6 
     272      mpq_set_si(res, 1, 6); 
     273      return; 
     274   } 
     275    
     276   if (k & 1) 
     277   { 
     278      // B_k = 0 if k is odd 
     279      mpq_set_ui(res, 0, 1); 
     280      return; 
     281   } 
     282 
     283   if (num_threads <= 0) 
     284      num_threads = 1; 
     285 
     286   mpz_t num, den; 
     287   mpz_init(num); 
     288   mpz_init(den); 
     289    
     290   const double log2 =    0.69314718055994528622676; 
     291   const double invlog2 = 1.44269504088896340735992;   // = 1/log(2) 
     292 
     293   // compute preliminary prime bound and build prime table 
     294   long bound1 = (long) max(37.0, ceil((k + 0.5) * log(k) * invlog2)); 
     295   PrimeTable table(bound1); 
     296 
     297   // compute denominator of B_k 
     298   bern_den(den, k, table); 
     299 
     300   // compute number of bits we need to resolve the numerator 
     301   long bits = (long) ceil((k + 0.5) * log(k) * invlog2 - 4.094 * k + 2.470 
     302                                      + log(mpz_get_d(den)) * invlog2); 
     303 
     304   // compute tighter prime bound 
     305   // (note: we can safely get away with double-precision here. It would 
     306   // only start being insufficient around k = 10^13 or so, which is totally 
     307   // impractical at present.) 
     308   double prod = 1.0; 
     309   long prod_bits = 0; 
     310   long p; 
     311   for (p = 5; prod_bits < bits + 1; p = table.next_prime(p)) 
     312   { 
     313      if (p >= NTL_SP_BOUND) 
     314         abort();   // !!!!! not sure what else we can do here... 
     315      if (k % (p-1) != 0) 
     316         prod *= (double) p; 
     317      int exp; 
     318      prod = frexp(prod, &exp); 
     319      prod_bits += exp; 
     320   } 
     321   long bound2 = p; 
     322 
     323   State state(k, bound2, table); 
     324 
     325#ifdef USE_THREADS 
     326   vector<pthread_t> threads(num_threads - 1); 
     327 
     328   // spawn worker threads to process blocks 
     329   for (long i = 0; i < num_threads - 1; i++) 
     330      pthread_create(&threads[i], NULL, worker, &state); 
     331#endif 
     332    
     333   worker(&state);    // make this thread a worker too 
     334 
     335#ifdef USE_THREADS 
     336   for (long i = 0; i < num_threads - 1; i++) 
     337      pthread_join(threads[i], NULL); 
     338#endif 
     339       
     340   // reconstruct B_k as a rational number 
     341   Item* item = *(state.items.begin()); 
     342   mpz_mul(num, item->residue, den); 
     343   mpz_mod(num, num, item->modulus); 
     344    
     345   if (k % 4 == 0) 
     346   { 
     347      // B_k is negative 
     348      mpz_sub(num, item->modulus, num); 
     349      mpz_neg(num, num); 
     350   } 
     351 
     352   delete item; 
     353    
     354   mpz_swap(num, mpq_numref(res)); 
     355   mpz_swap(den, mpq_denref(res)); 
     356    
     357   mpz_clear(num); 
     358   mpz_clear(den); 
     359} 
     360 
     361 
     362 
     363};    // end namespace 
     364 
     365 
     366// end of file ================================================================ 
  • (a) /dev/null vs. (b) b/sage/rings/bernmm/bern_rat.h

    diff -r dcd714fbd837 sage/rings/bernmm/bern_rat.h
    a b  
     1/* 
     2   bern_rat.h:  multi-modular algorithm for computing Bernoulli numbers 
     3    
     4   Copyright (C) 2008, David Harvey 
     5    
     6   This file is part of the bernmm package (version 1.0). 
     7 
     8   This program is free software: you can redistribute it and/or modify 
     9   it under the terms of the GNU General Public License as published by 
     10   the Free Software Foundation, either version 2 of the License, or 
     11   (at your option) any later version. 
     12 
     13   This program is distributed in the hope that it will be useful, 
     14   but WITHOUT ANY WARRANTY; without even the implied warranty of 
     15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
     16   GNU General Public License for more details. 
     17 
     18   You should have received a copy of the GNU General Public License 
     19   along with this program.  If not, see <http://www.gnu.org/licenses/>. 
     20*/ 
     21 
     22#ifndef BERNMM_BERN_RAT_H 
     23#define BERNMM_BERN_RAT_H 
     24 
     25#include <gmp.h> 
     26 
     27 
     28namespace bernmm { 
     29 
     30 
     31/* 
     32   Returns B_k as a rational number, stored at _res_. 
     33    
     34   k must be >= 0. 
     35    
     36   Uses _num_threads_ threads. (If USE_THREADS is not #defined, the code just 
     37   uses one thread.) 
     38*/ 
     39void bern_rat(mpq_t res, long k, int num_threads); 
     40 
     41 
     42 
     43};    // end namespace 
     44 
     45 
     46#endif 
     47 
     48// end of file ================================================================ 
  • (a) /dev/null vs. (b) b/sage/rings/bernmm/bernmm-test.cpp

    diff -r dcd714fbd837 sage/rings/bernmm/bernmm-test.cpp
    a b  
     1/* 
     2   bernmm-test.cpp:  test module 
     3    
     4   Copyright (C) 2008, David Harvey 
     5    
     6   This file is part of the bernmm package (version 1.0). 
     7 
     8   This program is free software: you can redistribute it and/or modify 
     9   it under the terms of the GNU General Public License as published by 
     10   the Free Software Foundation, either version 2 of the License, or 
     11   (at your option) any later version. 
     12 
     13   This program is distributed in the hope that it will be useful, 
     14   but WITHOUT ANY WARRANTY; without even the implied warranty of 
     15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
     16   GNU General Public License for more details. 
     17 
     18   You should have received a copy of the GNU General Public License 
     19   along with this program.  If not, see <http://www.gnu.org/licenses/>. 
     20*/ 
     21 
     22#include <iostream> 
     23#include <NTL/ZZ.h> 
     24#include <gmp.h> 
     25#include "bern_modp_util.h" 
     26#include "bern_modp.h" 
     27#include "bern_rat.h" 
     28 
     29 
     30NTL_CLIENT; 
     31 
     32 
     33using namespace bernmm; 
     34using namespace std; 
     35 
     36 
     37/* 
     38   Computes B_0, B_1, ..., B_{n-1} using naive algorithm, writes them to res. 
     39*/ 
     40void bern_naive(mpq_t* res, long n) 
     41{ 
     42   mpq_t t, u; 
     43   mpq_init(t); 
     44   mpq_init(u); 
     45 
     46   // compute res[j] = B_j / j! for 0 <= j < n 
     47   if (n > 0) 
     48      mpq_set_si(res[0], 1, 1); 
     49 
     50   for (long j = 1; j < n; j++) 
     51   { 
     52      mpq_set_si(res[j], 0, 1); 
     53      mpq_set_ui(t, 1, 1); 
     54      for (long k = 0; k < j; k++) 
     55      { 
     56         mpz_mul_ui(mpq_denref(t), mpq_denref(t), k + 2); 
     57         mpq_mul(u, res[j - 1 - k], t); 
     58         mpq_sub(res[j], res[j], u); 
     59      } 
     60   } 
     61 
     62   // multiply through by j! for 0 <= j < n 
     63   mpq_set_ui(t, 1, 1); 
     64   for (long j = 2; j < n; j++) 
     65   { 
     66      mpz_mul_ui(mpq_numref(t), mpq_numref(t), j); 
     67      mpq_mul(res[j], res[j], t); 
     68   } 
     69 
     70   mpq_clear(u); 
     71   mpq_clear(t); 
     72} 
     73 
     74 
     75/* 
     76   Tests _bern_modp_powg() for a given p and k by comparing against the 
     77   rational number B_k (must be supplied in b). 
     78    
     79   Returns 1 on success. 
     80*/ 
     81int testcase__bern_modp_powg(long p, long k, mpq_t b) 
     82{ 
     83   double pinv = 1 / ((double) p); 
     84    
     85   // compute B_k mod p using _bern_modp_powg() 
     86   long x = _bern_modp_powg(p, pinv, k); 
     87   x = MulMod(x, k, p, pinv); 
     88    
     89   // compute B_k mod p from rational B_k 
     90   long y = mpz_fdiv_ui(mpq_numref(b), p); 
     91   long z = mpz_fdiv_ui(mpq_denref(b), p); 
     92   return y == MulMod(z, x, p, pinv); 
     93} 
     94 
     95 
     96 
     97/* 
     98   Tests _bern_modp_powg() by comparing against naive computation of B_k 
     99   (as a rational) for a range of small p and k. 
     100    
     101   Returns 1 on success. 
     102*/ 
     103int test__bern_modp_powg() 
     104{ 
     105   int success = 1; 
     106 
     107   const long MAX = 300; 
     108   mpq_t bern[MAX]; 
     109    
     110   // compute B_k's as rational numbers using naive algorithm 
     111   for (long i = 0; i < MAX; i++) 
     112      mpq_init(bern[i]); 
     113   bern_naive(bern, MAX); 
     114 
     115   // try a range of k's 
     116   for (long k = 2; k < MAX && success; k += 2) 
     117   { 
     118      // try a range of small p's 
     119      for (long p = k + 3; p < 2*MAX && success; p += 2) 
     120      { 
     121         if (!ProbPrime(p)) 
     122            continue; 
     123         success = success && testcase__bern_modp_powg(p, k, bern[k]); 
     124      } 
     125       
     126      // try a single larger p 
     127      success = success && testcase__bern_modp_powg(1000003, k, bern[k]); 
     128   } 
     129 
     130   // if we're on a 32-bit machine, try a single example with p right near 
     131   // NTL's boundary (this is infeasible on a 64-bit machine) 
     132   if (NTL_SP_NBITS <= 32) 
     133   { 
     134      long p = NTL_SP_BOUND - 1; 
     135      while (!ProbPrime(p)) 
     136         p--; 
     137 
     138      long k = (MAX/2)*2 - 2; 
     139      success = success && testcase__bern_modp_powg(p, k, bern[k]); 
     140   } 
     141 
     142   for (long i = 0; i < MAX; i++) 
     143      mpq_clear(bern[i]); 
     144    
     145   return success; 
     146} 
     147 
     148 
     149 
     150/* 
     151   Tests _bern_modp_pow2() for a given p and k by comparing against result 
     152   from _bern_modp_powg(). 
     153 
     154   Returns 1 on success. 
     155    
     156   If 2^k = 1 mod p, then _bern_modp_pow2() won't work, so it just returns 1. 
     157*/ 
     158int testcase__bern_modp_pow2(long p, long k) 
     159{ 
     160   double pinv = 1 / ((double) p); 
     161 
     162   if (PowerMod(2, k, p, pinv) == 1) 
     163      return 1; 
     164 
     165   long x = _bern_modp_powg(p, pinv, k); 
     166   long y = _bern_modp_pow2(p, pinv, k); 
     167    
     168   return x == y; 
     169} 
     170 
     171 
     172 
     173/* 
     174   Tests _bern_modp_pow2() by comparing against _bern_modp_powg() for 
     175   a range of p and k. 
     176    
     177   Returns 1 on success. 
     178*/ 
     179int test__bern_modp_pow2() 
     180{ 
     181   int success = 1; 
     182 
     183   // exhaustive comparison over some small p and k 
     184   for (long p = 5; p < 2000 && success; p += 2) 
     185   { 
     186      if (!ProbPrime(p)) 
     187         continue; 
     188 
     189      for (long k = 2; k <= p - 3 && success; k += 2) 
     190         success = success && testcase__bern_modp_pow2(p, k); 
     191   } 
     192 
     193   // a few larger values of p 
     194   for (long p = 1000000; p < 1030000; p++) 
     195   { 
     196      if (!ProbPrime(p)) 
     197         continue; 
     198       
     199      long k = 2 * (rand() % ((p-3)/2)) + 2; 
     200      success = success && testcase__bern_modp_pow2(p, k); 
     201   } 
     202    
     203   // if we're on a 32-bit machine, try a single example with p right near 
     204   // NTL's boundary (this is infeasible on a 64-bit machine) 
     205   if (NTL_SP_NBITS <= 32) 
     206   { 
     207      long p = NTL_SP_BOUND - 1; 
     208      while (!ProbPrime(p)) 
     209         p--; 
     210      success = success & testcase__bern_modp_pow2(p, 10); 
     211   } 
     212    
     213   // try a few just below the REDC barrier 
     214   if (LONG_BIT == 32) 
     215   { 
     216      long boundary = 1L << (LONG_BIT/2 - 1); 
     217      for (long p = boundary - 1000; p < boundary && success; p++) 
     218      { 
     219         if (ProbPrime(p)) 
     220         { 
     221            for (long trial = 0; trial < 1000 && success; trial++) 
     222            { 
     223               long k = 2 * (rand() % ((p-3)/2)) + 2; 
     224               success = success && testcase__bern_modp_pow2(p, k); 
     225            } 
     226         } 
     227      } 
     228   } 
     229   else 
     230   { 
     231      // on a 64-bit machine, only try one, since these are huge! 
     232      long p = 1L << (LONG_BIT/2 - 1); 
     233      while (!ProbPrime(p)) 
     234         p--; 
     235      success = success && testcase__bern_modp_pow2(p, 10); 
     236   } 
     237 
     238   return success; 
     239} 
     240 
     241 
     242/* 
     243   Tests bern_rat() by comparing against the naive algorithm for several small 
     244   k, and testing against bern_modp() for a couple of larger k. 
     245    
     246   Returns 1 on success. 
     247*/ 
     248int test_bern_rat() 
     249{ 
     250   int success = 1; 
     251 
     252   const long MAX = 300; 
     253   mpq_t bern[MAX]; 
     254    
     255   // compute B_k's as rational numbers using naive algorithm 
     256   for (long i = 0; i < MAX; i++) 
     257      mpq_init(bern[i]); 
     258   bern_naive(bern, MAX); 
     259 
     260   mpq_t x; 
     261   mpq_init(x); 
     262 
     263   // exhaustive test for small k 
     264   for (long k = 0; k < MAX && success; k++) 
     265   { 
     266      bern_rat(x, k, 4);    // try with 4 threads just for fun 
     267      success = success && mpq_equal(x, bern[k]); 
     268   } 
     269    
     270   // try a few larger k 
     271   for (long i = 0; i < 50 && success; i++) 
     272   { 
     273      long k = ((random() % 20000) / 2) * 2; 
     274      bern_rat(x, k, 4); 
     275       
     276      // compare with modular information 
     277      long p = 1000003; 
     278      long num = mpz_fdiv_ui(mpq_numref(x), p); 
     279      long den = mpz_fdiv_ui(mpq_denref(x), p); 
     280      success = success && (MulMod(bern_modp(p, k), den, p) == num); 
     281   } 
     282 
     283   mpq_clear(x); 
     284   for (long i = 0; i < MAX; i++) 
     285      mpq_clear(bern[i]); 
     286    
     287   return success; 
     288} 
     289 
     290 
     291void report(int success) 
     292{ 
     293   if (success) 
     294      cout << "ok" << endl; 
     295   else 
     296   { 
     297      cout << "failed!" << endl; 
     298      abort(); 
     299   } 
     300} 
     301 
     302 
     303int main(int argc, char* argv[]) 
     304{ 
     305   if (argc == 1) 
     306   { 
     307      cout << "bernmm test module" << endl; 
     308      cout << endl; 
     309      cout << "   bernmm-test --test" << endl; 
     310      cout << "        runs test suite" << endl; 
     311      cout << "   bernmm-test --rational <k> <threads>" << endl; 
     312      cout << "        computes B_k with <threads> threads" << endl; 
     313      cout << "   bernmm-test --modular <p> <k>" << endl; 
     314      cout << "        computes B_k mod p" << endl; 
     315      return 0; 
     316   } 
     317    
     318   if (!strcmp(argv[1], "--test")) 
     319   { 
     320      cout << "testing _bern_modp_powg()... " << flush; 
     321      report(test__bern_modp_powg()); 
     322 
     323      cout << "testing _bern_modp_pow2()... " << flush; 
     324      report(test__bern_modp_pow2()); 
     325 
     326      cout << "testing bern_rat()... " << flush; 
     327      report(test_bern_rat()); 
     328   } 
     329   else if (!strcmp(argv[1], "--rational")) 
     330   { 
     331      if (argc <= 3) 
     332      { 
     333         cout << "not enough arguments" << endl; 
     334         return 0; 
     335      } 
     336      long k = atol(argv[2]); 
     337      long threads = atol(argv[3]); 
     338      mpq_t r; 
     339      mpq_init(r); 
     340      bern_rat(r, k, threads); 
     341      gmp_printf("%Zd/%Zd\n", mpq_numref(r), mpq_denref(r)); 
     342      mpq_clear(r); 
     343   } 
     344   else if (!strcmp(argv[1], "--modular")) 
     345   { 
     346      if (argc <= 3) 
     347      { 
     348         cout << "not enough arguments" << endl; 
     349         return 0; 
     350      } 
     351      long p = atol(argv[2]); 
     352      long k = atol(argv[3]); 
     353      cout << bern_modp(p, k) << endl; 
     354   } 
     355   else 
     356   { 
     357      cout << "unknown command" << endl; 
     358   } 
     359 
     360   return 0; 
     361} 
     362 
     363 
     364// end of file ================================================================ 
  • sage/rings/bernoulli_mod_p.pyx

    diff -r dcd714fbd837 sage/rings/bernoulli_mod_p.pyx
    a b  
    66    - William Stein (2006-07-28): some touch up.   
    77    - David Harvey (2006-08-06): new, faster algorithm, also using faster NTL interface 
    88    - David Harvey (2007-08-31): algorithm for a single bernoulli number mod p 
     9    - David Harvey (2008-06): added interface to bernmm, removed old code 
    910""" 
    1011 
    1112#***************************************************************************** 
     
    2930from sage.libs.ntl.ntl_ZZ_pX cimport ntl_ZZ_pX 
    3031import sage.libs.pari.gen 
    3132from sage.rings.integer_mod_ring import Integers 
     33from sage.rings.bernmm import bernmm_bern_modp 
    3234 
    3335 
    3436 
     
    208210 
    209211 
    210212 
    211 def bernoulli_mod_p_single(int p, int k): 
     213def bernoulli_mod_p_single(long p, long k): 
    212214    r""" 
    213215    Returns the bernoulli number $B_k$ mod $p$. 
    214216 
     217    If $B_k$ is not $p$-integral, an ArithmeticError is raised. 
     218 
    215219    INPUT: 
    216220        p -- integer, a prime 
    217         k -- even integer in the range $0 \leq k \leq p-3$ 
     221        k -- non-negative integer 
    218222     
    219223    OUTPUT: 
    220224        The $k$-th bernoulli number mod $p$. 
    221  
    222     ALGORITHM: 
    223         Uses the identity 
    224           $$ (1-g^k) B_k/k = 2\sum_{r=1}^{(p-1)/2} g^{r(k-1)} ( [g^r/p] - g [g^(r-1)/p] + (g-1)/2 ), $$ 
    225         where $g$ is a primitive root mod $p$, and where square brackets 
    226         denote the fractional part. This identity can be derived from 
    227         Theorem 2.3, chapter 2 of Lang's book "Cyclotomic fields". 
    228      
    229     PERFORMANCE: 
    230         Linear in $p$. In particular the running time doesn't depend on k. 
    231          
    232         It's much faster than computing *all* bernoulli numbers by using 
    233         bernoulli_mod_p(). For p = 1000003, the latter takes about 3s on my 
    234         laptop, whereas this function takes only 0.06s. 
    235          
    236         It may or may not be faster than computing literally bernoulli(k) % p, 
    237         depending on how big k and p are relative to each other. For example on 
    238         my laptop, computing bernoulli(2000) % p only takes 0.01s. But 
    239         computing bernoulli(100000) % p takes 40s, whereas this function still 
    240         takes only 0.06s. 
    241225 
    242226    EXAMPLES: 
    243227        sage: bernoulli_mod_p_single(1009, 48) 
     
    256240        ValueError: p (=100) must be a prime 
    257241         
    258242        sage: bernoulli_mod_p_single(19, 5) 
    259         Traceback (most recent call last): 
    260         ... 
    261         ValueError: k (=5) must be even 
     243        0 
    262244         
    263245        sage: bernoulli_mod_p_single(19, 18) 
    264246        Traceback (most recent call last): 
    265247        ... 
    266         ValueError: k (=18) must be non-negative, and at most p-3 
     248        ArithmeticError: B_k is not integral at p 
    267249         
    268250        sage: bernoulli_mod_p_single(19, -4) 
    269251        Traceback (most recent call last): 
    270252        ... 
    271         ValueError: k (=-4) must be non-negative, and at most p-3 
     253        ValueError: k must be non-negative 
    272254     
    273255    Check results against bernoulli_mod_p: 
    274256     
     
    299281          
    300282    AUTHOR: 
    301283        -- David Harvey (2007-08-31) 
     284        -- David Harvey (2008-06): rewrote to use bernmm library 
    302285 
    303286    """ 
    304287    if p <= 2: 
     
    309292 
    310293    R = Integers(p) 
    311294 
    312     if k == 0: 
    313         return R(1) 
    314          
    315     if k & 1: 
    316         raise ValueError, "k (=%s) must be even" % k 
    317          
    318     if k < 0 or k > p-3: 
    319         raise ValueError, "k (=%s) must be non-negative, and at most p-3" % k 
    320      
    321     g = R.multiplicative_generator() 
    322     cdef llong g_lift = g.lift() 
    323     cdef llong g_to_km1 = (g**(k-1)).lift() 
    324     cdef llong g_to_km1_pow = g_to_km1 
    325     cdef llong c = ((g-1)/2).lift() 
    326     cdef llong g_pow = 1 
    327     cdef llong g_pow_new, quot 
    328     cdef llong sum = 0 
    329     cdef int r 
     295    cdef long x = bernmm_bern_modp(p, k) 
     296    if x == -1: 
     297        raise ArithmeticError, "B_k is not integral at p" 
     298    return x 
    330299 
    331     for r from 0 <= r < (p-1)/2: 
    332         g_pow_new = g_pow * g_lift 
    333         quot = g_pow_new / p 
    334         sum = sum + g_to_km1_pow * (p + c - quot) 
    335         sum = sum % p 
    336         g_pow = g_pow_new % p 
    337         g_to_km1_pow = (g_to_km1_pow * g_to_km1) % p 
    338      
    339     return R(sum) * 2 * k / (1 - g**k) 
    340      
    341300 
    342301# ============ end of file 
  • setup.py

    diff -r dcd714fbd837 setup.py
    a b  
    812812              libraries=['ntl','stdc++'], 
    813813              language = 'c++', 
    814814              include_dirs=debian_include_dirs + ['sage/libs/ntl/']), \ 
     815 
     816    Extension('sage.rings.bernmm', 
     817              sources = ['sage/rings/bernmm.pyx', 
     818                         'sage/rings/bernmm/bern_modp.cpp', 
     819                         'sage/rings/bernmm/bern_modp_util.cpp', 
     820                         'sage/rings/bernmm/bern_rat.cpp'], 
     821              libraries = ['gmp', 'ntl', 'stdc++', 'pthread'], 
     822              language = 'c++', 
     823              define_macros=[('USE_THREADS', '1')]), \ 
    815824 
    816825    Extension('sage.schemes.hyperelliptic_curves.hypellfrob', 
    817826                 sources = ['sage/schemes/hyperelliptic_curves/hypellfrob.pyx',