Ticket #3726: sage-3726-part8.patch

File sage-3726-part8.patch, 3.9 kB (added by was, 5 months ago)
  • a/sage/stats/hmm/chmm.pyx

    old new  
    527527        """ 
    528528 
    529529        cdef double log_p 
    530         cdef ghmm_cseq* sqd = to_cseq(seq) 
     530        cdef ghmm_cseq* sqd 
     531        try: 
     532            sqd = to_cseq(seq) 
     533        except ValueError: 
     534            return float(0) 
    531535        cdef int ret = ghmm_cmodel_likelihood(self.m, sqd, &log_p) 
    532536        ghmm_cseq_free(&sqd) 
    533537        if ret == -1: 
     
    566570            seq = TimeSeries(seq) 
    567571        cdef TimeSeries T = seq 
    568572        cdef int* path = ghmm_cmodel_viterbi(self.m, T._values, T._length, &log_p) 
     573        if not path: 
     574            raise RuntimeError, "sequence can't be built from model" 
    569575        cdef Py_ssize_t i 
    570576        v = [path[i] for i in range(T._length)] 
    571577        sage_free(path) 
     
    583589        Baum-Welch algorithm to increase the probability of observing O. 
    584590 
    585591        INPUT: 
    586             training_seqs -- a list of lists of emission symbols 
     592            training_seqs -- a list of lists of emission symbols; all sequences of length 0 are ignored. 
    587593            max_iter -- integer or None (default: 10000) maximum number 
    588594                      of Baum-Welch steps to take 
    589595            log_likehood_cutoff -- positive float or None (default: 0.00001); 
     
    609615            Emission parameters: 
    610616            [(1.0, 0.0001)] 
    611617            Initial probabilities: [1.0] 
     618 
     619        Training sequences of length 0 are gracefully ignored: 
     620            sage: m.baum_welch([]) 
     621            sage: m.baum_welch([([],1)]) 
    612622        """ 
    613623        cdef ghmm_cmodel_baum_welch_context cs 
    614624 
    615625        cs.smo      = self.m 
    616         cs.sqd      = to_cseq(training_seqs) 
     626        try: 
     627            cs.sqd      = to_cseq(training_seqs) 
     628        except ValueError: 
     629            # No sequences 
     630            return 
    617631        cs.logp     = <double*> safe_malloc(sizeof(double)) 
    618632        cs.eps      = log_likelihood_cutoff 
    619633        cs.max_iter = max_iter 
     
    626640cdef ghmm_cseq* to_cseq(seq) except NULL: 
    627641    """ 
    628642    Return a pointer to a ghmm_cseq C struct. 
     643 
     644    All empty sequences are ignored.  If there are no nonempty 
     645    sequences a ValueError is raised, since GHMM doesn't treat 
     646    this degenerate case well. 
    629647    """ 
    630648    if isinstance(seq, list) and len(seq) > 0 and not isinstance(seq[0], tuple): 
    631649        seq = TimeSeries(seq) 
     
    645663            else: 
    646664                z = (TimeSeries(z), float(1)) 
    647665        seq[i] = z 
     666    seq = [x for x in seq if len(x[0]) > 0] 
    648667 
    649668    n = len(seq) 
     669    if n == 0: 
     670        raise ValueError, "there must be at least one nonempty sequence" 
    650671    cdef ghmm_cseq* sqd = <ghmm_cseq*>safe_malloc(sizeof(ghmm_cseq)) 
    651672    sqd.seq        = <double**>safe_malloc(sizeof(double*) * n) 
    652673    sqd.seq_len    = to_int_array([len(v) for v,_ in seq]) 
  • a/sage/stats/hmm/hmm.pyx

    old new  
    558558 
    559559        path = ghmm_dmodel_viterbi(self.m, O, len(seq), &pathlen, &log_p) 
    560560        sage_free(O) 
     561        if not path: 
     562            raise RuntimeError, "error computing viterbi path" 
    561563        p = [path[i] for i in range(pathlen)] 
    562564        sage_free(path) 
    563565