source: sage/dsage/twisted/tests/test_remote.py @ 3823:bb2c55b47723

Revision 3823:bb2c55b47723, 13.6 KB checked in by Yi Qiang <yqiang@…>, 6 years ago (diff)

changed user_id to username, more intuitive.

Line 
1############################################################################
2#                                                                     
3#   DSAGE: Distributed SAGE                     
4#                                                                             
5#       Copyright (C) 2006, 2007 Yi Qiang <yqiang@gmail.com>               
6#                                                                           
7#  Distributed under the terms of the GNU General Public License (GPL)       
8#
9#    This code is distributed in the hope that it will be useful,
10#    but WITHOUT ANY WARRANTY; without even the implied warranty of
11#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12#    General Public License for more details.
13#
14#  The full text of the GPL is available at:
15#
16#                  http://www.gnu.org/licenses/
17############################################################################
18
19import ConfigParser
20import os
21import random
22from glob import glob
23import cPickle
24import zlib
25import uuid
26
27from twisted.trial import unittest
28from twisted.spread import pb
29from twisted.internet import reactor
30from twisted.cred import portal, credentials
31from twisted.conch.ssh import keys
32
33from sage.dsage.twisted.pb import Realm
34from sage.dsage.server.server import DSageServer
35from sage.dsage.twisted.pb import _SSHKeyPortalRoot
36from sage.dsage.twisted.pb import PBClientFactory
37from sage.dsage.twisted.pubkeyauth import PublicKeyCredentialsCheckerDB
38from sage.dsage.database.jobdb import JobDatabaseSQLite
39from sage.dsage.database.monitordb import MonitorDatabase
40from sage.dsage.database.clientdb import ClientDatabase
41from sage.dsage.database.job import Job
42from sage.dsage.errors.exceptions import BadJobError
43from sage.dsage.misc.hostinfo import ClassicHostInfo
44
45DSAGE_DIR = os.path.join(os.getenv('DOT_SAGE'), 'dsage')
46# Begin reading configuration
47try:
48    conf_file = os.path.join(DSAGE_DIR, 'server.conf')
49    config = ConfigParser.ConfigParser()
50    config.read(conf_file)
51
52    LOG_FILE = config.get('server_log', 'log_file')
53    SSL = config.getint('ssl', 'ssl')
54    WORKER_PORT = config.getint('server', 'worker_port')
55    CLIENT_PORT = config.getint('server', 'client_port')
56    PUBKEY_DATABASE = os.path.expanduser(config.get('auth',
57                                                    'pubkey_database'))
58
59    conf_file = os.path.join(DSAGE_DIR, 'client.conf')
60    config = ConfigParser.ConfigParser()
61    config.read(conf_file)
62
63    LOG_FILE = config.get('log', 'log_file')
64    SSL = config.getint('ssl', 'ssl')
65    USERNAME = config.get('auth', 'username')
66    PRIVKEY_FILE = os.path.expanduser(config.get('auth', 'privkey_file'))
67    PUBKEY_FILE = os.path.expanduser(config.get('auth', 'pubkey_file'))
68   
69    conf_file = os.path.join(DSAGE_DIR, 'worker.conf')
70    config = ConfigParser.ConfigParser()
71    config.read(conf_file)
72    if len(config.get('uuid', 'id')) != 36:
73        config.set('uuid', 'id', str(uuid.uuid1()))
74        f = open(conf_file, 'w')
75        config.write(f)
76    UUID = config.get('uuid', 'id')
77    WORKERS = config.getint('general', 'workers')
78   
79except Exception, msg:
80    print msg
81    raise 
82# End reading configuration
83hf = ClassicHostInfo().host_info
84hf['uuid'] = UUID
85hf['workers'] = WORKERS
86
87Data =  ''.join([chr(i) for i in [random.randint(65, 123) for n in
88                range(500)]])
89
90class ClientRemoteCallsTest(unittest.TestCase):
91    r"""
92    Tests of remote procedure calls go here.
93   
94    """
95   
96    def unpickle(self, pickled_job):
97        return cPickle.loads(zlib.decompress(pickled_job))
98   
99    def setUp(self):
100        self.jobdb = JobDatabaseSQLite(test=True)
101        self.monitordb = MonitorDatabase(test=True)
102        self.clientdb = ClientDatabase(test=True)
103        self.dsage_server = DSageServer(self.jobdb, 
104                                        self.monitordb, 
105                                        self.clientdb,
106                                        log_level=5)
107        self.realm = Realm(self.dsage_server)
108        self.p = _SSHKeyPortalRoot(portal.Portal(self.realm))
109        self.clientdb = ClientDatabase(test=True)
110        self.p.portal.registerChecker(
111        PublicKeyCredentialsCheckerDB(self.clientdb))
112        self.client_factory = pb.PBServerFactory(self.p)
113        self.hostname = 'localhost'
114        self.port = CLIENT_PORT
115        self.server = reactor.listenTCP(CLIENT_PORT, self.client_factory)
116
117        # public key authentication information
118        self.username = USERNAME
119        self.pubkey_file = PUBKEY_FILE
120        self.privkey_file = PRIVKEY_FILE
121        self.public_key_string = keys.getPublicKeyString(
122                                 filename=self.pubkey_file)
123        self.private_key = keys.getPrivateKeyObject(filename=self.privkey_file)
124        self.public_key = keys.getPublicKeyObject(self.public_key_string)
125        self.alg_name = 'rsa'
126        self.blob = keys.makePublicKeyBlob(self.public_key)
127        self.data = Data
128        self.signature = keys.signData(self.private_key, self.data)
129        self.creds = credentials.SSHPrivateKey(self.username,
130                                               self.alg_name,
131                                               self.blob, 
132                                               self.data,
133                                               self.signature)
134        c = ConfigParser.ConfigParser()
135        c.read(os.path.join(DSAGE_DIR, 'client.conf'))
136        username = c.get('auth', 'username')
137        pubkey_file = c.get('auth', 'pubkey_file')
138        self.clientdb.add_user(username, pubkey_file)
139       
140    def tearDown(self):
141        self.connection.disconnect()
142        self.jobdb._shutdown()
143        files = glob('*.db*')
144        for file in files:
145            os.remove(file)
146        return self.server.stopListening()
147
148    def _catchFailure(self, failure, *args):
149        log.msg("Error: ", failure.getErrorMessage())
150        log.msg("Traceback: ", failure.printTraceback())
151       
152    def testremoteSubmitJob(self):
153        """tests perspective_submit_job"""
154        jobs = self.create_jobs(1)
155
156        factory = PBClientFactory()
157        self.connection = reactor.connectTCP(self.hostname, 
158                                             self.port, 
159                                             factory)
160
161        d = factory.login(self.creds, None)
162        d.addCallback(self._LoginConnected2, jobs)
163        d.addErrback(self._catchFailure)
164        return d
165
166    def _LoginConnected2(self, remoteobj, jobs):
167        job = jobs[0]
168        job.code = "2+2"
169        d = remoteobj.callRemote('submit_job', job.reduce())
170        d.addCallback(self._got_jdict)
171        return d
172
173    def _got_jdict(self, jdict):
174        self.assertEquals(type(jdict), dict)
175        self.assertEquals(type(jdict['job_id']), str)
176
177    def testremoteSubmitBadJob(self):
178        """tests perspective_submit_job"""
179
180        factory = PBClientFactory()
181        self.connection = reactor.connectTCP(self.hostname, 
182                                             self.port, 
183                                             factory)
184
185        d = factory.login(self.creds, None)
186        d.addCallback(self._LoginConnected3)
187        return d
188
189    def _LoginConnected3(self, remoteobj):
190        d = remoteobj.callRemote('submit_job', None)
191        d.addErrback(self._gotNoJobID)
192        return d
193
194    def _gotNoJobID(self, failure):
195        self.assertEquals(BadJobError, failure.check(BadJobError))
196
197    def create_jobs(self, n):
198        """This method creates n jobs. """
199
200        jobs = []
201        for i in range(n):
202            jobs.append(Job(name='unittest', username='yqiang'))
203
204        return jobs
205
206class MonitorRemoteCallsTest(unittest.TestCase):
207    r"""
208    Tests remote calls for monitors.
209   
210    """
211   
212    def setUp(self):
213        self.jobdb = JobDatabaseSQLite(test=True)
214        self.monitordb = MonitorDatabase(test=True)
215        self.clientdb = ClientDatabase(test=True)
216        self.dsage_server = DSageServer(self.jobdb, 
217                                        self.monitordb,
218                                        self.clientdb,
219                                        log_level=5)
220        self.realm = Realm(self.dsage_server)
221        self.p = _SSHKeyPortalRoot(portal.Portal(self.realm))
222        self.p.portal.registerChecker(
223        PublicKeyCredentialsCheckerDB(self.clientdb))
224        self.client_factory = pb.PBServerFactory(self.p)
225        self.hostname = 'localhost'
226        self.port = CLIENT_PORT
227        self.server = reactor.listenTCP(CLIENT_PORT, self.client_factory)
228
229        # public key authentication information
230        self.username = USERNAME
231        self.pubkey_file = PUBKEY_FILE
232        self.privkey_file = PRIVKEY_FILE
233        self.public_key_string = keys.getPublicKeyString(
234                                 filename=self.pubkey_file)
235        self.private_key = keys.getPrivateKeyObject(filename=self.privkey_file)
236        self.public_key = keys.getPublicKeyObject(self.public_key_string)
237        self.alg_name = 'rsa'
238        self.blob = keys.makePublicKeyBlob(self.public_key)
239        self.data = Data
240        self.signature = keys.signData(self.private_key, self.data)
241        self.creds = credentials.SSHPrivateKey(self.username,
242                                               self.alg_name,
243                                               self.blob, 
244                                               self.data,
245                                               self.signature)
246        c = ConfigParser.ConfigParser()
247        c.read(os.path.join(DSAGE_DIR, 'client.conf'))
248        username = c.get('auth', 'username')
249        pubkey_file = c.get('auth', 'pubkey_file')
250        self.clientdb.add_user(username, pubkey_file) 
251       
252    def tearDown(self):
253        self.connection.disconnect()
254        self.jobdb._shutdown()
255        files = glob('*.db*')
256        for file in files:
257            os.remove(file)
258        return self.server.stopListening() 
259   
260    def testremote_get_job(self):
261        job = Job()
262        job.code = "2+2"
263        self.dsage_server.submit_job(job.reduce())
264        factory = PBClientFactory()
265        self.connection = reactor.connectTCP(self.hostname, 
266                                             self.port, 
267                                             factory)                                       
268        d = factory.login(self.creds, (pb.Referenceable(), hf))
269        d.addCallback(self._logged_in)
270        d.addCallback(self._get_job)
271       
272        return d
273   
274    def _logged_in(self, remoteobj):
275        self.assert_(remoteobj is not None)
276       
277        return remoteobj
278       
279    def _get_job(self, remoteobj):
280        d = remoteobj.callRemote('get_job')
281        d.addCallback(self._got_job)
282       
283        return d
284       
285    def _got_job(self, jdict):
286        self.assertEquals(type(jdict), dict)
287   
288    def testremote_job_done(self):
289        factory = PBClientFactory()
290        self.connection = reactor.connectTCP(self.hostname, 
291                                             self.port, 
292                                             factory)                                       
293        d = factory.login(self.creds, (pb.Referenceable(), hf))
294        job = Job()
295        job.code = "2+2"
296        jdict = self.dsage_server.submit_job(job.reduce())
297        d.addCallback(self._logged_in)
298        d.addCallback(self._job_done, jdict)
299       
300        return d
301   
302    def _job_done(self, remoteobj, jdict):
303        job_id = jdict['job_id']
304        result = jdict['result']
305        d = remoteobj.callRemote('job_done', 
306                                 job_id, 
307                                 'Nothing.',
308                                 result,
309                                 False,
310                                 'lalal')
311        d.addCallback(self._done_job)
312       
313        return d
314   
315    def _done_job(self, jdict):
316        self.assertEquals(type(jdict), dict)
317        self.assertEquals(jdict['status'], 'new')
318        self.assertEquals(jdict['output'], 'Nothing.')
319   
320    def testremote_job_failed(self):
321        factory = PBClientFactory()
322        self.connection = reactor.connectTCP(self.hostname, 
323                                             self.port, 
324                                             factory)
325        job = Job()
326        job.code = "2+2"
327        jdict = self.dsage_server.submit_job(job.reduce())
328        d = factory.login(self.creds, (pb.Referenceable(), hf))
329        d.addCallback(self._logged_in)
330        d.addCallback(self._job_failed, jdict)
331       
332        return d
333       
334    def _job_failed(self, remoteobj, jdict):
335        d = remoteobj.callRemote('job_failed', jdict['job_id'])
336        d.addCallback(self._failed_job)
337       
338        return d
339       
340    def _failed_job(self, jdict):
341        self.assertEquals(type(jdict), dict) 
342        self.assertEquals(jdict['failures'], 1)
343   
344    def testget_killed_jobs_list(self):
345        factory = PBClientFactory()
346        self.connection = reactor.connectTCP(self.hostname, 
347                                             self.port, 
348                                             factory)
349 
350        job = Job()
351        job.code = "2+2"
352        job.killed = True
353        jdict = self.dsage_server.submit_job(job.reduce())
354        d = factory.login(self.creds, (pb.Referenceable(), hf))
355        d.addCallback(self._logged_in)
356        d.addCallback(self._get_killed_jobs_list)
357        d.addCallback(self._got_killed_jobs_list, jdict)
358       
359        return d
360   
361    def _get_killed_jobs_list(self, remoteobj):
362        d = remoteobj.callRemote('get_killed_jobs_list')
363       
364        return d
365   
366    def _got_killed_jobs_list(self, killed_jobs_list, jdict):
367        self.assertEquals(len(killed_jobs_list), 1)
368        self.assertEquals(killed_jobs_list[0]['job_id'], jdict['job_id'])
369       
370       
371if __name__ == 'main':
372    unittest.main()
Note: See TracBrowser for help on using the repository browser.