# HG changeset patch
# User William Stein <wstein@gmail.com>
# Date 1218544182 25200
# Node ID ac011678b9aaed854de8d339fa469cb0663aafd3
# Parent 3dd3cb24895105810bc0a01ad697ac5e64635fe6
#3773 -- fix a bug pointed out by Jkantor that n.baum_wealch(m.sample(10,10)) didn't work.
diff -r 3dd3cb248951 -r ac011678b9aa sage/stats/hmm/chmm.pyx
a
|
b
|
cdef class GaussianHiddenMarkovModel(Con |
604 | 604 | cdef ghmm_cseq* sqd |
605 | 605 | try: |
606 | 606 | sqd = to_cseq(seq) |
607 | | except ValueError: |
| 607 | except RuntimeError: |
| 608 | # no sequences |
608 | 609 | return float(0) |
609 | 610 | cdef int ret = ghmm_cmodel_likelihood(self.m, sqd, &log_p) |
610 | 611 | ghmm_cseq_free(&sqd) |
… |
… |
cdef class GaussianHiddenMarkovModel(Con |
663 | 664 | Baum-Welch algorithm to increase the probability of observing O. |
664 | 665 | |
665 | 666 | INPUT: |
666 | | training_seqs -- a list of lists of emission symbols; all sequences of |
667 | | length 0 are ignored. |
| 667 | training_seqs -- a list of lists of emission symbols, where all sequences of |
| 668 | length 0 are ignored; or, a list of pairs |
| 669 | (sample_sequence, weight), |
| 670 | where sample_sequence is a list or TimeSeries, and weight is |
| 671 | a positive real number. |
668 | 672 | max_iter -- integer or None (default: 10000) maximum number |
669 | 673 | of Baum-Welch steps to take |
670 | 674 | log_likehood_cutoff -- positive float or None (default: 0.00001); |
… |
… |
cdef class GaussianHiddenMarkovModel(Con |
693 | 697 | |
694 | 698 | We train using a list of lists: |
695 | 699 | sage: m = hmm.GaussianHiddenMarkovModel([[1,0],[0,1]], [(0,1),(0,2)], [1/2,1/2]) |
696 | | sage: m.baum_welch([[1,2,], [3,2]]) |
| 700 | sage: m.baum_welch([[1,2], [3,2]]) |
697 | 701 | sage: m |
698 | 702 | Gaussian Hidden Markov Model with 2 States |
699 | 703 | Transition matrix: |
… |
… |
cdef class GaussianHiddenMarkovModel(Con |
725 | 729 | cs.smo = self.m |
726 | 730 | try: |
727 | 731 | cs.sqd = to_cseq(training_seqs) |
728 | | except ValueError: |
729 | | # No sequences |
| 732 | except RuntimeError: |
| 733 | # No nonempty sequences |
730 | 734 | return |
731 | 735 | cs.logp = <double*> safe_malloc(sizeof(double)) |
732 | 736 | cs.eps = log_likelihood_cutoff |
… |
… |
cdef ghmm_cseq* to_cseq(seq) except NULL |
742 | 746 | Return a pointer to a ghmm_cseq C struct. |
743 | 747 | |
744 | 748 | All empty sequences are ignored. If there are no nonempty |
745 | | sequences a ValueError is raised, since GHMM doesn't treat |
746 | | this degenerate case well. |
| 749 | sequences a RuntimeError is raised, since GHMM doesn't treat |
| 750 | this degenerate case well. If there are any nonpositive |
| 751 | weights, then a ValueError is raised. |
747 | 752 | """ |
748 | | if isinstance(seq, list) and len(seq) > 0 and not isinstance(seq[0], (list, tuple)): |
| 753 | if isinstance(seq, list) and len(seq) > 0 and not isinstance(seq[0], (list, tuple, TimeSeries)): |
749 | 754 | seq = TimeSeries(seq) |
750 | 755 | if isinstance(seq, TimeSeries): |
751 | 756 | seq = [(seq,float(1))] |
… |
… |
cdef ghmm_cseq* to_cseq(seq) except NULL |
767 | 772 | |
768 | 773 | n = len(seq) |
769 | 774 | if n == 0: |
770 | | raise ValueError, "there must be at least one nonempty sequence" |
| 775 | raise RuntimeError, "there must be at least one nonempty sequence" |
| 776 | |
| 777 | for _, w in seq: |
| 778 | if w <= 0: |
| 779 | raise ValueError, "each weight must be positive" |
| 780 | |
771 | 781 | cdef ghmm_cseq* sqd = <ghmm_cseq*>safe_malloc(sizeof(ghmm_cseq)) |
772 | 782 | sqd.seq = <double**>safe_malloc(sizeof(double*) * n) |
773 | 783 | sqd.seq_len = to_int_array([len(v) for v,_ in seq]) |