Ticket #5777: trac_5777-12-pickle_expression.patch

File trac_5777-12-pickle_expression.patch, 12.2 KB (added by burcin, 11 years ago)

support pickling pynac expressions

  • c_lib/include/ccobject.h

    # HG changeset patch
    # User Burcin Erocal <burcin@erocal.org>
    # Date 1241564760 -7200
    # Node ID 34b5f93d4e6990e98057eeb4952e0292218320a7
    # Parent  ee6415a94611aacf0c97a4e38df4ea327691366a
    Add pickling support for sage.symbolic.expression.Expression.
    
    diff --git a/c_lib/include/ccobject.h b/c_lib/include/ccobject.h
    a b  
    8484}
    8585
    8686template <class T>
    87 void _from_str(T* dest, char* src){
     87void _from_str(T* dest, const char* src){
    8888  std::istringstream out(src);
    8989  out >> *dest;
    9090}
    9191
    9292template <class T>
     93void _from_str_len(T* dest, const char* src, unsigned int len){
     94  std::istringstream out(std::string(src, len));
     95  out >> *dest;
     96}
     97
     98template <class T>
    9399PyObject* _to_PyString(const T *x)
    94100{
    95101  std::ostringstream instore;
    96102  instore << (*x);
    97   return PyString_FromString(instore.str().data());
     103  std::string instr = instore.str();
     104  // using PyString_FromString truncates the output if whitespace is
     105  // encountered so we use Py_BuildValue and specify the length
     106  return Py_BuildValue("s#",instr.c_str(), instr.size());
    98107}
    99108
    100109#endif
  • c_lib/include/ginac_wrap.h

    diff --git a/c_lib/include/ginac_wrap.h b/c_lib/include/ginac_wrap.h
    a b  
    1414
    1515using namespace GiNaC;
    1616
    17 
    18 const symbol & get_symbol(const std::string & s)
    19 {
    20     static std::map<std::string, symbol> directory;
    21     std::map<std::string, symbol>::iterator i = directory.find(s);
    22     if (i != directory.end())
    23         return i->second;
    24     else
    25         return directory.insert(std::make_pair(s, symbol(s))).first->second;
    26 }
    27 
    2817void list_symbols(const ex& e, std::set<ex, ex_is_less> &s)
    2918{
    3019    if (is_a<symbol>(e)) {
  • sage/libs/ginac/decl.pxi

    diff --git a/sage/libs/ginac/decl.pxi b/sage/libs/ginac/decl.pxi
    a b  
    4545    ctypedef struct GExList "GiNaC::lst":
    4646        GExListIter begin()
    4747        GExListIter end()
     48        GExList append_sym "append" (GSymbol e)
    4849
    4950    ctypedef struct GEx "ex":
    5051        unsigned int gethash()        except +
     
    191192    GSymbol get_symbol(char* s)              except +
    192193    GEx g_collect_common_factors "collect_common_factors" (GEx e) except +
    193194
     195    # standard library string
     196    ctypedef struct stdstring "std::string":
     197        stdstring assign(char* s, Py_ssize_t l)
     198        char* c_str()
     199        unsigned int size()
     200        char at(unsigned int ind)
     201
     202    stdstring* stdstring_construct_cstr \
     203            "new std::string" (char* s, unsigned int l)
     204    void stdstring_delete "Delete<std::string>"(stdstring* s)
     205
     206    # Archive
     207    ctypedef struct GArchive "archive":
     208        void archive_ex(GEx e, char* name) except +
     209        GEx unarchive_ex(GExList sym_lst, unsigned ind) except +
     210        void printraw "printraw(std::cout); " (int t)
     211
     212    object GArchive_to_str "_to_PyString<archive>"(GArchive *s)
     213    void GArchive_from_str "_from_str_len<archive>"(GArchive *ar, char* s,
     214            unsigned int l)
     215
     216
    194217    GEx g_abs "GiNaC::abs" (GEx x)           except +
    195218    GEx g_step "GiNaC::step" (GEx x)         except +  # step function
    196219    GEx g_csgn "GiNaC::csgn" (GEx x)         except + # complex sign
  • sage/symbolic/expression.pyx

    diff --git a/sage/symbolic/expression.pyx b/sage/symbolic/expression.pyx
    a b  
    149149        """
    150150        GEx_destruct(&self._gobj)
    151151
     152    def __getstate__(self):
     153        """
     154        Returns a tuple describing the state of this expression for pickling.
     155
     156        This should return all information that will be required to unpickle
     157        the object. The functionality for unpickling is implemented in
     158        __setstate__().
     159
     160        In order to pickle Expression objects, we return a tuple containing
     161       
     162         * 0  - as pickle version number
     163                in case we decide to change the pickle format in the feature
     164         * names of symbols of this expression
     165         * a string representation of self stored in a Pynac archive.
     166
     167        TESTS::
     168            sage: var('x,y,z',ns=1)
     169            (x, y, z)
     170            sage: t = 2*x*y^z+3
     171            sage: s = dumps(t)
     172
     173            sage: t.__getstate__()
     174            (0,
     175             ['x', 'y', 'z'],
     176             ...)
     177       
     178        """
     179        cdef GArchive ar
     180        ar.archive_ex(self._gobj, "sage_ex")
     181        ar_str = GArchive_to_str(&ar)
     182        return (0, map(repr, self.variables()), ar_str)
     183
     184    def __setstate__(self, state):
     185        """
     186        Initializes the state of the object from data saved in a pickle.
     187
     188        During unpickling __init__ methods of classes are not called, the saved
     189        data is passed to the class via this function instead.
     190
     191        TESTS::
     192            sage: var('x,y,z',ns=1)
     193            (x, y, z)
     194            sage: t = 2*x*y^z+3
     195            sage: u = loads(dumps(t)) # indirect doctest
     196            sage: u
     197            2*y^z*x + 3
     198            sage: bool(t == u)
     199            True
     200            sage: u.subs(x=z)
     201            2*y^z*z + 3
     202
     203            sage: loads(dumps(x.parent()(2)))
     204            2
     205        """
     206        # check input
     207        if state[0] != 0 or len(state) != 3:
     208            raise ValueError, "unknown state information"
     209        # set parent
     210        self._set_parent(ring.NSR)
     211        # get variables
     212        cdef GExList sym_lst
     213        for name in state[1]:
     214            sym_lst.append_sym(get_symbol(name))
     215
     216        # initialize archive
     217        cdef GArchive ar
     218        GArchive_from_str(&ar, state[2], len(state[2]))
     219
     220        # extract the expression from the archive
     221        GEx_construct_ex(&self._gobj, ar.unarchive_ex(sym_lst, <unsigned>0))
     222
    152223    # TODO: The keyword argument simplify is for compatibility with
    153224    # old symbolics, once the switch is complete, it should be removed
    154225    def _repr_(self, simplify=None):
  • sage/symbolic/pynac.pyx

    diff --git a/sage/symbolic/pynac.pyx b/sage/symbolic/pynac.pyx
    a b  
    131131# We declare the functions defined below as extern here, to prevent Cython
    132132# from generating separate declarations for them which confuse g++
    133133cdef extern from *:
    134     char* py_latex(object o) except +
    135     char* py_latex_variable(char* var_name) except +
    136     char* py_print_function(unsigned id, object args) except +
    137     char* py_latex_function(unsigned id, object args) except +
    138     char* py_print_fderivative(unsigned id, object params, object args) except +
    139     char* py_latex_fderivative(unsigned id, object params, object args) except +
     134    stdstring* py_latex(object o) except +
     135    stdstring* py_latex_variable(char* var_name) except +
     136    stdstring* py_print_function(unsigned id, object args) except +
     137    stdstring* py_latex_function(unsigned id, object args) except +
     138    stdstring* py_print_fderivative(unsigned id, object params, object args) except +
     139    stdstring* py_latex_fderivative(unsigned id, object params, object args) except +
    140140
    141 cdef public char* py_latex(object o) except +:
     141cdef public stdstring* py_latex(object o) except +:
    142142    from sage.misc.latex import latex
    143143    s = latex(o)
    144     return s
     144    return string_from_pystr(s)
    145145
    146 cdef extern from "string.h":
    147     Py_ssize_t strlen(char*)
    148     char* strncpy(char* dest, char* src, Py_ssize_t n)
    149 
    150 cdef char* cstr_from_pystr(object py_str):
     146cdef stdstring* string_from_pystr(object py_str):
    151147    """
    152     Creates a C string with the same contents as the given python string.
     148    Creates a C++ string with the same contents as the given python string.
    153149
    154150    Used when passing string output to pynac for printing, since we don't want
    155151    to mess with reference counts of the python objects and we cannot guarantee
    156152    they won't be garbage collected before the output is printed.
    157153    """
    158     cdef char *t_str, *o_str
    159     t_str = PyString_AsString(py_str)
    160     cdef Py_ssize_t slen = strlen(t_str)+1
    161     o_str = <char*>sage_malloc(sizeof(char)*slen)
    162     strncpy(o_str, t_str, slen)
    163     return o_str
     154    cdef char *t_str = PyString_AsString(py_str)
     155    cdef Py_ssize_t slen = len(py_str)
     156    cdef stdstring* sout = stdstring_construct_cstr(t_str, slen)
     157    return sout
    164158
    165 cdef public char* py_latex_variable(char* var_name) except +:
     159cdef public stdstring* py_latex_variable(char* var_name) except +:
    166160    """
    167     Returns a c string containing the latex representation of the given
     161    Returns a c++ string containing the latex representation of the given
    168162    variable name.
    169163
    170164    Real work is done by the function sage.misc.latex.latex_variable_name.
     
    190184    cdef Py_ssize_t slen
    191185    from sage.misc.latex import latex_variable_name
    192186    py_vlatex = latex_variable_name(var_name)
    193     return cstr_from_pystr(py_vlatex)
     187    return string_from_pystr(py_vlatex)
    194188
    195189def py_latex_variable_for_doctests(x):
    196     cdef char* ostr
    197     ostr = py_latex_variable(PyString_AsString(x))
    198     print(ostr)
    199     sage_free(ostr)
     190    cdef stdstring* ostr = py_latex_variable(PyString_AsString(x))
     191    print(ostr.c_str())
     192    stdstring_delete(ostr)
    200193
    201194def py_print_function_pystring(id, args, fname_paren=False):
    202195    """
     
    262255    olist.extend(['(', ', '.join(map(repr, args)), ')'])
    263256    return ''.join(olist)
    264257
    265 cdef public char* py_print_function(unsigned id, object args) except +:
    266     return cstr_from_pystr(py_print_function_pystring(id, args))
     258cdef public stdstring* py_print_function(unsigned id, object args) except +:
     259    return string_from_pystr(py_print_function_pystring(id, args))
    267260
    268261def py_latex_function_pystring(id, args, fname_paren=False):
    269262    """
     
    346339        r'\right)'] )
    347340    return ''.join(olist)
    348341
    349 cdef public char* py_latex_function(unsigned id, object args) except +:
    350     return cstr_from_pystr(py_latex_function_pystring(id, args))
     342cdef public stdstring* py_latex_function(unsigned id, object args) except +:
     343    return string_from_pystr(py_latex_function_pystring(id, args))
    351344
    352 cdef public char* py_print_fderivative(unsigned id, object params, object args)\
    353         except +:
     345cdef public stdstring* py_print_fderivative(unsigned id, object params,
     346        object args) except +:
    354347    """
    355348    Return a string with the representation of the derivative of the symbolic
    356349    function specified by the given id, lists of params and args.
     
    394387    ostr = ''.join(['D[', ', '.join([repr(int(x)) for x in params]), ']'])
    395388    fstr = py_print_function_pystring(id, args, True)
    396389    py_res = ostr + fstr
    397     return cstr_from_pystr(py_res)
     390    return string_from_pystr(py_res)
    398391
    399392def py_print_fderivative_for_doctests(id, params, args):
    400     cdef char* ostr
    401     ostr = py_print_fderivative(id, params, args)
    402     print(ostr)
    403     sage_free(ostr)
     393    cdef stdstring* ostr = py_print_fderivative(id, params, args)
     394    print(ostr.c_str())
     395    stdstring_delete(ostr)
    404396
    405 cdef public char* py_latex_fderivative(unsigned id, object params, object args)\
    406         except +:
     397cdef public stdstring* py_latex_fderivative(unsigned id, object params,
     398        object args) except +:
    407399    """
    408400    Return a string with the latex representation of the derivative of the
    409401    symbolic function specified by the given id, lists of params and args.
     
    452444    ostr = ''.join(['D[', ', '.join([repr(int(x)) for x in params]), ']'])
    453445    fstr = py_latex_function_pystring(id, args, True)
    454446    py_res = ostr + fstr
    455     return cstr_from_pystr(py_res)
     447    return string_from_pystr(py_res)
    456448
    457449def py_latex_fderivative_for_doctests(id, params, args):
    458     cdef char* ostr
    459     ostr = py_latex_fderivative(id, params, args)
    460     print(ostr)
    461     sage_free(ostr)
     450    cdef stdstring* ostr = py_latex_fderivative(id, params, args)
     451    print(ostr.c_str())
     452    stdstring_delete(ostr)
    462453
     454#################################################################
     455# Archive helpers
     456#################################################################
     457
     458cdef extern from *:
     459    stdstring* py_dumps(object o) except +
     460    object py_loads(object s) except +
     461
     462from sage.structure.sage_object import loads, dumps
     463cdef public stdstring* py_dumps(object o) except +:
     464    s = dumps(o, compress=False)
     465    return string_from_pystr(s)
     466
     467cdef public object py_loads(object s) except +:
     468    return loads(s)
    463469
    464470#################################################################
    465471# Modular helpers