source: sage/dsage/twisted/tests/test_remote.py @ 3831:62161b21fb93

Revision 3831:62161b21fb93, 13.7 KB checked in by Yi Qiang <yqiang@…>, 6 years ago (diff)

Use twisted logging instead of printing whenever possible.

Line 
1############################################################################
2#                                                                     
3#   DSAGE: Distributed SAGE                     
4#                                                                             
5#       Copyright (C) 2006, 2007 Yi Qiang <yqiang@gmail.com>               
6#                                                                           
7#  Distributed under the terms of the GNU General Public License (GPL)       
8#
9#    This code is distributed in the hope that it will be useful,
10#    but WITHOUT ANY WARRANTY; without even the implied warranty of
11#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12#    General Public License for more details.
13#
14#  The full text of the GPL is available at:
15#
16#                  http://www.gnu.org/licenses/
17############################################################################
18
19import ConfigParser
20import os
21import random
22from glob import glob
23import cPickle
24import zlib
25import uuid
26
27from twisted.trial import unittest
28from twisted.spread import pb
29from twisted.internet import reactor
30from twisted.cred import portal, credentials
31from twisted.conch.ssh import keys
32from twisted.python import log
33
34from sage.dsage.twisted.pb import Realm
35from sage.dsage.server.server import DSageServer
36from sage.dsage.twisted.pb import _SSHKeyPortalRoot
37from sage.dsage.twisted.pb import PBClientFactory
38from sage.dsage.twisted.pubkeyauth import PublicKeyCredentialsCheckerDB
39from sage.dsage.database.jobdb import JobDatabaseSQLite
40from sage.dsage.database.monitordb import MonitorDatabase
41from sage.dsage.database.clientdb import ClientDatabase
42from sage.dsage.database.job import Job
43from sage.dsage.errors.exceptions import BadJobError
44from sage.dsage.misc.hostinfo import ClassicHostInfo
45
46DSAGE_DIR = os.path.join(os.getenv('DOT_SAGE'), 'dsage')
47# Begin reading configuration
48try:
49    conf_file = os.path.join(DSAGE_DIR, 'server.conf')
50    config = ConfigParser.ConfigParser()
51    config.read(conf_file)
52
53    LOG_FILE = config.get('server_log', 'log_file')
54    SSL = config.getint('ssl', 'ssl')
55    WORKER_PORT = config.getint('server', 'worker_port')
56    CLIENT_PORT = config.getint('server', 'client_port')
57    PUBKEY_DATABASE = os.path.expanduser(config.get('auth',
58                                                    'pubkey_database'))
59
60    conf_file = os.path.join(DSAGE_DIR, 'client.conf')
61    config = ConfigParser.ConfigParser()
62    config.read(conf_file)
63
64    LOG_FILE = config.get('log', 'log_file')
65    SSL = config.getint('ssl', 'ssl')
66    USERNAME = config.get('auth', 'username')
67    PRIVKEY_FILE = os.path.expanduser(config.get('auth', 'privkey_file'))
68    PUBKEY_FILE = os.path.expanduser(config.get('auth', 'pubkey_file'))
69   
70    conf_file = os.path.join(DSAGE_DIR, 'worker.conf')
71    config = ConfigParser.ConfigParser()
72    config.read(conf_file)
73    if len(config.get('uuid', 'id')) != 36:
74        config.set('uuid', 'id', str(uuid.uuid1()))
75        f = open(conf_file, 'w')
76        config.write(f)
77    UUID = config.get('uuid', 'id')
78    WORKERS = config.getint('general', 'workers')
79   
80except Exception, msg:
81    log.msg(msg)
82    raise 
83# End reading configuration
84hf = ClassicHostInfo().host_info
85hf['uuid'] = UUID
86hf['workers'] = WORKERS
87
88Data =  ''.join([chr(i) for i in [random.randint(65, 123) for n in
89                range(500)]])
90
91class ClientRemoteCallsTest(unittest.TestCase):
92    r"""
93    Tests of remote procedure calls go here.
94   
95    """
96   
97    def unpickle(self, pickled_job):
98        return cPickle.loads(zlib.decompress(pickled_job))
99   
100    def setUp(self):
101        self.jobdb = JobDatabaseSQLite(test=True)
102        self.monitordb = MonitorDatabase(test=True)
103        self.clientdb = ClientDatabase(test=True)
104        self.dsage_server = DSageServer(self.jobdb, 
105                                        self.monitordb, 
106                                        self.clientdb,
107                                        log_level=5)
108        self.realm = Realm(self.dsage_server)
109        self.p = _SSHKeyPortalRoot(portal.Portal(self.realm))
110        self.clientdb = ClientDatabase(test=True)
111        self.p.portal.registerChecker(
112        PublicKeyCredentialsCheckerDB(self.clientdb))
113        self.client_factory = pb.PBServerFactory(self.p)
114        self.hostname = 'localhost'
115        self.port = CLIENT_PORT
116        self.server = reactor.listenTCP(CLIENT_PORT, self.client_factory)
117
118        # public key authentication information
119        self.username = USERNAME
120        self.pubkey_file = PUBKEY_FILE
121        self.privkey_file = PRIVKEY_FILE
122        self.public_key_string = keys.getPublicKeyString(
123                                 filename=self.pubkey_file)
124        self.private_key = keys.getPrivateKeyObject(filename=self.privkey_file)
125        self.public_key = keys.getPublicKeyObject(self.public_key_string)
126        self.alg_name = 'rsa'
127        self.blob = keys.makePublicKeyBlob(self.public_key)
128        self.data = Data
129        self.signature = keys.signData(self.private_key, self.data)
130        self.creds = credentials.SSHPrivateKey(self.username,
131                                               self.alg_name,
132                                               self.blob, 
133                                               self.data,
134                                               self.signature)
135        c = ConfigParser.ConfigParser()
136        c.read(os.path.join(DSAGE_DIR, 'client.conf'))
137        username = c.get('auth', 'username')
138        pubkey_file = c.get('auth', 'pubkey_file')
139        self.clientdb.add_user(username, pubkey_file)
140       
141    def tearDown(self):
142        self.connection.disconnect()
143        self.jobdb._shutdown()
144        files = glob('*.db*')
145        for file in files:
146            os.remove(file)
147        return self.server.stopListening()
148
149    def _catch_failure(self, failure, *args):
150        log.msg("Error: ", failure.getErrorMessage())
151        log.msg("Traceback: ", failure.printTraceback())
152       
153    def testremoteSubmitJob(self):
154        """tests perspective_submit_job"""
155        jobs = self.create_jobs(1)
156
157        factory = PBClientFactory()
158        self.connection = reactor.connectTCP(self.hostname, 
159                                             self.port, 
160                                             factory)
161
162        d = factory.login(self.creds, None)
163        d.addCallback(self._LoginConnected2, jobs)
164        d.addErrback(self._catch_failure)
165        return d
166
167    def _LoginConnected2(self, remoteobj, jobs):
168        job = jobs[0]
169        job.code = "2+2"
170        d = remoteobj.callRemote('submit_job', job.reduce())
171        d.addCallback(self._got_jdict)
172        return d
173
174    def _got_jdict(self, jdict):
175        self.assertEquals(type(jdict), dict)
176        self.assertEquals(type(jdict['job_id']), str)
177
178    def testremoteSubmitBadJob(self):
179        """tests perspective_submit_job"""
180
181        factory = PBClientFactory()
182        self.connection = reactor.connectTCP(self.hostname, 
183                                             self.port, 
184                                             factory)
185
186        d = factory.login(self.creds, None)
187        d.addCallback(self._LoginConnected3)
188        return d
189
190    def _LoginConnected3(self, remoteobj):
191        d = remoteobj.callRemote('submit_job', None)
192        d.addErrback(self._gotNoJobID)
193        return d
194
195    def _gotNoJobID(self, failure):
196        self.assertEquals(BadJobError, failure.check(BadJobError))
197
198    def create_jobs(self, n):
199        """This method creates n jobs. """
200
201        jobs = []
202        for i in range(n):
203            jobs.append(Job(name='unittest', username='yqiang'))
204
205        return jobs
206
207class MonitorRemoteCallsTest(unittest.TestCase):
208    r"""
209    Tests remote calls for monitors.
210   
211    """
212   
213    def setUp(self):
214        self.jobdb = JobDatabaseSQLite(test=True)
215        self.monitordb = MonitorDatabase(test=True)
216        self.clientdb = ClientDatabase(test=True)
217        self.dsage_server = DSageServer(self.jobdb, 
218                                        self.monitordb,
219                                        self.clientdb,
220                                        log_level=5)
221        self.realm = Realm(self.dsage_server)
222        self.p = _SSHKeyPortalRoot(portal.Portal(self.realm))
223        self.p.portal.registerChecker(
224        PublicKeyCredentialsCheckerDB(self.clientdb))
225        self.client_factory = pb.PBServerFactory(self.p)
226        self.hostname = 'localhost'
227        self.port = CLIENT_PORT
228        self.server = reactor.listenTCP(CLIENT_PORT, self.client_factory)
229
230        # public key authentication information
231        self.username = USERNAME
232        self.pubkey_file = PUBKEY_FILE
233        self.privkey_file = PRIVKEY_FILE
234        self.public_key_string = keys.getPublicKeyString(
235                                 filename=self.pubkey_file)
236        self.private_key = keys.getPrivateKeyObject(filename=self.privkey_file)
237        self.public_key = keys.getPublicKeyObject(self.public_key_string)
238        self.alg_name = 'rsa'
239        self.blob = keys.makePublicKeyBlob(self.public_key)
240        self.data = Data
241        self.signature = keys.signData(self.private_key, self.data)
242        self.creds = credentials.SSHPrivateKey(self.username,
243                                               self.alg_name,
244                                               self.blob, 
245                                               self.data,
246                                               self.signature)
247        c = ConfigParser.ConfigParser()
248        c.read(os.path.join(DSAGE_DIR, 'client.conf'))
249        username = c.get('auth', 'username')
250        pubkey_file = c.get('auth', 'pubkey_file')
251        self.clientdb.add_user(username, pubkey_file) 
252       
253    def tearDown(self):
254        self.connection.disconnect()
255        self.jobdb._shutdown()
256        files = glob('*.db*')
257        for file in files:
258            os.remove(file)
259        return self.server.stopListening() 
260   
261    def testremote_get_job(self):
262        job = Job()
263        job.code = "2+2"
264        self.dsage_server.submit_job(job.reduce())
265        factory = PBClientFactory()
266        self.connection = reactor.connectTCP(self.hostname, 
267                                             self.port, 
268                                             factory)                                       
269        d = factory.login(self.creds, (pb.Referenceable(), hf))
270        d.addCallback(self._logged_in)
271        d.addCallback(self._get_job)
272       
273        return d
274   
275    def _logged_in(self, remoteobj):
276        self.assert_(remoteobj is not None)
277       
278        return remoteobj
279       
280    def _get_job(self, remoteobj):
281        d = remoteobj.callRemote('get_job')
282        d.addCallback(self._got_job)
283       
284        return d
285       
286    def _got_job(self, jdict):
287        self.assertEquals(type(jdict), dict)
288   
289    def testremote_job_done(self):
290        factory = PBClientFactory()
291        self.connection = reactor.connectTCP(self.hostname, 
292                                             self.port, 
293                                             factory)                                       
294        d = factory.login(self.creds, (pb.Referenceable(), hf))
295        job = Job()
296        job.code = "2+2"
297        jdict = self.dsage_server.submit_job(job.reduce())
298        d.addCallback(self._logged_in)
299        d.addCallback(self._job_done, jdict)
300       
301        return d
302   
303    def _job_done(self, remoteobj, jdict):
304        job_id = jdict['job_id']
305        result = jdict['result']
306        d = remoteobj.callRemote('job_done', 
307                                 job_id, 
308                                 'Nothing.',
309                                 result,
310                                 False,
311                                 'lalal')
312        d.addCallback(self._done_job)
313       
314        return d
315   
316    def _done_job(self, jdict):
317        self.assertEquals(type(jdict), dict)
318        self.assertEquals(jdict['status'], 'new')
319        self.assertEquals(jdict['output'], 'Nothing.')
320   
321    def testremote_job_failed(self):
322        factory = PBClientFactory()
323        self.connection = reactor.connectTCP(self.hostname, 
324                                             self.port, 
325                                             factory)
326        job = Job()
327        job.code = "2+2"
328        jdict = self.dsage_server.submit_job(job.reduce())
329        d = factory.login(self.creds, (pb.Referenceable(), hf))
330        d.addCallback(self._logged_in)
331        d.addCallback(self._job_failed, jdict)
332       
333        return d
334       
335    def _job_failed(self, remoteobj, jdict):
336        d = remoteobj.callRemote('job_failed', jdict['job_id'], 'Failure')
337        d.addCallback(self._failed_job)
338       
339        return d
340       
341    def _failed_job(self, jdict):
342        self.assertEquals(type(jdict), dict) 
343        self.assertEquals(jdict['failures'], 1)
344        self.assertEquals(jdict['output'], 'Failure')
345       
346    def testget_killed_jobs_list(self):
347        factory = PBClientFactory()
348        self.connection = reactor.connectTCP(self.hostname, 
349                                             self.port, 
350                                             factory)
351 
352        job = Job()
353        job.code = "2+2"
354        job.killed = True
355        jdict = self.dsage_server.submit_job(job.reduce())
356        d = factory.login(self.creds, (pb.Referenceable(), hf))
357        d.addCallback(self._logged_in)
358        d.addCallback(self._get_killed_jobs_list)
359        d.addCallback(self._got_killed_jobs_list, jdict)
360       
361        return d
362   
363    def _get_killed_jobs_list(self, remoteobj):
364        d = remoteobj.callRemote('get_killed_jobs_list')
365       
366        return d
367   
368    def _got_killed_jobs_list(self, killed_jobs_list, jdict):
369        self.assertEquals(len(killed_jobs_list), 1)
370        self.assertEquals(killed_jobs_list[0]['job_id'], jdict['job_id'])
371       
372       
373if __name__ == 'main':
374    unittest.main()
Note: See TracBrowser for help on using the repository browser.