Ticket #10775: streamplot.py

File streamplot.py, 9.5 KB (added by jason, 11 years ago)
Line 
1"""
2Streamline plotting like Mathematica.
3Copyright (c) 2011 Tom Flannaghan.
4
5Permission is hereby granted, free of charge, to any person obtaining a copy
6of this software and associated documentation files (the "Software"), to deal
7in the Software without restriction, including without limitation the rights
8to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9copies of the Software, and to permit persons to whom the Software is
10furnished to do so, subject to the following conditions:
11
12The above copyright notice and this permission notice shall be included in
13all copies or substantial portions of the Software.
14
15THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21THE SOFTWARE.
22"""
23
24import numpy
25import pylab
26import matplotlib
27
28def streamplot(x, y, u, v, density=1, thickness=None, **kwargs):
29    '''Draws streamlines of a vector flow. x and y are 1d arrays
30    defining an *evenly spaced* grid. u and v are 2d arrays (shape
31    [y,x]) giving velocities. density controls the closeness of the
32    streamlines - slow for large values. Thickness can also be
33    specified as a 2d array, which controls line width - slow though.'''
34
35    ## Sanity checks.
36    assert len(x.shape)==1
37    assert len(y.shape)==1
38    assert u.shape == (len(y), len(x))
39    assert v.shape == (len(y), len(x))
40    if thickness != None:
41        assert thickness.shape == (len(y), len(x))
42
43    ## Set up some constants - size of the grid used.
44    NGX = len(x)
45    NGY = len(y)
46    ## Constants used to convert between grid index coords and user coords.
47    DX = x[1]-x[0]
48    DY = y[1]-y[0]
49    XOFF = x[0]
50    YOFF = y[0]
51
52    ## Now rescale velocity onto a unit square
53    u /= (x[-1]-x[0])
54    v /= (y[-1]-y[0])
55    ## Rescale v so it's in units of NGX.
56    v *= NGY/float(NGX)
57    ## So now s is in units of NGX.
58
59    ## Speed calculation is needed for path integrations
60    speed = numpy.sqrt(u*u+v*v)
61
62    ## Blank array: This is the heart of the algorithm. It begins life
63    ## zeroed, but is set to one when a streamline passes through each
64    ## box. Then streamlines are only allowed to pass through zeroed
65    ## boxes. The lower resolution of this grid determines the
66    ## approximate spacing between trajectories.
67    NB = int(30*density)
68    blank = numpy.zeros((NB,NB))
69    ## Constants for conversion between grid-index space and
70    ## blank-index space
71    bx_spacing = NGX/float(NB-1)
72    by_spacing = NGY/float(NB-1)
73
74    def blank_pos(xi, yi):
75        ## Takes grid space coords and returns nearest space in
76        ## the blank array.
77        return int((xi / bx_spacing) + 0.5), int((yi / by_spacing) + 0.5)
78
79    def value_at(a, xi, yi):
80        ## Linear interpolation - nice and quick because we are
81        ## working in grid-index coordinates.
82        x = int(xi)
83        y = int(yi)
84        a00 = a[y,x]
85        a01 = a[y,x+1]
86        a10 = a[y+1,x]
87        a11 = a[y+1,x+1]
88        xt = xi - x
89        yt = yi - y
90        a0 = a00*(1-xt) + a01*xt
91        a1 = a10*(1-xt) + a11*xt
92        return a0*(1-yt) + a1*yt
93
94    def rk4_integrate(x0, y0):
95        ## This function does RK4 forward and back trajectories from
96        ## the initial conditions, with the odd 'blank array'
97        ## termination conditions. TODO tidy the integration loops.
98       
99        def f(xi, yi):
100            dt_ds = 1./value_at(speed, xi, yi)
101            ui = value_at(u, xi, yi)
102            vi = value_at(v, xi, yi)
103            return ui*dt_ds, vi*dt_ds
104
105        def g(xi, yi):
106            dt_ds = 1./value_at(speed, xi, yi)
107            ui = value_at(u, xi, yi)
108            vi = value_at(v, xi, yi)
109            return -ui*dt_ds, -vi*dt_ds
110
111        check = lambda xi, yi: xi>=0 and xi<NGX and yi>=0 and yi<NGY
112
113        bx_changes = []
114        by_changes = []
115
116        # Forward trace
117        ds = 0.01*NGX
118        xi = x0
119        yi = y0
120        xb, yb = blank_pos(xi, yi)
121        initxb, inityb = xb, yb
122        stotal = 0
123        xf_traj = []
124        yf_traj = []
125        while check(xi, yi):
126            # Time step. First save the point.
127            xf_traj.append(xi)
128            yf_traj.append(yi)
129            # Next, advance one using RK4
130            try:
131                k1x, k1y = f(xi, yi)
132                k2x, k2y = f(xi + .5*ds*k1x, yi + .5*ds*k1y)
133                k3x, k3y = f(xi + .5*ds*k2x, yi + .5*ds*k2y)
134                k4x, k4y = f(xi + ds*k3x, yi + ds*k3y)
135            except IndexError:
136                # Out of the domain on one of the intermediate steps
137                break
138            xi += ds*(k1x+2*k2x+2*k3x+k4x) / 6.
139            yi += ds*(k1y+2*k2y+2*k3y+k4y) / 6.
140            # Final position might be out of the domain
141            if not check(xi, yi): break
142            stotal += ds
143            # Next, if s gets to thres, check blank.
144            new_xb, new_yb = blank_pos(xi, yi)
145            if new_xb != xb or new_yb != yb:
146                # New square, so check and colour. Quit if required.
147                if blank[new_yb,new_xb] == 0:
148                    blank[new_yb,new_xb] = 1
149                    bx_changes.append(new_xb)
150                    by_changes.append(new_yb)
151                    xb = new_xb
152                    yb = new_yb
153                else:
154                    break
155            if stotal > NGX:
156                break
157
158        # Backward trace
159        xi = x0
160        yi = y0
161        xb, yb = blank_pos(xi, yi)
162        initxb, inityb = xb, yb
163        xb_traj = []
164        yb_traj = []
165        while check(xi, yi):
166            # Time step. First save the point.
167            xb_traj.append(xi)
168            yb_traj.append(yi)
169            # Next, advance one using RK4
170            try:
171                k1x, k1y = g(xi, yi)
172                k2x, k2y = g(xi + .5*ds*k1x, yi + .5*ds*k1y)
173                k3x, k3y = g(xi + .5*ds*k2x, yi + .5*ds*k2y)
174                k4x, k4y = g(xi + ds*k3x, yi + ds*k3y)
175            except IndexError:
176                break
177            xi += ds*(k1x+2*k2x+2*k3x+k4x) / 6.
178            yi += ds*(k1y+2*k2y+2*k3y+k4y) / 6.
179            if not check(xi, yi): break
180            stotal += ds
181            # Next, if s gets to thres, check blank.
182            new_xb, new_yb = blank_pos(xi, yi)
183            if new_xb != xb or new_yb != yb:
184                #print new_xb, new_yb, xi, yi
185                # New square, so check and colour. Quit if required.
186                if blank[new_yb,new_xb] == 0:
187                    blank[new_yb,new_xb] = 1
188                    bx_changes.append(new_xb)
189                    by_changes.append(new_yb)
190                    xb = new_xb
191                    yb = new_yb
192                else:
193                    break
194            if stotal > NGX:
195                break
196
197        x_traj = xb_traj[::-1] + xf_traj[1:]
198        y_traj = yb_traj[::-1] + yf_traj[1:]
199
200        if stotal > 4*bx_spacing:
201            blank[inityb, initxb] = 1
202            return x_traj, y_traj
203        else:
204            for xb, yb in zip(bx_changes, by_changes):
205                blank[yb, xb] = 0
206            return None
207
208    ## A quick function for integrating trajectories if blank==0.
209    trajectories = []
210    def traj(xb, yb):
211        if blank[yb, xb] == 0:
212            t = rk4_integrate(xb*bx_spacing, yb*by_spacing)
213            if t != None:
214                trajectories.append(t)
215
216    ## Now we build up the trajectory set. I've found it best to look
217    ## for blank==0 along the edges first, and work inwards.
218    for indent in range(NB/2):
219        for x in range(NB-2*indent):
220            traj(x+indent, indent)
221            traj(x+indent, NB-1-indent)
222            traj(indent, x+indent)
223            traj(NB-1-indent, x+indent)
224
225    ## PLOTTING HERE.
226    for t in trajectories:
227        # Finally apply the rescale to adjust back to user-coords from
228        # grid-index coordinates.
229        tx = numpy.array(t[0])*DX+XOFF
230        ty = numpy.array(t[1])*DY+YOFF
231
232        # If thickness is specified then break down the line. If not,
233        # plot as one trajectory.
234        if thickness != None:
235            ## Horrible!! Is there a better way?
236            linewidth = numpy.zeros(len(tx), dtype=numpy.float32)
237            for i in range(len(tx)):
238                linewidth[i] = value_at(thickness, t[0][i], t[1][i])
239            linewidth /= thickness.max()
240
241            for i in range(len(tx)-1):
242                pylab.plot(tx[i:i+2], ty[i:i+2], lw=3*linewidth[i], **kwargs)
243        else:
244            ## More sane and much faster without variable thickness.
245            pylab.plot(tx, ty, **kwargs)
246
247        ## Add arrows half way along each trajectory.
248        n = len(tx)/2
249        pylab.arrow(tx[n], ty[n], 
250                    (tx[n+1]-tx[n])*.01, (ty[n+1]-ty[n])*.01,
251                    head_width=DX, **kwargs)
252
253    return
254
255def test():
256    pylab.figure(1)
257    x = numpy.linspace(-3,3,100)
258    y = numpy.linspace(-3,3,100)
259    u = -1-x**2+y[:,numpy.newaxis]
260    v = 1+x-y[:,numpy.newaxis]**2
261    streamplot(x, y, u, v, color='k', density=1)
262    pylab.show()
263
264    ## Very slow - thickness kills speed.
265    #pylab.figure(2)
266    #pylab.title('Line width proportional to speed')
267    #x = numpy.linspace(-3,3,100)
268    #y = numpy.linspace(-3,3,200)
269    #u = -1-x**2+y[:,numpy.newaxis]
270    #v = 1+x-y[:,numpy.newaxis]**2
271    #speed = numpy.sqrt(u*u + v*v)
272    #streamplot(x, y, u, v, thickness=speed, color='k')
273
274if __name__ == '__main__':
275    test()