f12cae10f897c2408f47a06dee3bcd52ba1df503
[sfa.git] / tests / testInterfaces.py
1 #!/usr/bin/python
2 import sys
3 import os
4 import random
5 import string
6 import unittest
7 import sfa.util.sfaprotocol as sfaprotocol
8 from unittest import TestCase
9 from optparse import OptionParser
10 from sfa.util.xrn import get_authority
11 from sfa.util.config import *
12 from sfa.trust.certificate import *
13 from sfa.trust.credential import *
14 from sfa.trust.sfaticket import SfaTicket
15 from sfa.client import sfi
16
17 def random_string(size):
18     return "".join(random.sample(string.letters, size))
19
20 class Client:
21     registry = None
22     aggregate = None
23     sm = None
24     cm = None
25     key = None
26     cert = None
27     credential = None
28     type = None            
29     def __init__(self, options):
30         try: self.config = config = Config(options.config_file)
31         except:
32             print "failed to read config_file %s" % options.config_file
33             sys.exit(1)
34         key_path = os.path.dirname(options.config_file)
35         user_name = self.config.SFI_USER.split('.')[-1:][0]
36         key_file = key_path + os.sep + user_name + '.pkey'
37         cert_file = key_path + os.sep + user_name + '.cert'
38         self.key = Keypair(filename=key_file)
39         self.cert = Certificate(subject=self.config.SFI_USER)
40         self.cert.set_pubkey(self.key)
41         self.cert.set_issuer(self.key, self.config.SFI_USER)
42         self.cert.sign()
43         self.cert.save_to_file(cert_file)        
44         SFI_AGGREGATE = config.SFI_SM.replace('12347', '12346')
45         SFI_CM = 'http://' + options.cm_host + ':12346'
46         self.registry = sfaprotocol.server_proxy(config.SFI_REGISTRY, key_file, cert_file)
47         self.aggregate = sfaprotocol.server_proxy(SFI_AGGREGATE, key_file, cert_file)
48         self.sm = sfaprotocol.server_proxy(config.SFI_SM, key_file, cert_file)
49         self.cm = sfaprotocol.server_proxy(SFI_CM, key_file, cert_file)
50         self.hrn = config.SFI_USER
51         # XX defaulting to user, but this should be configurable so we can
52         # test from components persepctive
53         self.type = 'user'
54         self.credential = self.GetCredential(self.hrn)
55         
56     def GetCredential(self, hrn = None, type = 'user'):
57         if not hrn: hrn = self.hrn 
58         if hrn == self.hrn:
59             cert = self.cert.save_to_string(save_parents=True)
60             request_hash = self.key.compute_hash([cert, 'user', hrn])
61             credential = self.registry.get_self_credential(cert, type, hrn, request_hash)
62             return credential
63         else:
64             if not self.credential:
65                 self.credential = self.GetCredential(self.hrn, 'user')
66             return self.registry.GetCredential(self.credential, type, hrn)     
67
68 class BasicTestCase(unittest.TestCase):
69     def __init__(self, testname, client, test_slice=None):
70         unittest.TestCase.__init__(self, testname)
71         self.client = client
72         self.slice = test_slice
73     
74     def setUp(self):
75         self.registry = self.client.registry
76         self.aggregate = self.client.aggregate
77         self.sm = self.client.sm
78         self.cm = self.client.cm
79         self.credential = self.client.credential
80         self.hrn = self.client.hrn
81         self.type = self.client.type  
82                 
83 # Registry tests
84 class RegistryTest(BasicTestCase):
85
86     def setUp(self):
87         """
88         Make sure test records dont exsit
89         """
90         BasicTestCase.setUp(self)
91
92     def testGetSelfCredential(self):
93         cred = self.client.GetCredential()
94         # this will raise an openssl error if the credential string isnt valid
95         Credential(string=cred)
96
97     def testRegister(self):
98         authority = get_authority(self.hrn)
99         auth_cred = self.client.GetCredential(authority, 'authority')
100         auth_record = {'hrn': '.'.join([authority, random_string(10).lower()]),
101                        'type': 'authority'}
102         node_record = {'hrn': '.'.join([authority, random_string(10)]),
103                        'type': 'node',
104                        'hostname': random_string(6) + '.' + random_string(6)}
105         slice_record = {'hrn': '.'.join([authority, random_string(10)]),
106                         'type': 'slice', 'researcher': [self.hrn]}
107         user_record = {'hrn': '.'.join([authority, random_string(10)]),
108                        'type': 'user',
109                        'email': random_string(6) +'@'+ random_string(5) +'.'+ random_string(3),
110                        'first_name': random_string(7),
111                        'last_name': random_string(7)}
112
113         all_records = [auth_record, node_record, slice_record, user_record]
114         for record in all_records:
115             try:
116                 self.registry.Register(auth_cred, record)
117                 self.registry.Resolve(self.credential, record['hrn'])
118             except:
119                 raise
120             finally:
121                 try: self.registry.Remove(auth_cred, record['type'], record['hrn'])
122                 except: pass
123
124     
125     def testRegisterPeerObject(self):
126         assert True
127    
128     def testUpdate(self):
129         authority = get_authority(self.hrn)
130         auth_cred = self.client.GetCredential(authority, 'authority')
131         records = self.registry.Resolve(self.credential, self.hrn)
132         if not records: assert False
133         record = records[0]
134         self.registry.update(auth_cred, record) 
135
136     def testResolve(self):
137         authority = get_authority(self.hrn)
138         self.registry.Resolve(self.credential, self.hrn)
139    
140     def testRemove(self):
141         authority = get_authority(self.hrn)
142         auth_cred = self.client.GetCredential(authority, 'authority')
143         record = {'hrn': ".".join([authority, random_string(10)]),
144                        'type': 'slice'}
145         self.registry.Register(auth_cred, record)
146         self.registry.Remove(auth_cred, record['type'], record['hrn'])
147         # should generate an exception
148         try:
149             self.registry.Resolve(self.credential,  record['hrn'])
150             assert False
151         except:       
152             assert True
153  
154     def testRemovePeerObject(self):
155         assert True
156
157     def testList(self):
158         authority = get_authority(self.client.hrn)
159         self.registry.List(self.credential, authority)
160              
161     def testGetRegistries(self):
162         self.registry.get_registries(self.credential)
163     
164     def testGetAggregates(self):
165         self.registry.get_aggregates(self.credential)
166
167     def testGetTrustedCerts(self):
168         # this should fail unless we are a node
169         callable = self.registry.get_trusted_certs
170         server_exception = False 
171         try:
172             callable(self.credential)
173         except sfaprotocol.ServerException:
174             server_exception = True
175         finally:
176             if self.type in ['user'] and not server_exception:
177                 assert False
178             
179
180 class AggregateTest(BasicTestCase):
181     def setUp(self):
182         BasicTestCase.setUp(self)
183         
184     def testGetSlices(self):
185         self.aggregate.ListSlices(self.credential)
186
187     def testGetResources(self):
188         # available resources
189         agg_rspec = self.aggregate.get_resources(self.credential)
190         # resources used by a slice
191         slice_rspec = self.aggregate.get_resources(self.credential, self.slice['hrn'])
192         # will raise an exception if the rspec isnt valid
193         RSpec(xml=agg_rspec)
194         RSpec(xml=slice_rspec)
195
196     def testCreateSlice(self):
197         # get availabel resources   
198         rspec = self.aggregate.get_resources(self.credential)
199         slice_credential = self.client.GetCredential(self.slice['hrn'], 'slice')
200         self.aggregate.CreateSliver(slice_credential, self.slice['hrn'], rspec)
201
202     def testDeleteSlice(self):
203         slice_credential = self.client.GetCredential(self.slice['hrn'], 'slice')
204         self.aggregate.DeleteSliver(slice_credential, self.slice['hrn'],"call-id-delete-slice")
205
206     def testGetTicket(self):
207         slice_credential = self.client.GetCredential(self.slice['hrn'], 'slice')
208         rspec = self.aggregate.get_resources(self.credential)
209         ticket = self.aggregate.GetTicket(slice_credential, self.slice['hrn'], rspec)
210         # will raise an exception if the ticket inst valid
211         SfaTicket(string=ticket)        
212
213 class SlicemgrTest(AggregateTest):
214     def setUp(self):
215         AggregateTest.setUp(self)
216         
217         # force calls to go through slice manager   
218         self.aggregate = self.sm
219
220         # get the slice credential
221         
222
223 class ComponentTest(BasicTestCase):
224     def setUp(self):
225         BasicTestCase.setUp(self)
226         self.slice_cred = self.client.GetCredential(self.slice['hrn'], 'slice')
227
228     def testStartSlice(self):
229         self.cm.start_slice(self.slice_cred, self.slice['hrn'])
230
231     def testStopSlice(self):
232         self.cm.stop_slice(self.slice_cred, self.slice['hrn'])
233
234     def testDeleteSlice(self):
235         self.cm.DeleteSliver(self.slice_cred, self.slice['hrn'],"call-id-delete-slice-cm")
236
237     def testRestartSlice(self):
238         self.cm.restart_slice(self.slice_cred, self.slice['hrn'])
239
240     def testGetSlices(self):
241         self.cm.ListSlices(self.slice_cred, self.slice['hrn'])
242
243     def testRedeemTicket(self):
244         rspec = self.aggregate.get_resources(self.credential)
245         ticket = self.aggregate.GetTicket(slice_cred, self.slice['hrn'], rspec)
246         self.cm.redeem_ticket(slice_cred, ticket)
247
248
249 def test_names(testcase):
250     return [name for name in dir(testcase) if name.startswith('test')]
251
252 def CreateSliver(client):
253     # register a slice that will be used for some test
254     authority = get_authority(client.hrn)
255     auth_cred = client.GetCredential(authority, 'authority')
256     slice_record = {'hrn': ".".join([authority, random_string(10)]),
257                     'type': 'slice', 'researcher': [client.hrn]}
258     client.registry.Register(auth_cred, slice_record)
259     return  slice_record
260  
261 def DeleteSliver(client, slice):
262     authority = get_authority(client.hrn)
263     auth_cred = client.GetCredential(authority, 'authority')
264     if slice:
265         client.registry.Remove(auth_cred, 'slice', slice['hrn'])
266     
267 if __name__ == '__main__':
268
269     args = sys.argv
270     prog_name = args[0]
271     default_config_dir = os.path.expanduser('~/.sfi/sfi_config')
272     default_cm = "echo.cs.princeton.edu"
273     parser = OptionParser(usage="%(prog_name)s [options]" % locals())
274     parser.add_option('-f', '--config_file', dest='config_file', default=default_config_dir,
275                       help='config file. default is %s' % default_config_dir)
276     parser.add_option('-r', '--registry', dest='registry', action='store_true',
277                       default=False, help='run registry tests')
278     parser.add_option('-a', '--aggregate', dest='aggregate', action='store_true',
279                       default=False, help='run aggregate tests')
280     parser.add_option('-s', '--slicemgr', dest='slicemgr', action='store_true',
281                       default=False, help='run slicemgr tests')
282     parser.add_option('-c', '--component', dest='component', action='store_true',
283                       default=False, help='run component tests')
284     parser.add_option('-d', '--cm_host', dest='cm_host', default=default_cm, 
285                       help='dns name of component to test. default is %s' % default_cm)
286     parser.add_option('-A', '--all', dest='all', action='store_true',
287                       default=False, help='run component tests')
288     
289     options, args = parser.parse_args()
290     suite = unittest.TestSuite()
291     client = Client(options)
292     test_slice = {}
293     
294     # create the test slice if necessary
295     if options.all or options.slicemgr or options.aggregate \
296        or options.component:
297         test_slice = CreateSliver(client)
298
299     if options.registry or options.all:
300         for name in test_names(RegistryTest):
301             suite.addTest(RegistryTest(name, client))
302
303     if options.aggregate or options.all: 
304         for name in test_names(AggregateTest):
305             suite.addTest(AggregateTest(name, client, test_slice))
306
307     if options.slicemgr or options.all: 
308         for name in test_names(SlicemgrTest):
309             suite.addTest(SlicemgrTest(name, client, test_slice))
310
311     if options.component or options.all: 
312         for name in test_names(ComponentTest):
313             suite.addTest(ComponentTest(name, client, test_slice))
314     
315     # run tests 
316     unittest.TextTestRunner(verbosity=2).run(suite)
317
318     # remove teset slice
319     DeleteSliver(client, test_slice)