Ticket #5777: trac_5777-13-pickle_sfunction.patch

File trac_5777-13-pickle_sfunction.patch, 13.5 KB (added by burcin, 11 years ago)

support pickling symbolic functions

  • sage/symbolic/function.pxd

    # HG changeset patch
    # User Burcin Erocal <burcin@erocal.org>
    # Date 1241564830 -7200
    # Node ID 012d71daae6bd233cc39a5c0ab69ddba6ec294b0
    # Parent  34b5f93d4e6990e98057eeb4952e0292218320a7
    Add pickling support for sage.symbolic.function.SFunction.
    
    diff --git a/sage/symbolic/function.pxd b/sage/symbolic/function.pxd
    a b  
    1616    cdef object print_latex_f
    1717    cdef object latex_name
    1818
     19    # cache hash value
     20    cdef long _hash_(self)
     21    cdef bint __hinit
     22    cdef long __hcache
     23
     24    # common initialization from __init__ and __setstate__
     25    cdef _init_(self)
  • sage/symbolic/function.pyx

    diff --git a/sage/symbolic/function.pyx b/sage/symbolic/function.pyx
    a b  
    2424cdef dict sfunction_serial_dict = {}
    2525
    2626from pynac import get_ginac_serial
     27
     28from sage.misc.fpickle import pickle_function, unpickle_function
     29
    2730cdef class SFunction:
    2831    """
    2932    Return a formal symbolic function.
     
    148151            ValueError: eval_func parameter must be callable
    149152
    150153        """
    151         cdef GFunctionOpt opt
    152         serial = -1
    153         try:
    154             self.serial = find_function(name, nargs)
    155         except ValueError, err:
    156             pass
    157        
    158         opt = g_function_options_args(name, nargs)
    159         opt.set_python_func()
     154        self.name = name
     155        if nargs:
     156            self.nargs = nargs
     157        else:
     158            self.nargs = 0
    160159
    161160        self.eval_f = eval_func
    162161        if eval_func:
    163162            if not callable(eval_func):
    164163                raise ValueError, "eval_func parameter must be callable"
    165             opt.eval_func(eval_func)
    166164
    167165        self.evalf_f = evalf_func
    168166        if evalf_func:
    169167            if not callable(evalf_func):
    170168                raise ValueError, "evalf_func parameter must be callable"
    171             opt.evalf_func(evalf_func)
    172169
    173170        self.conjugate_f = conjugate_func
    174171        if conjugate_func:
    175172            if not callable(conjugate_func):
    176173                raise ValueError, "conjugate_func parameter must be callable"
    177             opt.conjugate_func(conjugate_func)
    178174
    179175        self.real_part_f = real_part_func
    180176        if real_part_func:
    181177            if not callable(real_part_func):
    182178                raise ValueError, "real_part_func parameter must be callable"
    183             opt.real_part_func(real_part_func)
    184179
    185180        self.imag_part_f = imag_part_func
    186181        if imag_part_func:
    187182            if not callable(imag_part_func):
    188183                raise ValueError, "imag_part_func parameter must be callable"
    189             opt.imag_part_func(imag_part_func)
    190184
    191185        self.derivative_f = derivative_func
    192186        if derivative_func:
    193187            if not callable(derivative_func):
    194188                raise ValueError, "derivative_func parameter must be callable"
    195             opt.derivative_func(derivative_func)
    196189
    197190        self.power_f = power_func
    198191        if power_func:
    199192            if not callable(power_func):
    200193                raise ValueError, "power_func parameter must be callable"
    201             opt.power_func(power_func)
    202194
    203195        self.series_f = series_func
    204196        if series_func:
    205197            if not callable(series_func):
    206198                raise ValueError, "series_func parameter must be callable"
    207             opt.series_func(series_func)
    208199
    209200        # handle custom printing
    210201        # if print_func is defined, it is used instead of name
     
    214205            raise ValueError, "only one of latex_name and print_latex_name should be specified."
    215206
    216207        self.print_latex_f = print_latex_func
    217         if print_latex_func:
    218             opt.set_print_latex_func(print_latex_func)
     208        self.latex_name = latex_name
     209        self.print_f = print_func
    219210
    220         self.latex_name = latex_name
    221         if latex_name:
    222             opt.latex_name(latex_name)
     211        self._init_()
    223212
    224         self.print_f = print_func
    225         if print_func:
    226             opt.set_print_dflt_func(print_func)
     213    # this is separated from the constructor since it is also called
     214    # during unpickling
     215    cdef _init_(self):
     216        # see if there is already an SFunction with the same state
     217        cdef SFunction sfunc
     218        cdef long myhash = self._hash_()
     219        for sfunc in sfunction_serial_dict.itervalues():
     220            if myhash == sfunc._hash_():
     221                # found one, set self.serial to be a copy
     222                self.serial = sfunc.serial
     223                return
    227224
     225        cdef GFunctionOpt opt
     226       
     227        opt = g_function_options_args(self.name, self.nargs)
     228        opt.set_python_func()
    228229
    229         if serial is -1 or serial < get_ginac_serial():
    230             self.serial = g_register_new(opt)
    231         else:
    232             self.serial = serial
     230        if self.eval_f:
     231            opt.eval_func(self.eval_f)
     232
     233        if self.evalf_f:
     234            opt.evalf_func(self.evalf_f)
     235
     236        if self.conjugate_f:
     237            opt.conjugate_func(self.conjugate_f)
     238
     239        if self.real_part_f:
     240            opt.real_part_func(self.real_part_f)
     241
     242        if self.imag_part_f:
     243            opt.imag_part_func(self.imag_part_f)
     244
     245        if self.derivative_f:
     246            opt.derivative_func(self.derivative_f)
     247
     248        if self.power_f:
     249            opt.power_func(self.power_f)
     250
     251        if self.series_f:
     252            opt.series_func(self.series_f)
     253
     254        if self.print_latex_f:
     255            opt.set_print_latex_func(self.print_latex_f)
     256
     257        if self.latex_name:
     258            opt.latex_name(self.latex_name)
     259
     260        if self.print_f:
     261            opt.set_print_dflt_func(self.print_f)
     262
     263        self.serial = g_register_new(opt)
    233264
    234265        g_foptions_assign(g_registered_functions().index(self.serial), opt)
    235266        global sfunction_serial_dict
    236267        sfunction_serial_dict[self.serial] = self
    237268
    238         self.name = name
    239         if nargs:
    240             self.nargs = nargs
    241         else:
    242             self.nargs = 0
     269        self.__hinit = False
     270
     271    # cache the hash value of this function
     272    # this is used very often while unpickling to see if there is already
     273    # a function with the same properties
     274    cdef long _hash_(self) except +:
     275        if not self.__hinit:
     276            # create a string representation of this SFunction
     277            slist = [self.nargs, self.name, str(self.latex_name)]
     278            for f in [self.eval_f, self.evalf_f, self.conjugate_f,
     279                    self.real_part_f, self.imag_part_f, self.derivative_f,
     280                    self.power_f, self.series_f, self.print_f,
     281                    self.print_latex_f]:
     282                if f:
     283                    slist.append(hash(f.func_code))
     284                else:
     285                    slist.append(' ')
     286            self.__hcache = hash(tuple(slist))
     287            self.__hinit = True
     288        return self.__hcache
     289
     290    def __hash__(self):
     291        """
     292        TESTS::
     293            sage: from sage.symbolic.function import function as nfunction
     294            sage: foo = nfunction("foo", 2)
     295            sage: hash(foo)
     296            7313648655953480146
     297
     298            sage: def ev(x): return 2*x
     299            sage: foo = nfunction("foo", 2, eval_func = ev)
     300            sage: hash(foo)
     301            4884169210301491732
     302
     303        """
     304        return self.serial*self._hash_()
     305
     306    def __getstate__(self):
     307        """
     308        Returns a tuple describing the state of this object for pickling.
     309
     310        Pickling SFunction objects is limited by the ability to pickle
     311        functions in python. We use sage.misc.fpickle.pickle_function for
     312        this purpose, which only works if there are no nested functions.
     313
     314
     315        This should return all information that will be required to unpickle
     316        the object. The functionality for unpickling is implemented in
     317        __setstate__().
     318
     319        In order to pickle SFunction objects, we return a tuple containing
     320       
     321         * 0  - as pickle version number
     322                in case we decide to change the pickle format in the feature
     323         * name of this function
     324         * number of arguments
     325         * latex_name
     326         * a tuple containing attempts to pickle the following optional
     327           functions, in the order below
     328           * eval_f
     329           * evalf_f
     330           * conjugate_f
     331           * real_part_f
     332           * imag_part_f
     333           * derivative_f
     334           * power_f
     335           * series_f
     336           * print_f
     337           * print_latex_f
     338
     339        EXAMPLES::
     340            sage: from sage.symbolic.function import function as nfunction
     341            sage: foo = nfunction("foo", 2)
     342            sage: foo.__getstate__()
     343            (0, 'foo', 2, None, [None, None, None, None, None, None, None, None, None, None])
     344            sage: t = loads(dumps(foo))
     345            sage: t == foo
     346            True
     347            sage: var('x,y',ns=1)
     348            (x, y)
     349            sage: t(x,y)
     350            foo(x, y)
     351
     352            sage: def ev(x,y): return 2*x
     353            sage: foo = nfunction("foo", 2, eval_func = ev)
     354            sage: foo.__getstate__()
     355            (0, 'foo', 2, None, ["...", None, None, None, None, None, None, None, None, None])
     356
     357            sage: u = loads(dumps(foo))
     358            sage: u == foo
     359            True
     360            sage: t == u
     361            False
     362            sage: u(y,x)
     363            2*y
     364
     365            sage: def evalf_f(x, prec=0): return int(6)
     366            sage: foo = nfunction("foo", 1, evalf_func=evalf_f)
     367            sage: foo.__getstate__()
     368            (0, 'foo', 1, None, [None, "...", None, None, None, None, None, None, None, None])
     369
     370            sage: v = loads(dumps(foo))
     371            sage: v == foo
     372            True
     373            sage: v == u
     374            False
     375            sage: foo(y).n()
     376            6
     377            sage: v(y).n()
     378            6
     379
     380        Test pickling expressions with symbolic functions::
     381           
     382            sage: u = loads(dumps(foo(x)^2 + foo(y) + x^y)); u
     383            x^y + foo(x)^2 + foo(y)
     384            sage: u.subs(y=0)
     385            foo(x)^2 + foo(0) + 1
     386            sage: u.subs(y=0).n()
     387            43.0000000000000
     388        """
     389        return (0, self.name, self.nargs, self.latex_name,
     390                map(pickle_wrapper, [self.eval_f, self.evalf_f,
     391                    self.conjugate_f, self.real_part_f, self.imag_part_f,
     392                    self.derivative_f, self.power_f, self.series_f,
     393                    self.print_f, self.print_latex_f]))
     394
     395    def __setstate__(self, state):
     396        """
     397        Initializes the state of the object from data saved in a pickle.
     398
     399        During unpickling __init__ methods of classes are not called, the saved
     400        data is passed to the class via this function instead.
     401
     402        TESTS::
     403
     404            sage: from sage.symbolic.function import function as nfunction
     405            sage: var('x,y', ns=1)
     406            (x, y)
     407            sage: foo = nfunction("foo", 2)
     408            sage: bar = nfunction("bar", 1)
     409            sage: bar.__setstate__(foo.__getstate__())
     410
     411        Note that the other direction doesn't work here, since foo._hash_()
     412        hash already been initialized.::
     413
     414            sage: bar
     415            foo
     416            sage: bar(x,y)
     417            foo(x, y)
     418        """
     419        # check input
     420        if state[0] != 0 or len(state) != 5:
     421            raise ValueError, "unknown state information"
     422
     423        self.name = state[1]
     424        self.nargs = state[2]
     425        self.latex_name = state[3]
     426        self.eval_f = unpickle_wrapper(state[4][0])
     427        self.evalf_f = unpickle_wrapper(state[4][1])
     428        self.conjugate_f = unpickle_wrapper(state[4][2])
     429        self.real_part_f = unpickle_wrapper(state[4][3])
     430        self.imag_part_f = unpickle_wrapper(state[4][4])
     431        self.derivative_f = unpickle_wrapper(state[4][5])
     432        self.power_f = unpickle_wrapper(state[4][6])
     433        self.series_f = unpickle_wrapper(state[4][7])
     434        self.print_f = unpickle_wrapper(state[4][8])
     435        self.print_latex_f = unpickle_wrapper(state[4][9])
     436
     437        self._init_()
    243438
    244439    def __repr__(self):
    245440        """
     
    404599    global sfunction_serial_dict
    405600    return sfunction_serial_dict.get(serial)
    406601
     602import base64
     603def pickle_wrapper(f):
     604    if f is None:
     605        return None
     606    return pickle_function(f)
     607
     608def unpickle_wrapper(p):
     609    if p is None:
     610        return None
     611    return unpickle_function(p)
     612
    407613def init_sfunction_map():
    408614    """
    409615    Initializes a list mapping GiNaC function serials to the equivalent Sage
  • sage/symbolic/pynac.pyx

    diff --git a/sage/symbolic/pynac.pyx b/sage/symbolic/pynac.pyx
    a b  
    458458cdef extern from *:
    459459    stdstring* py_dumps(object o) except +
    460460    object py_loads(object s) except +
     461    object py_get_sfunction_from_serial(unsigned s) except +
     462    unsigned py_get_serial_from_sfunction(SFunction f) except +
    461463
    462464from sage.structure.sage_object import loads, dumps
    463465cdef public stdstring* py_dumps(object o) except +:
    464466    s = dumps(o, compress=False)
     467    # pynac archive format terminates atoms with zeroes.
     468    # since pickle output can break the archive format
     469    # we use the base64 data encoding
     470    import base64
     471    s = base64.b64encode(s)
    465472    return string_from_pystr(s)
    466473
    467474cdef public object py_loads(object s) except +:
     475    import base64
     476    s = base64.b64decode(s)
    468477    return loads(s)
    469478
     479cdef public object py_get_sfunction_from_serial(unsigned s) except +:
     480    return get_sfunction_from_serial(s)
     481
     482cdef public unsigned py_get_serial_from_sfunction(SFunction f) except +:
     483    return f.serial
     484
    470485#################################################################
    471486# Modular helpers
    472487#################################################################