Ticket #8159: mpmath_cython.patch

File mpmath_cython.patch, 81.7 KB (added by fredrik.johansson, 12 years ago)
  • module_list.py

    # HG changeset patch
    # User Fredrik Johansson <fredrik.johansson@gmail.com>
    # Date 1265304121 -3600
    # Node ID a4e8aa8f473d61a7c168864eb8eee44922991a07
    # Parent  eb27a39a6df43f557bdc8bc838c30f3b5c41e925
    Extended Cython backend for mpmath added to sage.libs.mpmath
    
    diff -r eb27a39a6df4 -r a4e8aa8f473d module_list.py
    a b  
    499499    Extension('sage.libs.mpmath.utils',
    500500              sources = ["sage/libs/mpmath/utils.pyx"],
    501501              libraries = ['mpfr', 'gmp']),
    502    
     502
     503    Extension('sage.libs.mpmath.ext_impl',
     504              sources = ["sage/libs/mpmath/ext_impl.pyx"],
     505              libraries = ['gmp']),
     506
     507    Extension('sage.libs.mpmath.ext_main',
     508              sources = ["sage/libs/mpmath/ext_main.pyx"],
     509              libraries = ['gmp']),
     510
     511    Extension('sage.libs.mpmath.ext_libmp',
     512              sources = ["sage/libs/mpmath/ext_libmp.pyx"],
     513              libraries = ['gmp']),
     514
    503515        ###################################
    504516        ##
    505517        ## sage.libs.cremona
  • sage/libs/mpmath/all.py

    diff -r eb27a39a6df4 -r a4e8aa8f473d sage/libs/mpmath/all.py
    a b  
    22
    33# Patch mpmath to use Cythonized functions
    44import utils as _utils
    5 mpmath.libmpf.normalize = mpmath.libmpf.normalize1 = normalize = _utils.normalize
    6 mpmath.libmpf.from_man_exp = from_man_exp = _utils.from_man_exp
     5#mpmath.libmp.normalize = mpmath.libmp.normalize1 = normalize = _utils.normalize
     6#mpmath.libmp.from_man_exp = from_man_exp = _utils.from_man_exp
    77
    88# Also import internal functions
    9 from mpmath.libmpf import *
    10 from mpmath.libelefun import *
    11 from mpmath.libhyper import *
    12 from mpmath.libmpc import *
    13 from mpmath.mptypes import *
    14 from mpmath.gammazeta import *
     9from mpmath.libmp import *
    1510
    1611# Main namespace
    1712from mpmath import *
  • new file sage/libs/mpmath/ext_impl.pxd

    diff -r eb27a39a6df4 -r a4e8aa8f473d sage/libs/mpmath/ext_impl.pxd
    - +  
     1from sage.libs.gmp.all cimport mpz_t
     2
     3ctypedef struct MPopts:
     4    long prec
     5    int rounding
     6
     7cdef mpz_set_integer(mpz_t v, x)
     8cdef inline mpzi(mpz_t n)
     9cdef inline str rndmode_to_python(int rnd)
     10cdef inline rndmode_from_python(str rnd)
     11
     12ctypedef struct MPF:
     13    mpz_t man
     14    mpz_t exp
     15    int special
     16
     17cdef inline void MPF_init(MPF *x)
     18cdef inline void MPF_clear(MPF *x)
     19cdef inline void MPF_set(MPF *dest, MPF *src)
     20cdef inline void MPF_set_zero(MPF *x)
     21cdef inline void MPF_set_nan(MPF *x)
     22cdef inline void MPF_set_inf(MPF *x)
     23cdef inline void MPF_set_ninf(MPF *x)
     24cdef MPF_set_si(MPF *x, long n)
     25cdef MPF_set_int(MPF *x, n)
     26cdef MPF_set_man_exp(MPF *x, man, exp)
     27cdef MPF_set_tuple(MPF *x, tuple value)
     28cdef MPF_to_tuple(MPF *x)
     29cdef MPF_set_double(MPF *r, double x)
     30cdef double MPF_to_double(MPF *x, bint strict)
     31cdef MPF_to_fixed(mpz_t r, MPF *x, long prec, bint truncate)
     32cdef int MPF_sgn(MPF *x)
     33cdef void MPF_neg(MPF *r, MPF *s)
     34cdef void MPF_abs(MPF *r, MPF *s)
     35cdef MPF_normalize(MPF *x, MPopts opts)
     36cdef MPF_add(MPF *r, MPF *s, MPF *t, MPopts opts)
     37cdef MPF_sub(MPF *r, MPF *s, MPF *t, MPopts opts)
     38cdef bint MPF_eq(MPF *s, MPF *t)
     39cdef bint MPF_ne(MPF *s, MPF *t)
     40cdef int MPF_cmp(MPF *s, MPF *t)
     41cdef bint MPF_lt(MPF *s, MPF *t)
     42cdef bint MPF_le(MPF *s, MPF *t)
     43cdef bint MPF_gt(MPF *s, MPF *t)
     44cdef bint MPF_ge(MPF *s, MPF *t)
     45cdef MPF_mul(MPF *r, MPF *s, MPF *t, MPopts opts)
     46cdef MPF_div(MPF *r, MPF *s, MPF *t, MPopts opts)
     47cdef MPF_sqrt(MPF *r, MPF *s, MPopts opts)
     48cdef MPF_hypot(MPF *r, MPF *a, MPF *b, MPopts opts)
     49cdef MPF_pow_int(MPF *r, MPF *x, mpz_t n, MPopts opts)
     50cdef MPF_set_double(MPF *r, double x)
  • new file sage/libs/mpmath/ext_impl.pyx

    diff -r eb27a39a6df4 -r a4e8aa8f473d sage/libs/mpmath/ext_impl.pyx
    - +  
     1"""
     2This module provides the core implementation of multiprecision
     3floating-point arithmetic. Operations are done in-place.
     4"""
     5
     6include '../../ext/interrupt.pxi'
     7include "../../ext/stdsage.pxi"
     8include "../../ext/python_int.pxi"
     9include "../../ext/python_long.pxi"
     10include "../../ext/python_float.pxi"
     11include "../../ext/python_complex.pxi"
     12include "../../ext/python_number.pxi"
     13
     14cdef extern from "math.h":
     15    cdef double fpow "pow" (double, double)
     16    cdef double frexp "frexp" (double, int*)
     17
     18from sage.libs.gmp.all cimport *
     19from sage.rings.integer cimport Integer
     20
     21cdef extern from "mpz_pylong.h":
     22    cdef mpz_get_pylong(mpz_t src)
     23    cdef mpz_get_pyintlong(mpz_t src)
     24    cdef int mpz_set_pylong(mpz_t dst, src) except -1
     25    cdef long mpz_pythonhash(mpz_t src)
     26
     27cdef mpz_set_integer(mpz_t v, x):
     28    if PyInt_CheckExact(x):
     29        mpz_set_si(v, PyInt_AS_LONG(x))
     30    elif PyLong_CheckExact(x):
     31        mpz_set_pylong(v, x)
     32    elif PY_TYPE_CHECK(x, Integer):
     33        mpz_set(v, (<Integer>x).value)
     34    else:
     35        raise TypeError("cannot convert %s to an integer" % x)
     36
     37cdef inline void mpz_add_si(mpz_t a, mpz_t b, long x):
     38    if x >= 0:
     39        mpz_add_ui(a, b, x)
     40    else:
     41        # careful: overflow when negating INT_MIN
     42        mpz_sub_ui(a, b, <unsigned long>(-x))
     43
     44cdef inline mpzi(mpz_t n):
     45    return mpz_get_pyintlong(n)
     46
     47# This should be done better
     48cdef int mpz_tstbit_abs(mpz_t z, unsigned long bit_index):
     49    cdef int res
     50    if mpz_sgn(z) < 0:
     51        mpz_neg(z, z)
     52        res = mpz_tstbit(z, bit_index)
     53        mpz_neg(z, z)
     54    else:
     55        res = mpz_tstbit(z, bit_index)
     56    return res
     57
     58cdef unsigned long mpz_bitcount(mpz_t z):
     59    if mpz_sgn(z) == 0:
     60        return 0
     61    return mpz_sizeinbase(z, 2)
     62
     63# The following limits allowed exponent shifts. We could use mpz_fits_slong_p,
     64# but then (-LONG_MIN) wraps around; we may also not be able to add large
     65# shifts safely. A higher limit could be used on 64-bit systems, but
     66# it is unlikely that anyone will run into this (adding numbers
     67# that differ by 2^(2^30), at precisions of 2^30 bits).
     68
     69DEF MAX_SHIFT = 1073741824      # 2^30
     70
     71cdef int mpz_reasonable_shift(mpz_t z):
     72    if mpz_sgn(z) > 0:
     73        return mpz_cmp_ui(z, MAX_SHIFT) < 0
     74    else:
     75        return mpz_cmp_si(z, -MAX_SHIFT) > 0
     76
     77DEF ROUND_N = 0
     78DEF ROUND_F = 1
     79DEF ROUND_C = 2
     80DEF ROUND_D = 3
     81DEF ROUND_U = 4
     82
     83DEF S_NORMAL = 0
     84DEF S_ZERO = 1
     85DEF S_NZERO = 2
     86DEF S_INF = 3
     87DEF S_NINF = 4
     88DEF S_NAN = 5
     89
     90cdef inline str rndmode_to_python(int rnd):
     91    if rnd == ROUND_N: return 'n'
     92    if rnd == ROUND_F: return 'f'
     93    if rnd == ROUND_C: return 'c'
     94    if rnd == ROUND_D: return 'd'
     95    if rnd == ROUND_U: return 'u'
     96
     97cdef inline rndmode_from_python(str rnd):
     98    if rnd == 'n': return ROUND_N
     99    if rnd == 'f': return ROUND_F
     100    if rnd == 'c': return ROUND_C
     101    if rnd == 'd': return ROUND_D
     102    if rnd == 'u': return ROUND_U
     103
     104cdef MPopts opts_exact
     105cdef MPopts opts_double_precision
     106cdef MPopts opts_mini_prec
     107
     108opts_exact.prec = 0
     109opts_exact.rounding = ROUND_N
     110opts_double_precision.prec = 53
     111opts_double_precision.rounding = ROUND_N
     112opts_mini_prec.prec = 5
     113opts_mini_prec.rounding = ROUND_D
     114
     115cdef double _double_inf = float("1e300") * float("1e300")
     116cdef double _double_ninf = -_double_inf
     117cdef double _double_nan = _double_inf - _double_inf
     118
     119cdef inline void MPF_init(MPF *x):
     120    """Allocate space and set value to zero.
     121    Must be called exactly once when creating a new MPF."""
     122    x.special = S_ZERO
     123    mpz_init(x.man)
     124    mpz_init(x.exp)
     125
     126cdef inline void MPF_clear(MPF *x):
     127    """Deallocate space. Must be called exactly once when finished with an MPF."""
     128    mpz_clear(x.man)
     129    mpz_clear(x.exp)
     130
     131cdef inline void MPF_set(MPF *dest, MPF *src):
     132    """Clone MPF value. Assumes source value is already normalized."""
     133    if src is dest:
     134        return
     135    dest.special = src.special
     136    mpz_set(dest.man, src.man)
     137    mpz_set(dest.exp, src.exp)
     138
     139cdef inline void MPF_set_zero(MPF *x):
     140    """Set value to 0."""
     141    x.special = S_ZERO
     142
     143cdef inline void MPF_set_nan(MPF *x):
     144    """Set value to NaN (not a number)."""
     145    x.special = S_NAN
     146
     147cdef inline void MPF_set_inf(MPF *x):
     148    """Set value to +infinity."""
     149    x.special = S_INF
     150
     151cdef inline void MPF_set_ninf(MPF *x):
     152    """Set value to -infinity."""
     153    x.special = S_NINF
     154
     155cdef MPF_set_si(MPF *x, long n):
     156    """Set value to that of a given C (long) integer."""
     157    if n:
     158        x.special = S_NORMAL
     159        mpz_set_si(x.man, n)
     160        mpz_set_ui(x.exp, 0)
     161        MPF_normalize(x, opts_exact)
     162    else:
     163        MPF_set_zero(x)
     164
     165cdef MPF_set_int(MPF *x, n):
     166    """Set value to that of a given Python integer."""
     167    x.special = S_NORMAL
     168    mpz_set_integer(x.man, n)
     169    if mpz_sgn(x.man):
     170        mpz_set_ui(x.exp, 0)
     171        MPF_normalize(x, opts_exact)
     172    else:
     173        MPF_set_zero(x)
     174
     175cdef MPF_set_man_exp(MPF *x, man, exp):
     176    """
     177    Set value to man*2^exp where man, exp may be of any appropriate
     178    Python integer types.
     179    """
     180    x.special = S_NORMAL
     181    mpz_set_integer(x.man, man)
     182    mpz_set_integer(x.exp, exp)
     183    MPF_normalize(x, opts_exact)
     184
     185
     186# Temporary variables. Note: not thread-safe.
     187# Used by MPF_add/MPF_sub/MPF_div
     188cdef mpz_t tmp_exponent
     189mpz_init(tmp_exponent)
     190cdef MPF tmp0
     191MPF_init(&tmp0)
     192
     193# Used by MPF_hypot and MPF_cmp, which may call MPF_add/MPF_sub
     194cdef MPF tmp1
     195MPF_init(&tmp1)
     196cdef MPF tmp2
     197MPF_init(&tmp2)
     198
     199
     200# Constants needed in a few places
     201cdef MPF MPF_C_1
     202MPF_init(&MPF_C_1)
     203MPF_set_si(&MPF_C_1, 1)
     204cdef Integer MPZ_ZERO = Integer(0)
     205cdef tuple _mpf_fzero = (0, MPZ_ZERO, 0, 0)
     206cdef tuple _mpf_fnan = (0, MPZ_ZERO, -123, -1)
     207cdef tuple _mpf_finf = (0, MPZ_ZERO, -456, -2)
     208cdef tuple _mpf_fninf = (1, MPZ_ZERO, -789, -3)
     209
     210cdef MPF_set_tuple(MPF *x, tuple value):
     211    """
     212    Set value of an MPF to that of a normalized (sign, man, exp, bc) tuple
     213    in the format used by mpmath.libmp.
     214    """
     215    #cdef int sign
     216    cdef Integer man
     217    sign, _man, exp, bc = value
     218    if PY_TYPE_CHECK(_man, Integer):
     219        man = <Integer>_man
     220    else:
     221        # This is actually very unlikely; it should never happen
     222        # in internal code that man isn't an Integer. Maybe the check
     223        # can be avoided by doing checks in e.g. MPF_set_any?
     224        man = Integer(_man)
     225    if mpz_sgn(man.value):
     226        MPF_set_man_exp(x, man, exp)
     227        if sign:
     228            mpz_neg(x.man, x.man)
     229        return
     230    if value == _mpf_fzero:
     231        MPF_set_zero(x)
     232    elif value == _mpf_finf:
     233        MPF_set_inf(x)
     234    elif value == _mpf_fninf:
     235        MPF_set_ninf(x)
     236    else:
     237        MPF_set_nan(x)
     238
     239cdef MPF_to_tuple(MPF *x):
     240    """Convert MPF value to (sign, man, exp, bc) tuple."""
     241    cdef Integer man
     242    if x.special:
     243        if x.special == S_ZERO: return _mpf_fzero
     244        #if x.special == S_NZERO: return _mpf_fnzero
     245        if x.special == S_INF: return _mpf_finf
     246        if x.special == S_NINF: return _mpf_fninf
     247        return _mpf_fnan
     248    man = PY_NEW(Integer)
     249    if mpz_sgn(x.man) < 0:
     250        mpz_neg(man.value, x.man)
     251        sign = 1
     252    else:
     253        mpz_set(man.value, x.man)
     254        sign = 0
     255    exp = mpz_get_pyintlong(x.exp)
     256    bc = mpz_sizeinbase(x.man, 2)
     257    return (sign, man, exp, bc)
     258
     259cdef MPF_set_double(MPF *r, double x):
     260    """
     261    Set value of a C double.
     262    """
     263    cdef int exp
     264    cdef double man
     265    if x != x:
     266        MPF_set_nan(r)
     267        return
     268    if x == _double_inf:
     269        MPF_set_inf(r)
     270        return
     271    if x == _double_ninf:
     272        MPF_set_ninf(r)
     273        return
     274    man = frexp(x, &exp)
     275    man *= 9007199254740992.0
     276    mpz_set_d(r.man, man)
     277    mpz_set_si(r.exp, exp-53)
     278    r.special = S_NORMAL
     279    MPF_normalize(r, opts_exact)
     280
     281import math as pymath
     282
     283# TODO: implement this function safely without using the Python math module
     284cdef double MPF_to_double(MPF *x, bint strict):
     285    """Convert MPF value to a Python float."""
     286    if x.special == S_NORMAL:
     287        man = mpzi(x.man)
     288        exp = mpzi(x.exp)
     289        bc = mpz_sizeinbase(x.man, 2)
     290        try:
     291            if bc < 100:
     292                return pymath.ldexp(man, exp)
     293            # Try resizing the mantissa. Overflow may still happen here.
     294            n = bc - 53
     295            m = man >> n
     296            return pymath.ldexp(m, exp + n)
     297        except OverflowError:
     298            if strict:
     299                raise
     300            # Overflow to infinity
     301            if exp + bc > 0:
     302                if man < 0:
     303                    return _double_ninf
     304                else:
     305                    return _double_inf
     306            # Underflow to zero
     307            return 0.0
     308    if x.special == S_ZERO:
     309        return 0.0
     310    if x.special == S_INF:
     311        return _double_inf
     312    if x.special == S_NINF:
     313        return _double_ninf
     314    return _double_nan
     315
     316cdef MPF_to_fixed(mpz_t r, MPF *x, long prec, bint truncate):
     317    """
     318    Set r = x, r being in the format of a fixed-point number with prec bits.
     319    Floor division is used unless truncate=True in which case
     320    truncating division is used.
     321    """
     322    cdef long shift
     323    if x.special:
     324        if x.special == S_ZERO or x.special == S_NZERO:
     325            mpz_set_ui(r, 0)
     326            return
     327        raise ValueError("cannot create fixed-point number from special value")
     328    if mpz_reasonable_shift(x.exp):
     329        # XXX: signed integer overflow
     330        shift = mpz_get_si(x.exp) + prec
     331        if shift >= 0:
     332            mpz_mul_2exp(r, x.man, shift)
     333        else:
     334            if truncate:
     335                mpz_tdiv_q_2exp(r, x.man, -shift)
     336            else:
     337                mpz_fdiv_q_2exp(r, x.man, -shift)
     338        return
     339    # Underflow
     340    if mpz_sgn(x.exp) < 0:
     341        mpz_set_ui(r, 0)
     342        return
     343    raise OverflowError("cannot convert huge number to fixed-point format")
     344
     345cdef int MPF_sgn(MPF *x):
     346    """
     347    Gives the sign of an MPF (-1, 0, or 1).
     348    """
     349    if x.special:
     350        if x.special == S_INF:
     351            return 1
     352        if x.special == S_NINF:
     353            return -1
     354        return 0
     355    return mpz_sgn(x.man)
     356
     357cdef void MPF_neg(MPF *r, MPF *s):
     358    """
     359    Sets r = -s. MPF_neg(x, x) negates in place.
     360    """
     361    if s.special:
     362        if   s.special == S_ZERO: r.special = S_ZERO #r.special = S_NZERO
     363        elif s.special == S_NZERO: r.special = S_ZERO
     364        elif s.special == S_INF: r.special = S_NINF
     365        elif s.special == S_NINF: r.special = S_INF
     366        else: r.special = s.special
     367        return
     368    r.special = s.special
     369    mpz_neg(r.man, s.man)
     370    if r is not s:
     371        mpz_set(r.exp, s.exp)
     372
     373cdef void MPF_abs(MPF *r, MPF *s):
     374    if s.special:
     375        if    s.special == S_NINF: r.special = S_INF
     376        else: r.special = s.special
     377        return
     378    r.special = s.special
     379    mpz_abs(r.man, s.man)
     380    if r is not s:
     381        mpz_set(r.exp, s.exp)
     382
     383cdef MPF_normalize(MPF *x, MPopts opts):
     384    """
     385    Normalize.
     386
     387    With prec = 0, trailing zero bits are stripped but no rounding
     388    is performed.
     389    """
     390    cdef int sign
     391    cdef long trail, bc, shift
     392    if x.special != S_NORMAL:
     393        return
     394    sign = mpz_sgn(x.man)
     395    if sign == 0:
     396        x.special = S_ZERO
     397        mpz_set_ui(x.exp, 0)
     398        return
     399    bc = mpz_sizeinbase(x.man, 2)
     400    shift = bc - opts.prec
     401    # Ok if mantissa small and no trailing zero bits
     402    if (shift <= 0 or not opts.prec) and mpz_odd_p(x.man):
     403        return
     404    # Mantissa is too large, so divide by appropriate power of 2
     405    # Need to be careful about rounding
     406    if shift > 0 and opts.prec:
     407        if opts.rounding == ROUND_N:
     408            if mpz_tstbit_abs(&x.man, shift-1):
     409                if mpz_tstbit_abs(&x.man, shift) or mpz_scan1(x.man, 0) < (shift-1):
     410                    if sign < 0:
     411                        mpz_fdiv_q_2exp(x.man, x.man, shift)
     412                    else:
     413                        mpz_cdiv_q_2exp(x.man, x.man, shift)
     414                else:
     415                    mpz_tdiv_q_2exp(x.man, x.man, shift)
     416            else:
     417                mpz_tdiv_q_2exp(x.man, x.man, shift)
     418        elif opts.rounding == ROUND_D:
     419            mpz_tdiv_q_2exp(x.man, x.man, shift)
     420        elif opts.rounding == ROUND_F:
     421            mpz_fdiv_q_2exp(x.man, x.man, shift)
     422        elif opts.rounding == ROUND_C:
     423            mpz_cdiv_q_2exp(x.man, x.man, shift)
     424        elif opts.rounding == ROUND_U:
     425            if sign < 0:
     426                mpz_fdiv_q_2exp(x.man, x.man, shift)
     427            else:
     428                mpz_cdiv_q_2exp(x.man, x.man, shift)
     429        else:
     430            raise ValueError("bad rounding mode")
     431    else:
     432        shift = 0
     433    # Strip trailing bits
     434    trail = mpz_scan1(x.man, 0)
     435    if 0 < trail < bc:
     436        mpz_tdiv_q_2exp(x.man, x.man, trail)
     437        shift += trail
     438    mpz_add_si(x.exp, x.exp, shift)
     439
     440cdef void _add_special(MPF *r, MPF *s, MPF *t):
     441    if s.special == S_ZERO:
     442        # (+0) + (-0) = +0
     443        if t.special == S_NZERO:
     444            MPF_set(r, s)
     445        # (+0) + x = x
     446        else:
     447            MPF_set(r, t)
     448    elif t.special == S_ZERO:
     449        # (-0) + (+0) = +0
     450        if s.special == S_NZERO:
     451            MPF_set(r, t)
     452        # x + (+0) = x
     453        else:
     454            MPF_set(r, s)
     455    # (+/- 0) + x = x
     456    elif s.special == S_NZERO:
     457        MPF_set(r, t)
     458    elif t.special == S_NZERO:
     459        MPF_set(r, s)
     460    # (+/- inf) + (-/+ inf) = nan
     461    elif ((s.special == S_INF and t.special == S_NINF) or
     462       (s.special == S_NINF and t.special == S_INF)):
     463        MPF_set_nan(r)
     464    # nan or +/- inf trumps any finite number
     465    elif s.special == S_NAN or t.special == S_NAN:
     466        MPF_set_nan(r)
     467    elif s.special:
     468        MPF_set(r, s)
     469    else:
     470        MPF_set(r, t)
     471    return
     472
     473cdef void _sub_special(MPF *r, MPF *s, MPF *t):
     474    if s.special == S_ZERO:
     475        # (+0) - (+/-0) = (+0)
     476        if t.special == S_NZERO:
     477            MPF_set(r, s)
     478        else:
     479            # (+0) - x = (-x)
     480            MPF_neg(r, t)
     481    elif t.special == S_ZERO:
     482        # x - (+0) = x; also covers (-0) - (+0) = (-0)
     483        MPF_set(r, s)
     484    # (-0) - x = x
     485    elif s.special == S_NZERO:
     486        # (-0) - (-0) = (+0)
     487        if t.special == S_NZERO:
     488            MPF_set_zero(r)
     489        # (-0) - x = -x
     490        else:
     491            MPF_neg(r, t)
     492    elif t.special == S_NZERO:
     493        # x - (-0) = x
     494        MPF_set(r, s)
     495    # (+/- inf) - (+/- inf) = nan
     496    elif ((s.special == S_INF and t.special == S_INF) or
     497       (s.special == S_NINF and t.special == S_NINF)):
     498        MPF_set_nan(r)
     499    elif s.special == S_NAN or t.special == S_NAN:
     500        MPF_set_nan(r)
     501    # nan - x or (+/-inf) - x = l.h.s
     502    elif s.special:
     503        MPF_set(r, s)
     504    # x - nan or x - (+/-inf) = (- r.h.s)
     505    else:
     506        MPF_neg(r, t)
     507
     508cdef void _mul_special(MPF *r, MPF *s, MPF *t):
     509    if s.special == S_ZERO:
     510        if t.special == S_NORMAL or t.special == S_ZERO:
     511            MPF_set(r, s)
     512        elif t.special == S_NZERO:
     513            MPF_set(r, t)
     514        else:
     515            MPF_set_nan(r)
     516    elif t.special == S_ZERO:
     517        if s.special == S_NORMAL:
     518            MPF_set(r, t)
     519        elif s.special == S_NZERO:
     520            MPF_set(r, s)
     521        else:
     522            MPF_set_nan(r)
     523    elif s.special == S_NZERO:
     524        if t.special == S_NORMAL:
     525            if mpz_sgn(t.man) < 0:
     526                MPF_set_zero(r)
     527            else:
     528                MPF_set(r, s)
     529        else:
     530            MPF_set_nan(r)
     531    elif t.special == S_NZERO:
     532        if s.special == S_NORMAL:
     533            if mpz_sgn(s.man) < 0:
     534                MPF_set_zero(r)
     535            else:
     536                MPF_set(r, t)
     537        else:
     538            MPF_set_nan(r)
     539    elif s.special == S_NAN or t.special == S_NAN:
     540        MPF_set_nan(r)
     541    else:
     542        if MPF_sgn(s) == MPF_sgn(t):
     543            MPF_set_inf(r)
     544        else:
     545            MPF_set_ninf(r)
     546
     547cdef _div_special(MPF *r, MPF *s, MPF *t):
     548    # TODO: handle signed zeros correctly
     549    if s.special == S_NAN or t.special == S_NAN:
     550        MPF_set_nan(r)
     551    elif t.special == S_ZERO or t.special == S_NZERO:
     552        raise ZeroDivisionError
     553    elif s.special == S_ZERO or s.special == S_NZERO:
     554        MPF_set_zero(r)
     555    elif s.special == S_NORMAL:
     556        MPF_set_zero(r)
     557    elif s.special == S_INF or s.special == S_NINF:
     558        if t.special == S_INF or t.special == S_NINF:
     559            MPF_set_nan(r)
     560        elif MPF_sgn(s) == MPF_sgn(t):
     561            MPF_set_inf(r)
     562        else:
     563            MPF_set_ninf(r)
     564    # else:
     565    elif t.special == S_INF or t.special == S_NINF:
     566        MPF_set_zero(r)
     567   
     568cdef _add_perturbation(MPF *r, MPF *s, int sign, MPopts opts):
     569    cdef long shift
     570    if opts.rounding == ROUND_N:
     571        MPF_set(r, s)
     572    else:
     573        shift = opts.prec - mpz_sizeinbase(s, 2) + 8
     574        if shift < 0:
     575            shift = 8
     576        mpz_mul_2exp(r.man, s.man, shift)
     577        mpz_add_si(r.man, r.man, sign)
     578        mpz_sub_ui(r.exp, s.exp, shift)
     579        MPF_normalize(r, opts)
     580
     581cdef MPF_add(MPF *r, MPF *s, MPF *t, MPopts opts):
     582    """
     583    Set r = s + t, with exact rounding.
     584
     585    With prec = 0, the addition is performed exactly. Note that this
     586    may cause overflow if the exponents are huge.
     587    """
     588    cdef long shift, sbc, tbc
     589    #assert (r is not s) and (r is not t)
     590    if s.special or t.special:
     591        _add_special(r, s, t)
     592        return
     593    r.special = S_NORMAL
     594    # Difference between exponents
     595    mpz_sub(tmp_exponent, s.exp, t.exp)
     596    if mpz_reasonable_shift(tmp_exponent):
     597        shift = mpz_get_si(tmp_exponent)
     598        if shift >= 0:
     599            # |s| >> |t|
     600            if shift > 2*opts.prec and opts.prec:
     601                sbc = mpz_sizeinbase(s,2)
     602                tbc = mpz_sizeinbase(t,2)
     603                if shift + sbc - tbc > opts.prec+8:
     604                    _add_perturbation(r, s, mpz_sgn(t.man), opts)
     605                    return
     606            # |s| > |t|
     607            mpz_mul_2exp(tmp0.man, s.man, shift)
     608            mpz_add(r.man, tmp0.man, t.man)
     609            mpz_set(r.exp, t.exp)
     610            MPF_normalize(r, opts)
     611        elif shift < 0:
     612            shift = -shift
     613            # |s| << |t|
     614            if shift > 2*opts.prec and opts.prec:
     615                sbc = mpz_sizeinbase(s,2)
     616                tbc = mpz_sizeinbase(t,2)
     617                if shift + tbc - sbc > opts.prec+8:
     618                    _add_perturbation(r, t, mpz_sgn(s.man), opts)
     619                    return
     620            # |s| < |t|
     621            mpz_mul_2exp(tmp0.man, t.man, shift)
     622            mpz_add(r.man, tmp0.man, s.man)
     623            mpz_set(r.exp, s.exp)
     624            MPF_normalize(r, opts)
     625    else:
     626        if not opts.prec:
     627            raise OverflowError("the exact result does not fit in memory")
     628        # |s| >>> |t|
     629        if mpz_sgn(tmp_exponent) > 0:
     630            _add_perturbation(r, s, mpz_sgn(t.man), opts)
     631        # |s| <<< |t|
     632        else:
     633            _add_perturbation(r, t, mpz_sgn(s.man), opts)
     634
     635cdef MPF_sub(MPF *r, MPF *s, MPF *t, MPopts opts):
     636    """
     637    Set r = s - t, with exact rounding.
     638
     639    With prec = 0, the addition is performed exactly. Note that this
     640    may cause overflow if the exponents are huge.
     641    """
     642    cdef long shift, sbc, tbc
     643    #assert (r is not s) and (r is not t)
     644    if s.special or t.special:
     645        _sub_special(r, s, t)
     646        return
     647    r.special = S_NORMAL
     648    # Difference between exponents
     649    mpz_sub(tmp_exponent, s.exp, t.exp)
     650    if mpz_reasonable_shift(tmp_exponent):
     651        shift = mpz_get_si(tmp_exponent)
     652        if shift >= 0:
     653            # |s| >> |t|
     654            if shift > 2*opts.prec and opts.prec:
     655                sbc = mpz_sizeinbase(s,2)
     656                tbc = mpz_sizeinbase(t,2)
     657                if shift + sbc - tbc > opts.prec+8:
     658                    _add_perturbation(r, s, -mpz_sgn(t.man), opts)
     659                    return
     660            # |s| > |t|
     661            mpz_mul_2exp(tmp0.man, s.man, shift)
     662            mpz_sub(r.man, tmp0.man, t.man)
     663            mpz_set(r.exp, t.exp)
     664            MPF_normalize(r, opts)
     665        elif shift < 0:
     666            shift = -shift
     667            # |s| << |t|
     668            if shift > 2*opts.prec and opts.prec:
     669                sbc = mpz_sizeinbase(s,2)
     670                tbc = mpz_sizeinbase(t,2)
     671                if shift + tbc - sbc > opts.prec+8:
     672                    _add_perturbation(r, t, -mpz_sgn(s.man), opts)
     673                    MPF_neg(r, r)
     674                    return
     675            # |s| < |t|
     676            mpz_mul_2exp(tmp0.man, t.man, shift)
     677            mpz_sub(r.man, s.man, tmp0.man)
     678            mpz_set(r.exp, s.exp)
     679            MPF_normalize(r, opts)
     680    else:
     681        if not opts.prec:
     682            raise OverflowError("the exact result does not fit in memory")
     683        # |s| >>> |t|
     684        if mpz_sgn(tmp_exponent) > 0:
     685            _add_perturbation(r, s, -mpz_sgn(t.man), opts)
     686        # |s| <<< |t|
     687        else:
     688            _add_perturbation(r, t, -mpz_sgn(s.man), opts)
     689            MPF_neg(r, r)
     690
     691cdef bint MPF_eq(MPF *s, MPF *t):
     692    if s.special == S_NAN or t.special == S_NAN:
     693        return False
     694    if s.special == t.special:
     695        if s.special == S_NORMAL:
     696            return (mpz_cmp(s.man, t.man) == 0) and (mpz_cmp(s.exp, t.exp) == 0)
     697        else:
     698            return True
     699    return False
     700
     701cdef bint MPF_ne(MPF *s, MPF *t):
     702    if s.special == S_NAN or t.special == S_NAN:
     703        return True
     704    if s.special == S_NORMAL and t.special == S_NORMAL:
     705        return (mpz_cmp(s.man, t.man) != 0) or (mpz_cmp(s.exp, t.exp) != 0)
     706    return s.special != t.special
     707
     708cdef int MPF_cmp(MPF *s, MPF *t):
     709    cdef long sbc, tbc
     710    cdef int cm
     711    if MPF_eq(s, t):
     712        return 0
     713    if s.special != S_NORMAL or t.special != S_NORMAL:
     714        if s.special == S_ZERO: return -MPF_sgn(t)
     715        if t.special == S_ZERO: return MPF_sgn(s)
     716        if t.special == S_NAN: return 1
     717        if s.special == S_INF: return 1
     718        if t.special == S_NINF: return 1
     719        return -1
     720    if mpz_sgn(s.man) != mpz_sgn(t.man):
     721        if mpz_sgn(s.man) < 0:
     722            return -1
     723        else:
     724            return 1
     725    if not mpz_cmp(s.exp, t.exp):
     726        return mpz_cmp(s.man, t.man)
     727    mpz_add_ui(tmp1.exp, s.exp, mpz_sizeinbase(s.man, 2))
     728    mpz_add_ui(tmp2.exp, t.exp, mpz_sizeinbase(t.man, 2))
     729    cm = mpz_cmp(tmp1.exp, tmp2.exp)
     730    if mpz_sgn(s.man) < 0:
     731        if cm < 0: return 1
     732        if cm > 0: return -1
     733    else:
     734        if cm < 0: return -1
     735        if cm > 0: return 1
     736    MPF_sub(&tmp1, s, t, opts_mini_prec)
     737    return MPF_sgn(&tmp1)
     738
     739cdef bint MPF_lt(MPF *s, MPF *t):
     740    if s.special == S_NAN or t.special == S_NAN:
     741        return False
     742    return MPF_cmp(s, t) < 0
     743
     744cdef bint MPF_le(MPF *s, MPF *t):
     745    if s.special == S_NAN or t.special == S_NAN:
     746        return False
     747    return MPF_cmp(s, t) <= 0
     748
     749cdef bint MPF_gt(MPF *s, MPF *t):
     750    if s.special == S_NAN or t.special == S_NAN:
     751        return False
     752    return MPF_cmp(s, t) > 0
     753
     754cdef bint MPF_ge(MPF *s, MPF *t):
     755    if s.special == S_NAN or t.special == S_NAN:
     756        return False
     757    return MPF_cmp(s, t) >= 0
     758
     759cdef MPF_mul(MPF *r, MPF *s, MPF *t, MPopts opts):
     760    """
     761    Set r = s * t, with correct rounding.
     762
     763    With prec = 0, the multiplication is performed exactly,
     764    i.e. no rounding is performed.
     765    """
     766    if s.special or t.special:
     767        _mul_special(r, s, t)
     768    else:
     769        r.special = S_NORMAL
     770        mpz_mul(r.man, s.man, t.man)
     771        mpz_add(r.exp, s.exp, t.exp)
     772        if opts.prec:
     773            MPF_normalize(r, opts)
     774
     775cdef MPF_div(MPF *r, MPF *s, MPF *t, MPopts opts):
     776    """
     777    Set r = s / t, with correct rounding.
     778    """
     779    cdef int sign
     780    cdef long sbc, tbc, extra
     781    cdef mpz_t rem
     782    #assert (r is not s) and (r is not t)
     783    if s.special or t.special:
     784        _div_special(r, s, t)
     785        return
     786    r.special = S_NORMAL
     787    # Division by a power of two <=> shift exponents
     788    if mpz_cmp_si(t.man, 1) == 0:
     789        MPF_set(&tmp0, s)
     790        mpz_sub(tmp0.exp, tmp0.exp, t.exp)
     791        MPF_normalize(&tmp0, opts)
     792        MPF_set(r, &tmp0)
     793        return
     794    elif mpz_cmp_si(t.man, -1) == 0:
     795        MPF_neg(&tmp0, s)
     796        mpz_sub(tmp0.exp, tmp0.exp, t.exp)
     797        MPF_normalize(&tmp0, opts)
     798        MPF_set(r, &tmp0)
     799        return
     800    sign = mpz_sgn(s.man) != mpz_sgn(t.man)
     801    # Same strategy as for addition: if there is a remainder, perturb
     802    # the result a few bits outside the precision range before rounding
     803    extra = opts.prec - mpz_sizeinbase(s.man,2) + mpz_sizeinbase(t.man,2) + 5
     804    if extra < 5:
     805        extra = 5
     806    mpz_init(rem)
     807    mpz_mul_2exp(tmp0.man, s.man, extra)
     808    mpz_tdiv_qr(r.man, rem, tmp0.man, t.man)
     809    if mpz_sgn(rem):
     810        mpz_mul_2exp(r.man, r.man, 1)
     811        if sign:
     812            mpz_sub_ui(r.man, r.man, 1)
     813        else:
     814            mpz_add_ui(r.man, r.man, 1)
     815        extra  += 1
     816    mpz_clear(rem)
     817    mpz_sub(r.exp, s.exp, t.exp)
     818    mpz_sub_ui(r.exp, r.exp, extra)
     819    MPF_normalize(r, opts)
     820
     821cdef MPF_sqrt(MPF *r, MPF *s, MPopts opts):
     822    """
     823    Set r = sqrt(s), with correct rounding.
     824    """
     825    cdef long shift
     826    cdef mpz_t rem
     827    #assert r is not s
     828    if s.special:
     829        if s.special == S_ZERO or s.special == S_INF:
     830            MPF_set(r, s)
     831        else:
     832            MPF_set_nan(r)
     833        return
     834    if mpz_sgn(s.man) < 0:
     835        MPF_set_nan(r)
     836        return
     837    r.special = S_NORMAL
     838    if mpz_odd_p(s.exp):
     839        mpz_sub_ui(r.exp, s.exp, 1)
     840        mpz_mul_2exp(r.man, s.man, 1)
     841    elif mpz_cmp_ui(s.man, 1) == 0:
     842        # Square of a power of two
     843        mpz_set_ui(r.man, 1)
     844        mpz_tdiv_q_2exp(r.exp, s.exp, 1)
     845        MPF_normalize(r, opts)
     846        return
     847    else:
     848        mpz_set(r.man, s.man)
     849        mpz_set(r.exp, s.exp)
     850    shift = 2*opts.prec - mpz_sizeinbase(r.man,2) + 4
     851    if shift < 4:
     852        shift = 4
     853    shift += shift & 1
     854    mpz_mul_2exp(r.man, r.man, shift)
     855    if opts.rounding == ROUND_F or opts.rounding == ROUND_D:
     856        mpz_sqrt(r.man, r.man)
     857    else:
     858        mpz_init(rem)
     859        mpz_sqrtrem(r.man, rem, r.man)
     860        if mpz_sgn(rem):
     861            mpz_mul_2exp(r.man, r.man, 1)
     862            mpz_add_ui(r.man, r.man, 1)
     863            shift += 2
     864        mpz_clear(rem)
     865    mpz_add_si(r.exp, r.exp, -shift)
     866    mpz_tdiv_q_2exp(r.exp, r.exp, 1)
     867    MPF_normalize(r, opts)
     868
     869cdef MPF_hypot(MPF *r, MPF *a, MPF *b, MPopts opts):
     870    """
     871    Set r = sqrt(a^2 + b^2)
     872    """
     873    cdef MPopts tmp_opts
     874    if a.special == S_ZERO:
     875        MPF_abs(r, b)
     876        MPF_normalize(r, opts)
     877        return
     878    if b.special == S_ZERO:
     879        MPF_abs(r, a)
     880        MPF_normalize(r, opts)
     881        return
     882    tmp_opts = opts
     883    tmp_opts.prec += 30
     884    MPF_mul(&tmp1, a, a, opts_exact)
     885    MPF_mul(&tmp2, b, b, opts_exact)
     886    MPF_add(r, &tmp1, &tmp2, tmp_opts)
     887    MPF_sqrt(r, r, opts)
     888
     889cdef MPF_pow_int(MPF *r, MPF *x, mpz_t n, MPopts opts):
     890    """
     891    Set r = x ** n. Currently falls back to mpmath.libmp
     892    unless n is tiny.
     893    """
     894    cdef long m, absm
     895    cdef unsigned long bc
     896    cdef int nsign
     897    if x.special != S_NORMAL:
     898        nsign = mpz_sgn(n)
     899        if x.special == S_ZERO:
     900            if nsign < 0:
     901                raise ZeroDivisionError
     902            elif nsign == 0:
     903                MPF_set(r, &MPF_C_1)
     904            else:
     905                MPF_set_zero(r)
     906        elif x.special == S_INF:
     907            if nsign > 0:
     908                MPF_set(r, x)
     909            elif nsign == 0:
     910                MPF_set_nan(r)
     911            else:
     912                MPF_set_zero(r)
     913        elif x.special == S_NINF:
     914            if nsign > 0:
     915                if mpz_odd_p(n):
     916                    MPF_set(r, x)
     917                else:
     918                    MPF_neg(r, x)
     919            elif nsign == 0:
     920                MPF_set_nan(r)
     921            else:
     922                MPF_set_zero(r)
     923        else:
     924            MPF_set_nan(r)
     925        return
     926    bc = mpz_sizeinbase(r.man,2)
     927    r.special = S_NORMAL
     928    if mpz_reasonable_shift(n):
     929        m = mpz_get_si(n)
     930        if m == 0:
     931            MPF_set(r, &MPF_C_1)
     932            return
     933        if m == 1:
     934            MPF_set(r, x)
     935            MPF_normalize(r, opts)
     936            return
     937        if m == 2:
     938            MPF_mul(r, x, x, opts)
     939            return
     940        if m == -1:
     941            MPF_div(r, &MPF_C_1, x, opts)
     942            return
     943        if m == -2:
     944            MPF_mul(r, x, x, opts_exact)
     945            MPF_div(r, &MPF_C_1, r, opts)
     946            return
     947        absm = abs(m)
     948        if bc * absm < 10000:
     949            mpz_pow_ui(r.man, x.man, absm)
     950            mpz_mul_ui(r.exp, x.exp, absm)
     951            if m < 0:
     952                MPF_div(r, &MPF_C_1, r, opts)
     953            else:
     954                MPF_normalize(r, opts)
     955            return
     956    r.special = S_NORMAL
     957    # (2^p)^n
     958    if mpz_cmp_si(x.man, 1) == 0:
     959        mpz_set(r.man, x.man)
     960        mpz_mul(r.exp, x.exp, n)
     961        return
     962    # (-2^p)^n
     963    if mpz_cmp_si(x.man, -1) == 0:
     964        if mpz_odd_p(n):
     965            mpz_set(r.man, x.man)
     966        else:
     967            mpz_neg(r.man, x.man)
     968        mpz_mul(r.exp, x.exp, n)
     969        return
     970    # TODO: implement efficiently here
     971    import mpmath.libmp
     972    MPF_set_tuple(r,
     973        mpmath.libmp.mpf_pow_int(MPF_to_tuple(x), mpzi(n),
     974        opts.prec, rndmode_to_python(opts.rounding)))
     975
  • new file sage/libs/mpmath/ext_libmp.pyx

    diff -r eb27a39a6df4 -r a4e8aa8f473d sage/libs/mpmath/ext_libmp.pyx
    - +  
     1"""
     2Faster versions of some key functions in mpmath.libmp.
     3"""
     4
     5include "../../ext/stdsage.pxi"
     6from ext_impl cimport *
     7from sage.libs.gmp.all cimport *
     8from sage.rings.integer cimport Integer
     9
     10# Note: not thread-safe
     11cdef MPF tmp1
     12cdef MPF tmp2
     13MPF_init(&tmp1)
     14MPF_init(&tmp2)
     15
     16def mpf_add(tuple x, tuple y, int prec=0, str rnd='d'):
     17    cdef MPopts opts
     18    MPF_set_tuple(&tmp1, x)
     19    MPF_set_tuple(&tmp2, y)
     20    opts.rounding = rndmode_from_python(rnd)
     21    opts.prec = prec
     22    MPF_add(&tmp1, &tmp1, &tmp2, opts)
     23    return MPF_to_tuple(&tmp1)
     24
     25def mpf_sub(tuple x, tuple y, int prec=0, str rnd='d'):
     26    cdef MPopts opts
     27    MPF_set_tuple(&tmp1, x)
     28    MPF_set_tuple(&tmp2, y)
     29    opts.rounding = rndmode_from_python(rnd)
     30    opts.prec = prec
     31    MPF_sub(&tmp1, &tmp1, &tmp2, opts)
     32    return MPF_to_tuple(&tmp1)
     33
     34def mpf_mul(tuple x, tuple y, int prec=0, str rnd='d'):
     35    cdef MPopts opts
     36    MPF_set_tuple(&tmp1, x)
     37    MPF_set_tuple(&tmp2, y)
     38    opts.rounding = rndmode_from_python(rnd)
     39    opts.prec = prec
     40    MPF_mul(&tmp1, &tmp1, &tmp2, opts)
     41    return MPF_to_tuple(&tmp1)
     42
     43def mpf_div(tuple x, tuple y, int prec, str rnd='d'):
     44    cdef MPopts opts
     45    MPF_set_tuple(&tmp1, x)
     46    MPF_set_tuple(&tmp2, y)
     47    opts.rounding = rndmode_from_python(rnd)
     48    opts.prec = prec
     49    MPF_div(&tmp1, &tmp1, &tmp2, opts)
     50    return MPF_to_tuple(&tmp1)
     51
     52def mpf_sqrt(tuple x, int prec, str rnd='d'):
     53    if x[0]:
     54        import mpmath.libmp as libmp
     55        raise libmp.ComplexResult("square root of a negative number")
     56    cdef MPopts opts
     57    MPF_set_tuple(&tmp1, x)
     58    opts.rounding = rndmode_from_python(rnd)
     59    opts.prec = prec
     60    MPF_sqrt(&tmp1, &tmp1, opts)
     61    return MPF_to_tuple(&tmp1)
  • new file sage/libs/mpmath/ext_main.pxd

    diff -r eb27a39a6df4 -r a4e8aa8f473d sage/libs/mpmath/ext_main.pxd
    - +  
     1from ext_impl cimport *
     2
  • new file sage/libs/mpmath/ext_main.pyx

    diff -r eb27a39a6df4 -r a4e8aa8f473d sage/libs/mpmath/ext_main.pyx
    - +  
     1"""
     2Implements mpf and mpc types, with binary operations and support
     3for interaction with other types. Also implements the main
     4context class, and related utilities.
     5"""
     6
     7include '../../ext/interrupt.pxi'
     8include "../../ext/stdsage.pxi"
     9include "../../ext/python_int.pxi"
     10include "../../ext/python_long.pxi"
     11include "../../ext/python_float.pxi"
     12include "../../ext/python_complex.pxi"
     13include "../../ext/python_number.pxi"
     14
     15from sage.libs.gmp.all cimport *
     16from sage.rings.integer cimport Integer
     17
     18cdef extern from "mpz_pylong.h":
     19    cdef mpz_get_pylong(mpz_t src)
     20    cdef mpz_get_pyintlong(mpz_t src)
     21    cdef int mpz_set_pylong(mpz_t dst, src) except -1
     22    cdef long mpz_pythonhash(mpz_t src)
     23
     24DEF ROUND_N = 0
     25DEF ROUND_F = 1
     26DEF ROUND_C = 2
     27DEF ROUND_D = 3
     28DEF ROUND_U = 4
     29DEF S_NORMAL = 0
     30DEF S_ZERO = 1
     31DEF S_NZERO = 2
     32DEF S_INF = 3
     33DEF S_NINF = 4
     34DEF S_NAN = 5
     35
     36from ext_impl cimport *
     37
     38import mpmath.rational as rationallib
     39import mpmath.libmp as libmp
     40import mpmath.function_docs as function_docs
     41from mpmath.libmp import to_str
     42from mpmath.libmp import repr_dps, prec_to_dps, dps_to_prec
     43
     44DEF OP_ADD = 0
     45DEF OP_SUB = 1
     46DEF OP_MUL = 2
     47DEF OP_DIV = 3
     48DEF OP_POW = 4
     49DEF OP_MOD = 5
     50DEF OP_RICHCMP = 6
     51DEF OP_EQ = (OP_RICHCMP + 2)
     52DEF OP_NE = (OP_RICHCMP + 3)
     53DEF OP_LT = (OP_RICHCMP + 0)
     54DEF OP_GT = (OP_RICHCMP + 4)
     55DEF OP_LE = (OP_RICHCMP + 1)
     56DEF OP_GE = (OP_RICHCMP + 5)
     57
     58cdef MPopts opts_exact
     59cdef MPopts opts_double_precision
     60cdef MPopts opts_mini_prec
     61
     62opts_exact.prec = 0
     63opts_exact.rounding = ROUND_N
     64opts_double_precision.prec = 53
     65opts_double_precision.rounding = ROUND_N
     66opts_mini_prec.prec = 5
     67opts_mini_prec.rounding = ROUND_D
     68
     69cdef MPF MPF_C_0
     70cdef MPF MPF_C_1
     71cdef MPF MPF_C_2
     72
     73MPF_init(&MPF_C_0); MPF_set_zero(&MPF_C_0);
     74MPF_init(&MPF_C_1); MPF_set_si(&MPF_C_1, 1);
     75MPF_init(&MPF_C_2); MPF_set_si(&MPF_C_2, 2);
     76
     77# Temporaries used for operands in binary operations
     78cdef mpz_t tmp_mpz
     79mpz_init(tmp_mpz)
     80
     81cdef MPF tmp1
     82cdef MPF tmp2
     83cdef MPF tmp_opx_re
     84cdef MPF tmp_opx_im
     85cdef MPF tmp_opy_re
     86cdef MPF tmp_opy_im
     87
     88MPF_init(&tmp1)
     89MPF_init(&tmp2)
     90MPF_init(&tmp_opx_re)
     91MPF_init(&tmp_opx_im)
     92MPF_init(&tmp_opy_re)
     93MPF_init(&tmp_opy_im)
     94
     95cdef class Context
     96cdef class mpnumber
     97cdef class mpf_base
     98cdef class mpf
     99cdef class mpc
     100cdef class constant
     101cdef class wrapped_libmp_function
     102cdef class wrapped_specfun
     103
     104
     105cdef int MPF_set_any(MPF *re, MPF *im, x, MPopts opts, bint str_tuple_ok) except -1:
     106    """
     107    Sets re + im*i = x, where x is any Python number.
     108
     109    Returns 0 if unable to coerce x; 1 if x is real and re was set;
     110    2 if x is complex and both re and im were set.
     111
     112    If str_tuple_ok=True, strings and tuples are accepted and converted
     113    (useful for parsing arguments, but not for arithmetic operands).
     114    """
     115    if PY_TYPE_CHECK(x, mpf):
     116        MPF_set(re, &(<mpf>x).value)
     117        return 1
     118    if PY_TYPE_CHECK(x, mpc):
     119        MPF_set(re, &(<mpc>x).re)
     120        MPF_set(im, &(<mpc>x).im)
     121        return 2
     122    if PyInt_Check(x) or PyLong_Check(x) or PY_TYPE_CHECK(x, Integer):
     123        MPF_set_int(re, x)
     124        return 1
     125    if PyFloat_Check(x):
     126        MPF_set_double(re, x)
     127        return 1
     128    if PyComplex_Check(x):
     129        MPF_set_double(re, x.real)
     130        MPF_set_double(im, x.imag)
     131        return 2
     132    if PY_TYPE_CHECK(x, constant):
     133        MPF_set_tuple(re, x.func(opts.prec, rndmode_to_python(opts.rounding)))
     134        return 1
     135    if hasattr(x, "_mpf_"):
     136        MPF_set_tuple(re, x._mpf_)
     137        return 1
     138    if hasattr(x, "_mpc_"):
     139        r, i = x._mpc_
     140        MPF_set_tuple(re, r)
     141        MPF_set_tuple(im, i)
     142        return 2
     143    if hasattr(x, "_mpmath_"):
     144        return MPF_set_any(re, im, x._mpmath_(opts.prec,
     145            rndmode_to_python(opts.rounding)), opts, False)
     146    if PY_TYPE_CHECK(x, rationallib.mpq):
     147        p, q = x
     148        MPF_set_tuple(re, libmp.from_rational(p, q, opts.prec,
     149            rndmode_to_python(opts.rounding)))
     150        return 1
     151    if str_tuple_ok:
     152        if PY_TYPE_CHECK(x, tuple):
     153            if len(x) == 2:
     154                MPF_set_man_exp(re, x[0], x[1])
     155                return 1
     156            elif len(x) == 4:
     157                MPF_set_tuple(re, x)
     158                return 1
     159        if isinstance(x, basestring):
     160            try:
     161                st = libmp.from_str(x, opts.prec,
     162                    rndmode_to_python(opts.rounding))
     163            except ValueError:
     164                return 0
     165            MPF_set_tuple(re, st)
     166            return 1
     167    return 0
     168
     169cdef binop(int op, x, y, MPopts opts):
     170    cdef int typx
     171    cdef int typy
     172    cdef MPF xre, xim, yre, yim
     173    cdef mpf rr
     174    cdef mpc rc
     175    cdef MPopts altopts
     176
     177    if PY_TYPE_CHECK(x, mpf):
     178        xre = (<mpf>x).value
     179        typx = 1
     180    elif PY_TYPE_CHECK(x, mpc):
     181        xre = (<mpc>x).re
     182        xim = (<mpc>x).im
     183        typx = 2
     184    else:
     185        typx = MPF_set_any(&tmp_opx_re, &tmp_opx_im, x, opts, False)
     186        if typx == 0:
     187            return NotImplemented
     188        xre = tmp_opx_re
     189        xim = tmp_opx_im
     190
     191    if PY_TYPE_CHECK(y, mpf):
     192        yre = (<mpf>y).value
     193        typy = 1
     194    elif PY_TYPE_CHECK(y, mpc):
     195        yre = (<mpc>y).re
     196        yim = (<mpc>y).im
     197        typy = 2
     198    else:
     199        typy = MPF_set_any(&tmp_opy_re, &tmp_opy_im, y, opts, False)
     200        if typy == 0:
     201            return NotImplemented
     202        yre = tmp_opy_re
     203        yim = tmp_opy_im
     204
     205    if op == OP_ADD:
     206        if typx == 1 and typy == 1:
     207            # Real result
     208            rr = PY_NEW(mpf)
     209            MPF_add(&rr.value, &xre, &yre, opts)
     210            return rr
     211        else:
     212            # Complex result
     213            rc = PY_NEW(mpc)
     214            MPF_add(&rc.re, &xre, &yre, opts)
     215            if typx == 1:
     216                MPF_set(&rc.im, &yim)
     217                #MPF_normalize(&rc.im, opts)
     218            elif typy == 1:
     219                MPF_set(&rc.im, &xim)
     220                #MPF_normalize(&rc.im, opts)
     221            else:
     222                MPF_add(&rc.im, &xim, &yim, opts)
     223            return rc
     224
     225    elif op == OP_SUB:
     226        if typx == 1 and typy == 1:
     227            # Real result
     228            rr = PY_NEW(mpf)
     229            MPF_sub(&rr.value, &xre, &yre, opts)
     230            return rr
     231        else:
     232            # Complex result
     233            rc = PY_NEW(mpc)
     234            MPF_sub(&rc.re, &xre, &yre, opts)
     235            if typx == 1:
     236                MPF_neg(&rc.im, &yim)
     237                MPF_normalize(&rc.im, opts)
     238            elif typy == 1:
     239                MPF_set(&rc.im, &xim)
     240                MPF_normalize(&rc.im, opts)
     241            else:
     242                MPF_sub(&rc.im, &xim, &yim, opts)
     243            return rc
     244
     245    elif op == OP_MUL:
     246        if typx == 1 and typy == 1:
     247            # Real result
     248            rr = PY_NEW(mpf)
     249            MPF_mul(&rr.value, &xre, &yre, opts)
     250            return rr
     251        else:
     252            # Complex result
     253            rc = PY_NEW(mpc)
     254            if typx == 1:
     255                MPF_mul(&rc.re, &yre, &xre, opts)
     256                MPF_mul(&rc.im, &yim, &xre, opts)
     257            elif typy == 1:
     258                MPF_mul(&rc.re, &xre, &yre, opts)
     259                MPF_mul(&rc.im, &xim, &yre, opts)
     260            else:
     261                # a*c - b*d
     262                MPF_mul(&rc.re, &xre,   &yre,  opts_exact)
     263                MPF_mul(&tmp1,  &xim,   &yim,  opts_exact)
     264                MPF_sub(&rc.re, &rc.re, &tmp1, opts)
     265                # a*d + b*c
     266                MPF_mul(&rc.im, &xre,   &yim,  opts_exact)
     267                MPF_mul(&tmp1,  &xim,   &yre,  opts_exact)
     268                MPF_add(&rc.im, &rc.im, &tmp1, opts)
     269            return rc
     270
     271    elif op == OP_DIV:
     272        if typx == 1 and typy == 1:
     273            # Real result
     274            rr = PY_NEW(mpf)
     275            MPF_div(&rr.value, &xre, &yre, opts)
     276            return rr
     277        else:
     278            rc = PY_NEW(mpc)
     279            if typy == 1:
     280                MPF_div(&rc.re, &xre, &yre, opts)
     281                MPF_div(&rc.im, &xim, &yre, opts)
     282            else:
     283                if typx == 1:
     284                    xim = MPF_C_0
     285                altopts = opts
     286                altopts.prec += 10
     287                # m = c*c + d*d
     288                MPF_mul(&tmp1, &yre,  &yre, opts_exact)
     289                MPF_mul(&tmp2, &yim,  &yim, opts_exact)
     290                MPF_add(&tmp1, &tmp1, &tmp2, altopts)
     291                # (a*c+b*d)/m
     292                MPF_mul(&rc.re, &xre,   &yre, opts_exact)
     293                MPF_mul(&tmp2,  &xim,   &yim, opts_exact)
     294                MPF_add(&rc.re, &rc.re, &tmp2, altopts)
     295                MPF_div(&rc.re, &rc.re, &tmp1, opts)
     296                # (b*c-a*d)/m
     297                MPF_mul(&rc.im, &xim,   &yre, opts_exact)
     298                MPF_mul(&tmp2,  &xre,   &yim, opts_exact)
     299                MPF_sub(&rc.im, &rc.im, &tmp2, altopts)
     300                MPF_div(&rc.im, &rc.im, &tmp1, opts)
     301            return rc
     302
     303    elif op == OP_POW:
     304        if typx == 1 and typy == 1:
     305            if yre.special == S_NORMAL and mpz_sgn(yre.exp) >= 0:
     306                # check if size is reasonable
     307                mpz_add_ui(tmp1.man, yre.exp, mpz_sizeinbase(yre.man,2))
     308                mpz_abs(tmp1.man, tmp1.man)
     309                if mpz_cmp_ui(tmp1.man, 10000) < 0:
     310                    # man * 2^exp
     311                    mpz_mul_2exp(tmp1.man, yre.man, mpz_get_ui(yre.exp))
     312                    rr = PY_NEW(mpf)
     313                    MPF_pow_int(&rr.value, &xre, tmp1.man, opts)
     314                    return rr
     315
     316        # TODO: optimize me
     317        xret = MPF_to_tuple(&xre)
     318        yret = MPF_to_tuple(&yre)
     319        if typx == 1 and typy == 1:
     320            try:
     321                v = libmp.mpf_pow(xret, yret, opts.prec, rndmode_to_python(opts.rounding))
     322                rr = PY_NEW(mpf)
     323                MPF_set_tuple(&rr.value, v)
     324                return rr
     325            except libmp.ComplexResult:
     326                xim = MPF_C_0
     327                yim = MPF_C_0
     328        else:
     329            if typx == 1: xim = MPF_C_0
     330            if typy == 1: yim = MPF_C_0
     331        ximt = MPF_to_tuple(&xim)
     332        yimt = MPF_to_tuple(&yim)
     333        vr, vi = libmp.mpc_pow((xret,ximt), (yret,yimt), opts.prec, rndmode_to_python(opts.rounding))
     334        rc = PY_NEW(mpc)
     335        MPF_set_tuple(&rc.re, vr)
     336        MPF_set_tuple(&rc.im, vi)
     337        return rc
     338
     339    elif op == OP_MOD:
     340        if typx != 1 or typx != 1:
     341            raise TypeError("mod for complex numbers")
     342        xret = MPF_to_tuple(&xre)
     343        yret = MPF_to_tuple(&yre)
     344        v = libmp.mpf_mod(xret, yret, opts.prec, rndmode_to_python(opts.rounding))
     345        rr = PY_NEW(mpf)
     346        MPF_set_tuple(&rr.value, v)
     347        return rr
     348
     349    elif op == OP_EQ:
     350        if typx == 1 and typy == 1:
     351            return MPF_eq(&xre, &yre)
     352        if typx == 1:
     353            return MPF_eq(&xre, &yre) and MPF_eq(&yim, &MPF_C_0)
     354        if typy == 1:
     355            return MPF_eq(&xre, &yre) and MPF_eq(&xim, &MPF_C_0)
     356        return MPF_eq(&xre, &yre) and MPF_eq(&xim, &yim)
     357
     358    elif op == OP_NE:
     359        if typx == 1 and typy == 1:
     360            return MPF_ne(&xre, &yre)
     361        if typx == 1:
     362            return MPF_ne(&xre, &yre) or MPF_ne(&yim, &MPF_C_0)
     363        if typy == 1:
     364            return MPF_ne(&xre, &yre) or MPF_ne(&xim, &MPF_C_0)
     365        return MPF_ne(&xre, &yre) or MPF_ne(&xim, &yim)
     366
     367    elif op == OP_LT:
     368        if typx != 1 or typy != 1:
     369            raise ValueError("cannot compare complex numbers")
     370        return MPF_lt(&xre, &yre)
     371
     372    elif op == OP_GT:
     373        if typx != 1 or typy != 1:
     374            raise ValueError("cannot compare complex numbers")
     375        return MPF_gt(&xre, &yre)
     376
     377    elif op == OP_LE:
     378        if typx != 1 or typy != 1:
     379            raise ValueError("cannot compare complex numbers")
     380        return MPF_le(&xre, &yre)
     381
     382    elif op == OP_GE:
     383        if typx != 1 or typy != 1:
     384            raise ValueError("cannot compare complex numbers")
     385        return MPF_ge(&xre, &yre)
     386
     387    return NotImplemented
     388
     389
     390cdef MPopts global_opts
     391
     392global_context = None
     393
     394cdef class Context:
     395    cdef public mpf, mpc, constant #, def_mp_function
     396    cdef public trap_complex
     397    cdef public pretty
     398
     399    def __cinit__(ctx):
     400        global global_opts, global_context
     401        global_opts = opts_double_precision
     402        global_context = ctx
     403        ctx.mpf = mpf
     404        ctx.mpc = mpc
     405        ctx.constant = constant
     406        #ctx.def_mp_function = def_mp_function
     407        ctx._mpq = rationallib.mpq
     408
     409    def default(ctx):
     410        global_opts = opts_double_precision
     411        ctx.trap_complex = False
     412        ctx.pretty = False
     413
     414    def _get_prec(ctx): return global_opts.prec
     415    def _set_prec(ctx, prec): global_opts.prec = prec
     416    def _set_dps(ctx, n): global_opts.prec = dps_to_prec(int(n))
     417    def _get_dps(ctx): return libmp.prec_to_dps(global_opts.prec)
     418    dps = property(_get_dps, _set_dps)
     419    prec = property(_get_prec, _set_prec)
     420    _dps = property(_get_dps, _set_dps)
     421    _prec = property(_get_prec, _set_prec)
     422
     423    def _get_prec_rounding(ctx):
     424        return global_opts.prec, rndmode_to_python(global_opts.rounding)
     425
     426    _prec_rounding = property(_get_prec_rounding)
     427
     428    cpdef mpf make_mpf(ctx, tuple v):
     429        cdef mpf x
     430        x = PY_NEW(mpf)
     431        MPF_set_tuple(&x.value, v)
     432        return x
     433
     434    cpdef mpc make_mpc(ctx, tuple v):
     435        cdef mpc x
     436        x = PY_NEW(mpc)
     437        MPF_set_tuple(&x.re, v[0])
     438        MPF_set_tuple(&x.im, v[1])
     439        return x
     440
     441    def convert(ctx, x, strings=True):
     442        """
     443        Converts *x* to an ``mpf``, ``mpc`` or ``mpi``. If *x* is of type ``mpf``,
     444        ``mpc``, ``int``, ``float``, ``complex``, the conversion
     445        will be performed losslessly.
     446
     447        If *x* is a string, the result will be rounded to the present
     448        working precision. Strings representing fractions or complex
     449        numbers are permitted.
     450
     451            >>> from mpmath import *
     452            >>> mp.dps = 15; mp.pretty = False
     453            >>> mpmathify(3.5)
     454            mpf('3.5')
     455            >>> mpmathify('2.1')
     456            mpf('2.1000000000000001')
     457            >>> mpmathify('3/4')
     458            mpf('0.75')
     459            >>> mpmathify('2+3j')
     460            mpc(real='2.0', imag='3.0')
     461
     462        """
     463        cdef mpf rr
     464        cdef mpc rc
     465        if PY_TYPE_CHECK(x, mpnumber):
     466            return x
     467        typx = MPF_set_any(&tmp_opx_re, &tmp_opx_im, x, global_opts, strings)
     468        if typx == 1:
     469            rr = PY_NEW(mpf)
     470            MPF_set(&rr.value, &tmp_opx_re)
     471            return rr
     472        if typx == 2:
     473            rc = PY_NEW(mpc)
     474            MPF_set(&rc.re, &tmp_opx_re)
     475            MPF_set(&rc.im, &tmp_opx_im)
     476            return rc
     477        return ctx._convert_fallback(x, strings)
     478
     479    def isnan(ctx, x):
     480        """
     481        For an ``mpf`` *x*, determines whether *x* is not-a-number (nan)::
     482
     483            >>> from mpmath import *
     484            >>> isnan(nan), isnan(3)
     485            (True, False)
     486        """
     487        cdef int typ
     488        if PY_TYPE_CHECK(x, mpf):
     489            return (<mpf>x).value.special == S_NAN
     490        #else:
     491        #    typ = MPF_set_any(&tmp_opx_re, &tmp_opx_im, x, global_opts, 0)
     492        #    if typ == 1:
     493        #        return tmp_opx_re.special == S_NAN
     494        return False
     495
     496    def isinf(ctx, x):
     497        """
     498        For an ``mpf`` *x*, determines whether *x* is infinite::
     499
     500            >>> from mpmath import *
     501            >>> isinf(inf), isinf(-inf), isinf(3)
     502            (True, True, False)
     503        """
     504        cdef int s, typ
     505        if PY_TYPE_CHECK(x, mpf):
     506            s = (<mpf>x).value.special
     507            return s == S_INF or s == S_NINF
     508        #else:
     509        #    typ = MPF_set_any(&tmp_opx_re, &tmp_opx_im, x, global_opts, 0)
     510        #    if typ == 1:
     511        #        s = tmp_opx_re.special
     512        #        return s == S_INF or s == S_NINF
     513        return False
     514
     515    def isint(ctx, x):
     516        """
     517        For an ``mpf`` *x*, or any type that can be converted
     518        to ``mpf``, determines whether *x* is exactly
     519        integer-valued::
     520
     521            >>> from mpmath import *
     522            >>> isint(3), isint(mpf(3)), isint(3.2)
     523            (True, True, False)
     524        """
     525        cdef MPF v
     526        cdef int typ
     527        if PyInt_CheckExact(x) or PyLong_CheckExact(x) or PY_TYPE_CHECK(x, Integer):
     528            return True
     529        if PY_TYPE_CHECK(x, mpf):
     530            v = (<mpf>x).value
     531            return v.special == S_ZERO or (v.special == S_NORMAL and mpz_sgn(v.exp) >= 0)
     532        if PY_TYPE_CHECK(x, mpc):
     533            return False
     534        if PY_TYPE_CHECK(x, rationallib.mpq):
     535            p, q = x
     536            return not (p % q)
     537        typ = MPF_set_any(&tmp_opx_re, &tmp_opx_im, x, global_opts, 0)
     538        if typ == 1:
     539            v = tmp_opx_re
     540            return v.special == S_ZERO or (v.special == S_NORMAL and mpz_sgn(v.exp) >= 0)
     541        return False
     542
     543    def fsum(ctx, terms, bint absolute=False, bint squared=False):
     544        """
     545        Calculates a sum containing a finite number of terms (for infinite
     546        series, see :func:`nsum`). The terms will be converted to
     547        mpmath numbers. For len(terms) > 2, this function is generally
     548        faster and produces more accurate results than the builtin
     549        Python function :func:`sum`.
     550
     551            >>> from mpmath import *
     552            >>> mp.dps = 15; mp.pretty = False
     553            >>> fsum([1, 2, 0.5, 7])
     554            mpf('10.5')
     555
     556        With squared=True each term is squared, and with absolute=True
     557        the absolute value of each term is used.
     558        """
     559        cdef MPF sre, sim, tre, tim, tmp
     560        cdef mpf rr
     561        cdef mpc rc
     562        cdef MPopts workopts
     563        cdef int styp, ttyp
     564        MPF_init(&sre)
     565        MPF_init(&sim)
     566        MPF_init(&tre)
     567        MPF_init(&tim)
     568        MPF_init(&tmp)
     569        workopts = global_opts
     570        workopts.prec = workopts.prec * 2 + 50
     571        workopts.rounding = ROUND_D
     572        unknown = global_context.zero
     573        _sig_on
     574        styp = 1
     575        for term in terms:
     576            ttyp = MPF_set_any(&tre, &tim, term, workopts, 0)
     577            if ttyp == 0:
     578                if absolute: term = ctx.absmax(term)
     579                if squared: term = term**2
     580                unknown += term
     581                continue
     582            if absolute:
     583                if squared:
     584                    if ttyp == 1:
     585                        MPF_mul(&tre, &tre, &tre, opts_exact)
     586                        MPF_add(&sre, &sre, &tre, workopts)
     587                    elif ttyp == 2:
     588                        # |(a+bi)^2| = a^2+b^2
     589                        MPF_mul(&tre, &tre, &tre, opts_exact)
     590                        MPF_add(&sre, &sre, &tre, workopts)
     591                        MPF_mul(&tim, &tim, &tim, opts_exact)
     592                        MPF_add(&sre, &sre, &tim, workopts)
     593                else:
     594                    if ttyp == 1:
     595                        MPF_abs(&tre, &tre)
     596                        MPF_add(&sre, &sre, &tre, workopts)
     597                    elif ttyp == 2:
     598                        # |a+bi| = sqrt(a^2+b^2)
     599                        MPF_mul(&tre, &tre, &tre, opts_exact)
     600                        MPF_mul(&tim, &tim, &tim, opts_exact)
     601                        MPF_add(&tre, &tre, &tim, workopts)
     602                        MPF_sqrt(&tre, &tre, workopts)
     603                        MPF_add(&sre, &sre, &tre, workopts)
     604            elif squared:
     605                if ttyp == 1:
     606                    MPF_mul(&tre, &tre, &tre, opts_exact)
     607                    MPF_add(&sre, &sre, &tre, workopts)
     608                elif ttyp == 2:
     609                    # (a+bi)^2 = a^2-b^2 + 2i*ab
     610                    MPF_mul(&tmp, &tre, &tim, opts_exact)
     611                    MPF_mul(&tmp, &tmp, &MPF_C_2, opts_exact)
     612                    MPF_add(&sim, &sim, &tmp, workopts)
     613                    MPF_mul(&tre, &tre, &tre, opts_exact)
     614                    MPF_add(&sre, &sre, &tre, workopts)
     615                    MPF_mul(&tim, &tim, &tim, opts_exact)
     616                    MPF_sub(&sre, &sre, &tim, workopts)
     617                    styp = 2
     618            else:
     619                if ttyp == 1:
     620                    MPF_add(&sre, &sre, &tre, workopts)
     621                elif ttyp == 2:
     622                    MPF_add(&sre, &sre, &tre, workopts)
     623                    MPF_add(&sim, &sim, &tim, workopts)
     624                    styp = 2
     625        MPF_clear(&tre)
     626        MPF_clear(&tim)
     627        if styp == 1:
     628            rr = PY_NEW(mpf)
     629            MPF_set(&rr.value, &sre)
     630            MPF_clear(&sre)
     631            MPF_clear(&sim)
     632            MPF_normalize(&rr.value, global_opts)
     633            if unknown is not global_context.zero:
     634                return ctx._stupid_add(rr, unknown)
     635            return rr
     636        elif styp == 2:
     637            rc = PY_NEW(mpc)
     638            MPF_set(&rc.re, &sre)
     639            MPF_set(&rc.im, &sim)
     640            MPF_clear(&sre)
     641            MPF_clear(&sim)
     642            MPF_normalize(&rc.re, global_opts)
     643            MPF_normalize(&rc.im, global_opts)
     644            if unknown is not global_context.zero:
     645                return ctx._stupid_add(rc, unknown)
     646            return rc
     647        else:
     648            MPF_clear(&sre)
     649            MPF_clear(&sim)
     650            return +unknown
     651
     652    def fdot(ctx, A, B=None):
     653        r"""
     654        Computes the dot product of the iterables `A` and `B`,
     655
     656        .. math ::
     657
     658            \sum_{k=0} A_k B_k.
     659
     660        Alternatively, :func:`fdot` accepts a single iterable of pairs.
     661        In other words, ``fdot(A,B)`` and ``fdot(zip(A,B))`` are equivalent.
     662
     663        The elements are automatically converted to mpmath numbers.
     664
     665        Examples::
     666
     667            >>> from mpmath import *
     668            >>> mp.dps = 15; mp.pretty = False
     669            >>> A = [2, 1.5, 3]
     670            >>> B = [1, -1, 2]
     671            >>> fdot(A, B)
     672            mpf('6.5')
     673            >>> zip(A, B)
     674            [(2, 1), (1.5, -1), (3, 2)]
     675            >>> fdot(_)
     676            mpf('6.5')
     677        """
     678        if B:
     679            A = zip(A, B)
     680        cdef MPF sre, sim, tre, tim, ure, uim, tmp
     681        cdef mpf rr
     682        cdef mpc rc
     683        cdef MPopts workopts
     684        cdef int styp, ttyp, utyp
     685        MPF_init(&sre)
     686        MPF_init(&sim)
     687        MPF_init(&tre)
     688        MPF_init(&tim)
     689        MPF_init(&ure)
     690        MPF_init(&uim)
     691        MPF_init(&tmp)
     692        workopts = global_opts
     693        workopts.prec = workopts.prec * 2 + 50
     694        workopts.rounding = ROUND_D
     695        unknown = global_context.zero
     696        styp = 1
     697        for a, b in A:
     698            ttyp = MPF_set_any(&tre, &tim, a, workopts, 0)
     699            utyp = MPF_set_any(&ure, &uim, b, workopts, 0)
     700            if ttyp == 0 or utyp == 0:
     701                unknown += a * b
     702                continue
     703            styp = max(styp, ttyp)
     704            styp = max(styp, utyp)
     705            if ttyp == 1:
     706                if utyp == 1:
     707                    MPF_mul(&tre, &tre, &ure, opts_exact)
     708                    MPF_add(&sre, &sre, &tre, workopts)
     709                elif utyp == 2:
     710                    MPF_mul(&ure, &ure, &tre, opts_exact)
     711                    MPF_mul(&uim, &uim, &tre, opts_exact)
     712                    MPF_add(&sre, &sre, &ure, workopts)
     713                    MPF_add(&sim, &sim, &uim, workopts)
     714                    styp = 2
     715            elif ttyp == 2:
     716                styp = 2
     717                if utyp == 1:
     718                    MPF_mul(&tre, &tre, &ure, opts_exact)
     719                    MPF_mul(&tim, &tim, &ure, opts_exact)
     720                    MPF_add(&sre, &sre, &tre, workopts)
     721                    MPF_add(&sim, &sim, &tim, workopts)
     722                elif utyp == 2:
     723                    MPF_mul(&tmp, &tre, &ure, opts_exact)
     724                    MPF_add(&sre, &sre, &tmp, workopts)
     725                    MPF_mul(&tmp, &tim, &uim, opts_exact)
     726                    MPF_sub(&sre, &sre, &tmp, workopts)
     727                    MPF_mul(&tmp, &tim, &ure, opts_exact)
     728                    MPF_add(&sim, &sim, &tmp, workopts)
     729                    MPF_mul(&tmp, &tre, &uim, opts_exact)
     730                    MPF_add(&sim, &sim, &tmp, workopts)
     731        MPF_clear(&tre)
     732        MPF_clear(&tim)
     733        MPF_clear(&ure)
     734        MPF_clear(&uim)
     735        if styp == 1:
     736            rr = PY_NEW(mpf)
     737            MPF_set(&rr.value, &sre)
     738            MPF_clear(&sre)
     739            MPF_clear(&sim)
     740            MPF_normalize(&rr.value, global_opts)
     741            if unknown is not global_context.zero:
     742                return ctx._stupid_add(rr, unknown)
     743            return rr
     744        elif styp == 2:
     745            rc = PY_NEW(mpc)
     746            MPF_set(&rc.re, &sre)
     747            MPF_set(&rc.im, &sim)
     748            MPF_clear(&sre)
     749            MPF_clear(&sim)
     750            MPF_normalize(&rc.re, global_opts)
     751            MPF_normalize(&rc.im, global_opts)
     752            if unknown is not global_context.zero:
     753                return ctx._stupid_add(rc, unknown)
     754            return rc
     755        else:
     756            MPF_clear(&sre)
     757            MPF_clear(&sim)
     758            return +unknown
     759
     760    # Doing a+b directly doesn't work with mpi, presumably due to
     761    # Cython trying to be clever with the operation resolution
     762    cdef _stupid_add(ctx, a, b):
     763        return a + b
     764
     765    def _convert_param(ctx, x):
     766        cdef MPF v
     767        cdef bint ismpf, ismpc
     768        if PyInt_Check(x) or PyLong_Check(x) or PY_TYPE_CHECK(x, Integer):
     769            return int(x), 'Z'
     770        if PY_TYPE_CHECK(x, tuple):
     771            p, q = x
     772            p = int(p)
     773            q = int(q)
     774            if not p % q:
     775                return p // q, 'Z'
     776            return rationallib.mpq((p,q)), 'Q'
     777        if PY_TYPE_CHECK(x, basestring) and '/' in x:
     778            p, q = x.split('/')
     779            p = int(p)
     780            q = int(q)
     781            if not p % q:
     782                return p // q, 'Z'
     783            return rationallib.mpq((p,q)), 'Q'
     784        if PY_TYPE_CHECK(x, constant):
     785            return x, 'R'
     786        ismpf = PY_TYPE_CHECK(x, mpf)
     787        ismpc = PY_TYPE_CHECK(x, mpc)
     788        if not (ismpf or ismpc):
     789            x = global_context.convert(x)
     790            ismpf = PY_TYPE_CHECK(x, mpf)
     791            ismpc = PY_TYPE_CHECK(x, mpc)
     792            if not (ismpf or ismpc):
     793                return x, 'U'
     794        if ismpf:
     795            v = (<mpf>x).value
     796        elif ismpc:
     797            if (<mpc>x).im.special != S_ZERO:
     798                return x, 'C'
     799            x = (<mpc>x).real
     800            v = (<mpc>x).re
     801        # A real number
     802        if v.special == S_ZERO:
     803            return 0, 'Z'
     804        if v.special != S_NORMAL:
     805            return x, 'U'
     806        if mpz_sgn(v.exp) >= 0:
     807            return (mpzi(v.man) << mpzi(v.exp)), 'Z'
     808        if mpz_cmp_si(v.exp, -4) > 0:
     809            p = mpzi(v.man)
     810            q = 1 << (-mpzi(v.exp))
     811            return rationallib.mpq((p,q)), 'Q'
     812        return x, 'R'
     813
     814    def mag(ctx, x):
     815        """
     816        Quick logarithmic magnitude estimate of a number.
     817        Returns an integer or infinity `m` such that `|x| <= 2^m`.
     818        It is not guaranteed that `m` is an optimal bound,
     819        but it will never be off by more than 2 (and probably not
     820        more than 1).
     821        """
     822        cdef int typ
     823        if PyInt_Check(x) or PyLong_Check(x) or PY_TYPE_CHECK(x, Integer):
     824            mpz_set_integer(tmp_opx_re.man, x)
     825            if mpz_sgn(tmp_opx_re.man) == 0:
     826                return global_context.ninf
     827            else:
     828                return mpz_sizeinbase(tmp_opx_re.man,2)
     829        if PY_TYPE_CHECK(x, rationallib.mpq):
     830            p, q = x
     831            mpz_set_integer(tmp_opx_re.man, int(p))
     832            if mpz_sgn(tmp_opx_re.man) == 0:
     833                return global_context.ninf
     834            mpz_set_integer(tmp_opx_re.exp, int(q))
     835            return 1 + mpz_sizeinbase(tmp_opx_re.man,2) + mpz_sizeinbase(tmp_opx_re.exp,2)
     836        typ = MPF_set_any(&tmp_opx_re, &tmp_opx_im, x, global_opts, False)
     837        if typ == 1:
     838            if tmp_opx_re.special == S_ZERO:
     839                return global_context.ninf
     840            if tmp_opx_re.special == S_INF or tmp_opx_re.special == S_NINF:
     841                return global_context.inf
     842            if tmp_opx_re.special != S_NORMAL:
     843                return global_context.nan
     844            mpz_add_ui(tmp_opx_re.exp, tmp_opx_re.exp, mpz_sizeinbase(tmp_opx_re.man, 2))
     845            return mpzi(tmp_opx_re.exp)
     846        if typ == 2:
     847            if tmp_opx_re.special == S_NAN or tmp_opx_im.special == S_NAN:
     848                return global_context.nan
     849            if tmp_opx_re.special == S_INF or tmp_opx_im.special == S_NINF or \
     850               tmp_opx_im.special == S_INF or tmp_opx_im.special == S_NINF:
     851                return global_context.inf
     852            if tmp_opx_re.special == S_ZERO:
     853                if tmp_opx_im.special == S_ZERO:
     854                    return global_context.ninf
     855                else:
     856                    mpz_add_ui(tmp_opx_im.exp, tmp_opx_im.exp, mpz_sizeinbase(tmp_opx_im.man, 2))
     857                    return mpzi(tmp_opx_im.exp)
     858            elif tmp_opx_im.special == S_ZERO:
     859                mpz_add_ui(tmp_opx_re.exp, tmp_opx_re.exp, mpz_sizeinbase(tmp_opx_re.man, 2))
     860                return mpzi(tmp_opx_re.exp)
     861            mpz_add_ui(tmp_opx_im.exp, tmp_opx_im.exp, mpz_sizeinbase(tmp_opx_im.man, 2))
     862            mpz_add_ui(tmp_opx_re.exp, tmp_opx_re.exp, mpz_sizeinbase(tmp_opx_re.man, 2))
     863            if mpz_cmp(tmp_opx_re.exp, tmp_opx_im.exp) >= 0:
     864                mpz_add_ui(tmp_opx_re.exp, tmp_opx_re.exp, 1)
     865                return mpzi(tmp_opx_re.exp)
     866            else:
     867                mpz_add_ui(tmp_opx_im.exp, tmp_opx_im.exp, 1)
     868                return mpzi(tmp_opx_im.exp)
     869        raise TypeError("requires an mpf/mpc")
     870
     871    def _wrap_libmp_function(ctx, mpf_f, mpc_f=None, mpi_f=None, doc="<no doc>"):
     872        name = mpf_f.__name__[4:]
     873        doc = function_docs.__dict__.get(name, "Computes the %s of x" % doc)
     874        # workaround lack of closures in Cython
     875        f_cls = type(name, (wrapped_libmp_function,), {'__doc__':doc})
     876        f = f_cls(mpf_f, mpc_f, mpi_f, doc)
     877        return f
     878
     879    @classmethod
     880    def _wrap_specfun(cls, name, f, wrap):
     881        doc = function_docs.__dict__.get(name, "<no doc>")
     882        if wrap:
     883            # workaround lack of closures in Cython
     884            f_wrapped_cls = type(name, (wrapped_specfun,), {'__doc__':doc})
     885            f_wrapped = f_wrapped_cls(name, f)
     886        else:
     887            f_wrapped = f
     888        f_wrapped.__doc__ = doc
     889        setattr(cls, name, f_wrapped)
     890
     891
     892cdef class wrapped_libmp_function:
     893
     894    cdef public mpf_f, mpc_f, mpi_f, name, __doc__
     895
     896    def __init__(self, mpf_f, mpc_f=None, mpi_f=None, doc="<no doc>"):
     897        self.mpf_f = mpf_f
     898        self.mpc_f = mpc_f
     899        self.mpi_f = mpi_f
     900        self.name = mpf_f.__name__[4:]
     901        self.__doc__ = function_docs.__dict__.get(self.name, "Computes the %s of x" % doc)
     902
     903    def __call__(self, x, **kwargs):
     904        cdef int typx
     905        cdef tuple rev, imv, reu, cxu
     906        cdef mpf rr
     907        cdef mpc rc
     908        prec = global_opts.prec
     909        rounding = rndmode_to_python(global_opts.rounding)
     910        if kwargs:
     911            if 'prec' in kwargs: prec = int(kwargs['prec'])
     912            if 'dps'  in kwargs: prec = libmp.dps_to_prec(int(kwargs['dps']))
     913            if 'rounding' in kwargs: rounding = kwargs['rounding']
     914        typx = MPF_set_any(&tmp_opx_re, &tmp_opx_im, x, global_opts, 1)
     915        if typx == 1:
     916            rev = MPF_to_tuple(&tmp_opx_re)
     917            try:
     918                reu = self.mpf_f(rev, prec, rounding)
     919                rr = PY_NEW(mpf)
     920                MPF_set_tuple(&rr.value, reu)
     921                return rr
     922            except libmp.ComplexResult:
     923                if global_context.trap_complex:
     924                    raise
     925                cxu = self.mpc_f((rev, libmp.fzero), prec, rounding)
     926                rc = PY_NEW(mpc)
     927                MPF_set_tuple(&rc.re, cxu[0])
     928                MPF_set_tuple(&rc.im, cxu[1])
     929                return rc
     930        if typx == 2:
     931            rev = MPF_to_tuple(&tmp_opx_re)
     932            imv = MPF_to_tuple(&tmp_opx_im)
     933            cxu = self.mpc_f((rev, imv), prec, rounding)
     934            rc = PY_NEW(mpc)
     935            MPF_set_tuple(&rc.re, cxu[0])
     936            MPF_set_tuple(&rc.im, cxu[1])
     937            return rc
     938        x = global_context.convert(x)
     939        if hasattr(x, "_mpi_"):
     940            if self.mpi_f:
     941                return global_context.make_mpi(self.mpi_f(x._mpi_, prec))
     942        raise NotImplementedError("%s of a %s" % (self.name, type(x)))
     943
     944
     945
     946cdef class wrapped_specfun:
     947    cdef public f, name, __doc__
     948
     949    def __init__(self, name, f):
     950        self.name = name
     951        self.f = f
     952        self.__doc__ = function_docs.__dict__.get(name, "<no doc>")
     953
     954    def __call__(self, *args, **kwargs):
     955        cdef int origprec
     956        args = [global_context.convert(a) for a in args]
     957        origprec = global_opts.prec
     958        global_opts.prec += 10
     959        try:
     960            retval = self.f(global_context, *args, **kwargs)
     961        finally:
     962            global_opts.prec = origprec
     963        return +retval
     964
     965
     966cdef class mpnumber:
     967    def __richcmp__(self, other, int op):
     968        return binop(OP_RICHCMP+op, self, other, global_opts)
     969    def __add__(self, other): return binop(OP_ADD, self, other, global_opts)
     970    def __sub__(self, other): return binop(OP_SUB, self, other, global_opts)
     971    def __mul__(self, other): return binop(OP_MUL, self, other, global_opts)
     972    def __div__(self, other): return binop(OP_DIV, self, other, global_opts)
     973    def __mod__(self, other): return binop(OP_MOD, self, other, global_opts)
     974    def __pow__(self, other, mod):
     975        if mod is not None:
     976            raise ValueError("three-argument pow not supported")
     977        return binop(OP_POW, self, other, global_opts)
     978    def ae(s, t, rel_eps=None, abs_eps=None):
     979        return global_context.almosteq(s, t, rel_eps, abs_eps)
     980
     981
     982cdef class mpf_base(mpnumber):
     983
     984    # Shared methods for mpf, constant. However, somehow some methods
     985    # (hash?, __richcmp__?) aren't inerited, so they have to
     986    # be defined multiple times. TODO: fix this.
     987
     988    def __hash__(self):
     989        return libmp.mpf_hash(self._mpf_)
     990
     991    def __repr__(self):
     992        if global_context.pretty:
     993            return self.__str__()
     994        n = repr_dps(global_opts.prec)
     995        return "mpf('%s')" % to_str(self._mpf_, n)
     996
     997    def __str__(self):
     998        return to_str(self._mpf_, global_context._str_digits)
     999
     1000    @property
     1001    def real(self): return self
     1002
     1003    @property
     1004    def imag(self): return global_context.zero
     1005
     1006    def conjugate(self): return self
     1007
     1008    @property
     1009    def man(self): return self._mpf_[1]
     1010    @property
     1011    def exp(self): return self._mpf_[2]
     1012    @property
     1013    def bc(self): return self._mpf_[3]
     1014
     1015    # XXX: optimize
     1016    def __int__(self): return int(libmp.to_int(self._mpf_))
     1017    def __long__(self): return long(self.__int__())
     1018    def __float__(self): return libmp.to_float(self._mpf_)
     1019    def __complex__(self): return complex(float(self))
     1020    def to_fixed(self, prec): return libmp.to_fixed(self._mpf_, prec)
     1021    def __getstate__(self): return libmp.to_pickable(self._mpf_)
     1022    def __setstate__(self, val): self._mpf_ = libmp.from_pickable(val)
     1023
     1024
     1025cdef class mpf(mpf_base):
     1026    """
     1027    An mpf instance holds a real-valued floating-point number. mpf:s
     1028    work analogously to Python floats, but support arbitrary-precision
     1029    arithmetic.
     1030    """
     1031
     1032    cdef MPF value
     1033
     1034    def __init__(self, x=0, **kwargs):
     1035        cdef MPopts opts
     1036        opts = global_opts
     1037        if kwargs:
     1038            if 'prec' in kwargs: opts.prec = int(kwargs['prec'])
     1039            if 'dps'  in kwargs: opts.prec = libmp.dps_to_prec(int(kwargs['dps']))
     1040            if 'rounding' in kwargs: opts.rounding = rndmode_from_python(kwargs['rounding'])
     1041        if MPF_set_any(&self.value, &self.value, x, opts, 1) != 1:
     1042            raise TypeError
     1043
     1044    def __reduce__(self): return (mpf, (), self._mpf_)
     1045    def _get_mpf(self): return MPF_to_tuple(&self.value)
     1046    def _set_mpf(self, v): MPF_set_tuple(&self.value, v)
     1047    _mpf_ = property(_get_mpf, _set_mpf)
     1048
     1049    def __nonzero__(self): return self.value.special != S_ZERO
     1050    def __hash__(self): return libmp.mpf_hash(self._mpf_)
     1051
     1052    @property
     1053    def real(self): return self
     1054
     1055    @property
     1056    def imag(self): return global_context.zero
     1057
     1058    def conjugate(self): return self
     1059
     1060    @property
     1061    def man(self): return self._mpf_[1]
     1062
     1063    @property
     1064    def exp(self): return self._mpf_[2]
     1065
     1066    @property
     1067    def bc(self): return self._mpf_[3]
     1068
     1069    def to_fixed(self, long prec):
     1070        # return libmp.to_fixed(self._mpf_, prec)
     1071        MPF_to_fixed(tmp_mpz, &self.value, prec, False)
     1072        cdef Integer r
     1073        r = PY_NEW(Integer)
     1074        mpz_set(r.value, tmp_mpz)
     1075        return r
     1076
     1077    def __int__(self):
     1078        MPF_to_fixed(tmp_mpz, &self.value, 0, True)
     1079        return mpzi(tmp_mpz)
     1080
     1081    def __float__(self):
     1082        return MPF_to_double(&self.value, False)
     1083
     1084    def __getstate__(self):
     1085        return libmp.to_pickable(self._mpf_)
     1086
     1087    def __setstate__(self, val):
     1088        self._mpf_ = libmp.from_pickable(val)
     1089
     1090    def __cinit__(self):
     1091        MPF_init(&self.value)
     1092
     1093    def  __dealloc__(self):
     1094        MPF_clear(&self.value)
     1095
     1096    def __neg__(s):
     1097        cdef mpf r = PY_NEW(mpf)
     1098        MPF_neg(&r.value, &s.value)
     1099        MPF_normalize(&r.value, global_opts)
     1100        return r
     1101
     1102    def __pos__(s):
     1103        cdef mpf r = PY_NEW(mpf)
     1104        MPF_set(&r.value, &s.value)
     1105        MPF_normalize(&r.value, global_opts)
     1106        return r
     1107
     1108    def __abs__(s):
     1109        cdef mpf r = PY_NEW(mpf)
     1110        MPF_abs(&r.value, &s.value)
     1111        MPF_normalize(&r.value, global_opts)
     1112        return r
     1113
     1114    def sqrt(s):
     1115        cdef mpf r = PY_NEW(mpf)
     1116        MPF_sqrt(&r.value, &s.value, global_opts)
     1117        return r
     1118
     1119    def __richcmp__(self, other, int op):
     1120        return binop(OP_RICHCMP+op, self, other, global_opts)
     1121
     1122
     1123
     1124cdef class constant(mpf_base):
     1125    """
     1126    Represents a mathematical constant with dynamic precision.
     1127    When printed or used in an arithmetic operation, a constant
     1128    is converted to a regular mpf at the working precision. A
     1129    regular mpf can also be obtained using the operation +x.
     1130    """
     1131
     1132    cdef public name, func, __doc__
     1133
     1134    def __init__(self, func, name, docname=''):
     1135        self.name = name
     1136        self.func = func
     1137        self.__doc__ = getattr(function_docs, docname, '')
     1138
     1139    def __call__(self, prec=None, dps=None, rounding=None):
     1140        prec2 = global_opts.prec
     1141        rounding2 = rndmode_to_python(global_opts.rounding)
     1142        if not prec: prec = prec2
     1143        if not rounding: rounding = rounding2
     1144        if dps: prec = dps_to_prec(dps)
     1145        return global_context.make_mpf(self.func(prec, rounding))
     1146
     1147    @property
     1148    def _mpf_(self):
     1149        prec = global_opts.prec
     1150        rounding = rndmode_to_python(global_opts.rounding)
     1151        return self.func(prec, rounding)
     1152
     1153    def __repr__(self):
     1154        if global_context.pretty:
     1155            return self.__str__()
     1156        return "<%s: %s~>" % (self.name, global_context.nstr(self))
     1157
     1158    def __nonzero__(self): return self._mpf_ != libmp.fzero
     1159    def __neg__(self): return -mpf(self)
     1160    def __pos__(self): return mpf(self)
     1161    def __abs__(self): return abs(mpf(self))
     1162    def sqrt(self): return mpf(self).sqrt()
     1163
     1164    # XXX: optimize
     1165    def to_fixed(self, prec):
     1166        return libmp.to_fixed(self._mpf_, prec)
     1167    def __getstate__(self):
     1168        return libmp.to_pickable(self._mpf_)
     1169    def __setstate__(self, val):
     1170        self._mpf_ = libmp.from_pickable(val)
     1171
     1172    # WHY?
     1173    def __hash__(self):
     1174        return libmp.mpf_hash(self._mpf_)
     1175    def __richcmp__(self, other, int op):
     1176        return binop(OP_RICHCMP+op, self, other, global_opts)
     1177
     1178
     1179cdef class mpc(mpnumber):
     1180    """
     1181    An mpc represents a complex number using a pair of mpf:s (one
     1182    for the real part and another for the imaginary part.) The mpc
     1183    class behaves fairly similarly to Python's complex type.
     1184    """
     1185
     1186    cdef MPF re
     1187    cdef MPF im
     1188
     1189    def __init__(self, real=0, imag=0):
     1190        cdef int typx, typy
     1191        typx = MPF_set_any(&self.re, &self.im, real, global_opts, 1)
     1192        if typx == 2:
     1193            typy = 1
     1194        else:
     1195            typy = MPF_set_any(&self.im, &self.im, imag, global_opts, 1)
     1196        if typx == 0 or typy != 1:
     1197            raise TypeError
     1198
     1199    def __cinit__(self):
     1200        MPF_init(&self.re)
     1201        MPF_init(&self.im)
     1202
     1203    def  __dealloc__(self):
     1204        MPF_clear(&self.re)
     1205        MPF_clear(&self.im)
     1206
     1207    def __reduce__(self):
     1208        return (mpc, (), self._mpc_)
     1209
     1210    def __setstate__(self, val):
     1211        self._mpc_ = val[0], val[1]
     1212
     1213    def __repr__(self):
     1214        if global_context.pretty:
     1215            return self.__str__()
     1216        re, im = self._mpc_
     1217        n = repr_dps(global_opts.prec)
     1218        return "mpc(real='%s', imag='%s')" % (to_str(re, n), to_str(im, n))
     1219
     1220    def __str__(s):
     1221        return "(%s)" % libmp.mpc_to_str(s._mpc_, global_context._str_digits)
     1222
     1223    def __nonzero__(self):
     1224        return self.re.special != S_ZERO or self.im.special != S_ZERO
     1225
     1226    #def __complex__(self):
     1227    #    a, b = self._mpc_
     1228    #    return complex(libmp.to_float(a), libmp.to_float(b))
     1229
     1230    def __complex__(self):
     1231        return complex(MPF_to_double(&self.re, False), MPF_to_double(&self.im, False))
     1232
     1233    def _get_mpc(self):
     1234        return MPF_to_tuple(&self.re), MPF_to_tuple(&self.im)
     1235
     1236    def _set_mpc(self, tuple v):
     1237        MPF_set_tuple(&self.re, v[0])
     1238        MPF_set_tuple(&self.im, v[1])
     1239
     1240    _mpc_ = property(_get_mpc, _set_mpc)
     1241
     1242    @property
     1243    def real(self):
     1244        cdef mpf r = PY_NEW(mpf)
     1245        MPF_set(&r.value, &self.re)
     1246        return r
     1247
     1248    @property
     1249    def imag(self):
     1250        cdef mpf r = PY_NEW(mpf)
     1251        MPF_set(&r.value, &self.im)
     1252        return r
     1253
     1254    def __hash__(self):
     1255        return libmp.mpc_hash(self._mpc_)
     1256
     1257    def __neg__(s):
     1258        cdef mpc r = PY_NEW(mpc)
     1259        MPF_neg(&r.re, &s.re)
     1260        MPF_neg(&r.im, &s.im)
     1261        MPF_normalize(&r.re, global_opts)
     1262        MPF_normalize(&r.im, global_opts)
     1263        return r
     1264
     1265    def conjugate(s):
     1266        cdef mpc r = PY_NEW(mpc)
     1267        MPF_set(&r.re, &s.re)
     1268        MPF_neg(&r.im, &s.im)
     1269        MPF_normalize(&r.re, global_opts)
     1270        MPF_normalize(&r.im, global_opts)
     1271        return r
     1272
     1273    def __pos__(s):
     1274        cdef mpc r = PY_NEW(mpc)
     1275        MPF_set(&r.re, &s.re)
     1276        MPF_set(&r.im, &s.im)
     1277        MPF_normalize(&r.re, global_opts)
     1278        MPF_normalize(&r.im, global_opts)
     1279        return r
     1280
     1281    def __abs__(s):
     1282        cdef mpf r = PY_NEW(mpf)
     1283        MPF_hypot(&r.value, &s.re, &s.im, global_opts)
     1284        return r
     1285
     1286    def __richcmp__(self, other, int op):
     1287        return binop(OP_RICHCMP+op, self, other, global_opts)
     1288
  • sage/libs/mpmath/utils.pyx

    diff -r eb27a39a6df4 -r a4e8aa8f473d sage/libs/mpmath/utils.pyx
    a b  
    1616from sage.rings.complex_field import ComplexField
    1717from sage.rings.real_mpfr import RealField_constructor as RealField
    1818
     19cpdef int bitcount(n):
     20    cdef Integer m
     21    if PY_TYPE_CHECK(n, Integer):
     22        m = <Integer>n
     23    else:
     24        m = Integer(n)
     25    if mpz_sgn(m.value) == 0:
     26        return 0
     27    return mpz_sizeinbase(m.value, 2)
     28
    1929cpdef from_man_exp(man, exp, long prec = 0, str rnd = 'd'):
    2030    """
    2131    Create normalized mpf value tuple from mantissa and exponent.
     
    2434
    2535    EXAMPLES::
    2636
    27         sage: from mpmath.libmpf import from_man_exp
     37        sage: from mpmath.libmp import from_man_exp
    2838        sage: from_man_exp(-6, -1)
    2939        (1, 3, 0, 2)
    3040        sage: from_man_exp(-6, -1, 1, 'd')
     
    5060
    5161    EXAMPLES::
    5262   
    53         sage: from mpmath.libmpf import normalize
     63        sage: from mpmath.libmp import normalize
    5464        sage: normalize(0, 4, 5, 3, 53, 'n')
    5565        (0, 1, 7, 1)
    5666    """
    5767    cdef long shift
    58     cdef unsigned long haverem
    5968    cdef Integer res
    6069    cdef unsigned long trail
    6170    if mpz_sgn(man.value) == 0:
    62         from mpmath.libmpf import fzero
     71        from mpmath.libmp import fzero
    6372        return fzero
    6473    if bc <= prec and mpz_odd_p(man.value):
    6574        return (sign, man, exp, bc)
     
    109118            mpfr_neg(res, res, GMP_RNDZ)
    110119        mpfr_mul_2si(res, res, exp, GMP_RNDZ)
    111120        return
    112     from mpmath.libmpf import finf, fninf
     121    from mpmath.libmp import finf, fninf
    113122    if exp == 0:
    114123        mpfr_set_ui(res, 0, GMP_RNDZ)
    115124    elif x == finf:
     
    125134    the same number.
    126135    """
    127136    if mpfr_nan_p(value):
    128         from mpmath.libmpf import fnan
     137        from mpmath.libmp import fnan
    129138        return fnan
    130139    if mpfr_inf_p(value):
    131         from mpmath.libmpf import finf, fninf
     140        from mpmath.libmp import finf, fninf
    132141        if mpfr_sgn(value) > 0:
    133142            return finf
    134143        else:
    135144            return fninf
    136145    if mpfr_sgn(value) == 0:
    137         from mpmath.libmpf import fzero
     146        from mpmath.libmp import fzero
    138147        return fzero
    139148    sign = 0
    140149    cdef Integer man = PY_NEW(Integer)