sfaprotocol is renamed into sfaserverproxy, with class SfaServerProxy
[sfa.git] / tests / testInterfaces.py
1 #!/usr/bin/python
2 import sys
3 import os
4 import random
5 import string
6 import unittest
7
8 from unittest import TestCase
9 from optparse import OptionParser
10 from sfa.util.xrn import get_authority
11 from sfa.util.config import *
12 from sfa.trust.certificate import *
13 from sfa.trust.credential import *
14 from sfa.trust.sfaticket import SfaTicket
15 from sfa.client import sfi
16 from sfa.client.sfaserverproxy import SfaServerProxy, ServerException
17
18 def random_string(size):
19     return "".join(random.sample(string.letters, size))
20
21 class Client:
22     registry = None
23     aggregate = None
24     sm = None
25     cm = None
26     key = None
27     cert = None
28     credential = None
29     type = None            
30     def __init__(self, options):
31         try: self.config = config = Config(options.config_file)
32         except:
33             print "failed to read config_file %s" % options.config_file
34             sys.exit(1)
35         key_path = os.path.dirname(options.config_file)
36         user_name = self.config.SFI_USER.split('.')[-1:][0]
37         key_file = key_path + os.sep + user_name + '.pkey'
38         cert_file = key_path + os.sep + user_name + '.cert'
39         self.key = Keypair(filename=key_file)
40         self.cert = Certificate(subject=self.config.SFI_USER)
41         self.cert.set_pubkey(self.key)
42         self.cert.set_issuer(self.key, self.config.SFI_USER)
43         self.cert.sign()
44         self.cert.save_to_file(cert_file)        
45         SFI_AGGREGATE = config.SFI_SM.replace('12347', '12346')
46         SFI_CM = 'http://' + options.cm_host + ':12346'
47         self.registry = SfaServerProxy(config.SFI_REGISTRY, key_file, cert_file)
48         self.aggregate = SfaServerProxy(SFI_AGGREGATE, key_file, cert_file)
49         self.sm = SfaServerProxy(config.SFI_SM, key_file, cert_file)
50         self.cm = SfaServerProxy(SFI_CM, key_file, cert_file)
51         self.hrn = config.SFI_USER
52         # XX defaulting to user, but this should be configurable so we can
53         # test from components persepctive
54         self.type = 'user'
55         self.credential = self.GetCredential(self.hrn)
56         
57     def GetCredential(self, hrn = None, type = 'user'):
58         if not hrn: hrn = self.hrn 
59         if hrn == self.hrn:
60             cert = self.cert.save_to_string(save_parents=True)
61             request_hash = self.key.compute_hash([cert, 'user', hrn])
62             credential = self.registry.get_self_credential(cert, type, hrn, request_hash)
63             return credential
64         else:
65             if not self.credential:
66                 self.credential = self.GetCredential(self.hrn, 'user')
67             return self.registry.GetCredential(self.credential, type, hrn)     
68
69 class BasicTestCase(unittest.TestCase):
70     def __init__(self, testname, client, test_slice=None):
71         unittest.TestCase.__init__(self, testname)
72         self.client = client
73         self.slice = test_slice
74     
75     def setUp(self):
76         self.registry = self.client.registry
77         self.aggregate = self.client.aggregate
78         self.sm = self.client.sm
79         self.cm = self.client.cm
80         self.credential = self.client.credential
81         self.hrn = self.client.hrn
82         self.type = self.client.type  
83                 
84 # Registry tests
85 class RegistryTest(BasicTestCase):
86
87     def setUp(self):
88         """
89         Make sure test records dont exsit
90         """
91         BasicTestCase.setUp(self)
92
93     def testGetSelfCredential(self):
94         cred = self.client.GetCredential()
95         # this will raise an openssl error if the credential string isnt valid
96         Credential(string=cred)
97
98     def testRegister(self):
99         authority = get_authority(self.hrn)
100         auth_cred = self.client.GetCredential(authority, 'authority')
101         auth_record = {'hrn': '.'.join([authority, random_string(10).lower()]),
102                        'type': 'authority'}
103         node_record = {'hrn': '.'.join([authority, random_string(10)]),
104                        'type': 'node',
105                        'hostname': random_string(6) + '.' + random_string(6)}
106         slice_record = {'hrn': '.'.join([authority, random_string(10)]),
107                         'type': 'slice', 'researcher': [self.hrn]}
108         user_record = {'hrn': '.'.join([authority, random_string(10)]),
109                        'type': 'user',
110                        'email': random_string(6) +'@'+ random_string(5) +'.'+ random_string(3),
111                        'first_name': random_string(7),
112                        'last_name': random_string(7)}
113
114         all_records = [auth_record, node_record, slice_record, user_record]
115         for record in all_records:
116             try:
117                 self.registry.Register(auth_cred, record)
118                 self.registry.Resolve(self.credential, record['hrn'])
119             except:
120                 raise
121             finally:
122                 try: self.registry.Remove(auth_cred, record['type'], record['hrn'])
123                 except: pass
124
125     
126     def testRegisterPeerObject(self):
127         assert True
128    
129     def testUpdate(self):
130         authority = get_authority(self.hrn)
131         auth_cred = self.client.GetCredential(authority, 'authority')
132         records = self.registry.Resolve(self.credential, self.hrn)
133         if not records: assert False
134         record = records[0]
135         self.registry.update(auth_cred, record) 
136
137     def testResolve(self):
138         authority = get_authority(self.hrn)
139         self.registry.Resolve(self.credential, self.hrn)
140    
141     def testRemove(self):
142         authority = get_authority(self.hrn)
143         auth_cred = self.client.GetCredential(authority, 'authority')
144         record = {'hrn': ".".join([authority, random_string(10)]),
145                        'type': 'slice'}
146         self.registry.Register(auth_cred, record)
147         self.registry.Remove(auth_cred, record['type'], record['hrn'])
148         # should generate an exception
149         try:
150             self.registry.Resolve(self.credential,  record['hrn'])
151             assert False
152         except:       
153             assert True
154  
155     def testRemovePeerObject(self):
156         assert True
157
158     def testList(self):
159         authority = get_authority(self.client.hrn)
160         self.registry.List(self.credential, authority)
161              
162     def testGetRegistries(self):
163         self.registry.get_registries(self.credential)
164     
165     def testGetAggregates(self):
166         self.registry.get_aggregates(self.credential)
167
168     def testGetTrustedCerts(self):
169         # this should fail unless we are a node
170         callable = self.registry.get_trusted_certs
171         server_exception = False 
172         try:
173             callable(self.credential)
174         except ServerException:
175             server_exception = True
176         finally:
177             if self.type in ['user'] and not server_exception:
178                 assert False
179             
180
181 class AggregateTest(BasicTestCase):
182     def setUp(self):
183         BasicTestCase.setUp(self)
184         
185     def testGetSlices(self):
186         self.aggregate.ListSlices(self.credential)
187
188     def testGetResources(self):
189         # available resources
190         agg_rspec = self.aggregate.get_resources(self.credential)
191         # resources used by a slice
192         slice_rspec = self.aggregate.get_resources(self.credential, self.slice['hrn'])
193         # will raise an exception if the rspec isnt valid
194         RSpec(xml=agg_rspec)
195         RSpec(xml=slice_rspec)
196
197     def testCreateSlice(self):
198         # get availabel resources   
199         rspec = self.aggregate.get_resources(self.credential)
200         slice_credential = self.client.GetCredential(self.slice['hrn'], 'slice')
201         self.aggregate.CreateSliver(slice_credential, self.slice['hrn'], rspec)
202
203     def testDeleteSlice(self):
204         slice_credential = self.client.GetCredential(self.slice['hrn'], 'slice')
205         self.aggregate.DeleteSliver(slice_credential, self.slice['hrn'],"call-id-delete-slice")
206
207     def testGetTicket(self):
208         slice_credential = self.client.GetCredential(self.slice['hrn'], 'slice')
209         rspec = self.aggregate.get_resources(self.credential)
210         ticket = self.aggregate.GetTicket(slice_credential, self.slice['hrn'], rspec)
211         # will raise an exception if the ticket inst valid
212         SfaTicket(string=ticket)        
213
214 class SlicemgrTest(AggregateTest):
215     def setUp(self):
216         AggregateTest.setUp(self)
217         
218         # force calls to go through slice manager   
219         self.aggregate = self.sm
220
221         # get the slice credential
222         
223
224 class ComponentTest(BasicTestCase):
225     def setUp(self):
226         BasicTestCase.setUp(self)
227         self.slice_cred = self.client.GetCredential(self.slice['hrn'], 'slice')
228
229     def testStartSlice(self):
230         self.cm.start_slice(self.slice_cred, self.slice['hrn'])
231
232     def testStopSlice(self):
233         self.cm.stop_slice(self.slice_cred, self.slice['hrn'])
234
235     def testDeleteSlice(self):
236         self.cm.DeleteSliver(self.slice_cred, self.slice['hrn'],"call-id-delete-slice-cm")
237
238     def testRestartSlice(self):
239         self.cm.restart_slice(self.slice_cred, self.slice['hrn'])
240
241     def testGetSlices(self):
242         self.cm.ListSlices(self.slice_cred, self.slice['hrn'])
243
244     def testRedeemTicket(self):
245         rspec = self.aggregate.get_resources(self.credential)
246         ticket = self.aggregate.GetTicket(slice_cred, self.slice['hrn'], rspec)
247         self.cm.redeem_ticket(slice_cred, ticket)
248
249
250 def test_names(testcase):
251     return [name for name in dir(testcase) if name.startswith('test')]
252
253 def CreateSliver(client):
254     # register a slice that will be used for some test
255     authority = get_authority(client.hrn)
256     auth_cred = client.GetCredential(authority, 'authority')
257     slice_record = {'hrn': ".".join([authority, random_string(10)]),
258                     'type': 'slice', 'researcher': [client.hrn]}
259     client.registry.Register(auth_cred, slice_record)
260     return  slice_record
261  
262 def DeleteSliver(client, slice):
263     authority = get_authority(client.hrn)
264     auth_cred = client.GetCredential(authority, 'authority')
265     if slice:
266         client.registry.Remove(auth_cred, 'slice', slice['hrn'])
267     
268 if __name__ == '__main__':
269
270     args = sys.argv
271     prog_name = args[0]
272     default_config_dir = os.path.expanduser('~/.sfi/sfi_config')
273     default_cm = "echo.cs.princeton.edu"
274     parser = OptionParser(usage="%(prog_name)s [options]" % locals())
275     parser.add_option('-f', '--config_file', dest='config_file', default=default_config_dir,
276                       help='config file. default is %s' % default_config_dir)
277     parser.add_option('-r', '--registry', dest='registry', action='store_true',
278                       default=False, help='run registry tests')
279     parser.add_option('-a', '--aggregate', dest='aggregate', action='store_true',
280                       default=False, help='run aggregate tests')
281     parser.add_option('-s', '--slicemgr', dest='slicemgr', action='store_true',
282                       default=False, help='run slicemgr tests')
283     parser.add_option('-c', '--component', dest='component', action='store_true',
284                       default=False, help='run component tests')
285     parser.add_option('-d', '--cm_host', dest='cm_host', default=default_cm, 
286                       help='dns name of component to test. default is %s' % default_cm)
287     parser.add_option('-A', '--all', dest='all', action='store_true',
288                       default=False, help='run component tests')
289     
290     options, args = parser.parse_args()
291     suite = unittest.TestSuite()
292     client = Client(options)
293     test_slice = {}
294     
295     # create the test slice if necessary
296     if options.all or options.slicemgr or options.aggregate \
297        or options.component:
298         test_slice = CreateSliver(client)
299
300     if options.registry or options.all:
301         for name in test_names(RegistryTest):
302             suite.addTest(RegistryTest(name, client))
303
304     if options.aggregate or options.all: 
305         for name in test_names(AggregateTest):
306             suite.addTest(AggregateTest(name, client, test_slice))
307
308     if options.slicemgr or options.all: 
309         for name in test_names(SlicemgrTest):
310             suite.addTest(SlicemgrTest(name, client, test_slice))
311
312     if options.component or options.all: 
313         for name in test_names(ComponentTest):
314             suite.addTest(ComponentTest(name, client, test_slice))
315     
316     # run tests 
317     unittest.TextTestRunner(verbosity=2).run(suite)
318
319     # remove teset slice
320     DeleteSliver(client, test_slice)