8 from unittest import TestCase
9 from optparse import OptionParser
10 from sfa.util.xrn import get_authority
11 from sfa.util.config import *
12 from sfa.trust.certificate import *
13 from sfa.trust.credential import *
14 from sfa.trust.sfaticket import SfaTicket
15 from sfa.client import sfi
16 from sfa.client.sfaserverproxy import SfaServerProxy, ServerException
18 def random_string(size):
19 return "".join(random.sample(string.letters, size))
30 def __init__(self, options):
31 try: self.config = config = Config(options.config_file)
33 print "failed to read config_file %s" % options.config_file
35 key_path = os.path.dirname(options.config_file)
36 user_name = self.config.SFI_USER.split('.')[-1:][0]
37 key_file = key_path + os.sep + user_name + '.pkey'
38 cert_file = key_path + os.sep + user_name + '.cert'
39 self.key = Keypair(filename=key_file)
40 self.cert = Certificate(subject=self.config.SFI_USER)
41 self.cert.set_pubkey(self.key)
42 self.cert.set_issuer(self.key, self.config.SFI_USER)
44 self.cert.save_to_file(cert_file)
45 SFI_AGGREGATE = config.SFI_SM.replace('12347', '12346')
46 SFI_CM = 'http://' + options.cm_host + ':12346'
47 self.registry = SfaServerProxy(config.SFI_REGISTRY, key_file, cert_file)
48 self.aggregate = SfaServerProxy(SFI_AGGREGATE, key_file, cert_file)
49 self.sm = SfaServerProxy(config.SFI_SM, key_file, cert_file)
50 self.cm = SfaServerProxy(SFI_CM, key_file, cert_file)
51 self.hrn = config.SFI_USER
52 # XX defaulting to user, but this should be configurable so we can
53 # test from components persepctive
55 self.credential = self.GetCredential(self.hrn)
57 def GetCredential(self, hrn = None, type = 'user'):
58 if not hrn: hrn = self.hrn
60 cert = self.cert.save_to_string(save_parents=True)
61 request_hash = self.key.compute_hash([cert, 'user', hrn])
62 credential = self.registry.get_self_credential(cert, type, hrn, request_hash)
65 if not self.credential:
66 self.credential = self.GetCredential(self.hrn, 'user')
67 return self.registry.GetCredential(self.credential, type, hrn)
69 class BasicTestCase(unittest.TestCase):
70 def __init__(self, testname, client, test_slice=None):
71 unittest.TestCase.__init__(self, testname)
73 self.slice = test_slice
76 self.registry = self.client.registry
77 self.aggregate = self.client.aggregate
78 self.sm = self.client.sm
79 self.cm = self.client.cm
80 self.credential = self.client.credential
81 self.hrn = self.client.hrn
82 self.type = self.client.type
85 class RegistryTest(BasicTestCase):
89 Make sure test records dont exsit
91 BasicTestCase.setUp(self)
93 def testGetSelfCredential(self):
94 cred = self.client.GetCredential()
95 # this will raise an openssl error if the credential string isnt valid
96 Credential(string=cred)
98 def testRegister(self):
99 authority = get_authority(self.hrn)
100 auth_cred = self.client.GetCredential(authority, 'authority')
101 auth_record = {'hrn': '.'.join([authority, random_string(10).lower()]),
103 node_record = {'hrn': '.'.join([authority, random_string(10)]),
105 'hostname': random_string(6) + '.' + random_string(6)}
106 slice_record = {'hrn': '.'.join([authority, random_string(10)]),
107 'type': 'slice', 'researcher': [self.hrn]}
108 user_record = {'hrn': '.'.join([authority, random_string(10)]),
110 'email': random_string(6) +'@'+ random_string(5) +'.'+ random_string(3),
111 'first_name': random_string(7),
112 'last_name': random_string(7)}
114 all_records = [auth_record, node_record, slice_record, user_record]
115 for record in all_records:
117 self.registry.Register(auth_cred, record)
118 self.registry.Resolve(self.credential, record['hrn'])
122 try: self.registry.Remove(auth_cred, record['type'], record['hrn'])
126 def testRegisterPeerObject(self):
129 def testUpdate(self):
130 authority = get_authority(self.hrn)
131 auth_cred = self.client.GetCredential(authority, 'authority')
132 records = self.registry.Resolve(self.credential, self.hrn)
133 if not records: assert False
135 self.registry.update(auth_cred, record)
137 def testResolve(self):
138 authority = get_authority(self.hrn)
139 self.registry.Resolve(self.credential, self.hrn)
141 def testRemove(self):
142 authority = get_authority(self.hrn)
143 auth_cred = self.client.GetCredential(authority, 'authority')
144 record = {'hrn': ".".join([authority, random_string(10)]),
146 self.registry.Register(auth_cred, record)
147 self.registry.Remove(auth_cred, record['type'], record['hrn'])
148 # should generate an exception
150 self.registry.Resolve(self.credential, record['hrn'])
155 def testRemovePeerObject(self):
159 authority = get_authority(self.client.hrn)
160 self.registry.List(self.credential, authority)
162 def testGetRegistries(self):
163 self.registry.get_registries(self.credential)
165 def testGetAggregates(self):
166 self.registry.get_aggregates(self.credential)
168 def testGetTrustedCerts(self):
169 # this should fail unless we are a node
170 callable = self.registry.get_trusted_certs
171 server_exception = False
173 callable(self.credential)
174 except ServerException:
175 server_exception = True
177 if self.type in ['user'] and not server_exception:
181 class AggregateTest(BasicTestCase):
183 BasicTestCase.setUp(self)
185 def testGetSlices(self):
186 self.aggregate.ListSlices(self.credential)
188 def testGetResources(self):
189 # available resources
190 agg_rspec = self.aggregate.get_resources(self.credential)
191 # resources used by a slice
192 slice_rspec = self.aggregate.get_resources(self.credential, self.slice['hrn'])
193 # will raise an exception if the rspec isnt valid
195 RSpec(xml=slice_rspec)
197 def testCreateSlice(self):
198 # get availabel resources
199 rspec = self.aggregate.get_resources(self.credential)
200 slice_credential = self.client.GetCredential(self.slice['hrn'], 'slice')
201 self.aggregate.CreateSliver(slice_credential, self.slice['hrn'], rspec)
203 def testDeleteSlice(self):
204 slice_credential = self.client.GetCredential(self.slice['hrn'], 'slice')
205 self.aggregate.DeleteSliver(slice_credential, self.slice['hrn'],"call-id-delete-slice")
207 def testGetTicket(self):
208 slice_credential = self.client.GetCredential(self.slice['hrn'], 'slice')
209 rspec = self.aggregate.get_resources(self.credential)
210 ticket = self.aggregate.GetTicket(slice_credential, self.slice['hrn'], rspec)
211 # will raise an exception if the ticket inst valid
212 SfaTicket(string=ticket)
214 class SlicemgrTest(AggregateTest):
216 AggregateTest.setUp(self)
218 # force calls to go through slice manager
219 self.aggregate = self.sm
221 # get the slice credential
224 class ComponentTest(BasicTestCase):
226 BasicTestCase.setUp(self)
227 self.slice_cred = self.client.GetCredential(self.slice['hrn'], 'slice')
229 def testStartSlice(self):
230 self.cm.start_slice(self.slice_cred, self.slice['hrn'])
232 def testStopSlice(self):
233 self.cm.stop_slice(self.slice_cred, self.slice['hrn'])
235 def testDeleteSlice(self):
236 self.cm.DeleteSliver(self.slice_cred, self.slice['hrn'],"call-id-delete-slice-cm")
238 def testRestartSlice(self):
239 self.cm.restart_slice(self.slice_cred, self.slice['hrn'])
241 def testGetSlices(self):
242 self.cm.ListSlices(self.slice_cred, self.slice['hrn'])
244 def testRedeemTicket(self):
245 rspec = self.aggregate.get_resources(self.credential)
246 ticket = self.aggregate.GetTicket(slice_cred, self.slice['hrn'], rspec)
247 self.cm.redeem_ticket(slice_cred, ticket)
250 def test_names(testcase):
251 return [name for name in dir(testcase) if name.startswith('test')]
253 def CreateSliver(client):
254 # register a slice that will be used for some test
255 authority = get_authority(client.hrn)
256 auth_cred = client.GetCredential(authority, 'authority')
257 slice_record = {'hrn': ".".join([authority, random_string(10)]),
258 'type': 'slice', 'researcher': [client.hrn]}
259 client.registry.Register(auth_cred, slice_record)
262 def DeleteSliver(client, slice):
263 authority = get_authority(client.hrn)
264 auth_cred = client.GetCredential(authority, 'authority')
266 client.registry.Remove(auth_cred, 'slice', slice['hrn'])
268 if __name__ == '__main__':
272 default_config_dir = os.path.expanduser('~/.sfi/sfi_config')
273 default_cm = "echo.cs.princeton.edu"
274 parser = OptionParser(usage="%(prog_name)s [options]" % locals())
275 parser.add_option('-f', '--config_file', dest='config_file', default=default_config_dir,
276 help='config file. default is %s' % default_config_dir)
277 parser.add_option('-r', '--registry', dest='registry', action='store_true',
278 default=False, help='run registry tests')
279 parser.add_option('-a', '--aggregate', dest='aggregate', action='store_true',
280 default=False, help='run aggregate tests')
281 parser.add_option('-s', '--slicemgr', dest='slicemgr', action='store_true',
282 default=False, help='run slicemgr tests')
283 parser.add_option('-c', '--component', dest='component', action='store_true',
284 default=False, help='run component tests')
285 parser.add_option('-d', '--cm_host', dest='cm_host', default=default_cm,
286 help='dns name of component to test. default is %s' % default_cm)
287 parser.add_option('-A', '--all', dest='all', action='store_true',
288 default=False, help='run component tests')
290 options, args = parser.parse_args()
291 suite = unittest.TestSuite()
292 client = Client(options)
295 # create the test slice if necessary
296 if options.all or options.slicemgr or options.aggregate \
297 or options.component:
298 test_slice = CreateSliver(client)
300 if options.registry or options.all:
301 for name in test_names(RegistryTest):
302 suite.addTest(RegistryTest(name, client))
304 if options.aggregate or options.all:
305 for name in test_names(AggregateTest):
306 suite.addTest(AggregateTest(name, client, test_slice))
308 if options.slicemgr or options.all:
309 for name in test_names(SlicemgrTest):
310 suite.addTest(SlicemgrTest(name, client, test_slice))
312 if options.component or options.all:
313 for name in test_names(ComponentTest):
314 suite.addTest(ComponentTest(name, client, test_slice))
317 unittest.TextTestRunner(verbosity=2).run(suite)
320 DeleteSliver(client, test_slice)