Ticket #3971: sage-3971.patch

File sage-3971.patch, 4.0 kB (added by was, 4 months ago)
  • a/sage/stats/hmm/hmm.pxd

    old new  
    1212    cdef int GHMM_kSilentStates 
    1313    cdef int GHMM_kHigherOrderEmissions 
    1414    cdef int GHMM_kDiscreteHMM 
     15    cdef int GHMM_EPS_ITER_BW 
     16    cdef int GHMM_MAX_ITER_BW 
    1517 
    1618cdef extern from "ghmm/model.h": 
    1719    ctypedef struct ghmm_dstate: 
     
    170172 
    171173    # Problem 3: Learning 
    172174    int ghmm_dmodel_baum_welch (ghmm_dmodel *m, ghmm_dseq *sq) 
    173      
     175    int ghmm_dmodel_baum_welch_nstep (ghmm_dmodel *m, ghmm_dseq *sq, 
     176                                int nsteps, double log_likelihood_cutoff) 
    174177 
    175178cdef class HiddenMarkovModel: 
    176179    cdef Matrix_real_double_dense A, B 
  • a/sage/stats/hmm/hmm.pyx

    old new  
    767767            [0.333333333333 0.666666666667] 
    768768            Initial probabilities: [0.5, 0.5] 
    769769 
     770        We compare using a non-default number of steps and non-default log likelihood cutoff: 
     771            sage: h = hmm.DiscreteHiddenMarkovModel([[.1,.9],[.4,.6]], [[1/2]*2]*2, [1/2]*2) 
     772            sage: h.baum_welch([1,1,1,1,1,0,1,1,1,1,1,0,0]) 
     773            sage: h 
     774            Discrete Hidden Markov Model with 2 States and 2 Emissions 
     775            Transition matrix: 
     776            [  0.643888431046   0.356111568954] 
     777            [0.00232031442167   0.997679685578] 
     778            Emission matrix: 
     779            [           0.0            1.0] 
     780            [0.296080269308 0.703919730692] 
     781            Initial probabilities: [1.0, 3.6620347292312443e-20] 
     782            sage: h = hmm.DiscreteHiddenMarkovModel([[.1,.9],[.4,.6]], [[1/2]*2]*2, [1/2]*2) 
     783            sage: h.baum_welch([1,1,1,1,1,0,1,1,1,1,1,0,0], nsteps=1) 
     784            sage: h 
     785            Discrete Hidden Markov Model with 2 States and 2 Emissions 
     786            Transition matrix: 
     787            [0.1 0.9] 
     788            [0.4 0.6] 
     789            Emission matrix: 
     790            [0.222426510432 0.777573489568] 
     791            [0.234678486789 0.765321513211] 
     792            Initial probabilities: [0.5, 0.5] 
     793            sage: h = hmm.DiscreteHiddenMarkovModel([[.1,.9],[.4,.6]], [[1/2]*2]*2, [1/2]*2) 
     794            sage: h.baum_welch([1,1,1,1,1,0,1,1,1,1,1,0,0], log_likelihood_cutoff=0.01) 
     795            sage: h 
     796            Discrete Hidden Markov Model with 2 States and 2 Emissions 
     797            Transition matrix: 
     798            [0.0999471960754  0.900052803925] 
     799            [ 0.399248483887  0.600751516113] 
     800            Emission matrix: 
     801            [0.209328925617 0.790671074383] 
     802            [ 0.24081056723  0.75918943277] 
     803            Initial probabilities: [0.50877982192958637, 0.49122017807041363]         
     804 
    770805        TESTS: 
    771806        We test training with non-default string symbols: 
    772807            sage: a = hmm.DiscreteHiddenMarkovModel([[0.5,0.5],[0.5,0.5]], [[0.5,0.5],[0.5,0.5]], [0.5,0.5], ['up','down']) 
     
    799834             
    800835        cdef ghmm_dseq* d = malloc_ghmm_dseq(seqs) 
    801836         
    802         if ghmm_dmodel_baum_welch(self.m, d): 
    803             raise RuntimeError, "error running Baum-Welch algorithm" 
     837        if nsteps or log_likelihood_cutoff: 
     838            if nsteps is None: 
     839                nsteps = GHMM_MAX_ITER_BW 
     840            if log_likelihood_cutoff is None: 
     841                log_likelihood_cutoff = GHMM_EPS_ITER_BW  
     842            if ghmm_dmodel_baum_welch_nstep(self.m, d, nsteps, log_likelihood_cutoff): 
     843                raise RuntimeError, "error running Baum-Welch algorithm" 
     844        else: 
     845            if ghmm_dmodel_baum_welch(self.m, d): 
     846                raise RuntimeError, "error running Baum-Welch algorithm" 
    804847         
    805848        ghmm_dseq_free(&d) 
    806849