# Ticket #11837: newton_basins_more_cython.spyx

File newton_basins_more_cython.spyx, 7.3 KB (added by SimonKing, 9 years ago)

Put more cython into the basins

Line
1from sage.rings.complex_double cimport ComplexDoubleElement
2cimport numpy as cnumpy
3from sage.plot.primitive import GraphicPrimitive
4from sage.misc.decorators import options
5from sage.misc.misc import srange
6from sage.symbolic.ring import var
7
8cdef inline ComplexDoubleElement new_CDF_element(double w, double v):
9    cdef ComplexDoubleElement z = <ComplexDoubleElement>PY_NEW(ComplexDoubleElement)
10    z._complex.dat[0] = w
11    z._complex.dat[1] = v
12    return z
13
14
15
16def complex_to_rgb(roots,z_values):
17    """
18    INPUT:
19
20    - z_values -- A grid of complex numbers, as a list of lists
21
22    OUTPUT:
23
24    An N \\times M \\times 3 floating point Numpy array X, where
25    X[i,j] is an (r,g,b) tuple.
26
27    EXAMPLES::
28
29        sage: from sage.plot.complex_plot import complex_to_rgb
30        sage: complex_to_rgb([[0, 1, 1000]])
31        array([[[ 0.        ,  0.        ,  0.        ],
32                [ 0.77172568,  0.        ,  0.        ],
33                [ 1.        ,  0.64421177,  0.64421177]]])
34        sage: complex_to_rgb([[0, 1j, 1000j]])
35        array([[[ 0.        ,  0.        ,  0.        ],
36                [ 0.38586284,  0.77172568,  0.        ],
37                [ 0.82210588,  1.        ,  0.64421177]]])
38    """
39    import numpy
40    cdef unsigned int i, j, imax, jmax, n
41    cdef ComplexDoubleElement z
42    from sage.rings.complex_double import CDF
43    from sage.plot.colors import rainbow
44
45    imax = len(z_values)
46    jmax = len(z_values[0])
47    cdef cnumpy.ndarray[cnumpy.float_t, ndim=3, mode='c'] rgb = numpy.empty(dtype=numpy.float, shape=(imax, jmax, 3))
48
49    n = len(roots)
50    R = rainbow(n)
51    R = [Color(col).rgb() for col in R]
52
53    sig_on()
54    for i from 0 <= i < imax:
55
56        row = z_values[i]
57        for j from 0 <= j < jmax:
58
59            zz = row[j]
60
61            try:
62                color = R[roots.index(zz)]
63                rgb[i, j, 0] = color[0]
64                rgb[i, j, 1] = color[1]
65                rgb[i, j, 2] = color[2]
66
67            except:
68                 print zz
69    sig_off()
70    return rgb
71
72
73def newt(func):
74    vari = func.args()[0]
75    fprime = diff(func,vari)
76    return vari-func/fprime
77
78
79
80class BasinPlot(GraphicPrimitive):
81    def __init__(self, roots, z_values, xrange, yrange, options):
82        """
83        TESTS::
84
85            sage: p = complex_plot(lambda z: z^2-1, (-2, 2), (-2, 2))
86        """
87        self.roots = roots
88        self.xrange = xrange
89        self.yrange = yrange
90        self.z_values = z_values
91        self.x_count = len(z_values)
92        self.y_count = len(z_values[0])
93        self.rgb_data = complex_to_rgb(roots,z_values)
94        GraphicPrimitive.__init__(self, options)
95
96    def get_minmax_data(self):
97        """
98        Returns a dictionary with the bounding box data.
99
100        EXAMPLES::
101
102            sage: p = complex_plot(lambda z: z, (-1, 2), (-3, 4))
103            sage: sorted(p.get_minmax_data().items())
104            [('xmax', 2.0), ('xmin', -1.0), ('ymax', 4.0), ('ymin', -3.0)]
105        """
106        from sage.plot.plot import minmax_data
107        return minmax_data(self.xrange, self.yrange, dict=True)
108
109    def _allowed_options(self):
110        """
111        TESTS::
112
113            sage: isinstance(complex_plot(lambda z: z, (-1,1), (-1,1))[0]._allowed_options(), dict)
114            True
115        """
116        return {'plot_points':'How many points to use for plotting precision',
117                'interpolation':'What interpolation method to use'}
118
119    def _repr_(self):
120        """
121        TESTS::
122
123            sage: isinstance(complex_plot(lambda z: z, (-1,1), (-1,1))[0]._repr_(), str)
124            True
125        """
126        return "BasinPlot defined by a %s x %s data grid"%(self.x_count, self.y_count)
127
128    def _render_on_subplot(self, subplot):
129        """
130        TESTS::
131
132            sage: complex_plot(lambda x: x^2, (-5, 5), (-5, 5))
133        """
134        options = self.options()
135        x0,x1 = float(self.xrange[0]), float(self.xrange[1])
136        y0,y1 = float(self.yrange[0]), float(self.yrange[1])
137        subplot.imshow(self.rgb_data, origin='lower', extent=(x0,x1,y0,y1), interpolation=options['interpolation'])
138
139
140cdef class RootFinder:
141    cdef list roots
142    cdef object newtf
143    cdef ComplexDoubleElement cutoff
144    def __init__(self, roots, newtf):
145        self.newtf = newtf
146        self.roots = roots
147        self.cutoff = CDF(2)/3
148    cpdef ComplexDoubleElement which_root(self,Varia):
149        cdef int counter
150        cdef ComplexDoubleElement root
151        cdef ComplexDoubleElement varia = Varia
152        for counter from 0<=counter<20:
153            varia = self.newtf(varia)
154            #print z
155            for root in self.roots:
156                if (<ComplexDoubleElement>varia._sub_(root)).abs()<self.cutoff:
157                    return root
158        cdef list ls = []
159        for root in self.roots:
160            ls.append(varia._sub_(root).abs())
161        return self.roots[ls.index(min(ls))]
162
163
164@options(plot_points=3, interpolation='nearest')
165def basin_plot(list roots, xrange, yrange, **options):
166    r"""
167    complex_plot takes a complex function of one variable,
168    f(z) and plots output of the function over the specified
169    xrange and yrange as demonstrated below. The magnitude of the
170    output is indicated by the brightness (with zero being black and
171    infinity being white) while the argument is represented by the
172    hue (with red being positive real, and increasing through orange,
173    yellow, ... as the argument increases).
174
175    complex_plot(f, (xmin, xmax), (ymin, ymax), ...)
176
177    INPUT:
178
179    - f -- a function of a single complex value x + iy
180
181    - (xmin, xmax) -- 2-tuple, the range of x values
182
183    - (ymin, ymax) -- 2-tuple, the range of y values
184
185    The following inputs must all be passed in as named parameters:
186
187    - plot_points -- integer (default: 100); number of points to plot
188      in each direction of the grid
189
190    - interpolation -- string (default: 'catrom'), the interpolation
191      method to use: 'bilinear', 'bicubic', 'spline16',
192      'spline36', 'quadric', 'gaussian', 'sinc',
193      'bessel', 'mitchell', 'lanczos', 'catrom',
194      'hermite', 'hanning', 'hamming', 'kaiser'
195
196
197    """
198    from sage.plot.plot import Graphics
199    from sage.plot.misc import setup_for_eval_on_grid
200    from sage.ext.fast_callable import fast_callable
201    from sage.rings.complex_double import CDF
202    roots = [CDF(temp) for temp in roots]
203
204    x = var('x')
205    prefunc = prod([(x-root) for root in roots])
206    f = fast_callable(prefunc, domain=CDF, expect_one_var=True)
207    newtf = fast_callable(newt(prefunc), domain=CDF, expect_one_var=True)
208
209    cdef ComplexDoubleElement cutoff = CDF(2./3)
210
211    cdef RootFinder Finder = RootFinder(roots,newtf)
212    f1 = Finder.which_root
213
214    cdef double t, u
215    ignore, ranges = setup_for_eval_on_grid([], [xrange, yrange], options['plot_points'])
216    xrange,yrange=[r[:2] for r in ranges]
217    sig_on()
218    z_values = [[  f1(new_CDF_element(t, u)) for t in srange(*ranges[0], include_endpoint=True)]
219                                            for u in srange(*ranges[1], include_endpoint=True)]
220#    print z_values
221    sig_off()
222    g = Graphics()
223    g._set_extra_kwds(Graphics._extract_kwds_for_show(options, ignore=['xmin', 'xmax']))
224    g.add_primitive(BasinPlot(roots, z_values, xrange, yrange, options))
225    return g