Ticket #3773: sage-3773-part5.patch

File sage-3773-part5.patch, 3.7 kB (added by was, 5 months ago)

addresses another referee remark

  • a/sage/stats/hmm/chmm.pyx

    old new  
    604604        cdef ghmm_cseq* sqd 
    605605        try: 
    606606            sqd = to_cseq(seq) 
    607         except ValueError: 
     607        except RuntimeError: 
     608            # no sequences 
    608609            return float(0) 
    609610        cdef int ret = ghmm_cmodel_likelihood(self.m, sqd, &log_p) 
    610611        ghmm_cseq_free(&sqd) 
     
    663664        Baum-Welch algorithm to increase the probability of observing O. 
    664665 
    665666        INPUT: 
    666             training_seqs -- a list of lists of emission symbols; all sequences of 
    667                       length 0 are ignored. 
     667            training_seqs -- a list of lists of emission symbols, where all sequences of 
     668                      length 0 are ignored; or, a list of pairs 
     669                            (sample_sequence, weight), 
     670                      where sample_sequence is a list or TimeSeries, and weight is 
     671                      a positive real number.  
    668672            max_iter -- integer or None (default: 10000) maximum number 
    669673                      of Baum-Welch steps to take 
    670674            log_likehood_cutoff -- positive float or None (default: 0.00001); 
     
    693697 
    694698        We train using a list of lists: 
    695699            sage: m = hmm.GaussianHiddenMarkovModel([[1,0],[0,1]], [(0,1),(0,2)], [1/2,1/2]) 
    696             sage: m.baum_welch([[1,2,], [3,2]]) 
     700            sage: m.baum_welch([[1,2], [3,2]]) 
    697701            sage: m 
    698702            Gaussian Hidden Markov Model with 2 States 
    699703            Transition matrix: 
     
    725729        cs.smo      = self.m 
    726730        try: 
    727731            cs.sqd      = to_cseq(training_seqs) 
    728         except ValueError: 
    729             # No sequences 
     732        except RuntimeError: 
     733            # No nonempty sequences 
    730734            return 
    731735        cs.logp     = <double*> safe_malloc(sizeof(double)) 
    732736        cs.eps      = log_likelihood_cutoff 
     
    742746    Return a pointer to a ghmm_cseq C struct. 
    743747 
    744748    All empty sequences are ignored.  If there are no nonempty 
    745     sequences a ValueError is raised, since GHMM doesn't treat 
    746     this degenerate case well. 
     749    sequences a RuntimeError is raised, since GHMM doesn't treat 
     750    this degenerate case well.   If there are any nonpositive 
     751    weights, then a ValueError is raised. 
    747752    """ 
    748     if isinstance(seq, list) and len(seq) > 0 and not isinstance(seq[0], (list, tuple)): 
     753    if isinstance(seq, list) and len(seq) > 0 and not isinstance(seq[0], (list, tuple, TimeSeries)): 
    749754        seq = TimeSeries(seq) 
    750755    if isinstance(seq, TimeSeries): 
    751756        seq = [(seq,float(1))] 
     
    767772 
    768773    n = len(seq) 
    769774    if n == 0: 
    770         raise ValueError, "there must be at least one nonempty sequence" 
     775        raise RuntimeError, "there must be at least one nonempty sequence" 
     776 
     777    for _, w in seq: 
     778        if w <= 0: 
     779            raise ValueError, "each weight must be positive" 
     780     
    771781    cdef ghmm_cseq* sqd = <ghmm_cseq*>safe_malloc(sizeof(ghmm_cseq)) 
    772782    sqd.seq        = <double**>safe_malloc(sizeof(double*) * n) 
    773783    sqd.seq_len    = to_int_array([len(v) for v,_ in seq])