Ticket #3813: trac_3813_v2.patch

File trac_3813_v2.patch, 12.5 kB (added by anakha, 5 months ago)
  • a/sage/plot/plot.py

    old new  
    34143414 
    34153415    PLOT OPTIONS: 
    34163416    The plot options are 
    3417         plot_points -- the number of points to initially plot before 
    3418                        doing adaptive refinement 
    3419         plot_division -- the maximum number of subdivisions to 
    3420                          introduce in adaptive refinement. 
    3421         max_bend      -- parameter that affects adaptive refinement 
     3417        plot_points -- (default: 200) the number of points to initially plot 
     3418                       before the adaptive refinement. 
     3419        adaptive_recursion -- (default: 5) how many levels of recursion to go  
     3420                              before giving up when doing adaptive refinement. 
     3421                              Setting this to 0 disables adaptive refinement. 
     3422        adaptive_tolerance -- how large a difference should be before the 
     3423                              adaptive refinement code considers it significant. 
     3424                              This depends on the interval you use by default. 
    34223425        xmin -- starting x value 
    34233426        xmax -- ending x value 
    34243427        color -- an rgb-tuple (r,g,b) with each of r,g,b between 0 and 1, or 
     
    34613464        Graphics object consisting of 1 graphics primitive 
    34623465        sage: len(P)     # number of graphics primitives 
    34633466        1 
    3464         sage: len(P[0])  # how many points were computed 
    3465         200 
     3467        sage: len(P[0])  # how many points were computed (random) 
     3468        224 
    34663469        sage: P          # render 
    34673470         
    34683471        sage: P = plot(sin, (0,10), plot_points=10); print P 
    34693472        Graphics object consisting of 1 graphics primitive 
    34703473        sage: len(P[0])  # random output 
    3471         80 
     3474        32 
    34723475        sage: P          # render 
    34733476 
    34743477    We plot with randomize=False, which makes the initial sample  
    34753478    points evenly spaced (hence always the same).  Adaptive plotting  
    3476     might insert other points, however, unless plot_division=0.  
    3477         sage: p=plot(lambda x: 1, (x,0,3), plot_points=4, randomize=False, plot_division=0) 
     3479    might insert other points, however, unless adaptive_recursion=0.  
     3480        sage: p=plot(lambda x: 1, (x,0,3), plot_points=4, randomize=False, adaptive_recursion=0) 
    34783481        sage: list(p[0]) 
    34793482        [(0.0, 1.0), (1.0, 1.0), (2.0, 1.0), (3.0, 1.0)] 
    34803483 
     
    34883491        sage: plot([sin(n*x) for n in [1..4]], (0, pi)) 
    34893492 
    34903493 
    3491     The function $\sin(1/x)$ wiggles wildly near $0$, so the 
    3492     first plot below won't look perfect.  Sage adapts to this 
    3493     and plots extra points near the origin. 
     3494    The function $\sin(1/x)$ wiggles wildly near $0$.  Sage adapts 
     3495    to this and plots extra points near the origin. 
    34943496        sage: plot(sin(1/x), (x, -1, 1)) 
    3495  
    3496     With the \code{plot_points} option you can increase the number 
    3497     of sample points, to obtain a more accurate plot.  
    3498         sage: plot(sin(1/x), (x, -1, 1), plot_points=1000) 
    34993497 
    35003498    Note that the independent variable may be omitted if there is no 
    35013499    ambiguity: 
    3502         sage: plot(sin(1/x), (-1, 1), plot_points=1000) 
     3500        sage: plot(sin(1/x), (-1, 1)) 
     3501 
     3502    The algorithm used to insert extra points is actually pretty simple. On 
     3503    the picture drawn by the lines below: 
     3504        sage: p = plot(x^2, (-0.5, 1.4)) + line([(0,0), (1,1)], rgbcolor='green') 
     3505        sage: p += line([(0.5, 0.5), (0.5, 0.5^2)], rgbcolor='purple') 
     3506        sage: p += point(((0, 0), (0.5, 0.5), (0.5, 0.5^2), (1, 1)), rgbcolor='red', pointsize=20) 
     3507        sage: p += text('A', (-0.05, 0.1), rgbcolor='red') 
     3508        sage: p += text('B', (1.01, 1.1), rgbcolor='red') 
     3509        sage: p += text('C', (0.48, 0.57), rgbcolor='red') 
     3510        sage: p += text('D', (0.53, 0.18), rgbcolor='red') 
     3511        sage: p.show(axes=False, xmin=-0.5, xmax=1.4, ymin=0, ymax=2) 
     3512     
     3513    You have the function (in blue) and its approximation (in green) passing 
     3514    through the points A and B. The algorithm finds the midpoint C of AB and 
     3515    computes the distance between C and D. The point D is added the curve if 
     3516    it exceeds the (nonzero) adaptive_tolerance threshold. If D is added to 
     3517    the curve, then the algorithm is applied recursively to the points A and D, 
     3518    and D and B. It is repeated adaptive_recursion times (10, by default). 
    35033519 
    35043520    The actual sample points are slightly randomized, so the above 
    35053521    plots may look slightly different each time you draw them. 
     
    35433559    def _reset(self): 
    35443560        o = self.options 
    35453561        o['plot_points'] = 200 
    3546         o['plot_division'] = 1000  
    3547         o['max_bend'] = 0.1        
     3562        o['adaptive_recursion'] = 5 
    35483563        o['rgbcolor'] = (0,0,1)    
    35493564 
    35503565    def __repr__(self): 
    3551         """ 
     3566        r""" 
    35523567        Returns a string representation of this PlotFactory object. 
    35533568 
    35543569        TESTS: 
     
    36123627        plot_points = int(options['plot_points']) 
    36133628        del options['plot_points'] 
    36143629        x, data = var_and_list_of_values(xrange, plot_points) 
    3615         data = list(data) 
    36163630        xmin = data[0] 
    36173631        xmax = data[-1] 
    36183632 
     
    36383652            # randomize is true 
    36393653            if i > 0 and i < plot_points-1: 
    36403654                if randomize: 
    3641                     xi += delta*random() 
    3642                 if xi > xmax: 
    3643                     xi = xmax 
     3655                    xi += delta*(random() - 0.5) 
     3656                # we can't get over xmax (xmax - delta + delta*0.5 = xmax - delta*0.5) 
    36443657            elif i == plot_points-1:  
    36453658                xi = xmax  # guarantee that we get the last point. 
    3646                  
     3659             
    36473660            try: 
    36483661                data[i] = (float(xi), float(f(xi))) 
    3649             except (ZeroDivisionError, TypeError, ValueError,OverflowError), msg: 
     3662                # Only NaN is not equal to NaN 
     3663                if data[i][1] != data[i][1]: 
     3664                    raise ValueError, "nan result" 
     3665            except (ZeroDivisionError, TypeError, ValueError, OverflowError), msg: 
    36503666                sage.misc.misc.verbose("%s\nUnable to compute f(%s)"%(msg, x),1) 
    3651                 exceptions += 1 
    3652                 exception_indices.append(i) 
     3667                if i == 0: 
     3668                    for j in range(1, 99): 
     3669                        xj = xi + delta*j/100.0 
     3670                        try: 
     3671                            data[i] = (float(xj), float(f(xj))) 
     3672                            if data[i][1] != data[i][1]: 
     3673                                raise ValueError, "nan result" 
     3674                            break 
     3675                        except (ZeroDivisionError, TypeError, ValueError, OverflowError), msg: 
     3676                            pass 
     3677                elif i == plot_points-1: 
     3678                    for j in range(1, 99): 
     3679                        xj = xi - delta*j/100.0 
     3680                        try: 
     3681                            data[i] = (float(xj), float(f(xj))) 
     3682                            if data[i][1] != data[i][1]: 
     3683                                raise ValueError, "nan result" 
     3684                            break 
     3685                        except (ZeroDivisionError, TypeError, ValueError, OverflowError), msg: 
     3686                            pass 
     3687                else: 
     3688                    exceptions += 1 
     3689                    exception_indices.append(i) 
    36533690        data = [data[i] for i in range(len(data)) if i not in exception_indices] 
    3654              
     3691         
     3692        if 'plot_division' in options: 
     3693            del options['plot_division'] 
     3694            import warnings 
     3695            warnings.warn("plot_division is deprecated. See adaptive_recursion in the documentation for plot()", DeprecationWarning, stacklevel=3) 
    36553696        # adaptive refinement 
    36563697        i, j = 0, 0 
    3657         max_bend = float(options['max_bend']) 
    3658         del options['max_bend'] 
    3659         plot_division = int(options['plot_division']) 
    3660         del options['plot_division'] 
     3698        if 'adaptive_tolerance' in options: 
     3699            adaptive_tolerance = float(options['adaptive_tolerance']) 
     3700            del options['adaptive_tolerance'] 
     3701        else: 
     3702            adaptive_tolerance = delta * 0.01 
     3703        adaptive_recursion = int(options['adaptive_recursion']) 
     3704        del options['adaptive_recursion'] 
     3705         
    36613706        while i < len(data) - 1: 
    3662             if abs(data[i+1][1] - data[i][1]) > max_bend: 
    3663                 x = float((data[i+1][0] + data[i][0])/2) 
    3664                 try: 
    3665                     y = float(f(x)) 
    3666                     data.insert(i+1, (x, y)) 
    3667                 except (ZeroDivisionError, TypeError, ValueError), msg: 
    3668                     sage.misc.misc.verbose("%s\nUnable to compute f(%s)"%(msg, x),1) 
    3669                     exceptions += 1 
    3670                 j += 1 
    3671                 if j > plot_division: 
    3672                     break 
    3673             else: 
     3707            for p in adaptive_refinement(f, data[i], data[i+1],  
     3708                            adaptive_tolerance, adaptive_recursion): 
     3709                data.insert(i+1, p) 
    36743710                i += 1 
    3675  
     3711            i += 1 
     3712         
    36763713        if (len(data) == 0 and exceptions > 0) or exceptions > 10: 
    36773714            sage.misc.misc.verbose("WARNING: When plotting, failed to evaluate function at %s points."%exceptions, level=0) 
    36783715            sage.misc.misc.verbose("Last error message: '%s'"%msg, level=0) 
     
    44494486                g.append(fast_float(f, str(xvar), str(yvar))) 
    44504487             
    44514488    return g, xstep, ystep, xrange, yrange 
     4489 
     4490def adaptive_refinement(f, p1, p2, adaptive_tolerance=0.01, adaptive_recursion=10, level=0): 
     4491    r""" 
     4492    The adaptive refinement algorithm for plotting a function f. See the 
     4493    docstring for plot or PlotFactory for a description of the algorithm. 
     4494 
     4495    INPUT: 
     4496        f -- a function of one variable 
     4497        p1, p2 -- two points to refine between 
     4498        adaptive_recursion -- (default: 10) how many levels of recursion to go  
     4499                              before giving up when doing adaptive refinement. 
     4500                              Setting this to 0 disables adaptive refinement. 
     4501        adaptive_tolerance -- (default 0.01) how large a difference should be  
     4502                              before the adaptive refinement code considers  
     4503                              it significant. 
     4504     
     4505    OUTPUT: 
     4506        list -- a list of points to insert between p1 and p2 to get 
     4507                a better linear approximation between them 
     4508 
     4509    TESTS: 
     4510        sage: from sage.plot.plot import adaptive_refinement 
     4511        sage: adaptive_refinement(sin, (0,0), (pi,0), 0.01, level=10) 
     4512        [] 
     4513        sage: adaptive_refinement(sin, (0,0), (pi,0), 0.01) 
     4514        [(0.125000000000000*pi, 0.38268343236508978), (0.187500000000000*pi, 0.55557023301960218), (0.250000000000000*pi, 0.70710678118654757), (0.312500000000000*pi, 0.83146961230254524), (0.375000000000000*pi, 0.92387953251128674), (0.437500000000000*pi, 0.98078528040323043), (0.500000000000000*pi, 1.0), (0.562500000000000*pi, 0.98078528040323043), (0.625000000000000*pi, 0.92387953251128674), (0.687500000000000*pi, 0.83146961230254546), (0.750000000000000*pi, 0.70710678118654757), (0.812500000000000*pi, 0.55557023301960218), (0.875000000000000*pi, 0.38268343236508989)] 
     4515    """ 
     4516    if level >= adaptive_recursion: 
     4517        return [] 
     4518    x = (p1[0] + p2[0])/2.0 
     4519    try: 
     4520        y = float(f(x)) 
     4521    except (ZeroDivisionError, TypeError, ValueError, OverflowError), msg: 
     4522        sage.misc.misc.verbose("%s\nUnable to compute f(%s)"%(msg, x), 1) 
     4523        # give up for this branch 
     4524        return [] 
     4525    # this distance calculation is not perfect. 
     4526    if abs((p1[1] + p2[1])/2.0 - y) > adaptive_tolerance: 
     4527        return adaptive_refinement(f, p1, (x, y),  
     4528                        adaptive_tolerance=adaptive_tolerance, 
     4529                        adaptive_recursion=adaptive_recursion, 
     4530                        level=level+1) \ 
     4531               + [(x, y)] + \ 
     4532               adaptive_refinement(f, (x, y), p2, 
     4533                        adaptive_tolerance=adaptive_tolerance, 
     4534                        adaptive_recursion=adaptive_recursion, 
     4535                        level=level+1) 
     4536    else: 
     4537        return []