pl: we need to distinguish between local pl sites/slices and foreign ones coming...
[sfa.git] / tests / testCred.py
1 import unittest
2 from sfa.trust.credential import *
3 from sfa.trust.rights import *
4 from sfa.trust.gid import *
5 from sfa.trust.certificate import *
6
7 class TestCred(unittest.TestCase):
8    def setUp(self):
9       pass
10
11    def testCreate(self):
12       cred = Credential(create=True)
13
14    def testDefaults(self):
15       cred = Credential(subject="testCredential")
16
17       self.assertEqual(cred.get_gid_caller(), None)
18       self.assertEqual(cred.get_gid_object(), None)
19
20    def testLoadSave(self):
21       cred = Credential(subject="testCredential")
22
23       gidCaller = GID(subject="caller", uuid=create_uuid(), hrn="foo.caller")
24       gidObject = GID(subject="object", uuid=create_uuid(), hrn="foo.object")
25       lifeTime = 12345
26       delegate = True
27       rights = "embed:1,bind:1"
28
29       cred.set_gid_caller(gidCaller)
30       self.assertEqual(cred.get_gid_caller().get_subject(), gidCaller.get_subject())
31
32       cred.set_gid_object(gidObject)
33       self.assertEqual(cred.get_gid_object().get_subject(), gidObject.get_subject())
34
35       cred.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=lifeTime))
36       
37       cred.set_privileges(rights)
38       self.assertEqual(cred.get_privileges().save_to_string(), rights)
39
40       cred.get_privileges().delegate_all_privileges(delegate)
41
42       cred.encode()
43
44       cred_str = cred.save_to_string()
45
46       # re-load the credential from a string and make sure its fields are
47       # intact
48       cred2 = Credential(string = cred_str)
49       self.assertEqual(cred2.get_gid_caller().get_subject(), gidCaller.get_subject())
50       self.assertEqual(cred2.get_gid_object().get_subject(), gidObject.get_subject())
51       self.assertEqual(cred2.get_privileges().get_all_delegate(), delegate)
52       self.assertEqual(cred2.get_privileges().save_to_string(), rights)
53
54
55
56    def createSignedGID(self, subject, urn, issuer_pkey = None, issuer_gid = None):
57       gid = GID(subject=subject, uuid=1, urn=urn)
58       keys = Keypair(create=True)
59       gid.set_pubkey(keys)
60       if issuer_pkey:
61          gid.set_issuer(issuer_pkey, str(issuer_gid.get_issuer()))
62       else:
63          gid.set_issuer(keys, subject)
64
65       gid.encode()
66       gid.sign()
67       return gid, keys
68
69    
70    
71
72    def testDelegationAndVerification(self):
73       gidAuthority, keys = self.createSignedGID("site", "urn:publicid:IDN+plc+authority+site")
74       gidCaller, ckeys = self.createSignedGID("site.foo", "urn:publicid:IDN+plc:site+user+foo",
75                                           keys, gidAuthority)
76       gidObject, _ = self.createSignedGID("site.slice", "urn:publicid:IDN+plc:site+slice+bar_slice",
77                                           keys, gidAuthority)
78       gidDelegatee, _ = self.createSignedGID("site.delegatee", "urn:publicid:IDN+plc:site+user+delegatee",
79                                              keys, gidAuthority)
80
81       cred = Credential()
82       cred.set_gid_caller(gidCaller)
83       cred.set_gid_object(gidObject)
84       cred.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=3600))
85       cred.set_privileges("embed:1, bind:1")
86       cred.encode()
87
88       gidAuthority.save_to_file("/tmp/auth_gid")
89       keys.save_to_file("/tmp/auth_key")
90       cred.set_issuer_keys("/tmp/auth_key", "/tmp/auth_gid")
91       cred.sign()
92
93
94       cred.verify(['/tmp/auth_gid'])
95
96       # Test copying
97       cred2 = Credential(string=cred.save_to_string())
98       cred2.verify(['/tmp/auth_gid'])
99
100
101       # Test delegation
102       delegated = Credential()
103       delegated.set_gid_caller(gidDelegatee)
104       delegated.set_gid_object(gidObject)      
105       delegated.set_parent(cred)
106       delegated.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=600))
107       delegated.set_privileges("embed:1, bind:1")
108       gidCaller.save_to_file("/tmp/caller_gid")
109       ckeys.save_to_file("/tmp/caller_pkey")      
110       
111       delegated.set_issuer_keys("/tmp/caller_pkey", "/tmp/caller_gid")
112
113       delegated.encode()
114
115       delegated.sign()
116       
117       # This should verify
118       delegated.verify(['/tmp/auth_gid'])
119
120       backup = Credential(string=delegated.get_xml())
121
122       # Test that verify catches an incorrect lifetime      
123       delegated.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=6000))
124       delegated.encode()
125       delegated.sign()
126       try:
127          delegated.verify(['/tmp/auth_gid'])
128          assert(1==0)
129       except CredentialNotVerifiable:
130          pass
131
132       # Test that verify catches an incorrect signer
133       delegated = Credential(string=backup.get_xml())
134       delegated.set_issuer_keys("/tmp/auth_key", "/tmp/auth_gid")
135       delegated.encode()
136       delegated.sign()
137
138       try:
139          delegated.verify(['/tmp/auth_gid'])
140          assert(1==0)
141       except CredentialNotVerifiable:
142          pass
143
144
145       # Test that verify catches a changed gid
146       delegated = Credential(string=backup.get_xml())
147       delegated.set_gid_object(delegated.get_gid_caller())
148       delegated.encode()
149       delegated.sign()
150
151       try:
152          delegated.verify(['/tmp/auth_gid'])
153          assert(1==0)
154       except CredentialNotVerifiable:
155          pass
156
157
158       # Test that verify catches a credential with the wrong authority for the object
159       test = Credential(string=cred.get_xml())
160       test.set_issuer_keys("/tmp/caller_pkey", "/tmp/caller_gid")
161       test.encode()
162       test.sign()
163
164       try:
165          test.verify(['/tmp/auth_gid'])
166          assert(1==0)
167       except CredentialNotVerifiable:
168          pass      
169       
170       # Test that * gets translated properly
171
172 if __name__ == "__main__":
173     unittest.main()