Ticket #3311: trac_3311_sage.patch

File trac_3311_sage.patch, 44.5 KB (added by mabshoff, 9 years ago)
  • new file sage/dsage/scripts/dsage_setup.py

    # HG changeset patch
    # User mabshoff@sage.math.washington.edu
    # Date 1211858170 25200
    # Node ID b0fe5a4b514ac5da2a903f7cb5004309295b551b
    # Parent  5ce556fc4ec1124d0ddb02a2513291edf4a5457f
    Revert #3097 by adding dsage_setup.py and dsage_worker.py back into the repo
    
    diff -r 5ce556fc4ec1 -r b0fe5a4b514a sage/dsage/scripts/dsage_setup.py
    - +  
     1############################################################################
     2#                                                                     
     3#   DSAGE: Distributed SAGE                     
     4#                                                                             
     5#       Copyright (C) 2006, 2007 Yi Qiang <yqiang@gmail.com>               
     6#                                                                           
     7#  Distributed under the terms of the GNU General Public License (GPL)       
     8#
     9#    This code is distributed in the hope that it will be useful,
     10#    but WITHOUT ANY WARRANTY; without even the implied warranty of
     11#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
     12#    General Public License for more details.
     13#
     14#  The full text of the GPL is available at:
     15#
     16#                  http://www.gnu.org/licenses/
     17############################################################################
     18
     19import os
     20import random
     21import socket
     22import ConfigParser
     23import subprocess
     24import sys
     25import sqlite3
     26
     27from sage.dsage.database.clientdb import ClientDatabaseSA as ClientDatabase
     28from sage.dsage.database.db_config import create_schema
     29from sage.dsage.misc.constants import (DELIMITER, DSAGE_DIR, DSAGE_DB_DIR,
     30                                       DSAGE_DB)
     31from sage.dsage.misc.config import check_dsage_dir
     32from sage.dsage.__version__ import version
     33
     34from sage.misc.viewer import cmd_exists
     35
     36DB_DIR = os.path.join(DSAGE_DIR, 'db/')
     37SAGE_ROOT = os.getenv('SAGE_ROOT')
     38DSAGE_VERSION = version
     39
     40def get_config(type):
     41    config = ConfigParser.ConfigParser()
     42    config.add_section('general')
     43    config.set('general', 'version', DSAGE_VERSION)
     44    config.add_section('ssl')
     45    if type == 'client':
     46        config.add_section('auth')
     47        config.add_section('log')
     48    elif type == 'worker':
     49        config.add_section('uuid')
     50        config.add_section('log')
     51    elif type == 'server':
     52        config.add_section('auth')
     53        config.add_section('server')
     54        config.add_section('server_log')
     55        config.add_section('db')
     56        config.add_section('db_log')
     57    return config
     58
     59def add_default_client(Session):
     60    """
     61    Adds the default client.
     62   
     63    """
     64   
     65    from twisted.conch.ssh import keys
     66    from getpass import getuser
     67   
     68    clientdb = ClientDatabase(Session)
     69   
     70    username = getuser()
     71    pubkey_file = os.path.join(DSAGE_DIR, 'dsage_key.pub')
     72    pubkey = keys.Key.fromFile(pubkey_file)
     73    if clientdb.get_client(username) is None:
     74        clientdb.add_client(username, pubkey.toString(type='openssh'))
     75        print 'Added user %s.\n' % (username)
     76    else:
     77        client = clientdb.get_client(username)
     78        if client.public_key != pubkey:
     79            clientdb.del_client(username)
     80            clientdb.add_client(username, pubkey)
     81            print "User %s's pubkey changed, setting to new one." % (username)
     82        else:
     83            print 'User %s already exists.' % (username)
     84
     85def setup_client(testing=False):
     86    check_dsage_dir()
     87    key_file = os.path.join(DSAGE_DIR, 'dsage_key')
     88    if testing:
     89        cmd = ["ssh-keygen", "-q", "-trsa", "-P ''", "-f%s" % key_file]
     90        return
     91   
     92    if not cmd_exists('ssh-keygen'):
     93        print DELIMITER
     94        print "Could NOT find ssh-keygen."
     95        print "Aborting."
     96        return
     97       
     98    print DELIMITER
     99    print "Generating public/private key pair for authentication..."
     100    print "Your key will be stored in %s/dsage_key" % DSAGE_DIR
     101    print "Just hit enter when prompted for a passphrase"
     102    print DELIMITER
     103   
     104    cmd = ["ssh-keygen", "-q", "-trsa", "-f%s" % key_file]   
     105    ld = os.environ['LD_LIBRARY_PATH']
     106    try:
     107        del os.environ['LD_LIBRARY_PATH']
     108        p = subprocess.call(cmd)
     109    finally:
     110        os.environ['LD_LIBRARY_PATH'] = ld
     111       
     112    print "\n"
     113    print "Client configuration finished.\n"
     114
     115def setup_worker():
     116    check_dsage_dir()
     117    print "Worker configuration finished.\n"
     118
     119def setup_server(template=None):
     120    check_dsage_dir()
     121    print "Choose a domain name for your SAGE notebook server,"
     122    print "for example, localhost (personal use) or %s (to allow outside connections)." % socket.getfqdn()
     123    dn = raw_input("Domain name [localhost]: ").strip()
     124    if dn == '':
     125        print "Using default localhost"
     126        dn = 'localhost'
     127   
     128    template_dict = {'organization': 'SAGE (at %s)' % (dn),
     129                'unit': '389',
     130                'locality': None,
     131                'state': 'Washington',
     132                'country': 'US',
     133                'cn': dn,
     134                'uid': 'sage_user',
     135                'dn_oid': None,
     136                'serial': str(random.randint(1,2**31)),
     137                'dns_name': None,
     138                'crl_dist_points': None,
     139                'ip_address': None,
     140                'expiration_days': 10000,
     141                'email': 'sage@sagemath.org',
     142                'ca': None,
     143                'tls_www_client': None,
     144                'tls_www_server': True,
     145                'signing_key': True,
     146                'encryption_key': True,
     147                }
     148               
     149    if isinstance(template, dict):
     150        template_dict.update(template)
     151   
     152    s = ""
     153    for key, val in template_dict.iteritems():
     154        if val is None:
     155            continue
     156        if val == True:
     157            w = ''
     158        elif isinstance(val, list):
     159            w = ' '.join(['"%s"' % x for x in val])
     160        else:
     161            w = '"%s"' % val
     162        s += '%s = %s \n' % (key, w)
     163   
     164    template_file = os.path.join(DSAGE_DIR, 'cert.cfg')
     165    f = open(template_file, 'w')
     166    f.write(s)
     167    f.close()
     168   
     169    # Disable certificate generation -- not used right now anyways
     170    privkey_file = os.path.join(DSAGE_DIR, 'cacert.pem')
     171    pubkey_file = os.path.join(DSAGE_DIR, 'pubcert.pem')
     172   
     173    print DELIMITER
     174    print "Generating SSL certificate for server..."
     175   
     176    if False and os.uname()[0] != 'Darwin' and cmd_exists('openssl'):
     177        # We use openssl by default if it exists, since it is *vastly*
     178        # faster on Linux.
     179        cmd = ['openssl genrsa > %s' % privkey_file]
     180        print "Using openssl to generate key"
     181        print cmd[0]
     182        subprocess.call(cmd, shell=True)
     183    else:
     184        cmd = ['certtool --generate-privkey --outfile %s' % privkey_file]
     185        print "Using certtool to generate key"
     186        print cmd[0]
     187        # cmd = ['openssl genrsa > %s' % privkey_file]
     188        subprocess.call(cmd, shell=True)
     189       
     190    cmd = ['certtool --generate-self-signed --template %s --load-privkey %s \
     191           --outfile %s' % (template_file, privkey_file, pubkey_file)]
     192    subprocess.call(cmd, shell=True)
     193    print DELIMITER
     194   
     195    # Set read only permissions on cert
     196    os.chmod(os.path.join(DSAGE_DIR, 'cacert.pem'), 0600)
     197   
     198    # create database schemas
     199    from sage.dsage.database.db_config import init_db_sa as init_db
     200    Session = init_db(DSAGE_DB)
     201   
     202    # add default user
     203    add_default_client(Session)
     204           
     205    print "Server configuration finished.\n\n"
     206   
     207def setup(template=None):
     208    setup_client()
     209    setup_worker()
     210    setup_server(template=template)
     211    print "Configuration finished.."
     212
     213if __name__ == '__main__':
     214    if len(sys.argv) == 1:
     215        setup()
     216    if len(sys.argv) == 2:
     217        if sys.argv[1] == 'server':
     218            setup_server()
     219        elif sys.argv[1] == 'worker':
     220            setup_worker()
     221        elif sys.argv[1] == 'client':
     222            setup_client()
     223
  • new file sage/dsage/scripts/dsage_worker.py

    diff -r 5ce556fc4ec1 -r b0fe5a4b514a sage/dsage/scripts/dsage_worker.py
    - +  
     1#!/usr/bin/env python
     2############################################################################
     3#                                                                     
     4#   DSAGE: Distributed SAGE                     
     5#                                                                             
     6#       Copyright (C) 2006, 2007 Yi Qiang <yqiang@gmail.com>               
     7#                                                                           
     8#  Distributed under the terms of the GNU General Public License (GPL)       
     9#
     10#    This code is distributed in the hope that it will be useful,
     11#    but WITHOUT ANY WARRANTY; without even the implied warranty of
     12#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
     13#    General Public License for more details.
     14#
     15#  The full text of the GPL is available at:
     16#
     17#                  http://www.gnu.org/licenses/
     18#
     19############################################################################
     20__docformat__ = "restructuredtext en"
     21
     22import sys
     23import os
     24import cPickle
     25import zlib
     26import pexpect
     27import datetime
     28from math import ceil
     29from getpass import getuser
     30
     31from twisted.spread import pb
     32from twisted.internet import reactor, defer, error, task
     33from twisted.python import log
     34from twisted.spread import banana
     35banana.SIZE_LIMIT = 100*1024*1024 # 100 MegaBytes
     36
     37from gnutls.constants import *
     38from gnutls.crypto import *
     39from gnutls.errors import *
     40from gnutls.interfaces.twisted import X509Credentials
     41
     42from sage.interfaces.sage0 import Sage
     43from sage.misc.preparser import preparse_file
     44
     45from sage.dsage.database.job import Job, expand_job
     46from sage.dsage.misc.hostinfo import HostInfo
     47from sage.dsage.errors.exceptions import NoJobException
     48from sage.dsage.twisted.pb import ClientFactory
     49from sage.dsage.misc.constants import DELIMITER
     50from sage.dsage.misc.constants import DSAGE_DIR
     51from sage.dsage.misc.constants import TMP_WORKER_FILES
     52from sage.dsage.misc.misc import random_str, get_uuid
     53
     54START_MARKER = '\x01r\x01e'
     55END_MARKER = '\x01r\x01b'
     56LOG_PREFIX = "[Worker %s] "
     57
     58class Worker(object):
     59    """
     60    Workers perform the computation of dsage jobs.
     61   
     62    """
     63   
     64    def __init__(self, remoteobj, id, log_level=0, poll=1.0):
     65        """
     66        :type remoteobj: remoteobj
     67        :param remoteobj: Reference to the remote dsage server
     68       
     69        :type id: integer
     70        :param id: numerical identifier of worker
     71       
     72        :type log_level: integer
     73        :param log_level: log level, higher means more verbose
     74       
     75        :type poll: integer
     76        :param poll: rate (in seconds) a worker talks to the server
     77       
     78        """
     79       
     80        self.remoteobj = remoteobj
     81        self.id = id
     82        self.free = True
     83        self.job = None
     84        self.log_level = log_level
     85        self.poll_rate = poll
     86        self.checker_task = task.LoopingCall(self.check_work)
     87        self.checker_timeout = 0.5
     88        self.got_output = False
     89        self.job_start_time = None
     90        self.orig_poll = poll
     91        self.start()
     92       
     93    def _catch_failure(self, failure):
     94        log.msg("Error: ", failure.getErrorMessage())
     95        log.msg("Traceback: ", failure.printTraceback())
     96   
     97    def _increase_poll_rate(self):
     98        if self.poll_rate >= 15: # Cap the polling interval to 15 seconds
     99            self.poll_rate = 15
     100            if self.log_level > 3:
     101                log.msg('[Worker %s] Capping poll rate to %s'
     102                         % (self.id, self.poll_rate))
     103        else:
     104            self.poll_rate = ceil(self.poll_rate * 1.5)
     105            if self.log_level > 3:
     106                log.msg('[Worker %s] Increased polling rate to %s'
     107                        % (self.id, self.poll_rate))
     108   
     109    def get_job(self):
     110        try:
     111            if self.log_level > 3:
     112                log.msg(LOG_PREFIX % self.id +  'Getting job...')
     113            d = self.remoteobj.callRemote('get_job')
     114        except Exception, msg:
     115            log.msg(msg)
     116            log.msg(LOG_PREFIX % self.id +  'Disconnected...')
     117            self._increase_poll_rate()
     118            reactor.callLater(self.poll_rate, self.get_job)
     119            return
     120        d.addCallback(self.gotJob)
     121        d.addErrback(self.noJob)
     122       
     123        return d
     124   
     125    def gotJob(self, jdict):
     126        """
     127        callback for the remoteobj's get_job method.
     128       
     129        :type jdict: dict
     130        :param jdict: job dictionary
     131
     132        """
     133       
     134        if self.log_level > 1:
     135            if jdict is None:
     136                log.msg(LOG_PREFIX % self.id + 'No new job.')
     137        if self.log_level > 3:
     138            if jdict is not None:
     139                log.msg(LOG_PREFIX % self.id + 'Got Job: %s' % jdict)
     140        self.job = expand_job(jdict)
     141        if not isinstance(self.job, Job):
     142            raise NoJobException
     143        try:
     144            self.poll_rate = self.orig_poll
     145            self.doJob(self.job)
     146        except Exception, msg:
     147            log.msg(msg)
     148            self.report_failure(msg)
     149            self.restart()
     150   
     151    def job_done(self, output, result, completed, cpu_time):
     152        """
     153        Reports to the server that a job has finished. It also reports partial
     154        completeness by presenting the server with new output.
     155       
     156        Parameters:
     157        :type output: string
     158        :param output: output of command (to sys.stdout)
     159       
     160        :type result: python pickle
     161        :param result: result of the job
     162       
     163        :type completed: bool
     164        :param completed: whether or not the job is finished
     165       
     166        :type cpu_time: string
     167        :param cpu_time: how long the job took
     168       
     169        """
     170       
     171        job_id = self.job.job_id
     172        wait = 5.0
     173        try:
     174            d = self.remoteobj.callRemote('job_done', job_id, output, result,
     175                                          completed, cpu_time)
     176        except Exception, msg:
     177            log.msg('Error trying to submit job status...')
     178            log.msg('Retrying to submit again in %s seconds...' % wait)
     179            log.err(msg)
     180            reactor.callLater(wait, self.job_done, output, result,
     181                              completed, cpu_time)
     182            d = defer.Deferred()
     183            d.errback(error.ConnectionLost())     
     184            return d
     185       
     186        if completed:
     187            log.msg('[Worker %s] Finished job %s' % (self.id, job_id))
     188            self.restart()
     189   
     190        return d
     191       
     192       
     193    def noJob(self, failure):
     194        """
     195        Errback that catches the NoJobException.
     196       
     197        :type failure: twisted.python.failure
     198        :param failure: a twisted failure object
     199       
     200        """
     201       
     202        if failure.check(NoJobException):
     203            if self.log_level > 1:
     204                msg = 'Sleeping for %s seconds' % self.poll_rate
     205                log.msg(LOG_PREFIX % self.id + msg)
     206            self._increase_poll_rate()
     207            reactor.callLater(self.poll_rate, self.get_job)
     208        else:
     209            log.msg("Error: ", failure.getErrorMessage())
     210            log.msg("Traceback: ", failure.printTraceback())
     211   
     212    def setup_tmp_dir(self, job):
     213        """
     214        Creates the temporary directory for the worker.
     215       
     216        :type job: sage.dsage.database.job.Job
     217        :param job: a Job object
     218       
     219        """
     220       
     221        cur_dir = os.getcwd() # keep a reference to the current directory
     222        tmp_job_dir = os.path.join(TMP_WORKER_FILES, job.job_id)
     223        if not os.path.isdir(TMP_WORKER_FILES):
     224            os.mkdir(TMP_WORKER_FILES)
     225        if not os.path.isdir(tmp_job_dir):
     226            os.mkdir(tmp_job_dir)
     227        os.chdir(tmp_job_dir)
     228        self.sage.eval("os.chdir('%s')" % tmp_job_dir)
     229       
     230        return tmp_job_dir
     231
     232    def extract_and_load_job_data(self, job):
     233        """
     234        Extracts all the data that is in a job object.
     235       
     236        :type job: sage.dsage.database.job.Job
     237        :param job: a Job object
     238       
     239        """
     240       
     241        if isinstance(job.data, list):
     242            if self.log_level > 2:
     243                msg = 'Extracting job data...'
     244                log.msg(LOG_PREFIX % self.id + msg)
     245            try:
     246                for var, data, kind in job.data:
     247                    try:
     248                        data = zlib.decompress(data)
     249                    except Exception, msg:
     250                        log.msg(msg)
     251                        continue
     252                    if kind == 'file':
     253                        data = preparse_file(data, magic=True, do_time=False,
     254                                             ignore_prompts=False)
     255                        f = open(var, 'wb')
     256                        f.write(data)
     257                        f.close()
     258                        if self.log_level > 2:
     259                            msg = 'Extracted %s' % f
     260                            log.msg(LOG_PREFIX % self.id + msg)
     261                        self.sage.eval("execfile('%s')" % var)
     262                    if kind == 'object':
     263                        fname = var + '.sobj'
     264                        if self.log_level > 2:
     265                            log.msg('Object to be loaded: %s' % fname)
     266                        f = open(fname, 'wb')
     267                        f.write(data)
     268                        f.close()
     269                        self.sage.eval("%s = load('%s')" % (var, fname))
     270                        if self.log_level > 2:
     271                            msg = 'Loaded %s' % fname
     272                            log.msg(LOG_PREFIX % self.id + msg)
     273            except Exception, msg:
     274                log.msg(LOG_PREFIX % self.id + msg)
     275
     276    def write_job_file(self, job):
     277        """
     278        Writes out the job file to be executed to disk.
     279       
     280        :type job: sage.dsage.database.job.Job
     281        :param job: A Job object
     282       
     283        """
     284       
     285        parsed_file = preparse_file(job.code, magic=True,
     286                                    do_time=False, ignore_prompts=False)
     287
     288        job_filename = str(job.name) + '.py'
     289        job_file = open(job_filename, 'w')
     290        BEGIN = "print '%s'\n\n" % (START_MARKER)
     291        END = "print '%s'\n\n" % (END_MARKER)
     292        GO_TO_TMP_DIR = """os.chdir('%s')\n""" % self.tmp_job_dir
     293        SAVE_TIME = """save((time.time()-dsage_start_time), 'cpu_time.sobj', compress=False)\n"""
     294        SAVE_RESULT = """try:
     295    save(DSAGE_RESULT, 'result.sobj', compress=True)
     296except:
     297    save('No DSAGE_RESULT', 'result.sobj', compress=True)
     298"""
     299        job_file.write("alarm(%s)\n\n" % (job.timeout))
     300        job_file.write("import time\n\n")
     301        job_file.write(BEGIN)
     302        job_file.write('dsage_start_time = time.time()\n')
     303        job_file.write(parsed_file)
     304        job_file.write("\n\n")
     305        job_file.write(END)
     306        job_file.write("\n")
     307        job_file.write(GO_TO_TMP_DIR)
     308        job_file.write(SAVE_RESULT)
     309        job_file.write(SAVE_TIME)
     310        job_file.close()
     311        if self.log_level > 2:
     312            log.msg('[Worker: %s] Wrote job file. ' % (self.id))
     313           
     314        return job_filename
     315       
     316    def doJob(self, job):
     317        """
     318        Executes a job
     319       
     320        :type job: sage.dsage.database.job.Job
     321        :param job: A Job object
     322
     323        """
     324       
     325        log.msg(LOG_PREFIX % self.id + 'Starting job %s ' % job.job_id)
     326           
     327        self.free = False
     328        self.got_output = False
     329        d = defer.Deferred()
     330       
     331        try:
     332            self.checker_task.start(self.checker_timeout, now=False)
     333        except AssertionError:
     334            self.checker_task.stop()
     335            self.checker_task.start(self.checker_timeout, now=False)
     336        if self.log_level > 2:
     337            log.msg(LOG_PREFIX % self.id + 'Starting checker task...')
     338       
     339        self.tmp_job_dir = self.setup_tmp_dir(job)
     340        self.extract_and_load_job_data(job)
     341       
     342        job_filename = self.write_job_file(job)
     343
     344        f = os.path.join(self.tmp_job_dir, job_filename)
     345        self.sage._send("execfile('%s')" % (f))
     346        self.job_start_time = datetime.datetime.now()
     347        if self.log_level > 2:
     348            msg = 'File to execute: %s' % f
     349            log.msg(LOG_PREFIX % self.id + msg)
     350       
     351        d.callback(True)
     352
     353    def reset_checker(self):
     354        """
     355        Resets the output/result checker for the worker.
     356       
     357        """
     358       
     359        if self.checker_task.running:
     360            self.checker_task.stop()
     361        self.checker_timeout = 1.0
     362        self.checker_task = task.LoopingCall(self.check_work)
     363
     364    def check_work(self):
     365        """
     366        check_work periodically polls workers for new output. The period is
     367        determined by an exponential back off algorithm.
     368       
     369        This figures out whether or not there is anything new output that we
     370        should submit to the server.
     371       
     372        """
     373       
     374        if self.sage == None:
     375            return
     376        if self.job == None or self.free == True:
     377            if self.checker_task.running:
     378                self.checker_task.stop()
     379            return
     380        if self.log_level > 1:
     381            msg = 'Checking job %s' % self.job.job_id
     382            log.msg(LOG_PREFIX % self.id + msg)
     383        os.chdir(self.tmp_job_dir)
     384        try:
     385            # foo, output, new = self.sage._so_far()
     386            # This sucks and is a very bad way to tell when a calculation is
     387            # finished           
     388            done, new = self.sage._get()
     389            # If result.sobj exists, our calculation is done
     390            result = open('result.sobj', 'rb').read()
     391            done = True
     392        except RuntimeError, msg: # Error in calling worker.sage._so_far()
     393            done = False
     394            if self.log_level > 1:
     395                log.msg(LOG_PREFIX % self.id + 'RuntimeError: %s' % msg)
     396                log.msg("Don't worry, the RuntimeError above " +
     397                        "is a non-fatal SAGE failure")
     398            self.increase_checker_task_timeout()
     399            return
     400        except IOError, msg: # File does not exist yet
     401            done = False
     402           
     403        if done:
     404            try:
     405                cpu_time = cPickle.loads(open('cpu_time.sobj', 'rb').read())
     406            except IOError:
     407                cpu_time = -1
     408            self.free = True
     409            self.reset_checker()
     410        else:
     411            result = cPickle.dumps('Job not done yet.', 2)
     412            cpu_time = None
     413           
     414        if self.check_failure(new):
     415            self.report_failure(new)
     416            self.restart()
     417            return
     418       
     419        sanitized_output = self.clean_output(new)   
     420        if self.log_level > 3:
     421            print 'Output before sanitizing: \n' , sanitized_output
     422        if self.log_level > 3:
     423            print 'Output after sanitizing: \n', sanitized_output
     424        if sanitized_output == '' and not done:
     425            self.increase_checker_task_timeout()
     426        else:
     427            d = self.job_done(sanitized_output, result, done, cpu_time)
     428            d.addErrback(self._catch_failure)
     429
     430    def report_failure(self, failure):
     431        """
     432        Reports failure of a job.
     433       
     434        :type failure: twisted.python.failure
     435        :param failure: A twisted failure object
     436       
     437        """
     438       
     439        msg = 'Job %s failed!' % (self.job.job_id)
     440        import shutil
     441        failed_dir = self.tmp_job_dir + '_failed'
     442        if os.path.exists(failed_dir):
     443            shutil.rmtree(failed_dir)
     444        shutil.move(self.tmp_job_dir, failed_dir)
     445        log.msg(LOG_PREFIX % self.id + msg)
     446        log.msg('Traceback: \n%s' % failure)
     447        d = self.remoteobj.callRemote('job_failed', self.job.job_id, failure)
     448        d.addErrback(self._catch_failure)
     449       
     450        return d
     451       
     452    def increase_checker_task_timeout(self):
     453        """
     454        Quickly decreases the number of times a worker checks for output
     455       
     456        """
     457       
     458        if self.checker_task.running:
     459            self.checker_task.stop()
     460       
     461        self.checker_timeout = self.checker_timeout * 1.5
     462        if self.checker_timeout > 300.0:
     463            self.checker_timeout = 300.0
     464        self.checker_task = task.LoopingCall(self.check_work)
     465        self.checker_task.start(self.checker_timeout, now=False)
     466        if self.log_level > 0:
     467            msg = 'Checking output again in %s' % self.checker_timeout
     468            log.msg(LOG_PREFIX % self.id + msg)
     469       
     470    def clean_output(self, sage_output):
     471        """
     472        clean_output attempts to clean up the output string from sage.
     473
     474        :type sage_output: string
     475        :param sage_output: sys.stdout output from the child sage instance
     476       
     477        """
     478       
     479        begin = sage_output.find(START_MARKER)
     480        if begin != -1:
     481            self.got_output = True
     482            begin += len(START_MARKER)
     483        else:
     484            begin = 0
     485        end = sage_output.find(END_MARKER)
     486        if end != -1:
     487            end -= 1
     488        else:
     489            if not self.got_output:
     490                end = 0
     491            else:
     492                end = len(sage_output)
     493        output = sage_output[begin:end]
     494        output = output.strip()
     495        output = output.replace('\r', '')
     496       
     497        if ('execfile' in output or 'load' in output) and self.got_output:
     498            output = ''           
     499           
     500        return output
     501 
     502    def check_failure(self, sage_output):
     503        """
     504        Checks for signs of exceptions or errors in the output.
     505       
     506        :type sage_output: string
     507        :param sage_output: output from the sage instance
     508       
     509        """
     510
     511        if sage_output == None:
     512            return False
     513        else:
     514            sage_output = ''.join(sage_output)
     515
     516        if 'Traceback' in sage_output:
     517            return True
     518        elif 'Error' in sage_output:
     519            return True
     520        else:
     521            return False
     522
     523    def kill_sage(self):
     524        """
     525        Try to hard kill the SAGE instance.
     526       
     527        """
     528       
     529        try:
     530            self.sage.quit()
     531            del self.sage
     532        except Exception, msg:
     533            pid = self.sage.pid()
     534            cmd = 'kill -9 %s' % pid
     535            os.system(cmd)
     536            log.msg(msg)
     537           
     538    def stop(self, hard_reset=False):
     539        """
     540        Stops the current worker and resets it's internal state.
     541       
     542        :type hard_reset: boolean
     543        :param hard_reset: Specifies whether to kill -9 the sage instances
     544           
     545        """
     546       
     547        # Set status to free and delete any current jobs we have
     548        self.free = True
     549        self.job = None
     550       
     551        if hard_reset:
     552            log.msg(LOG_PREFIX % self.id + 'Performing hard reset.')
     553            self.kill_sage()
     554        else: # try for a soft reset
     555            INTERRUPT_TRIES = 20
     556            timeout = 0.3
     557            e = self.sage._expect
     558            try:
     559                for i in range(INTERRUPT_TRIES):   
     560                    self.sage._expect.sendline('q')
     561                    self.sage._expect.sendline(chr(3))  # send ctrl-c
     562                    try:
     563                        e.expect(self.sage._prompt, timeout=timeout)           
     564                        success = True
     565                        break
     566                    except (pexpect.TIMEOUT, pexpect.EOF), msg:
     567                        success = False
     568                        if self.log_level > 3:
     569                            msg = 'Interrupting SAGE (try %s)' % i
     570                            log.msg(LOG_PREFIX % self.id + msg)
     571            except Exception, msg:
     572                success = False
     573                log.msg(msg)
     574                log.msg(LOG_PREFIX % self.id + "Performing hard reset.")
     575       
     576            if not success:
     577                self.kill_sage()
     578            else:
     579                self.sage.reset()
     580   
     581    def start(self):
     582        """
     583        Starts a new worker if it does not exist already.
     584       
     585        """
     586       
     587        log.msg('[Worker %s] Started...' % (self.id))
     588        if not hasattr(self, 'sage'):
     589            if self.log_level > 3:
     590                logfile = DSAGE_DIR + '/%s-pexpect.log' % self.id
     591                self.sage = Sage(maxread=1, logfile=logfile, python=True)
     592            else:
     593                self.sage = Sage(maxread=1, python=True)
     594            try:
     595                self.sage._start(block_during_init=True)
     596            except RuntimeError, msg: # Could not start SAGE
     597                print msg
     598                print 'Failed to start a worker, probably Expect issues.'
     599                reactor.stop()
     600                sys.exit(-1)
     601        E = self.sage.expect()
     602        E.sendline('\n')
     603        E.expect('>>>')
     604        cmd = 'from sage.all import *;'
     605        cmd += 'from sage.all_notebook import *;'
     606        cmd += 'import sage.server.support as _support_; '
     607        cmd += 'import time;'
     608        cmd += 'import os;'
     609        E.sendline(cmd)
     610       
     611        if os.uname()[0].lower() == 'linux':
     612            try:
     613                self.base_mem = int(self.sage.get_memory_usage())
     614            except:
     615                pass
     616   
     617        self.get_job()
     618   
     619    def restart(self):
     620        """
     621        Restarts the current worker.
     622       
     623        """
     624       
     625        log.msg('[Worker: %s] Restarting...' % (self.id))
     626       
     627        if hasattr(self, 'base_mem'):
     628            try:
     629                cur_mem = int(self.sage.get_memory_usage())
     630            except:
     631                cur_mem = 0
     632        try:
     633            if hasattr(self, 'base_mem'):
     634                if cur_mem >= (2 * self.base_mem):
     635                    self.stop(hard_reset=True)
     636            else:
     637                from sage.dsage.misc.misc import timedelta_to_seconds
     638                delta = datetime.datetime.now() - self.job_start_time
     639                secs = timedelta_to_seconds(delta)
     640                if secs >= (3*60): # more than 3 minutes, do a hard reset
     641                    self.stop(hard_reset=True)
     642                else:
     643                    self.stop(hard_reset=False)
     644        except TypeError:
     645            self.stop(hard_reset=True)
     646        self.job_start_time = None
     647        self.start()
     648        self.reset_checker()
     649   
     650   
     651class Monitor(pb.Referenceable):
     652    """
     653    Monitors control workers.
     654    They are able to shutdown workers and spawn them, as well as check on
     655    their status.
     656   
     657    """
     658   
     659    def __init__(self, server='localhost', port=8081, username=getuser(),
     660                 ssl=True, workers=2, authenticate=False, priority=20,
     661                 poll=1.0, log_level=0,
     662                 log_file=os.path.join(DSAGE_DIR, 'worker.log'),
     663                 pubkey_file=None, privkey_file=None):
     664        """
     665        :type server: string
     666        :param server: hostname of remote server
     667       
     668        :type port: integer
     669        :param port: port of remote server
     670       
     671        :type username: string
     672        :param username: username to use for authentication
     673       
     674        :type ssl: boolean
     675        :param ssl: specify whether or not to use SSL for the connection
     676       
     677        :type workers: integer
     678        :param workers: specifies how many workers to launch
     679       
     680        :type authenticate: boolean
     681        :param authenticate: specifies whether or not to authenticate
     682       
     683        :type priority: integer
     684        :param priority: specifies the UNIX priority of the workers
     685       
     686        :type poll: float
     687        :param poll: specifies how fast workers talk to the server in seconds
     688       
     689        :type log_level: integer
     690        :param log_level: specifies verbosity of logging, higher equals more
     691       
     692        :type log_file: string
     693        :param log_file: specifies the location of the log_file
     694           
     695        """
     696       
     697        self.server = server
     698        self.port = port
     699        self.username = username
     700        self.ssl = ssl
     701        self.workers = workers
     702        self.authenticate = authenticate
     703        self.priority = priority
     704        self.poll_rate = poll
     705        self.log_level = log_level
     706        self.log_file = log_file
     707        self.pubkey_file = pubkey_file
     708        self.privkey_file = privkey_file
     709       
     710        self.remoteobj = None
     711        self.connected = False
     712        self.reconnecting = False
     713        self.worker_pool = None
     714        self.sleep_time = 1.0
     715       
     716        self.host_info = HostInfo().host_info
     717       
     718        self.host_info['uuid'] = get_uuid()
     719        self.host_info['workers'] = self.workers
     720        self.host_info['username'] = self.username
     721       
     722        self._startLogging(self.log_file)
     723       
     724        try:
     725            os.nice(self.priority)
     726        except OSError, msg:
     727            log.msg('Error setting priority: %s' % (self.priority))
     728            pass       
     729        if self.authenticate:
     730            from twisted.cred import credentials
     731            from twisted.conch.ssh import keys
     732            self.DATA =  random_str(500)
     733            # public key authentication information
     734            self.pubkey = keys.Key.fromFile(self.pubkey_file)
     735            # try getting the private key object without a passphrase first
     736            try:
     737                self.privkey = keys.Key.fromFile(self.privkey_file)
     738            except keys.BadKeyError:
     739                pphrase = self._getpassphrase()
     740                self.privkey = keys.Key.fromFile(self.privkey_file,
     741                                                  passphrase=pphrase)
     742            self.algorithm = 'rsa'
     743            self.blob = self.pubkey.blob()
     744            self.data = self.DATA
     745            self.signature = self.privkey.sign(self.data)
     746            self.creds = credentials.SSHPrivateKey(self.username,
     747                                                   self.algorithm,
     748                                                   self.blob,
     749                                                   self.data,
     750                                                   self.signature)
     751   
     752    def _startLogging(self, log_file):
     753        """
     754        :type log_file: string
     755        :param log_file: file name to log to
     756       
     757        """
     758       
     759        if log_file == 'stdout':
     760            log.startLogging(sys.stdout)
     761            log.msg('WARNING: Only loggint to stdout!')
     762        else:
     763            worker_log = open(log_file, 'a')
     764            log.startLogging(sys.stdout)
     765            log.startLogging(worker_log)
     766            log.msg("Logging to file: ", log_file)
     767           
     768    def _getpassphrase(self):
     769        import getpass
     770        passphrase = getpass.getpass('Passphrase (Hit enter for None): ')
     771       
     772        return passphrase
     773       
     774    def _connected(self, remoteobj):
     775        """
     776        Callback for connect.
     777       
     778        :type remoteobj: remote object
     779        :param remoteobj: remote obj
     780       
     781        """
     782       
     783        self.remoteobj = remoteobj
     784        self.remoteobj.notifyOnDisconnect(self._disconnected)
     785        self.connected = True
     786       
     787        if self.worker_pool == None: # Only pool workers the first time
     788            self.pool_workers(self.remoteobj)
     789        else:
     790            for worker in self.worker_pool:
     791                worker.remoteobj = self.remoteobj # Update workers
     792                if worker.job == None:
     793                    worker.restart()
     794   
     795    def _disconnected(self, remoteobj):
     796        """
     797        :type remoteobj: remote object
     798        :param remoteobj: remote obj
     799       
     800        """
     801       
     802        log.msg('Closed connection to the server.')
     803        self.connected = False
     804   
     805    def _got_killed_jobs(self, killed_jobs):
     806        """
     807        Callback for check_killed_jobs.
     808       
     809        :type killed_jobs: dict
     810        :param killed_jobs: dict of job jdicts which were killed
     811       
     812        """
     813       
     814        if killed_jobs == None:
     815            return
     816        killed_jobs = [expand_job(jdict) for jdict in killed_jobs]
     817        for worker in self.worker_pool:
     818            if worker.job is None:
     819                continue
     820            if worker.free:
     821                continue
     822            for job in killed_jobs:
     823                if job is None or worker.job is None:
     824                    continue
     825                if worker.job.job_id == job.job_id:
     826                    msg = 'Processing killed job, restarting...'
     827                    log.msg(LOG_PREFIX % worker.id + msg)
     828                    worker.restart()
     829   
     830    def _retryConnect(self):
     831        log.msg('[Monitor] Disconnected, reconnecting in %s' % (5.0))
     832        if not self.connected:
     833            reactor.callLater(5.0, self.connect)
     834   
     835    def _catchConnectionFailure(self, failure):               
     836        log.msg("Error: ", failure.getErrorMessage())
     837        log.msg("Traceback: ", failure.printTraceback())
     838        self._disconnected(None)
     839   
     840    def _catch_failure(self, failure):
     841        log.msg("Error: ", failure.getErrorMessage())
     842        log.msg("Traceback: ", failure.printTraceback())
     843       
     844    def connect(self):
     845        """
     846        This method connects the monitor to a remote PB server.
     847       
     848        """
     849       
     850        if self.connected: # Don't connect multiple times
     851            return
     852       
     853        self.factory = ClientFactory(self._login, (), {})
     854        cred = None
     855        if self.ssl:
     856            cred = X509Credentials()
     857            reactor.connectTLS(self.server, self.port, self.factory, cred)
     858        else:
     859            reactor.connectTCP(self.server, self.port, self.factory)
     860       
     861        log.msg(DELIMITER)
     862        log.msg('DSAGE Worker')
     863        log.msg('Started with PID: %s' % (os.getpid()))
     864        log.msg('Connecting to %s:%s' % (self.server, self.port))
     865        if cred is not None:
     866            log.msg('Using SSL: True')
     867        else:
     868            log.msg('Using SSL: False')
     869        log.msg(DELIMITER)
     870   
     871    def _login(self, *args, **kwargs):
     872        if self.authenticate:
     873            log.msg('Connecting as authenticated worker...\n')
     874            d = self.factory.login(self.creds, (self, self.host_info))
     875        else:
     876            from twisted.cred.credentials import Anonymous
     877            log.msg('Connecting as unauthenticated worker...\n')
     878            d = self.factory.login(Anonymous(), (self, self.host_info))
     879        d.addCallback(self._connected)
     880        d.addErrback(self._catchConnectionFailure)
     881           
     882        return d
     883       
     884    def pool_workers(self, remoteobj):
     885        """
     886        Creates the worker pool.
     887       
     888        """
     889
     890        log.msg('[Monitor] Starting %s workers...' % (self.workers))
     891        self.worker_pool = [Worker(remoteobj, x, self.log_level,
     892                            self.poll_rate)
     893                            for x in range(self.workers)]
     894
     895       
     896    def remote_set_uuid(self, uuid):
     897        """
     898        Sets the workers uuid.
     899        This is called by the server.
     900       
     901        """
     902       
     903        from sage.dsage.misc.misc import set_uuid
     904        set_uuid(uuid)
     905   
     906
     907    def remote_calc_score(self, script):
     908        """
     909        Calculuates the worker score.
     910       
     911        :type script: string
     912        :param script: script to score the worker
     913       
     914        """
     915       
     916        from sage.dsage.misc.misc import exec_wrs
     917       
     918        return exec_wrs(script)
     919
     920   
     921    def remote_kill_job(self, job_id):
     922        """
     923        Kills the job given the job id.
     924       
     925        :type job_id: string
     926        :param job_id: the unique job identifier.
     927       
     928        """
     929       
     930        print 'Killing %s' % (job_id)
     931        for worker in self.worker_pool:
     932            if worker.job != None:
     933                if worker.job.job_id == job_id:
     934                    worker.restart()
     935       
     936       
     937def usage():
     938    """
     939    Prints usage help.
     940
     941    """
     942   
     943    from optparse import OptionParser
     944   
     945    usage = ['usage: %prog [options]\n',
     946              'Bug reports to <yqiang@gmail.com>']
     947    parser = OptionParser(usage=''.join(usage))
     948    parser.add_option('-s', '--server',
     949                      dest='server',
     950                      default='localhost',
     951                      help='hostname. Default is localhost')
     952    parser.add_option('-p', '--port',
     953                      dest='port',
     954                      type='int',
     955                      default=8081,
     956                      help='port to connect to. default=8081')
     957    parser.add_option('--poll',
     958                      dest='poll',
     959                      type='float',
     960                      default=5.0,
     961                      help='poll rate before checking for new job. default=5')
     962    parser.add_option('-a', '--authenticate',
     963                      dest='authenticate',
     964                      default=False,
     965                      action='store_true',
     966                      help='Connect as authenticate worker. default=True')
     967    parser.add_option('-f', '--logfile',
     968                      dest='logfile',
     969                      default=os.path.join(DSAGE_DIR, 'worker.log'),
     970                      help='log file')
     971    parser.add_option('-l', '--loglevel',
     972                      dest='loglevel',
     973                      type='int',
     974                      default=0,
     975                      help='log level. default=0')
     976    parser.add_option('--ssl',
     977                      dest='ssl',
     978                      action='store_true',
     979                      default=False,
     980                      help='enable or disable ssl')
     981    parser.add_option('--privkey',
     982                      dest='privkey_file',
     983                      default=os.path.join(DSAGE_DIR, 'dsage_key'),
     984                      help='private key file. default = ' +
     985                           '~/.sage/dsage/dsage_key')
     986    parser.add_option('--pubkey',
     987                      dest='pubkey_file',
     988                      default=os.path.join(DSAGE_DIR, 'dsage_key.pub'),
     989                      help='public key file. default = ' +
     990                           '~/.sage/dsage/dsage_key.pub')
     991    parser.add_option('-w', '--workers',
     992                      dest='workers',
     993                      type='int',
     994                      default=2,
     995                      help='number of workers. default=2')
     996    parser.add_option('--priority',
     997                      dest='priority',
     998                      type='int',
     999                      default=20,
     1000                      help='priority of workers. default=20')
     1001    parser.add_option('-u', '--username',
     1002                      dest='username',
     1003                      default=getuser(),
     1004                      help='username')
     1005    parser.add_option('--noblock',
     1006                      dest='noblock',
     1007                      action='store_true',
     1008                      default=False,
     1009                      help='tells that the server was ' +
     1010                           'started in blocking mode')
     1011    (options, args) = parser.parse_args()
     1012
     1013    return options
     1014       
     1015def main():
     1016    options = usage()
     1017    SSL = options.ssl
     1018    monitor = Monitor(server=options.server, port=options.port,
     1019                      username=options.username, ssl=SSL,
     1020                      workers=options.workers,
     1021                      authenticate=options.authenticate,
     1022                      priority=options.priority, poll=options.poll,
     1023                      log_file=options.logfile,
     1024                      log_level=options.loglevel,
     1025                      pubkey_file=options.pubkey_file,
     1026                      privkey_file=options.privkey_file)
     1027    monitor.connect()
     1028    try:
     1029        if options.noblock:
     1030            reactor.run(installSignalHandlers=0)
     1031        else:
     1032            reactor.run(installSignalHandlers=1)
     1033    except:
     1034        log.msg('Error starting the twisted reactor, exiting...')
     1035        sys.exit()
     1036
     1037if __name__ == '__main__':
     1038    usage()
     1039    main()
  • setup.py

    diff -r 5ce556fc4ec1 -r b0fe5a4b514a setup.py
    a b code = setup(name = 'sage', 
    14391439                     'sage.dsage.web',
    14401440                     'sage.dsage.scripts',
    14411441                     ],
    1442      
    1443       scripts = [ 'spkg-debian-maybe' ],
     1442
     1443      scripts = ['sage/dsage/scripts/dsage_worker.py',
     1444                 'sage/dsage/scripts/dsage_setup.py',
     1445                 'spkg-debian-maybe',
     1446                ],
    14441447
    14451448      data_files = [('dsage/web/static',                       
    14461449                    ['sage/dsage/web/static/dsage_web.css',