Ticket #3773: sage-3773-part5.patch

File sage-3773-part5.patch, 3.7 KB (added by was, 13 years ago)

addresses another referee remark

  • sage/stats/hmm/chmm.pyx

    # HG changeset patch
    # User William Stein <wstein@gmail.com>
    # Date 1218544182 25200
    # Node ID ac011678b9aaed854de8d339fa469cb0663aafd3
    # Parent  3dd3cb24895105810bc0a01ad697ac5e64635fe6
    #3773 -- fix a bug pointed out by Jkantor that n.baum_wealch(m.sample(10,10)) didn't work.
    
    diff -r 3dd3cb248951 -r ac011678b9aa sage/stats/hmm/chmm.pyx
    a b cdef class GaussianHiddenMarkovModel(Con 
    604604        cdef ghmm_cseq* sqd
    605605        try:
    606606            sqd = to_cseq(seq)
    607         except ValueError:
     607        except RuntimeError:
     608            # no sequences
    608609            return float(0)
    609610        cdef int ret = ghmm_cmodel_likelihood(self.m, sqd, &log_p)
    610611        ghmm_cseq_free(&sqd)
    cdef class GaussianHiddenMarkovModel(Con 
    663664        Baum-Welch algorithm to increase the probability of observing O.
    664665
    665666        INPUT:
    666             training_seqs -- a list of lists of emission symbols; all sequences of
    667                       length 0 are ignored.
     667            training_seqs -- a list of lists of emission symbols, where all sequences of
     668                      length 0 are ignored; or, a list of pairs
     669                            (sample_sequence, weight),
     670                      where sample_sequence is a list or TimeSeries, and weight is
     671                      a positive real number.
    668672            max_iter -- integer or None (default: 10000) maximum number
    669673                      of Baum-Welch steps to take
    670674            log_likehood_cutoff -- positive float or None (default: 0.00001);
    cdef class GaussianHiddenMarkovModel(Con 
    693697
    694698        We train using a list of lists:
    695699            sage: m = hmm.GaussianHiddenMarkovModel([[1,0],[0,1]], [(0,1),(0,2)], [1/2,1/2])
    696             sage: m.baum_welch([[1,2,], [3,2]])
     700            sage: m.baum_welch([[1,2], [3,2]])
    697701            sage: m
    698702            Gaussian Hidden Markov Model with 2 States
    699703            Transition matrix:
    cdef class GaussianHiddenMarkovModel(Con 
    725729        cs.smo      = self.m
    726730        try:
    727731            cs.sqd      = to_cseq(training_seqs)
    728         except ValueError:
    729             # No sequences
     732        except RuntimeError:
     733            # No nonempty sequences
    730734            return
    731735        cs.logp     = <double*> safe_malloc(sizeof(double))
    732736        cs.eps      = log_likelihood_cutoff
    cdef ghmm_cseq* to_cseq(seq) except NULL 
    742746    Return a pointer to a ghmm_cseq C struct.
    743747
    744748    All empty sequences are ignored.  If there are no nonempty
    745     sequences a ValueError is raised, since GHMM doesn't treat
    746     this degenerate case well.
     749    sequences a RuntimeError is raised, since GHMM doesn't treat
     750    this degenerate case well.   If there are any nonpositive
     751    weights, then a ValueError is raised.
    747752    """
    748     if isinstance(seq, list) and len(seq) > 0 and not isinstance(seq[0], (list, tuple)):
     753    if isinstance(seq, list) and len(seq) > 0 and not isinstance(seq[0], (list, tuple, TimeSeries)):
    749754        seq = TimeSeries(seq)
    750755    if isinstance(seq, TimeSeries):
    751756        seq = [(seq,float(1))]
    cdef ghmm_cseq* to_cseq(seq) except NULL 
    767772
    768773    n = len(seq)
    769774    if n == 0:
    770         raise ValueError, "there must be at least one nonempty sequence"
     775        raise RuntimeError, "there must be at least one nonempty sequence"
     776
     777    for _, w in seq:
     778        if w <= 0:
     779            raise ValueError, "each weight must be positive"
     780   
    771781    cdef ghmm_cseq* sqd = <ghmm_cseq*>safe_malloc(sizeof(ghmm_cseq))
    772782    sqd.seq        = <double**>safe_malloc(sizeof(double*) * n)
    773783    sqd.seq_len    = to_int_array([len(v) for v,_ in seq])