# HG changeset patch
# User William Stein <wstein@gmail.com>
# Date 1218544182 25200
# Node ID ac011678b9aaed854de8d339fa469cb0663aafd3
# Parent 3dd3cb24895105810bc0a01ad697ac5e64635fe6
#3773  fix a bug pointed out by Jkantor that n.baum_wealch(m.sample(10,10)) didn't work.
diff r 3dd3cb248951 r ac011678b9aa sage/stats/hmm/chmm.pyx
a

b

cdef class GaussianHiddenMarkovModel(Con 
604  604  cdef ghmm_cseq* sqd 
605  605  try: 
606  606  sqd = to_cseq(seq) 
607   except ValueError: 
 607  except RuntimeError: 
 608  # no sequences 
608  609  return float(0) 
609  610  cdef int ret = ghmm_cmodel_likelihood(self.m, sqd, &log_p) 
610  611  ghmm_cseq_free(&sqd) 
… 
… 
cdef class GaussianHiddenMarkovModel(Con 
663  664  BaumWelch algorithm to increase the probability of observing O. 
664  665  
665  666  INPUT: 
666   training_seqs  a list of lists of emission symbols; all sequences of 
667   length 0 are ignored. 
 667  training_seqs  a list of lists of emission symbols, where all sequences of 
 668  length 0 are ignored; or, a list of pairs 
 669  (sample_sequence, weight), 
 670  where sample_sequence is a list or TimeSeries, and weight is 
 671  a positive real number. 
668  672  max_iter  integer or None (default: 10000) maximum number 
669  673  of BaumWelch steps to take 
670  674  log_likehood_cutoff  positive float or None (default: 0.00001); 
… 
… 
cdef class GaussianHiddenMarkovModel(Con 
693  697  
694  698  We train using a list of lists: 
695  699  sage: m = hmm.GaussianHiddenMarkovModel([[1,0],[0,1]], [(0,1),(0,2)], [1/2,1/2]) 
696   sage: m.baum_welch([[1,2,], [3,2]]) 
 700  sage: m.baum_welch([[1,2], [3,2]]) 
697  701  sage: m 
698  702  Gaussian Hidden Markov Model with 2 States 
699  703  Transition matrix: 
… 
… 
cdef class GaussianHiddenMarkovModel(Con 
725  729  cs.smo = self.m 
726  730  try: 
727  731  cs.sqd = to_cseq(training_seqs) 
728   except ValueError: 
729   # No sequences 
 732  except RuntimeError: 
 733  # No nonempty sequences 
730  734  return 
731  735  cs.logp = <double*> safe_malloc(sizeof(double)) 
732  736  cs.eps = log_likelihood_cutoff 
… 
… 
cdef ghmm_cseq* to_cseq(seq) except NULL 
742  746  Return a pointer to a ghmm_cseq C struct. 
743  747  
744  748  All empty sequences are ignored. If there are no nonempty 
745   sequences a ValueError is raised, since GHMM doesn't treat 
746   this degenerate case well. 
 749  sequences a RuntimeError is raised, since GHMM doesn't treat 
 750  this degenerate case well. If there are any nonpositive 
 751  weights, then a ValueError is raised. 
747  752  """ 
748   if isinstance(seq, list) and len(seq) > 0 and not isinstance(seq[0], (list, tuple)): 
 753  if isinstance(seq, list) and len(seq) > 0 and not isinstance(seq[0], (list, tuple, TimeSeries)): 
749  754  seq = TimeSeries(seq) 
750  755  if isinstance(seq, TimeSeries): 
751  756  seq = [(seq,float(1))] 
… 
… 
cdef ghmm_cseq* to_cseq(seq) except NULL 
767  772  
768  773  n = len(seq) 
769  774  if n == 0: 
770   raise ValueError, "there must be at least one nonempty sequence" 
 775  raise RuntimeError, "there must be at least one nonempty sequence" 
 776  
 777  for _, w in seq: 
 778  if w <= 0: 
 779  raise ValueError, "each weight must be positive" 
 780  
771  781  cdef ghmm_cseq* sqd = <ghmm_cseq*>safe_malloc(sizeof(ghmm_cseq)) 
772  782  sqd.seq = <double**>safe_malloc(sizeof(double*) * n) 
773  783  sqd.seq_len = to_int_array([len(v) for v,_ in seq]) 