Merge branch 'master' of ssh://bakers@git.planet-lab.org/git/sfa
authorsmbaker <smbaker@fc8clean.lan>
Tue, 30 Aug 2011 16:43:05 +0000 (09:43 -0700)
committersmbaker <smbaker@fc8clean.lan>
Tue, 30 Aug 2011 16:43:05 +0000 (09:43 -0700)
18 files changed:
sfa.spec
sfa/client/sfi.py
sfa/managers/aggregate_manager_pl.py
sfa/managers/registry_manager_pl.py
sfa/methods/CreateGid.py [new file with mode: 0644]
sfa/methods/CreateSliver.py
sfa/methods/Register.py
sfa/methods/Update.py
sfa/methods/__init__.py
sfa/plc/sfa-nuke-plc.py
sfa/plc/slices.py
sfa/rspecs/sfa_rspec.py
sfa/server/sfa-server.py
sfa/trust/certificate.py
sfa/trust/credential.py
sfa/trust/gid.py
sfa/trust/hierarchy.py
sfa/util/xrn.py

index 9a5988d..286bfbf 100644 (file)
--- a/sfa.spec
+++ b/sfa.spec
@@ -1,6 +1,6 @@
 %define name sfa
 %define version 1.0
-%define taglevel 29
+%define taglevel 34
 
 %define release %{taglevel}%{?pldistro:.%{pldistro}}%{?date:.%{date}}
 %global python_sitearch        %( python -c "from distutils.sysconfig import get_python_lib; print get_python_lib(1)" )
@@ -196,6 +196,33 @@ fi
 [ "$1" -ge "1" ] && service sfa-cm restart || :
 
 %changelog
+* Mon Aug 29 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-34
+- new option -c to sfa-nuke-plc.py
+- CreateSliver fixed for admin-only slice tags
+
+* Wed Aug 24 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-32
+- Fixed exploit that allowed an authorities to issue certs for objects that dont belong to them.
+- Fixed holes in certificate verification logic.
+- Aggregates no longer try to lookup slice and person records when processing CreateSliver requests. Clients are now required to specify this info in the 'users' argument. 
+- Added 'boot_state' as an attribute of the node element in SFA rspec.
+- Non authority certificates are marked as CA:FALSE.
+
+* Tue Aug 16 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-32
+- fix typo in sfa-1.0-31 tag.
+- added CreateGid() Registry interface method.
+
+* Tue Aug 16 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-31
+- fix typo in sfa-1.0-30 tag
+
+* Tue Aug 16 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-30
+- Declare namespace and schema location in the credential.
+- Fix bug that prevetend connections from timing out.
+- Fix slice delegation.
+- Add statistics to slicemaanger listresources/createsliver rspec.
+- Added SFA_MAX_SLICE_RENEW which allows operators to configure the max ammout
+  of days a user can extend their slice expiration.
+- CA certs are only issued to objects of type authority
+   
 * Fri Aug 05 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-29
 - tag 1.0-28 was broken due to typo in the changelog
 - new class sfa/util/httpsProtocol.py that supports timeouts
index 6e0c82c..764ce18 100755 (executable)
@@ -154,6 +154,7 @@ class Sfi:
                   "update": "record",
                   "aggregates": "[name]",
                   "registries": "[name]",
+                  "create_gid": "[name]",
                   "get_gid": [],  
                   "get_trusted_certs": "cred",
                   "slices": "",
@@ -217,7 +218,7 @@ class Sfi:
                                 help="optional component information", default=None)
 
 
-        if command in ("resources", "show", "list"):
+        if command in ("resources", "show", "list", "create_gid"):
            parser.add_option("-o", "--output", dest="file",
                             help="output XML to file", metavar="FILE", default=None)
         
@@ -275,7 +276,7 @@ class Sfi:
         parser.add_option("-k", "--hashrequest",
                          action="store_true", dest="hashrequest", default=False,
                          help="Create a hash of the request that will be authenticated on the server")
-        parser.add_option("-t", "--timeout", dest="timeout", default=30,
+        parser.add_option("-t", "--timeout", dest="timeout", default=None,
                          help="Amout of time tom wait before timing out the request")
         parser.disable_interspersed_args()
 
@@ -645,7 +646,22 @@ class Sfi:
   
     def dispatch(self, command, cmd_opts, cmd_args):
         return getattr(self, command)(cmd_opts, cmd_args)
+
+    def create_gid(self, opts, args):
+        if len(args) < 1:
+            self.print_help()
+            sys.exit(1)
+        target_hrn = args[0]
+        user_cred = self.get_user_cred().save_to_string(save_parents=True)
+        gid = self.registry.CreateGid(user_cred, target_hrn, self.cert.save_to_string())
+        if opts.file:
+            filename = opts.file
+        else:
+            filename = os.sep.join([self.sfi_dir, '%s.gid' % target_hrn])
+        self.logger.info("writing %s gid to %s" % (target_hrn, filename))
+        GID(string=gid).save_to_file(filename)
+         
+     
     # list entires in named authority registry
     def list(self, opts, args):
         if len(args)!= 1:
@@ -696,7 +712,6 @@ class Sfi:
                 record.dump()  
             else:
                 print record.save_to_string() 
         if opts.file:
             file = opts.file
             if not file.startswith(os.sep):
@@ -926,6 +941,8 @@ class Sfi:
     
     # created named slice with given rspec
     def create(self, opts, args):
+        server = self.get_server_from_opts(opts)
+        server_version = self.get_cached_server_version(server)
         slice_hrn = args[0]
         slice_urn = hrn_to_urn(slice_hrn, 'slice') 
         user_cred = self.get_user_cred()
@@ -937,29 +954,40 @@ class Sfi:
         rspec_file = self.get_rspec_file(args[1])
         rspec = open(rspec_file).read()
 
+        # need to pass along user keys to the aggregate.  
         # users = [
         #  { urn: urn:publicid:IDN+emulab.net+user+alice
         #    keys: [<ssh key A>, <ssh key B>] 
         #  }]
         users = []
-        server = self.get_server_from_opts(opts)
-        version = self.get_cached_server_version(server)
-        if 'sfa' not in version:
-            # need to pass along user keys if this request is going to a ProtoGENI aggregate 
+        all_keys = []
+        all_key_ids = []
+        slice_records = self.registry.Resolve(slice_urn, [user_cred.save_to_string(save_parents=True)])
+        if slice_records and 'researcher' in slice_records[0]:
+            slice_record = slice_records[0]
+            user_hrns = slice_record['researcher']
+            user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns]
+            user_records = self.registry.Resolve(user_urns, [user_cred.save_to_string(save_parents=True)])
+            for user_record in user_records:
+                #user = {'urn': user_cred.get_gid_caller().get_urn(),'keys': []}
+                user = {'urn': user_cred.get_gid_caller().get_urn(), #
+                        'keys': user_record['keys'],
+                        'email': user_record['email'], #  needed for MyPLC
+                        'person_id': user_record['person_id'], # needed for MyPLC
+                        'first_name': user_record['first_name'], # needed for MyPLC
+                        'last_name': user_record['last_name'], # needed for MyPLC
+                        'slice_record': slice_record, # needed for legacy refresh peer
+                        'key_ids': user_record['key_ids'] # needed for legacy refresh peer
+                } 
+                users.append(user)
+                all_keys.extend(user_record['keys'])
+                all_key_ids.extend(user_record['key_ids'])
             # ProtoGeni Aggregates will only install the keys of the user that is issuing the
-            # request. So we will only pass in one user that contains the keys for all
-            # users of the slice 
-            user = {'urn': user_cred.get_gid_caller().get_urn(),
-                    'keys': []}
-            slice_record = self.registry.Resolve(slice_urn, creds)
-            if slice_record and 'researchers' in slice_record:
-                user_hrns = slice_record['researchers']
-                user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns] 
-                user_records = self.registry.Resolve(user_urns, creds)
-                for user_record in user_records:
-                    if 'keys' in user_record:
-                        user['keys'].extend(user_record['keys'])
-            users.append(user)
+            # request. So we will add all to the current caller's list of keys
+            if 'sfa' not in server_version:
+                for user in users:
+                    if user['urn'] == user_cred.get_gid_caller().get_urn():
+                        user['keys'] = all_keys  
 
         call_args = [slice_urn, creds, rspec, users]
         if self.server_supports_call_id_arg(server):
index 7dd0607..f723f49 100644 (file)
@@ -114,7 +114,6 @@ def SliverStatus(api, slice_xrn, creds, call_id):
 
     (hrn, type) = urn_to_hrn(slice_xrn)
     # find out where this slice is currently running
-    api.logger.info(hrn)
     slicename = hrn_to_pl_slicename(hrn)
     
     slices = api.plshell.GetSlices(api.plauth, [slicename], ['slice_id', 'node_ids','person_ids','name','expires'])
@@ -160,9 +159,6 @@ def SliverStatus(api, slice_xrn, creds, call_id):
         
     result['geni_status'] = top_level_status
     result['geni_resources'] = resources
-    # XX remove me
-    #api.logger.info(result)
-    # XX remove me
     return result
 
 def CreateSliver(api, slice_xrn, creds, rspec_string, users, call_id):
@@ -172,76 +168,36 @@ def CreateSliver(api, slice_xrn, creds, rspec_string, users, call_id):
     """
     if Callids().already_handled(call_id): return ""
 
-    reg_objects = __get_registry_objects(slice_xrn, creds, users)
-
-    (hrn, type) = urn_to_hrn(slice_xrn)
-    peer = None
     aggregate = Aggregate(api)
     slices = Slices(api)
+    (hrn, type) = urn_to_hrn(slice_xrn)
     peer = slices.get_peer(hrn)
     sfa_peer = slices.get_sfa_peer(hrn)
-    registry = api.registries[api.hrn]
-    credential = api.getCredential()
-    (site_id, remote_site_id) = slices.verify_site(registry, credential, hrn, 
-                                                   peer, sfa_peer, reg_objects)
-
-    slice = slices.verify_slice(registry, credential, hrn, site_id, 
-                                       remote_site_id, peer, sfa_peer, reg_objects)
-     
-    nodes = api.plshell.GetNodes(api.plauth, slice['node_ids'], ['hostname'])
-    current_slivers = [node['hostname'] for node in nodes] 
+    slice_record=None    
+    if users:
+        slice_record = users[0].get('slice_record', {})
+
+    # parse rspec
     rspec = parse_rspec(rspec_string)
+    requested_attributes = rspec.get_slice_attributes()
+    
+    # ensure site record exists
+    site = slices.verify_site(hrn, slice_record, peer, sfa_peer)
+    # ensure slice record exists
+    slice = slices.verify_slice(hrn, slice_record, peer, sfa_peer)
+    # ensure person records exists
+    persons = slices.verify_persons(hrn, slice, users, peer, sfa_peer)
+    # ensure slice attributes exists
+    slices.verify_slice_attributes(slice, requested_attributes)
+    
+    # add/remove slice from nodes
     requested_slivers = [str(host) for host in rspec.get_nodes_with_slivers()]
-    # remove nodes not in rspec
-    deleted_nodes = list(set(current_slivers).difference(requested_slivers))
-
-    # add nodes from rspec
-    added_nodes = list(set(requested_slivers).difference(current_slivers))
-
-    # get sliver attributes
-    requested_slice_attributes = rspec.get_slice_attributes()
-    removed_slice_attributes = []
-    existing_slice_attributes = []    
-    for slice_tag in api.plshell.GetSliceTags(api.plauth, {'slice_id': slice['slice_id']}):
-        attribute_found=False
-        for requested_attribute in requested_slice_attributes:
-            if requested_attribute['name'] == slice_tag['tagname'] and \
-               requested_attribute['value'] == slice_tag['value']:
-                attribute_found=True
-
-        if not attribute_found: 
-            removed_slice_attributes.append(slice_tag)
-        else:
-            existing_slice_attributes.append(slice_tag)  
-         
-    try:
-        if peer:
-            api.plshell.UnBindObjectFromPeer(api.plauth, 'slice', slice['slice_id'], peer)
-
-        api.plshell.AddSliceToNodes(api.plauth, slice['name'], added_nodes) 
-        api.plshell.DeleteSliceFromNodes(api.plauth, slice['name'], deleted_nodes)
-        # remove stale attributes
-        for attribute in removed_slice_attributes:
-            try:
-                api.plshell.DeleteSliceTag(api.plauth, attribute['slice_tag_id'])
-            except Exception, e:
-                api.logger.warn('Failed to remove sliver attribute. name: %s, value: %s, node_id: %s\nCause:%s'\
-                                % (name, value,  node_id, str(e)))
-
-        # add requested_attributes
-        for attribute in requested_slice_attributes:
-            try:
-                name, value, node_id = attribute['name'], attribute['value'], attribute.get('node_id', None)
-                api.plshell.AddSliceTag(api.plauth, slice['name'], name, value, node_id)
-            except Exception, e:
-                api.logger.warn('Failed to add sliver attribute. name: %s, value: %s, node_id: %s\nCause:%s'\
-                                % (name, value,  node_id, str(e)))
-
-    finally:
-        if peer:
-            api.plshell.BindObjectToPeer(api.plauth, 'slice', slice['slice_id'], peer, 
-                                         slice['peer_id'])
+    slices.verify_slice_nodes(slice, requested_slivers, peer) 
 
+    # hanlde MyPLC peer association.
+    # only used by plc and ple.
+    slices.handle_peer(site, slice, persons, peer)
+    
     return aggregate.get_rspec(slice_xrn=slice_xrn, version=rspec.version)
 
 
index 8bec1f6..6052eee 100644 (file)
@@ -174,6 +174,18 @@ def list(api, xrn, origin_hrn=None):
     return records
 
 
+def create_gid(api, xrn, cert):
+    # get the authority
+    authority = Xrn(xrn=xrn).get_authority_hrn()
+    auth_info = api.auth.get_auth_info(authority)
+    if not cert:
+        pkey = Keypair(create=True)
+    else:
+        certificate = Certificate(string=cert)
+        pkey = certificate.get_pubkey()    
+    gid = api.auth.hierarchy.create_gid(xrn, create_uuid(), pkey) 
+    return gid.save_to_string(save_parents=True)
+    
 def register(api, record):
 
     hrn, type = record['hrn'], record['type']
@@ -192,7 +204,6 @@ def register(api, record):
     record['authority'] = get_authority(record['hrn'])
     type = record['type']
     hrn = record['hrn']
-    api.auth.verify_object_permission(hrn)
     auth_info = api.auth.get_auth_info(record['authority'])
     pub_key = None
     # make sure record has a gid
@@ -288,7 +299,6 @@ def update(api, record_dict):
     type = new_record['type']
     hrn = new_record['hrn']
     urn = hrn_to_urn(hrn,type)
-    api.auth.verify_object_permission(hrn)
     table = SfaTable()
     # make sure the record exists
     records = table.findObjects({'type': type, 'hrn': hrn})
diff --git a/sfa/methods/CreateGid.py b/sfa/methods/CreateGid.py
new file mode 100644 (file)
index 0000000..7c1bf8a
--- /dev/null
@@ -0,0 +1,50 @@
+### $Id: register.py 16477 2010-01-05 16:31:37Z thierry $
+### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/register.py $
+
+from sfa.util.xrn import urn_to_hrn
+from sfa.util.method import Method
+from sfa.util.parameter import Parameter, Mixed
+from sfa.trust.credential import Credential
+
+class CreateGid(Method):
+    """
+    Create a signed credential for the s object with the registry. In addition to being stored in the
+    SFA database, the appropriate records will also be created in the
+    PLC databases
+    
+    @param xrn urn or hrn of certificate owner
+    @param cert caller's certificate
+    @param cred credential string
+    
+    @return gid string representation
+    """
+
+    interfaces = ['registry']
+    
+    accepts = [
+        Mixed(Parameter(str, "Credential string"),
+              Parameter(type([str]), "List of credentials")),
+        Parameter(str, "URN or HRN of certificate owner"),
+        Parameter(str, "Certificate string"),
+        ]
+
+    returns = Parameter(int, "String representation of gid object")
+    
+    def call(self, creds, xrn, cert=None):
+        # TODO: is there a better right to check for or is 'update good enough? 
+        valid_creds = self.api.auth.checkCredentials(creds, 'update')
+
+        # verify permissions
+        hrn, type = urn_to_hrn(xrn)
+        self.api.auth.verify_object_permission(hrn)
+
+        #log the call
+        origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
+
+        # log
+        origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
+        self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, xrn, self.name))
+
+        manager = self.api.get_interface_manager()
+
+        return manager.create_gid(self.api, xrn, cert)
index e62e6f4..7895de3 100644 (file)
@@ -37,6 +37,11 @@ class CreateSliver(Method):
         valid_creds = self.api.auth.checkCredentials(creds, 'createsliver', hrn)
         origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
 
+        # make sure users info is specified
+        if not users:
+            msg = "'users' musst be specified and cannot be null. You may need to update your client." 
+            raise SfaInvalidArgument(name='users', extra=msg)  
+
         manager = self.api.get_interface_manager()
         
         # flter rspec through sfatables
index 1233fa8..6f61870 100644 (file)
@@ -34,15 +34,15 @@ class Register(Method):
     returns = Parameter(int, "String representation of gid object")
     
     def call(self, record, creds):
-        
+        # validate cred    
         valid_creds = self.api.auth.checkCredentials(creds, 'register')
+        
+        # verify permissions
+        hrn = record.get('hrn', '')
+        self.api.auth.verify_object_permission(hrn)
 
         #log the call
         origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
-
-        hrn = None
-        if 'hrn' in record:
-            hrn = record['hrn']
         self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, hrn, self.name))
         
         manager = self.api.get_interface_manager()
index d36ea36..aa881ea 100644 (file)
@@ -1,6 +1,3 @@
-### $Id: update.py 16477 2010-01-05 16:31:37Z thierry $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/update.py $
-
 import time
 from sfa.util.faults import *
 from sfa.util.method import Method
@@ -31,8 +28,14 @@ class Update(Method):
     def call(self, record_dict, creds):
         # validate the cred
         valid_creds = self.api.auth.checkCredentials(creds, "update")
+        
+        # verify permissions
+        hrn = record_dict.get('hrn', '')  
+        self.api.auth.verify_object_permission(hrn)
+    
+        # log
         origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
-        self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, None, self.name))
+        self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, hrn, self.name))
        
         manager = self.api.get_interface_manager()
  
index a585d93..eef24de 100644 (file)
@@ -1,6 +1,7 @@
 ## Please use make index to update this file
 all = """
 CreateSliver
+CreateGid
 DeleteSliver
 GetCredential
 GetGids
index 7ba6337..fb84020 100755 (executable)
@@ -21,6 +21,8 @@ def main():
    parser = OptionParser(usage=usage)
    parser.add_option('-f','--file-system',dest='clean_fs',action='store_true',default=False,
                      help='Clean up the /var/lib/sfa/authorities area as well')
+   parser.add_option('-c','--certs',dest='clean_certs',action='store_true',default=False,
+                     help='Remove all cached certs/gids found in /var/lib/sfa/authorities area as well')
    (options,args)=parser.parse_args()
    if args:
       parser.print_help()
@@ -28,8 +30,23 @@ def main():
    logger.info("Purging SFA records from database")
    table = SfaTable()
    table.sfa_records_purge()
+
+   if options.clean_certs:
+      # remove the server certificate and all gids found in /var/lib/sfa/authorities
+      logger.info("Purging cached certificates")
+      for (dir, _, files) in os.walk('/var/lib/sfa/authorities'):
+         for file in files:
+            if file.endswith('.gid') or file == 'server.cert':
+               path=dir+os.sep+file
+               os.unlink(path)
+               if not os.path.exists(path):
+                  logger.info("Unlinked file %s"%path)
+               else:
+                  logger.error("Could not unlink file %s"%path)
+
    if options.clean_fs:
       # just remove all files that do not match 'server.key' or 'server.cert'
+      logger.info("Purging registry filesystem cache")
       preserved_files = [ 'server.key', 'server.cert']
       for (dir,_,files) in os.walk('/var/lib/sfa/authorities'):
          for file in files:
index f99ddc1..c35f753 100644 (file)
@@ -5,7 +5,7 @@ import sys
 
 from types import StringTypes
 from sfa.util.xrn import get_leaf, get_authority, hrn_to_urn, urn_to_hrn
-from sfa.util.plxrn import hrn_to_pl_slicename
+from sfa.util.plxrn import hrn_to_pl_slicename, hrn_to_pl_login_base
 from sfa.util.rspec import *
 from sfa.util.specdict import *
 from sfa.util.faults import *
@@ -24,6 +24,8 @@ class Slices:
         #filepath = path + os.sep + filename
         self.policy = Policy(self.api)    
         self.origin_hrn = origin_hrn
+        self.registry = api.registries[api.hrn]
+        self.credential = api.getCredential()
 
     def get_slivers(self, xrn, node=None):
         hrn, type = urn_to_hrn(xrn)
@@ -148,7 +150,7 @@ class Slices:
         for peer_record in peers:
             names = [name.lower() for name in peer_record.values() if isinstance(name, StringTypes)]
             if site_authority in names:
-                peer = peer_record['shortname']
+                peer = peer_record
 
         return peer
 
@@ -163,206 +165,335 @@ class Slices:
         if site_authority != self.api.hrn:
             sfa_peer = site_authority
 
-        return sfa_peer 
+        return sfa_peer
 
-    def verify_site(self, registry, credential, slice_hrn, peer, sfa_peer, reg_objects=None):
-        authority = get_authority(slice_hrn)
-        authority_urn = hrn_to_urn(authority, 'authority')
-        login_base = None
-        if reg_objects:
-            site = reg_objects['site']
-            login_base = site['login_base']
-        else:
-            site_records = registry.Resolve(authority_urn, [credential])
-            site = {}            
-            for site_record in site_records:            
-                if site_record['type'] == 'authority':
-                    site = site_record
-            if not site:
-                raise RecordNotFound(authority)
-            
-        remote_site_id = site.pop('site_id')    
+    def verify_slice_nodes(self, slice, requested_slivers, peer):
         
-        if login_base is None:
-            login_base = get_leaf(authority)
-        sites = self.api.plshell.GetSites(self.api.plauth, login_base)
+        nodes = self.api.plshell.GetNodes(self.api.plauth, slice['node_ids'], ['hostname'])
+        current_slivers = [node['hostname'] for node in nodes]
 
-        if not sites:
-            site_id = self.api.plshell.AddSite(self.api.plauth, site)
-            if peer:
-                try:
-                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', site_id, peer, remote_site_id)   
-                except Exception,e:
-                    self.api.plshell.DeleteSite(self.api.plauth, site_id)
-                    raise e
-            # mark this site as an sfa peer record
-            if sfa_peer and not reg_objects:
-                peer_dict = {'type': 'authority', 'hrn': authority, 'peer_authority': sfa_peer, 'pointer': site_id}
-                registry.register_peer_object(credential, peer_dict)
+        # remove nodes not in rspec
+        deleted_nodes = list(set(current_slivers).difference(requested_slivers))
 
-            # exempt federated sites from monitor policies
-            self.api.plshell.AddSiteTag(site_id, 'exempt_site_until', "20200101")
-             
-        else:
-            site_id = sites[0]['site_id']
-            remote_site_id = sites[0]['peer_site_id']
-            old_site = sites[0]
-            #the site is already on the remote agg. Let us update(e.g. max_slices field) it with the latest info.
-            self.sync_site(old_site, site, peer)
+        # add nodes from rspec
+        added_nodes = list(set(requested_slivers).difference(current_slivers))        
 
+        try:
+            if peer:
+                self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'slice', slice['slice_id'], peer['shortname'])
+            self.api.plshell.AddSliceToNodes(self.api.plauth, slice['name'], added_nodes)
+            self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slice['name'], deleted_nodes)
 
-        return (site_id, remote_site_id) 
+        except: 
+            self.api.logger.log_exc('Failed to add/remove slice from nodes')
 
-    def verify_slice(self, registry, credential, slice_hrn, site_id, remote_site_id, peer, sfa_peer, reg_objects=None):
-        slice = {}
-        slice_record = None
-        authority = get_authority(slice_hrn)
+    def handle_peer(self, site, slice, persons, peer):
+        if peer:
+            # bind site
+            try:
+                if site:
+                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', \
+                       site['site_id'], peer['shortname'], slice['site_id'])
+            except Exception,e:
+                self.api.plshell.DeleteSite(self.api.plauth, site['site_id'])
+                raise e
+            
+            # bind slice
+            try:
+                if slice:
+                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', \
+                       slice['slice_id'], peer['shortname'], slice['slice_id'])
+            except Exception,e:
+                self.api.plshell.DeleteSlice(self.api.plauth, slice['slice_id'])
+                raise e 
+
+            # bind persons
+            for person in persons:
+                try:
+                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', \
+                        person['person_id'], peer['shortname'], person['peer_person_id'])
+
+                    for (key, remote_key_id) in zip(person['keys'], person['key_ids']):
+                        try:
+                            self.api.plshell.BindObjectToPeer(self.api.plauth, 'key',\
+                                key['key_id'], peer['shortname'], remote_key_id)
+                        except:
+                            self.api.plshell.DeleteKey(self.api.plauth, key['key_id'])
+                            self.api.logger("failed to bind key: %s to peer: %s " % (key['key_id'], peer['shortname']))
+                except Exception,e:
+                    self.api.plshell.DeletePerson(self.api.plauth, person['person_id'])
+                    raise e       
 
-        if reg_objects:
-            slice_record = reg_objects['slice_record']
-        else:
-            slice_records = registry.Resolve(slice_hrn, [credential])
-    
-            for record in slice_records:
-                if record['type'] in ['slice']:
-                    slice_record = record
-            if not slice_record:
-                raise RecordNotFound(hrn)
+        return slice
+
+    def verify_site(self, slice_xrn, slice_record={}, peer=None, sfa_peer=None):
+        (slice_hrn, type) = urn_to_hrn(slice_xrn)
+        site_hrn = get_authority(slice_hrn)
+        # login base can't be longer than 20 characters
+        authority_name = get_leaf(site_hrn) 
+        login_base = authority_name[:20]
+        sites = self.api.plshell.GetSites(self.api.plauth, login_base)
+        if not sites:
+            # create new site record
+            site = {'name': 'geni.%s' % authority_name,
+                    'abbreviated_name': authority_name,
+                    'login_base': login_base,
+                    'max_slices': 100,
+                    'max_slivers': 1000,
+                    'enabled': True,
+                    'peer_site_id': None}
+            if peer:
+                site['peer_site_id'] = slice_record.get('site_id', None)
+            site['site_id'] = self.api.plshell.AddSite(self.api.plauth, site)
+            # exempt federated sites from monitor policies
+            self.api.plshell.AddSiteTag(self.api.plauth, site['site_id'], 'exempt_site_until', "20200101")
             
+            # is this still necessary?
+            # add record to the local registry 
+            if sfa_peer and slice_record:
+                peer_dict = {'type': 'authority', 'hrn': site_hrn, \
+                             'peer_authority': sfa_peer, 'pointer': site['site_id']}
+                self.registry.register_peer_object(self.credential, peer_dict)
+        else:
+            site =  sites[0]
+            if peer:
+                # unbind from peer so we can modify if necessary. Will bind back later
+                self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'site', site['site_id'], peer['shortname']) 
         
+        return site        
+
+    def verify_slice(self, slice_hrn, slice_record, peer, sfa_peer):
         slicename = hrn_to_pl_slicename(slice_hrn)
         parts = slicename.split("_")
         login_base = parts[0]
         slices = self.api.plshell.GetSlices(self.api.plauth, [slicename]) 
         if not slices:
-            slice_fields = {}
-            slice_keys = ['name', 'url', 'description']
-            for key in slice_keys:
-                if key in slice_record and slice_record[key]:
-                    slice_fields[key] = slice_record[key]
+            slice = {'name': slicename,
+                     'url': slice_record.get('url', slice_hrn), 
+                     'description': slice_record.get('description', slice_hrn)}
             # add the slice                          
-            slice_id = self.api.plshell.AddSlice(self.api.plauth, slice_fields)
-            slice = slice_fields
-            slice['slice_id'] = slice_id
-
+            slice['slice_id'] = self.api.plshell.AddSlice(self.api.plauth, slice)
+            slice['node_ids'] = []
+            slice['person_ids'] = []
+            if peer:
+                slice['peer_slice_id'] = slice_record.get('slice_id', None) 
             # mark this slice as an sfa peer record
             if sfa_peer:
-                peer_dict = {'type': 'slice', 'hrn': slice_hrn, 'peer_authority': sfa_peer, 'pointer': slice_id}
-                registry.register_peer_object(credential, peer_dict)
-
-            #this belongs to a peer
-            if peer:
-                try:
-                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', slice_id, peer, slice_record['pointer'])
-                except Exception,e:
-                    self.api.plshell.DeleteSlice(self.api.plauth,slice_id)
-                    raise e
-            slice['node_ids'] = []
+                peer_dict = {'type': 'slice', 'hrn': slice_hrn, 
+                             'peer_authority': sfa_peer, 'pointer': slice['slice_id']}
+                self.registry.register_peer_object(self.credential, peer_dict)
         else:
             slice = slices[0]
-            slice_id = slice['slice_id']
-            site_id = slice['site_id']
-           #the slice is alredy on the remote agg. Let us update(e.g. expires field) it with the latest info.
-           self.sync_slice(slice, slice_record, peer)
-
-        slice['peer_slice_id'] = slice_record['pointer']
-        self.verify_persons(registry, credential, slice_record, site_id, remote_site_id, peer, sfa_peer, reg_objects)
-    
+            if peer:
+                slice['peer_slice_id'] = slice_record.get('slice_id', None)
+                # unbind from peer so we can modify if necessary. Will bind back later
+                self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'slice',\
+                             slice['slice_id'], peer['shortname'])
+               #Update existing record (e.g. expires field) it with the latest info.
+            if slice_record and slice['expires'] != slice_record['expires']:
+                self.api.plshell.UpdateSlice(self.api.plauth, slice['slice_id'],\
+                             {'expires' : slice_record['expires']})
+       
         return slice        
 
-    def verify_persons(self, registry, credential, slice_record, site_id, remote_site_id, peer, sfa_peer, reg_objects=None):
-        # get the list of valid slice users from the registry and make 
-        # sure they are added to the slice 
-        slicename = hrn_to_pl_slicename(slice_record['hrn'])
-        if reg_objects:
-            researchers = reg_objects['users'].keys()
-        else:
-            researchers = slice_record.get('researcher', [])
-        for researcher in researchers:
-            if reg_objects:
-                person_dict = reg_objects['users'][researcher]
-            else:
-                person_records = registry.Resolve(researcher, [credential])
-                for record in person_records:
-                    if record['type'] in ['user'] and record['enabled']:
-                        person_record = record
-                if not person_record:
-                    return 1
-                person_dict = person_record
-
-            local_person=False
-            if peer:
-                peer_id = self.api.plshell.GetPeers(self.api.plauth, {'shortname': peer}, ['peer_id'])[0]['peer_id']
-                persons = self.api.plshell.GetPersons(self.api.plauth, {'email': [person_dict['email']], 'peer_id': peer_id}, ['person_id', 'key_ids'])
-                if not persons:
-                    persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])
-                    if persons:
-                        local_person=True
-                        
+    def verify_persons(self, slice_hrn, slice_record, users, peer, sfa_peer):
+        slicename = hrn_to_pl_slicename(slice_hrn)
+        login_base = hrn_to_pl_login_base(slice_hrn)
+        # create a dict users keyed on the user's email
+        users_dict = {}
+        for user in users:
+            if 'email' in user:     
+                users_dict[user['email'].lower()] = user
             else:
-                persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])   
+                fake_email = hrn_to_pl_slicename(slice_hrn) + "@geni.net"
+                user['email'] = fake_email.lower()
+                users_dict[fake_email] = user
         
-            if not persons:
-                person_id=self.api.plshell.AddPerson(self.api.plauth, person_dict)
-                self.api.plshell.UpdatePerson(self.api.plauth, person_id, {'enabled' : True})
-                
-                # mark this person as an sfa peer record
-                if sfa_peer:
-                    peer_dict = {'type': 'user', 'hrn': researcher, 'peer_authority': sfa_peer, 'pointer': person_id}
-                    registry.register_peer_object(credential, peer_dict)
-
-                if peer:
-                    try:
-                        self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
-                    except Exception,e:
-                        self.api.plshell.DeletePerson(self.api.plauth,person_id)
-                        raise e
-                key_ids = []
-            else:
-                person_id = persons[0]['person_id']
-                key_ids = persons[0]['key_ids']
+        # requested slice users        
+        requested_user_ids = users_dict.keys()
+        
+        # existing users
+        existing_users_filter = {'email': requested_user_ids}
+        existing_users = self.api.plshell.GetPersons(self.api.plauth, \
+            existing_users_filter, ['person_id', 'key_ids', 'email'])
+        existing_user_ids = [user['email'] for user in existing_users]
+            
+        # existing slice users
+        existing_slice_users_filter = {'person_id': slice_record.get('person_ids', [])}
+        existing_slice_users = self.api.plshell.GetPersons(self.api.plauth, \
+             existing_slice_users_filter, ['person_id', 'key_ids', 'email'])
+        existing_slice_user_ids = [user['email'] for user in existing_slice_users]
+        
+        # users to be added, removed or updated
+        added_user_ids = set(requested_user_ids).difference(existing_user_ids)
+        added_slice_user_ids = set(requested_user_ids).difference(existing_slice_user_ids)
+        removed_user_ids = set(existing_slice_user_ids).difference(requested_user_ids)
+        updated_user_ids = set(existing_slice_user_ids).intersection(requested_user_ids)
+
+        # remove stale users
+        for removed_user_id in removed_user_ids:
+            self.api.plshell.DeletePersonFromSlice(self.api.plauth, removed_user_id, slicename)
+
+        # update_existing users
+        updated_users_list = [user for user in users if user['email'] in updated_user_ids]
+        self.verify_keys(existing_slice_users, updated_users_list, peer)
+
+        added_persons = []
+        # add new users
+        for added_user_id in added_user_ids:
+            added_user = users_dict[added_user_id]
+            hrn, type = urn_to_hrn(added_user['urn'])  
+            person = {
+                'first_name': added_user.get('first_name', hrn),
+                'last_name': added_user.get('last_name', hrn),
+                'email': added_user_id,
+                'peer_person_id': None,
+                'keys': [],
+                'key_ids': added_user.get('key_ids', []),
+            }
+            person['person_id'] = self.api.plshell.AddPerson(self.api.plauth, person)
+            if peer:
+                person['peer_person_id'] = added_user['person_id']
+            added_persons.append(person)
+           
+            # enable the account 
+            self.api.plshell.UpdatePerson(self.api.plauth, person['person_id'], {'enabled': True})
+            
+            # add person to site
+            self.api.plshell.AddPersonToSite(self.api.plauth, added_user_id, login_base)
 
+            for key_string in added_user.get('keys', []):
+                key = {'key':key_string, 'key_type':'ssh'}
+                key['key_id'] = self.api.plshell.AddPersonKey(self.api.plauth, person['person_id'], key)
+                person['keys'].append(key)
+
+            # add the registry record
+            if sfa_peer:
+                peer_dict = {'type': 'user', 'hrn': hrn, 'peer_authority': sfa_peer, \
+                    'pointer': person['person_id']}
+                self.registry.register_peer_object(self.credential, peer_dict)
+    
+        for added_slice_user_id in added_slice_user_ids.union(added_user_ids):
+            # add person to the slice 
+            self.api.plshell.AddPersonToSlice(self.api.plauth, added_slice_user_id, slicename)
 
-            # if this is a peer person, we must unbind them from the peer or PLCAPI will throw
-            # an error
-            try:
-                if peer and not local_person:
-                    self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'person', person_id, peer)
-                if peer:
-                    self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'site', site_id,  peer)
-
-                self.api.plshell.AddPersonToSlice(self.api.plauth, person_dict['email'], slicename)
-                self.api.plshell.AddPersonToSite(self.api.plauth, person_dict['email'], site_id)
-            finally:
-                if peer:
-                    try: self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', site_id, peer, remote_site_id)
-                    except: pass
-                if peer and not local_person:
-                    try: self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
-                    except: pass
             
-            self.verify_keys(registry, credential, person_dict, key_ids, person_id, peer, local_person)
+            # if this is a peer record then it should already be bound to a peer.
+            # no need to return worry about it getting bound later 
 
-    def verify_keys(self, registry, credential, person_dict, key_ids, person_id,  peer, local_person):
-        keylist = self.api.plshell.GetKeys(self.api.plauth, key_ids, ['key'])
-        keys = [key['key'] for key in keylist]
+        return added_persons
+            
+
+    def verify_keys(self, persons, users, peer):
+        # existing keys 
+        key_ids = []
+        for person in persons:
+            key_ids.extend(person['key_ids'])
+        keylist = self.api.plshell.GetKeys(self.api.plauth, key_ids, ['key_id', 'key'])
+        keydict = {}
+        for key in keylist:
+            keydict[key['key']] = key['key_id']     
+        existing_keys = keydict.keys()
+        persondict = {}
+        for person in persons:
+            persondict[person['email']] = person    
+    
+        # add new keys
+        requested_keys = []
+        updated_persons = []
+        for user in users:
+            user_keys = user.get('keys', [])
+            updated_persons.append(user)
+            for key_string in user_keys:
+                requested_keys.append(key_string)
+                if key_string not in existing_keys:
+                    key = {'key': key_string, 'key_type': 'ssh'}
+                    try:
+                        if peer:
+                            person = persondict[user['email']]
+                            self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'person', person['person_id'], peer['shortname'])
+                        key['key_id'] = self.api.plshell.AddPersonKey(self.api.plauth, user['email'], key)
+                        if peer:
+                            key_index = user_keys.index(key['key'])
+                            remote_key_id = user['key_ids'][key_index]
+                            self.api.plshell.BindObjectToPeer(self.api.plauth, 'key', key['key_id'], peer['shortname'], remote_key_id)
+                            
+                    finally:
+                        if peer:
+                            self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person['person_id'], peer['shortname'], user['person_id'])
         
-        #add keys that arent already there
-        key_ids = person_dict['key_ids']
-        for personkey in person_dict['keys']:
-            if personkey not in keys:
-                key = {'key_type': 'ssh', 'key': personkey}
+        # remove old keys
+        removed_keys = set(existing_keys).difference(requested_keys)
+        for existing_key_id in keydict:
+            if keydict[existing_key_id] in removed_keys:
                 try:
                     if peer:
-                        self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'person', person_id, peer)
-                    key_id = self.api.plshell.AddPersonKey(self.api.plauth, person_dict['email'], key)
-                finally:
-                    if peer and not local_person:
-                        self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
-                    if peer:
-                        # xxx - thierry how are we getting the peer_key_id in here ?
-                        try: self.api.plshell.BindObjectToPeer(self.api.plauth, 'key', key_id, peer, key_ids.pop(0))
-                        except: pass   
+                        self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'key', existing_key_id, peer['shortname'])
+                    self.api.plshell.DeleteKey(self.api.plauth, existing_key_id)
+                except:
+                    pass   
+
+    def verify_slice_attributes(self, slice, requested_slice_attributes):
+        # get list of attributes users ar able to manage
+        slice_attributes = self.api.plshell.GetTagTypes(self.api.plauth, {'category': '*slice*', '|roles': ['user']})
+        valid_slice_attribute_names = [attribute['tagname'] for attribute in slice_attributes]
+
+        # get sliver attributes
+        added_slice_attributes = []
+        removed_slice_attributes = []
+        ignored_slice_attribute_names = []
+        existing_slice_attributes = self.api.plshell.GetSliceTags(self.api.plauth, {'slice_id': slice['slice_id']})
+
+        # get attributes that should be removed
+        for slice_tag in existing_slice_attributes:
+            if slice_tag['tagname'] in ignored_slice_attribute_names:
+                # If a slice already has a admin only role it was probably given to them by an
+                # admin, so we should ignore it.
+                ignored_slice_attribute_names.append(slice_tag['tagname'])
+            else:
+                # If an existing slice attribute was not found in the request it should
+                # be removed
+                attribute_found=False
+                for requested_attribute in requested_slice_attributes:
+                    if requested_attribute['name'] == slice_tag['tagname'] and \
+                       requested_attribute['value'] == slice_tag['value']:
+                        attribute_found=True
+                        break
+
+            if not attribute_found:
+                removed_slice_attributes.append(slice_tag)
+        
+        # get attributes that should be added:
+        for requested_attribute in requested_slice_attributes:
+            # if the requested attribute wasn't found  we should add it
+            if requested_attribute['name'] in valid_slice_attribute_names:
+                attribute_found = False
+                for existing_attribute in existing_slice_attributes:
+                    if requested_attribute['name'] == existing_attribute['tagname'] and \
+                       requested_attribute['value'] == existing_attribute['value']:
+                        attribute_found=True
+                        break
+                if not attribute_found:
+                    added_slice_attributes.append(requested_attribute)
+
+
+        # remove stale attributes
+        for attribute in removed_slice_attributes:
+            try:
+                self.api.plshell.DeleteSliceTag(self.api.plauth, attribute['slice_tag_id'])
+            except Exception, e:
+                self.api.logger.warn('Failed to remove sliver attribute. name: %s, value: %s, node_id: %s\nCause:%s'\
+                                % (name, value,  node_id, str(e)))
+
+        # add requested_attributes
+        for attribute in added_slice_attributes:
+            try:
+                name, value, node_id = attribute['name'], attribute['value'], attribute.get('node_id', None)
+                self.api.plshell.AddSliceTag(self.api.plauth, slice['name'], name, value, node_id)
+            except Exception, e:
+                self.api.logger.warn('Failed to add sliver attribute. name: %s, value: %s, node_id: %s\nCause:%s'\
+                                % (name, value,  node_id, str(e)))
 
     def create_slice_aggregate(self, xrn, rspec):
         hrn, type = urn_to_hrn(xrn)
@@ -443,27 +574,3 @@ class Slices:
 
         return 1
 
-    def sync_site(self, old_record, new_record, peer):
-        if old_record['max_slices'] != new_record['max_slices'] or old_record['max_slivers'] != new_record['max_slivers']:
-            try:
-                if peer:
-                    self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'site', old_record['site_id'], peer)
-                if old_record['max_slices'] != new_record['max_slices']:
-                    self.api.plshell.UpdateSite(self.api.plauth, old_record['site_id'], {'max_slices' : new_record['max_slices']})
-                if old_record['max_slivers'] != new_record['max_slivers']:
-                    self.api.plshell.UpdateSite(self.api.plauth, old_record['site_id'], {'max_slivers' : new_record['max_slivers']})
-            finally:
-                if peer:
-                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', old_record['site_id'], peer, old_record['peer_site_id'])
-       return 1
-
-    def sync_slice(self, old_record, new_record, peer):
-        if old_record['expires'] != new_record['expires']:
-            try:
-                if peer:
-                    self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'slice', old_record['slice_id'], peer)
-                self.api.plshell.UpdateSlice(self.api.plauth, old_record['slice_id'], {'expires' : new_record['expires']})
-            finally:
-                if peer:
-                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', old_record['slice_id'], peer, old_record['peer_slice_id'])
-       return 1
index 5c62404..3d729a6 100755 (executable)
@@ -223,6 +223,8 @@ class SfaRSpec(RSpec):
                 node_tag.set('site_id', node['site_urn'])
             if 'node_id' in node: 
                 node_tag.set('node_id', 'n'+str(node['node_id']))
+            if 'boot_state' in node:
+                node_tag.set('boot_state', node['boot_state']) 
             if 'hostname' in node:
                 hostname_tag = etree.SubElement(node_tag, 'hostname').text = node['hostname']
             if 'interfaces' in node:
index 03c5d35..de891cd 100755 (executable)
@@ -232,6 +232,7 @@ def update_cert_records(gids):
     """
     # import SfaTable here so this module can be loaded by ComponentAPI
     from sfa.util.table import SfaTable
+    from sfa.util.record import SfaRecord
     if not gids:
         return
     table = SfaTable()
index 959b763..8cbe172 100644 (file)
-#----------------------------------------------------------------------
-# Copyright (c) 2008 Board of Trustees, Princeton University
-#
-# Permission is hereby granted, free of charge, to any person obtaining
-# a copy of this software and/or hardware specification (the "Work") to
-# deal in the Work without restriction, including without limitation the
-# rights to use, copy, modify, merge, publish, distribute, sublicense,
-# and/or sell copies of the Work, and to permit persons to whom the Work
-# is furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be
-# included in all copies or substantial portions of the Work.
-#
-# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
-# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 
-# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
-# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
-# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
-# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
-# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS 
-# IN THE WORK.
-#----------------------------------------------------------------------
-
-##
-# SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement
-# the necessary crypto functionality. Ideally just one of these libraries
-# would be used, but unfortunately each of these libraries is independently
-# lacking. The pyOpenSSL library is missing many necessary functions, and
-# the M2Crypto library has crashed inside of some of the functions. The
-# design decision is to use pyOpenSSL whenever possible as it seems more
-# stable, and only use M2Crypto for those functions that are not possible
-# in pyOpenSSL.
-#
-# This module exports two classes: Keypair and Certificate.
-##
-#
-
-import functools
-import os
-import tempfile
-import base64
-import traceback
-from tempfile import mkstemp
-
-from OpenSSL import crypto
-import M2Crypto
-from M2Crypto import X509
-
-from sfa.util.sfalogging import logger
-from sfa.util.xrn import urn_to_hrn
-from sfa.util.faults import *
-from sfa.util.sfalogging import logger
-
-glo_passphrase_callback = None
-
-##
-# A global callback msy be implemented for requesting passphrases from the
-# user. The function will be called with three arguments:
-#
-#    keypair_obj: the keypair object that is calling the passphrase
-#    string: the string containing the private key that's being loaded
-#    x: unknown, appears to be 0, comes from pyOpenSSL and/or m2crypto
-#
-# The callback should return a string containing the passphrase.
-
-def set_passphrase_callback(callback_func):
-    global glo_passphrase_callback
-
-    glo_passphrase_callback = callback_func
-
-##
-# Sets a fixed passphrase.
-
-def set_passphrase(passphrase):
-    set_passphrase_callback( lambda k,s,x: passphrase )
-
-##
-# Check to see if a passphrase works for a particular private key string.
-# Intended to be used by passphrase callbacks for input validation.
-
-def test_passphrase(string, passphrase):
-    try:
-        crypto.load_privatekey(crypto.FILETYPE_PEM, string, (lambda x: passphrase))
-        return True
-    except:
-        return False
-
-def convert_public_key(key):
-    keyconvert_path = "/usr/bin/keyconvert.py"
-    if not os.path.isfile(keyconvert_path):
-        raise IOError, "Could not find keyconvert in %s" % keyconvert_path
-
-    # we can only convert rsa keys
-    if "ssh-dss" in key:
-        return None
-
-    (ssh_f, ssh_fn) = tempfile.mkstemp()
-    ssl_fn = tempfile.mktemp()
-    os.write(ssh_f, key)
-    os.close(ssh_f)
-
-    cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
-    os.system(cmd)
-
-    # this check leaves the temporary file containing the public key so
-    # that it can be expected to see why it failed.
-    # TODO: for production, cleanup the temporary files
-    if not os.path.exists(ssl_fn):
-        return None
-
-    k = Keypair()
-    try:
-        k.load_pubkey_from_file(ssl_fn)
-    except:
-        logger.log_exc("convert_public_key caught exception")
-        k = None
-
-    # remove the temporary files
-    os.remove(ssh_fn)
-    os.remove(ssl_fn)
-
-    return k
-
-##
-# Public-private key pairs are implemented by the Keypair class.
-# A Keypair object may represent both a public and private key pair, or it
-# may represent only a public key (this usage is consistent with OpenSSL).
-
-class Keypair:
-    key = None       # public/private keypair
-    m2key = None     # public key (m2crypto format)
-
-    ##
-    # Creates a Keypair object
-    # @param create If create==True, creates a new public/private key and
-    #     stores it in the object
-    # @param string If string!=None, load the keypair from the string (PEM)
-    # @param filename If filename!=None, load the keypair from the file
-
-    def __init__(self, create=False, string=None, filename=None):
-        if create:
-            self.create()
-        if string:
-            self.load_from_string(string)
-        if filename:
-            self.load_from_file(filename)
-
-    ##
-    # Create a RSA public/private key pair and store it inside the keypair object
-
-    def create(self):
-        self.key = crypto.PKey()
-        self.key.generate_key(crypto.TYPE_RSA, 1024)
-
-    ##
-    # Save the private key to a file
-    # @param filename name of file to store the keypair in
-
-    def save_to_file(self, filename):
-        open(filename, 'w').write(self.as_pem())
-        self.filename=filename
-
-    ##
-    # Load the private key from a file. Implicity the private key includes the public key.
-
-    def load_from_file(self, filename):
-        self.filename=filename
-        buffer = open(filename, 'r').read()
-        self.load_from_string(buffer)
-
-    ##
-    # Load the private key from a string. Implicitly the private key includes the public key.
-
-    def load_from_string(self, string):
-        if glo_passphrase_callback:
-            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string) )
-            self.m2key = M2Crypto.EVP.load_key_string(string, functools.partial(glo_passphrase_callback, self, string) )
-        else:
-            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
-            self.m2key = M2Crypto.EVP.load_key_string(string)
-
-    ##
-    #  Load the public key from a string. No private key is loaded.
-
-    def load_pubkey_from_file(self, filename):
-        # load the m2 public key
-        m2rsakey = M2Crypto.RSA.load_pub_key(filename)
-        self.m2key = M2Crypto.EVP.PKey()
-        self.m2key.assign_rsa(m2rsakey)
-
-        # create an m2 x509 cert
-        m2name = M2Crypto.X509.X509_Name()
-        m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
-        m2x509 = M2Crypto.X509.X509()
-        m2x509.set_pubkey(self.m2key)
-        m2x509.set_serial_number(0)
-        m2x509.set_issuer_name(m2name)
-        m2x509.set_subject_name(m2name)
-        ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
-        ASN1.set_time(500)
-        m2x509.set_not_before(ASN1)
-        m2x509.set_not_after(ASN1)
-        # x509v3 so it can have extensions
-        # prob not necc since this cert itself is junk but still...
-        m2x509.set_version(2)
-        junk_key = Keypair(create=True)
-        m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
-
-        # convert the m2 x509 cert to a pyopenssl x509
-        m2pem = m2x509.as_pem()
-        pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
-
-        # get the pyopenssl pkey from the pyopenssl x509
-        self.key = pyx509.get_pubkey()
-        self.filename=filename
-
-    ##
-    # Load the public key from a string. No private key is loaded.
-
-    def load_pubkey_from_string(self, string):
-        (f, fn) = tempfile.mkstemp()
-        os.write(f, string)
-        os.close(f)
-        self.load_pubkey_from_file(fn)
-        os.remove(fn)
-
-    ##
-    # Return the private key in PEM format.
-
-    def as_pem(self):
-        return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
-
-    ##
-    # Return an M2Crypto key object
-
-    def get_m2_pkey(self):
-        if not self.m2key:
-            self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
-        return self.m2key
-
-    ##
-    # Returns a string containing the public key represented by this object.
-
-    def get_pubkey_string(self):
-        m2pkey = self.get_m2_pkey()
-        return base64.b64encode(m2pkey.as_der())
-
-    ##
-    # Return an OpenSSL pkey object
-
-    def get_openssl_pkey(self):
-        return self.key
-
-    ##
-    # Given another Keypair object, return TRUE if the two keys are the same.
-
-    def is_same(self, pkey):
-        return self.as_pem() == pkey.as_pem()
-
-    def sign_string(self, data):
-        k = self.get_m2_pkey()
-        k.sign_init()
-        k.sign_update(data)
-        return base64.b64encode(k.sign_final())
-
-    def verify_string(self, data, sig):
-        k = self.get_m2_pkey()
-        k.verify_init()
-        k.verify_update(data)
-        return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
-
-    def compute_hash(self, value):
-        return self.sign_string(str(value))
-
-    # only informative
-    def get_filename(self):
-        return getattr(self,'filename',None)
-
-    def dump (self, *args, **kwargs):
-        print self.dump_string(*args, **kwargs)
-
-    def dump_string (self):
-        result=""
-        result += "KEYPAIR: pubkey=%40s..."%self.get_pubkey_string()
-        filename=self.get_filename()
-        if filename: result += "Filename %s\n"%filename
-        return result
-    
-##
-# The certificate class implements a general purpose X509 certificate, making
-# use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
-# several addition features, such as the ability to maintain a chain of
-# parent certificates, and storage of application-specific data.
-#
-# Certificates include the ability to maintain a chain of parents. Each
-# certificate includes a pointer to it's parent certificate. When loaded
-# from a file or a string, the parent chain will be automatically loaded.
-# When saving a certificate to a file or a string, the caller can choose
-# whether to save the parent certificates as well.
-
-class Certificate:
-    digest = "md5"
-
-    cert = None
-    issuerKey = None
-    issuerSubject = None
-    parent = None
-
-    separator="-----parent-----"
-
-    ##
-    # Create a certificate object.
-    #
-    # @param create If create==True, then also create a blank X509 certificate.
-    # @param subject If subject!=None, then create a blank certificate and set
-    #     it's subject name.
-    # @param string If string!=None, load the certficate from the string.
-    # @param filename If filename!=None, load the certficiate from the file.
-
-    def __init__(self, create=False, subject=None, string=None, filename=None, intermediate=None):
-        self.data = {}
-        if create or subject:
-            self.create()
-        if subject:
-            self.set_subject(subject)
-        if string:
-            self.load_from_string(string)
-        if filename:
-            self.load_from_file(filename)
-
-        if intermediate:
-            self.set_intermediate_ca(intermediate)
-
-    # Create a blank X509 certificate and store it in this object.
-
-    def create(self):
-        self.cert = crypto.X509()
-        self.cert.set_serial_number(3)
-        self.cert.gmtime_adj_notBefore(0)
-        self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
-        self.cert.set_version(2) # x509v3 so it can have extensions        
-
-
-    ##
-    # Given a pyOpenSSL X509 object, store that object inside of this
-    # certificate object.
-
-    def load_from_pyopenssl_x509(self, x509):
-        self.cert = x509
-
-    ##
-    # Load the certificate from a string
-
-    def load_from_string(self, string):
-        # if it is a chain of multiple certs, then split off the first one and
-        # load it (support for the ---parent--- tag as well as normal chained certs)
-
-        string = string.strip()
-        
-        # If it's not in proper PEM format, wrap it
-        if string.count('-----BEGIN CERTIFICATE') == 0:
-            string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
-
-        # If there is a PEM cert in there, but there is some other text first
-        # such as the text of the certificate, skip the text
-        beg = string.find('-----BEGIN CERTIFICATE')
-        if beg > 0:
-            # skipping over non cert beginning                                                                                                              
-            string = string[beg:]
-
-        parts = []
-
-        if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
-               string.count(Certificate.separator) == 0:
-            parts = string.split('-----END CERTIFICATE-----',1)
-            parts[0] += '-----END CERTIFICATE-----'
-        else:
-            parts = string.split(Certificate.separator, 1)
-
-        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
-
-        # if there are more certs, then create a parent and let the parent load
-        # itself from the remainder of the string
-        if len(parts) > 1 and parts[1] != '':
-            self.parent = self.__class__()
-            self.parent.load_from_string(parts[1])
-
-    ##
-    # Load the certificate from a file
-
-    def load_from_file(self, filename):
-        file = open(filename)
-        string = file.read()
-        self.load_from_string(string)
-        self.filename=filename
-
-    ##
-    # Save the certificate to a string.
-    #
-    # @param save_parents If save_parents==True, then also save the parent certificates.
-
-    def save_to_string(self, save_parents=True):
-        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
-        if save_parents and self.parent:
-            string = string + self.parent.save_to_string(save_parents)
-        return string
-
-    ##
-    # Save the certificate to a file.
-    # @param save_parents If save_parents==True, then also save the parent certificates.
-
-    def save_to_file(self, filename, save_parents=True, filep=None):
-        string = self.save_to_string(save_parents=save_parents)
-        if filep:
-            f = filep
-        else:
-            f = open(filename, 'w')
-        f.write(string)
-        f.close()
-        self.filename=filename
-
-    ##
-    # Save the certificate to a random file in /tmp/
-    # @param save_parents If save_parents==True, then also save the parent certificates.
-    def save_to_random_tmp_file(self, save_parents=True):
-        fp, filename = mkstemp(suffix='cert', text=True)
-        fp = os.fdopen(fp, "w")
-        self.save_to_file(filename, save_parents=True, filep=fp)
-        return filename
-
-    ##
-    # Sets the issuer private key and name
-    # @param key Keypair object containing the private key of the issuer
-    # @param subject String containing the name of the issuer
-    # @param cert (optional) Certificate object containing the name of the issuer
-
-    def set_issuer(self, key, subject=None, cert=None):
-        self.issuerKey = key
-        if subject:
-            # it's a mistake to use subject and cert params at the same time
-            assert(not cert)
-            if isinstance(subject, dict) or isinstance(subject, str):
-                req = crypto.X509Req()
-                reqSubject = req.get_subject()
-                if (isinstance(subject, dict)):
-                    for key in reqSubject.keys():
-                        setattr(reqSubject, key, subject[key])
-                else:
-                    setattr(reqSubject, "CN", subject)
-                subject = reqSubject
-                # subject is not valid once req is out of scope, so save req
-                self.issuerReq = req
-        if cert:
-            # if a cert was supplied, then get the subject from the cert
-            subject = cert.cert.get_subject()
-        assert(subject)
-        self.issuerSubject = subject
-
-    ##
-    # Get the issuer name
-
-    def get_issuer(self, which="CN"):
-        x = self.cert.get_issuer()
-        return getattr(x, which)
-
-    ##
-    # Set the subject name of the certificate
-
-    def set_subject(self, name):
-        req = crypto.X509Req()
-        subj = req.get_subject()
-        if (isinstance(name, dict)):
-            for key in name.keys():
-                setattr(subj, key, name[key])
-        else:
-            setattr(subj, "CN", name)
-        self.cert.set_subject(subj)
-    ##
-    # Get the subject name of the certificate
-
-    def get_subject(self, which="CN"):
-        x = self.cert.get_subject()
-        return getattr(x, which)
-
-    ##
-    # Get the public key of the certificate.
-    #
-    # @param key Keypair object containing the public key
-
-    def set_pubkey(self, key):
-        assert(isinstance(key, Keypair))
-        self.cert.set_pubkey(key.get_openssl_pkey())
-
-    ##
-    # Get the public key of the certificate.
-    # It is returned in the form of a Keypair object.
-
-    def get_pubkey(self):
-        m2x509 = X509.load_cert_string(self.save_to_string())
-        pkey = Keypair()
-        pkey.key = self.cert.get_pubkey()
-        pkey.m2key = m2x509.get_pubkey()
-        return pkey
-
-    def set_intermediate_ca(self, val):
-        self.intermediate = val
-        if val:
-            self.add_extension('basicConstraints', 1, 'CA:TRUE')
-
-
-
-    ##
-    # Add an X509 extension to the certificate. Add_extension can only be called
-    # once for a particular extension name, due to limitations in the underlying
-    # library.
-    #
-    # @param name string containing name of extension
-    # @param value string containing value of the extension
-
-    def add_extension(self, name, critical, value):
-        ext = crypto.X509Extension (name, critical, value)
-        self.cert.add_extensions([ext])
-
-    ##
-    # Get an X509 extension from the certificate
-
-    def get_extension(self, name):
-
-        # pyOpenSSL does not have a way to get extensions
-        m2x509 = X509.load_cert_string(self.save_to_string())
-        value = m2x509.get_ext(name).get_value()
-        
-        return value
-
-    ##
-    # Set_data is a wrapper around add_extension. It stores the parameter str in
-    # the X509 subject_alt_name extension. Set_data can only be called once, due
-    # to limitations in the underlying library.
-
-    def set_data(self, str, field='subjectAltName'):
-        # pyOpenSSL only allows us to add extensions, so if we try to set the
-        # same extension more than once, it will not work
-        if self.data.has_key(field):
-            raise "Cannot set ", field, " more than once"
-        self.data[field] = str
-        self.add_extension(field, 0, str)
-
-    ##
-    # Return the data string that was previously set with set_data
-
-    def get_data(self, field='subjectAltName'):
-        if self.data.has_key(field):
-            return self.data[field]
-
-        try:
-            uri = self.get_extension(field)
-            self.data[field] = uri
-        except LookupError:
-            return None
-
-        return self.data[field]
-
-    ##
-    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
-
-    def sign(self):
-        logger.debug('certificate.sign')
-        assert self.cert != None
-        assert self.issuerSubject != None
-        assert self.issuerKey != None
-        self.cert.set_issuer(self.issuerSubject)
-        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
-
-    ##
-    # Verify the authenticity of a certificate.
-    # @param pkey is a Keypair object representing a public key. If Pkey
-    #     did not sign the certificate, then an exception will be thrown.
-
-    def verify(self, pkey):
-        # pyOpenSSL does not have a way to verify signatures
-        m2x509 = X509.load_cert_string(self.save_to_string())
-        m2pkey = pkey.get_m2_pkey()
-        # verify it
-        return m2x509.verify(m2pkey)
-
-        # XXX alternatively, if openssl has been patched, do the much simpler:
-        # try:
-        #   self.cert.verify(pkey.get_openssl_key())
-        #   return 1
-        # except:
-        #   return 0
-
-    ##
-    # Return True if pkey is identical to the public key that is contained in the certificate.
-    # @param pkey Keypair object
-
-    def is_pubkey(self, pkey):
-        return self.get_pubkey().is_same(pkey)
-
-    ##
-    # Given a certificate cert, verify that this certificate was signed by the
-    # public key contained in cert. Throw an exception otherwise.
-    #
-    # @param cert certificate object
-
-    def is_signed_by_cert(self, cert):
-        k = cert.get_pubkey()
-        result = self.verify(k)
-        return result
-
-    ##
-    # Set the parent certficiate.
-    #
-    # @param p certificate object.
-
-    def set_parent(self, p):
-        self.parent = p
-
-    ##
-    # Return the certificate object of the parent of this certificate.
-
-    def get_parent(self):
-        return self.parent
-
-    ##
-    # Verification examines a chain of certificates to ensure that each parent
-    # signs the child, and that some certificate in the chain is signed by a
-    # trusted certificate.
-    #
-    # Verification is a basic recursion: <pre>
-    #     if this_certificate was signed by trusted_certs:
-    #         return
-    #     else
-    #         return verify_chain(parent, trusted_certs)
-    # </pre>
-    #
-    # At each recursion, the parent is tested to ensure that it did sign the
-    # child. If a parent did not sign a child, then an exception is thrown. If
-    # the bottom of the recursion is reached and the certificate does not match
-    # a trusted root, then an exception is thrown.
-    #
-    # @param Trusted_certs is a list of certificates that are trusted.
-    #
-
-    def verify_chain(self, trusted_certs = None):
-        # Verify a chain of certificates. Each certificate must be signed by
-        # the public key contained in it's parent. The chain is recursed
-        # until a certificate is found that is signed by a trusted root.
-
-        # verify expiration time
-        if self.cert.has_expired():
-            logger.debug("verify_chain: NO our certificate has expired")
-            raise CertExpired(self.get_subject(), "client cert")   
-        
-        # if this cert is signed by a trusted_cert, then we are set
-        for trusted_cert in trusted_certs:
-            if self.is_signed_by_cert(trusted_cert):
-                # verify expiration of trusted_cert ?
-                if not trusted_cert.cert.has_expired():
-                    logger.debug("verify_chain: YES cert %s signed by trusted cert %s"%(
-                            self.get_subject(), trusted_cert.get_subject()))
-                    return trusted_cert
-                else:
-                    logger.debug("verify_chain: NO cert %s is signed by trusted_cert %s, but this is expired..."%(
-                            self.get_subject(),trusted_cert.get_subject()))
-                    raise CertExpired(self.get_subject(),"trusted_cert %s"%trusted_cert.get_subject())
-
-        # if there is no parent, then no way to verify the chain
-        if not self.parent:
-            logger.debug("verify_chain: NO %s has no parent and is not in trusted roots"%self.get_subject())
-            raise CertMissingParent(self.get_subject())
-
-        # if it wasn't signed by the parent...
-        if not self.is_signed_by_cert(self.parent):
-            logger.debug("verify_chain: NO %s is not signed by parent"%self.get_subject())
-            return CertNotSignedByParent(self.get_subject())
-
-        # if the parent isn't verified...
-        logger.debug("verify_chain: .. %s, -> verifying parent %s"%(self.get_subject(),self.parent.get_subject()))
-        self.parent.verify_chain(trusted_certs)
-
-        return
-
-    ### more introspection
-    def get_extensions(self):
-        # pyOpenSSL does not have a way to get extensions
-        triples=[]
-        m2x509 = X509.load_cert_string(self.save_to_string())
-        nb_extensions=m2x509.get_ext_count()
-        logger.debug("X509 had %d extensions"%nb_extensions)
-        for i in range(nb_extensions):
-            ext=m2x509.get_ext_at(i)
-            triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) )
-        return triples
-
-    def get_data_names(self):
-        return self.data.keys()
-
-    def get_all_datas (self):
-        triples=self.get_extensions()
-        for name in self.get_data_names(): 
-            triples.append( (name,self.get_data(name),'data',) )
-        return triples
-
-    # only informative
-    def get_filename(self):
-        return getattr(self,'filename',None)
-
-    def dump (self, *args, **kwargs):
-        print self.dump_string(*args, **kwargs)
-
-    def dump_string (self,show_extensions=False):
-        result = ""
-        result += "CERTIFICATE for %s\n"%self.get_subject()
-        result += "Issued by %s\n"%self.get_issuer()
-        filename=self.get_filename()
-        if filename: result += "Filename %s\n"%filename
-        if show_extensions:
-            all_datas=self.get_all_datas()
-            result += " has %d extensions/data attached"%len(all_datas)
-            for (n,v,c) in all_datas:
-                if c=='data':
-                    result += "   data: %s=%s\n"%(n,v)
-                else:
-                    result += "    ext: %s (crit=%s)=<<<%s>>>\n"%(n,c,v)
-        return result
+#----------------------------------------------------------------------\r
+# Copyright (c) 2008 Board of Trustees, Princeton University\r
+#\r
+# Permission is hereby granted, free of charge, to any person obtaining\r
+# a copy of this software and/or hardware specification (the "Work") to\r
+# deal in the Work without restriction, including without limitation the\r
+# rights to use, copy, modify, merge, publish, distribute, sublicense,\r
+# and/or sell copies of the Work, and to permit persons to whom the Work\r
+# is furnished to do so, subject to the following conditions:\r
+#\r
+# The above copyright notice and this permission notice shall be\r
+# included in all copies or substantial portions of the Work.\r
+#\r
+# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS \r
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF \r
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND \r
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT \r
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, \r
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, \r
+# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS \r
+# IN THE WORK.\r
+#----------------------------------------------------------------------\r
+\r
+##\r
+# SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement\r
+# the necessary crypto functionality. Ideally just one of these libraries\r
+# would be used, but unfortunately each of these libraries is independently\r
+# lacking. The pyOpenSSL library is missing many necessary functions, and\r
+# the M2Crypto library has crashed inside of some of the functions. The\r
+# design decision is to use pyOpenSSL whenever possible as it seems more\r
+# stable, and only use M2Crypto for those functions that are not possible\r
+# in pyOpenSSL.\r
+#\r
+# This module exports two classes: Keypair and Certificate.\r
+##\r
+#\r
+\r
+import functools\r
+import os\r
+import tempfile\r
+import base64\r
+import traceback\r
+from tempfile import mkstemp\r
+\r
+from OpenSSL import crypto\r
+import M2Crypto\r
+from M2Crypto import X509\r
+\r
+from sfa.util.sfalogging import logger\r
+from sfa.util.xrn import urn_to_hrn\r
+from sfa.util.faults import *\r
+from sfa.util.sfalogging import logger\r
+\r
+glo_passphrase_callback = None\r
+\r
+##\r
+# A global callback msy be implemented for requesting passphrases from the\r
+# user. The function will be called with three arguments:\r
+#\r
+#    keypair_obj: the keypair object that is calling the passphrase\r
+#    string: the string containing the private key that's being loaded\r
+#    x: unknown, appears to be 0, comes from pyOpenSSL and/or m2crypto\r
+#\r
+# The callback should return a string containing the passphrase.\r
+\r
+def set_passphrase_callback(callback_func):\r
+    global glo_passphrase_callback\r
+\r
+    glo_passphrase_callback = callback_func\r
+\r
+##\r
+# Sets a fixed passphrase.\r
+\r
+def set_passphrase(passphrase):\r
+    set_passphrase_callback( lambda k,s,x: passphrase )\r
+\r
+##\r
+# Check to see if a passphrase works for a particular private key string.\r
+# Intended to be used by passphrase callbacks for input validation.\r
+\r
+def test_passphrase(string, passphrase):\r
+    try:\r
+        crypto.load_privatekey(crypto.FILETYPE_PEM, string, (lambda x: passphrase))\r
+        return True\r
+    except:\r
+        return False\r
+\r
+def convert_public_key(key):\r
+    keyconvert_path = "/usr/bin/keyconvert.py"\r
+    if not os.path.isfile(keyconvert_path):\r
+        raise IOError, "Could not find keyconvert in %s" % keyconvert_path\r
+\r
+    # we can only convert rsa keys\r
+    if "ssh-dss" in key:\r
+        return None\r
+\r
+    (ssh_f, ssh_fn) = tempfile.mkstemp()\r
+    ssl_fn = tempfile.mktemp()\r
+    os.write(ssh_f, key)\r
+    os.close(ssh_f)\r
+\r
+    cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn\r
+    os.system(cmd)\r
+\r
+    # this check leaves the temporary file containing the public key so\r
+    # that it can be expected to see why it failed.\r
+    # TODO: for production, cleanup the temporary files\r
+    if not os.path.exists(ssl_fn):\r
+        return None\r
+\r
+    k = Keypair()\r
+    try:\r
+        k.load_pubkey_from_file(ssl_fn)\r
+    except:\r
+        logger.log_exc("convert_public_key caught exception")\r
+        k = None\r
+\r
+    # remove the temporary files\r
+    os.remove(ssh_fn)\r
+    os.remove(ssl_fn)\r
+\r
+    return k\r
+\r
+##\r
+# Public-private key pairs are implemented by the Keypair class.\r
+# A Keypair object may represent both a public and private key pair, or it\r
+# may represent only a public key (this usage is consistent with OpenSSL).\r
+\r
+class Keypair:\r
+    key = None       # public/private keypair\r
+    m2key = None     # public key (m2crypto format)\r
+\r
+    ##\r
+    # Creates a Keypair object\r
+    # @param create If create==True, creates a new public/private key and\r
+    #     stores it in the object\r
+    # @param string If string!=None, load the keypair from the string (PEM)\r
+    # @param filename If filename!=None, load the keypair from the file\r
+\r
+    def __init__(self, create=False, string=None, filename=None):\r
+        if create:\r
+            self.create()\r
+        if string:\r
+            self.load_from_string(string)\r
+        if filename:\r
+            self.load_from_file(filename)\r
+\r
+    ##\r
+    # Create a RSA public/private key pair and store it inside the keypair object\r
+\r
+    def create(self):\r
+        self.key = crypto.PKey()\r
+        self.key.generate_key(crypto.TYPE_RSA, 1024)\r
+\r
+    ##\r
+    # Save the private key to a file\r
+    # @param filename name of file to store the keypair in\r
+\r
+    def save_to_file(self, filename):\r
+        open(filename, 'w').write(self.as_pem())\r
+        self.filename=filename\r
+\r
+    ##\r
+    # Load the private key from a file. Implicity the private key includes the public key.\r
+\r
+    def load_from_file(self, filename):\r
+        self.filename=filename\r
+        buffer = open(filename, 'r').read()\r
+        self.load_from_string(buffer)\r
+\r
+    ##\r
+    # Load the private key from a string. Implicitly the private key includes the public key.\r
+\r
+    def load_from_string(self, string):\r
+        if glo_passphrase_callback:\r
+            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string) )\r
+            self.m2key = M2Crypto.EVP.load_key_string(string, functools.partial(glo_passphrase_callback, self, string) )\r
+        else:\r
+            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)\r
+            self.m2key = M2Crypto.EVP.load_key_string(string)\r
+\r
+    ##\r
+    #  Load the public key from a string. No private key is loaded.\r
+\r
+    def load_pubkey_from_file(self, filename):\r
+        # load the m2 public key\r
+        m2rsakey = M2Crypto.RSA.load_pub_key(filename)\r
+        self.m2key = M2Crypto.EVP.PKey()\r
+        self.m2key.assign_rsa(m2rsakey)\r
+\r
+        # create an m2 x509 cert\r
+        m2name = M2Crypto.X509.X509_Name()\r
+        m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)\r
+        m2x509 = M2Crypto.X509.X509()\r
+        m2x509.set_pubkey(self.m2key)\r
+        m2x509.set_serial_number(0)\r
+        m2x509.set_issuer_name(m2name)\r
+        m2x509.set_subject_name(m2name)\r
+        ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()\r
+        ASN1.set_time(500)\r
+        m2x509.set_not_before(ASN1)\r
+        m2x509.set_not_after(ASN1)\r
+        # x509v3 so it can have extensions\r
+        # prob not necc since this cert itself is junk but still...\r
+        m2x509.set_version(2)\r
+        junk_key = Keypair(create=True)\r
+        m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")\r
+\r
+        # convert the m2 x509 cert to a pyopenssl x509\r
+        m2pem = m2x509.as_pem()\r
+        pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)\r
+\r
+        # get the pyopenssl pkey from the pyopenssl x509\r
+        self.key = pyx509.get_pubkey()\r
+        self.filename=filename\r
+\r
+    ##\r
+    # Load the public key from a string. No private key is loaded.\r
+\r
+    def load_pubkey_from_string(self, string):\r
+        (f, fn) = tempfile.mkstemp()\r
+        os.write(f, string)\r
+        os.close(f)\r
+        self.load_pubkey_from_file(fn)\r
+        os.remove(fn)\r
+\r
+    ##\r
+    # Return the private key in PEM format.\r
+\r
+    def as_pem(self):\r
+        return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)\r
+\r
+    ##\r
+    # Return an M2Crypto key object\r
+\r
+    def get_m2_pkey(self):\r
+        if not self.m2key:\r
+            self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())\r
+        return self.m2key\r
+\r
+    ##\r
+    # Returns a string containing the public key represented by this object.\r
+\r
+    def get_pubkey_string(self):\r
+        m2pkey = self.get_m2_pkey()\r
+        return base64.b64encode(m2pkey.as_der())\r
+\r
+    ##\r
+    # Return an OpenSSL pkey object\r
+\r
+    def get_openssl_pkey(self):\r
+        return self.key\r
+\r
+    ##\r
+    # Given another Keypair object, return TRUE if the two keys are the same.\r
+\r
+    def is_same(self, pkey):\r
+        return self.as_pem() == pkey.as_pem()\r
+\r
+    def sign_string(self, data):\r
+        k = self.get_m2_pkey()\r
+        k.sign_init()\r
+        k.sign_update(data)\r
+        return base64.b64encode(k.sign_final())\r
+\r
+    def verify_string(self, data, sig):\r
+        k = self.get_m2_pkey()\r
+        k.verify_init()\r
+        k.verify_update(data)\r
+        return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)\r
+\r
+    def compute_hash(self, value):\r
+        return self.sign_string(str(value))\r
+\r
+    # only informative\r
+    def get_filename(self):\r
+        return getattr(self,'filename',None)\r
+\r
+    def dump (self, *args, **kwargs):\r
+        print self.dump_string(*args, **kwargs)\r
+\r
+    def dump_string (self):\r
+        result=""\r
+        result += "KEYPAIR: pubkey=%40s..."%self.get_pubkey_string()\r
+        filename=self.get_filename()\r
+        if filename: result += "Filename %s\n"%filename\r
+        return result\r
+\r
+##\r
+# The certificate class implements a general purpose X509 certificate, making\r
+# use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds\r
+# several addition features, such as the ability to maintain a chain of\r
+# parent certificates, and storage of application-specific data.\r
+#\r
+# Certificates include the ability to maintain a chain of parents. Each\r
+# certificate includes a pointer to it's parent certificate. When loaded\r
+# from a file or a string, the parent chain will be automatically loaded.\r
+# When saving a certificate to a file or a string, the caller can choose\r
+# whether to save the parent certificates as well.\r
+\r
+class Certificate:\r
+    digest = "md5"\r
+\r
+    cert = None\r
+    issuerKey = None\r
+    issuerSubject = None\r
+    parent = None\r
+    isCA = None # will be a boolean once set\r
+\r
+    separator="-----parent-----"\r
+\r
+    ##\r
+    # Create a certificate object.\r
+    #\r
+    # @param lifeDays life of cert in days - default is 1825==5 years\r
+    # @param create If create==True, then also create a blank X509 certificate.\r
+    # @param subject If subject!=None, then create a blank certificate and set\r
+    #     it's subject name.\r
+    # @param string If string!=None, load the certficate from the string.\r
+    # @param filename If filename!=None, load the certficiate from the file.\r
+    # @param isCA If !=None, set whether this cert is for a CA\r
+\r
+    def __init__(self, lifeDays=1825, create=False, subject=None, string=None, filename=None, isCA=None):\r
+        self.data = {}\r
+        if create or subject:\r
+            self.create(lifeDays)\r
+        if subject:\r
+            self.set_subject(subject)\r
+        if string:\r
+            self.load_from_string(string)\r
+        if filename:\r
+            self.load_from_file(filename)\r
+\r
+        # Set the CA bit if a value was supplied\r
+        if isCA != None:\r
+            self.set_is_ca(isCA)\r
+\r
+    # Create a blank X509 certificate and store it in this object.\r
+\r
+    def create(self, lifeDays=1825):\r
+        self.cert = crypto.X509()\r
+        # FIXME: Use different serial #s\r
+        self.cert.set_serial_number(3)\r
+        self.cert.gmtime_adj_notBefore(0) # 0 means now\r
+        self.cert.gmtime_adj_notAfter(lifeDays*60*60*24) # five years is default\r
+        self.cert.set_version(2) # x509v3 so it can have extensions\r
+\r
+\r
+    ##\r
+    # Given a pyOpenSSL X509 object, store that object inside of this\r
+    # certificate object.\r
+\r
+    def load_from_pyopenssl_x509(self, x509):\r
+        self.cert = x509\r
+\r
+    ##\r
+    # Load the certificate from a string\r
+\r
+    def load_from_string(self, string):\r
+        # if it is a chain of multiple certs, then split off the first one and\r
+        # load it (support for the ---parent--- tag as well as normal chained certs)\r
+\r
+        string = string.strip()\r
+        \r
+        # If it's not in proper PEM format, wrap it\r
+        if string.count('-----BEGIN CERTIFICATE') == 0:\r
+            string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string\r
+\r
+        # If there is a PEM cert in there, but there is some other text first\r
+        # such as the text of the certificate, skip the text\r
+        beg = string.find('-----BEGIN CERTIFICATE')\r
+        if beg > 0:\r
+            # skipping over non cert beginning                                                                                                              \r
+            string = string[beg:]\r
+\r
+        parts = []\r
+\r
+        if string.count('-----BEGIN CERTIFICATE-----') > 1 and \\r
+               string.count(Certificate.separator) == 0:\r
+            parts = string.split('-----END CERTIFICATE-----',1)\r
+            parts[0] += '-----END CERTIFICATE-----'\r
+        else:\r
+            parts = string.split(Certificate.separator, 1)\r
+\r
+        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])\r
+\r
+        # if there are more certs, then create a parent and let the parent load\r
+        # itself from the remainder of the string\r
+        if len(parts) > 1 and parts[1] != '':\r
+            self.parent = self.__class__()\r
+            self.parent.load_from_string(parts[1])\r
+\r
+    ##\r
+    # Load the certificate from a file\r
+\r
+    def load_from_file(self, filename):\r
+        file = open(filename)\r
+        string = file.read()\r
+        self.load_from_string(string)\r
+        self.filename=filename\r
+\r
+    ##\r
+    # Save the certificate to a string.\r
+    #\r
+    # @param save_parents If save_parents==True, then also save the parent certificates.\r
+\r
+    def save_to_string(self, save_parents=True):\r
+        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)\r
+        if save_parents and self.parent:\r
+            string = string + self.parent.save_to_string(save_parents)\r
+        return string\r
+\r
+    ##\r
+    # Save the certificate to a file.\r
+    # @param save_parents If save_parents==True, then also save the parent certificates.\r
+\r
+    def save_to_file(self, filename, save_parents=True, filep=None):\r
+        string = self.save_to_string(save_parents=save_parents)\r
+        if filep:\r
+            f = filep\r
+        else:\r
+            f = open(filename, 'w')\r
+        f.write(string)\r
+        f.close()\r
+        self.filename=filename\r
+\r
+    ##\r
+    # Save the certificate to a random file in /tmp/\r
+    # @param save_parents If save_parents==True, then also save the parent certificates.\r
+    def save_to_random_tmp_file(self, save_parents=True):\r
+        fp, filename = mkstemp(suffix='cert', text=True)\r
+        fp = os.fdopen(fp, "w")\r
+        self.save_to_file(filename, save_parents=True, filep=fp)\r
+        return filename\r
+\r
+    ##\r
+    # Sets the issuer private key and name\r
+    # @param key Keypair object containing the private key of the issuer\r
+    # @param subject String containing the name of the issuer\r
+    # @param cert (optional) Certificate object containing the name of the issuer\r
+\r
+    def set_issuer(self, key, subject=None, cert=None):\r
+        self.issuerKey = key\r
+        if subject:\r
+            # it's a mistake to use subject and cert params at the same time\r
+            assert(not cert)\r
+            if isinstance(subject, dict) or isinstance(subject, str):\r
+                req = crypto.X509Req()\r
+                reqSubject = req.get_subject()\r
+                if (isinstance(subject, dict)):\r
+                    for key in reqSubject.keys():\r
+                        setattr(reqSubject, key, subject[key])\r
+                else:\r
+                    setattr(reqSubject, "CN", subject)\r
+                subject = reqSubject\r
+                # subject is not valid once req is out of scope, so save req\r
+                self.issuerReq = req\r
+        if cert:\r
+            # if a cert was supplied, then get the subject from the cert\r
+            subject = cert.cert.get_subject()\r
+        assert(subject)\r
+        self.issuerSubject = subject\r
+\r
+    ##\r
+    # Get the issuer name\r
+\r
+    def get_issuer(self, which="CN"):\r
+        x = self.cert.get_issuer()\r
+        return getattr(x, which)\r
+\r
+    ##\r
+    # Set the subject name of the certificate\r
+\r
+    def set_subject(self, name):\r
+        req = crypto.X509Req()\r
+        subj = req.get_subject()\r
+        if (isinstance(name, dict)):\r
+            for key in name.keys():\r
+                setattr(subj, key, name[key])\r
+        else:\r
+            setattr(subj, "CN", name)\r
+        self.cert.set_subject(subj)\r
+\r
+    ##\r
+    # Get the subject name of the certificate\r
+\r
+    def get_subject(self, which="CN"):\r
+        x = self.cert.get_subject()\r
+        return getattr(x, which)\r
+\r
+    ##\r
+    # Get a pretty-print subject name of the certificate\r
+\r
+    def get_printable_subject(self):\r
+        x = self.cert.get_subject()\r
+        return "[ OU: %s, CN: %s, SubjectAltName: %s ]" % (getattr(x, "OU"), getattr(x, "CN"), self.get_data())\r
+\r
+    ##\r
+    # Get the public key of the certificate.\r
+    #\r
+    # @param key Keypair object containing the public key\r
+\r
+    def set_pubkey(self, key):\r
+        assert(isinstance(key, Keypair))\r
+        self.cert.set_pubkey(key.get_openssl_pkey())\r
+\r
+    ##\r
+    # Get the public key of the certificate.\r
+    # It is returned in the form of a Keypair object.\r
+\r
+    def get_pubkey(self):\r
+        m2x509 = X509.load_cert_string(self.save_to_string())\r
+        pkey = Keypair()\r
+        pkey.key = self.cert.get_pubkey()\r
+        pkey.m2key = m2x509.get_pubkey()\r
+        return pkey\r
+\r
+    def set_intermediate_ca(self, val):\r
+        return self.set_is_ca(val)\r
+\r
+    # Set whether this cert is for a CA. All signers and only signers should be CAs.\r
+    # The local member starts unset, letting us check that you only set it once\r
+    # @param val Boolean indicating whether this cert is for a CA\r
+    def set_is_ca(self, val):\r
+        if val is None:\r
+            return\r
+\r
+        if self.isCA != None:\r
+            # Can't double set properties\r
+            raise "Cannot set basicConstraints CA:?? more than once. Was %s, trying to set as %s" % (self.isCA, val)\r
+\r
+        self.isCA = val\r
+        if val:\r
+            self.add_extension('basicConstraints', 1, 'CA:TRUE')\r
+        else:\r
+            self.add_extension('basicConstraints', 1, 'CA:FALSE')\r
+\r
+\r
+\r
+    ##\r
+    # Add an X509 extension to the certificate. Add_extension can only be called\r
+    # once for a particular extension name, due to limitations in the underlying\r
+    # library.\r
+    #\r
+    # @param name string containing name of extension\r
+    # @param value string containing value of the extension\r
+\r
+    def add_extension(self, name, critical, value):\r
+        oldExtVal = None\r
+        try:\r
+            oldExtVal = self.get_extension(name)\r
+        except:\r
+            # M2Crypto LookupError when the extension isn't there (yet)\r
+            pass\r
+\r
+        # This code limits you from adding the extension with the same value\r
+        # The method comment says you shouldn't do this with the same name\r
+        # But actually it (m2crypto) appears to allow you to do this.\r
+        if oldExtVal and oldExtVal == value:\r
+            # don't add this extension again\r
+            # just do nothing as here\r
+            return\r
+        # FIXME: What if they are trying to set with a different value?\r
+        # Is this ever OK? Or should we raise an exception?\r
+#        elif oldExtVal:\r
+#            raise "Cannot add extension %s which had val %s with new val %s" % (name, oldExtVal, value)\r
+\r
+        ext = crypto.X509Extension (name, critical, value)\r
+        self.cert.add_extensions([ext])\r
+\r
+    ##\r
+    # Get an X509 extension from the certificate\r
+\r
+    def get_extension(self, name):\r
+\r
+        # pyOpenSSL does not have a way to get extensions\r
+        m2x509 = X509.load_cert_string(self.save_to_string())\r
+        value = m2x509.get_ext(name).get_value()\r
+\r
+        return value\r
+\r
+    ##\r
+    # Set_data is a wrapper around add_extension. It stores the parameter str in\r
+    # the X509 subject_alt_name extension. Set_data can only be called once, due\r
+    # to limitations in the underlying library.\r
+\r
+    def set_data(self, str, field='subjectAltName'):\r
+        # pyOpenSSL only allows us to add extensions, so if we try to set the\r
+        # same extension more than once, it will not work\r
+        if self.data.has_key(field):\r
+            raise "Cannot set ", field, " more than once"\r
+        self.data[field] = str\r
+        self.add_extension(field, 0, str)\r
+\r
+    ##\r
+    # Return the data string that was previously set with set_data\r
+\r
+    def get_data(self, field='subjectAltName'):\r
+        if self.data.has_key(field):\r
+            return self.data[field]\r
+\r
+        try:\r
+            uri = self.get_extension(field)\r
+            self.data[field] = uri\r
+        except LookupError:\r
+            return None\r
+\r
+        return self.data[field]\r
+\r
+    ##\r
+    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().\r
+\r
+    def sign(self):\r
+        logger.debug('certificate.sign')\r
+        assert self.cert != None\r
+        assert self.issuerSubject != None\r
+        assert self.issuerKey != None\r
+        self.cert.set_issuer(self.issuerSubject)\r
+        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)\r
+\r
+    ##\r
+    # Verify the authenticity of a certificate.\r
+    # @param pkey is a Keypair object representing a public key. If Pkey\r
+    #     did not sign the certificate, then an exception will be thrown.\r
+\r
+    def verify(self, pkey):\r
+        # pyOpenSSL does not have a way to verify signatures\r
+        m2x509 = X509.load_cert_string(self.save_to_string())\r
+        m2pkey = pkey.get_m2_pkey()\r
+        # verify it\r
+        return m2x509.verify(m2pkey)\r
+\r
+        # XXX alternatively, if openssl has been patched, do the much simpler:\r
+        # try:\r
+        #   self.cert.verify(pkey.get_openssl_key())\r
+        #   return 1\r
+        # except:\r
+        #   return 0\r
+\r
+    ##\r
+    # Return True if pkey is identical to the public key that is contained in the certificate.\r
+    # @param pkey Keypair object\r
+\r
+    def is_pubkey(self, pkey):\r
+        return self.get_pubkey().is_same(pkey)\r
+\r
+    ##\r
+    # Given a certificate cert, verify that this certificate was signed by the\r
+    # public key contained in cert. Throw an exception otherwise.\r
+    #\r
+    # @param cert certificate object\r
+\r
+    def is_signed_by_cert(self, cert):\r
+        k = cert.get_pubkey()\r
+        result = self.verify(k)\r
+        return result\r
+\r
+    ##\r
+    # Set the parent certficiate.\r
+    #\r
+    # @param p certificate object.\r
+\r
+    def set_parent(self, p):\r
+        self.parent = p\r
+\r
+    ##\r
+    # Return the certificate object of the parent of this certificate.\r
+\r
+    def get_parent(self):\r
+        return self.parent\r
+\r
+    ##\r
+    # Verification examines a chain of certificates to ensure that each parent\r
+    # signs the child, and that some certificate in the chain is signed by a\r
+    # trusted certificate.\r
+    #\r
+    # Verification is a basic recursion: <pre>\r
+    #     if this_certificate was signed by trusted_certs:\r
+    #         return\r
+    #     else\r
+    #         return verify_chain(parent, trusted_certs)\r
+    # </pre>\r
+    #\r
+    # At each recursion, the parent is tested to ensure that it did sign the\r
+    # child. If a parent did not sign a child, then an exception is thrown. If\r
+    # the bottom of the recursion is reached and the certificate does not match\r
+    # a trusted root, then an exception is thrown.\r
+    # Also require that parents are CAs.\r
+    #\r
+    # @param Trusted_certs is a list of certificates that are trusted.\r
+    #\r
+\r
+    def verify_chain(self, trusted_certs = None):\r
+        # Verify a chain of certificates. Each certificate must be signed by\r
+        # the public key contained in it's parent. The chain is recursed\r
+        # until a certificate is found that is signed by a trusted root.\r
+\r
+        # verify expiration time\r
+        if self.cert.has_expired():\r
+            logger.debug("verify_chain: NO our certificate %s has expired" % self.get_printable_subject())\r
+            raise CertExpired(self.get_printable_subject(), "client cert")\r
+\r
+        # if this cert is signed by a trusted_cert, then we are set\r
+        for trusted_cert in trusted_certs:\r
+            if self.is_signed_by_cert(trusted_cert):\r
+                # verify expiration of trusted_cert ?\r
+                if not trusted_cert.cert.has_expired():\r
+                    logger.debug("verify_chain: YES cert %s signed by trusted cert %s"%(\r
+                            self.get_printable_subject(), trusted_cert.get_printable_subject()))\r
+                    return trusted_cert\r
+                else:\r
+                    logger.debug("verify_chain: NO cert %s is signed by trusted_cert %s, but this is expired..."%(\r
+                            self.get_printable_subject(),trusted_cert.get_printable_subject()))\r
+                    raise CertExpired(self.get_printable_subject()," signer trusted_cert %s"%trusted_cert.get_printable_subject())\r
+\r
+        # if there is no parent, then no way to verify the chain\r
+        if not self.parent:\r
+            logger.debug("verify_chain: NO %s has no parent and issuer %s is not in %d trusted roots"%self.get_printable_subject(), self.get_issuer(), len(trusted_certs))\r
+            raise CertMissingParent(self.get_printable_subject(), "Non trusted issuer: %s out of %d trusted roots" % (self.get_issuer(), len(trusted_certs)))\r
+\r
+        # if it wasn't signed by the parent...\r
+        if not self.is_signed_by_cert(self.parent):\r
+            logger.debug("verify_chain: NO %s is not signed by parent %s, but by %s"%self.get_printable_subject(), self.parent.get_printable_subject(), self.get_issuer())\r
+            return CertNotSignedByParent(self.get_printable_subject(), "parent %s, issuer %s" % (selr.parent.get_printable_subject(), self.get_issuer()))\r
+\r
+        # Confirm that the parent is a CA. Only CAs can be trusted as\r
+        # signers.\r
+        # Note that trusted roots are not parents, so don't need to be\r
+        # CAs.\r
+        # Ugly - cert objects aren't parsed so we need to read the\r
+        # extension and hope there are no other basicConstraints\r
+        if not self.parent.isCA and not (self.parent.get_extension('basicConstraints') == 'CA:TRUE'):\r
+            logger.warn("verify_chain: cert %s's parent %s is not a CA" % (self.get_printable_subject(), self.parent.get_printable_subject()))\r
+            return CertNotSignedByParent(self.get_printable_subject(), "Parent %s not a CA" % self.parent.get_printable_subject())\r
+\r
+        # if the parent isn't verified...\r
+        logger.debug("verify_chain: .. %s, -> verifying parent %s"%(self.get_printable_subject(),self.parent.get_printable_subject()))\r
+        self.parent.verify_chain(trusted_certs)\r
+\r
+        return\r
+\r
+    ### more introspection\r
+    def get_extensions(self):\r
+        # pyOpenSSL does not have a way to get extensions\r
+        triples=[]\r
+        m2x509 = X509.load_cert_string(self.save_to_string())\r
+        nb_extensions=m2x509.get_ext_count()\r
+        logger.debug("X509 had %d extensions"%nb_extensions)\r
+        for i in range(nb_extensions):\r
+            ext=m2x509.get_ext_at(i)\r
+            triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) )\r
+        return triples\r
+\r
+    def get_data_names(self):\r
+        return self.data.keys()\r
+\r
+    def get_all_datas (self):\r
+        triples=self.get_extensions()\r
+        for name in self.get_data_names():\r
+            triples.append( (name,self.get_data(name),'data',) )\r
+        return triples\r
+\r
+    # only informative\r
+    def get_filename(self):\r
+        return getattr(self,'filename',None)\r
+\r
+    def dump (self, *args, **kwargs):\r
+        print self.dump_string(*args, **kwargs)\r
+\r
+    def dump_string (self,show_extensions=False):\r
+        result = ""\r
+        result += "CERTIFICATE for %s\n"%self.get_printable_subject()\r
+        result += "Issued by %s\n"%self.get_issuer()\r
+        filename=self.get_filename()\r
+        if filename: result += "Filename %s\n"%filename\r
+        if show_extensions:\r
+            all_datas=self.get_all_datas()\r
+            result += " has %d extensions/data attached"%len(all_datas)\r
+            for (n,v,c) in all_datas:\r
+                if c=='data':\r
+                    result += "   data: %s=%s\n"%(n,v)\r
+                else:\r
+                    result += "    ext: %s (crit=%s)=<<<%s>>>\n"%(n,c,v)\r
+        return result\r
index 6fb8c0c..f112a95 100644 (file)
@@ -47,7 +47,7 @@ from sfa.trust.certificate import Keypair
 from sfa.trust.credential_legacy import CredentialLegacy\r
 from sfa.trust.rights import Right, Rights, determine_rights\r
 from sfa.trust.gid import GID\r
-from sfa.util.xrn import urn_to_hrn\r
+from sfa.util.xrn import urn_to_hrn, hrn_authfor_hrn\r
 \r
 # 2 weeks, in seconds \r
 DEFAULT_CREDENTIAL_LIFETIME = 86400 * 14\r
@@ -269,7 +269,16 @@ class Credential(object):
     def get_subject(self):\r
         if not self.gidObject:\r
             self.decode()\r
-        return self.gidObject.get_subject()   \r
+        return self.gidObject.get_printable_subject()\r
+\r
+    def get_summary_tostring(self):\r
+        if not self.gidObject:\r
+            self.decode()\r
+        obj = self.gidObject.get_printable_subject()\r
+        caller = self.gidCaller.get_printable_subject()\r
+        exp = self.get_expiration()\r
+        # Summarize the rights too? The issuer?\r
+        return "[ Grant %s rights on %s until %s ]" % (caller, obj, exp)\r
 \r
     def get_signature(self):\r
         if not self.signature:\r
@@ -672,13 +681,19 @@ class Credential(object):
 \r
         # Is this a signed-cred or just a cred?\r
         if len(signed_cred) > 0:\r
-            cred = signed_cred[0].getElementsByTagName("credential")[0]\r
+            creds = signed_cred[0].getElementsByTagName("credential")\r
             signatures = signed_cred[0].getElementsByTagName("signatures")\r
             if len(signatures) > 0:\r
                 sigs = signatures[0].getElementsByTagName("Signature")\r
         else:\r
-            cred = doc.getElementsByTagName("credential")[0]\r
+            creds = doc.getElementsByTagName("credential")\r
         \r
+        if creds is None or len(creds) == 0:\r
+            # malformed cred file\r
+            raise CredentialNotVerifiable("Malformed XML: No credential tag found")\r
+\r
+        # Just take the first cred if there are more than one\r
+        cred = creds[0]\r
 \r
         self.set_refid(cred.getAttribute("xml:id"))\r
         self.set_expiration(utcparse(getTextNode(cred, "expires")))\r
@@ -738,6 +753,7 @@ class Credential(object):
     # . That the issuer of the credential is the authority in the target's urn\r
     #    . In the case of a delegated credential, this must be true of the root\r
     # . That all of the gids presented in the credential are valid\r
+    #    . Including verifying GID chains, and includ the issuer\r
     # . The credential is not expired\r
     #\r
     # -- For Delegates (credentials with parents)\r
@@ -764,7 +780,7 @@ class Credential(object):
                 xmlschema = etree.XMLSchema(schema_doc)\r
                 if not xmlschema.validate(tree):\r
                     error = xmlschema.error_log.last_error\r
-                    message = "%s (line %s)" % (error.message, error.line)\r
+                    message = "%s: %s (line %s)" % (self.get_summary_tostring(), error.message, error.line)\r
                     raise CredentialNotVerifiable(message)\r
 \r
         if trusted_certs_required and trusted_certs is None:\r
@@ -797,7 +813,7 @@ class Credential(object):
         \r
         # make sure it is not expired\r
         if self.get_expiration() < datetime.datetime.utcnow():\r
-            raise CredentialNotVerifiable("Credential expired at %s" % self.expiration.isoformat())\r
+            raise CredentialNotVerifiable("Credential %s expired at %s" % (self.get_summary_tostring(), self.expiration.isoformat()))\r
 \r
         # Verify the signatures\r
         filename = self.save_to_random_tmp_file()\r
@@ -805,7 +821,7 @@ class Credential(object):
             cert_args = " ".join(['--trusted-pem %s' % x for x in trusted_certs])\r
 \r
         # If caller explicitly passed in None that means skip cert chain validation.\r
-        # Strange and not typical\r
+        # Strange and not typical\r
         if trusted_certs is not None:\r
             # Verify the gids of this cred and of its parents\r
             for cur_cred in self.get_credential_list():\r
@@ -837,15 +853,16 @@ class Credential(object):
                     mstart = mstart + 4\r
                     mend = verified.find('\\', mstart)\r
                     msg = verified[mstart:mend]\r
-                raise CredentialNotVerifiable("xmlsec1 error verifying cred using Signature ID %s: %s %s" % (ref, msg, verified.strip()))\r
+                raise CredentialNotVerifiable("xmlsec1 error verifying cred %s using Signature ID %s: %s %s" % (self.get_summary_tostring(), ref, msg, verified.strip()))\r
         os.remove(filename)\r
 \r
         # Verify the parents (delegation)\r
         if self.parent:\r
             self.verify_parent(self.parent)\r
 \r
-        # Make sure the issuer is the target's authority\r
-        self.verify_issuer()\r
+        # Make sure the issuer is the target's authority, and is\r
+        # itself a valid GID\r
+        self.verify_issuer(trusted_cert_objects)\r
         return True\r
 \r
     ##\r
@@ -863,39 +880,61 @@ class Credential(object):
         return list\r
     \r
     ##\r
-    # Make sure the credential's target gid was signed by (or is the same) the entity that signed\r
-    # the original credential or an authority over that namespace.\r
-    def verify_issuer(self):                \r
+    # Make sure the credential's target gid (a) was signed by or (b)\r
+    # is the same as the entity that signed the original credential,\r
+    # or (c) is an authority over the target's namespace.\r
+    # Also ensure that the credential issuer / signer itself has a valid\r
+    # GID signature chain (signed by an authority with namespace rights).\r
+    def verify_issuer(self, trusted_gids):\r
         root_cred = self.get_credential_list()[-1]\r
         root_target_gid = root_cred.get_gid_object()\r
         root_cred_signer = root_cred.get_signature().get_issuer_gid()\r
 \r
+        # Case 1:\r
+        # Allow non authority to sign target and cred about target.\r
+        #\r
+        # Why do we need to allow non authorities to sign?\r
+        # If in the target gid validation step we correctly\r
+        # checked that the target is only signed by an authority,\r
+        # then this is just a special case of case 3.\r
+        # This short-circuit is the common case currently -\r
+        # and cause GID validation doesn't check 'authority',\r
+        # this allows users to generate valid slice credentials.\r
         if root_target_gid.is_signed_by_cert(root_cred_signer):\r
             # cred signer matches target signer, return success\r
             return\r
 \r
-        root_target_gid_str = root_target_gid.save_to_string()\r
-        root_cred_signer_str = root_cred_signer.save_to_string()\r
-        if root_target_gid_str == root_cred_signer_str:\r
-            # cred signer is target, return success\r
-            return\r
+        # Case 2:\r
+        # Allow someone to sign credential about themeselves. Used?\r
+        # If not, remove this.\r
+        #root_target_gid_str = root_target_gid.save_to_string()\r
+        #root_cred_signer_str = root_cred_signer.save_to_string()\r
+        #if root_target_gid_str == root_cred_signer_str:\r
+        #    # cred signer is target, return success\r
+        #    return\r
+\r
+        # Case 3:\r
 \r
         # root_cred_signer is not the target_gid\r
-        # So this is a different gid that we have not verified\r
-        # Did xmlsec1 verify the cert chain on this already?\r
-        # Regardless, it hasn't verified that the gid meets the HRN namespace\r
-        # requirements\r
-# FIXME: Uncomment once we verify this is right\r
-#        root_cred_signer.verify_chain(trusted_cert_objects)\r
-\r
-        # See if it the signer is an authority over the domain of the target\r
+        # So this is a different gid that we have not verified.\r
+        # xmlsec1 verified the cert chain on this already, but\r
+        # it hasn't verified that the gid meets the HRN namespace\r
+        # requirements.\r
+        # Below we'll ensure that it is an authority.\r
+        # But we haven't verified that it is _signed by_ an authority\r
+        # We also don't know if xmlsec1 requires that cert signers\r
+        # are marked as CAs.\r
+        root_cred_signer.verify_chain(trusted_gids)\r
+\r
+        # See if the signer is an authority over the domain of the target.\r
+        # There are multiple types of authority - accept them all here\r
         # Maybe should be (hrn, type) = urn_to_hrn(root_cred_signer.get_urn())\r
         root_cred_signer_type = root_cred_signer.get_type()\r
-        if (root_cred_signer_type == 'authority'):\r
+        if (root_cred_signer_type.find('authority') == 0):\r
             #logger.debug('Cred signer is an authority')\r
             # signer is an authority, see if target is in authority's domain\r
-            hrn = root_cred_signer.get_hrn()\r
-            if root_target_gid.get_hrn().startswith(hrn):\r
+            signerhrn = root_cred_signer.get_hrn()\r
+            if hrn_authfor_hrn(signerhrn, root_target_gid.get_hrn()):\r
                 return\r
 \r
         # We've required that the credential be signed by an authority\r
@@ -920,23 +959,23 @@ class Credential(object):
         # make sure the rights given to the child are a subset of the\r
         # parents rights (and check delegate bits)\r
         if not parent_cred.get_privileges().is_superset(self.get_privileges()):\r
-            raise ChildRightsNotSubsetOfParent(("Parent cred ref %s rights " % self.parent.get_refid()) + \r
-                self.parent.get_privileges().save_to_string() + (" not superset of delegated cred ref %s rights " % self.get_refid()) +\r
+            raise ChildRightsNotSubsetOfParent(("Parent cred ref %s rights " % parent_cred.get_refid()) +\r
+                self.parent.get_privileges().save_to_string() + (" not superset of delegated cred %s ref %s rights " % (self.get_summary_tostring(), self.get_refid())) +\r
                 self.get_privileges().save_to_string())\r
 \r
         # make sure my target gid is the same as the parent's\r
         if not parent_cred.get_gid_object().save_to_string() == \\r
            self.get_gid_object().save_to_string():\r
-            raise CredentialNotVerifiable("Target gid not equal between parent and child")\r
+            raise CredentialNotVerifiable("Delegated cred %s: Target gid not equal between parent and child. Parent %s" % (self.get_summary_tostring(), parent_cred.get_summary_tostring()))\r
 \r
         # make sure my expiry time is <= my parent's\r
         if not parent_cred.get_expiration() >= self.get_expiration():\r
-            raise CredentialNotVerifiable("Delegated credential expires after parent")\r
+            raise CredentialNotVerifiable("Delegated credential %s expires after parent %s" % (self.get_summary_tostring(), parent_cred.get_summary_tostring()))\r
 \r
         # make sure my signer is the parent's caller\r
         if not parent_cred.get_gid_caller().save_to_string(False) == \\r
            self.get_signature().get_issuer_gid().save_to_string(False):\r
-            raise CredentialNotVerifiable("Delegated credential not signed by parent caller")\r
+            raise CredentialNotVerifiable("Delegated credential %s not signed by parent %s's caller" % (self.get_summary_tostring(), parent_cred.get_summary_tostring()))\r
                 \r
         # Recurse\r
         if parent_cred.parent:\r
index b881a1f..a7b2e71 100644 (file)
-#----------------------------------------------------------------------
-# Copyright (c) 2008 Board of Trustees, Princeton University
-#
-# Permission is hereby granted, free of charge, to any person obtaining
-# a copy of this software and/or hardware specification (the "Work") to
-# deal in the Work without restriction, including without limitation the
-# rights to use, copy, modify, merge, publish, distribute, sublicense,
-# and/or sell copies of the Work, and to permit persons to whom the Work
-# is furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be
-# included in all copies or substantial portions of the Work.
-#
-# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
-# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 
-# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
-# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
-# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
-# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
-# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS 
-# IN THE WORK.
-#----------------------------------------------------------------------
-##
-# Implements SFA GID. GIDs are based on certificates, and the GID class is a
-# descendant of the certificate class.
-##
-
-import xmlrpclib
-import uuid
-
-from sfa.util.sfalogging import logger 
-from sfa.trust.certificate import Certificate
-from sfa.util.xrn import hrn_to_urn, urn_to_hrn
-
-##
-# Create a new uuid. Returns the UUID as a string.
-
-def create_uuid():
-    return str(uuid.uuid4().int)
-
-##
-# GID is a tuple:
-#    (uuid, urn, public_key)
-#
-# UUID is a unique identifier and is created by the python uuid module
-#    (or the utility function create_uuid() in gid.py).
-#
-# HRN is a human readable name. It is a dotted form similar to a backward domain
-#    name. For example, planetlab.us.arizona.bakers.
-#
-# URN is a human readable identifier of form:
-#   "urn:publicid:IDN+toplevelauthority[:sub-auth.]*[\res. type]\ +object name"
-#   For  example, urn:publicid:IDN+planetlab:us:arizona+user+bakers      
-#
-# PUBLIC_KEY is the public key of the principal identified by the UUID/HRN.
-# It is a Keypair object as defined in the cert.py module.
-#
-# It is expected that there is a one-to-one pairing between UUIDs and HRN,
-# but it is uncertain how this would be inforced or if it needs to be enforced.
-#
-# These fields are encoded using xmlrpc into the subjectAltName field of the
-# x509 certificate. Note: Call encode() once the fields have been filled in
-# to perform this encoding.
-
-
-class GID(Certificate):
-    uuid = None
-    hrn = None
-    urn = None
-
-    ##
-    # Create a new GID object
-    #
-    # @param create If true, create the X509 certificate
-    # @param subject If subject!=None, create the X509 cert and set the subject name
-    # @param string If string!=None, load the GID from a string
-    # @param filename If filename!=None, load the GID from a file
-
-    def __init__(self, create=False, subject=None, string=None, filename=None, uuid=None, hrn=None, urn=None):
-        
-        Certificate.__init__(self, create, subject, string, filename)
-        if subject:
-            logger.debug("Creating GID for subject: %s" % subject)
-        if uuid:
-            self.uuid = int(uuid)
-        if hrn:
-            self.hrn = hrn
-            self.urn = hrn_to_urn(hrn, 'unknown')
-        if urn:
-            self.urn = urn
-            self.hrn, type = urn_to_hrn(urn)
-
-    def set_uuid(self, uuid):
-        if isinstance(uuid, str):
-            self.uuid = int(uuid)
-        else:
-            self.uuid = uuid
-
-    def get_uuid(self):
-        if not self.uuid:
-            self.decode()
-        return self.uuid
-
-    def set_hrn(self, hrn):
-        self.hrn = hrn
-
-    def get_hrn(self):
-        if not self.hrn:
-            self.decode()
-        return self.hrn
-
-    def set_urn(self, urn):
-        self.urn = urn
-        self.hrn, type = urn_to_hrn(urn)
-    def get_urn(self):
-        if not self.urn:
-            self.decode()
-        return self.urn            
-
-    def get_type(self):
-        if not self.urn:
-            self.decode()
-        _, t = urn_to_hrn(self.urn)
-        return t
-    
-    ##
-    # Encode the GID fields and package them into the subject-alt-name field
-    # of the X509 certificate. This must be called prior to signing the
-    # certificate. It may only be called once per certificate.
-
-    def encode(self):
-        if self.urn:
-            urn = self.urn
-        else:
-            urn = hrn_to_urn(self.hrn, None)
-            
-        str = "URI:" + urn
-
-        if self.uuid:
-            str += ", " + "URI:" + uuid.UUID(int=self.uuid).urn
-        
-        self.set_data(str, 'subjectAltName')
-
-        
-
-
-    ##
-    # Decode the subject-alt-name field of the X509 certificate into the
-    # fields of the GID. This is automatically called by the various get_*()
-    # functions in this class.
-
-    def decode(self):
-        data = self.get_data('subjectAltName')
-        dict = {}
-        if data:
-            if data.lower().startswith('uri:http://<params>'):
-                dict = xmlrpclib.loads(data[11:])[0][0]
-            else:
-                spl = data.split(', ')
-                for val in spl:
-                    if val.lower().startswith('uri:urn:uuid:'):
-                        dict['uuid'] = uuid.UUID(val[4:]).int
-                    elif val.lower().startswith('uri:urn:publicid:idn+'):
-                        dict['urn'] = val[4:]
-                    
-        self.uuid = dict.get("uuid", None)
-        self.urn = dict.get("urn", None)
-        self.hrn = dict.get("hrn", None)    
-        if self.urn:
-            self.hrn = urn_to_hrn(self.urn)[0]
-
-    ##
-    # Dump the credential to stdout.
-    #
-    # @param indent specifies a number of spaces to indent the output
-    # @param dump_parents If true, also dump the parents of the GID
-
-    def dump(self, *args, **kwargs):
-        print self.dump_string(*args,**kwargs)
-
-    def dump_string(self, indent=0, dump_parents=False):
-        result=" "*(indent-2) + "GID\n"
-        result += " "*indent + "hrn:" + str(self.get_hrn()) +"\n"
-        result += " "*indent + "urn:" + str(self.get_urn()) +"\n"
-        result += " "*indent + "uuid:" + str(self.get_uuid()) + "\n"
-        filename=self.get_filename()
-        if filename: result += "Filename %s\n"%filename
-
-        if self.parent and dump_parents:
-            result += " "*indent + "parent:\n"
-            result += self.parent.dump_string(indent+4, dump_parents)
-        return result
-
-    ##
-    # Verify the chain of authenticity of the GID. First perform the checks
-    # of the certificate class (verifying that each parent signs the child,
-    # etc). In addition, GIDs also confirm that the parent's HRN is a prefix
-    # of the child's HRN.
-    #
-    # Verifying these prefixes prevents a rogue authority from signing a GID
-    # for a principal that is not a member of that authority. For example,
-    # planetlab.us.arizona cannot sign a GID for planetlab.us.princeton.foo.
-
-    def verify_chain(self, trusted_certs = None):
-        # do the normal certificate verification stuff
-        trusted_root = Certificate.verify_chain(self, trusted_certs)        
-       
-        if self.parent:
-            # make sure the parent's hrn is a prefix of the child's hrn
-            if not self.get_hrn().startswith(self.parent.get_hrn()):
-                raise GidParentHrn("This cert HRN %s doesnt start with parent HRN %s" % (self.get_hrn(), self.parent.get_hrn()))
-        else:
-            # make sure that the trusted root's hrn is a prefix of the child's
-            trusted_gid = GID(string=trusted_root.save_to_string())
-            trusted_type = trusted_gid.get_type()
-            trusted_hrn = trusted_gid.get_hrn()
-            #if trusted_type == 'authority':
-            #    trusted_hrn = trusted_hrn[:trusted_hrn.rindex('.')]
-            cur_hrn = self.get_hrn()
-            if not self.get_hrn().startswith(trusted_hrn):
-                raise GidParentHrn("Trusted roots HRN %s isnt start of this cert %s" % (trusted_hrn, cur_hrn))
-
-        return
+#----------------------------------------------------------------------\r
+# Copyright (c) 2008 Board of Trustees, Princeton University\r
+#\r
+# Permission is hereby granted, free of charge, to any person obtaining\r
+# a copy of this software and/or hardware specification (the "Work") to\r
+# deal in the Work without restriction, including without limitation the\r
+# rights to use, copy, modify, merge, publish, distribute, sublicense,\r
+# and/or sell copies of the Work, and to permit persons to whom the Work\r
+# is furnished to do so, subject to the following conditions:\r
+#\r
+# The above copyright notice and this permission notice shall be\r
+# included in all copies or substantial portions of the Work.\r
+#\r
+# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS \r
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF \r
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND \r
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT \r
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, \r
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, \r
+# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS \r
+# IN THE WORK.\r
+#----------------------------------------------------------------------\r
+##\r
+# Implements SFA GID. GIDs are based on certificates, and the GID class is a\r
+# descendant of the certificate class.\r
+##\r
+\r
+import xmlrpclib\r
+import uuid\r
+\r
+from sfa.trust.certificate import Certificate\r
+\r
+from sfa.util.faults import *\r
+from sfa.util.sfalogging import logger\r
+from sfa.util.xrn import hrn_to_urn, urn_to_hrn, hrn_authfor_hrn\r
+\r
+##\r
+# Create a new uuid. Returns the UUID as a string.\r
+\r
+def create_uuid():\r
+    return str(uuid.uuid4().int)\r
+\r
+##\r
+# GID is a tuple:\r
+#    (uuid, urn, public_key)\r
+#\r
+# UUID is a unique identifier and is created by the python uuid module\r
+#    (or the utility function create_uuid() in gid.py).\r
+#\r
+# HRN is a human readable name. It is a dotted form similar to a backward domain\r
+#    name. For example, planetlab.us.arizona.bakers.\r
+#\r
+# URN is a human readable identifier of form:\r
+#   "urn:publicid:IDN+toplevelauthority[:sub-auth.]*[\res. type]\ +object name"\r
+#   For  example, urn:publicid:IDN+planetlab:us:arizona+user+bakers      \r
+#\r
+# PUBLIC_KEY is the public key of the principal identified by the UUID/HRN.\r
+# It is a Keypair object as defined in the cert.py module.\r
+#\r
+# It is expected that there is a one-to-one pairing between UUIDs and HRN,\r
+# but it is uncertain how this would be inforced or if it needs to be enforced.\r
+#\r
+# These fields are encoded using xmlrpc into the subjectAltName field of the\r
+# x509 certificate. Note: Call encode() once the fields have been filled in\r
+# to perform this encoding.\r
+\r
+\r
+class GID(Certificate):\r
+    uuid = None\r
+    hrn = None\r
+    urn = None\r
+\r
+    ##\r
+    # Create a new GID object\r
+    #\r
+    # @param create If true, create the X509 certificate\r
+    # @param subject If subject!=None, create the X509 cert and set the subject name\r
+    # @param string If string!=None, load the GID from a string\r
+    # @param filename If filename!=None, load the GID from a file\r
+    # @param lifeDays life of GID in days - default is 1825==5 years\r
+\r
+    def __init__(self, create=False, subject=None, string=None, filename=None, uuid=None, hrn=None, urn=None, lifeDays=1825):\r
+        \r
+        Certificate.__init__(self, lifeDays, create, subject, string, filename)\r
+        if subject:\r
+            logger.debug("Creating GID for subject: %s" % subject)\r
+        if uuid:\r
+            self.uuid = int(uuid)\r
+        if hrn:\r
+            self.hrn = hrn\r
+            self.urn = hrn_to_urn(hrn, 'unknown')\r
+        if urn:\r
+            self.urn = urn\r
+            self.hrn, type = urn_to_hrn(urn)\r
+\r
+    def set_uuid(self, uuid):\r
+        if isinstance(uuid, str):\r
+            self.uuid = int(uuid)\r
+        else:\r
+            self.uuid = uuid\r
+\r
+    def get_uuid(self):\r
+        if not self.uuid:\r
+            self.decode()\r
+        return self.uuid\r
+\r
+    def set_hrn(self, hrn):\r
+        self.hrn = hrn\r
+\r
+    def get_hrn(self):\r
+        if not self.hrn:\r
+            self.decode()\r
+        return self.hrn\r
+\r
+    def set_urn(self, urn):\r
+        self.urn = urn\r
+        self.hrn, type = urn_to_hrn(urn)\r
\r
+    def get_urn(self):\r
+        if not self.urn:\r
+            self.decode()\r
+        return self.urn            \r
+\r
+    def get_type(self):\r
+        if not self.urn:\r
+            self.decode()\r
+        _, t = urn_to_hrn(self.urn)\r
+        return t\r
+    \r
+    ##\r
+    # Encode the GID fields and package them into the subject-alt-name field\r
+    # of the X509 certificate. This must be called prior to signing the\r
+    # certificate. It may only be called once per certificate.\r
+\r
+    def encode(self):\r
+        if self.urn:\r
+            urn = self.urn\r
+        else:\r
+            urn = hrn_to_urn(self.hrn, None)\r
+            \r
+        str = "URI:" + urn\r
+\r
+        if self.uuid:\r
+            str += ", " + "URI:" + uuid.UUID(int=self.uuid).urn\r
+        \r
+        self.set_data(str, 'subjectAltName')\r
+\r
+        \r
+\r
+\r
+    ##\r
+    # Decode the subject-alt-name field of the X509 certificate into the\r
+    # fields of the GID. This is automatically called by the various get_*()\r
+    # functions in this class.\r
+\r
+    def decode(self):\r
+        data = self.get_data('subjectAltName')\r
+        dict = {}\r
+        if data:\r
+            if data.lower().startswith('uri:http://<params>'):\r
+                dict = xmlrpclib.loads(data[11:])[0][0]\r
+            else:\r
+                spl = data.split(', ')\r
+                for val in spl:\r
+                    if val.lower().startswith('uri:urn:uuid:'):\r
+                        dict['uuid'] = uuid.UUID(val[4:]).int\r
+                    elif val.lower().startswith('uri:urn:publicid:idn+'):\r
+                        dict['urn'] = val[4:]\r
+                    \r
+        self.uuid = dict.get("uuid", None)\r
+        self.urn = dict.get("urn", None)\r
+        self.hrn = dict.get("hrn", None)    \r
+        if self.urn:\r
+            self.hrn = urn_to_hrn(self.urn)[0]\r
+\r
+    ##\r
+    # Dump the credential to stdout.\r
+    #\r
+    # @param indent specifies a number of spaces to indent the output\r
+    # @param dump_parents If true, also dump the parents of the GID\r
+\r
+    def dump(self, *args, **kwargs):\r
+        print self.dump_string(*args,**kwargs)\r
+\r
+    def dump_string(self, indent=0, dump_parents=False):\r
+        result=" "*(indent-2) + "GID\n"\r
+        result += " "*indent + "hrn:" + str(self.get_hrn()) +"\n"\r
+        result += " "*indent + "urn:" + str(self.get_urn()) +"\n"\r
+        result += " "*indent + "uuid:" + str(self.get_uuid()) + "\n"\r
+        filename=self.get_filename()\r
+        if filename: result += "Filename %s\n"%filename\r
+\r
+        if self.parent and dump_parents:\r
+            result += " "*indent + "parent:\n"\r
+            result += self.parent.dump_string(indent+4, dump_parents)\r
+        return result\r
+\r
+    ##\r
+    # Verify the chain of authenticity of the GID. First perform the checks\r
+    # of the certificate class (verifying that each parent signs the child,\r
+    # etc). In addition, GIDs also confirm that the parent's HRN is a prefix\r
+    # of the child's HRN, and the parent is of type 'authority'.\r
+    #\r
+    # Verifying these prefixes prevents a rogue authority from signing a GID\r
+    # for a principal that is not a member of that authority. For example,\r
+    # planetlab.us.arizona cannot sign a GID for planetlab.us.princeton.foo.\r
+\r
+    def verify_chain(self, trusted_certs = None):\r
+        # do the normal certificate verification stuff\r
+        trusted_root = Certificate.verify_chain(self, trusted_certs)        \r
+       \r
+        if self.parent:\r
+            # make sure the parent's hrn is a prefix of the child's hrn\r
+            if not hrn_authfor_hrn(self.parent.get_hrn(), self.get_hrn()):\r
+                raise GidParentHrn("This cert HRN %s isn't in the namespace for parent HRN %s" % (self.get_hrn(), self.parent.get_hrn()))\r
+\r
+            # Parent must also be an authority (of some type) to sign a GID\r
+            # There are multiple types of authority - accept them all here\r
+            if not self.parent.get_type().find('authority') == 0:\r
+                raise GidInvalidParentHrn("This cert %s's parent %s is not an authority (is a %s)" % (self.get_hrn(), self.parent.get_hrn(), self.parent.get_type()))\r
+\r
+            # Then recurse up the chain - ensure the parent is a trusted\r
+            # root or is in the namespace of a trusted root\r
+            self.parent.verify_chain(trusted_certs)\r
+        else:\r
+            # make sure that the trusted root's hrn is a prefix of the child's\r
+            trusted_gid = GID(string=trusted_root.save_to_string())\r
+            trusted_type = trusted_gid.get_type()\r
+            trusted_hrn = trusted_gid.get_hrn()\r
+            #if trusted_type == 'authority':\r
+            #    trusted_hrn = trusted_hrn[:trusted_hrn.rindex('.')]\r
+            cur_hrn = self.get_hrn()\r
+            if not hrn_authfor_hrn(trusted_hrn, cur_hrn):\r
+                raise GidParentHrn("Trusted root with HRN %s isn't a namespace authority for this cert %s" % (trusted_hrn, cur_hrn))\r
+\r
+            # There are multiple types of authority - accept them all here\r
+            if not trusted_type.find('authority') == 0:\r
+                raise GidInvalidParentHrn("This cert %s's trusted root signer %s is not an authority (is a %s)" % (self.get_hrn(), trusted_hrn, trusted_type))\r
+\r
+        return\r
index f28329c..6323436 100644 (file)
@@ -229,15 +229,28 @@ class Hierarchy:
     # @param uuid the unique identifier to store in the GID
     # @param pkey the public key to store in the GID
 
-    def create_gid(self, xrn, uuid, pkey):
+    def create_gid(self, xrn, uuid, pkey, CA=False):
         hrn, type = urn_to_hrn(xrn)
+        parent_hrn = get_authority(hrn)
         # Using hrn_to_urn() here to make sure the urn is in the right format
         # If xrn was a hrn instead of a urn, then the gid's urn will be
         # of type None 
         urn = hrn_to_urn(hrn, type)
         gid = GID(subject=hrn, uuid=uuid, hrn=hrn, urn=urn)
 
-        parent_hrn = get_authority(hrn)
+        # is this a CA cert
+        if hrn == self.config.SFA_INTERFACE_HRN or not parent_hrn:
+            # root or sub authority  
+            gid.set_intermediate_ca(True)
+        elif type and 'authority' in type:
+            # authority type
+            gid.set_intermediate_ca(True)
+        elif CA:
+            gid.set_intermediate_ca(True)
+        else:
+            gid.set_intermediate_ca(False)
+
+        # set issuer
         if not parent_hrn or hrn == self.config.SFA_INTERFACE_HRN:
             # if there is no parent hrn, then it must be self-signed. this
             # is where we terminate the recursion
@@ -247,7 +260,6 @@ class Hierarchy:
             parent_auth_info = self.get_auth_info(parent_hrn)
             gid.set_issuer(parent_auth_info.get_pkey_object(), parent_auth_info.hrn)
             gid.set_parent(parent_auth_info.get_gid_object())
-            gid.set_intermediate_ca(True)
 
         gid.set_pubkey(pkey)
         gid.encode()
index c166400..75c4f32 100644 (file)
-import re
-
-from sfa.util.faults import *
-
-# for convenience and smoother translation - we should get rid of these functions eventually 
-def get_leaf(hrn): return Xrn(hrn).get_leaf()
-def get_authority(hrn): return Xrn(hrn).get_authority_hrn()
-def urn_to_hrn(urn): xrn=Xrn(urn); return (xrn.hrn, xrn.type)
-def hrn_to_urn(hrn,type): return Xrn(hrn, type=type).urn
-
-def urn_to_sliver_id(urn, slice_id, node_id, index=0):
-    urn = urn.replace('+slice+', '+sliver+')    
-    return ":".join([urn, str(slice_id), str(node_id), str(index)])
-
-class Xrn:
-
-    ########## basic tools on HRNs
-    # split a HRN-like string into pieces
-    # this is like split('.') except for escaped (backslashed) dots
-    # e.g. hrn_split ('a\.b.c.d') -> [ 'a\.b','c','d']
-    @staticmethod
-    def hrn_split(hrn):
-        return [ x.replace('--sep--','\\.') for x in hrn.replace('\\.','--sep--').split('.') ]
-
-    # e.g. hrn_leaf ('a\.b.c.d') -> 'd'
-    @staticmethod
-    def hrn_leaf(hrn): return Xrn.hrn_split(hrn)[-1]
-
-    # e.g. hrn_auth_list ('a\.b.c.d') -> ['a\.b', 'c']
-    @staticmethod
-    def hrn_auth_list(hrn): return Xrn.hrn_split(hrn)[0:-1]
-    
-    # e.g. hrn_auth ('a\.b.c.d') -> 'a\.b.c'
-    @staticmethod
-    def hrn_auth(hrn): return '.'.join(Xrn.hrn_auth_list(hrn))
-    
-    # e.g. escape ('a.b') -> 'a\.b'
-    @staticmethod
-    def escape(token): return re.sub(r'([^\\])\.', r'\1\.', token)
-    # e.g. unescape ('a\.b') -> 'a.b'
-    @staticmethod
-    def unescape(token): return token.replace('\\.','.')
-        
-    URN_PREFIX = "urn:publicid:IDN"
-
-    ########## basic tools on URNs
-    @staticmethod
-    def urn_full (urn):
-        if urn.startswith(Xrn.URN_PREFIX): return urn
-        else: return Xrn.URN_PREFIX+URN
-    @staticmethod
-    def urn_meaningful (urn):
-        if urn.startswith(Xrn.URN_PREFIX): return urn[len(Xrn.URN_PREFIX):]
-        else: return urn
-    @staticmethod
-    def urn_split (urn):
-        return Xrn.urn_meaningful(urn).split('+')
-
-    ####################
-    # the local fields that are kept consistent
-    # self.urn
-    # self.hrn
-    # self.type
-    # self.path
-    # provide either urn, or (hrn + type)
-    def __init__ (self, xrn, type=None):
-        if not xrn: xrn = ""
-        # user has specified xrn : guess if urn or hrn
-        if xrn.startswith(Xrn.URN_PREFIX):
-            self.hrn=None
-            self.urn=xrn
-            self.urn_to_hrn()
-        else:
-            self.urn=None
-            self.hrn=xrn
-            self.type=type
-            self.hrn_to_urn()
-# happens all the time ..
-#        if not type:
-#            debug_logger.debug("type-less Xrn's are not safe")
-
-    def get_urn(self): return self.urn
-    def get_hrn(self): return self.hrn
-    def get_type(self): return self.type
-    def get_hrn_type(self): return (self.hrn, self.type)
-
-    def _normalize(self):
-        if self.hrn is None: raise SfaAPIError, "Xrn._normalize"
-        if not hasattr(self,'leaf'): 
-            self.leaf=Xrn.hrn_split(self.hrn)[-1]
-        # self.authority keeps a list
-        if not hasattr(self,'authority'): 
-            self.authority=Xrn.hrn_auth_list(self.hrn)
-
-    def get_leaf(self):
-        self._normalize()
-        return self.leaf
-
-    def get_authority_hrn(self): 
-        self._normalize()
-        return '.'.join( self.authority )
-    
-    def get_authority_urn(self): 
-        self._normalize()
-        return ':'.join( [Xrn.unescape(x) for x in self.authority] )
-    
-    def urn_to_hrn(self):
-        """
-        compute tuple (hrn, type) from urn
-        """
-        
-#        if not self.urn or not self.urn.startswith(Xrn.URN_PREFIX):
-        if not self.urn.startswith(Xrn.URN_PREFIX):
-            raise SfaAPIError, "Xrn.urn_to_hrn"
-
-        parts = Xrn.urn_split(self.urn)
-        type=parts.pop(2)
-        # Remove the authority name (e.g. '.sa')
-        if type == 'authority':
-            name = parts.pop()
-            # Drop the sa. This is a bad hack, but its either this
-            # or completely change how record types are generated/stored   
-            if name != 'sa':
-                type = type + "+" + name
-
-        # convert parts (list) into hrn (str) by doing the following
-        # 1. remove blank parts
-        # 2. escape dots inside parts
-        # 3. replace ':' with '.' inside parts
-        # 3. join parts using '.' 
-        hrn = '.'.join([Xrn.escape(part).replace(':','.') for part in parts if part]) 
-
-        self.hrn=str(hrn)
-        self.type=str(type)
-    
-    def hrn_to_urn(self):
-        """
-        compute urn from (hrn, type)
-        """
-
-#        if not self.hrn or self.hrn.startswith(Xrn.URN_PREFIX):
-        if self.hrn.startswith(Xrn.URN_PREFIX):
-            raise SfaAPIError, "Xrn.hrn_to_urn, hrn=%s"%self.hrn
-
-        if self.type and self.type.startswith('authority'):
-            self.authority = Xrn.hrn_split(self.hrn)
-            type_parts = self.type.split("+")
-            self.type = type_parts[0]
-            name = 'sa'
-            if len(type_parts) > 1:
-                name = type_parts[1]
-        else:
-            self.authority = Xrn.hrn_auth_list(self.hrn)
-            name = Xrn.hrn_leaf(self.hrn)
-
-        authority_string = self.get_authority_urn()
-
-        if self.type == None:
-            urn = "+".join(['',authority_string,Xrn.unescape(name)])
-        else:
-            urn = "+".join(['',authority_string,self.type,Xrn.unescape(name)])
-        
-        self.urn = Xrn.URN_PREFIX + urn
-
-    def dump_string(self):
-        result="-------------------- XRN\n"
-        result += "URN=%s\n"%self.urn
-        result += "HRN=%s\n"%self.hrn
-        result += "TYPE=%s\n"%self.type
-        result += "LEAF=%s\n"%self.get_leaf()
-        result += "AUTH(hrn format)=%s\n"%self.get_authority_hrn()
-        result += "AUTH(urn format)=%s\n"%self.get_authority_urn()
-        return result
-        
+#----------------------------------------------------------------------\r
+# Copyright (c) 2008 Board of Trustees, Princeton University\r
+#\r
+# Permission is hereby granted, free of charge, to any person obtaining\r
+# a copy of this software and/or hardware specification (the "Work") to\r
+# deal in the Work without restriction, including without limitation the\r
+# rights to use, copy, modify, merge, publish, distribute, sublicense,\r
+# and/or sell copies of the Work, and to permit persons to whom the Work\r
+# is furnished to do so, subject to the following conditions:\r
+#\r
+# The above copyright notice and this permission notice shall be\r
+# included in all copies or substantial portions of the Work.\r
+#\r
+# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS \r
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF \r
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND \r
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT \r
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, \r
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, \r
+# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS \r
+# IN THE WORK.\r
+#----------------------------------------------------------------------\r
+\r
+import re\r
+\r
+from sfa.util.faults import *\r
+\r
+# for convenience and smoother translation - we should get rid of these functions eventually \r
+def get_leaf(hrn): return Xrn(hrn).get_leaf()\r
+def get_authority(hrn): return Xrn(hrn).get_authority_hrn()\r
+def urn_to_hrn(urn): xrn=Xrn(urn); return (xrn.hrn, xrn.type)\r
+def hrn_to_urn(hrn,type): return Xrn(hrn, type=type).urn\r
+def hrn_authfor_hrn(parenthrn, hrn): return Xrn.hrn_is_auth_for_hrn(parenthrn, hrn)\r
+\r
+def urn_to_sliver_id(urn, slice_id, node_id, index=0):\r
+    return ":".join([urn, slice_id, node_id, index])\r
+\r
+class Xrn:\r
+\r
+    ########## basic tools on HRNs\r
+    # split a HRN-like string into pieces\r
+    # this is like split('.') except for escaped (backslashed) dots\r
+    # e.g. hrn_split ('a\.b.c.d') -> [ 'a\.b','c','d']\r
+    @staticmethod\r
+    def hrn_split(hrn):\r
+        return [ x.replace('--sep--','\\.') for x in hrn.replace('\\.','--sep--').split('.') ]\r
+\r
+    # e.g. hrn_leaf ('a\.b.c.d') -> 'd'\r
+    @staticmethod\r
+    def hrn_leaf(hrn): return Xrn.hrn_split(hrn)[-1]\r
+\r
+    # e.g. hrn_auth_list ('a\.b.c.d') -> ['a\.b', 'c']\r
+    @staticmethod\r
+    def hrn_auth_list(hrn): return Xrn.hrn_split(hrn)[0:-1]\r
+    \r
+    # e.g. hrn_auth ('a\.b.c.d') -> 'a\.b.c'\r
+    @staticmethod\r
+    def hrn_auth(hrn): return '.'.join(Xrn.hrn_auth_list(hrn))\r
+    \r
+    # e.g. escape ('a.b') -> 'a\.b'\r
+    @staticmethod\r
+    def escape(token): return re.sub(r'([^\\])\.', r'\1\.', token)\r
+\r
+    # e.g. unescape ('a\.b') -> 'a.b'\r
+    @staticmethod\r
+    def unescape(token): return token.replace('\\.','.')\r
+\r
+    # Return the HRN authority chain from top to bottom.\r
+    # e.g. hrn_auth_chain('a\.b.c.d') -> ['a\.b', 'a\.b.c']\r
+    @staticmethod\r
+    def hrn_auth_chain(hrn):\r
+        parts = Xrn.hrn_auth_list(hrn)\r
+        chain = []\r
+        for i in range(len(parts)):\r
+            chain.append('.'.join(parts[:i+1]))\r
+        # Include the HRN itself?\r
+        #chain.append(hrn)\r
+        return chain\r
+\r
+    # Is the given HRN a true authority over the namespace of the other\r
+    # child HRN?\r
+    # A better alternative than childHRN.startswith(parentHRN)\r
+    # e.g. hrn_is_auth_for_hrn('a\.b', 'a\.b.c.d') -> True,\r
+    # but hrn_is_auth_for_hrn('a', 'a\.b.c.d') -> False\r
+    # Also hrn_is_uauth_for_hrn('a\.b.c.d', 'a\.b.c.d') -> True\r
+    @staticmethod\r
+    def hrn_is_auth_for_hrn(parenthrn, hrn):\r
+        if parenthrn == hrn:\r
+            return True\r
+        for auth in Xrn.hrn_auth_chain(hrn):\r
+            if parenthrn == auth:\r
+                return True\r
+        return False\r
+\r
+    URN_PREFIX = "urn:publicid:IDN"\r
+\r
+    ########## basic tools on URNs\r
+    @staticmethod\r
+    def urn_full (urn):\r
+        if urn.startswith(Xrn.URN_PREFIX): return urn\r
+        else: return Xrn.URN_PREFIX+URN\r
+    @staticmethod\r
+    def urn_meaningful (urn):\r
+        if urn.startswith(Xrn.URN_PREFIX): return urn[len(Xrn.URN_PREFIX):]\r
+        else: return urn\r
+    @staticmethod\r
+    def urn_split (urn):\r
+        return Xrn.urn_meaningful(urn).split('+')\r
+\r
+    ####################\r
+    # the local fields that are kept consistent\r
+    # self.urn\r
+    # self.hrn\r
+    # self.type\r
+    # self.path\r
+    # provide either urn, or (hrn + type)\r
+    def __init__ (self, xrn, type=None):\r
+        if not xrn: xrn = ""\r
+        # user has specified xrn : guess if urn or hrn\r
+        if xrn.startswith(Xrn.URN_PREFIX):\r
+            self.hrn=None\r
+            self.urn=xrn\r
+            self.urn_to_hrn()\r
+        else:\r
+            self.urn=None\r
+            self.hrn=xrn\r
+            self.type=type\r
+            self.hrn_to_urn()\r
+# happens all the time ..\r
+#        if not type:\r
+#            debug_logger.debug("type-less Xrn's are not safe")\r
+\r
+    def get_urn(self): return self.urn\r
+    def get_hrn(self): return self.hrn\r
+    def get_type(self): return self.type\r
+    def get_hrn_type(self): return (self.hrn, self.type)\r
+\r
+    def _normalize(self):\r
+        if self.hrn is None: raise SfaAPIError, "Xrn._normalize"\r
+        if not hasattr(self,'leaf'): \r
+            self.leaf=Xrn.hrn_split(self.hrn)[-1]\r
+        # self.authority keeps a list\r
+        if not hasattr(self,'authority'): \r
+            self.authority=Xrn.hrn_auth_list(self.hrn)\r
+\r
+    def get_leaf(self):\r
+        self._normalize()\r
+        return self.leaf\r
+\r
+    def get_authority_hrn(self): \r
+        self._normalize()\r
+        return '.'.join( self.authority )\r
+    \r
+    def get_authority_urn(self): \r
+        self._normalize()\r
+        return ':'.join( [Xrn.unescape(x) for x in self.authority] )\r
+    \r
+    def urn_to_hrn(self):\r
+        """\r
+        compute tuple (hrn, type) from urn\r
+        """\r
+        \r
+#        if not self.urn or not self.urn.startswith(Xrn.URN_PREFIX):\r
+        if not self.urn.startswith(Xrn.URN_PREFIX):\r
+            raise SfaAPIError, "Xrn.urn_to_hrn"\r
+\r
+        parts = Xrn.urn_split(self.urn)\r
+        type=parts.pop(2)\r
+        # Remove the authority name (e.g. '.sa')\r
+        if type == 'authority':\r
+            name = parts.pop()\r
+            # Drop the sa. This is a bad hack, but its either this\r
+            # or completely change how record types are generated/stored   \r
+            if name != 'sa':\r
+                type = type + "+" + name\r
+\r
+        # convert parts (list) into hrn (str) by doing the following\r
+        # 1. remove blank parts\r
+        # 2. escape dots inside parts\r
+        # 3. replace ':' with '.' inside parts\r
+        # 3. join parts using '.' \r
+        hrn = '.'.join([Xrn.escape(part).replace(':','.') for part in parts if part]) \r
+\r
+        self.hrn=str(hrn)\r
+        self.type=str(type)\r
+    \r
+    def hrn_to_urn(self):\r
+        """\r
+        compute urn from (hrn, type)\r
+        """\r
+\r
+#        if not self.hrn or self.hrn.startswith(Xrn.URN_PREFIX):\r
+        if self.hrn.startswith(Xrn.URN_PREFIX):\r
+            raise SfaAPIError, "Xrn.hrn_to_urn, hrn=%s"%self.hrn\r
+\r
+        if self.type and self.type.startswith('authority'):\r
+            self.authority = Xrn.hrn_split(self.hrn)\r
+            type_parts = self.type.split("+")\r
+            self.type = type_parts[0]\r
+            name = 'sa'\r
+            if len(type_parts) > 1:\r
+                name = type_parts[1]\r
+        else:\r
+            self.authority = Xrn.hrn_auth_list(self.hrn)\r
+            name = Xrn.hrn_leaf(self.hrn)\r
+\r
+        authority_string = self.get_authority_urn()\r
+\r
+        if self.type == None:\r
+            urn = "+".join(['',authority_string,Xrn.unescape(name)])\r
+        else:\r
+            urn = "+".join(['',authority_string,self.type,Xrn.unescape(name)])\r
+        \r
+        self.urn = Xrn.URN_PREFIX + urn\r
+\r
+    def dump_string(self):\r
+        result="-------------------- XRN\n"\r
+        result += "URN=%s\n"%self.urn\r
+        result += "HRN=%s\n"%self.hrn\r
+        result += "TYPE=%s\n"%self.type\r
+        result += "LEAF=%s\n"%self.get_leaf()\r
+        result += "AUTH(hrn format)=%s\n"%self.get_authority_hrn()\r
+        result += "AUTH(urn format)=%s\n"%self.get_authority_urn()\r
+        return result\r
+        \r