source: sage/dsage/scripts/dsage_worker.py @ 3821:f996bb51296c

Revision 3821:f996bb51296c, 24.2 KB checked in by Yi Qiang <yqiang@…>, 6 years ago (diff)

Only instantiate host_info object once.

  • Property exe set to *
Line 
1#!/usr/bin/env python
2############################################################################
3#                                                                     
4#   DSAGE: Distributed SAGE                     
5#                                                                             
6#       Copyright (C) 2006, 2007 Yi Qiang <yqiang@gmail.com>               
7#                                                                           
8#  Distributed under the terms of the GNU General Public License (GPL)       
9#
10#    This code is distributed in the hope that it will be useful,
11#    but WITHOUT ANY WARRANTY; without even the implied warranty of
12#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13#    General Public License for more details.
14#
15#  The full text of the GPL is available at:
16#
17#                  http://www.gnu.org/licenses/
18#
19############################################################################
20
21import sys
22import os
23import ConfigParser
24import uuid
25import cPickle
26import zlib
27
28from twisted.spread import pb
29from twisted.internet import reactor, defer, error, task
30from twisted.python import log
31
32from sage.interfaces.sage0 import Sage
33from sage.misc.preparser import preparse_file
34
35from sage.dsage.database.job import Job, expand_job
36from sage.dsage.misc.hostinfo import HostInfo, ClassicHostInfo
37from sage.dsage.errors.exceptions import NoJobException
38from sage.dsage.twisted.pb import PBClientFactory
39from sage.dsage.misc.constants import delimiter as DELIMITER
40
41pb.setUnjellyableForClass(HostInfo, HostInfo)
42
43DSAGE_DIR = os.path.join(os.getenv('DOT_SAGE'), 'dsage')
44
45# Begin reading configuration
46try:
47    CONF_FILE = os.path.join(DSAGE_DIR, 'worker.conf')
48    CONFIG = ConfigParser.ConfigParser()
49    CONFIG.read(CONF_FILE)
50   
51    LOG_FILE = CONFIG.get('log', 'log_file')
52    LOG_LEVEL = CONFIG.getint('log','log_level')
53    SSL = CONFIG.getint('ssl', 'ssl')
54    WORKERS = CONFIG.getint('general', 'workers')
55    SERVER = CONFIG.get('general', 'server')
56    PORT = CONFIG.getint('general', 'port')
57    DELAY = CONFIG.getint('general', 'delay')
58    NICE_LEVEL = CONFIG.getint('general', 'nice_level')
59    AUTHENTICATE = CONFIG.getboolean('general', 'authenticate')
60except Exception, msg:
61    print msg
62    print "Error reading %s, please fix manually or run dsage.setup()" % \
63    CONF_FILE
64    sys.exit(-1)
65# End reading configuration
66
67# OUTPUT MARKERS shared by Worker and Monitor
68START_MARKER = '___BEGIN___'
69END_MARKER = '___END___'
70
71def unpickle(pickled_job):
72    return cPickle.loads(zlib.decompress(pickled_job))
73   
74class Worker(object):
75    r"""
76    This class represents a worker object that does the actual calculation.
77   
78    Parameters:
79    remoteobj -- reference to the remote PB server
80   
81    """
82   
83    def __init__(self, remoteobj, id):
84        self.remoteobj = remoteobj
85        self.id = id
86        self.free = True
87        self.job = None
88       
89        if LOG_LEVEL > 3:
90            self.sage = Sage(logfile=DSAGE_DIR + '/%s-pexpect.log'\
91                             % self.id)
92        else:
93            self.sage = Sage()
94           
95        # import some basic modules into our Sage() instance
96        self.sage.eval('import time')
97        self.sage.eval('import sys')
98        self.sage.eval('import os')
99       
100        # Initialize getting of jobs
101        self.get_job()
102
103    def get_job(self):
104        try:
105            if LOG_LEVEL > 3:
106                log.msg('Worker %s: Getting job...' % (self.id))
107            d = self.remoteobj.callRemote('get_job')
108        except Exception, msg:
109            log.msg(msg)
110            log.msg('[Worker: %s, get_job] Disconnected from remote server.'\
111                    % self.id)
112            reactor.callLater(DELAY, self.get_job)
113            return
114        d.addCallback(self.gotJob)
115        d.addErrback(self.noJob)
116       
117        return d
118   
119    def gotJob(self, jdict):
120        r"""
121        gotJob is a callback for the remoteobj's get_job method.
122       
123        Parameters:
124        job -- Job object returned by remote's 'get_job' method
125       
126        """
127       
128        if LOG_LEVEL > 3:
129            log.msg('[Worker %s, gotJob] %s' % (self.id, jdict))
130           
131        self.job = expand_job(jdict)
132       
133        if not isinstance(self.job, Job):
134            raise NoJobException
135       
136        log.msg('[Worker: %s] Got job (%s, %s)' % (self.id,
137                                                   self.job.name, 
138                                                   self.job.id))
139        try:
140            self.doJob(self.job)
141        except Exception, e:
142            print e
143            raise
144   
145    def job_done(self, output, result, completed, worker_info):
146        r"""
147        job_done is a callback for doJob.  Called when a job completes.
148       
149        Parameters:
150        output -- the output of the command
151        result -- the result of processing the job, a pickled object
152        completed -- whether or not the job is completely finished (bool)
153        worker_info -- user@host, os.uname() (tuple)
154       
155        """
156       
157        try:
158            d = self.remoteobj.callRemote('job_done',
159                                          self.job.id,
160                                          output,
161                                          result,
162                                          completed,
163                                          worker_info)
164        except Exception, msg:
165            log.msg(msg)
166            log.msg('[Worker: %s, job_done] Disconnected, reconnecting in %s'\
167                    % (self.id, DELAY))
168            reactor.callLater(DELAY, self.job_done, output, 
169                              result, completed, worker_info)
170            d = defer.Deferred()
171            d.errback(error.ConnectionLost())
172            return d
173       
174        if completed:
175            self.restart()
176       
177        return d
178   
179    def noJob(self, failure):
180        # TODO: Probably do not need this errback, look into consolidating
181        # with failedJob
182        r"""
183        noJob is an errback that catches the NoJobException.
184       
185        Parameters:
186        failure -- a twisted.python.failure object (twisted.python.failure)
187       
188        """
189       
190        sleep_time = 5.0
191        if failure.check(NoJobException):
192            if LOG_LEVEL > 3:
193                log.msg('[Worker %s, noJob] Sleeping for %s seconds\
194                ' % (self.id, sleep_time))
195            reactor.callLater(5.0, self.get_job)
196        else:
197            print "Error: ", failure.getErrorMessage()
198            print "Traceback: ", failure.printTraceback()
199   
200    def setupTempDir(self, job):
201        # change to a temporary directory
202        cur_dir = os.getcwd() # keep a reference to the current directory
203        tmp_dir = os.path.join(DSAGE_DIR, 'tmp_worker_files')
204        tmp_job_dir = os.path.join(tmp_dir, job.id)
205        self.tmp_job_dir = tmp_job_dir
206        if not os.path.isdir(tmp_dir):
207            os.mkdir(tmp_dir)
208        os.mkdir(tmp_job_dir)
209        os.chdir(tmp_job_dir)
210        self.sage.eval("os.chdir('%s')" % tmp_job_dir)
211       
212        return tmp_job_dir
213       
214    def extractJobData(self, job):
215        r"""
216        Extracts all the data that is in a job object.
217       
218        """
219       
220        if isinstance(job.data, list):
221            if LOG_LEVEL > 2:
222                log.msg('Extracting job data...')
223            for var, data, kind in job.data:
224                # Uncompress data
225                try:
226                    data = zlib.decompress(data)
227                except Exception, msg:
228                    log.msg(msg)
229                    continue
230                if kind == 'file':
231                    # Write out files to current dir
232                    f = open(var, 'wb')
233                    f.write(data)
234                    if LOG_LEVEL > 2:
235                        log.msg('[Worker: %s] Extracted %s. ' % (self.id, f))
236                if kind == 'object':
237                    # Load object into the SAGE worker
238                    fname = var + '.sobj'
239                    if LOG_LEVEL > 3:
240                        log.msg('Object to be loaded: %s' % fname)
241                    f = open(fname, 'wb')
242                    f.write(data)
243                    f.close()
244                    self.sage.eval("%s = load('%s')" % (var, fname))
245                    if LOG_LEVEL > 2:
246                        log.msg('[Worker: %s] Loaded %s' % (self.id, fname))
247
248    def writeJobFile(self, job):
249        r"""
250        Writes out the job file to be executed to disk.
251       
252        """
253        parsed_file = preparse_file(job.code, magic=False,
254                                    do_time=False, ignore_prompts=False)
255
256        job_filename = str(job.name) + '.py'
257        job_file = open(job_filename, 'w')
258        BEGIN = "print '%s'\n\n" % (START_MARKER)
259        END = "print '%s'\n\n" % (END_MARKER)
260        job_file.write(BEGIN)
261        job_file.write(parsed_file)
262        job_file.write("\n\n")
263        job_file.write(END)
264        job_file.close()
265       
266        if LOG_LEVEL > 2:
267            log.msg('[Worker: %s] Wrote job file. ' % (self.id))
268           
269        return job_filename
270       
271    def doJob(self, job):
272        r"""
273        doJob is the method that drives the execution of a job.
274       
275        Parameters:
276        job -- a Job object (dsage.database.Job)
277       
278        """
279       
280        if LOG_LEVEL > 3:
281            log.msg('[Worker %s, doJob] Beginning job execution...' % (self.id))
282        self.free = False
283        d = defer.Deferred()
284       
285        tmp_job_dir = self.setupTempDir(job)
286        self.extractJobData(job)
287       
288        job_filename = self.writeJobFile(job)
289
290        f = os.path.join(tmp_job_dir, job_filename)
291        self.sage._send("execfile('%s')" % (f))
292        if LOG_LEVEL > 2:
293            log.msg('[Worker: %s] File to execute: %s' % (self.id, f))
294        if LOG_LEVEL > 3:
295            log.msg('[Worker: %s] Called sage._send()' % (self.id))
296       
297        d.callback(True)
298       
299        return d
300   
301    def stop(self):
302        r"""
303        stop() kills the current running job.
304           
305        """
306   
307        self.sage.quit()
308        self.free = True
309        self.job = None
310        self.sage = None
311   
312    def start(self):
313        if LOG_LEVEL > 3:
314            self.sage = Sage(logfile=DSAGE_DIR + '/%s-pexpect.out' % self.id)
315        else:
316            self.sage = Sage()
317        self.get_job()
318   
319    def restart(self):
320        log.msg('[Worker: %s] Restarting...' % (self.id))
321        self.stop()
322        self.start()
323
324class Monitor(object):
325    r"""
326    This class represents a monitor that controls workers.
327   
328    It monitors the workers and checks on their status
329   
330    Parameters:
331    hostname -- the hostname of the server we want to connect to (str)
332    port -- the port of the server we want to connect to (int)
333   
334    """
335   
336    def __init__(self, hostname, port):
337        if hostname is None:
338            hostname = SERVER
339        self.hostname = hostname
340        if port is None:
341            port = PORT
342        self.port = port
343        self.remoteobj = None
344        self.connected = False
345        self.reconnecting = False
346        self.workers = None
347       
348        # Start twisted logging facility
349        self._startLogging(LOG_FILE)
350       
351        if len(CONFIG.get('uuid', 'id')) != 36:
352            CONFIG.set('uuid', 'id', str(uuid.uuid1()))
353            f = open(CONF_FILE, 'w')
354            CONFIG.write(f)
355        self.uuid = CONFIG.get('uuid', 'id')
356       
357        self.host_info = ClassicHostInfo().host_info
358        self.host_info['uuid'] = self.uuid
359        self.host_info['workers'] = WORKERS
360       
361        if AUTHENTICATE:
362            from twisted.cred import credentials
363            from twisted.conch.ssh import keys
364            self._get_auth_info()
365            # public key authentication information
366            self.pubkey_str =keys.getPublicKeyString(filename=self.pubkey_file)
367            # try getting the private key object without a passphrase first
368            try:
369                self.priv_key = keys.getPrivateKeyObject(
370                                    filename=self.privkey_file)
371            except keys.BadKeyError:
372                passphrase = self._getpassphrase()
373                self.priv_key = keys.getPrivateKeyObject(
374                                filename=self.privkey_file,
375                                passphrase=passphrase)
376            self.pub_key = keys.getPublicKeyObject(self.pubkey_str)
377            self.alg_name = 'rsa'
378            self.blob = keys.makePublicKeyBlob(self.pub_key)
379            self.data = self.DATA
380            self.signature = keys.signData(self.priv_key, self.data)
381            self.creds = credentials.SSHPrivateKey(self.username,
382                                                   self.alg_name,
383                                                   self.blob, 
384                                                   self.data,
385                                                   self.signature)
386   
387    def _startLogging(self, log_file):
388        if log_file == 'stdout':
389            log.startLogging(sys.stdout)
390        else:
391            print "Logging to file: ", log_file
392            server_log = open(log_file, 'a')
393            log.startLogging(server_log)
394
395    def _get_auth_info(self):
396        import random
397        self.DATA =  ''.join([chr(i) for i in [random.randint(65, 123) for n in
398                        range(500)]])
399        self.DSAGE_DIR = os.path.join(os.getenv('DOT_SAGE'), 'dsage')
400        # Begin reading configuration
401        try:
402            conf_file = os.path.join(self.DSAGE_DIR, 'client.conf')
403            config = ConfigParser.ConfigParser()
404            config.read(conf_file)
405           
406            self.port = config.getint('general', 'port')
407            self.username = config.get('auth', 'username')
408            self.privkey_file = os.path.expanduser(config.get('auth',
409                                                              'privkey_file'))
410            self.pubkey_file = os.path.expanduser(config.get('auth',
411                                                             'pubkey_file'))
412        except Exception, msg:
413            print msg
414            raise
415   
416    def _getpassphrase(self):
417        import getpass
418        passphrase = getpass.getpass('Passphrase (Hit enter for None): ')
419       
420        return passphrase
421       
422    def _connected(self, remoteobj):
423        self.remoteobj = remoteobj
424        self.remoteobj.notifyOnDisconnect(self._disconnected)
425        self.connected = True
426        self.reconnecting = False
427       
428        if self.workers == None: # Only pool workers the first time
429            self.poolWorkers(self.remoteobj)
430        else:
431            for worker in self.workers:
432                worker.remoteobj = self.remoteobj # Update workers
433        # self.submit_host_info()
434   
435    def _disconnected(self, remoteobj):
436        log.msg('Lost connection to the server.')
437        self.connected = False
438        self._retryConnect()
439   
440    def _gotKilledJobsList(self, killed_jobs):
441        if killed_jobs == None:
442            return
443        # reconstruct the Job objects from the jdicts
444        killed_jobs = [expand_job(jdict) for jdict in killed_jobs]
445        for worker in self.workers:
446            if worker.job is None:
447                continue
448            if worker.free:
449                continue
450            for job in killed_jobs:
451                if job is None or worker.job is None:
452                    continue
453                if worker.job.id == job.id:
454                    log.msg('[Worker: %s] Processing a killed job, \
455                            restarting...' % worker.id)
456                    worker.restart()
457   
458    def _retryConnect(self):
459        log.msg('[Monitor] Disconnected, reconnecting in %s' % DELAY)
460        reactor.callLater(DELAY, self.connect)
461   
462    def _catchConnectionFailure(self, failure):
463        # If we lost the connection to the server somehow
464        # if failure.check(error.ConnectionRefusedError,
465        #                 error.ConnectionLost,
466        #                 pb.DeadReferenceError):
467       
468        self.connected = False
469        self._retryConnect()
470       
471        log.msg("Error: ", failure.getErrorMessage())
472        log.msg("Traceback: ", failure.printTraceback())
473   
474    def _catchFailure(self, failure):
475        log.msg("Error: ", failure.getErrorMessage())
476        log.msg("Traceback: ", failure.printTraceback())
477       
478    def connect(self):
479        r"""
480        This method connects the monitor to a remote PB server.
481       
482        """
483        if self.connected: # Don't connect multiple times
484            return
485   
486        factory = pb.PBClientFactory()
487       
488        log.msg(DELIMITER)
489        log.msg('DSAGE Worker')
490        log.msg('Connecting to %s:%s' % (self.hostname, self.port))
491        log.msg(DELIMITER)
492       
493        self.factory = PBClientFactory()
494        if SSL == 1:
495            from twisted.internet import ssl
496            contextFactory = ssl.ClientContextFactory()
497            reactor.connectSSL(self.hostname, self.port,
498                               self.factory, contextFactory) 
499        else:
500            reactor.connectTCP(self.hostname, self.port, self.factory)
501       
502        if AUTHENTICATE:
503            log.msg('Connecting as authenticated worker...\n')
504            d = self.factory.login(self.creds, (pb.Referenceable(), self.host_info))
505        else:
506            log.msg('Connecting as anonymous worker...\n')
507            d = self.factory.login('Anonymous', (pb.Referenceable(), self.host_info))
508        d.addCallback(self._connected)
509        d.addErrback(self._catchConnectionFailure)
510           
511        return d
512   
513    def poolWorkers(self, remoteobj):
514        r"""
515        poolWorkers creates as many workers as specified in worker.conf.
516       
517        """
518       
519        self.workers = [Worker(remoteobj, x) for x in range(WORKERS)]
520        log.msg('[Monitor] Initialized ' + str(WORKERS) + ' workers.')
521   
522    def checkForJobOutput(self):
523        r"""
524        checkForJobOutput periodically polls workers for new output.
525       
526        This figures out whether or not there is anything new output that we
527        should submit to the server.
528       
529        """
530
531        if self.workers == None:
532            return
533       
534        for worker in self.workers:
535            if worker.job == None:
536                continue
537            if worker.free == True:
538                continue
539           
540            if LOG_LEVEL > 1:
541                log.msg('[Monitor] Checking for job output')
542            try:
543                done, output, new = worker.sage._so_far()
544            except Exception, msg:
545                log.msg(msg)
546                continue
547            if new == '' or new.isspace():
548                continue
549            if done:
550                # Checks to see if the job created a result var
551                sobj = worker.sage.get('DSAGE_RESULT')
552                if sobj == '' or sobj.isspace():
553                    sobj = worker.sage.get('DSAGE_RESULT')
554                    if sobj == '' or sobj.isspace():
555                        sobj = worker.sage.get('DSAGE_RESULT')
556                    else:
557                        if LOG_LEVEL > 1:
558                            log.msg('Got DSAGE_RESULT second time')
559               
560                # DSAGE_RESULT does not exist
561                if 'Error: name \'DSAGE_RESULT\' is not defined' in sobj:
562                    if LOG_LEVEL > 1:
563                        log.msg('DSAGE_RESULT does not exist')
564                    result = cPickle.dumps('No result saved.', 2)
565                else:
566                    os.chdir(worker.tmp_job_dir)
567                    try:
568                        result = open(sobj, 'rb').read()
569                    except Exception, msg:
570                        if LOG_LEVEL > 1:
571                            log.msg(msg)
572                        result = cPickle.dumps('Error in reading result.', 2)
573                worker.free = True
574                log.msg("Job '%s' finished" % worker.job.name)
575            else:
576                result = cPickle.dumps('Job not done yet.', 2)
577           
578            worker_info = (os.getenv('USER') + '@' + os.uname()[1],
579                           ' '.join(os.uname()[2:]))
580                           
581            sanitized_output = self.sanitizeOutput(new)
582           
583            if self.checkOutputForFailure(sanitized_output):
584                s = ['[Monitor] Error in result for ',
585                     'job %s %s done by ' % (worker.job.name, worker.job.id),
586                     'Worker %s' % worker.id
587                     ]
588                log.msg(''.join(s))
589                log.msg('[Monitor] Traceback: \n%s' % sanitized_output)
590                d = self.remoteobj.callRemote('job_failed', worker.job.id)
591               
592            d = worker.job_done(sanitized_output, result, done, worker_info)
593            d.addErrback(self._catchConnectionFailure)
594   
595    def checkOutputForFailure(self, sage_output):
596        if sage_output == None:
597            return False
598        else:
599            sage_output = ''.join(sage_output)
600       
601        if 'Traceback' in sage_output:
602            return True
603        elif 'Error' in sage_output:
604            return True
605        else:
606            return False
607   
608    def checkForKilledJobs(self):
609        r"""
610        checkForKilledJobs retrieves a list of killed job ids.
611       
612        """
613       
614        if not self.connected:
615            return
616           
617        killed_jobs = self.remoteobj.callRemote('get_killed_jobs_list')
618        killed_jobs.addCallback(self._gotKilledJobsList)
619        killed_jobs.addErrback(self._catchFailure)
620   
621    def jobUpdated(self, id):
622        r"""
623        jobUpdated is a callback that gets called when there is new output
624        from checkForJobOutput.
625       
626        """
627       
628        print str(id) + ' was updated!'
629   
630    def sanitizeOutput(self, sage_output):
631        r"""
632        sanitizeOutput attempts to clean up the output string from sage.
633       
634        """
635       
636        # log.msg("Before cleaning output: ", sage_output)
637        begin = sage_output.find(START_MARKER)
638        if begin != -1:
639            begin += len(START_MARKER)
640        else:
641            begin = 0
642        end = sage_output.find(END_MARKER)
643        if end != -1:
644            end -= 1
645        else:
646            end = len(sage_output)
647        output = sage_output[begin:end]
648        output = output.strip()
649        output = output.replace('\r', '')
650       
651        # log.msg("After cleaning output: ", output)
652        return output
653   
654    def _gotHostInfo(self, h):
655       
656        # attach the workers uuid to the dictionary returned by
657        # HostInfo().get_host_info
658        h['uuid'] = self.uuid
659       
660        d = self.remoteobj.callRemote("submit_host_info", h)
661        d.addErrback(self._catchConnectionFailure)
662        log.msg('[Monitor] Submitted host info')
663   
664    def submit_host_info(self):
665        r"""
666        Sends the workers hardware specs to the server.
667       
668        """
669       
670        h = HostInfo().get_host_info()
671        h.addCallback(self._gotHostInfo)
672        h.addErrback(self._catchConnectionFailure)
673   
674    def startLoopingCalls(self):
675        r"""
676        startLoopingCalls prepares and starts our periodic checking methods.
677       
678        """
679   
680        # submits the output to the server
681        self.tsk1 = task.LoopingCall(self.checkForJobOutput)
682        self.tsk1.start(0.1, now=False)
683       
684        # checks for killed jobs
685        self.tsk2 = task.LoopingCall(self.checkForKilledJobs)
686        self.tsk2.start(5.0, now=False)
687   
688    def stopLoopingCalls(self):
689        r"""
690        Stops the looping calls.
691       
692        """
693        self.tsk.stop()
694        self.tsk1.stop()
695        self.tsk2.stop()
696
697def main():
698    r"""
699    argv[1] == hostname
700    argv[2] == port
701   
702    """
703
704    if len(sys.argv) == 3:
705        hostname, port = sys.argv[1:3]       
706        try:
707            port = int(port)
708        except:
709            port = None
710        if hostname == 'None':
711            hostname = None
712        else:
713            try:
714                hostname = str(hostname)
715            except Exception, msg:
716                print msg
717                hostname = None
718    else:
719        hostname = port = None
720       
721    monitor = Monitor(hostname, port)
722
723    monitor.connect()
724    monitor.startLoopingCalls()
725   
726    try:
727        reactor.run()
728    except:
729        sys.exist(-1)
730
731if __name__ == '__main__':
732    main()
Note: See TracBrowser for help on using the repository browser.