Ticket #3773: sage-3773-part2.patch

File sage-3773-part2.patch, 2.1 KB (added by was, 13 years ago)
  • sage/stats/hmm/hmm.pyx

    # HG changeset patch
    # User William Stein <wstein@gmail.com>
    # Date 1217967177 25200
    # Node ID 83635164ac474bbfc37e21e843bc62b41216d0e1
    # Parent  7df8dba964e0131a58cf2d15d64e786dc85c4114
    fix some small hmm bugs.
    
    diff -r 7df8dba964e0 -r 83635164ac47 sage/stats/hmm/hmm.pyx
    a b cdef class DiscreteHiddenMarkovModel(Hid 
    544544            sage: a.log_likelihood([1,1])
    545545            -1.3862943611198906
    546546        """
     547        if self._emission_symbols_dict:
     548            seq = [self._emission_symbols_dict[z] for z in seq]
    547549        cdef double log_p
    548550        cdef int* O = to_int_array(seq)
    549551        cdef int ret = ghmm_dmodel_logp(self.m, O, len(seq), &log_p)
    cdef class DiscreteHiddenMarkovModel(Hid 
    587589            sage: a.viterbi([3/4, 'abc', 'abc'] + [3/4]*10)
    588590            ([0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], -25.299405845367794)
    589591        """
     592        if len(seq) == 0:
     593            return [], 0.0
    590594        if self._emission_symbols_dict:
    591595            seq = [self._emission_symbols_dict[z] for z in seq]
    592596        cdef int* path
    cdef class DiscreteHiddenMarkovModel(Hid 
    615619        Baum-Welch algorithm to increase the probability of observing O.
    616620
    617621        INPUT:
    618             training_seqs -- a list of lists of emission symbols
     622            training_seqs -- a list of lists of emission symbols (or a single list)
    619623            nsteps -- integer or None (default: None) maximum number
    620624                      of Baum-Welch steps to take
    621625            log_likehood_cutoff -- positive float or None (default:
    cdef class DiscreteHiddenMarkovModel(Hid 
    675679            Applications in Speech Recognition"', Proceedings of the IEEE,
    676680            77, no 2, 1989, pp 257--285.
    677681        """
     682        if len(training_seqs) > 0 and not isinstance(training_seqs[0], (list, tuple)):
     683            training_seqs = [training_seqs]
     684           
    678685        if self._emission_symbols_dict:
    679686            seqs = [[self._emission_symbols_dict[z] for z in x] for x in training_seqs]
    680687        else: