# Ticket #11837: newton_basins.spyx

File newton_basins.spyx, 7.1 KB (added by kcrisman, 9 years ago)

VERY rough draft of eventual file - not a Mercurial patch

Line
1from sage.rings.complex_double cimport ComplexDoubleElement
2cimport numpy as cnumpy
3from sage.plot.primitive import GraphicPrimitive
4from sage.misc.decorators import options
5from sage.misc.misc import srange
6from sage.symbolic.ring import var
7
8cdef inline ComplexDoubleElement new_CDF_element(double w, double v):
9    cdef ComplexDoubleElement z = <ComplexDoubleElement>PY_NEW(ComplexDoubleElement)
10    z._complex.dat[0] = w
11    z._complex.dat[1] = v
12    return z
13
14
15
16def complex_to_rgb(roots,z_values):
17    """
18    INPUT:
19
20    - z_values -- A grid of complex numbers, as a list of lists
21
22    OUTPUT:
23
24    An N \\times M \\times 3 floating point Numpy array X, where
25    X[i,j] is an (r,g,b) tuple.
26
27    EXAMPLES::
28
29        sage: from sage.plot.complex_plot import complex_to_rgb
30        sage: complex_to_rgb([[0, 1, 1000]])
31        array([[[ 0.        ,  0.        ,  0.        ],
32                [ 0.77172568,  0.        ,  0.        ],
33                [ 1.        ,  0.64421177,  0.64421177]]])
34        sage: complex_to_rgb([[0, 1j, 1000j]])
35        array([[[ 0.        ,  0.        ,  0.        ],
36                [ 0.38586284,  0.77172568,  0.        ],
37                [ 0.82210588,  1.        ,  0.64421177]]])
38    """
39    import numpy
40    cdef unsigned int i, j, imax, jmax, n
41    cdef ComplexDoubleElement z
42    from sage.rings.complex_double import CDF
43    from sage.plot.colors import rainbow
44
45    imax = len(z_values)
46    jmax = len(z_values[0])
47    cdef cnumpy.ndarray[cnumpy.float_t, ndim=3, mode='c'] rgb = numpy.empty(dtype=numpy.float, shape=(imax, jmax, 3))
48
49    n = len(roots)
50    R = rainbow(n)
51    R = [Color(col).rgb() for col in R]
52
53    sig_on()
54    for i from 0 <= i < imax:
55
56        row = z_values[i]
57        for j from 0 <= j < jmax:
58
59            zz = row[j]
60
61            try:
62                color = R[roots.index(zz)]
63                rgb[i, j, 0] = color[0]
64                rgb[i, j, 1] = color[1]
65                rgb[i, j, 2] = color[2]
66
67            except:
68                 print zz
69    sig_off()
70    return rgb
71
72
73def newt(func):
74    vari = func.args()[0]
75    fprime = diff(func,vari)
76    return vari-func/fprime
77
78
79
80class BasinPlot(GraphicPrimitive):
81    def __init__(self, roots, z_values, xrange, yrange, options):
82        """
83        TESTS::
84
85            sage: p = complex_plot(lambda z: z^2-1, (-2, 2), (-2, 2))
86        """
87        self.roots = roots
88        self.xrange = xrange
89        self.yrange = yrange
90        self.z_values = z_values
91        self.x_count = len(z_values)
92        self.y_count = len(z_values[0])
93        self.rgb_data = complex_to_rgb(roots,z_values)
94        GraphicPrimitive.__init__(self, options)
95
96    def get_minmax_data(self):
97        """
98        Returns a dictionary with the bounding box data.
99
100        EXAMPLES::
101
102            sage: p = complex_plot(lambda z: z, (-1, 2), (-3, 4))
103            sage: sorted(p.get_minmax_data().items())
104            [('xmax', 2.0), ('xmin', -1.0), ('ymax', 4.0), ('ymin', -3.0)]
105        """
106        from sage.plot.plot import minmax_data
107        return minmax_data(self.xrange, self.yrange, dict=True)
108
109    def _allowed_options(self):
110        """
111        TESTS::
112
113            sage: isinstance(complex_plot(lambda z: z, (-1,1), (-1,1))[0]._allowed_options(), dict)
114            True
115        """
116        return {'plot_points':'How many points to use for plotting precision',
117                'interpolation':'What interpolation method to use'}
118
119    def _repr_(self):
120        """
121        TESTS::
122
123            sage: isinstance(complex_plot(lambda z: z, (-1,1), (-1,1))[0]._repr_(), str)
124            True
125        """
126        return "BasinPlot defined by a %s x %s data grid"%(self.x_count, self.y_count)
127
128    def _render_on_subplot(self, subplot):
129        """
130        TESTS::
131
132            sage: complex_plot(lambda x: x^2, (-5, 5), (-5, 5))
133        """
134        options = self.options()
135        x0,x1 = float(self.xrange[0]), float(self.xrange[1])
136        y0,y1 = float(self.yrange[0]), float(self.yrange[1])
137        subplot.imshow(self.rgb_data, origin='lower', extent=(x0,x1,y0,y1), interpolation=options['interpolation'])
138
139
140
141@options(plot_points=3, interpolation='nearest')
142def basin_plot(roots, xrange, yrange, **options):
143    r"""
144    complex_plot takes a complex function of one variable,
145    f(z) and plots output of the function over the specified
146    xrange and yrange as demonstrated below. The magnitude of the
147    output is indicated by the brightness (with zero being black and
148    infinity being white) while the argument is represented by the
149    hue (with red being positive real, and increasing through orange,
150    yellow, ... as the argument increases).
151
152    complex_plot(f, (xmin, xmax), (ymin, ymax), ...)
153
154    INPUT:
155
156    - f -- a function of a single complex value x + iy
157
158    - (xmin, xmax) -- 2-tuple, the range of x values
159
160    - (ymin, ymax) -- 2-tuple, the range of y values
161
162    The following inputs must all be passed in as named parameters:
163
164    - plot_points -- integer (default: 100); number of points to plot
165      in each direction of the grid
166
167    - interpolation -- string (default: 'catrom'), the interpolation
168      method to use: 'bilinear', 'bicubic', 'spline16',
169      'spline36', 'quadric', 'gaussian', 'sinc',
170      'bessel', 'mitchell', 'lanczos', 'catrom',
171      'hermite', 'hanning', 'hamming', 'kaiser'
172
173
174    """
175    from sage.plot.plot import Graphics
176    from sage.plot.misc import setup_for_eval_on_grid
177    from sage.ext.fast_callable import fast_callable
178    from sage.rings.complex_double import CDF
179
180    x = var('x')
181    prefunc = prod([(x-root) for root in roots])
182    f = fast_callable(prefunc, domain=CDF, expect_one_var=True)
183    newtf = fast_callable(newt(prefunc), domain=CDF, expect_one_var=True)
184
185    cdef int counter
186
187    def which_root(funct,varia):
188        counter = 0
189        while counter < 20:
190#            try:
191            varia = newtf(varia)
192            #print z
193            for root in roots:
194                if abs(varia-root)<2/3:
195                    return root
196            counter += 1
197#            except:
198#                print 'hi again'
199#                return roots[0]
200        ls = []
201        for root in roots:
202            ls.append(abs(varia-root))
203        return roots[ls.index(min(ls))]
204
205    try:
206        f1 = fast_callable(lambda z1:which_root(f,z1), domain=CDF, expect_one_var=True)
207    except (AttributeError, TypeError, ValueError):
208        f1 = lambda z1: which_root(f,z1)
209
210    cdef double t, u
211    ignore, ranges = setup_for_eval_on_grid([], [xrange, yrange], options['plot_points'])
212    xrange,yrange=[r[:2] for r in ranges]
213    sig_on()
214    z_values = [[  f1(new_CDF_element(t, u)) for t in srange(*ranges[0], include_endpoint=True)]
215                                            for u in srange(*ranges[1], include_endpoint=True)]
216#    print z_values
217    sig_off()
218    g = Graphics()
219    g._set_extra_kwds(Graphics._extract_kwds_for_show(options, ignore=['xmin', 'xmax']))
220    g.add_primitive(BasinPlot(roots, z_values, xrange, yrange, options))
221    return g