Ticket #3726: sage-3726-part8.patch
| File sage-3726-part8.patch, 3.9 kB (added by was, 5 months ago) |
|---|
-
a/sage/stats/hmm/chmm.pyx
old new 527 527 """ 528 528 529 529 cdef double log_p 530 cdef ghmm_cseq* sqd = to_cseq(seq) 530 cdef ghmm_cseq* sqd 531 try: 532 sqd = to_cseq(seq) 533 except ValueError: 534 return float(0) 531 535 cdef int ret = ghmm_cmodel_likelihood(self.m, sqd, &log_p) 532 536 ghmm_cseq_free(&sqd) 533 537 if ret == -1: … … 566 570 seq = TimeSeries(seq) 567 571 cdef TimeSeries T = seq 568 572 cdef int* path = ghmm_cmodel_viterbi(self.m, T._values, T._length, &log_p) 573 if not path: 574 raise RuntimeError, "sequence can't be built from model" 569 575 cdef Py_ssize_t i 570 576 v = [path[i] for i in range(T._length)] 571 577 sage_free(path) … … 583 589 Baum-Welch algorithm to increase the probability of observing O. 584 590 585 591 INPUT: 586 training_seqs -- a list of lists of emission symbols 592 training_seqs -- a list of lists of emission symbols; all sequences of length 0 are ignored. 587 593 max_iter -- integer or None (default: 10000) maximum number 588 594 of Baum-Welch steps to take 589 595 log_likehood_cutoff -- positive float or None (default: 0.00001); … … 609 615 Emission parameters: 610 616 [(1.0, 0.0001)] 611 617 Initial probabilities: [1.0] 618 619 Training sequences of length 0 are gracefully ignored: 620 sage: m.baum_welch([]) 621 sage: m.baum_welch([([],1)]) 612 622 """ 613 623 cdef ghmm_cmodel_baum_welch_context cs 614 624 615 625 cs.smo = self.m 616 cs.sqd = to_cseq(training_seqs) 626 try: 627 cs.sqd = to_cseq(training_seqs) 628 except ValueError: 629 # No sequences 630 return 617 631 cs.logp = <double*> safe_malloc(sizeof(double)) 618 632 cs.eps = log_likelihood_cutoff 619 633 cs.max_iter = max_iter … … 626 640 cdef ghmm_cseq* to_cseq(seq) except NULL: 627 641 """ 628 642 Return a pointer to a ghmm_cseq C struct. 643 644 All empty sequences are ignored. If there are no nonempty 645 sequences a ValueError is raised, since GHMM doesn't treat 646 this degenerate case well. 629 647 """ 630 648 if isinstance(seq, list) and len(seq) > 0 and not isinstance(seq[0], tuple): 631 649 seq = TimeSeries(seq) … … 645 663 else: 646 664 z = (TimeSeries(z), float(1)) 647 665 seq[i] = z 666 seq = [x for x in seq if len(x[0]) > 0] 648 667 649 668 n = len(seq) 669 if n == 0: 670 raise ValueError, "there must be at least one nonempty sequence" 650 671 cdef ghmm_cseq* sqd = <ghmm_cseq*>safe_malloc(sizeof(ghmm_cseq)) 651 672 sqd.seq = <double**>safe_malloc(sizeof(double*) * n) 652 673 sqd.seq_len = to_int_array([len(v) for v,_ in seq]) -
a/sage/stats/hmm/hmm.pyx
old new 558 558 559 559 path = ghmm_dmodel_viterbi(self.m, O, len(seq), &pathlen, &log_p) 560 560 sage_free(O) 561 if not path: 562 raise RuntimeError, "error computing viterbi path" 561 563 p = [path[i] for i in range(pathlen)] 562 564 sage_free(path) 563 565