source: sage/dsage/scripts/dsage_worker.py @ 3830:387049c1b87d

Revision 3830:387049c1b87d, 24.3 KB checked in by Yi Qiang <yqiang@…>, 6 years ago (diff)

Use sage.restart versus hard restarting now.

  • Property exe set to *
Line 
1#!/usr/bin/env python
2############################################################################
3#                                                                     
4#   DSAGE: Distributed SAGE                     
5#                                                                             
6#       Copyright (C) 2006, 2007 Yi Qiang <yqiang@gmail.com>               
7#                                                                           
8#  Distributed under the terms of the GNU General Public License (GPL)       
9#
10#    This code is distributed in the hope that it will be useful,
11#    but WITHOUT ANY WARRANTY; without even the implied warranty of
12#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13#    General Public License for more details.
14#
15#  The full text of the GPL is available at:
16#
17#                  http://www.gnu.org/licenses/
18#
19############################################################################
20
21import sys
22import os
23import ConfigParser
24import uuid
25import cPickle
26import zlib
27
28from twisted.spread import pb
29from twisted.internet import reactor, defer, error, task
30from twisted.python import log
31
32from sage.interfaces.sage0 import Sage
33from sage.misc.preparser import preparse_file
34
35from sage.dsage.database.job import Job, expand_job
36from sage.dsage.misc.hostinfo import HostInfo, ClassicHostInfo
37from sage.dsage.errors.exceptions import NoJobException
38from sage.dsage.twisted.pb import PBClientFactory
39from sage.dsage.misc.constants import delimiter as DELIMITER
40
41pb.setUnjellyableForClass(HostInfo, HostInfo)
42
43DSAGE_DIR = os.path.join(os.getenv('DOT_SAGE'), 'dsage')
44
45# Begin reading configuration
46try:
47    CONF_FILE = os.path.join(DSAGE_DIR, 'worker.conf')
48    CONFIG = ConfigParser.ConfigParser()
49    CONFIG.read(CONF_FILE)
50   
51    LOG_FILE = CONFIG.get('log', 'log_file')
52    LOG_LEVEL = CONFIG.getint('log','log_level')
53    SSL = CONFIG.getint('ssl', 'ssl')
54    WORKERS = CONFIG.getint('general', 'workers')
55    SERVER = CONFIG.get('general', 'server')
56    PORT = CONFIG.getint('general', 'port')
57    DELAY = CONFIG.getint('general', 'delay')
58    NICE_LEVEL = CONFIG.getint('general', 'nice_level')
59    AUTHENTICATE = CONFIG.getboolean('general', 'authenticate')
60except Exception, msg:
61    print msg
62    print "Error reading %s, please fix manually or run dsage.setup()" % \
63    CONF_FILE
64    sys.exit(-1)
65# End reading configuration
66
67# OUTPUT MARKERS shared by Worker and Monitor
68START_MARKER = '___BEGIN___'
69END_MARKER = '___END___'
70
71def unpickle(pickled_job):
72    return cPickle.loads(zlib.decompress(pickled_job))
73   
74class Worker(object):
75    r"""
76    This class represents a worker object that does the actual calculation.
77   
78    Parameters:
79    remoteobj -- reference to the remote PB server
80   
81    """
82   
83    def __init__(self, remoteobj, id):
84        self.remoteobj = remoteobj
85        self.id = id
86        self.free = True
87        self.job = None
88       
89        if LOG_LEVEL > 3:
90            self.sage = Sage(logfile=DSAGE_DIR + '/%s-pexpect.log'\
91                             % self.id)
92        else:
93            self.sage = Sage()
94           
95        # import some basic modules into our Sage() instance
96        self.sage.eval('import time')
97        self.sage.eval('import sys')
98        self.sage.eval('import os')
99       
100        # Initialize getting of jobs
101        self.get_job()
102
103    def get_job(self):
104        try:
105            if LOG_LEVEL > 3:
106                log.msg('Worker %s: Getting job...' % (self.id))
107            d = self.remoteobj.callRemote('get_job')
108        except Exception, msg:
109            log.msg(msg)
110            log.msg('[Worker: %s, get_job] Disconnected from remote server.'\
111                    % self.id)
112            reactor.callLater(DELAY, self.get_job)
113            return
114        d.addCallback(self.gotJob)
115        d.addErrback(self.noJob)
116       
117        return d
118   
119    def gotJob(self, jdict):
120        r"""
121        gotJob is a callback for the remoteobj's get_job method.
122       
123        Parameters:
124        job -- Job object returned by remote's 'get_job' method
125       
126        """
127       
128        if LOG_LEVEL > 3:
129            log.msg('[Worker %s, gotJob] %s' % (self.id, jdict))
130           
131        self.job = expand_job(jdict)
132       
133        if not isinstance(self.job, Job):
134            raise NoJobException
135       
136        log.msg('[Worker: %s] Got job (%s, %s)' % (self.id,
137                                                   self.job.name, 
138                                                   self.job.id))
139        try:
140            self.doJob(self.job)
141        except Exception, e:
142            print e
143            raise
144   
145    def job_done(self, output, result, completed, worker_info):
146        r"""
147        job_done is a callback for doJob.  Called when a job completes.
148       
149        Parameters:
150        output -- the output of the command
151        result -- the result of processing the job, a pickled object
152        completed -- whether or not the job is completely finished (bool)
153        worker_info -- user@host, os.uname() (tuple)
154       
155        """
156       
157        try:
158            d = self.remoteobj.callRemote('job_done',
159                                          self.job.id,
160                                          output,
161                                          result,
162                                          completed,
163                                          worker_info)
164        except Exception, msg:
165            log.msg(msg)
166            log.msg('[Worker: %s, job_done] Disconnected, reconnecting in %s'\
167                    % (self.id, DELAY))
168            reactor.callLater(DELAY, self.job_done, output, 
169                              result, completed, worker_info)
170            d = defer.Deferred()
171            d.errback(error.ConnectionLost())
172            return d
173       
174        if completed:
175            self.restart()
176       
177        return d
178   
179    def noJob(self, failure):
180        # TODO: Probably do not need this errback, look into consolidating
181        # with failedJob
182        r"""
183        noJob is an errback that catches the NoJobException.
184       
185        Parameters:
186        failure -- a twisted.python.failure object (twisted.python.failure)
187       
188        """
189       
190        sleep_time = 5.0
191        if failure.check(NoJobException):
192            if LOG_LEVEL > 3:
193                log.msg('[Worker %s, noJob] Sleeping for %s seconds\
194                ' % (self.id, sleep_time))
195            reactor.callLater(5.0, self.get_job)
196        else:
197            print "Error: ", failure.getErrorMessage()
198            print "Traceback: ", failure.printTraceback()
199   
200    def setupTempDir(self, job):
201        # change to a temporary directory
202        cur_dir = os.getcwd() # keep a reference to the current directory
203        tmp_dir = os.path.join(DSAGE_DIR, 'tmp_worker_files')
204        tmp_job_dir = os.path.join(tmp_dir, job.id)
205        self.tmp_job_dir = tmp_job_dir
206        if not os.path.isdir(tmp_dir):
207            os.mkdir(tmp_dir)
208        os.mkdir(tmp_job_dir)
209        os.chdir(tmp_job_dir)
210        self.sage.eval("os.chdir('%s')" % tmp_job_dir)
211       
212        return tmp_job_dir
213       
214    def extractJobData(self, job):
215        r"""
216        Extracts all the data that is in a job object.
217       
218        """
219       
220        if isinstance(job.data, list):
221            if LOG_LEVEL > 2:
222                log.msg('Extracting job data...')
223            for var, data, kind in job.data:
224                # Uncompress data
225                try:
226                    data = zlib.decompress(data)
227                except Exception, msg:
228                    log.msg(msg)
229                    continue
230                if kind == 'file':
231                    # Write out files to current dir
232                    f = open(var, 'wb')
233                    f.write(data)
234                    if LOG_LEVEL > 2:
235                        log.msg('[Worker: %s] Extracted %s. ' % (self.id, f))
236                if kind == 'object':
237                    # Load object into the SAGE worker
238                    fname = var + '.sobj'
239                    if LOG_LEVEL > 3:
240                        log.msg('Object to be loaded: %s' % fname)
241                    f = open(fname, 'wb')
242                    f.write(data)
243                    f.close()
244                    self.sage.eval("%s = load('%s')" % (var, fname))
245                    if LOG_LEVEL > 2:
246                        log.msg('[Worker: %s] Loaded %s' % (self.id, fname))
247
248    def writeJobFile(self, job):
249        r"""
250        Writes out the job file to be executed to disk.
251       
252        """
253        parsed_file = preparse_file(job.code, magic=False,
254                                    do_time=False, ignore_prompts=False)
255
256        job_filename = str(job.name) + '.py'
257        job_file = open(job_filename, 'w')
258        BEGIN = "print '%s'\n\n" % (START_MARKER)
259        END = "print '%s'\n\n" % (END_MARKER)
260        job_file.write(BEGIN)
261        job_file.write(parsed_file)
262        job_file.write("\n\n")
263        job_file.write(END)
264        job_file.close()
265       
266        if LOG_LEVEL > 2:
267            log.msg('[Worker: %s] Wrote job file. ' % (self.id))
268           
269        return job_filename
270       
271    def doJob(self, job):
272        r"""
273        doJob is the method that drives the execution of a job.
274       
275        Parameters:
276        job -- a Job object (dsage.database.Job)
277       
278        """
279       
280        if LOG_LEVEL > 3:
281            log.msg('[Worker %s, doJob] Beginning job execution...' % (self.id))
282        self.free = False
283        d = defer.Deferred()
284       
285        tmp_job_dir = self.setupTempDir(job)
286        self.extractJobData(job)
287       
288        job_filename = self.writeJobFile(job)
289
290        f = os.path.join(tmp_job_dir, job_filename)
291        self.sage._send("execfile('%s')" % (f))
292        if LOG_LEVEL > 2:
293            log.msg('[Worker: %s] File to execute: %s' % (self.id, f))
294        if LOG_LEVEL > 3:
295            log.msg('[Worker: %s] Called sage._send()' % (self.id))
296       
297        d.callback(True)
298       
299        return d
300   
301    def stop(self):
302        r"""
303        stop() kills the current running job.
304           
305        """
306       
307        self.sage.reset()
308        self.free = True
309        self.job = None
310   
311    def start(self):
312        if self.sage is None:
313            if LOG_LEVEL > 3:
314                self.sage = Sage(logfile=DSAGE_DIR + '/%s-pexpect.out' % self.id)
315            else:
316                self.sage = Sage()
317        self.get_job()
318   
319    def restart(self):
320        r"""
321        Restarts the current worker.
322       
323        """
324       
325        log.msg('[Worker: %s] Restarting...' % (self.id))
326        self.stop()
327        self.start()
328
329class Monitor(object):
330    r"""
331    This class represents a monitor that controls workers.
332   
333    It monitors the workers and checks on their status
334   
335    Parameters:
336    hostname -- the hostname of the server we want to connect to (str)
337    port -- the port of the server we want to connect to (int)
338   
339    """
340   
341    def __init__(self, hostname, port):
342        if hostname is None:
343            hostname = SERVER
344        self.hostname = hostname
345        if port is None:
346            port = PORT
347        self.port = port
348        self.remoteobj = None
349        self.connected = False
350        self.reconnecting = False
351        self.workers = None
352       
353        # Start twisted logging facility
354        self._startLogging(LOG_FILE)
355       
356        if len(CONFIG.get('uuid', 'id')) != 36:
357            CONFIG.set('uuid', 'id', str(uuid.uuid1()))
358            f = open(CONF_FILE, 'w')
359            CONFIG.write(f)
360        self.uuid = CONFIG.get('uuid', 'id')
361       
362        self.host_info = ClassicHostInfo().host_info
363        self.host_info['uuid'] = self.uuid
364        self.host_info['workers'] = WORKERS
365       
366        if AUTHENTICATE:
367            from twisted.cred import credentials
368            from twisted.conch.ssh import keys
369            self._get_auth_info()
370            # public key authentication information
371            self.pubkey_str =keys.getPublicKeyString(filename=self.pubkey_file)
372            # try getting the private key object without a passphrase first
373            try:
374                self.priv_key = keys.getPrivateKeyObject(
375                                    filename=self.privkey_file)
376            except keys.BadKeyError:
377                passphrase = self._getpassphrase()
378                self.priv_key = keys.getPrivateKeyObject(
379                                filename=self.privkey_file,
380                                passphrase=passphrase)
381            self.pub_key = keys.getPublicKeyObject(self.pubkey_str)
382            self.alg_name = 'rsa'
383            self.blob = keys.makePublicKeyBlob(self.pub_key)
384            self.data = self.DATA
385            self.signature = keys.signData(self.priv_key, self.data)
386            self.creds = credentials.SSHPrivateKey(self.username,
387                                                   self.alg_name,
388                                                   self.blob, 
389                                                   self.data,
390                                                   self.signature)
391   
392    def _startLogging(self, log_file):
393        if log_file == 'stdout':
394            log.startLogging(sys.stdout)
395        else:
396            print "Logging to file: ", log_file
397            server_log = open(log_file, 'a')
398            log.startLogging(server_log)
399
400    def _get_auth_info(self):
401        import random
402        self.DATA =  ''.join([chr(i) for i in [random.randint(65, 123) for n in
403                        range(500)]])
404        self.DSAGE_DIR = os.path.join(os.getenv('DOT_SAGE'), 'dsage')
405        # Begin reading configuration
406        try:
407            conf_file = os.path.join(self.DSAGE_DIR, 'client.conf')
408            config = ConfigParser.ConfigParser()
409            config.read(conf_file)
410           
411            self.port = config.getint('general', 'port')
412            self.username = config.get('auth', 'username')
413            self.privkey_file = os.path.expanduser(config.get('auth',
414                                                              'privkey_file'))
415            self.pubkey_file = os.path.expanduser(config.get('auth',
416                                                             'pubkey_file'))
417        except Exception, msg:
418            print msg
419            raise
420   
421    def _getpassphrase(self):
422        import getpass
423        passphrase = getpass.getpass('Passphrase (Hit enter for None): ')
424       
425        return passphrase
426       
427    def _connected(self, remoteobj):
428        self.remoteobj = remoteobj
429        self.remoteobj.notifyOnDisconnect(self._disconnected)
430        self.connected = True
431        self.reconnecting = False
432       
433        if self.workers == None: # Only pool workers the first time
434            self.poolWorkers(self.remoteobj)
435        else:
436            for worker in self.workers:
437                worker.remoteobj = self.remoteobj # Update workers
438        # self.submit_host_info()
439   
440    def _disconnected(self, remoteobj):
441        log.msg('Lost connection to the server.')
442        self.connected = False
443        self._retryConnect()
444   
445    def _gotKilledJobsList(self, killed_jobs):
446        if killed_jobs == None:
447            return
448        # reconstruct the Job objects from the jdicts
449        killed_jobs = [expand_job(jdict) for jdict in killed_jobs]
450        for worker in self.workers:
451            if worker.job is None:
452                continue
453            if worker.free:
454                continue
455            for job in killed_jobs:
456                if job is None or worker.job is None:
457                    continue
458                if worker.job.id == job.id:
459                    log.msg('[Worker: %s] Processing a killed job, \
460                            restarting...' % worker.id)
461                    worker.restart()
462   
463    def _retryConnect(self):
464        log.msg('[Monitor] Disconnected, reconnecting in %s' % DELAY)
465        reactor.callLater(DELAY, self.connect)
466   
467    def _catchConnectionFailure(self, failure):
468        # If we lost the connection to the server somehow
469        # if failure.check(error.ConnectionRefusedError,
470        #                 error.ConnectionLost,
471        #                 pb.DeadReferenceError):
472       
473        self.connected = False
474        self._retryConnect()
475       
476        log.msg("Error: ", failure.getErrorMessage())
477        log.msg("Traceback: ", failure.printTraceback())
478   
479    def _catchFailure(self, failure):
480        log.msg("Error: ", failure.getErrorMessage())
481        log.msg("Traceback: ", failure.printTraceback())
482       
483    def connect(self):
484        r"""
485        This method connects the monitor to a remote PB server.
486       
487        """
488        if self.connected: # Don't connect multiple times
489            return
490   
491        factory = pb.PBClientFactory()
492       
493        log.msg(DELIMITER)
494        log.msg('DSAGE Worker')
495        log.msg('Connecting to %s:%s' % (self.hostname, self.port))
496        log.msg(DELIMITER)
497       
498        self.factory = PBClientFactory()
499        if SSL == 1:
500            from twisted.internet import ssl
501            contextFactory = ssl.ClientContextFactory()
502            reactor.connectSSL(self.hostname, self.port,
503                               self.factory, contextFactory) 
504        else:
505            reactor.connectTCP(self.hostname, self.port, self.factory)
506       
507        if AUTHENTICATE:
508            log.msg('Connecting as authenticated worker...\n')
509            d = self.factory.login(self.creds, (pb.Referenceable(), self.host_info))
510        else:
511            log.msg('Connecting as anonymous worker...\n')
512            d = self.factory.login('Anonymous', (pb.Referenceable(), self.host_info))
513        d.addCallback(self._connected)
514        d.addErrback(self._catchConnectionFailure)
515           
516        return d
517   
518    def poolWorkers(self, remoteobj):
519        r"""
520        poolWorkers creates as many workers as specified in worker.conf.
521       
522        """
523       
524        self.workers = [Worker(remoteobj, x) for x in range(WORKERS)]
525        log.msg('[Monitor] Initialized ' + str(WORKERS) + ' workers.')
526   
527    def checkForJobOutput(self):
528        r"""
529        checkForJobOutput periodically polls workers for new output.
530       
531        This figures out whether or not there is anything new output that we
532        should submit to the server.
533       
534        """
535
536        if self.workers == None:
537            return
538       
539        for worker in self.workers:
540            if worker.job == None:
541                continue
542            if worker.free == True:
543                continue
544           
545            if LOG_LEVEL > 1:
546                log.msg('[Monitor] Checking for job output')
547            try:
548                done, output, new = worker.sage._so_far()
549            except Exception, msg:
550                log.msg(msg)
551                continue
552            if new == '' or new.isspace():
553                continue
554            if done:
555                # Checks to see if the job created a result var
556                sobj = worker.sage.get('DSAGE_RESULT')
557                if sobj == '' or sobj.isspace():
558                    sobj = worker.sage.get('DSAGE_RESULT')
559                    if sobj == '' or sobj.isspace():
560                        sobj = worker.sage.get('DSAGE_RESULT')
561                    else:
562                        if LOG_LEVEL > 1:
563                            log.msg('Got DSAGE_RESULT second time')
564               
565                # DSAGE_RESULT does not exist
566                if 'Error: name \'DSAGE_RESULT\' is not defined' in sobj:
567                    if LOG_LEVEL > 1:
568                        log.msg('DSAGE_RESULT does not exist')
569                    result = cPickle.dumps('No result saved.', 2)
570                else:
571                    os.chdir(worker.tmp_job_dir)
572                    try:
573                        result = open(sobj, 'rb').read()
574                    except Exception, msg:
575                        if LOG_LEVEL > 1:
576                            log.msg(msg)
577                        result = cPickle.dumps('Error in reading result.', 2)
578                worker.free = True
579                log.msg("Job '%s' finished" % worker.job.name)
580            else:
581                result = cPickle.dumps('Job not done yet.', 2)
582           
583            worker_info = (os.getenv('USER') + '@' + os.uname()[1],
584                           ' '.join(os.uname()[2:]))
585                           
586            sanitized_output = self.sanitizeOutput(new)
587           
588            if self.checkOutputForFailure(sanitized_output):
589                s = ['[Monitor] Error in result for ',
590                     'job %s %s done by ' % (worker.job.name, worker.job.id),
591                     'Worker %s' % worker.id
592                     ]
593                log.msg(''.join(s))
594                log.msg('[Monitor] Traceback: \n%s' % sanitized_output)
595                d = self.remoteobj.callRemote('job_failed', worker.job.id)
596               
597            d = worker.job_done(sanitized_output, result, done, worker_info)
598            d.addErrback(self._catchConnectionFailure)
599   
600    def checkOutputForFailure(self, sage_output):
601        if sage_output == None:
602            return False
603        else:
604            sage_output = ''.join(sage_output)
605       
606        if 'Traceback' in sage_output:
607            return True
608        elif 'Error' in sage_output:
609            return True
610        else:
611            return False
612   
613    def checkForKilledJobs(self):
614        r"""
615        checkForKilledJobs retrieves a list of killed job ids.
616       
617        """
618       
619        if not self.connected:
620            return
621           
622        killed_jobs = self.remoteobj.callRemote('get_killed_jobs_list')
623        killed_jobs.addCallback(self._gotKilledJobsList)
624        killed_jobs.addErrback(self._catchFailure)
625   
626    def jobUpdated(self, id):
627        r"""
628        jobUpdated is a callback that gets called when there is new output
629        from checkForJobOutput.
630       
631        """
632       
633        print str(id) + ' was updated!'
634   
635    def sanitizeOutput(self, sage_output):
636        r"""
637        sanitizeOutput attempts to clean up the output string from sage.
638       
639        """
640       
641        # log.msg("Before cleaning output: ", sage_output)
642        begin = sage_output.find(START_MARKER)
643        if begin != -1:
644            begin += len(START_MARKER)
645        else:
646            begin = 0
647        end = sage_output.find(END_MARKER)
648        if end != -1:
649            end -= 1
650        else:
651            end = len(sage_output)
652        output = sage_output[begin:end]
653        output = output.strip()
654        output = output.replace('\r', '')
655       
656        # log.msg("After cleaning output: ", output)
657        return output
658   
659    def _gotHostInfo(self, h):
660       
661        # attach the workers uuid to the dictionary returned by
662        # HostInfo().get_host_info
663        h['uuid'] = self.uuid
664       
665        d = self.remoteobj.callRemote("submit_host_info", h)
666        d.addErrback(self._catchConnectionFailure)
667        log.msg('[Monitor] Submitted host info')
668   
669    def submit_host_info(self):
670        r"""
671        Sends the workers hardware specs to the server.
672       
673        """
674       
675        h = HostInfo().get_host_info()
676        h.addCallback(self._gotHostInfo)
677        h.addErrback(self._catchConnectionFailure)
678   
679    def startLoopingCalls(self):
680        r"""
681        startLoopingCalls prepares and starts our periodic checking methods.
682       
683        """
684   
685        # submits the output to the server
686        self.tsk1 = task.LoopingCall(self.checkForJobOutput)
687        self.tsk1.start(0.1, now=False)
688       
689        # checks for killed jobs
690        self.tsk2 = task.LoopingCall(self.checkForKilledJobs)
691        self.tsk2.start(5.0, now=False)
692   
693    def stopLoopingCalls(self):
694        r"""
695        Stops the looping calls.
696       
697        """
698        self.tsk.stop()
699        self.tsk1.stop()
700        self.tsk2.stop()
701
702def main():
703    r"""
704    argv[1] == hostname
705    argv[2] == port
706   
707    """
708
709    if len(sys.argv) == 3:
710        hostname, port = sys.argv[1:3]       
711        try:
712            port = int(port)
713        except:
714            port = None
715        if hostname == 'None':
716            hostname = None
717        else:
718            try:
719                hostname = str(hostname)
720            except Exception, msg:
721                print msg
722                hostname = None
723    else:
724        hostname = port = None
725       
726    monitor = Monitor(hostname, port)
727
728    monitor.connect()
729    monitor.startLoopingCalls()
730   
731    try:
732        reactor.run()
733    except:
734        sys.exist(-1)
735
736if __name__ == '__main__':
737    main()
Note: See TracBrowser for help on using the repository browser.