Ticket #3813: 3813-anakha-adaptive-plot-v3.patch

File 3813-anakha-adaptive-plot-v3.patch, 13.5 kB (added by ncalexan, 5 months ago)
  • a/sage/plot/plot.py

    old new  
    34143414 
    34153415    PLOT OPTIONS: 
    34163416    The plot options are 
    3417         plot_points -- the number of points to initially plot before 
    3418                        doing adaptive refinement 
    3419         plot_division -- the maximum number of subdivisions to 
    3420                          introduce in adaptive refinement. 
    3421         max_bend      -- parameter that affects adaptive refinement 
     3417        plot_points -- (default: 200) the number of points to initially plot 
     3418                       before the adaptive refinement. 
     3419        adaptive_recursion -- (default: 5) how many levels of recursion to go  
     3420                              before giving up when doing adaptive refinement. 
     3421                              Setting this to 0 disables adaptive refinement. 
     3422        adaptive_tolerance -- (default: 0.01) how large a difference should be 
     3423                              before the adaptive refinement code considers it 
     3424                              significant.  Use a smaller value for smoother 
     3425                              plots and a larger value for coarser plots.  In 
     3426                              general, adjust adaptive_recursion before 
     3427                              adaptive_tolerance, and consult the code of 
     3428                              adaptive_refinement for the details. 
     3429 
    34223430        xmin -- starting x value 
    34233431        xmax -- ending x value 
    34243432        color -- an rgb-tuple (r,g,b) with each of r,g,b between 0 and 1, or 
     
    34613469        Graphics object consisting of 1 graphics primitive 
    34623470        sage: len(P)     # number of graphics primitives 
    34633471        1 
    3464         sage: len(P[0])  # how many points were computed 
    3465         200 
     3472        sage: len(P[0])  # how many points were computed (random) 
     3473        224 
    34663474        sage: P          # render 
    34673475         
    34683476        sage: P = plot(sin, (0,10), plot_points=10); print P 
    34693477        Graphics object consisting of 1 graphics primitive 
    34703478        sage: len(P[0])  # random output 
    3471         80 
     3479        32 
    34723480        sage: P          # render 
    34733481 
    34743482    We plot with randomize=False, which makes the initial sample  
    34753483    points evenly spaced (hence always the same).  Adaptive plotting  
    3476     might insert other points, however, unless plot_division=0.  
    3477         sage: p=plot(lambda x: 1, (x,0,3), plot_points=4, randomize=False, plot_division=0) 
     3484    might insert other points, however, unless adaptive_recursion=0.  
     3485        sage: p=plot(lambda x: 1, (x,0,3), plot_points=4, randomize=False, adaptive_recursion=0) 
    34783486        sage: list(p[0]) 
    34793487        [(0.0, 1.0), (1.0, 1.0), (2.0, 1.0), (3.0, 1.0)] 
    34803488 
     
    34883496        sage: plot([sin(n*x) for n in [1..4]], (0, pi)) 
    34893497 
    34903498 
    3491     The function $\sin(1/x)$ wiggles wildly near $0$, so the 
    3492     first plot below won't look perfect.  Sage adapts to this 
    3493     and plots extra points near the origin. 
     3499    The function $\sin(1/x)$ wiggles wildly near $0$.  Sage adapts 
     3500    to this and plots extra points near the origin. 
    34943501        sage: plot(sin(1/x), (x, -1, 1)) 
    3495  
    3496     With the \code{plot_points} option you can increase the number 
    3497     of sample points, to obtain a more accurate plot.  
    3498         sage: plot(sin(1/x), (x, -1, 1), plot_points=1000) 
    34993502 
    35003503    Note that the independent variable may be omitted if there is no 
    35013504    ambiguity: 
    3502         sage: plot(sin(1/x), (-1, 1), plot_points=1000) 
     3505        sage: plot(sin(1/x), (-1, 1)) 
     3506 
     3507    The algorithm used to insert extra points is actually pretty simple. On 
     3508    the picture drawn by the lines below: 
     3509        sage: p = plot(x^2, (-0.5, 1.4)) + line([(0,0), (1,1)], rgbcolor='green') 
     3510        sage: p += line([(0.5, 0.5), (0.5, 0.5^2)], rgbcolor='purple') 
     3511        sage: p += point(((0, 0), (0.5, 0.5), (0.5, 0.5^2), (1, 1)), rgbcolor='red', pointsize=20) 
     3512        sage: p += text('A', (-0.05, 0.1), rgbcolor='red') 
     3513        sage: p += text('B', (1.01, 1.1), rgbcolor='red') 
     3514        sage: p += text('C', (0.48, 0.57), rgbcolor='red') 
     3515        sage: p += text('D', (0.53, 0.18), rgbcolor='red') 
     3516        sage: p.show(axes=False, xmin=-0.5, xmax=1.4, ymin=0, ymax=2) 
     3517     
     3518    You have the function (in blue) and its approximation (in green) passing 
     3519    through the points A and B. The algorithm finds the midpoint C of AB and 
     3520    computes the distance between C and D. The point D is added the curve if 
     3521    it exceeds the (nonzero) adaptive_tolerance threshold. If D is added to 
     3522    the curve, then the algorithm is applied recursively to the points A and D, 
     3523    and D and B. It is repeated adaptive_recursion times (10, by default). 
    35033524 
    35043525    The actual sample points are slightly randomized, so the above 
    35053526    plots may look slightly different each time you draw them. 
     
    35433564    def _reset(self): 
    35443565        o = self.options 
    35453566        o['plot_points'] = 200 
    3546         o['plot_division'] = 1000  
    3547         o['max_bend'] = 0.1        
     3567        o['adaptive_tolerance'] = 0.01 
     3568        o['adaptive_recursion'] = 5 
    35483569        o['rgbcolor'] = (0,0,1)    
    35493570 
    35503571    def __repr__(self): 
    3551         """ 
     3572        r""" 
    35523573        Returns a string representation of this PlotFactory object. 
    35533574 
    35543575        TESTS: 
     
    36123633        plot_points = int(options['plot_points']) 
    36133634        del options['plot_points'] 
    36143635        x, data = var_and_list_of_values(xrange, plot_points) 
    3615         data = list(data) 
    36163636        xmin = data[0] 
    36173637        xmax = data[-1] 
    36183638 
     
    36383658            # randomize is true 
    36393659            if i > 0 and i < plot_points-1: 
    36403660                if randomize: 
    3641                     xi += delta*random() 
    3642                 if xi > xmax: 
    3643                     xi = xmax 
     3661                    xi += delta*(random() - 0.5) 
     3662                # we can't get over xmax (xmax - delta + delta*0.5 = xmax - delta*0.5) 
    36443663            elif i == plot_points-1:  
    36453664                xi = xmax  # guarantee that we get the last point. 
    3646                  
     3665             
    36473666            try: 
    36483667                data[i] = (float(xi), float(f(xi))) 
    3649             except (ZeroDivisionError, TypeError, ValueError,OverflowError), msg: 
     3668                # Only NaN is not equal to NaN 
     3669                if data[i][1] != data[i][1]: 
     3670                    raise ValueError, "nan result" 
     3671            except (ZeroDivisionError, TypeError, ValueError, OverflowError), msg: 
    36503672                sage.misc.misc.verbose("%s\nUnable to compute f(%s)"%(msg, x),1) 
    3651                 exceptions += 1 
    3652                 exception_indices.append(i) 
     3673                if i == 0: 
     3674                    for j in range(1, 99): 
     3675                        xj = xi + delta*j/100.0 
     3676                        try: 
     3677                            data[i] = (float(xj), float(f(xj))) 
     3678                            if data[i][1] != data[i][1]: 
     3679                                raise ValueError, "nan result" 
     3680                            break 
     3681                        except (ZeroDivisionError, TypeError, ValueError, OverflowError), msg: 
     3682                            pass 
     3683                elif i == plot_points-1: 
     3684                    for j in range(1, 99): 
     3685                        xj = xi - delta*j/100.0 
     3686                        try: 
     3687                            data[i] = (float(xj), float(f(xj))) 
     3688                            if data[i][1] != data[i][1]: 
     3689                                raise ValueError, "nan result" 
     3690                            break 
     3691                        except (ZeroDivisionError, TypeError, ValueError, OverflowError), msg: 
     3692                            pass 
     3693                else: 
     3694                    exceptions += 1 
     3695                    exception_indices.append(i) 
    36533696        data = [data[i] for i in range(len(data)) if i not in exception_indices] 
    3654              
     3697         
     3698        if 'plot_division' in options: 
     3699            del options['plot_division'] 
     3700            import warnings 
     3701            warnings.warn("plot_division is deprecated. See adaptive_recursion in the documentation for plot()", DeprecationWarning, stacklevel=3) 
    36553702        # adaptive refinement 
    36563703        i, j = 0, 0 
    3657         max_bend = float(options['max_bend']) 
    3658         del options['max_bend'] 
    3659         plot_division = int(options['plot_division']) 
    3660         del options['plot_division'] 
     3704        adaptive_tolerance = delta * float(options['adaptive_tolerance']) 
     3705        del options['adaptive_tolerance'] 
     3706        adaptive_recursion = int(options['adaptive_recursion']) 
     3707        del options['adaptive_recursion'] 
     3708         
    36613709        while i < len(data) - 1: 
    3662             if abs(data[i+1][1] - data[i][1]) > max_bend: 
    3663                 x = float((data[i+1][0] + data[i][0])/2) 
    3664                 try: 
    3665                     y = float(f(x)) 
    3666                     data.insert(i+1, (x, y)) 
    3667                 except (ZeroDivisionError, TypeError, ValueError), msg: 
    3668                     sage.misc.misc.verbose("%s\nUnable to compute f(%s)"%(msg, x),1) 
    3669                     exceptions += 1 
    3670                 j += 1 
    3671                 if j > plot_division: 
    3672                     break 
    3673             else: 
     3710            for p in adaptive_refinement(f, data[i], data[i+1],  
     3711                                         adaptive_tolerance=adaptive_tolerance, 
     3712                                         adaptive_recursion=adaptive_recursion): 
     3713                data.insert(i+1, p) 
    36743714                i += 1 
    3675  
     3715            i += 1 
     3716         
    36763717        if (len(data) == 0 and exceptions > 0) or exceptions > 10: 
    36773718            sage.misc.misc.verbose("WARNING: When plotting, failed to evaluate function at %s points."%exceptions, level=0) 
    36783719            sage.misc.misc.verbose("Last error message: '%s'"%msg, level=0) 
     
    44494490                g.append(fast_float(f, str(xvar), str(yvar))) 
    44504491             
    44514492    return g, xstep, ystep, xrange, yrange 
     4493 
     4494def adaptive_refinement(f, p1, p2, adaptive_tolerance=0.01, adaptive_recursion=10, level=0): 
     4495    r""" 
     4496    The adaptive refinement algorithm for plotting a function f. See the 
     4497    docstring for plot or PlotFactory for a description of the algorithm. 
     4498 
     4499    INPUT: 
     4500        f -- a function of one variable 
     4501        p1, p2 -- two points to refine between 
     4502        adaptive_recursion -- (default: 10) how many levels of recursion to go  
     4503                              before giving up when doing adaptive refinement. 
     4504                              Setting this to 0 disables adaptive refinement. 
     4505        adaptive_tolerance -- (default: 0.01) how large a difference should be 
     4506                              before the adaptive refinement code considers  
     4507                              it significant.  See the documentation for 
     4508                              plot() for more information. 
     4509     
     4510    OUTPUT: 
     4511        list -- a list of points to insert between p1 and p2 to get 
     4512                a better linear approximation between them 
     4513 
     4514    TESTS: 
     4515        sage: from sage.plot.plot import adaptive_refinement 
     4516        sage: adaptive_refinement(sin, (0,0), (pi,0), adaptive_tolerance=0.01, level=10) 
     4517        [] 
     4518        sage: adaptive_refinement(sin, (0,0), (pi,0), adaptive_tolerance=0.01) 
     4519        [(0.125000000000000*pi, 0.38268343236508978), (0.187500000000000*pi, 0.55557023301960218), (0.250000000000000*pi, 0.707106781186547...), (0.312500000000000*pi, 0.83146961230254524), (0.375000000000000*pi, 0.92387953251128674), (0.437500000000000*pi, 0.98078528040323043), (0.500000000000000*pi, 1.0), (0.562500000000000*pi, 0.98078528040323043), (0.625000000000000*pi, 0.92387953251128674), (0.687500000000000*pi, 0.83146961230254546), (0.750000000000000*pi, 0.70710678118654757), (0.812500000000000*pi, 0.55557023301960218), (0.875000000000000*pi, 0.38268343236508989)] 
     4520 
     4521        This shows that lowering adaptive_tolerance and raising 
     4522        adaptive_recursion both increase the number of subdivision points: 
     4523 
     4524        sage: x = var('x') 
     4525        sage: f = sin(1/x) 
     4526        sage: n1 = len(adaptive_refinement(f, (0,0), (pi,0), adaptive_tolerance=0.01)); n1 
     4527        79 
     4528        sage: n2 = len(adaptive_refinement(f, (0,0), (pi,0), adaptive_recursion=5, adaptive_tolerance=0.01)); n2 
     4529        15 
     4530        sage: n1 > n2 
     4531        True 
     4532 
     4533        sage: n3 = len(adaptive_refinement(f, (0,0), (pi,0), adaptive_tolerance=0.005)); n3 
     4534        88 
     4535        sage: n1 < n3 
     4536        True 
     4537    """ 
     4538    if level >= adaptive_recursion: 
     4539        return [] 
     4540    x = (p1[0] + p2[0])/2.0 
     4541    try: 
     4542        y = float(f(x)) 
     4543    except (ZeroDivisionError, TypeError, ValueError, OverflowError), msg: 
     4544        sage.misc.misc.verbose("%s\nUnable to compute f(%s)"%(msg, x), 1) 
     4545        # give up for this branch 
     4546        return [] 
     4547    # this distance calculation is not perfect. 
     4548    if abs((p1[1] + p2[1])/2.0 - y) > adaptive_tolerance: 
     4549        return adaptive_refinement(f, p1, (x, y),  
     4550                        adaptive_tolerance=adaptive_tolerance, 
     4551                        adaptive_recursion=adaptive_recursion, 
     4552                        level=level+1) \ 
     4553               + [(x, y)] + \ 
     4554               adaptive_refinement(f, (x, y), p2, 
     4555                        adaptive_tolerance=adaptive_tolerance, 
     4556                        adaptive_recursion=adaptive_recursion, 
     4557                        level=level+1) 
     4558    else: 
     4559        return []