Ticket #3773: sage-3773-part5.patch
| File sage-3773-part5.patch, 3.7 kB (added by was, 5 months ago) |
|---|
-
a/sage/stats/hmm/chmm.pyx
old new 604 604 cdef ghmm_cseq* sqd 605 605 try: 606 606 sqd = to_cseq(seq) 607 except ValueError: 607 except RuntimeError: 608 # no sequences 608 609 return float(0) 609 610 cdef int ret = ghmm_cmodel_likelihood(self.m, sqd, &log_p) 610 611 ghmm_cseq_free(&sqd) … … 663 664 Baum-Welch algorithm to increase the probability of observing O. 664 665 665 666 INPUT: 666 training_seqs -- a list of lists of emission symbols; all sequences of 667 length 0 are ignored. 667 training_seqs -- a list of lists of emission symbols, where all sequences of 668 length 0 are ignored; or, a list of pairs 669 (sample_sequence, weight), 670 where sample_sequence is a list or TimeSeries, and weight is 671 a positive real number. 668 672 max_iter -- integer or None (default: 10000) maximum number 669 673 of Baum-Welch steps to take 670 674 log_likehood_cutoff -- positive float or None (default: 0.00001); … … 693 697 694 698 We train using a list of lists: 695 699 sage: m = hmm.GaussianHiddenMarkovModel([[1,0],[0,1]], [(0,1),(0,2)], [1/2,1/2]) 696 sage: m.baum_welch([[1,2 ,], [3,2]])700 sage: m.baum_welch([[1,2], [3,2]]) 697 701 sage: m 698 702 Gaussian Hidden Markov Model with 2 States 699 703 Transition matrix: … … 725 729 cs.smo = self.m 726 730 try: 727 731 cs.sqd = to_cseq(training_seqs) 728 except ValueError:729 # No sequences732 except RuntimeError: 733 # No nonempty sequences 730 734 return 731 735 cs.logp = <double*> safe_malloc(sizeof(double)) 732 736 cs.eps = log_likelihood_cutoff … … 742 746 Return a pointer to a ghmm_cseq C struct. 743 747 744 748 All empty sequences are ignored. If there are no nonempty 745 sequences a ValueError is raised, since GHMM doesn't treat 746 this degenerate case well. 749 sequences a RuntimeError is raised, since GHMM doesn't treat 750 this degenerate case well. If there are any nonpositive 751 weights, then a ValueError is raised. 747 752 """ 748 if isinstance(seq, list) and len(seq) > 0 and not isinstance(seq[0], (list, tuple )):753 if isinstance(seq, list) and len(seq) > 0 and not isinstance(seq[0], (list, tuple, TimeSeries)): 749 754 seq = TimeSeries(seq) 750 755 if isinstance(seq, TimeSeries): 751 756 seq = [(seq,float(1))] … … 767 772 768 773 n = len(seq) 769 774 if n == 0: 770 raise ValueError, "there must be at least one nonempty sequence" 775 raise RuntimeError, "there must be at least one nonempty sequence" 776 777 for _, w in seq: 778 if w <= 0: 779 raise ValueError, "each weight must be positive" 780 771 781 cdef ghmm_cseq* sqd = <ghmm_cseq*>safe_malloc(sizeof(ghmm_cseq)) 772 782 sqd.seq = <double**>safe_malloc(sizeof(double*) * n) 773 783 sqd.seq_len = to_int_array([len(v) for v,_ in seq])