From: Thierry Parmentelat Date: Thu, 3 Nov 2011 14:28:50 +0000 (+0100) Subject: Merge branch 'upstreammaster' X-Git-Tag: sfa-2.1-24~38 X-Git-Url: http://git.onelab.eu/?p=sfa.git;a=commitdiff_plain;h=e39e728991b762ae0b52b52b06655f0c7f1b7421;hp=18317236953434dfa9ea41328fc21731039e55f1 Merge branch 'upstreammaster' --- diff --git a/INSTALL.txt b/INSTALL.txt index d748883f..cc59cebd 100644 --- a/INSTALL.txt +++ b/INSTALL.txt @@ -63,9 +63,9 @@ This will initialize /etc/sfa/authorities/server.key from /etc/sfa/authorities/p This will start Registry, Slice Manager and Aggregate Manager. Your ps command output would look like: # ps -ef | grep python -root 24944 1 0 May11 ? 00:00:00 /usr/bin/python /usr/bin/sfa-server.py -r -d -root 24957 1 0 May11 ? 00:00:00 /usr/bin/python /usr/bin/sfa-server.py -a -d -root 24970 1 0 May11 ? 00:00:00 /usr/bin/python /usr/bin/sfa-server.py -s -d +root 24944 1 0 May11 ? 00:00:00 /usr/bin/python /usr/bin/sfa-start.py -r -d +root 24957 1 0 May11 ? 00:00:00 /usr/bin/python /usr/bin/sfa-start.py -a -d +root 24970 1 0 May11 ? 00:00:00 /usr/bin/python /usr/bin/sfa-start.py -s -d ------- 4) Configure SFA client: diff --git a/Makefile b/Makefile index f162f5f6..04ed4bae 100644 --- a/Makefile +++ b/Makefile @@ -87,7 +87,7 @@ force: ########## tags: - find . -type f | egrep -v '/\.git/|/\.svn/|TAGS|\.py[co]$$|\.doc$$|\.html$$|\.pdf$$|~$$|\.png$$|\.svg$$|\.out$$|\.bak$$|\.xml$$' | xargs etags + find . -type f | egrep -v '/\.git/|/\.svn/|TAGS|~$$|\.(py[co]|doc|html|pdf|png|svg|out|bak|xml|dg)$$' | xargs etags .PHONY: tags signatures: @@ -125,7 +125,7 @@ sfiAddAttribute.py sfiAddSliver.py sfiDeleteAttribute.py sfiDeleteSliver.py sfiL sfiListSlivers.py sfadump.py BINS = ./config/sfa-config-tty ./config/gen-sfa-cm-config.py \ - ./sfa/plc/sfa-import-plc.py ./sfa/plc/sfa-nuke-plc.py ./sfa/server/sfa-server.py \ + ./sfa/plc/sfa-import-plc.py ./sfa/plc/sfa-nuke-plc.py ./sfa/server/sfa-start.py \ $(foreach client,$(CLIENTS),./sfa/client/$(client)) sync: @@ -137,7 +137,9 @@ ifeq (,$(SSHURL)) else +$(RSYNC) ./sfa/ $(SSHURL)/usr/lib\*/python2.\*/site-packages/sfa/ +$(RSYNC) ./tests/ $(SSHURL)/root/tests-sfa - +$(RSYNC) $(BINS) $(SSHURL)/usr/bin + +$(RSYNC) $(BINS) $(SSHURL)/usr/bin/ + +$(RSYNC) ./sfa/init.d/sfa $(SSHURL)/etc/init.d/ + +$(RSYNC) ./config/default_config.xml $(SSHURL)/etc/sfa/ $(SSHCOMMAND) exec service sfa restart endif diff --git a/config/default_config.xml b/config/default_config.xml index 212dee42..670d6f26 100644 --- a/config/default_config.xml +++ b/config/default_config.xml @@ -18,6 +18,15 @@ Thierry Parmentelat Basic system variables. + + Generic Flavour + pl + This string refers to a class located in sfa.generic that describes + which specific implementation needs to be used for api, manager and driver objects. + PlanetLab users do not need to change this setting. + + + Human readable name plc @@ -49,9 +58,10 @@ Thierry Parmentelat it look like the user is the one performing the operation. Doing this requires a valid key pair and credential for the user. This option defines the path where key pairs and credentials are generated and stored. - This functionality is used by the SFA web gui + This functionality is used by the SFA web GUI. + diff --git a/config/topology b/config/topology new file mode 100644 index 00000000..24a8e13a --- /dev/null +++ b/config/topology @@ -0,0 +1,20 @@ +# Links in the physical topology, gleaned from looking at the Internet2 +# topology map. Link (a, b) connects sites with IDs a and b. +# +# 2 12 # I2 Princeton - New York +# 11 13 # I2 Chicago - Wash DC +# 11 15 # I2 Chicago - Atlanta +# 11 16 # I2 Chicago - CESNET +# 11 17 # I2 Chicago - Kansas City +# 12 13 # I2 New York - Wash DC +# 13 15 # I2 Wash DC - Atlanta +# 14 15 # Ga Tech - I2 Atlanta +# 15 19 # I2 Atlanta - Houston +# 17 19 # I2 Kansas City - Houston +# 17 22 # I2 Kansas City - Salt Lake City +# 17 24 # I2 Kansas City - UMKC +# 19 20 # I2 Houston - Los Angeles +# 20 21 # I2 Los Angeles - Seattle +# 20 22 # I2 Los Angeles - Salt Lake City +# 21 22 # I2 Seattle - Salt Lake City + diff --git a/docs/Makefile b/docs/Makefile index 463dbaf6..5f34949f 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,7 +3,7 @@ doc: pythondoc.py ../sfa/util/certificate.py ../sfa/util/credential.py ../sfa/util/gid.py \ ../sfa/util/rights.py ../sfa/util/config.py ../sfa/trust/hierarchy.py \ ../sfa/util/record.py ../sfa/util/client.py \ - ../sfa/util/server.py + ../sfa/server/sfaserver.py pythondoc.py ../sfa/registry/registry.py ../sfa/registry/import.py \ ../sfa/registry/nuke.py diff --git a/setup.py b/setup.py index b10be90b..c28e38da 100755 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ bins = [ 'sfa/plc/sfa-import-plc.py', 'sfa/plc/sfa-nuke-plc.py', 'sfa/server/sfa-ca.py', - 'sfa/server/sfa-server.py', + 'sfa/server/sfa-start.py', 'sfa/server/sfa-clean-peer-records.py', 'sfa/server/sfa_component_setup.py', 'sfa/client/sfi.py', @@ -36,14 +36,15 @@ bins = [ package_dirs = [ 'sfa', - 'sfa/client', - 'sfa/methods', - 'sfa/plc', - 'sfa/server', 'sfa/trust', 'sfa/util', + 'sfa/client', + 'sfa/server', + 'sfa/methods', + 'sfa/generic', 'sfa/managers', 'sfa/managers/vini', + 'sfa/plc', 'sfa/rspecs', 'sfa/rspecs/elements', 'sfa/rspecs/elements/versions', @@ -59,6 +60,7 @@ data_files = [('/etc/sfa/', [ 'config/aggregates.xml', 'config/registries.xml', 'config/default_config.xml', 'config/sfi_config', + 'config/topology', 'sfa/managers/pl/pl.rng', 'sfa/trust/credential.xsd', 'sfa/trust/top.xsd', diff --git a/sfa.spec b/sfa.spec index 4f10df47..d2979e44 100644 --- a/sfa.spec +++ b/sfa.spec @@ -1,6 +1,6 @@ %define name sfa -%define version 1.0 -%define taglevel 36 +%define version 1.1 +%define taglevel 1 %define release %{taglevel}%{?pldistro:.%{pldistro}}%{?date:.%{date}} %global python_sitearch %( python -c "from distutils.sysconfig import get_python_lib; print get_python_lib(1)" ) @@ -121,7 +121,7 @@ rm -rf $RPM_BUILD_ROOT %files # sfa and sfatables depend each other. -%{_bindir}/sfa-server.py* +%{_bindir}/sfa-start.py* /etc/sfatables/* %{python_sitelib}/* %{_bindir}/keyconvert.py* @@ -144,6 +144,7 @@ rm -rf $RPM_BUILD_ROOT /etc/sfa/sig.xsd /etc/sfa/xml.xsd /etc/sfa/protogeni-rspec-common.xsd +/etc/sfa/topology %{_bindir}/sfa-config-tty %{_bindir}/sfa-import-plc.py* %{_bindir}/sfa-clean-peer-records.py* @@ -196,6 +197,11 @@ fi [ "$1" -ge "1" ] && service sfa-cm restart || : %changelog +* Fri Oct 28 2011 Thierry Parmentelat - sfa-1.1-1 +- first support for protogeni rspecs is working +- vini no longer needs a specific manager +- refactoring underway towards more flexible/generic architecture + * Thu Sep 15 2011 Tony Mack - sfa-1.0-36 - Unicode-friendliness for user names with accents/special chars. - Fix bug that could cause create the client to fail when calling CreateSliver for a slice that has the same hrn as a user. diff --git a/sfa/client/sfi.py b/sfa/client/sfi.py index 83a66f9e..49392466 100755 --- a/sfa/client/sfi.py +++ b/sfa/client/sfi.py @@ -18,7 +18,7 @@ from sfa.util.sfalogging import sfi_logger from sfa.trust.certificate import Keypair, Certificate from sfa.trust.gid import GID from sfa.trust.credential import Credential -from sfa.util.sfaticket import SfaTicket +from sfa.trust.sfaticket import SfaTicket from sfa.util.record import SfaRecord, UserRecord, SliceRecord, NodeRecord, AuthorityRecord from sfa.rspecs.rspec import RSpec from sfa.rspecs.rspec_converter import RSpecConverter @@ -232,9 +232,9 @@ class Sfi: parser.add_option("-d", "--delegate", dest="delegate", default=None, action="store_true", help="Include a credential delegated to the user's root"+\ - "authority in set of credentials for this call") - - # registy filter option + "authority in set of credentials for this call") + + # registy filter option if command in ("list", "show", "remove"): parser.add_option("-t", "--type", dest="type", type="choice", help="type filter ([all]|user|slice|authority|node|aggregate)", @@ -521,7 +521,7 @@ class Sfi: if args: hrn = args[0] gid = self._get_gid(hrn) - self.logger.debug("Sfi.get_gid-> %s",gid.save_to_string(save_parents=True)) + self.logger.debug("Sfi.get_gid-> %s" % gid.save_to_string(save_parents=True)) return gid def _get_gid(self, hrn=None, type=None): @@ -985,10 +985,15 @@ class Sfi: slice_urn = hrn_to_urn(slice_hrn, 'slice') user_cred = self.get_user_cred() slice_cred = self.get_slice_cred(slice_hrn).save_to_string(save_parents=True) - # delegate the cred to the callers root authority - delegated_cred = self.delegate_cred(slice_cred, get_authority(self.authority)+'.slicemanager') - #delegated_cred = self.delegate_cred(slice_cred, get_authority(slice_hrn)) - #creds.append(delegated_cred) + + if hasattr(opts, 'aggregate') and opts.aggregate: + delegated_cred = None + else: + # delegate the cred to the callers root authority + delegated_cred = self.delegate_cred(slice_cred, get_authority(self.authority)+'.slicemanager') + #delegated_cred = self.delegate_cred(slice_cred, get_authority(slice_hrn)) + #creds.append(delegated_cred) + rspec_file = self.get_rspec_file(args[1]) rspec = open(rspec_file).read() @@ -1013,11 +1018,13 @@ class Sfi: creds = [slice_cred] else: users = sfa_users_arg(user_records, slice_record) - creds = [slice_cred, delegated_cred] + creds = [slice_cred] + if delegated_cred: + creds.append(delegated_cred) call_args = [slice_urn, creds, rspec, users] if self.server_supports_call_id_arg(server): call_args.append(unique_call_id()) - + result = server.CreateSliver(*call_args) if opts.file is None: print result diff --git a/sfa/client/sfiAddAttribute.py b/sfa/client/sfiAddAttribute.py index 9c2eae5c..f22e63e8 100755 --- a/sfa/client/sfiAddAttribute.py +++ b/sfa/client/sfiAddAttribute.py @@ -1,7 +1,6 @@ #! /usr/bin/env python import sys -from sfa.util.rspecHelper import RSpec, Commands from sfa.client.sfi_commands import Commands from sfa.rspecs.rspec import RSpec diff --git a/sfa/client/sfiAddLinks.py b/sfa/client/sfiAddLinks.py new file mode 100755 index 00000000..f5b28888 --- /dev/null +++ b/sfa/client/sfiAddLinks.py @@ -0,0 +1,45 @@ +#! /usr/bin/env python + +import sys +from sfa.client.sfi_commands import Commands +from sfa.rspecs.rspec import RSpec +from sfa.rspecs.version_manager import VersionManager + +command = Commands(usage="%prog [options] node1 node2...", + description="Add links to the RSpec. " + + "This command reads in an RSpec and outputs a modified " + + "RSpec. Use this to add links to your slivers") +command.add_linkfile_option() +command.prep() + +if not command.opts.linkfile: + print "Missing link list -- exiting" + command.parser.print_help() + sys.exit(1) + +if command.opts.infile: + infile=file(command.opts.infile) +else: + infile=sys.stdin +if command.opts.outfile: + outfile=file(command.opts.outfile,"w") +else: + outfile=sys.stdout +ad_rspec = RSpec(infile) +links = file(command.opts.linkfile).read().split('\n') +link_tuples = map(lambda x: tuple(x.split()), links) + +version_manager = VersionManager() +try: + type = ad_rspec.version.type + version_num = ad_rspec.version.version + request_version = version_manager._get_version(type, version_num, 'request') + request_rspec = RSpec(version=request_version) + request_rspec.version.merge(ad_rspec) + request_rspec.version.add_link_requests(link_tuples) +except: + print >> sys.stderr, "FAILED: %s" % links + raise + sys.exit(1) +print >>outfile, request_rspec.toxml() +sys.exit(0) diff --git a/sfa/client/sfiListLinks.py b/sfa/client/sfiListLinks.py new file mode 100755 index 00000000..a4720ca1 --- /dev/null +++ b/sfa/client/sfiListLinks.py @@ -0,0 +1,26 @@ +#! /usr/bin/env python + +import sys +from sfa.client.sfi_commands import Commands +from sfa.rspecs.rspec import RSpec +from sfa.util.xrn import Xrn + +command = Commands(usage="%prog [options]", + description="List all links in the RSpec. " + + "Use this to display the list of available links. " ) +command.prep() + +if command.opts.infile: + rspec = RSpec(command.opts.infile) + links = rspec.version.get_links() + if command.opts.outfile: + sys.stdout = open(command.opts.outfile, 'w') + + for link in links: + ifname1 = Xrn(link['interface1']['component_id']).get_leaf() + ifname2 = Xrn(link['interface2']['component_id']).get_leaf() + print "%s %s" % (ifname1, ifname2) + + + + diff --git a/sfa/client/sfi_commands.py b/sfa/client/sfi_commands.py index bdcc16d4..80897cd2 100755 --- a/sfa/client/sfi_commands.py +++ b/sfa/client/sfi_commands.py @@ -12,6 +12,7 @@ class Commands: self.parser.add_option("-o", "", dest="outfile", metavar="FILE", help="write output to FILE (default is stdout)") self.nodefile = False + self.linkfile = False self.attributes = {} def add_nodefile_option(self): @@ -20,6 +21,12 @@ class Commands: metavar="FILE", help="read node list from FILE"), + def add_linkfile_option(self): + self.linkfile = True + self.parser.add_option("-l", "", dest="linkfile", + metavar="FILE", + help="read link list from FILE") + def add_show_attributes_option(self): self.parser.add_option("-s", "--show-attributes", action="store_true", dest="showatt", default=False, diff --git a/sfa/generic/__init__.py b/sfa/generic/__init__.py new file mode 100644 index 00000000..843cd7b0 --- /dev/null +++ b/sfa/generic/__init__.py @@ -0,0 +1,104 @@ +from sfa.util.sfalogging import logger +from sfa.util.config import Config + +from sfa.managers.managerwrapper import ManagerWrapper + +# a bundle is the combination of +# (*) an api that reacts on the incoming requests to trigger the API methods +# (*) a manager that implements the function of the service, +# either aggregate, registry, or slicemgr +# (*) a driver that controls the underlying testbed +# +# +# The Generic class is a utility that uses the configuration to figure out +# which combination of these pieces need to be put together +# from config. +# this extra indirection is needed to adapt to the current naming scheme +# where we have 'pl' and 'plc' and components and the like, that does not +# yet follow a sensible scheme + +# needs refinements to cache more efficiently, esp. wrt the config + +class Generic: + + def __init__ (self, flavour, config): + self.flavour=flavour + self.config=config + + # proof of concept + # example flavour='pl' -> sfa.generic.pl.pl() + @staticmethod + def the_flavour (flavour=None, config=None): + if config is None: config=Config() + if flavour is None: flavour=config.SFA_GENERIC_FLAVOUR + flavour = flavour.lower() + #mixed = flavour.capitalize() + module_path="sfa.generic.%s"%flavour + classname="%s"%flavour + logger.info("Generic.the_flavour with flavour=%s"%flavour) + try: + module = __import__ (module_path, globals(), locals(), [classname]) + return getattr(module, classname)(flavour,config) + except: + logger.log_exc("Cannot locate generic instance with flavour=%s"%flavour) + + # in the simplest case these can be redefined to the class/module objects to be used + # see pl.py for an example + # some descendant of SfaApi + def api_class (self) : pass + # in practical terms these are modules for now + def registry_class (self) : pass + def slicemgr_class (self) : pass + def aggregate_class (self) : pass + def component_class (self) : pass + + + # build an API object + # insert a manager instance + def make_api (self, *args, **kwargs): + # interface is a required arg + if not 'interface' in kwargs: + logger.fatal("Generic.make_api: no interface found") + api = self.api_class()(*args, **kwargs) + interface=kwargs['interface'] + # or simpler, interface=api.interface + manager = self.make_manager(interface) + api.manager = ManagerWrapper(manager,interface) + return api + + def make_manager (self, interface): + """ + interface expected in ['registry', 'aggregate', 'slice', 'component'] + flavour is e.g. 'pl' or 'max' or whatever + """ + flavour = self.flavour + message="Generic.make_manager for interface=%s and flavour=%s"%(interface,flavour) + + classname = "%s_class"%interface + try: + module = getattr(self,classname)() + logger.info("%s : %s"%(message,module)) + return module + except: + logger.log_exc(message) + logger.fatal("Aborting") + +# former logic was +# basepath = 'sfa.managers' +# qualified = "%s.%s_manager_%s"%(basepath,interface,flavour) +# generic = "%s.%s_manager"%(basepath,interface) +# +# try: +# manager = __import__(qualified, fromlist=[basepath]) +# logger.info ("%s: loaded %s"%(message,qualified)) +# except: +# try: +# manager = __import__ (generic, fromlist=[basepath]) +# if flavour != 'pl' : +# logger.warn ("%s: using generic with flavour!='pl'"%(message)) +# logger.info("%s: loaded %s"%(message,generic)) +# except: +# logger.log_exc("%s: unable to import either %s or %s"%(message,qualified,generic)) +# logger.fatal("Aborted") +# return manager + diff --git a/sfa/generic/pl.py b/sfa/generic/pl.py new file mode 100644 index 00000000..853053d1 --- /dev/null +++ b/sfa/generic/pl.py @@ -0,0 +1,19 @@ +from sfa.generic import Generic +import sfa.plc.plcsfaapi +import sfa.managers.registry_manager +import sfa.managers.slice_manager +import sfa.managers.aggregate_manager + +class pl (Generic): + + def api_class (self): + return sfa.plc.plcsfaapi.PlcSfaApi + + def registry_class (self) : + return sfa.managers.registry_manager + def slicemgr_class (self) : + return sfa.managers.slice_manager + def aggregate_class (self) : + return sfa.managers.aggregate_manager + + diff --git a/sfa/generic/plcm.py b/sfa/generic/plcm.py new file mode 100644 index 00000000..dd24d3c2 --- /dev/null +++ b/sfa/generic/plcm.py @@ -0,0 +1,11 @@ +from sfa.generic.pl import pl +import sfa.plc.plccomponentapi +import sfa.managers.component_manager_pl + +class plcm (pl): + + def api_class (self): + return sfa.plc.plccomponentapi.PlcComponentApi + + def component_class (self): + return sfa.managers.component_manager_pl diff --git a/sfa/init.d/sfa b/sfa/init.d/sfa index e2fdb108..08975dc7 100755 --- a/sfa/init.d/sfa +++ b/sfa/init.d/sfa @@ -62,18 +62,18 @@ start() { reload # install peer certs - action $"SFA installing peer certs" daemon /usr/bin/sfa-server.py -t -d $OPTIONS + action $"SFA installing peer certs" daemon /usr/bin/sfa-start.py -t -d $OPTIONS if [ "$SFA_REGISTRY_ENABLED" -eq 1 ]; then - action $"SFA Registry" daemon /usr/bin/sfa-server.py -r -d $OPTIONS + action $"SFA Registry" daemon /usr/bin/sfa-start.py -r -d $OPTIONS fi if [ "$SFA_AGGREGATE_ENABLED" -eq 1 ]; then - action $"SFA Aggregate" daemon /usr/bin/sfa-server.py -a -d $OPTIONS + action $"SFA Aggregate" daemon /usr/bin/sfa-start.py -a -d $OPTIONS fi if [ "$SFA_SM_ENABLED" -eq 1 ]; then - action "SFA SliceMgr" daemon /usr/bin/sfa-server.py -s -d $OPTIONS + action "SFA SliceMgr" daemon /usr/bin/sfa-start.py -s -d $OPTIONS fi if [ "$SFA_FLASHPOLICY_ENABLED" -eq 1 ]; then @@ -81,15 +81,15 @@ start() { fi RETVAL=$? - [ $RETVAL -eq 0 ] && touch /var/lock/subsys/sfa-server.py + [ $RETVAL -eq 0 ] && touch /var/lock/subsys/sfa-start.py } stop() { - action $"Shutting down SFA" killproc sfa-server.py + action $"Shutting down SFA" killproc sfa-start.py RETVAL=$? - [ $RETVAL -eq 0 ] && rm -f /var/lock/subsys/sfa-server.py + [ $RETVAL -eq 0 ] && rm -f /var/lock/subsys/sfa-start.py } @@ -99,13 +99,13 @@ case "$1" in reload) reload force ;; restart) stop; start ;; condrestart) - if [ -f /var/lock/subsys/sfa-server.py ]; then + if [ -f /var/lock/subsys/sfa-start.py ]; then stop start fi ;; status) - status sfa-server.py + status sfa-start.py RETVAL=$? ;; *) diff --git a/sfa/init.d/sfa-cm b/sfa/init.d/sfa-cm index eea507cf..cdddf8b1 100755 --- a/sfa/init.d/sfa-cm +++ b/sfa/init.d/sfa-cm @@ -27,7 +27,7 @@ start() { echo "Component Mgr" # make sure server key (nodes private key) exists first init_key - /usr/bin/sfa-server.py -c -d $OPTIONS + /usr/bin/sfa-start.py -c -d $OPTIONS fi RETVAL=$? @@ -38,7 +38,7 @@ start() { stop() { echo -n $"Shutting down SFA: " - killproc sfa-server.py + killproc sfa-start.py RETVAL=$? echo diff --git a/sfa/managers/aggregate_manager.py b/sfa/managers/aggregate_manager.py index 77af0723..702fe7d8 100644 --- a/sfa/managers/aggregate_manager.py +++ b/sfa/managers/aggregate_manager.py @@ -1,31 +1,24 @@ import datetime import time -import traceback import sys -import re -from types import StringTypes -from sfa.util.faults import * +from sfa.util.faults import RecordNotFound, SliverDoesNotExist from sfa.util.xrn import get_authority, hrn_to_urn, urn_to_hrn, Xrn, urn_to_sliver_id -from sfa.util.plxrn import slicename_to_hrn, hrn_to_pl_slicename, hostname_to_urn -from sfa.util.specdict import * -from sfa.util.record import SfaRecord -from sfa.util.policy import Policy -from sfa.util.record import * -from sfa.util.sfaticket import SfaTicket -from sfa.plc.slices import Slices -from sfa.trust.credential import Credential -import sfa.plc.peers as peers -from sfa.plc.network import * -from sfa.plc.api import SfaAPI -from sfa.plc.aggregate import Aggregate -from sfa.plc.slices import * +from sfa.util.plxrn import slicename_to_hrn, hrn_to_pl_slicename from sfa.util.version import version_core -from sfa.rspecs.version_manager import VersionManager -from sfa.rspecs.rspec import RSpec from sfa.util.sfatime import utcparse from sfa.util.callids import Callids +from sfa.trust.sfaticket import SfaTicket +from sfa.trust.credential import Credential +from sfa.rspecs.version_manager import VersionManager +from sfa.rspecs.rspec import RSpec + +import sfa.plc.peers as peers +from sfa.plc.plcsfaapi import PlcSfaApi +from sfa.plc.aggregate import Aggregate +from sfa.plc.slices import Slices + def GetVersion(api): version_manager = VersionManager() @@ -51,7 +44,7 @@ def __get_registry_objects(slice_xrn, creds, users): """ """ - hrn, type = urn_to_hrn(slice_xrn) + hrn, _ = urn_to_hrn(slice_xrn) hrn_auth = get_authority(hrn) @@ -115,7 +108,7 @@ def __get_hostnames(nodes): def SliverStatus(api, slice_xrn, creds, call_id): if Callids().already_handled(call_id): return {} - (hrn, type) = urn_to_hrn(slice_xrn) + (hrn, _) = urn_to_hrn(slice_xrn) # find out where this slice is currently running slicename = hrn_to_pl_slicename(hrn) @@ -128,8 +121,6 @@ def SliverStatus(api, slice_xrn, creds, call_id): nodes = api.plshell.GetNodes(api.plauth, {'node_id':slice['node_ids'],'peer_id':None}, ['node_id', 'hostname', 'site_id', 'boot_state', 'last_contact']) site_ids = [node['site_id'] for node in nodes] - sites = api.plshell.GetSites(api.plauth, site_ids, ['site_id', 'login_base']) - sites_dict = dict ( [ (site['site_id'],site['login_base'] ) for site in sites ] ) result = {} top_level_status = 'unknown' @@ -154,7 +145,7 @@ def SliverStatus(api, slice_xrn, creds, call_id): res['geni_status'] = 'ready' else: res['geni_status'] = 'failed' - top_level_staus = 'failed' + top_level_status = 'failed' res['geni_error'] = '' @@ -173,7 +164,7 @@ def CreateSliver(api, slice_xrn, creds, rspec_string, users, call_id): aggregate = Aggregate(api) slices = Slices(api) - (hrn, type) = urn_to_hrn(slice_xrn) + (hrn, _) = urn_to_hrn(slice_xrn) peer = slices.get_peer(hrn) sfa_peer = slices.get_sfa_peer(hrn) slice_record=None @@ -197,7 +188,11 @@ def CreateSliver(api, slice_xrn, creds, rspec_string, users, call_id): requested_slivers = [str(host) for host in rspec.version.get_nodes_with_slivers()] slices.verify_slice_nodes(slice, requested_slivers, peer) - # hanlde MyPLC peer association. + aggregate.prepare_nodes({'hostname': requested_slivers}) + aggregate.prepare_interfaces({'node_id': aggregate.nodes.keys()}) + slices.verify_slice_links(slice, rspec.version.get_link_requests(), aggregate) + + # handle MyPLC peer association. # only used by plc and ple. slices.handle_peer(site, slice, persons, peer) @@ -206,7 +201,7 @@ def CreateSliver(api, slice_xrn, creds, rspec_string, users, call_id): def RenewSliver(api, xrn, creds, expiration_time, call_id): if Callids().already_handled(call_id): return True - (hrn, type) = urn_to_hrn(xrn) + (hrn, _) = urn_to_hrn(xrn) slicename = hrn_to_pl_slicename(hrn) slices = api.plshell.GetSlices(api.plauth, {'name': slicename}, ['slice_id']) if not slices: @@ -221,7 +216,7 @@ def RenewSliver(api, xrn, creds, expiration_time, call_id): return False def start_slice(api, xrn, creds): - hrn, type = urn_to_hrn(xrn) + (hrn, _) = urn_to_hrn(xrn) slicename = hrn_to_pl_slicename(hrn) slices = api.plshell.GetSlices(api.plauth, {'name': slicename}, ['slice_id']) if not slices: @@ -235,7 +230,7 @@ def start_slice(api, xrn, creds): return 1 def stop_slice(api, xrn, creds): - hrn, type = urn_to_hrn(xrn) + hrn, _ = urn_to_hrn(xrn) slicename = hrn_to_pl_slicename(hrn) slices = api.plshell.GetSlices(api.plauth, {'name': slicename}, ['slice_id']) if not slices: @@ -245,7 +240,7 @@ def stop_slice(api, xrn, creds): if not slice_tags: api.plshell.AddSliceTag(api.plauth, slice_id, 'enabled', '0') elif slice_tags[0]['value'] != "0": - tag_id = attributes[0]['slice_tag_id'] + tag_id = slice_tags[0]['slice_tag_id'] api.plshell.UpdateSliceTag(api.plauth, tag_id, '0') return 1 @@ -255,7 +250,7 @@ def reset_slice(api, xrn): def DeleteSliver(api, xrn, creds, call_id): if Callids().already_handled(call_id): return "" - (hrn, type) = urn_to_hrn(xrn) + (hrn, _) = urn_to_hrn(xrn) slicename = hrn_to_pl_slicename(hrn) slices = api.plshell.GetSlices(api.plauth, {'name': slicename}) if not slices: @@ -299,7 +294,7 @@ def ListResources(api, creds, options, call_id): if Callids().already_handled(call_id): return "" # get slice's hrn from options xrn = options.get('geni_slice_urn', None) - (hrn, type) = urn_to_hrn(xrn) + (hrn, _) = urn_to_hrn(xrn) version_manager = VersionManager() # get the rspec's return format from options @@ -331,36 +326,45 @@ def ListResources(api, creds, options, call_id): def get_ticket(api, xrn, creds, rspec, users): - reg_objects = __get_registry_objects(xrn, creds, users) - - slice_hrn, type = urn_to_hrn(xrn) + (slice_hrn, _) = urn_to_hrn(xrn) slices = Slices(api) peer = slices.get_peer(slice_hrn) sfa_peer = slices.get_sfa_peer(slice_hrn) # get the slice record - registry = api.registries[api.hrn] credential = api.getCredential() + interface = api.registries[api.hrn] + registry = api.get_server(interface, credential) records = registry.Resolve(xrn, credential) - # similar to CreateSliver, we must verify that the required records exist - # at this aggregate before we can issue a ticket - site_id, remote_site_id = slices.verify_site(registry, credential, slice_hrn, - peer, sfa_peer, reg_objects) - slice = slices.verify_slice(registry, credential, slice_hrn, site_id, - remote_site_id, peer, sfa_peer, reg_objects) - # make sure we get a local slice record record = None for tmp_record in records: if tmp_record['type'] == 'slice' and \ not tmp_record['peer_authority']: +#Error (E0602, get_ticket): Undefined variable 'SliceRecord' record = SliceRecord(dict=tmp_record) if not record: raise RecordNotFound(slice_hrn) + + # similar to CreateSliver, we must verify that the required records exist + # at this aggregate before we can issue a ticket + # parse rspec + rspec = RSpec(rspec_string) + requested_attributes = rspec.version.get_slice_attributes() + # ensure site record exists + site = slices.verify_site(hrn, slice_record, peer, sfa_peer) + # ensure slice record exists + slice = slices.verify_slice(hrn, slice_record, peer, sfa_peer) + # ensure person records exists + persons = slices.verify_persons(hrn, slice, users, peer, sfa_peer) + # ensure slice attributes exists + slices.verify_slice_attributes(slice, requested_attributes) + # get sliver info - slivers = Slices(api).get_slivers(slice_hrn) + slivers = slices.get_slivers(slice_hrn) + if not slivers: raise SliverDoesNotExist(slice_hrn) @@ -390,16 +394,18 @@ def get_ticket(api, xrn, creds, rspec, users): def main(): - api = SfaAPI() """ rspec = ListResources(api, "plc.princeton.sapan", None, 'pl_test_sapan') #rspec = ListResources(api, "plc.princeton.coblitz", None, 'pl_test_coblitz') #rspec = ListResources(api, "plc.pl.sirius", None, 'pl_test_sirius') print rspec """ + api = PlcSfaApi() f = open(sys.argv[1]) xml = f.read() f.close() +#Error (E1120, main): No value passed for parameter 'users' in function call +#Error (E1120, main): No value passed for parameter 'call_id' in function call CreateSliver(api, "plc.princeton.sapan", xml, 'CreateSliver_sapan') if __name__ == "__main__": diff --git a/sfa/managers/aggregate_manager_eucalyptus.py b/sfa/managers/aggregate_manager_eucalyptus.py index 6c7c1f4c..6c681183 100644 --- a/sfa/managers/aggregate_manager_eucalyptus.py +++ b/sfa/managers/aggregate_manager_eucalyptus.py @@ -17,7 +17,7 @@ from sfa.util.faults import * from sfa.util.xrn import urn_to_hrn, Xrn from sfa.server.registry import Registries from sfa.trust.credential import Credential -from sfa.plc.api import SfaAPI +from sfa.plc.plcsfaapi import PlcSfaApi from sfa.plc.aggregate import Aggregate from sfa.plc.slices import * from sfa.util.plxrn import hrn_to_pl_slicename, slicename_to_hrn @@ -41,7 +41,7 @@ cloud = {} # EUCALYPTUS_RSPEC_SCHEMA='/etc/sfa/eucalyptus.rng' -api = SfaAPI() +api = PlcSfaApi() ## # Meta data of an instance. @@ -735,7 +735,7 @@ def main(): server_key_file = '/var/lib/sfa/authorities/server.key' server_cert_file = '/var/lib/sfa/authorities/server.cert' - api = SfaAPI(key_file = server_key_file, cert_file = server_cert_file, interface='aggregate') + api = PlcSfaApi(key_file = server_key_file, cert_file = server_cert_file, interface='aggregate') print getKeysForSlice(api, 'gc.gc.test1') if __name__ == "__main__": diff --git a/sfa/managers/aggregate_manager_max.py b/sfa/managers/aggregate_manager_max.py index d7d37761..ff6ea73b 100644 --- a/sfa/managers/aggregate_manager_max.py +++ b/sfa/managers/aggregate_manager_max.py @@ -259,6 +259,6 @@ return the basic information needed in a dict. def fetch_context(slice_hrn, user_hrn, contexts): base_context = {'sfa':{'user':{'hrn':user_hrn}}} return base_context - api = SfaAPI() + api = PlcSfaApi() create_slice(api, "plc.maxpl.test000", None, rspec_xml, None) diff --git a/sfa/managers/aggregate_manager_vini.py b/sfa/managers/aggregate_manager_vini.py deleted file mode 100644 index eadcbfd0..00000000 --- a/sfa/managers/aggregate_manager_vini.py +++ /dev/null @@ -1,135 +0,0 @@ -import datetime -import time -import traceback -import sys - -from types import StringTypes -from sfa.util.xrn import urn_to_hrn, Xrn -from sfa.util.plxrn import hrn_to_pl_slicename -from sfa.util.specdict import * -from sfa.util.faults import * -from sfa.util.record import SfaRecord -from sfa.util.policy import Policy -from sfa.util.record import * -from sfa.util.sfaticket import SfaTicket -from sfa.server.registry import Registries -from sfa.plc.slices import Slices -import sfa.plc.peers as peers -from sfa.managers.vini.vini_network import * -from sfa.plc.vini_aggregate import ViniAggregate -from sfa.rspecs.version_manager import VersionManager -from sfa.plc.api import SfaAPI -from sfa.plc.slices import * -from sfa.managers.aggregate_manager_pl import __get_registry_objects, __get_hostnames -from sfa.util.version import version_core -from sfa.util.callids import Callids - -# VINI aggregate is almost identical to PLC aggregate for many operations, -# so lets just import the methods form the PLC manager -from sfa.managers.aggregate_manager_pl import ( -start_slice, stop_slice, RenewSliver, reset_slice, ListSlices, get_ticket, SliverStatus) - - -def GetVersion(api): - xrn=Xrn(api.hrn) - return version_core({'interface':'aggregate', - 'testbed':'myplc.vini', - 'hrn':xrn.get_hrn(), - }) - -def DeleteSliver(api, xrn, creds, call_id): - if Callids().already_handled(call_id): return "" - (hrn, type) = urn_to_hrn(xrn) - slicename = hrn_to_pl_slicename(hrn) - slices = api.plshell.GetSlices(api.plauth, {'name': slicename}) - if not slices: - return 1 - slice = slices[0] - - api.plshell.DeleteSliceFromNodes(api.plauth, slicename, slice['node_ids']) - return 1 - -def CreateSliver(api, xrn, creds, xml, users, call_id): - """ - Verify HRN and initialize the slice record in PLC if necessary. - """ - - if Callids().already_handled(call_id): return "" - - hrn, type = urn_to_hrn(xrn) - peer = None - reg_objects = __get_registry_objects(xrn, creds, users) - slices = Slices(api) - peer = slices.get_peer(hrn) - sfa_peer = slices.get_sfa_peer(hrn) - registries = Registries(api) - registry = registries[api.hrn] - credential = api.getCredential() - site_id, remote_site_id = slices.verify_site(registry, credential, hrn, - peer, sfa_peer, reg_objects) - slice = slices.verify_slice(registry, credential, hrn, site_id, - remote_site_id, peer, sfa_peer, reg_objects) - - network = ViniNetwork(api) - - slice = network.get_slice(api, hrn) - current = __get_hostnames(slice.get_nodes()) - - network.addRSpec(xml, "/var/www/html/schemas/vini.rng") - #network.addRSpec(xml, "/root/SVN/sfa/trunk/sfa/managers/vini/vini.rng") - request = __get_hostnames(network.nodesWithSlivers()) - - # remove nodes not in rspec - deleted_nodes = list(set(current).difference(request)) - - # add nodes from rspec - added_nodes = list(set(request).difference(current)) - - api.plshell.AddSliceToNodes(api.plauth, slice.name, added_nodes) - api.plshell.DeleteSliceFromNodes(api.plauth, slice.name, deleted_nodes) - network.updateSliceTags() - - # xxx - check this holds enough data for the client to understand what's happened - return network.toxml() - -def ListResources(api, creds, options,call_id): - if Callids().already_handled(call_id): return "" - # get slice's hrn from options - xrn = options.get('geni_slice_urn', '') - hrn, type = urn_to_hrn(xrn) - - version_manager = VersionManager() - # get the rspec's return format from options - rspec_version = version_manager.get_version(options.get('rspec_version')) - version_string = "rspec_%s" % (rspec_version.to_string()) - - # look in cache first - if api.cache and not xrn: - rspec = api.cache.get(version_string) - if rspec: - api.logger.info("aggregate.ListResources: returning cached value for hrn %s"%hrn) - return rspec - - aggregate = ViniAggregate(api, options) - rspec = aggregate.get_rspec(slice_xrn=xrn, version=rspec_version) - - # cache the result - if api.cache and not xrn: - api.cache.add('nodes', rspec) - - return rspec - -def main(): - api = SfaAPI() - """ - #rspec = ListResources(api, None, None,) - rspec = ListResources(api, "plc.princeton.iias", None, 'vini_test') - print rspec - """ - f = open(sys.argv[1]) - xml = f.read() - f.close() - CreateSliver(api, "plc.princeton.iias", xml, 'call-id-iias') - -if __name__ == "__main__": - main() diff --git a/sfa/managers/component_manager_pl.py b/sfa/managers/component_manager_pl.py index 6100e763..8aca53c8 100644 --- a/sfa/managers/component_manager_pl.py +++ b/sfa/managers/component_manager_pl.py @@ -1,9 +1,8 @@ -import os import xmlrpclib -from sfa.util.faults import * +from sfa.util.faults import SliverDoesNotExist from sfa.util.plxrn import PlXrn -from sfa.util.sfaticket import SfaTicket +from sfa.trust.sfaticket import SfaTicket from sfa.util.version import version_core def GetVersion(api): diff --git a/sfa/managers/import_manager.py b/sfa/managers/import_manager.py deleted file mode 100644 index f5f30c44..00000000 --- a/sfa/managers/import_manager.py +++ /dev/null @@ -1,26 +0,0 @@ -from sfa.util.sfalogging import logger - -def import_manager(kind, type): - """ - kind expected in ['registry', 'aggregate', 'slice', 'component'] - type is e.g. 'pl' or 'max' or whatever - """ - basepath = 'sfa.managers' - qualified = "%s.%s_manager_%s"%(basepath,kind,type) - generic = "%s.%s_manager"%(basepath,kind) - - message="import_manager for kind=%s and type=%s"%(kind,type) - try: - manager = __import__(qualified, fromlist=[basepath]) - logger.info ("%s: loaded %s"%(message,qualified)) - except: - try: - manager = __import__ (generic, fromlist=[basepath]) - if type != 'pl' : - logger.warn ("%s: using generic with type!='pl'"%(message)) - logger.info("%s: loaded %s"%(message,generic)) - except: - manager=None - logger.log_exc("%s: unable to import either %s or %s"%(message,qualified,generic)) - return manager - diff --git a/sfa/managers/managerwrapper.py b/sfa/managers/managerwrapper.py new file mode 100644 index 00000000..5231c2aa --- /dev/null +++ b/sfa/managers/managerwrapper.py @@ -0,0 +1,24 @@ +from sfa.util.faults import SfaNotImplemented +from sfa.util.sfalogging import logger + +#################### +class ManagerWrapper: + """ + This class acts as a wrapper around an SFA interface manager module, but + can be used with any python module. The purpose of this class is raise a + SfaNotImplemented exception if someone attempts to use an attribute + (could be a callable) thats not available in the library by checking the + library using hasattr. This helps to communicate better errors messages + to the users and developers in the event that a specifiec operation + is not implemented by a libarary and will generally be more helpful than + the standard AttributeError + """ + def __init__(self, manager, interface): + self.manager = manager + self.interface = interface + + def __getattr__(self, method): + if not hasattr(self.manager, method): + raise SfaNotImplemented(method, self.interface) + return getattr(self.manager, method) + diff --git a/sfa/managers/registry_manager.py b/sfa/managers/registry_manager.py index 6052eee3..085bc39f 100644 --- a/sfa/managers/registry_manager.py +++ b/sfa/managers/registry_manager.py @@ -1,18 +1,19 @@ import types import time -from sfa.util.faults import * +from sfa.util.faults import RecordNotFound, AccountNotEnabled, PermissionError, MissingAuthority, \ + UnknownSfaType, ExistingRecord from sfa.util.prefixTree import prefixTree from sfa.util.record import SfaRecord from sfa.util.table import SfaTable -from sfa.util.record import SfaRecord -from sfa.trust.gid import GID from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn, urn_to_hrn from sfa.util.plxrn import hrn_to_pl_login_base +from sfa.util.version import version_core + +from sfa.trust.gid import GID from sfa.trust.credential import Credential -from sfa.trust.certificate import Certificate, Keypair +from sfa.trust.certificate import Certificate, Keypair, convert_public_key from sfa.trust.gid import create_uuid -from sfa.util.version import version_core # The GENI GetVersion call def GetVersion(api): @@ -118,7 +119,9 @@ def resolve(api, xrns, type=None, full=True): xrns = xrn_dict[registry_hrn] if registry_hrn != api.hrn: credential = api.getCredential() - peer_records = registries[registry_hrn].Resolve(xrns, credential) + interface = api.registries[registry_hrn] + server = api.get_server(interface, credential) + peer_records = server.Resolve(xrns, credential) records.extend([SfaRecord(dict=record).as_dict() for record in peer_records]) # try resolving the remaining unfound records at the local registry @@ -154,13 +157,14 @@ def list(api, xrn, origin_hrn=None): #if there was no match then this record belongs to an unknow registry if not registry_hrn: raise MissingAuthority(xrn) - # if the best match (longest matching hrn) is not the local registry, # forward the request records = [] if registry_hrn != api.hrn: credential = api.getCredential() - record_list = registries[registry_hrn].List(xrn, credential) + interface = api.registries[registry_hrn] + server = api.get_server(interface, credential) + record_list = server.List(xrn, credential) records = [SfaRecord(dict=record).as_dict() for record in record_list] # if we still have not found the record yet, try the local registry diff --git a/sfa/managers/slice_manager.py b/sfa/managers/slice_manager.py index 31753106..29113138 100644 --- a/sfa/managers/slice_manager.py +++ b/sfa/managers/slice_manager.py @@ -1,33 +1,22 @@ -# import sys -import time,datetime +import time from StringIO import StringIO -from types import StringTypes -from copy import deepcopy from copy import copy from lxml import etree -from sfa.util.sfalogging import logger -from sfa.util.rspecHelper import merge_rspecs -from sfa.util.xrn import Xrn, urn_to_hrn, hrn_to_urn -from sfa.util.plxrn import hrn_to_pl_slicename -from sfa.util.specdict import * -from sfa.util.faults import * -from sfa.util.record import SfaRecord -from sfa.rspecs.rspec_converter import RSpecConverter -from sfa.client.client_helper import sfa_to_pg_users_arg -from sfa.rspecs.version_manager import VersionManager -from sfa.rspecs.rspec import RSpec -from sfa.util.policy import Policy -from sfa.util.prefixTree import prefixTree -from sfa.util.sfaticket import * +from sfa.trust.sfaticket import SfaTicket from sfa.trust.credential import Credential + +from sfa.util.sfalogging import logger +from sfa.util.xrn import Xrn, urn_to_hrn from sfa.util.threadmanager import ThreadManager -import sfa.util.xmlrpcprotocol as xmlrpcprotocol -import sfa.plc.peers as peers from sfa.util.version import version_core from sfa.util.callids import Callids +from sfa.rspecs.rspec_converter import RSpecConverter +from sfa.rspecs.version_manager import VersionManager +from sfa.rspecs.rspec import RSpec +from sfa.client.client_helper import sfa_to_pg_users_arg def _call_id_supported(api, server): """ @@ -91,7 +80,7 @@ def drop_slicemgr_stats(rspec): for node in stats_elements: node.getparent().remove(node) except Exception, e: - api.logger.warn("drop_slicemgr_stats failed: %s " % (str(e))) + logger.warn("drop_slicemgr_stats failed: %s " % (str(e))) def add_slicemgr_stat(rspec, callname, aggname, elapsed, status): try: @@ -103,7 +92,7 @@ def add_slicemgr_stat(rspec, callname, aggname, elapsed, status): etree.SubElement(stats_tag, "aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status)) except Exception, e: - api.logger.warn("add_slicemgr_stat failed on %s: %s" %(aggname, str(e))) + logger.warn("add_slicemgr_stat failed on %s: %s" %(aggname, str(e))) def ListResources(api, creds, options, call_id): version_manager = VersionManager() @@ -218,9 +207,9 @@ def CreateSliver(api, xrn, creds, rspec_str, users, call_id): # The schema used here needs to aggregate the PL and VINI schemas # schema = "/var/www/html/schemas/pl.rng" rspec = RSpec(rspec_str) - schema = None - if schema: - rspec.validate(schema) +# schema = None +# if schema: +# rspec.validate(schema) # if there is a section, the aggregates don't care about it, # so delete it. @@ -444,7 +433,7 @@ def get_ticket(api, xrn, creds, rspec, users): results = threads.get_results() # gather information from each ticket - rspecs = [] + rspec = None initscripts = [] slivers = [] object_gid = None @@ -453,15 +442,17 @@ def get_ticket(api, xrn, creds, rspec, users): attrs = agg_ticket.get_attributes() if not object_gid: object_gid = agg_ticket.get_gid_object() - rspecs.append(agg_ticket.get_rspec()) + if not rspec: + rspec = RSpec(agg_ticket.get_rspec()) + else: + rspec.version.merge(agg_ticket.get_rspec()) initscripts.extend(attrs.get('initscripts', [])) slivers.extend(attrs.get('slivers', [])) # merge info attributes = {'initscripts': initscripts, 'slivers': slivers} - merged_rspec = merge_rspecs(rspecs) - + # create a new ticket ticket = SfaTicket(subject = slice_hrn) ticket.set_gid_caller(api.auth.client_gid) @@ -470,7 +461,7 @@ def get_ticket(api, xrn, creds, rspec, users): ticket.set_pubkey(object_gid.get_pubkey()) #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn)) ticket.set_attributes(attributes) - ticket.set_rspec(merged_rspec) + ticket.set_rspec(rspec.toxml()) ticket.encode() ticket.sign() return ticket.save_to_string(save_parents=True) @@ -539,11 +530,12 @@ def status(api, xrn, creds): """ return 1 -def main(): - r = RSpec() - r.parseFile(sys.argv[1]) - rspec = r.toDict() - CreateSliver(None,'plc.princeton.tmacktestslice',rspec,'create-slice-tmacktestslice') +# this is plain broken +#def main(): +# r = RSpec() +# r.parseFile(sys.argv[1]) +# rspec = r.toDict() +# CreateSliver(None,'plc.princeton.tmacktestslice',rspec,'create-slice-tmacktestslice') if __name__ == "__main__": main() diff --git a/sfa/methods/CreateSliver.py b/sfa/methods/CreateSliver.py index 7895de3c..bb0051a9 100644 --- a/sfa/methods/CreateSliver.py +++ b/sfa/methods/CreateSliver.py @@ -1,9 +1,10 @@ -from sfa.util.faults import * +from sfa.util.faults import SfaInvalidArgument from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed from sfa.util.sfatablesRuntime import run_sfatables from sfa.trust.credential import Credential +from sfa.rspecs.rspec import RSpec class CreateSliver(Method): """ @@ -51,5 +52,9 @@ class CreateSliver(Method): chain_name = 'FORWARD-INCOMING' self.api.logger.debug("CreateSliver: sfatables on chain %s"%chain_name) rspec = run_sfatables(chain_name, hrn, origin_hrn, rspec) - - return manager.CreateSliver(self.api, slice_xrn, creds, rspec, users, call_id) + slivers = RSpec(rspec).version.get_nodes_with_slivers() + if slivers: + result = manager.CreateSliver(self.api, slice_xrn, creds, rspec, users, call_id) + else: + result = rspec + return result diff --git a/sfa/methods/DeleteSliver.py b/sfa/methods/DeleteSliver.py index ae301777..f766cb12 100644 --- a/sfa/methods/DeleteSliver.py +++ b/sfa/methods/DeleteSliver.py @@ -1,4 +1,3 @@ -from sfa.util.faults import * from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed diff --git a/sfa/methods/GetCredential.py b/sfa/methods/GetCredential.py index 34a4cb93..da3e97b7 100644 --- a/sfa/methods/GetCredential.py +++ b/sfa/methods/GetCredential.py @@ -1,6 +1,3 @@ -# -from sfa.trust.rights import * -from sfa.util.faults import * from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed diff --git a/sfa/methods/GetGids.py b/sfa/methods/GetGids.py index 37ad7967..e50f9404 100644 --- a/sfa/methods/GetGids.py +++ b/sfa/methods/GetGids.py @@ -1,9 +1,6 @@ -from sfa.util.faults import * +from sfa.util.faults import RecordNotFound from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed -from sfa.trust.auth import Auth -from sfa.trust.gid import GID -from sfa.trust.certificate import Certificate from sfa.trust.credential import Credential class GetGids(Method): @@ -30,15 +27,15 @@ class GetGids(Method): def call(self, xrns, creds): # validate the credential valid_creds = self.api.auth.checkCredentials(creds, 'getgids') + # xxxpylintxxx origin_hrn is unused.. origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn() # resolve the record manager = self.api.get_interface_manager() records = manager.resolve(self.api, xrns, full = False) if not records: - raise RecordNotFound(hrns) + raise RecordNotFound(xrns) - gids = [] allowed_fields = ['hrn', 'type', 'gid'] for record in records: for key in record.keys(): diff --git a/sfa/methods/GetSelfCredential.py b/sfa/methods/GetSelfCredential.py index 6a8261c9..de21ab57 100644 --- a/sfa/methods/GetSelfCredential.py +++ b/sfa/methods/GetSelfCredential.py @@ -1,12 +1,10 @@ -from sfa.util.faults import * +from sfa.util.faults import RecordNotFound, ConnectionKeyGIDMismatch from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed from sfa.util.record import SfaRecord -from sfa.trust.credential import Credential from sfa.trust.certificate import Certificate -from sfa.trust.rights import Right, Rights class GetSelfCredential(Method): """ diff --git a/sfa/methods/GetTicket.py b/sfa/methods/GetTicket.py index 3a250d57..14696931 100644 --- a/sfa/methods/GetTicket.py +++ b/sfa/methods/GetTicket.py @@ -1,13 +1,10 @@ -import time -from sfa.util.faults import * from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed -from sfa.trust.auth import Auth -from sfa.util.config import Config -from sfa.trust.credential import Credential from sfa.util.sfatablesRuntime import run_sfatables +from sfa.trust.credential import Credential + class GetTicket(Method): """ Retrieve a ticket. This operation is currently implemented on PLC diff --git a/sfa/methods/GetVersion.py b/sfa/methods/GetVersion.py index 2a8f6b21..72fe8061 100644 --- a/sfa/methods/GetVersion.py +++ b/sfa/methods/GetVersion.py @@ -1,4 +1,3 @@ -from sfa.util.faults import * from sfa.util.method import Method from sfa.util.parameter import Parameter diff --git a/sfa/methods/List.py b/sfa/methods/List.py index 8b4fcbea..a5d11235 100644 --- a/sfa/methods/List.py +++ b/sfa/methods/List.py @@ -1,5 +1,4 @@ -from sfa.util.faults import * from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed diff --git a/sfa/methods/ListResources.py b/sfa/methods/ListResources.py index fb831170..b8d7e2d0 100644 --- a/sfa/methods/ListResources.py +++ b/sfa/methods/ListResources.py @@ -1,7 +1,5 @@ -import sys import zlib -from sfa.util.faults import * from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed diff --git a/sfa/methods/ListSlices.py b/sfa/methods/ListSlices.py index fa65b074..7fc85130 100644 --- a/sfa/methods/ListSlices.py +++ b/sfa/methods/ListSlices.py @@ -1,7 +1,5 @@ -from sfa.util.faults import * from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed -from sfa.trust.auth import Auth from sfa.trust.credential import Credential class ListSlices(Method): diff --git a/sfa/methods/RedeemTicket.py b/sfa/methods/RedeemTicket.py index 3aff1e75..cab0e931 100644 --- a/sfa/methods/RedeemTicket.py +++ b/sfa/methods/RedeemTicket.py @@ -1,5 +1,3 @@ -import xmlrpclib -from sfa.util.faults import * from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed diff --git a/sfa/methods/Register.py b/sfa/methods/Register.py index f4b78018..619ed00b 100644 --- a/sfa/methods/Register.py +++ b/sfa/methods/Register.py @@ -1,12 +1,5 @@ - -from sfa.trust.certificate import Keypair, convert_public_key -from sfa.trust.gid import * -from sfa.util.faults import * from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed -from sfa.util.record import SfaRecord -from sfa.trust.auth import Auth -from sfa.trust.gid import create_uuid from sfa.trust.credential import Credential class Register(Method): diff --git a/sfa/methods/RegisterPeerObject.py b/sfa/methods/RegisterPeerObject.py index 11fd1bdf..2eec2f5d 100644 --- a/sfa/methods/RegisterPeerObject.py +++ b/sfa/methods/RegisterPeerObject.py @@ -1,14 +1,9 @@ -from sfa.trust.certificate import Keypair, convert_public_key -from sfa.trust.gid import * - -from sfa.util.faults import * +from sfa.util.faults import SfaInvalidArgument from sfa.util.xrn import get_authority from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed from sfa.util.record import SfaRecord from sfa.util.table import SfaTable -from sfa.trust.auth import Auth -from sfa.trust.gid import create_uuid from sfa.trust.credential import Credential class RegisterPeerObject(Method): diff --git a/sfa/methods/Remove.py b/sfa/methods/Remove.py index 73437a39..c547c262 100644 --- a/sfa/methods/Remove.py +++ b/sfa/methods/Remove.py @@ -1,4 +1,3 @@ -from sfa.util.faults import * from sfa.util.xrn import Xrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed diff --git a/sfa/methods/RemovePeerObject.py b/sfa/methods/RemovePeerObject.py index 460aa98c..fa30e882 100644 --- a/sfa/methods/RemovePeerObject.py +++ b/sfa/methods/RemovePeerObject.py @@ -1,11 +1,8 @@ -from sfa.util.faults import * +from sfa.util.faults import UnknownSfaType, SfaInvalidArgument from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed -from sfa.trust.auth import Auth -from sfa.util.record import SfaRecord from sfa.util.table import SfaTable from sfa.trust.credential import Credential -from types import StringTypes class RemovePeerObject(Method): """ diff --git a/sfa/methods/RenewSliver.py b/sfa/methods/RenewSliver.py index 1669517a..4a0e8361 100644 --- a/sfa/methods/RenewSliver.py +++ b/sfa/methods/RenewSliver.py @@ -1,10 +1,11 @@ -from sfa.util.faults import * +import datetime + +from sfa.util.faults import InsufficientRights from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter from sfa.trust.credential import Credential from sfa.util.sfatime import utcparse -import datetime class RenewSliver(Method): """ diff --git a/sfa/methods/Resolve.py b/sfa/methods/Resolve.py index 36b2bde0..74972cc9 100644 --- a/sfa/methods/Resolve.py +++ b/sfa/methods/Resolve.py @@ -1,6 +1,5 @@ -import traceback import types -from sfa.util.faults import * + from sfa.util.xrn import Xrn, urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed @@ -43,6 +42,4 @@ class Resolve(Method): # send the call to the right manager manager = self.api.get_interface_manager() return manager.resolve(self.api, xrns, type) - - diff --git a/sfa/methods/ResolveGENI.py b/sfa/methods/ResolveGENI.py index c6e52726..d9781971 100644 --- a/sfa/methods/ResolveGENI.py +++ b/sfa/methods/ResolveGENI.py @@ -1,7 +1,5 @@ -from sfa.util.faults import * from sfa.util.method import Method from sfa.util.parameter import Parameter -from sfa.trust.credential import Credential class ResolveGENI(Method): """ diff --git a/sfa/methods/Shutdown.py b/sfa/methods/Shutdown.py index 00142b64..9788608a 100644 --- a/sfa/methods/Shutdown.py +++ b/sfa/methods/Shutdown.py @@ -1,5 +1,3 @@ -from sfa.util.faults import * -from sfa.util.method import Method from sfa.util.parameter import Parameter from sfa.methods.Stop import Stop diff --git a/sfa/methods/SliverStatus.py b/sfa/methods/SliverStatus.py index 231bec56..18613b2a 100644 --- a/sfa/methods/SliverStatus.py +++ b/sfa/methods/SliverStatus.py @@ -1,4 +1,3 @@ -from sfa.util.faults import * from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed diff --git a/sfa/methods/Start.py b/sfa/methods/Start.py index e119f3cf..6882a37b 100644 --- a/sfa/methods/Start.py +++ b/sfa/methods/Start.py @@ -1,8 +1,6 @@ -from sfa.util.faults import * from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed -from sfa.trust.auth import Auth from sfa.trust.credential import Credential class Start(Method): diff --git a/sfa/methods/Stop.py b/sfa/methods/Stop.py index cdae0fcf..e8d3397e 100644 --- a/sfa/methods/Stop.py +++ b/sfa/methods/Stop.py @@ -1,8 +1,6 @@ -from sfa.util.faults import * from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed -from sfa.trust.auth import Auth from sfa.trust.credential import Credential class Stop(Method): diff --git a/sfa/methods/Update.py b/sfa/methods/Update.py index aa881ea2..31b17e99 100644 --- a/sfa/methods/Update.py +++ b/sfa/methods/Update.py @@ -1,7 +1,5 @@ -import time -from sfa.util.faults import * from sfa.util.method import Method -from sfa.util.parameter import Parameter, Mixed +from sfa.util.parameter import Parameter from sfa.trust.credential import Credential class Update(Method): diff --git a/sfa/methods/UpdateSliver.py b/sfa/methods/UpdateSliver.py index 83075726..f9baae4a 100644 --- a/sfa/methods/UpdateSliver.py +++ b/sfa/methods/UpdateSliver.py @@ -1,5 +1,3 @@ -from sfa.util.faults import * -from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed from sfa.methods.CreateSliver import CreateSliver diff --git a/sfa/methods/get_aggregates.py b/sfa/methods/get_aggregates.py index 59d6001a..23c8d609 100644 --- a/sfa/methods/get_aggregates.py +++ b/sfa/methods/get_aggregates.py @@ -1,9 +1,6 @@ -from types import StringTypes -from sfa.util.faults import * from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed -from sfa.trust.auth import Auth from sfa.server.aggregate import Aggregates class get_aggregates(Method): diff --git a/sfa/methods/get_key.py b/sfa/methods/get_key.py index 9cec0ec5..4bb65874 100644 --- a/sfa/methods/get_key.py +++ b/sfa/methods/get_key.py @@ -1,11 +1,10 @@ import os import tempfile import commands -from sfa.util.faults import * +from sfa.util.faults import NonExistingRecord, RecordNotFound from sfa.util.xrn import hrn_to_urn from sfa.util.method import Method -from sfa.util.parameter import Parameter, Mixed -from sfa.trust.auth import Auth +from sfa.util.parameter import Parameter from sfa.util.table import SfaTable from sfa.trust.certificate import Keypair from sfa.trust.gid import create_uuid diff --git a/sfa/methods/get_registries.py b/sfa/methods/get_registries.py index b404bb97..65d94449 100644 --- a/sfa/methods/get_registries.py +++ b/sfa/methods/get_registries.py @@ -1,9 +1,6 @@ -from types import StringTypes -from sfa.util.faults import * from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed -from sfa.trust.auth import Auth from sfa.server.registry import Registries class get_registries(Method): @@ -28,7 +25,7 @@ class get_registries(Method): def call(self, cred, xrn = None): hrn, type = urn_to_hrn(xrn) self.api.auth.check(cred, 'list') - registries = Registries(self.api).interfaces.values() + registries = Registries(self.api).values() if hrn: registries = [reg for reg in registries if reg['hrn'] == hrn] return registries diff --git a/sfa/methods/get_trusted_certs.py b/sfa/methods/get_trusted_certs.py index 704fd429..460ab4dc 100644 --- a/sfa/methods/get_trusted_certs.py +++ b/sfa/methods/get_trusted_certs.py @@ -1,4 +1,4 @@ -from sfa.util.faults import * +#from sfa.util.faults import * from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed from sfa.trust.auth import Auth diff --git a/sfa/methods/register_peer_object.py b/sfa/methods/register_peer_object.py index 484fae50..42ef2408 100644 --- a/sfa/methods/register_peer_object.py +++ b/sfa/methods/register_peer_object.py @@ -1,14 +1,10 @@ -from sfa.trust.certificate import Keypair, convert_public_key -from sfa.trust.gid import * -from sfa.util.faults import * +from sfa.util.faults import SfaInvalidArgument from sfa.util.xrn import get_authority from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed from sfa.util.record import SfaRecord from sfa.util.table import SfaTable -from sfa.trust.auth import Auth -from sfa.trust.gid import create_uuid from sfa.trust.credential import Credential class register_peer_object(Method): diff --git a/sfa/methods/remove_peer_object.py b/sfa/methods/remove_peer_object.py index 41d74dc1..a5101895 100644 --- a/sfa/methods/remove_peer_object.py +++ b/sfa/methods/remove_peer_object.py @@ -1,11 +1,8 @@ -from sfa.util.faults import * +from sfa.util.faults import UnknownSfaType, SfaInvalidArgument from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed -from sfa.trust.auth import Auth -from sfa.util.record import SfaRecord from sfa.util.table import SfaTable from sfa.trust.credential import Credential -from types import StringTypes class remove_peer_object(Method): """ diff --git a/sfa/methods/reset_slice.py b/sfa/methods/reset_slice.py index a4e9c5e2..15fb4a5b 100644 --- a/sfa/methods/reset_slice.py +++ b/sfa/methods/reset_slice.py @@ -1,9 +1,6 @@ -from sfa.util.faults import * from sfa.util.xrn import urn_to_hrn from sfa.util.method import Method from sfa.util.parameter import Parameter, Mixed -from sfa.trust.auth import Auth -from sfa.plc.slices import Slices class reset_slice(Method): """ diff --git a/sfa/plc/aggregate.py b/sfa/plc/aggregate.py index 5ad2bfff..4c398f99 100644 --- a/sfa/plc/aggregate.py +++ b/sfa/plc/aggregate.py @@ -1,10 +1,12 @@ #!/usr/bin/python -from sfa.util.xrn import * -from sfa.util.plxrn import * -#from sfa.rspecs.sfa_rspec import SfaRSpec -#from sfa.rspecs.pg_rspec import PGRSpec -#from sfa.rspecs.rspec_version import RSpecVersion +from sfa.util.xrn import hrn_to_urn, urn_to_hrn +from sfa.util.plxrn import PlXrn, hostname_to_urn, hrn_to_pl_slicename + from sfa.rspecs.rspec import RSpec +from sfa.rspecs.elements.link import Link +from sfa.rspecs.elements.interface import Interface + +from sfa.util.topology import Topology from sfa.rspecs.version_manager import VersionManager from sfa.plc.vlink import get_tc_rate @@ -25,14 +27,26 @@ class Aggregate: self.api = api self.user_options = user_options - def prepare_sites(self, force=False): + def prepare_sites(self, filter={}, force=False): if not self.sites or force: - for site in self.api.plshell.GetSites(self.api.plauth): + for site in self.api.plshell.GetSites(self.api.plauth, filter): self.sites[site['site_id']] = site - def prepare_nodes(self, force=False): + def prepare_nodes(self, filter={}, force=False): if not self.nodes or force: - for node in self.api.plshell.GetNodes(self.api.plauth, {'peer_id': None}): + filter.update({'peer_id': None}) + nodes = self.api.plshell.GetNodes(self.api.plauth, filter) + site_ids = [] + interface_ids = [] + tag_ids = [] + for node in nodes: + site_ids.append(node['site_id']) + interface_ids.extend(node['interface_ids']) + tag_ids.extend(node['node_tag_ids']) + self.prepare_sites({'site_id': site_ids}) + self.prepare_interfaces({'interface_id': interface_ids}) + self.prepare_node_tags({'node_tag_id': tag_ids}) + for node in nodes: # add site/interface info to nodes. # assumes that sites, interfaces and tags have already been prepared. site = self.sites[node['site_id']] @@ -47,37 +61,80 @@ class Aggregate: node['tags'] = tags self.nodes[node['node_id']] = node - def prepare_interfaces(self, force=False): + def prepare_interfaces(self, filter={}, force=False): if not self.interfaces or force: - for interface in self.api.plshell.GetInterfaces(self.api.plauth): + for interface in self.api.plshell.GetInterfaces(self.api.plauth, filter): self.interfaces[interface['interface_id']] = interface - def prepare_links(self, force=False): + def prepare_links(self, filter={}, force=False): if not self.links or force: - pass - - def prepare_node_tags(self, force=False): + if not self.api.config.SFA_AGGREGATE_TYPE.lower() == 'vini': + return + + topology = Topology() + for (site_id1, site_id2) in topology: + link = Link() + if not site_id1 in self.sites or site_id2 not in self.sites: + continue + site1 = self.sites[site_id1] + site2 = self.sites[site_id2] + # get hrns + site1_hrn = self.api.hrn + '.' + site1['login_base'] + site2_hrn = self.api.hrn + '.' + site2['login_base'] + # get the first node + node1 = self.nodes[site1['node_ids'][0]] + node2 = self.nodes[site2['node_ids'][0]] + + # set interfaces + # just get first interface of the first node + if1_xrn = PlXrn(auth=self.api.hrn, interface='node%s:eth0' % (node1['node_id'])) + if1_ipv4 = self.interfaces[node1['interface_ids'][0]]['ip'] + if2_xrn = PlXrn(auth=self.api.hrn, interface='node%s:eth0' % (node2['node_id'])) + if2_ipv4 = self.interfaces[node2['interface_ids'][0]]['ip'] + + if1 = Interface({'component_id': if1_xrn.urn, 'ipv4': if1_ipv4} ) + if2 = Interface({'component_id': if2_xrn.urn, 'ipv4': if2_ipv4} ) + + # set link + link = Link({'capacity': '1000000', 'latency': '0', 'packet_loss': '0', 'type': 'ipv4'}) + link['interface1'] = if1 + link['interface2'] = if2 + link['component_name'] = "%s:%s" % (site1['login_base'], site2['login_base']) + link['component_id'] = PlXrn(auth=self.api.hrn, interface=link['component_name']).get_urn() + link['component_manager_id'] = hrn_to_urn(self.api.hrn, 'authority+am') + self.links[link['component_name']] = link + + + def prepare_node_tags(self, filter={}, force=False): if not self.node_tags or force: - for node_tag in self.api.plshell.GetNodeTags(self.api.plauth): + for node_tag in self.api.plshell.GetNodeTags(self.api.plauth, filter): self.node_tags[node_tag['node_tag_id']] = node_tag - def prepare_pl_initscripts(self, force=False): + def prepare_pl_initscripts(self, filter={}, force=False): if not self.pl_initscripts or force: - for initscript in self.api.plshell.GetInitScripts(self.api.plauth, {'enabled': True}): + filter.update({'enabled': True}) + for initscript in self.api.plshell.GetInitScripts(self.api.plauth, filter): self.pl_initscripts[initscript['initscript_id']] = initscript - def prepare(self, force=False): - if not self.prepared or force: - self.prepare_sites(force) - self.prepare_interfaces(force) - self.prepare_node_tags(force) - self.prepare_nodes(force) - self.prepare_links(force) - self.prepare_pl_initscripts() - self.prepared = True + def prepare(self, slice = None, force=False): + if not self.prepared or force or slice: + if not slice: + self.prepare_sites(force=force) + self.prepare_interfaces(force=force) + self.prepare_node_tags(force=force) + self.prepare_nodes(force=force) + self.prepare_links(force=force) + self.prepare_pl_initscripts(force=force) + else: + self.prepare_sites({'site_id': slice['site_id']}) + self.prepare_interfaces({'node_id': slice['node_ids']}) + self.prepare_node_tags({'node_id': slice['node_ids']}) + self.prepare_nodes({'node_id': slice['node_ids']}) + self.prepare_links({'slice_id': slice['slice_id']}) + self.prepare_pl_initscripts() + self.prepared = True def get_rspec(self, slice_xrn=None, version = None): - self.prepare() version_manager = VersionManager() version = version_manager.get_version(version) if not slice_xrn: @@ -93,8 +150,11 @@ class Aggregate: slice_name = hrn_to_pl_slicename(slice_hrn) slices = self.api.plshell.GetSlices(self.api.plauth, slice_name) if slices: - slice = slices[0] - + slice = slices[0] + self.prepare(slice=slice) + else: + self.prepare() + # filter out nodes with a whitelist: valid_nodes = [] for node in self.nodes.values(): diff --git a/sfa/plc/network.py b/sfa/plc/network.py index 9276fb0f..5b2983b7 100644 --- a/sfa/plc/network.py +++ b/sfa/plc/network.py @@ -6,7 +6,7 @@ from StringIO import StringIO from lxml import etree from xmlbuilder import XMLBuilder -from sfa.util.faults import * +from sfa.util.faults import InvalidRSpec from sfa.util.xrn import get_authority from sfa.util.plxrn import hrn_to_pl_slicename, hostname_to_urn diff --git a/sfa/plc/plccomponentapi.py b/sfa/plc/plccomponentapi.py new file mode 100644 index 00000000..d3264827 --- /dev/null +++ b/sfa/plc/plccomponentapi.py @@ -0,0 +1,112 @@ +import os +import tempfile + +import sfa.util.xmlrpcprotocol as xmlrpcprotocol +from sfa.util.nodemanager import NodeManager + +from sfa.trust.credential import Credential +from sfa.trust.certificate import Certificate, Keypair +from sfa.trust.gid import GID + +from sfa.server.sfaapi import SfaApi + +#################### +class PlcComponentApi(SfaApi): + """ + This class is the type for the toplevel 'api' object + when running the component manager inside a planetlab node. + As such it runs an SFA-compliant interface and thus inherits SfaApi + However the fact that we run inside a planetlab nodes requires + some tweaks as compared with a service running in the infrastructure. + """ + + def __init__ (self, encoding="utf-8", methods='sfa.methods', + config = "/etc/sfa/sfa_config.py", + peer_cert = None, interface = None, + key_file = None, cert_file = None, cache = None): + SfaApi.__init__(self, encoding=encoding, methods=methods, + config=config, + peer_cert=peer_cert, interface=interface, + key_file=key_file, + cert_file=cert_file, cache=cache) + + self.nodemanager = NodeManager(self.config) + + def sliver_exists(self): + sliver_dict = self.nodemanager.GetXIDs() + ### xxx slicename is undefined + if slicename in sliver_dict.keys(): + return True + else: + return False + + def get_registry(self): + addr, port = self.config.SFA_REGISTRY_HOST, self.config.SFA_REGISTRY_PORT + url = "http://%(addr)s:%(port)s" % locals() + server = xmlrpcprotocol.get_server(url, self.key_file, self.cert_file) + return server + + def get_node_key(self): + # this call requires no authentication, + # so we can generate a random keypair here + subject="component" + (kfd, keyfile) = tempfile.mkstemp() + (cfd, certfile) = tempfile.mkstemp() + key = Keypair(create=True) + key.save_to_file(keyfile) + cert = Certificate(subject=subject) + cert.set_issuer(key=key, subject=subject) + cert.set_pubkey(key) + cert.sign() + cert.save_to_file(certfile) + registry = self.get_registry() + # the registry will scp the key onto the node + registry.get_key() + + # override the method in SfaApi + def getCredential(self): + """ + Get our credential from a remote registry + """ + path = self.config.SFA_DATA_DIR + config_dir = self.config.config_path + cred_filename = path + os.sep + 'node.cred' + try: + credential = Credential(filename = cred_filename) + return credential.save_to_string(save_parents=True) + except IOError: + node_pkey_file = config_dir + os.sep + "node.key" + node_gid_file = config_dir + os.sep + "node.gid" + cert_filename = path + os.sep + 'server.cert' + if not os.path.exists(node_pkey_file) or \ + not os.path.exists(node_gid_file): + self.get_node_key() + + # get node's hrn + gid = GID(filename=node_gid_file) + hrn = gid.get_hrn() + # get credential from registry + cert_str = Certificate(filename=cert_filename).save_to_string(save_parents=True) + registry = self.get_registry() + cred = registry.GetSelfCredential(cert_str, hrn, 'node') + # xxx credfile is undefined + Credential(string=cred).save_to_file(credfile, save_parents=True) + + return cred + + def clean_key_cred(self): + """ + remove the existing keypair and cred and generate new ones + """ + files = ["server.key", "server.cert", "node.cred"] + for f in files: + # xxx KEYDIR is undefined, could be meant to be "/var/lib/sfa/" from sfa_component_setup.py + filepath = KEYDIR + os.sep + f + if os.path.isfile(filepath): + os.unlink(f) + + # install the new key pair + # GetCredential will take care of generating the new keypair + # and credential + self.get_node_key() + self.getCredential() diff --git a/sfa/plc/api.py b/sfa/plc/plcsfaapi.py similarity index 59% rename from sfa/plc/api.py rename to sfa/plc/plcsfaapi.py index cad22671..842df318 100644 --- a/sfa/plc/api.py +++ b/sfa/plc/plcsfaapi.py @@ -1,64 +1,14 @@ +import xmlrpclib # -# SFA XML-RPC and SOAP interfaces -# +from sfa.util.faults import MissingSfaInfo +from sfa.util.sfalogging import logger +from sfa.util.table import SfaTable +from sfa.util.defaultdict import defaultdict -import sys -import os -import traceback -import string -import datetime -import xmlrpclib +from sfa.util.xrn import hrn_to_urn +from sfa.util.plxrn import slicename_to_hrn, hostname_to_hrn, hrn_to_pl_slicename, hrn_to_pl_login_base -from sfa.util.faults import * -from sfa.util.api import * -from sfa.util.config import * -from sfa.util.sfalogging import logger -import sfa.util.xmlrpcprotocol as xmlrpcprotocol -from sfa.trust.auth import Auth -from sfa.trust.rights import Right, Rights, determine_rights -from sfa.trust.credential import Credential,Keypair -from sfa.trust.certificate import Certificate -from sfa.util.xrn import get_authority, hrn_to_urn -from sfa.util.plxrn import hostname_to_hrn, hrn_to_pl_slicename, hrn_to_pl_slicename, slicename_to_hrn -from sfa.util.nodemanager import NodeManager -try: - from collections import defaultdict -except: - class defaultdict(dict): - def __init__(self, default_factory=None, *a, **kw): - if (default_factory is not None and - not hasattr(default_factory, '__call__')): - raise TypeError('first argument must be callable') - dict.__init__(self, *a, **kw) - self.default_factory = default_factory - def __getitem__(self, key): - try: - return dict.__getitem__(self, key) - except KeyError: - return self.__missing__(key) - def __missing__(self, key): - if self.default_factory is None: - raise KeyError(key) - self[key] = value = self.default_factory() - return value - def __reduce__(self): - if self.default_factory is None: - args = tuple() - else: - args = self.default_factory, - return type(self), args, None, None, self.items() - def copy(self): - return self.__copy__() - def __copy__(self): - return type(self)(self.default_factory, self) - def __deepcopy__(self, memo): - import copy - return type(self)(self.default_factory, - copy.deepcopy(self.items())) - def __repr__(self): - return 'defaultdict(%s, %s)' % (self.default_factory, - dict.__repr__(self)) -## end of http://code.activestate.com/recipes/523034/ }}} +from sfa.server.sfaapi import SfaApi def list_to_dict(recs, key): """ @@ -68,35 +18,19 @@ def list_to_dict(recs, key): keys = [rec[key] for rec in recs] return dict(zip(keys, recs)) -class SfaAPI(BaseAPI): - - # flat list of method names - import sfa.methods - methods = sfa.methods.all - - def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", - methods='sfa.methods', peer_cert = None, interface = None, - key_file = None, cert_file = None, cache = None): - BaseAPI.__init__(self, config=config, encoding=encoding, methods=methods, \ - peer_cert=peer_cert, interface=interface, key_file=key_file, \ - cert_file=cert_file, cache=cache) +class PlcSfaApi(SfaApi): + + def __init__ (self, encoding="utf-8", methods='sfa.methods', + config = "/etc/sfa/sfa_config.py", + peer_cert = None, interface = None, + key_file = None, cert_file = None, cache = None): + SfaApi.__init__(self, encoding=encoding, methods=methods, + config=config, + peer_cert=peer_cert, interface=interface, + key_file=key_file, + cert_file=cert_file, cache=cache) - self.encoding = encoding - from sfa.util.table import SfaTable self.SfaTable = SfaTable - # Better just be documenting the API - if config is None: - return - - # Load configuration - self.config = Config(config) - self.auth = Auth(peer_cert) - self.interface = interface - self.key_file = key_file - self.key = Keypair(filename=self.key_file) - self.cert_file = cert_file - self.cert = Certificate(filename=self.cert_file) - self.credential = None # Initialize the PLC shell only if SFA wraps a myPLC rspec_type = self.config.get_aggregate_type() if (rspec_type == 'pl' or rspec_type == 'vini' or \ @@ -104,10 +38,6 @@ class SfaAPI(BaseAPI): self.plshell = self.getPLCShell() self.plshell_version = "4.3" - self.hrn = self.config.SFA_INTERFACE_HRN - self.time_format = "%Y-%m-%d %H:%M:%S" - - def getPLCShell(self): self.plauth = {'Username': self.config.SFA_PLC_USER, 'AuthMethod': 'password', @@ -128,132 +58,6 @@ class SfaAPI(BaseAPI): shell = xmlrpclib.Server(url, verbose = 0, allow_none = True) return shell - def get_server(self, interface, cred, timeout=30): - """ - Returns a connection to the specified interface. Use the specified - credential to determine the caller and look for the caller's key/cert - in the registry hierarchy cache. - """ - from sfa.trust.hierarchy import Hierarchy - if not isinstance(cred, Credential): - cred_obj = Credential(string=cred) - else: - cred_obj = cred - caller_gid = cred_obj.get_gid_caller() - hierarchy = Hierarchy() - auth_info = hierarchy.get_auth_info(caller_gid.get_hrn()) - key_file = auth_info.get_privkey_filename() - cert_file = auth_info.get_gid_filename() - server = interface.get_server(key_file, cert_file, timeout) - return server - - - def getCredential(self): - """ - Return a valid credential for this interface. - """ - type = 'authority' - path = self.config.SFA_DATA_DIR - filename = ".".join([self.interface, self.hrn, type, "cred"]) - cred_filename = path + os.sep + filename - cred = None - if os.path.isfile(cred_filename): - cred = Credential(filename = cred_filename) - # make sure cred isnt expired - if not cred.get_expiration or \ - datetime.datetime.utcnow() < cred.get_expiration(): - return cred.save_to_string(save_parents=True) - - # get a new credential - if self.interface in ['registry']: - cred = self.__getCredentialRaw() - else: - cred = self.__getCredential() - cred.save_to_file(cred_filename, save_parents=True) - - return cred.save_to_string(save_parents=True) - - - def getDelegatedCredential(self, creds): - """ - Attempt to find a credential delegated to us in - the specified list of creds. - """ - from sfa.trust.hierarchy import Hierarchy - if creds and not isinstance(creds, list): - creds = [creds] - hierarchy = Hierarchy() - - delegated_cred = None - for cred in creds: - if hierarchy.auth_exists(Credential(string=cred).get_gid_caller().get_hrn()): - delegated_cred = cred - break - return delegated_cred - - def __getCredential(self): - """ - Get our credential from a remote registry - """ - from sfa.server.registry import Registries - registries = Registries() - registry = registries.get_server(self.hrn, self.key_file, self.cert_file) - cert_string=self.cert.save_to_string(save_parents=True) - # get self credential - self_cred = registry.GetSelfCredential(cert_string, self.hrn, 'authority') - # get credential - cred = registry.GetCredential(self_cred, self.hrn, 'authority') - return Credential(string=cred) - - def __getCredentialRaw(self): - """ - Get our current credential directly from the local registry. - """ - - hrn = self.hrn - auth_hrn = self.auth.get_authority(hrn) - - # is this a root or sub authority - if not auth_hrn or hrn == self.config.SFA_INTERFACE_HRN: - auth_hrn = hrn - auth_info = self.auth.get_auth_info(auth_hrn) - table = self.SfaTable() - records = table.findObjects({'hrn': hrn, 'type': 'authority+sa'}) - if not records: - raise RecordNotFound - record = records[0] - type = record['type'] - object_gid = record.get_gid_object() - new_cred = Credential(subject = object_gid.get_subject()) - new_cred.set_gid_caller(object_gid) - new_cred.set_gid_object(object_gid) - new_cred.set_issuer_keys(auth_info.get_privkey_filename(), auth_info.get_gid_filename()) - - r1 = determine_rights(type, hrn) - new_cred.set_privileges(r1) - new_cred.encode() - new_cred.sign() - - return new_cred - - - def loadCredential (self): - """ - Attempt to load credential from file if it exists. If it doesnt get - credential from registry. - """ - - # see if this file exists - # XX This is really the aggregate's credential. Using this is easier than getting - # the registry's credential from iteslf (ssl errors). - ma_cred_filename = self.config.SFA_DATA_DIR + os.sep + self.interface + self.hrn + ".ma.cred" - try: - self.credential = Credential(filename = ma_cred_filename) - except IOError: - self.credential = self.getCredentialFromRegistry() - - - ## # Convert SFA fields to PLC fields for use when registering up updating # registry record in the PLC database @@ -391,7 +195,7 @@ class SfaAPI(BaseAPI): for record in records: if 'site_id' in record: site_ids.append(record['site_id']) - if 'site_ids' in records: + if 'site_ids' in record: site_ids.extend(record['site_ids']) if 'person_ids' in record: person_ids.extend(record['person_ids']) @@ -576,7 +380,7 @@ class SfaAPI(BaseAPI): self.fill_record_sfa_info(records) def update_membership_list(self, oldRecord, record, listName, addFunc, delFunc): - # get a list of the HRNs tht are members of the old and new records + # get a list of the HRNs that are members of the old and new records if oldRecord: oldList = oldRecord.get(listName, []) else: @@ -623,97 +427,3 @@ class SfaAPI(BaseAPI): elif record.type == "authority": # xxx TODO pass - - - -class ComponentAPI(BaseAPI): - - def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", methods='sfa.methods', - peer_cert = None, interface = None, key_file = None, cert_file = None): - - BaseAPI.__init__(self, config=config, encoding=encoding, methods=methods, peer_cert=peer_cert, - interface=interface, key_file=key_file, cert_file=cert_file) - self.encoding = encoding - - # Better just be documenting the API - if config is None: - return - - self.nodemanager = NodeManager(self.config) - - def sliver_exists(self): - sliver_dict = self.nodemanager.GetXIDs() - if slicename in sliver_dict.keys(): - return True - else: - return False - - def get_registry(self): - addr, port = self.config.SFA_REGISTRY_HOST, self.config.SFA_REGISTRY_PORT - url = "http://%(addr)s:%(port)s" % locals() - server = xmlrpcprotocol.get_server(url, self.key_file, self.cert_file) - return server - - def get_node_key(self): - # this call requires no authentication, - # so we can generate a random keypair here - subject="component" - (kfd, keyfile) = tempfile.mkstemp() - (cfd, certfile) = tempfile.mkstemp() - key = Keypair(create=True) - key.save_to_file(keyfile) - cert = Certificate(subject=subject) - cert.set_issuer(key=key, subject=subject) - cert.set_pubkey(key) - cert.sign() - cert.save_to_file(certfile) - registry = self.get_registry() - # the registry will scp the key onto the node - registry.get_key() - - def getCredential(self): - """ - Get our credential from a remote registry - """ - path = self.config.SFA_DATA_DIR - config_dir = self.config.config_path - cred_filename = path + os.sep + 'node.cred' - try: - credential = Credential(filename = cred_filename) - return credential.save_to_string(save_parents=True) - except IOError: - node_pkey_file = config_dir + os.sep + "node.key" - node_gid_file = config_dir + os.sep + "node.gid" - cert_filename = path + os.sep + 'server.cert' - if not os.path.exists(node_pkey_file) or \ - not os.path.exists(node_gid_file): - self.get_node_key() - - # get node's hrn - gid = GID(filename=node_gid_file) - hrn = gid.get_hrn() - # get credential from registry - cert_str = Certificate(filename=cert_filename).save_to_string(save_parents=True) - registry = self.get_registry() - cred = registry.GetSelfCredential(cert_str, hrn, 'node') - Credential(string=cred).save_to_file(credfile, save_parents=True) - - return cred - - def clean_key_cred(self): - """ - remove the existing keypair and cred and generate new ones - """ - files = ["server.key", "server.cert", "node.cred"] - for f in files: - filepath = KEYDIR + os.sep + f - if os.path.isfile(filepath): - os.unlink(f) - - # install the new key pair - # GetCredential will take care of generating the new keypair - # and credential - self.get_node_key() - self.getCredential() - - diff --git a/sfa/plc/sfa-import-plc.py b/sfa/plc/sfa-import-plc.py index 95793a10..d57b4b86 100755 --- a/sfa/plc/sfa-import-plc.py +++ b/sfa/plc/sfa-import-plc.py @@ -14,22 +14,17 @@ # RSA keys at this time, not DSA keys. ## +import os import getopt import sys -import tempfile -from sfa.util.record import * from sfa.util.table import SfaTable from sfa.util.xrn import get_leaf, get_authority from sfa.util.plxrn import hostname_to_hrn, slicename_to_hrn, email_to_hrn, hrn_to_pl_slicename from sfa.util.config import Config -from sfa.trust.certificate import convert_public_key, Keypair -from sfa.trust.trustedroots import * -from sfa.trust.hierarchy import * from sfa.util.xrn import Xrn -from sfa.plc.api import * -from sfa.trust.gid import create_uuid -from sfa.plc.sfaImport import sfaImport, _cleanup_string + +from sfa.plc.sfaImport import sfaImport def process_options(): @@ -125,7 +120,8 @@ def main(): sites_dict[site['login_base']] = site # Get all plc users - persons = shell.GetPersons(plc_auth, {'peer_id': None, 'enabled': True}, ['person_id', 'email', 'key_ids', 'site_ids']) + persons = shell.GetPersons(plc_auth, {'peer_id': None, 'enabled': True}, + ['person_id', 'email', 'key_ids', 'site_ids']) persons_dict = {} for person in persons: persons_dict[person['person_id']] = person diff --git a/sfa/plc/sfa-nuke-plc.py b/sfa/plc/sfa-nuke-plc.py index fb84020b..be7b0c10 100755 --- a/sfa/plc/sfa-nuke-plc.py +++ b/sfa/plc/sfa-nuke-plc.py @@ -11,8 +11,6 @@ import sys import os from optparse import OptionParser -from sfa.trust.hierarchy import * -from sfa.util.record import * from sfa.util.table import SfaTable from sfa.util.sfalogging import logger diff --git a/sfa/plc/sfaImport.py b/sfa/plc/sfaImport.py index 1effe718..4de6e1bd 100644 --- a/sfa/plc/sfaImport.py +++ b/sfa/plc/sfaImport.py @@ -8,20 +8,16 @@ # RSA keys at this time, not DSA keys. ## -import getopt -import sys -import tempfile - from sfa.util.sfalogging import _SfaLogger -from sfa.util.record import * +from sfa.util.record import SfaRecord from sfa.util.table import SfaTable from sfa.util.xrn import get_authority, hrn_to_urn from sfa.util.plxrn import email_to_hrn from sfa.util.config import Config from sfa.trust.certificate import convert_public_key, Keypair from sfa.trust.trustedroots import TrustedRoots -from sfa.trust.hierarchy import * +from sfa.trust.hierarchy import Hierarchy from sfa.trust.gid import create_uuid @@ -234,8 +230,6 @@ class sfaImport: def import_site(self, hrn, site): - shell = self.shell - plc_auth = self.plc_auth urn = hrn_to_urn(hrn, 'authority') self.logger.info("Import: site %s"%hrn) diff --git a/sfa/plc/slices.py b/sfa/plc/slices.py index 557fc370..5cead3be 100644 --- a/sfa/plc/slices.py +++ b/sfa/plc/slices.py @@ -1,18 +1,12 @@ -import datetime -import time -import traceback -import sys - from types import StringTypes -from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn, urn_to_hrn -from sfa.util.plxrn import hrn_to_pl_slicename, hrn_to_pl_login_base -from sfa.util.specdict import * -from sfa.util.faults import * -from sfa.util.record import SfaRecord +from collections import defaultdict + +from sfa.util.xrn import get_leaf, get_authority, urn_to_hrn +from sfa.util.plxrn import hrn_to_pl_slicename from sfa.util.policy import Policy +from sfa.rspecs.rspec import RSpec from sfa.plc.vlink import VLink -from sfa.util.prefixTree import prefixTree -from collections import defaultdict +from sfa.util.xrn import Xrn MAXINT = 2L**31-1 @@ -190,11 +184,22 @@ class Slices: except: self.api.logger.log_exc('Failed to add/remove slice from nodes') - def verify_slice_links(self, slice, links, peer=None): - if not links or not nodes: + def verify_slice_links(self, slice, links, aggregate): + # nodes is undefined here + if not links: return + for link in links: - topo_rspec = VLink.get_topo_rspec(link) + # get the ip address of the first node in the link + ifname1 = Xrn(link['interface1']['component_id']).get_leaf() + (node, device) = ifname1.split(':') + node_id = int(node.replace('node', '')) + node = aggregate.nodes[node_id] + if1 = aggregate.interfaces[node['interface_ids'][0]] + ipaddr = if1['ip'] + topo_rspec = VLink.get_topo_rspec(link, ipaddr) + self.api.plshell.AddSliceTag(self.api.plauth, slice['name'], 'topo_rspec', str([topo_rspec]), node_id) + def handle_peer(self, site, slice, persons, peer): @@ -542,8 +547,7 @@ class Slices: # add requested_attributes for attribute in added_slice_attributes: try: - name, value, node_id = attribute['name'], attribute['value'], attribute.get('node_id', None) - self.api.plshell.AddSliceTag(self.api.plauth, slice['name'], name, value, node_id) + self.api.plshell.AddSliceTag(self.api.plauth, slice['name'], attribute['name'], attribute['value'], attribute.get('node_id', None)) except Exception, e: self.api.logger.warn('Failed to add sliver attribute. name: %s, value: %s, node_id: %s\nCause:%s'\ % (name, value, node_id, str(e))) diff --git a/sfa/plc/vini_aggregate.py b/sfa/plc/vini_aggregate.py deleted file mode 100644 index b5663b2b..00000000 --- a/sfa/plc/vini_aggregate.py +++ /dev/null @@ -1,40 +0,0 @@ -from sfa.plc.aggregate import Aggregate -from sfa.managers.vini.topology import PhysicalLinks -from sfa.rspecs.elements.link import Link -from sfa.util.xrn import hrn_to_urn -from sfa.util.plxrn import PlXrn - -class ViniAggregate(Aggregate): - - def prepare_links(self, force=False): - for (site_id1, site_id2) in PhysicalLinks: - link = Link() - if not site_id1 in self.sites or site_id2 not in self.sites: - continue - site1 = self.sites[site_id1] - site2 = self.sites[site_id2] - # get hrns - site1_hrn = self.api.hrn + '.' + site1['login_base'] - site2_hrn = self.api.hrn + '.' + site2['login_base'] - # get the first node - node1 = self.nodes[site1['node_id'][0]] - node2 = self.nodes[site2['node_id'][0]] - - # set interfaces - # just get first interface of the first node - if1_xrn = PlXrn(auth=self.api.hrn, interface='node%s:eth0' % (node1['node_id'])) - if2_xrn = PlXrn(auth=self.api.hrn, interface='node%s:eth0' % (node2['node_id'])) - - if1 = Interface({'component_id': if1_xrn.urn} ) - if2 = Interface({'component_id': if2_xrn.urn} ) - - # set link - link = Link({'capacity': '1000000', 'latency': '0', 'packet_loss': '0', 'type': 'ipv4'}) - link['interface1'] = if1 - link['interface2'] = if2 - link['component_name'] = "%s:%s" % (site1['login_base'], site2['login_base']) - link['component_id'] = PlXrn(auth=self.api.hrn, link=link['component_name']) - link['component_manager_id'] = hrn_to_urn(self.api.hrn, 'authority+am') - self.links[link['component_name']] = link - - diff --git a/sfa/plc/vlink.py b/sfa/plc/vlink.py index 8aeee498..625963d0 100644 --- a/sfa/plc/vlink.py +++ b/sfa/plc/vlink.py @@ -1,5 +1,5 @@ - -from sfa.util.plxrn import PlXrn +import re +from sfa.util.xrn import Xrn # Taken from bwlimit.py # # See tc_util.c and http://physics.nist.gov/cuu/Units/binary.html. Be @@ -79,34 +79,37 @@ class VLink: @staticmethod def get_virt_ip(if1, if2): - link_id = get_link_id(if1, if2) - iface_id = get_iface_id(if1, if2) + link_id = VLink.get_link_id(if1, if2) + iface_id = VLink.get_iface_id(if1, if2) first = link_id >> 6 second = ((link_id & 0x3f)<<2) + iface_id - return "192.168.%d.%s" % (frist, second) + return "192.168.%d.%s" % (first, second) @staticmethod def get_virt_net(link): - link_id = self.get_link_id(link) + link_id = VLink.get_link_id(link['interface1'], link['interface2']) first = link_id >> 6 second = (link_id & 0x3f)<<2 return "192.168.%d.%d/30" % (first, second) @staticmethod def get_interface_id(interface): - if_name = PlXrn(interface=interface['component_id']).interface_name() + if_name = Xrn(interface['component_id']).get_leaf() node, dev = if_name.split(":") - node_id = int(node.replace("pc", "")) + node_id = int(node.replace("node", "")) return node_id @staticmethod - def get_topo_rspec(link): + def get_topo_rspec(link, ipaddr): link['interface1']['id'] = VLink.get_interface_id(link['interface1']) link['interface2']['id'] = VLink.get_interface_id(link['interface2']) my_ip = VLink.get_virt_ip(link['interface1'], link['interface2']) remote_ip = VLink.get_virt_ip(link['interface2'], link['interface1']) net = VLink.get_virt_net(link) bw = format_tc_rate(long(link['capacity'])) - ipaddr = remote.get_primary_iface().ipv4 - return (link['interface2']['id'], ipaddr, bw, my_ip, remote_ip, net) + return (link['interface2']['id'], ipaddr, bw, my_ip, remote_ip, net) + + @staticmethod + def topo_rspec_to_link(topo_rspec): + pass diff --git a/sfa/rspecs/elements/interface.py b/sfa/rspecs/elements/interface.py index d2022d89..2aadf4db 100644 --- a/sfa/rspecs/elements/interface.py +++ b/sfa/rspecs/elements/interface.py @@ -1,10 +1,12 @@ class Interface(dict): + element = None fields = {'component_id': None, 'role': None, 'client_id': None, - 'ipv4': None + 'ipv4': None, } - def __init__(self, fields={}): + def __init__(self, fields={}, element=None): + self.element = element dict.__init__(self, Interface.fields) self.update(fields) diff --git a/sfa/rspecs/elements/link.py b/sfa/rspecs/elements/link.py index 4722cf83..d916d22f 100644 --- a/sfa/rspecs/elements/link.py +++ b/sfa/rspecs/elements/link.py @@ -1,7 +1,7 @@ from sfa.rspecs.elements.interface import Interface class Link(dict): - + element = None fields = { 'client_id': None, 'component_id': None, @@ -16,7 +16,8 @@ class Link(dict): 'description': None, } - def __init__(self, fields={}): + def __init__(self, fields={}, element=None): + self.element = element dict.__init__(self, Link.fields) self.update(fields) diff --git a/sfa/rspecs/elements/versions/pgv2Link.py b/sfa/rspecs/elements/versions/pgv2Link.py index 1f19a7ad..aeef7602 100644 --- a/sfa/rspecs/elements/versions/pgv2Link.py +++ b/sfa/rspecs/elements/versions/pgv2Link.py @@ -1,4 +1,6 @@ from lxml import etree +from sfa.util.plxrn import PlXrn +from sfa.util.xrn import Xrn from sfa.rspecs.elements.link import Link from sfa.rspecs.elements.interface import Interface from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements @@ -7,7 +9,10 @@ class PGv2Link: elements = { 'link': RSpecElement(RSpecElements.LINK, '//default:link | //link'), - 'component_manager': RSpecElement(RSpecElements.COMPONENT_MANAGER, './default:component_manager | ./component_manager') + 'component_manager': RSpecElement(RSpecElements.COMPONENT_MANAGER, './default:component_manager | ./component_manager'), + 'link_type': RSpecElement(RSpecElements.LINK_TYPE, './default:link_type | ./link_type'), + 'property': RSpecElement(RSpecElements.PROPERTY, './default:property | ./property'), + 'interface_ref': RSpecElement(RSpecElements.INTERFACE_REF, './default:interface_ref | ./interface_ref') } @staticmethod @@ -15,46 +20,45 @@ class PGv2Link: for link in links: link_elem = etree.SubElement(xml, 'link') for attrib in ['component_name', 'component_id', 'client_id']: - if attrib in link and link[attrib]: + if attrib in link and link[attrib] is not None: link_elem.set(attrib, link[attrib]) if 'component_manager' in link and link['component_manager']: - cm_element = etree.SubElement(xml, 'component_manager', name=link['component_manager']) + cm_element = etree.SubElement(link_elem, 'component_manager', name=link['component_manager']) for if_ref in [link['interface1'], link['interface2']]: - if_ref_elem = etree.SubElement(xml, 'interface_ref') + if_ref_elem = etree.SubElement(link_elem, 'interface_ref') for attrib in Interface.fields: if attrib in if_ref and if_ref[attrib]: if_ref_elem.attrib[attrib] = if_ref[attrib] - prop1 = etree.SubElement(xml, 'property', source_id = link['interface1']['component_id'], + prop1 = etree.SubElement(link_elem, 'property', source_id = link['interface1']['component_id'], dest_id = link['interface2']['component_id'], capacity=link['capacity'], latency=link['latency'], packet_loss=link['packet_loss']) - prop2 = etree.SubElement(xml, 'property', source_id = link['interface2']['component_id'], + prop2 = etree.SubElement(link_elem, 'property', source_id = link['interface2']['component_id'], dest_id = link['interface1']['component_id'], capacity=link['capacity'], latency=link['latency'], packet_loss=link['packet_loss']) if 'type' in link and link['type']: - type_elem = etree.SubElement(xml, 'link_type', name=link['type']) - + type_elem = etree.SubElement(link_elem, 'link_type', name=link['type']) @staticmethod - def get_links(xml, namespaces=None): + def get_links(xml): links = [] - link_elems = xml.xpath('//default:link', namespaces=namespaces) + link_elems = xml.xpath(PGv2Link.elements['link'].path, namespaces=xml.namespaces) for link_elem in link_elems: # set client_id, component_id, component_name - link = Link(link_elem.attrib) + link = Link(link_elem.attrib, link_elem) # set component manager - cm = link_elem.xpath('./default:component_manager', namespaces=namespaces) + cm = link_elem.xpath('./default:component_manager', namespaces=xml.namespaces) if len(cm) > 0: cm = cm[0] if 'name' in cm.attrib: link['component_manager'] = cm.attrib['name'] # set link type - link_types = link_elem.xpath('./default:link_type', namespaces=namespaces) + link_types = link_elem.xpath(PGv2Link.elements['link_type'].path, namespaces=xml.namespaces) if len(link_types) > 0: link_type = link_types[0] if 'name' in link_type.attrib: link['type'] = link_type.attrib['name'] # get capacity, latency and packet_loss from first property - props = link_elem.xpath('./default:property', namespaces=namespaces) + props = link_elem.xpath(PGv2Link.elements['property'].path, namespaces=xml.namespaces) if len(props) > 0: prop = props[0] for attrib in ['capacity', 'latency', 'packet_loss']: @@ -62,10 +66,10 @@ class PGv2Link: link[attrib] = prop.attrib[attrib] # get interfaces - if_elems = link_elem.xpath('./default:interface_ref', namespaces=namespaces) + if_elems = link_elem.xpath(PGv2Link.elements['interface_ref'].path, namespaces=xml.namespaces) ifs = [] for if_elem in if_elems: - if_ref = Interface(if_elem.attrib) + if_ref = Interface(if_elem.attrib, if_elem) ifs.append(if_ref) if len(ifs) > 1: link['interface1'] = ifs[0] @@ -73,3 +77,42 @@ class PGv2Link: links.append(link) return links + @staticmethod + def add_link_requests(xml, link_tuples, append=False): + if not isinstance(link_tuples, set): + link_tuples = set(link_tuples) + + available_links = PGv2Link.get_links(xml) + recently_added = [] + for link in available_links: + if_name1 = Xrn(link['interface1']['component_id']).get_leaf() + if_name2 = Xrn(link['interface2']['component_id']).get_leaf() + + requested_link = None + l_tup_1 = (if_name1, if_name2) + l_tup_2 = (if_name2, if_name1) + if link_tuples.issuperset([(if_name1, if_name2)]): + requested_link = (if_name1, if_name2) + elif link_tuples.issuperset([(if_name2, if_name2)]): + requested_link = (if_name2, if_name1) + if requested_link: + # add client id to link ane interface elements + link.element.set('client_id', link['component_name']) + link['interface1'].element.set('client_id', Xrn(link['interface1']['component_id']).get_leaf()) + link['interface2'].element.set('client_id', Xrn(link['interface2']['component_id']).get_leaf()) + recently_added.append(link['component_name']) + + if not append: + # remove all links that don't have a client id + for link in PGv2Link.get_links(xml): + if not link['client_id'] or link['component_name'] not in recently_added: + parent = link.element.getparent() + parent.remove(link.element) + + @staticmethod + def get_link_requests(xml): + link_requests = [] + for link in PGv2Link.get_links(xml): + if link['client_id'] != None: + link_requests.append(link) + return link_requests diff --git a/sfa/rspecs/pg_rspec_converter.py b/sfa/rspecs/pg_rspec_converter.py index 42e7ccdf..1c57d7da 100755 --- a/sfa/rspecs/pg_rspec_converter.py +++ b/sfa/rspecs/pg_rspec_converter.py @@ -1,7 +1,7 @@ #!/usr/bin/python from lxml import etree from StringIO import StringIO -from sfa.util.xrn import * +from sfa.util.xrn import Xrn, urn_to_hrn from sfa.rspecs.rspec import RSpec from sfa.rspecs.version_manager import VersionManager diff --git a/sfa/rspecs/rspec.py b/sfa/rspecs/rspec.py index b86f996a..a04ff288 100755 --- a/sfa/rspecs/rspec.py +++ b/sfa/rspecs/rspec.py @@ -1,11 +1,11 @@ #!/usr/bin/python from datetime import datetime, timedelta + from sfa.util.xml import XML, XpathFilter -from sfa.rspecs.version_manager import VersionManager -from sfa.util.xrn import * -from sfa.util.plxrn import hostname_to_urn +from sfa.util.faults import InvalidRSpecElement + from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements -from sfa.util.faults import SfaNotImplemented, InvalidRSpec, InvalidRSpecElement +from sfa.rspecs.version_manager import VersionManager class RSpec: @@ -70,7 +70,7 @@ class RSpec: def get(self, element_type, filter={}, depth=0): elements = self.get_elements(element_type, filter) - elements = [self.get_element_attributes(element, depth=depth) for element in elements] + elements = [self.xml.get_element_attributes(elem, depth=depth) for elem in elements] return elements def get_elements(self, element_type, filter={}): @@ -83,7 +83,7 @@ class RSpec: raise InvalidRSpecElement(element_type, extra=msg) rspec_element = self.get_rspec_element(element_type) xpath = rspec_element.path + XpathFilter.xpath(filter) - return self.xpath(xpath) + return self.xml.xpath(xpath) def merge(self, in_rspec): self.version.merge(in_rspec) diff --git a/sfa/rspecs/rspec_elements.py b/sfa/rspecs/rspec_elements.py index 42091397..3226f589 100644 --- a/sfa/rspecs/rspec_elements.py +++ b/sfa/rspecs/rspec_elements.py @@ -4,10 +4,14 @@ from sfa.util.enumeration import Enum RSpecElements = Enum(NETWORK='NETWORK', COMPONENT_MANAGER='COMPONENT_MANAGER', SLIVER='SLIVER', + SLIVER_TYPE='SLIVER_TYPE', NODE='NODE', INTERFACE='INTERFACE', + INTERFACE_REF='INTERFACE_REF', LINK='LINK', - SERVICE='SERVICE' + LINK_TYPE='LINK_TYPE', + SERVICE='SERVICE', + PROPERTY='PROPERTY' ) class RSpecElement: diff --git a/sfa/rspecs/sfa_rspec_converter.py b/sfa/rspecs/sfa_rspec_converter.py index 6ba56c1d..7bcc7878 100755 --- a/sfa/rspecs/sfa_rspec_converter.py +++ b/sfa/rspecs/sfa_rspec_converter.py @@ -1,8 +1,6 @@ #!/usr/bin/python -from lxml import etree -from StringIO import StringIO -from sfa.util.xrn import * +from sfa.util.xrn import hrn_to_urn from sfa.rspecs.rspec import RSpec from sfa.rspecs.version_manager import VersionManager diff --git a/sfa/rspecs/versions/pgv2.py b/sfa/rspecs/versions/pgv2.py index 3e995ad9..b57cd9bb 100644 --- a/sfa/rspecs/versions/pgv2.py +++ b/sfa/rspecs/versions/pgv2.py @@ -1,7 +1,7 @@ from lxml import etree from copy import deepcopy from StringIO import StringIO -from sfa.util.xrn import * +from sfa.util.xrn import urn_to_sliver_id from sfa.util.plxrn import hostname_to_urn, xrn_to_hostname from sfa.rspecs.rspec_version import BaseVersion from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements @@ -93,12 +93,17 @@ class PGv2(BaseVersion): return slice_attributes def get_links(self, network=None): - links = PGv2Link.get_links(self.xml.root, self.namespaces) - return links + return PGv2Link.get_links(self.xml) + + def get_link_requests(self): + return PGv2Link.get_link_requests(self.xml) def add_links(self, links): PGv2Link.add_links(self.xml.root, links) + def add_link_requests(self, link_tuples, append=False): + PGv2Link.add_link_requests(self.xml.root, link_tuples, append) + def attributes_list(self, elem): opts = [] if elem is not None: diff --git a/sfa/rspecs/versions/sfav1.py b/sfa/rspecs/versions/sfav1.py index 06f277a1..3917b39b 100644 --- a/sfa/rspecs/versions/sfav1.py +++ b/sfa/rspecs/versions/sfav1.py @@ -1,5 +1,7 @@ +from copy import deepcopy from lxml import etree from sfa.util.xrn import hrn_to_urn, urn_to_hrn +from sfa.util.plxrn import PlXrn from sfa.rspecs.rspec_version import BaseVersion from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements from sfa.rspecs.elements.versions.pgv2Link import PGv2Link @@ -113,8 +115,10 @@ class SFAv1(BaseVersion): return nodes def get_links(self, network=None): - links = PGv2Link.get_links(self.xml, self.namespaces) - return links + return PGv2Link.get_links(self.xml) + + def get_link_requests(self): + return PGv2Link.get_link_requests(self.xml) def get_link(self, fromnode, tonode, network=None): fromsite = fromnode.getparent() @@ -213,8 +217,9 @@ class SFAv1(BaseVersion): for interface in node['interfaces']: if 'bwlimit' in interface and interface['bwlimit']: bwlimit = etree.SubElement(node_tag, 'bw_limit', units='kbps').text = str(interface['bwlimit']/1000) - comp_id = hrn_to_urn(network, 'pc%s:eth%s' % (node['node_id'], i)) - interface_tag = etree.SubElement(node_tag, 'interface', component_id=comp_id) + comp_id = PlXrn(auth=network, interface='node%s:eth%s' % (node['node_id'], i)).get_urn() + ipaddr = interface['ip'] + interface_tag = etree.SubElement(node_tag, 'interface', component_id=comp_id, ipv4=ipaddr) i+=1 if 'bw_unallocated' in node: bw_unallocated = etree.SubElement(node_tag, 'bw_unallocated', units='kbps').text = str(node['bw_unallocated']/1000) @@ -242,7 +247,15 @@ class SFAv1(BaseVersion): pass def add_links(self, links): - PGv2Link.add_links(self.xml, links) + networks = self.get_network_elements() + if len(networks) > 0: + xml = networks[0] + else: + xml = self.xml + PGv2Link.add_links(xml, links) + + def add_link_requests(self, links): + PGv2Link.add_link_requests(self.xml, links) def add_slivers(self, slivers, network=None, sliver_urn=None, no_dupes=False, append=False): # add slice name to network tag diff --git a/sfa/server/aggregate.py b/sfa/server/aggregate.py index 59a3e6b8..e7340e10 100644 --- a/sfa/server/aggregate.py +++ b/sfa/server/aggregate.py @@ -1,5 +1,4 @@ -from sfa.util.faults import * -from sfa.util.server import SfaServer +from sfa.server.sfaserver import SfaServer from sfa.util.xrn import hrn_to_urn from sfa.server.interface import Interfaces, Interface from sfa.util.config import Config diff --git a/sfa/server/component.py b/sfa/server/component.py index 1dc66523..3958c5fb 100644 --- a/sfa/server/component.py +++ b/sfa/server/component.py @@ -6,7 +6,7 @@ import os import time import sys -from sfa.util.componentserver import ComponentServer +from sfa.server.sfaserver import SfaServer # GeniLight client support is optional try: @@ -17,7 +17,8 @@ except ImportError: ## # Component is a SfaServer that serves component operations. -class Component(ComponentServer): +# set SFA_GENERIC_FLAVOUR=plcm to get a PlcComponentApi instance in the request handler +class Component(SfaServer): ## # Create a new registry object. # @@ -27,5 +28,5 @@ class Component(ComponentServer): # @param cert_file certificate filename containing public key (could be a GID file) def __init__(self, ip, port, key_file, cert_file): - ComponentServer.__init__(self, ip, port, key_file, cert_file) + SfaServer.__init__(self, ip, port, key_file, cert_file) self.server.interface = 'component' diff --git a/sfa/server/interface.py b/sfa/server/interface.py index dbc8ef20..94302ecf 100644 --- a/sfa/server/interface.py +++ b/sfa/server/interface.py @@ -1,13 +1,6 @@ -import traceback -import os.path - -from sfa.util.faults import * -from sfa.util.storage import XmlStorage -from sfa.util.xrn import get_authority, hrn_to_urn -from sfa.util.record import SfaRecord +#from sfa.util.faults import * import sfa.util.xmlrpcprotocol as xmlrpcprotocol -import sfa.util.soapprotocol as soapprotocol -from sfa.trust.gid import GID +from sfa.util.xml import XML # GeniLight client support is optional try: @@ -15,10 +8,11 @@ try: except ImportError: GeniClientLight = None - - class Interface: - + """ + Interface to another SFA service, typically a peer, or the local aggregate + can retrieve a xmlrpclib.ServerProxy object for issuing calls there + """ def __init__(self, hrn, addr, port, client_type='sfa'): self.hrn = hrn self.addr = addr @@ -34,6 +28,7 @@ class Interface: def get_server(self, key_file, cert_file, timeout=30): server = None if self.client_type == 'geniclientlight' and GeniClientLight: + # xxx url and self.api are undefined server = GeniClientLight(url, self.api.key_file, self.api.cert_file) else: server = xmlrpcprotocol.get_server(self.get_url(), key_file, cert_file, timeout) @@ -62,21 +57,20 @@ class Interfaces(dict): def __init__(self, conf_file): dict.__init__(self, {}) # load config file - self.interface_info = XmlStorage(conf_file, self.default_dict) - self.interface_info.load() - records = self.interface_info.values()[0] - if not isinstance(records, list): - records = [records] - - required_fields = self.default_fields.keys() - for record in records: - if not record or not set(required_fields).issubset(record.keys()): - continue - # port is appended onto the domain, before the path. Should look like: - # http://domain:port/path - hrn, address, port = record['hrn'], record['addr'], record['port'] - interface = Interface(hrn, address, port) - self[hrn] = interface + required_fields = set(self.default_fields.keys()) + self.interface_info = XML(conf_file).todict() + for value in self.interface_info.values(): + if isinstance(value, list): + for record in value: + if isinstance(record, dict) and \ + required_fields.issubset(record.keys()): + hrn, address, port = record['hrn'], record['addr'], record['port'] + # sometime this is called at a very early stage with no config loaded + # avoid to remember this instance in such a case + if not address or not port: + continue + interface = Interface(hrn, address, port) + self[hrn] = interface def get_server(self, hrn, key_file, cert_file, timeout=30): return self[hrn].get_server(key_file, cert_file, timeout) diff --git a/sfa/server/modpython/SfaAggregateModPython.py b/sfa/server/modpython/SfaAggregateModPython.py index deaf89f9..8638ce9c 100755 --- a/sfa/server/modpython/SfaAggregateModPython.py +++ b/sfa/server/modpython/SfaAggregateModPython.py @@ -12,10 +12,10 @@ import traceback import xmlrpclib from mod_python import apache -from sfa.plc.api import SfaAPI +from sfa.plc.plcsfaapi import PlcSfaApi from sfa.util.sfalogging import logger -api = SfaAPI(interface='aggregate') +api = PlcSfaApi(interface='aggregate') def handler(req): try: diff --git a/sfa/server/modpython/SfaRegistryModPython.py b/sfa/server/modpython/SfaRegistryModPython.py index 8879813a..115fcbaf 100755 --- a/sfa/server/modpython/SfaRegistryModPython.py +++ b/sfa/server/modpython/SfaRegistryModPython.py @@ -12,10 +12,10 @@ import traceback import xmlrpclib from mod_python import apache -from sfa.plc.api import SfaAPI +from sfa.plc.plcsfaapi import PlcSfaApi from sfa.util.sfalogging import logger -api = SfaAPI(interface='registry') +api = PlcSfaApi(interface='registry') def handler(req): try: diff --git a/sfa/server/modpython/SfaSliceMgrModPython.py b/sfa/server/modpython/SfaSliceMgrModPython.py index e0f2b923..3de4519d 100755 --- a/sfa/server/modpython/SfaSliceMgrModPython.py +++ b/sfa/server/modpython/SfaSliceMgrModPython.py @@ -12,10 +12,10 @@ import traceback import xmlrpclib from mod_python import apache -from sfa.plc.api import SfaAPI +from sfa.plc.plcsfaapi import PlcSfaApi from sfa.util.sfalogging import logger -api = SfaAPI(interface='slicemgr') +api = PlcSfaApi(interface='slicemgr') def handler(req): try: diff --git a/sfa/server/registry.py b/sfa/server/registry.py index b2548113..2a37c22a 100644 --- a/sfa/server/registry.py +++ b/sfa/server/registry.py @@ -1,9 +1,7 @@ # # Registry is a SfaServer that implements the Registry interface # -from sfa.util.server import SfaServer -from sfa.util.faults import * -from sfa.util.xrn import hrn_to_urn +from sfa.server.sfaserver import SfaServer from sfa.server.interface import Interfaces, Interface from sfa.util.config import Config diff --git a/sfa/server/sfa-ca.py b/sfa/server/sfa-ca.py index 8297b2dd..0fbe140f 100755 --- a/sfa/server/sfa-ca.py +++ b/sfa/server/sfa-ca.py @@ -21,11 +21,12 @@ import os import sys from optparse import OptionParser -from sfa.trust.certificate import Keypair, Certificate + +from sfa.util.config import Config +from sfa.util.table import SfaTable + from sfa.trust.gid import GID, create_uuid from sfa.trust.hierarchy import Hierarchy -from sfa.util.config import Config -from collections import defaultdict def main(): args = sys.argv @@ -110,7 +111,6 @@ def sign(options): def export_gid(options): - from sfa.util.table import SfaTable # lookup the record for the specified hrn hrn = options.export type = options.type @@ -124,7 +124,7 @@ def export_gid(options): # check the authorities hierarchy hierarchy = Hierarchy() try: - auth_info = hierarchy.get_auth_info() + auth_info = hierarchy.get_auth_info(hrn) gid = auth_info.gid_object except: print "Record: %s not found" % hrn @@ -148,8 +148,6 @@ def import_gid(options): Import the specified gid into the registry (db and authorities hierarchy) overwriting any previous gid. """ - from sfa.util.table import SfaTable - from sfa.util.record import SfaRecord # load the gid gidfile = os.path.abspath(options.importgid) if not gidfile or not os.path.isfile(gidfile): @@ -167,7 +165,7 @@ def import_gid(options): table = SfaTable() records = table.find({'hrn': gid.get_hrn(), 'type': 'authority'}) if not records: - print "%s not found in record database" % get.get_hrn() + print "%s not found in record database" % gid.get_hrn() sys.exit(1) # update the database record diff --git a/sfa/server/sfa-clean-peer-records.py b/sfa/server/sfa-clean-peer-records.py index f821f4ce..93fef143 100644 --- a/sfa/server/sfa-clean-peer-records.py +++ b/sfa/server/sfa-clean-peer-records.py @@ -5,7 +5,7 @@ import os import traceback from sfa.util.table import SfaTable from sfa.util.prefixTree import prefixTree -from sfa.plc.api import SfaAPI +from sfa.plc.plcsfaapi import PlcSfaApi from sfa.util.config import Config from sfa.trust.certificate import Keypair from sfa.trust.hierarchy import Hierarchy @@ -31,7 +31,7 @@ def main(): authority = config.SFA_INTERFACE_HRN url = 'http://%s:%s/' %(config.SFA_REGISTRY_HOST, config.SFA_REGISTRY_PORT) registry = xmlrpcprotocol.get_server(url, key_file, cert_file) - sfa_api = SfaAPI(key_file = key_file, cert_file = cert_file, interface='registry') + sfa_api = PlcSfaApi(key_file = key_file, cert_file = cert_file, interface='registry') credential = sfa_api.getCredential() # get peer registries diff --git a/sfa/server/sfa-server.py b/sfa/server/sfa-start.py similarity index 89% rename from sfa/server/sfa-server.py rename to sfa/server/sfa-start.py index fadb1d3d..966a13e0 100755 --- a/sfa/server/sfa-server.py +++ b/sfa/server/sfa-start.py @@ -42,14 +42,12 @@ from sfa.trust.certificate import Keypair, Certificate from sfa.trust.hierarchy import Hierarchy from sfa.trust.gid import GID from sfa.util.config import Config -from sfa.plc.api import SfaAPI +from sfa.plc.plcsfaapi import PlcSfaApi from sfa.server.registry import Registries from sfa.server.aggregate import Aggregates from sfa.util.xrn import get_authority, hrn_to_urn from sfa.util.sfalogging import logger -from sfa.managers.import_manager import import_manager - # after http://www.erlenstar.demon.co.uk/unix/faq_2.html def daemon(): """Daemonize the current process.""" @@ -136,28 +134,6 @@ def init_self_signed_cert(hrn, key, server_cert_file): cert.sign() cert.save_to_file(server_cert_file) -def init_server(options, config): - """ - Locate the manager based on config.*TYPE - Execute the init_server method (well in fact function, sigh) if defined in that module - In order to migrate to a more generic approach: - * search for <>_manager_.py - * if not found, try <>_manager.py (and issue a warning if !='pl') - """ - if options.registry: - manager=import_manager ("registry", config.SFA_REGISTRY_TYPE) - if manager and hasattr(manager, 'init_server'): manager.init_server() - if options.am: - manager=import_manager ("aggregate", config.SFA_AGGREGATE_TYPE) - if manager and hasattr(manager, 'init_server'): manager.init_server() - if options.sm: - manager=import_manager ("slice", config.SFA_SM_TYPE) - if manager and hasattr(manager, 'init_server'): manager.init_server() - if options.cm: - manager=import_manager ("component", config.SFA_CM_TYPE) - if manager and hasattr(manager, 'init_server'): manager.init_server() - - def install_peer_certs(server_key_file, server_cert_file): """ Attempt to install missing trusted gids and db records for @@ -167,7 +143,7 @@ def install_peer_certs(server_key_file, server_cert_file): # There should be a gid file in /etc/sfa/trusted_roots for every # peer registry found in in the registries.xml config file. If there # are any missing gids, request a new one from the peer registry. - api = SfaAPI(key_file = server_key_file, cert_file = server_cert_file) + api = PlcSfaApi(key_file = server_key_file, cert_file = server_cert_file) registries = Registries() aggregates = Aggregates() interfaces = dict(registries.items() + aggregates.items()) @@ -224,7 +200,7 @@ def update_cert_records(gids): Make sure there is a record in the registry for the specified gids. Removes old records from the db. """ - # import SfaTable here so this module can be loaded by ComponentAPI + # import SfaTable here so this module can be loaded by PlcComponentApi from sfa.util.table import SfaTable from sfa.util.record import SfaRecord if not gids: @@ -256,7 +232,7 @@ def update_cert_records(gids): def main(): # Generate command line parser - parser = OptionParser(usage="sfa-server [options]") + parser = OptionParser(usage="sfa-start.py [options]") parser.add_option("-r", "--registry", dest="registry", action="store_true", help="run registry server", default=False) parser.add_option("-s", "--slicemgr", dest="sm", action="store_true", @@ -280,8 +256,7 @@ def main(): server_cert_file = os.path.join(hierarchy.basedir, "server.cert") init_server_key(server_key_file, server_cert_file, config, hierarchy) - init_server(options, config) - + if (options.daemon): daemon() if options.trusted_certs: diff --git a/sfa/server/sfa_component_setup.py b/sfa/server/sfa_component_setup.py index b0a5a478..954ff6d0 100755 --- a/sfa/server/sfa_component_setup.py +++ b/sfa/server/sfa_component_setup.py @@ -4,10 +4,11 @@ import os import tempfile from optparse import OptionParser -from sfa.util.faults import * +from sfa.util.faults import ConnectionKeyGIDMismatch from sfa.util.config import Config import sfa.util.xmlrpcprotocol as xmlrpcprotocol from sfa.util.plxrn import hrn_to_pl_slicename, slicename_to_hrn + from sfa.trust.certificate import Keypair, Certificate from sfa.trust.credential import Credential from sfa.trust.gid import GID @@ -221,8 +222,8 @@ def get_gids(registry=None, verbose=False): if verbose: print "Getting current slices on this node" # get a list of slices on this node - from sfa.plc.api import ComponentAPI - api = ComponentAPI() + from sfa.plc.plcsfaapi import PlcComponentApi + api = PlcComponentApi() xids_tuple = api.nodemanager.GetXIDs() slices = eval(xids_tuple[1]) slicenames = slices.keys() diff --git a/sfa/server/sfaapi.py b/sfa/server/sfaapi.py new file mode 100644 index 00000000..9d22e7f2 --- /dev/null +++ b/sfa/server/sfaapi.py @@ -0,0 +1,209 @@ +import os.path +import datetime + +from sfa.util.faults import SfaAPIError +from sfa.util.config import Config +from sfa.util.cache import Cache +from sfa.trust.auth import Auth +from sfa.trust.certificate import Keypair, Certificate +from sfa.trust.credential import Credential +from sfa.trust.rights import determine_rights + +# this is wrong all right, but temporary, will use generic +from sfa.server.xmlrpcapi import XmlrpcApi +import os +import datetime + +#################### +class SfaApi (XmlrpcApi): + + """ + An SfaApi instance is a basic xmlrcp service + augmented with the local cryptographic material and hrn + + It also has the notion of its own interface (a string describing + whether we run a registry, aggregate or slicemgr) and has + the notion of neighbour sfa services as defined + in /etc/sfa/{aggregates,registries}.xml + + Finally it contains a cache instance + + It gets augmented by the generic layer with + (*) an instance of manager (actually a manager module for now) + (*) which in turn holds an instance of a testbed driver + For convenience api.manager.driver == api.driver + """ + + def __init__ (self, encoding="utf-8", methods='sfa.methods', + config = "/etc/sfa/sfa_config.py", + peer_cert = None, interface = None, + key_file = None, cert_file = None, cache = None): + + XmlrpcApi.__init__ (self, encoding) + + # we may be just be documenting the API + if config is None: + return + # Load configuration + self.config = Config(config) + self.credential = None + self.auth = Auth(peer_cert) + self.interface = interface + self.hrn = self.config.SFA_INTERFACE_HRN + self.key_file = key_file + self.key = Keypair(filename=self.key_file) + self.cert_file = cert_file + self.cert = Certificate(filename=self.cert_file) + self.cache = cache + if self.cache is None: + self.cache = Cache() + + # load registries + from sfa.server.registry import Registries + self.registries = Registries() + + # load aggregates + from sfa.server.aggregate import Aggregates + self.aggregates = Aggregates() + + # filled later on by generic/Generic + self.manager=None + + # tmp + def get_interface_manager(self, manager_base = 'sfa.managers'): + return self.manager + + def get_server(self, interface, cred, timeout=30): + """ + Returns a connection to the specified interface. Use the specified + credential to determine the caller and look for the caller's key/cert + in the registry hierarchy cache. + """ + from sfa.trust.hierarchy import Hierarchy + if not isinstance(cred, Credential): + cred_obj = Credential(string=cred) + else: + cred_obj = cred + caller_gid = cred_obj.get_gid_caller() + hierarchy = Hierarchy() + auth_info = hierarchy.get_auth_info(caller_gid.get_hrn()) + key_file = auth_info.get_privkey_filename() + cert_file = auth_info.get_gid_filename() + server = interface.get_server(key_file, cert_file, timeout) + return server + + + def getCredential(self): + """ + Return a valid credential for this interface. + """ + type = 'authority' + path = self.config.SFA_DATA_DIR + filename = ".".join([self.interface, self.hrn, type, "cred"]) + cred_filename = os.path.join(path,filename) + cred = None + if os.path.isfile(cred_filename): + cred = Credential(filename = cred_filename) + # make sure cred isnt expired + if not cred.get_expiration or \ + datetime.datetime.utcnow() < cred.get_expiration(): + return cred.save_to_string(save_parents=True) + + # get a new credential + if self.interface in ['registry']: + cred = self.__getCredentialRaw() + else: + cred = self.__getCredential() + cred.save_to_file(cred_filename, save_parents=True) + + return cred.save_to_string(save_parents=True) + + + def getDelegatedCredential(self, creds): + """ + Attempt to find a credential delegated to us in + the specified list of creds. + """ + from sfa.trust.hierarchy import Hierarchy + if creds and not isinstance(creds, list): + creds = [creds] + hierarchy = Hierarchy() + + delegated_cred = None + for cred in creds: + if hierarchy.auth_exists(Credential(string=cred).get_gid_caller().get_hrn()): + delegated_cred = cred + break + return delegated_cred + + def __getCredential(self): + """ + Get our credential from a remote registry + """ + from sfa.server.registry import Registries + registries = Registries() + registry = registries.get_server(self.hrn, self.key_file, self.cert_file) + cert_string=self.cert.save_to_string(save_parents=True) + # get self credential + self_cred = registry.GetSelfCredential(cert_string, self.hrn, 'authority') + # get credential + cred = registry.GetCredential(self_cred, self.hrn, 'authority') + return Credential(string=cred) + + def __getCredentialRaw(self): + """ + Get our current credential directly from the local registry. + """ + + hrn = self.hrn + auth_hrn = self.auth.get_authority(hrn) + + # is this a root or sub authority + if not auth_hrn or hrn == self.config.SFA_INTERFACE_HRN: + auth_hrn = hrn + auth_info = self.auth.get_auth_info(auth_hrn) + table = self.SfaTable() + records = table.findObjects({'hrn': hrn, 'type': 'authority+sa'}) + if not records: + raise RecordNotFound + record = records[0] + type = record['type'] + object_gid = record.get_gid_object() + new_cred = Credential(subject = object_gid.get_subject()) + new_cred.set_gid_caller(object_gid) + new_cred.set_gid_object(object_gid) + new_cred.set_issuer_keys(auth_info.get_privkey_filename(), auth_info.get_gid_filename()) + + r1 = determine_rights(type, hrn) + new_cred.set_privileges(r1) + new_cred.encode() + new_cred.sign() + + return new_cred + + def loadCredential (self): + """ + Attempt to load credential from file if it exists. If it doesnt get + credential from registry. + """ + + # see if this file exists + # XX This is really the aggregate's credential. Using this is easier than getting + # the registry's credential from iteslf (ssl errors). + filename = self.interface + self.hrn + ".ma.cred" + ma_cred_path = os.path.join(self.config.SFA_DATA_DIR,filename) + try: + self.credential = Credential(filename = ma_cred_path) + except IOError: + self.credential = self.getCredentialFromRegistry() + + def get_cached_server_version(self, server): + cache_key = server.url + "-version" + server_version = None + if self.cache: + server_version = self.cache.get(cache_key) + if not server_version: + server_version = server.GetVersion() + # cache version for 24 hours + self.cache.add(cache_key, server_version, ttl= 60*60*24) + return server_version diff --git a/sfa/server/sfaserver.py b/sfa/server/sfaserver.py new file mode 100644 index 00000000..f392b785 --- /dev/null +++ b/sfa/server/sfaserver.py @@ -0,0 +1,65 @@ +## +# This module implements a general-purpose server layer for sfa. +# The same basic server should be usable on the registry, component, or +# other interfaces. +# +# TODO: investigate ways to combine this with existing PLC server? +## + +import threading + +from sfa.server.threadedserver import ThreadedServer, SecureXMLRpcRequestHandler + +from sfa.util.sfalogging import logger +from sfa.trust.certificate import Keypair, Certificate + +## +# Implements an HTTPS XML-RPC server. Generally it is expected that SFA +# functions will take a credential string, which is passed to +# decode_authentication. Decode_authentication() will verify the validity of +# the credential, and verify that the user is using the key that matches the +# GID supplied in the credential. + +class SfaServer(threading.Thread): + + ## + # Create a new SfaServer object. + # + # @param ip the ip address to listen on + # @param port the port to listen on + # @param key_file private key filename of registry + # @param cert_file certificate filename containing public key + # (could be a GID file) + + def __init__(self, ip, port, key_file, cert_file,interface): + threading.Thread.__init__(self) + self.key = Keypair(filename = key_file) + self.cert = Certificate(filename = cert_file) + #self.server = SecureXMLRPCServer((ip, port), SecureXMLRpcRequestHandler, key_file, cert_file) + self.server = ThreadedServer((ip, port), SecureXMLRpcRequestHandler, key_file, cert_file) + self.server.interface=interface + self.trusted_cert_list = None + self.register_functions() + logger.info("Starting SfaServer, interface=%s"%interface) + + ## + # Register functions that will be served by the XMLRPC server. This + # function should be overridden by each descendant class. + + def register_functions(self): + self.server.register_function(self.noop) + + ## + # Sample no-op server function. The no-op function decodes the credential + # that was passed to it. + + def noop(self, cred, anything): + return anything + + ## + # Execute the server, serving requests forever. + + def run(self): + self.server.serve_forever() + + diff --git a/sfa/server/slicemgr.py b/sfa/server/slicemgr.py index c0fbd6a1..9a7fa4a0 100644 --- a/sfa/server/slicemgr.py +++ b/sfa/server/slicemgr.py @@ -2,7 +2,7 @@ import os import sys import datetime import time -from sfa.util.server import * +from sfa.server.sfaserver import SfaServer class SliceMgr(SfaServer): diff --git a/sfa/util/server.py b/sfa/server/threadedserver.py similarity index 80% rename from sfa/util/server.py rename to sfa/server/threadedserver.py index c3ae7184..7a9c368b 100644 --- a/sfa/util/server.py +++ b/sfa/server/threadedserver.py @@ -7,24 +7,23 @@ ## import sys -import socket, os +import socket import traceback import threading from Queue import Queue import SocketServer import BaseHTTPServer -import SimpleHTTPServer import SimpleXMLRPCServer from OpenSSL import SSL -from sfa.trust.certificate import Keypair, Certificate -from sfa.trust.trustedroots import TrustedRoots +from sfa.util.sfalogging import logger from sfa.util.config import Config -from sfa.trust.credential import * -from sfa.util.faults import * -from sfa.plc.api import SfaAPI from sfa.util.cache import Cache -from sfa.util.sfalogging import logger +from sfa.trust.certificate import Certificate +from sfa.trust.trustedroots import TrustedRoots + +# don't hard code an api class anymore here +from sfa.generic import Generic ## # Verification callback for pyOpenSSL. We do our own checking of keys because @@ -97,11 +96,18 @@ class SecureXMLRpcRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler): try: peer_cert = Certificate() peer_cert.load_from_pyopenssl_x509(self.connection.get_peer_certificate()) - self.api = SfaAPI(peer_cert = peer_cert, - interface = self.server.interface, - key_file = self.server.key_file, - cert_file = self.server.cert_file, - cache = self.cache) + generic=Generic.the_flavour() + self.api = generic.make_api (peer_cert = peer_cert, + interface = self.server.interface, + key_file = self.server.key_file, + cert_file = self.server.cert_file, + cache = self.cache) + #logger.info("SecureXMLRpcRequestHandler.do_POST:") + #logger.info("interface=%s"%self.server.interface) + #logger.info("key_file=%s"%self.server.key_file) + #logger.info("api=%s"%self.api) + #logger.info("server=%s"%self.server) + #logger.info("handler=%s"%self) # get arguments request = self.rfile.read(int(self.headers["content-length"])) remote_addr = (remote_ip, remote_port) = self.connection.getpeername() @@ -129,6 +135,7 @@ class SecureXMLRpcRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler): ## # Taken from the web (XXX find reference). Implements an HTTPS xmlrpc server class SecureXMLRPCServer(BaseHTTPServer.HTTPServer,SimpleXMLRPCServer.SimpleXMLRPCDispatcher): + def __init__(self, server_address, HandlerClass, key_file, cert_file, logRequests=True): """Secure XML-RPC server. @@ -260,54 +267,3 @@ class ThreadPoolMixIn(SocketServer.ThreadingMixIn): class ThreadedServer(ThreadPoolMixIn, SecureXMLRPCServer): pass -## -# Implements an HTTPS XML-RPC server. Generally it is expected that SFA -# functions will take a credential string, which is passed to -# decode_authentication. Decode_authentication() will verify the validity of -# the credential, and verify that the user is using the key that matches the -# GID supplied in the credential. - -class SfaServer(threading.Thread): - - ## - # Create a new SfaServer object. - # - # @param ip the ip address to listen on - # @param port the port to listen on - # @param key_file private key filename of registry - # @param cert_file certificate filename containing public key - # (could be a GID file) - - def __init__(self, ip, port, key_file, cert_file,interface): - threading.Thread.__init__(self) - self.key = Keypair(filename = key_file) - self.cert = Certificate(filename = cert_file) - #self.server = SecureXMLRPCServer((ip, port), SecureXMLRpcRequestHandler, key_file, cert_file) - self.server = ThreadedServer((ip, port), SecureXMLRpcRequestHandler, key_file, cert_file) - self.server.interface=interface - self.trusted_cert_list = None - self.register_functions() - logger.info("Starting SfaServer, interface=%s"%interface) - - ## - # Register functions that will be served by the XMLRPC server. This - # function should be overridden by each descendant class. - - def register_functions(self): - self.server.register_function(self.noop) - - ## - # Sample no-op server function. The no-op function decodes the credential - # that was passed to it. - - def noop(self, cred, anything): - self.decode_authentication(cred) - return anything - - ## - # Execute the server, serving requests forever. - - def run(self): - self.server.serve_forever() - - diff --git a/sfa/util/api.py b/sfa/server/xmlrpcapi.py similarity index 56% rename from sfa/util/api.py rename to sfa/server/xmlrpcapi.py index 39d65382..456cd42d 100644 --- a/sfa/util/api.py +++ b/sfa/server/xmlrpcapi.py @@ -2,24 +2,25 @@ # SFA XML-RPC and SOAP interfaces # -import sys -import os -import traceback import string import xmlrpclib -from sfa.util.faults import * -from sfa.util.config import * -import sfa.util.xmlrpcprotocol as xmlrpcprotocol -from sfa.util.sfalogging import logger -from sfa.trust.auth import Auth -from sfa.util.cache import Cache -from sfa.trust.credential import * -from sfa.trust.certificate import * +# SOAP support is optional +try: + import SOAPpy + from SOAPpy.Parser import parseSOAPRPC + from SOAPpy.Types import faultType + from SOAPpy.NS import NS + from SOAPpy.SOAPBuilder import buildSOAP +except ImportError: + SOAPpy = None -# this is wrong all right, but temporary -from sfa.managers.import_manager import import_manager +#################### +#from sfa.util.faults import SfaNotImplemented, SfaAPIError, SfaInvalidAPIMethod, SfaFault +from sfa.util.faults import SfaInvalidAPIMethod, SfaAPIError, SfaFault +from sfa.util.sfalogging import logger +#################### # See "2.2 Characters" in the XML specification: # # #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] @@ -77,110 +78,24 @@ def xmlrpclib_dump(self, value, write): # You can't hide from me! xmlrpclib.Marshaller._Marshaller__dump = xmlrpclib_dump -# SOAP support is optional -try: - import SOAPpy - from SOAPpy.Parser import parseSOAPRPC - from SOAPpy.Types import faultType - from SOAPpy.NS import NS - from SOAPpy.SOAPBuilder import buildSOAP -except ImportError: - SOAPpy = None - - -def import_deep(name): - mod = __import__(name) - components = name.split('.') - for comp in components[1:]: - mod = getattr(mod, comp) - return mod - -class ManagerWrapper: +class XmlrpcApi: """ - This class acts as a wrapper around an SFA interface manager module, but - can be used with any python module. The purpose of this class is raise a - SfaNotImplemented exception if someone attempts to use an attribute - (could be a callable) thats not available in the library by checking the - library using hasattr. This helps to communicate better errors messages - to the users and developers in the event that a specifiec operation - is not implemented by a libarary and will generally be more helpful than - the standard AttributeError + The XmlrpcApi class implements a basic xmlrpc (or soap) service """ - def __init__(self, manager, interface): - self.manager = manager - self.interface = interface - - def __getattr__(self, method): - if not hasattr(self.manager, method): - raise SfaNotImplemented(method, self.interface) - return getattr(self.manager, method) - -class BaseAPI: protocol = None - def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", - methods='sfa.methods', peer_cert = None, interface = None, - key_file = None, cert_file = None, cache = None): + def __init__ (self, encoding="utf-8", methods='sfa.methods'): self.encoding = encoding + self.source = None # flat list of method names self.methods_module = methods_module = __import__(methods, fromlist=[methods]) self.methods = methods_module.all - # Better just be documenting the API - if config is None: - return - # Load configuration - self.config = Config(config) - self.auth = Auth(peer_cert) - self.hrn = self.config.SFA_INTERFACE_HRN - self.interface = interface - self.key_file = key_file - self.key = Keypair(filename=self.key_file) - self.cert_file = cert_file - self.cert = Certificate(filename=self.cert_file) - self.cache = cache - if self.cache is None: - self.cache = Cache() - self.credential = None - self.source = None - self.time_format = "%Y-%m-%d %H:%M:%S" self.logger = logger - # load registries - from sfa.server.registry import Registries - self.registries = Registries() - - # load aggregates - from sfa.server.aggregate import Aggregates - self.aggregates = Aggregates() - - - def get_interface_manager(self, manager_base = 'sfa.managers'): - """ - Returns the appropriate manager module for this interface. - Modules are usually found in sfa/managers/ - """ - manager=None - if self.interface in ['registry']: - manager=import_manager ("registry", self.config.SFA_REGISTRY_TYPE) - elif self.interface in ['aggregate']: - manager=import_manager ("aggregate", self.config.SFA_AGGREGATE_TYPE) - elif self.interface in ['slicemgr', 'sm']: - manager=import_manager ("slice", self.config.SFA_SM_TYPE) - elif self.interface in ['component', 'cm']: - manager=import_manager ("component", self.config.SFA_CM_TYPE) - if not manager: - raise SfaAPIError("No manager for interface: %s" % self.interface) - - # this isnt necessary but will help to produce better error messages - # if someone tries to access an operation this manager doesn't implement - manager = ManagerWrapper(manager, self.interface) - - return manager - def callable(self, method): """ Return a new instance of the specified method. @@ -195,7 +110,7 @@ class BaseAPI: module = __import__(self.methods_module.__name__ + "." + method, globals(), locals(), [classname]) callablemethod = getattr(module, classname)(self) return getattr(module, classname)(self) - except ImportError, AttributeError: + except (ImportError, AttributeError): raise SfaInvalidAPIMethod, method def call(self, source, method, *args): @@ -238,7 +153,7 @@ class BaseAPI: except SfaFault, fault: result = fault except Exception, fault: - logger.log_exc("BaseAPI.handle has caught Exception") + self.logger.log_exc("XmlrpcApi.handle has caught Exception") result = SfaAPIError(fault) @@ -267,13 +182,3 @@ class BaseAPI: return response - def get_cached_server_version(self, server): - cache_key = server.url + "-version" - server_version = None - if self.cache: - server_version = self.cache.get(cache_key) - if not server_version: - server_version = server.GetVersion() - # cache version for 24 hours - self.cache.add(cache_key, server_version, ttl= 60*60*24) - return server_version diff --git a/sfa/trust/auth.py b/sfa/trust/auth.py index 41c71cfd..f6269b31 100644 --- a/sfa/trust/auth.py +++ b/sfa/trust/auth.py @@ -3,16 +3,20 @@ # import sys +from sfa.util.faults import InsufficientRights, MissingCallerGID, MissingTrustedRoots, PermissionError, \ + BadRequestHash, ConnectionKeyGIDMismatch, SfaPermissionDenied +from sfa.util.sfalogging import logger +from sfa.util.config import Config +from sfa.util.xrn import get_authority + +from sfa.trust.gid import GID +from sfa.trust.rights import Rights from sfa.trust.certificate import Keypair, Certificate from sfa.trust.credential import Credential from sfa.trust.trustedroots import TrustedRoots -from sfa.util.faults import * from sfa.trust.hierarchy import Hierarchy -from sfa.util.config import * -from sfa.util.xrn import get_authority -from sfa.util.sfaticket import * +from sfa.trust.sfaticket import SfaTicket -from sfa.util.sfalogging import logger class Auth: """ @@ -145,7 +149,8 @@ class Auth: def authenticateCert(self, certStr, requestHash): cert = Certificate(string=certStr) - self.validateCert(self, cert) + # xxx should be validateCred ?? + self.validateCred(cert) def gidNoop(self, gidStr, value, requestHash): self.authenticateGid(gidStr, [gidStr, value], requestHash) @@ -311,7 +316,7 @@ class Auth: if not isinstance(creds, list): creds = [creds] creds = [] - if not isinistance(caller_hrn_list, list): + if not isinstance(caller_hrn_list, list): caller_hrn_list = [caller_hrn_list] for cred in creds: try: diff --git a/sfa/trust/certificate.py b/sfa/trust/certificate.py index bcec9d61..f0a2d71c 100644 --- a/sfa/trust/certificate.py +++ b/sfa/trust/certificate.py @@ -1,784 +1,790 @@ -#---------------------------------------------------------------------- -# Copyright (c) 2008 Board of Trustees, Princeton University -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and/or hardware specification (the "Work") to -# deal in the Work without restriction, including without limitation the -# rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Work, and to permit persons to whom the Work -# is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Work. -# -# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS -# IN THE WORK. -#---------------------------------------------------------------------- - -## -# SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement -# the necessary crypto functionality. Ideally just one of these libraries -# would be used, but unfortunately each of these libraries is independently -# lacking. The pyOpenSSL library is missing many necessary functions, and -# the M2Crypto library has crashed inside of some of the functions. The -# design decision is to use pyOpenSSL whenever possible as it seems more -# stable, and only use M2Crypto for those functions that are not possible -# in pyOpenSSL. -# -# This module exports two classes: Keypair and Certificate. -## -# - -import functools -import os -import tempfile -import base64 -import traceback -from tempfile import mkstemp - -from OpenSSL import crypto -import M2Crypto -from M2Crypto import X509 - -from sfa.util.sfalogging import logger -from sfa.util.xrn import urn_to_hrn -from sfa.util.faults import * -from sfa.util.sfalogging import logger - -glo_passphrase_callback = None - -## -# A global callback msy be implemented for requesting passphrases from the -# user. The function will be called with three arguments: -# -# keypair_obj: the keypair object that is calling the passphrase -# string: the string containing the private key that's being loaded -# x: unknown, appears to be 0, comes from pyOpenSSL and/or m2crypto -# -# The callback should return a string containing the passphrase. - -def set_passphrase_callback(callback_func): - global glo_passphrase_callback - - glo_passphrase_callback = callback_func - -## -# Sets a fixed passphrase. - -def set_passphrase(passphrase): - set_passphrase_callback( lambda k,s,x: passphrase ) - -## -# Check to see if a passphrase works for a particular private key string. -# Intended to be used by passphrase callbacks for input validation. - -def test_passphrase(string, passphrase): - try: - crypto.load_privatekey(crypto.FILETYPE_PEM, string, (lambda x: passphrase)) - return True - except: - return False - -def convert_public_key(key): - keyconvert_path = "/usr/bin/keyconvert.py" - if not os.path.isfile(keyconvert_path): - raise IOError, "Could not find keyconvert in %s" % keyconvert_path - - # we can only convert rsa keys - if "ssh-dss" in key: - return None - - (ssh_f, ssh_fn) = tempfile.mkstemp() - ssl_fn = tempfile.mktemp() - os.write(ssh_f, key) - os.close(ssh_f) - - cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn - os.system(cmd) - - # this check leaves the temporary file containing the public key so - # that it can be expected to see why it failed. - # TODO: for production, cleanup the temporary files - if not os.path.exists(ssl_fn): - return None - - k = Keypair() - try: - k.load_pubkey_from_file(ssl_fn) - except: - logger.log_exc("convert_public_key caught exception") - k = None - - # remove the temporary files - os.remove(ssh_fn) - os.remove(ssl_fn) - - return k - -## -# Public-private key pairs are implemented by the Keypair class. -# A Keypair object may represent both a public and private key pair, or it -# may represent only a public key (this usage is consistent with OpenSSL). - -class Keypair: - key = None # public/private keypair - m2key = None # public key (m2crypto format) - - ## - # Creates a Keypair object - # @param create If create==True, creates a new public/private key and - # stores it in the object - # @param string If string!=None, load the keypair from the string (PEM) - # @param filename If filename!=None, load the keypair from the file - - def __init__(self, create=False, string=None, filename=None): - if create: - self.create() - if string: - self.load_from_string(string) - if filename: - self.load_from_file(filename) - - ## - # Create a RSA public/private key pair and store it inside the keypair object - - def create(self): - self.key = crypto.PKey() - self.key.generate_key(crypto.TYPE_RSA, 1024) - - ## - # Save the private key to a file - # @param filename name of file to store the keypair in - - def save_to_file(self, filename): - open(filename, 'w').write(self.as_pem()) - self.filename=filename - - ## - # Load the private key from a file. Implicity the private key includes the public key. - - def load_from_file(self, filename): - self.filename=filename - buffer = open(filename, 'r').read() - self.load_from_string(buffer) - - ## - # Load the private key from a string. Implicitly the private key includes the public key. - - def load_from_string(self, string): - if glo_passphrase_callback: - self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string) ) - self.m2key = M2Crypto.EVP.load_key_string(string, functools.partial(glo_passphrase_callback, self, string) ) - else: - self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string) - self.m2key = M2Crypto.EVP.load_key_string(string) - - ## - # Load the public key from a string. No private key is loaded. - - def load_pubkey_from_file(self, filename): - # load the m2 public key - m2rsakey = M2Crypto.RSA.load_pub_key(filename) - self.m2key = M2Crypto.EVP.PKey() - self.m2key.assign_rsa(m2rsakey) - - # create an m2 x509 cert - m2name = M2Crypto.X509.X509_Name() - m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0) - m2x509 = M2Crypto.X509.X509() - m2x509.set_pubkey(self.m2key) - m2x509.set_serial_number(0) - m2x509.set_issuer_name(m2name) - m2x509.set_subject_name(m2name) - ASN1 = M2Crypto.ASN1.ASN1_UTCTIME() - ASN1.set_time(500) - m2x509.set_not_before(ASN1) - m2x509.set_not_after(ASN1) - # x509v3 so it can have extensions - # prob not necc since this cert itself is junk but still... - m2x509.set_version(2) - junk_key = Keypair(create=True) - m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1") - - # convert the m2 x509 cert to a pyopenssl x509 - m2pem = m2x509.as_pem() - pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem) - - # get the pyopenssl pkey from the pyopenssl x509 - self.key = pyx509.get_pubkey() - self.filename=filename - - ## - # Load the public key from a string. No private key is loaded. - - def load_pubkey_from_string(self, string): - (f, fn) = tempfile.mkstemp() - os.write(f, string) - os.close(f) - self.load_pubkey_from_file(fn) - os.remove(fn) - - ## - # Return the private key in PEM format. - - def as_pem(self): - return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key) - - ## - # Return an M2Crypto key object - - def get_m2_pkey(self): - if not self.m2key: - self.m2key = M2Crypto.EVP.load_key_string(self.as_pem()) - return self.m2key - - ## - # Returns a string containing the public key represented by this object. - - def get_pubkey_string(self): - m2pkey = self.get_m2_pkey() - return base64.b64encode(m2pkey.as_der()) - - ## - # Return an OpenSSL pkey object - - def get_openssl_pkey(self): - return self.key - - ## - # Given another Keypair object, return TRUE if the two keys are the same. - - def is_same(self, pkey): - return self.as_pem() == pkey.as_pem() - - def sign_string(self, data): - k = self.get_m2_pkey() - k.sign_init() - k.sign_update(data) - return base64.b64encode(k.sign_final()) - - def verify_string(self, data, sig): - k = self.get_m2_pkey() - k.verify_init() - k.verify_update(data) - return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey) - - def compute_hash(self, value): - return self.sign_string(str(value)) - - # only informative - def get_filename(self): - return getattr(self,'filename',None) - - def dump (self, *args, **kwargs): - print self.dump_string(*args, **kwargs) - - def dump_string (self): - result="" - result += "KEYPAIR: pubkey=%40s..."%self.get_pubkey_string() - filename=self.get_filename() - if filename: result += "Filename %s\n"%filename - return result - -## -# The certificate class implements a general purpose X509 certificate, making -# use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds -# several addition features, such as the ability to maintain a chain of -# parent certificates, and storage of application-specific data. -# -# Certificates include the ability to maintain a chain of parents. Each -# certificate includes a pointer to it's parent certificate. When loaded -# from a file or a string, the parent chain will be automatically loaded. -# When saving a certificate to a file or a string, the caller can choose -# whether to save the parent certificates as well. - -class Certificate: - digest = "md5" - - cert = None - issuerKey = None - issuerSubject = None - parent = None - isCA = None # will be a boolean once set - - separator="-----parent-----" - - ## - # Create a certificate object. - # - # @param lifeDays life of cert in days - default is 1825==5 years - # @param create If create==True, then also create a blank X509 certificate. - # @param subject If subject!=None, then create a blank certificate and set - # it's subject name. - # @param string If string!=None, load the certficate from the string. - # @param filename If filename!=None, load the certficiate from the file. - # @param isCA If !=None, set whether this cert is for a CA - - def __init__(self, lifeDays=1825, create=False, subject=None, string=None, filename=None, isCA=None): - self.data = {} - if create or subject: - self.create(lifeDays) - if subject: - self.set_subject(subject) - if string: - self.load_from_string(string) - if filename: - self.load_from_file(filename) - - # Set the CA bit if a value was supplied - if isCA != None: - self.set_is_ca(isCA) - - # Create a blank X509 certificate and store it in this object. - - def create(self, lifeDays=1825): - self.cert = crypto.X509() - # FIXME: Use different serial #s - self.cert.set_serial_number(3) - self.cert.gmtime_adj_notBefore(0) # 0 means now - self.cert.gmtime_adj_notAfter(lifeDays*60*60*24) # five years is default - self.cert.set_version(2) # x509v3 so it can have extensions - - - ## - # Given a pyOpenSSL X509 object, store that object inside of this - # certificate object. - - def load_from_pyopenssl_x509(self, x509): - self.cert = x509 - - ## - # Load the certificate from a string - - def load_from_string(self, string): - # if it is a chain of multiple certs, then split off the first one and - # load it (support for the ---parent--- tag as well as normal chained certs) - - string = string.strip() - - # If it's not in proper PEM format, wrap it - if string.count('-----BEGIN CERTIFICATE') == 0: - string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string - - # If there is a PEM cert in there, but there is some other text first - # such as the text of the certificate, skip the text - beg = string.find('-----BEGIN CERTIFICATE') - if beg > 0: - # skipping over non cert beginning - string = string[beg:] - - parts = [] - - if string.count('-----BEGIN CERTIFICATE-----') > 1 and \ - string.count(Certificate.separator) == 0: - parts = string.split('-----END CERTIFICATE-----',1) - parts[0] += '-----END CERTIFICATE-----' - else: - parts = string.split(Certificate.separator, 1) - - self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0]) - - # if there are more certs, then create a parent and let the parent load - # itself from the remainder of the string - if len(parts) > 1 and parts[1] != '': - self.parent = self.__class__() - self.parent.load_from_string(parts[1]) - - ## - # Load the certificate from a file - - def load_from_file(self, filename): - file = open(filename) - string = file.read() - self.load_from_string(string) - self.filename=filename - - ## - # Save the certificate to a string. - # - # @param save_parents If save_parents==True, then also save the parent certificates. - - def save_to_string(self, save_parents=True): - string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert) - if save_parents and self.parent: - string = string + self.parent.save_to_string(save_parents) - return string - - ## - # Save the certificate to a file. - # @param save_parents If save_parents==True, then also save the parent certificates. - - def save_to_file(self, filename, save_parents=True, filep=None): - string = self.save_to_string(save_parents=save_parents) - if filep: - f = filep - else: - f = open(filename, 'w') - f.write(string) - f.close() - self.filename=filename - - ## - # Save the certificate to a random file in /tmp/ - # @param save_parents If save_parents==True, then also save the parent certificates. - def save_to_random_tmp_file(self, save_parents=True): - fp, filename = mkstemp(suffix='cert', text=True) - fp = os.fdopen(fp, "w") - self.save_to_file(filename, save_parents=True, filep=fp) - return filename - - ## - # Sets the issuer private key and name - # @param key Keypair object containing the private key of the issuer - # @param subject String containing the name of the issuer - # @param cert (optional) Certificate object containing the name of the issuer - - def set_issuer(self, key, subject=None, cert=None): - self.issuerKey = key - if subject: - # it's a mistake to use subject and cert params at the same time - assert(not cert) - if isinstance(subject, dict) or isinstance(subject, str): - req = crypto.X509Req() - reqSubject = req.get_subject() - if (isinstance(subject, dict)): - for key in reqSubject.keys(): - setattr(reqSubject, key, subject[key]) - else: - setattr(reqSubject, "CN", subject) - subject = reqSubject - # subject is not valid once req is out of scope, so save req - self.issuerReq = req - if cert: - # if a cert was supplied, then get the subject from the cert - subject = cert.cert.get_subject() - assert(subject) - self.issuerSubject = subject - - ## - # Get the issuer name - - def get_issuer(self, which="CN"): - x = self.cert.get_issuer() - return getattr(x, which) - - ## - # Set the subject name of the certificate - - def set_subject(self, name): - req = crypto.X509Req() - subj = req.get_subject() - if (isinstance(name, dict)): - for key in name.keys(): - setattr(subj, key, name[key]) - else: - setattr(subj, "CN", name) - self.cert.set_subject(subj) - - ## - # Get the subject name of the certificate - - def get_subject(self, which="CN"): - x = self.cert.get_subject() - return getattr(x, which) - - ## - # Get a pretty-print subject name of the certificate - - def get_printable_subject(self): - x = self.cert.get_subject() - return "[ OU: %s, CN: %s, SubjectAltName: %s ]" % (getattr(x, "OU"), getattr(x, "CN"), self.get_data()) - - ## - # Get the public key of the certificate. - # - # @param key Keypair object containing the public key - - def set_pubkey(self, key): - assert(isinstance(key, Keypair)) - self.cert.set_pubkey(key.get_openssl_pkey()) - - ## - # Get the public key of the certificate. - # It is returned in the form of a Keypair object. - - def get_pubkey(self): - m2x509 = X509.load_cert_string(self.save_to_string()) - pkey = Keypair() - pkey.key = self.cert.get_pubkey() - pkey.m2key = m2x509.get_pubkey() - return pkey - - def set_intermediate_ca(self, val): - return self.set_is_ca(val) - - # Set whether this cert is for a CA. All signers and only signers should be CAs. - # The local member starts unset, letting us check that you only set it once - # @param val Boolean indicating whether this cert is for a CA - def set_is_ca(self, val): - if val is None: - return - - if self.isCA != None: - # Can't double set properties - raise "Cannot set basicConstraints CA:?? more than once. Was %s, trying to set as %s" % (self.isCA, val) - - self.isCA = val - if val: - self.add_extension('basicConstraints', 1, 'CA:TRUE') - else: - self.add_extension('basicConstraints', 1, 'CA:FALSE') - - - - ## - # Add an X509 extension to the certificate. Add_extension can only be called - # once for a particular extension name, due to limitations in the underlying - # library. - # - # @param name string containing name of extension - # @param value string containing value of the extension - - def add_extension(self, name, critical, value): - oldExtVal = None - try: - oldExtVal = self.get_extension(name) - except: - # M2Crypto LookupError when the extension isn't there (yet) - pass - - # This code limits you from adding the extension with the same value - # The method comment says you shouldn't do this with the same name - # But actually it (m2crypto) appears to allow you to do this. - if oldExtVal and oldExtVal == value: - # don't add this extension again - # just do nothing as here - return - # FIXME: What if they are trying to set with a different value? - # Is this ever OK? Or should we raise an exception? -# elif oldExtVal: -# raise "Cannot add extension %s which had val %s with new val %s" % (name, oldExtVal, value) - - ext = crypto.X509Extension (name, critical, value) - self.cert.add_extensions([ext]) - - ## - # Get an X509 extension from the certificate - - def get_extension(self, name): - - # pyOpenSSL does not have a way to get extensions - m2x509 = X509.load_cert_string(self.save_to_string()) - value = m2x509.get_ext(name).get_value() - - return value - - ## - # Set_data is a wrapper around add_extension. It stores the parameter str in - # the X509 subject_alt_name extension. Set_data can only be called once, due - # to limitations in the underlying library. - - def set_data(self, str, field='subjectAltName'): - # pyOpenSSL only allows us to add extensions, so if we try to set the - # same extension more than once, it will not work - if self.data.has_key(field): - raise "Cannot set ", field, " more than once" - self.data[field] = str - self.add_extension(field, 0, str) - - ## - # Return the data string that was previously set with set_data - - def get_data(self, field='subjectAltName'): - if self.data.has_key(field): - return self.data[field] - - try: - uri = self.get_extension(field) - self.data[field] = uri - except LookupError: - return None - - return self.data[field] - - ## - # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer(). - - def sign(self): - logger.debug('certificate.sign') - assert self.cert != None - assert self.issuerSubject != None - assert self.issuerKey != None - self.cert.set_issuer(self.issuerSubject) - self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest) - - ## - # Verify the authenticity of a certificate. - # @param pkey is a Keypair object representing a public key. If Pkey - # did not sign the certificate, then an exception will be thrown. - - def verify(self, pkey): - # pyOpenSSL does not have a way to verify signatures - m2x509 = X509.load_cert_string(self.save_to_string()) - m2pkey = pkey.get_m2_pkey() - # verify it - return m2x509.verify(m2pkey) - - # XXX alternatively, if openssl has been patched, do the much simpler: - # try: - # self.cert.verify(pkey.get_openssl_key()) - # return 1 - # except: - # return 0 - - ## - # Return True if pkey is identical to the public key that is contained in the certificate. - # @param pkey Keypair object - - def is_pubkey(self, pkey): - return self.get_pubkey().is_same(pkey) - - ## - # Given a certificate cert, verify that this certificate was signed by the - # public key contained in cert. Throw an exception otherwise. - # - # @param cert certificate object - - def is_signed_by_cert(self, cert): - k = cert.get_pubkey() - result = self.verify(k) - return result - - ## - # Set the parent certficiate. - # - # @param p certificate object. - - def set_parent(self, p): - self.parent = p - - ## - # Return the certificate object of the parent of this certificate. - - def get_parent(self): - return self.parent - - ## - # Verification examines a chain of certificates to ensure that each parent - # signs the child, and that some certificate in the chain is signed by a - # trusted certificate. - # - # Verification is a basic recursion:
-    #     if this_certificate was signed by trusted_certs:
-    #         return
-    #     else
-    #         return verify_chain(parent, trusted_certs)
-    # 
- # - # At each recursion, the parent is tested to ensure that it did sign the - # child. If a parent did not sign a child, then an exception is thrown. If - # the bottom of the recursion is reached and the certificate does not match - # a trusted root, then an exception is thrown. - # Also require that parents are CAs. - # - # @param Trusted_certs is a list of certificates that are trusted. - # - - def verify_chain(self, trusted_certs = None): - # Verify a chain of certificates. Each certificate must be signed by - # the public key contained in it's parent. The chain is recursed - # until a certificate is found that is signed by a trusted root. - - # verify expiration time - if self.cert.has_expired(): - logger.debug("verify_chain: NO, Certificate %s has expired" % self.get_printable_subject()) - raise CertExpired(self.get_printable_subject(), "client cert") - - # if this cert is signed by a trusted_cert, then we are set - for trusted_cert in trusted_certs: - if self.is_signed_by_cert(trusted_cert): - # verify expiration of trusted_cert ? - if not trusted_cert.cert.has_expired(): - logger.debug("verify_chain: YES. Cert %s signed by trusted cert %s"%( - self.get_printable_subject(), trusted_cert.get_printable_subject())) - return trusted_cert - else: - logger.debug("verify_chain: NO. Cert %s is signed by trusted_cert %s, but that signer is expired..."%( - self.get_printable_subject(),trusted_cert.get_printable_subject())) - raise CertExpired(self.get_printable_subject()," signer trusted_cert %s"%trusted_cert.get_printable_subject()) - - # if there is no parent, then no way to verify the chain - if not self.parent: - logger.debug("verify_chain: NO. %s has no parent and issuer %s is not in %d trusted roots"%(self.get_printable_subject(), self.get_issuer(), len(trusted_certs))) - raise CertMissingParent(self.get_printable_subject() + ": Issuer %s not trusted by any of %d trusted roots, and cert has no parent." % (self.get_issuer(), len(trusted_certs))) - - # if it wasn't signed by the parent... - if not self.is_signed_by_cert(self.parent): - logger.debug("verify_chain: NO. %s is not signed by parent %s, but by %s"%self.get_printable_subject(), self.parent.get_printable_subject(), self.get_issuer()) - raise CertNotSignedByParent(self.get_printable_subject() + ": Parent %s, issuer %s" % (self.parent.get_printable_subject(), self.get_issuer())) - - # Confirm that the parent is a CA. Only CAs can be trusted as - # signers. - # Note that trusted roots are not parents, so don't need to be - # CAs. - # Ugly - cert objects aren't parsed so we need to read the - # extension and hope there are no other basicConstraints - if not self.parent.isCA and not (self.parent.get_extension('basicConstraints') == 'CA:TRUE'): - logger.warn("verify_chain: cert %s's parent %s is not a CA" % (self.get_printable_subject(), self.parent.get_printable_subject())) - raise CertNotSignedByParent(self.get_printable_subject() + ": Parent %s not a CA" % self.parent.get_printable_subject()) - - # if the parent isn't verified... - logger.debug("verify_chain: .. %s, -> verifying parent %s"%(self.get_printable_subject(),self.parent.get_printable_subject())) - self.parent.verify_chain(trusted_certs) - - return - - ### more introspection - def get_extensions(self): - # pyOpenSSL does not have a way to get extensions - triples=[] - m2x509 = X509.load_cert_string(self.save_to_string()) - nb_extensions=m2x509.get_ext_count() - logger.debug("X509 had %d extensions"%nb_extensions) - for i in range(nb_extensions): - ext=m2x509.get_ext_at(i) - triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) ) - return triples - - def get_data_names(self): - return self.data.keys() - - def get_all_datas (self): - triples=self.get_extensions() - for name in self.get_data_names(): - triples.append( (name,self.get_data(name),'data',) ) - return triples - - # only informative - def get_filename(self): - return getattr(self,'filename',None) - - def dump (self, *args, **kwargs): - print self.dump_string(*args, **kwargs) - - def dump_string (self,show_extensions=False): - result = "" - result += "CERTIFICATE for %s\n"%self.get_printable_subject() - result += "Issued by %s\n"%self.get_issuer() - filename=self.get_filename() - if filename: result += "Filename %s\n"%filename - if show_extensions: - all_datas=self.get_all_datas() - result += " has %d extensions/data attached"%len(all_datas) - for (n,v,c) in all_datas: - if c=='data': - result += " data: %s=%s\n"%(n,v) - else: - result += " ext: %s (crit=%s)=<<<%s>>>\n"%(n,c,v) - return result +#---------------------------------------------------------------------- +# Copyright (c) 2008 Board of Trustees, Princeton University +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and/or hardware specification (the "Work") to +# deal in the Work without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Work, and to permit persons to whom the Work +# is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Work. +# +# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS +# IN THE WORK. +#---------------------------------------------------------------------- + +## +# SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement +# the necessary crypto functionality. Ideally just one of these libraries +# would be used, but unfortunately each of these libraries is independently +# lacking. The pyOpenSSL library is missing many necessary functions, and +# the M2Crypto library has crashed inside of some of the functions. The +# design decision is to use pyOpenSSL whenever possible as it seems more +# stable, and only use M2Crypto for those functions that are not possible +# in pyOpenSSL. +# +# This module exports two classes: Keypair and Certificate. +## +# + +import functools +import os +import tempfile +import base64 +from tempfile import mkstemp + +from OpenSSL import crypto +import M2Crypto +from M2Crypto import X509 + +from sfa.util.faults import CertExpired, CertMissingParent, CertNotSignedByParent +from sfa.util.sfalogging import logger + +glo_passphrase_callback = None + +## +# A global callback msy be implemented for requesting passphrases from the +# user. The function will be called with three arguments: +# +# keypair_obj: the keypair object that is calling the passphrase +# string: the string containing the private key that's being loaded +# x: unknown, appears to be 0, comes from pyOpenSSL and/or m2crypto +# +# The callback should return a string containing the passphrase. + +def set_passphrase_callback(callback_func): + global glo_passphrase_callback + + glo_passphrase_callback = callback_func + +## +# Sets a fixed passphrase. + +def set_passphrase(passphrase): + set_passphrase_callback( lambda k,s,x: passphrase ) + +## +# Check to see if a passphrase works for a particular private key string. +# Intended to be used by passphrase callbacks for input validation. + +def test_passphrase(string, passphrase): + try: + crypto.load_privatekey(crypto.FILETYPE_PEM, string, (lambda x: passphrase)) + return True + except: + return False + +def convert_public_key(key): + keyconvert_path = "/usr/bin/keyconvert.py" + if not os.path.isfile(keyconvert_path): + raise IOError, "Could not find keyconvert in %s" % keyconvert_path + + # we can only convert rsa keys + if "ssh-dss" in key: + return None + + (ssh_f, ssh_fn) = tempfile.mkstemp() + ssl_fn = tempfile.mktemp() + os.write(ssh_f, key) + os.close(ssh_f) + + cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn + os.system(cmd) + + # this check leaves the temporary file containing the public key so + # that it can be expected to see why it failed. + # TODO: for production, cleanup the temporary files + if not os.path.exists(ssl_fn): + return None + + k = Keypair() + try: + k.load_pubkey_from_file(ssl_fn) + except: + logger.log_exc("convert_public_key caught exception") + k = None + + # remove the temporary files + os.remove(ssh_fn) + os.remove(ssl_fn) + + return k + +## +# Public-private key pairs are implemented by the Keypair class. +# A Keypair object may represent both a public and private key pair, or it +# may represent only a public key (this usage is consistent with OpenSSL). + +class Keypair: + key = None # public/private keypair + m2key = None # public key (m2crypto format) + + ## + # Creates a Keypair object + # @param create If create==True, creates a new public/private key and + # stores it in the object + # @param string If string!=None, load the keypair from the string (PEM) + # @param filename If filename!=None, load the keypair from the file + + def __init__(self, create=False, string=None, filename=None): + if create: + self.create() + if string: + self.load_from_string(string) + if filename: + self.load_from_file(filename) + + ## + # Create a RSA public/private key pair and store it inside the keypair object + + def create(self): + self.key = crypto.PKey() + self.key.generate_key(crypto.TYPE_RSA, 1024) + + ## + # Save the private key to a file + # @param filename name of file to store the keypair in + + def save_to_file(self, filename): + open(filename, 'w').write(self.as_pem()) + self.filename=filename + + ## + # Load the private key from a file. Implicity the private key includes the public key. + + def load_from_file(self, filename): + self.filename=filename + buffer = open(filename, 'r').read() + self.load_from_string(buffer) + + ## + # Load the private key from a string. Implicitly the private key includes the public key. + + def load_from_string(self, string): + if glo_passphrase_callback: + self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string) ) + self.m2key = M2Crypto.EVP.load_key_string(string, functools.partial(glo_passphrase_callback, self, string) ) + else: + self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string) + self.m2key = M2Crypto.EVP.load_key_string(string) + + ## + # Load the public key from a string. No private key is loaded. + + def load_pubkey_from_file(self, filename): + # load the m2 public key + m2rsakey = M2Crypto.RSA.load_pub_key(filename) + self.m2key = M2Crypto.EVP.PKey() + self.m2key.assign_rsa(m2rsakey) + + # create an m2 x509 cert + m2name = M2Crypto.X509.X509_Name() + m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0) + m2x509 = M2Crypto.X509.X509() + m2x509.set_pubkey(self.m2key) + m2x509.set_serial_number(0) + m2x509.set_issuer_name(m2name) + m2x509.set_subject_name(m2name) + ASN1 = M2Crypto.ASN1.ASN1_UTCTIME() + ASN1.set_time(500) + m2x509.set_not_before(ASN1) + m2x509.set_not_after(ASN1) + # x509v3 so it can have extensions + # prob not necc since this cert itself is junk but still... + m2x509.set_version(2) + junk_key = Keypair(create=True) + m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1") + + # convert the m2 x509 cert to a pyopenssl x509 + m2pem = m2x509.as_pem() + pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem) + + # get the pyopenssl pkey from the pyopenssl x509 + self.key = pyx509.get_pubkey() + self.filename=filename + + ## + # Load the public key from a string. No private key is loaded. + + def load_pubkey_from_string(self, string): + (f, fn) = tempfile.mkstemp() + os.write(f, string) + os.close(f) + self.load_pubkey_from_file(fn) + os.remove(fn) + + ## + # Return the private key in PEM format. + + def as_pem(self): + return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key) + + ## + # Return an M2Crypto key object + + def get_m2_pkey(self): + if not self.m2key: + self.m2key = M2Crypto.EVP.load_key_string(self.as_pem()) + return self.m2key + + ## + # Returns a string containing the public key represented by this object. + + def get_pubkey_string(self): + m2pkey = self.get_m2_pkey() + return base64.b64encode(m2pkey.as_der()) + + ## + # Return an OpenSSL pkey object + + def get_openssl_pkey(self): + return self.key + + ## + # Given another Keypair object, return TRUE if the two keys are the same. + + def is_same(self, pkey): + return self.as_pem() == pkey.as_pem() + + def sign_string(self, data): + k = self.get_m2_pkey() + k.sign_init() + k.sign_update(data) + return base64.b64encode(k.sign_final()) + + def verify_string(self, data, sig): + k = self.get_m2_pkey() + k.verify_init() + k.verify_update(data) + return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey) + + def compute_hash(self, value): + return self.sign_string(str(value)) + + # only informative + def get_filename(self): + return getattr(self,'filename',None) + + def dump (self, *args, **kwargs): + print self.dump_string(*args, **kwargs) + + def dump_string (self): + result="" + result += "KEYPAIR: pubkey=%40s..."%self.get_pubkey_string() + filename=self.get_filename() + if filename: result += "Filename %s\n"%filename + return result + +## +# The certificate class implements a general purpose X509 certificate, making +# use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds +# several addition features, such as the ability to maintain a chain of +# parent certificates, and storage of application-specific data. +# +# Certificates include the ability to maintain a chain of parents. Each +# certificate includes a pointer to it's parent certificate. When loaded +# from a file or a string, the parent chain will be automatically loaded. +# When saving a certificate to a file or a string, the caller can choose +# whether to save the parent certificates as well. + +class Certificate: + digest = "md5" + + cert = None + issuerKey = None + issuerSubject = None + parent = None + isCA = None # will be a boolean once set + + separator="-----parent-----" + + ## + # Create a certificate object. + # + # @param lifeDays life of cert in days - default is 1825==5 years + # @param create If create==True, then also create a blank X509 certificate. + # @param subject If subject!=None, then create a blank certificate and set + # it's subject name. + # @param string If string!=None, load the certficate from the string. + # @param filename If filename!=None, load the certficiate from the file. + # @param isCA If !=None, set whether this cert is for a CA + + def __init__(self, lifeDays=1825, create=False, subject=None, string=None, filename=None, isCA=None): + self.data = {} + if create or subject: + self.create(lifeDays) + if subject: + self.set_subject(subject) + if string: + self.load_from_string(string) + if filename: + self.load_from_file(filename) + + # Set the CA bit if a value was supplied + if isCA != None: + self.set_is_ca(isCA) + + # Create a blank X509 certificate and store it in this object. + + def create(self, lifeDays=1825): + self.cert = crypto.X509() + # FIXME: Use different serial #s + self.cert.set_serial_number(3) + self.cert.gmtime_adj_notBefore(0) # 0 means now + self.cert.gmtime_adj_notAfter(lifeDays*60*60*24) # five years is default + self.cert.set_version(2) # x509v3 so it can have extensions + + + ## + # Given a pyOpenSSL X509 object, store that object inside of this + # certificate object. + + def load_from_pyopenssl_x509(self, x509): + self.cert = x509 + + ## + # Load the certificate from a string + + def load_from_string(self, string): + # if it is a chain of multiple certs, then split off the first one and + # load it (support for the ---parent--- tag as well as normal chained certs) + + string = string.strip() + + # If it's not in proper PEM format, wrap it + if string.count('-----BEGIN CERTIFICATE') == 0: + string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string + + # If there is a PEM cert in there, but there is some other text first + # such as the text of the certificate, skip the text + beg = string.find('-----BEGIN CERTIFICATE') + if beg > 0: + # skipping over non cert beginning + string = string[beg:] + + parts = [] + + if string.count('-----BEGIN CERTIFICATE-----') > 1 and \ + string.count(Certificate.separator) == 0: + parts = string.split('-----END CERTIFICATE-----',1) + parts[0] += '-----END CERTIFICATE-----' + else: + parts = string.split(Certificate.separator, 1) + + self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0]) + + # if there are more certs, then create a parent and let the parent load + # itself from the remainder of the string + if len(parts) > 1 and parts[1] != '': + self.parent = self.__class__() + self.parent.load_from_string(parts[1]) + + ## + # Load the certificate from a file + + def load_from_file(self, filename): + file = open(filename) + string = file.read() + self.load_from_string(string) + self.filename=filename + + ## + # Save the certificate to a string. + # + # @param save_parents If save_parents==True, then also save the parent certificates. + + def save_to_string(self, save_parents=True): + string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert) + if save_parents and self.parent: + string = string + self.parent.save_to_string(save_parents) + return string + + ## + # Save the certificate to a file. + # @param save_parents If save_parents==True, then also save the parent certificates. + + def save_to_file(self, filename, save_parents=True, filep=None): + string = self.save_to_string(save_parents=save_parents) + if filep: + f = filep + else: + f = open(filename, 'w') + f.write(string) + f.close() + self.filename=filename + + ## + # Save the certificate to a random file in /tmp/ + # @param save_parents If save_parents==True, then also save the parent certificates. + def save_to_random_tmp_file(self, save_parents=True): + fp, filename = mkstemp(suffix='cert', text=True) + fp = os.fdopen(fp, "w") + self.save_to_file(filename, save_parents=True, filep=fp) + return filename + + ## + # Sets the issuer private key and name + # @param key Keypair object containing the private key of the issuer + # @param subject String containing the name of the issuer + # @param cert (optional) Certificate object containing the name of the issuer + + def set_issuer(self, key, subject=None, cert=None): + self.issuerKey = key + if subject: + # it's a mistake to use subject and cert params at the same time + assert(not cert) + if isinstance(subject, dict) or isinstance(subject, str): + req = crypto.X509Req() + reqSubject = req.get_subject() + if (isinstance(subject, dict)): + for key in reqSubject.keys(): + setattr(reqSubject, key, subject[key]) + else: + setattr(reqSubject, "CN", subject) + subject = reqSubject + # subject is not valid once req is out of scope, so save req + self.issuerReq = req + if cert: + # if a cert was supplied, then get the subject from the cert + subject = cert.cert.get_subject() + assert(subject) + self.issuerSubject = subject + + ## + # Get the issuer name + + def get_issuer(self, which="CN"): + x = self.cert.get_issuer() + return getattr(x, which) + + ## + # Set the subject name of the certificate + + def set_subject(self, name): + req = crypto.X509Req() + subj = req.get_subject() + if (isinstance(name, dict)): + for key in name.keys(): + setattr(subj, key, name[key]) + else: + setattr(subj, "CN", name) + self.cert.set_subject(subj) + + ## + # Get the subject name of the certificate + + def get_subject(self, which="CN"): + x = self.cert.get_subject() + return getattr(x, which) + + ## + # Get a pretty-print subject name of the certificate + + def get_printable_subject(self): + x = self.cert.get_subject() + return "[ OU: %s, CN: %s, SubjectAltName: %s ]" % (getattr(x, "OU"), getattr(x, "CN"), self.get_data()) + + ## + # Get the public key of the certificate. + # + # @param key Keypair object containing the public key + + def set_pubkey(self, key): + assert(isinstance(key, Keypair)) + self.cert.set_pubkey(key.get_openssl_pkey()) + + ## + # Get the public key of the certificate. + # It is returned in the form of a Keypair object. + + def get_pubkey(self): + m2x509 = X509.load_cert_string(self.save_to_string()) + pkey = Keypair() + pkey.key = self.cert.get_pubkey() + pkey.m2key = m2x509.get_pubkey() + return pkey + + def set_intermediate_ca(self, val): + return self.set_is_ca(val) + + # Set whether this cert is for a CA. All signers and only signers should be CAs. + # The local member starts unset, letting us check that you only set it once + # @param val Boolean indicating whether this cert is for a CA + def set_is_ca(self, val): + if val is None: + return + + if self.isCA != None: + # Can't double set properties + raise Exception, "Cannot set basicConstraints CA:?? more than once. Was %s, trying to set as %s" % (self.isCA, val) + + self.isCA = val + if val: + self.add_extension('basicConstraints', 1, 'CA:TRUE') + else: + self.add_extension('basicConstraints', 1, 'CA:FALSE') + + + + ## + # Add an X509 extension to the certificate. Add_extension can only be called + # once for a particular extension name, due to limitations in the underlying + # library. + # + # @param name string containing name of extension + # @param value string containing value of the extension + + def add_extension(self, name, critical, value): + oldExtVal = None + try: + oldExtVal = self.get_extension(name) + except: + # M2Crypto LookupError when the extension isn't there (yet) + pass + + # This code limits you from adding the extension with the same value + # The method comment says you shouldn't do this with the same name + # But actually it (m2crypto) appears to allow you to do this. + if oldExtVal and oldExtVal == value: + # don't add this extension again + # just do nothing as here + return + # FIXME: What if they are trying to set with a different value? + # Is this ever OK? Or should we raise an exception? +# elif oldExtVal: +# raise "Cannot add extension %s which had val %s with new val %s" % (name, oldExtVal, value) + + ext = crypto.X509Extension (name, critical, value) + self.cert.add_extensions([ext]) + + ## + # Get an X509 extension from the certificate + + def get_extension(self, name): + + # pyOpenSSL does not have a way to get extensions + m2x509 = X509.load_cert_string(self.save_to_string()) + value = m2x509.get_ext(name).get_value() + + return value + + ## + # Set_data is a wrapper around add_extension. It stores the parameter str in + # the X509 subject_alt_name extension. Set_data can only be called once, due + # to limitations in the underlying library. + + def set_data(self, str, field='subjectAltName'): + # pyOpenSSL only allows us to add extensions, so if we try to set the + # same extension more than once, it will not work + if self.data.has_key(field): + raise "Cannot set ", field, " more than once" + self.data[field] = str + self.add_extension(field, 0, str) + + ## + # Return the data string that was previously set with set_data + + def get_data(self, field='subjectAltName'): + if self.data.has_key(field): + return self.data[field] + + try: + uri = self.get_extension(field) + self.data[field] = uri + except LookupError: + return None + + return self.data[field] + + ## + # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer(). + + def sign(self): + logger.debug('certificate.sign') + assert self.cert != None + assert self.issuerSubject != None + assert self.issuerKey != None + self.cert.set_issuer(self.issuerSubject) + self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest) + + ## + # Verify the authenticity of a certificate. + # @param pkey is a Keypair object representing a public key. If Pkey + # did not sign the certificate, then an exception will be thrown. + + def verify(self, pkey): + # pyOpenSSL does not have a way to verify signatures + m2x509 = X509.load_cert_string(self.save_to_string()) + m2pkey = pkey.get_m2_pkey() + # verify it + return m2x509.verify(m2pkey) + + # XXX alternatively, if openssl has been patched, do the much simpler: + # try: + # self.cert.verify(pkey.get_openssl_key()) + # return 1 + # except: + # return 0 + + ## + # Return True if pkey is identical to the public key that is contained in the certificate. + # @param pkey Keypair object + + def is_pubkey(self, pkey): + return self.get_pubkey().is_same(pkey) + + ## + # Given a certificate cert, verify that this certificate was signed by the + # public key contained in cert. Throw an exception otherwise. + # + # @param cert certificate object + + def is_signed_by_cert(self, cert): + k = cert.get_pubkey() + result = self.verify(k) + return result + + ## + # Set the parent certficiate. + # + # @param p certificate object. + + def set_parent(self, p): + self.parent = p + + ## + # Return the certificate object of the parent of this certificate. + + def get_parent(self): + return self.parent + + ## + # Verification examines a chain of certificates to ensure that each parent + # signs the child, and that some certificate in the chain is signed by a + # trusted certificate. + # + # Verification is a basic recursion:
+    #     if this_certificate was signed by trusted_certs:
+    #         return
+    #     else
+    #         return verify_chain(parent, trusted_certs)
+    # 
+ # + # At each recursion, the parent is tested to ensure that it did sign the + # child. If a parent did not sign a child, then an exception is thrown. If + # the bottom of the recursion is reached and the certificate does not match + # a trusted root, then an exception is thrown. + # Also require that parents are CAs. + # + # @param Trusted_certs is a list of certificates that are trusted. + # + + def verify_chain(self, trusted_certs = None): + # Verify a chain of certificates. Each certificate must be signed by + # the public key contained in it's parent. The chain is recursed + # until a certificate is found that is signed by a trusted root. + + # verify expiration time + if self.cert.has_expired(): + logger.debug("verify_chain: NO, Certificate %s has expired" % self.get_printable_subject()) + raise CertExpired(self.get_printable_subject(), "client cert") + + # if this cert is signed by a trusted_cert, then we are set + for trusted_cert in trusted_certs: + if self.is_signed_by_cert(trusted_cert): + # verify expiration of trusted_cert ? + if not trusted_cert.cert.has_expired(): + logger.debug("verify_chain: YES. Cert %s signed by trusted cert %s"%( + self.get_printable_subject(), trusted_cert.get_printable_subject())) + return trusted_cert + else: + logger.debug("verify_chain: NO. Cert %s is signed by trusted_cert %s, but that signer is expired..."%( + self.get_printable_subject(),trusted_cert.get_printable_subject())) + raise CertExpired(self.get_printable_subject()," signer trusted_cert %s"%trusted_cert.get_printable_subject()) + + # if there is no parent, then no way to verify the chain + if not self.parent: + logger.debug("verify_chain: NO. %s has no parent and issuer %s is not in %d trusted roots"%(self.get_printable_subject(), self.get_issuer(), len(trusted_certs))) + raise CertMissingParent(self.get_printable_subject() + ": Issuer %s not trusted by any of %d trusted roots, and cert has no parent." % (self.get_issuer(), len(trusted_certs))) + + # if it wasn't signed by the parent... + if not self.is_signed_by_cert(self.parent): + logger.debug("verify_chain: NO. %s is not signed by parent %s, but by %s"%\ + (self.get_printable_subject(), + self.parent.get_printable_subject(), + self.get_issuer())) + raise CertNotSignedByParent("%s: Parent %s, issuer %s"\ + % (self.get_printable_subject(), + self.parent.get_printable_subject(), + self.get_issuer())) + + # Confirm that the parent is a CA. Only CAs can be trusted as + # signers. + # Note that trusted roots are not parents, so don't need to be + # CAs. + # Ugly - cert objects aren't parsed so we need to read the + # extension and hope there are no other basicConstraints + if not self.parent.isCA and not (self.parent.get_extension('basicConstraints') == 'CA:TRUE'): + logger.warn("verify_chain: cert %s's parent %s is not a CA" % \ + (self.get_printable_subject(), self.parent.get_printable_subject())) + raise CertNotSignedByParent("%s: Parent %s not a CA" % (self.get_printable_subject(), + self.parent.get_printable_subject())) + + # if the parent isn't verified... + logger.debug("verify_chain: .. %s, -> verifying parent %s"%\ + (self.get_printable_subject(),self.parent.get_printable_subject())) + self.parent.verify_chain(trusted_certs) + + return + + ### more introspection + def get_extensions(self): + # pyOpenSSL does not have a way to get extensions + triples=[] + m2x509 = X509.load_cert_string(self.save_to_string()) + nb_extensions=m2x509.get_ext_count() + logger.debug("X509 had %d extensions"%nb_extensions) + for i in range(nb_extensions): + ext=m2x509.get_ext_at(i) + triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) ) + return triples + + def get_data_names(self): + return self.data.keys() + + def get_all_datas (self): + triples=self.get_extensions() + for name in self.get_data_names(): + triples.append( (name,self.get_data(name),'data',) ) + return triples + + # only informative + def get_filename(self): + return getattr(self,'filename',None) + + def dump (self, *args, **kwargs): + print self.dump_string(*args, **kwargs) + + def dump_string (self,show_extensions=False): + result = "" + result += "CERTIFICATE for %s\n"%self.get_printable_subject() + result += "Issued by %s\n"%self.get_issuer() + filename=self.get_filename() + if filename: result += "Filename %s\n"%filename + if show_extensions: + all_datas=self.get_all_datas() + result += " has %d extensions/data attached"%len(all_datas) + for (n,v,c) in all_datas: + if c=='data': + result += " data: %s=%s\n"%(n,v) + else: + result += " ext: %s (crit=%s)=<<<%s>>>\n"%(n,c,v) + return result diff --git a/sfa/trust/credential.py b/sfa/trust/credential.py index a18019d8..8fd11e8e 100644 --- a/sfa/trust/credential.py +++ b/sfa/trust/credential.py @@ -1,1063 +1,1064 @@ -#---------------------------------------------------------------------- -# Copyright (c) 2008 Board of Trustees, Princeton University -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and/or hardware specification (the "Work") to -# deal in the Work without restriction, including without limitation the -# rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Work, and to permit persons to whom the Work -# is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Work. -# -# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS -# IN THE WORK. -#---------------------------------------------------------------------- -## -# Implements SFA Credentials -# -# Credentials are signed XML files that assign a subject gid privileges to an object gid -## - -import os -from types import StringTypes -import datetime -from StringIO import StringIO -from tempfile import mkstemp -from xml.dom.minidom import Document, parseString - -HAVELXML = False -try: - from lxml import etree - HAVELXML = True -except: - pass - -from sfa.util.faults import * -from sfa.util.sfalogging import logger -from sfa.util.sfatime import utcparse -from sfa.trust.certificate import Keypair -from sfa.trust.credential_legacy import CredentialLegacy -from sfa.trust.rights import Right, Rights, determine_rights -from sfa.trust.gid import GID -from sfa.util.xrn import urn_to_hrn, hrn_authfor_hrn - -# 2 weeks, in seconds -DEFAULT_CREDENTIAL_LIFETIME = 86400 * 14 - - -# TODO: -# . make privs match between PG and PL -# . Need to add support for other types of credentials, e.g. tickets -# . add namespaces to signed-credential element? - -signature_template = \ -''' - - - - - - - - - - - - - - - - - - - - - - -''' - -# PG formats the template (whitespace) slightly differently. -# Note that they don't include the xmlns in the template, but add it later. -# Otherwise the two are equivalent. -#signature_template_as_in_pg = \ -#''' -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -#''' - -## -# Convert a string into a bool -# used to convert an xsd:boolean to a Python boolean -def str2bool(str): - if str.lower() in ['true','1']: - return True - return False - - -## -# Utility function to get the text of an XML element - -def getTextNode(element, subele): - sub = element.getElementsByTagName(subele)[0] - if len(sub.childNodes) > 0: - return sub.childNodes[0].nodeValue - else: - return None - -## -# Utility function to set the text of an XML element -# It creates the element, adds the text to it, -# and then appends it to the parent. - -def append_sub(doc, parent, element, text): - ele = doc.createElement(element) - ele.appendChild(doc.createTextNode(text)) - parent.appendChild(ele) - -## -# Signature contains information about an xmlsec1 signature -# for a signed-credential -# - -class Signature(object): - - def __init__(self, string=None): - self.refid = None - self.issuer_gid = None - self.xml = None - if string: - self.xml = string - self.decode() - - - def get_refid(self): - if not self.refid: - self.decode() - return self.refid - - def get_xml(self): - if not self.xml: - self.encode() - return self.xml - - def set_refid(self, id): - self.refid = id - - def get_issuer_gid(self): - if not self.gid: - self.decode() - return self.gid - - def set_issuer_gid(self, gid): - self.gid = gid - - def decode(self): - try: - doc = parseString(self.xml) - except ExpatError,e: - logger.log_exc ("Failed to parse credential, %s"%self.xml) - raise - sig = doc.getElementsByTagName("Signature")[0] - self.set_refid(sig.getAttribute("xml:id").strip("Sig_")) - keyinfo = sig.getElementsByTagName("X509Data")[0] - szgid = getTextNode(keyinfo, "X509Certificate") - szgid = "-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----" % szgid - self.set_issuer_gid(GID(string=szgid)) - - def encode(self): - self.xml = signature_template % (self.get_refid(), self.get_refid()) - - -## -# A credential provides a caller gid with privileges to an object gid. -# A signed credential is signed by the object's authority. -# -# Credentials are encoded in one of two ways. The legacy style places -# it in the subjectAltName of an X509 certificate. The new credentials -# are placed in signed XML. -# -# WARNING: -# In general, a signed credential obtained externally should -# not be changed else the signature is no longer valid. So, once -# you have loaded an existing signed credential, do not call encode() or sign() on it. - -def filter_creds_by_caller(creds, caller_hrn_list): - """ - Returns a list of creds who's gid caller matches the - specified caller hrn - """ - if not isinstance(creds, list): creds = [creds] - if not isinstance(caller_hrn_list, list): - caller_hrn_list = [caller_hrn_list] - caller_creds = [] - for cred in creds: - try: - tmp_cred = Credential(string=cred) - if tmp_cred.get_gid_caller().get_hrn() in caller_hrn_list: - caller_creds.append(cred) - except: pass - return caller_creds - -class Credential(object): - - ## - # Create a Credential object - # - # @param create If true, create a blank x509 certificate - # @param subject If subject!=None, create an x509 cert with the subject name - # @param string If string!=None, load the credential from the string - # @param filename If filename!=None, load the credential from the file - # FIXME: create and subject are ignored! - def __init__(self, create=False, subject=None, string=None, filename=None): - self.gidCaller = None - self.gidObject = None - self.expiration = None - self.privileges = None - self.issuer_privkey = None - self.issuer_gid = None - self.issuer_pubkey = None - self.parent = None - self.signature = None - self.xml = None - self.refid = None - self.legacy = None - - # Check if this is a legacy credential, translate it if so - if string or filename: - if string: - str = string - elif filename: - str = file(filename).read() - - if str.strip().startswith("-----"): - self.legacy = CredentialLegacy(False,string=str) - self.translate_legacy(str) - else: - self.xml = str - self.decode() - - # Find an xmlsec1 path - self.xmlsec_path = '' - paths = ['/usr/bin','/usr/local/bin','/bin','/opt/bin','/opt/local/bin'] - for path in paths: - if os.path.isfile(path + '/' + 'xmlsec1'): - self.xmlsec_path = path + '/' + 'xmlsec1' - break - - def get_subject(self): - if not self.gidObject: - self.decode() - return self.gidObject.get_printable_subject() - - def get_summary_tostring(self): - if not self.gidObject: - self.decode() - obj = self.gidObject.get_printable_subject() - caller = self.gidCaller.get_printable_subject() - exp = self.get_expiration() - # Summarize the rights too? The issuer? - return "[ Grant %s rights on %s until %s ]" % (caller, obj, exp) - - def get_signature(self): - if not self.signature: - self.decode() - return self.signature - - def set_signature(self, sig): - self.signature = sig - - - ## - # Translate a legacy credential into a new one - # - # @param String of the legacy credential - - def translate_legacy(self, str): - legacy = CredentialLegacy(False,string=str) - self.gidCaller = legacy.get_gid_caller() - self.gidObject = legacy.get_gid_object() - lifetime = legacy.get_lifetime() - if not lifetime: - self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME)) - else: - self.set_expiration(int(lifetime)) - self.lifeTime = legacy.get_lifetime() - self.set_privileges(legacy.get_privileges()) - self.get_privileges().delegate_all_privileges(legacy.get_delegate()) - - ## - # Need the issuer's private key and name - # @param key Keypair object containing the private key of the issuer - # @param gid GID of the issuing authority - - def set_issuer_keys(self, privkey, gid): - self.issuer_privkey = privkey - self.issuer_gid = gid - - - ## - # Set this credential's parent - def set_parent(self, cred): - self.parent = cred - self.updateRefID() - - ## - # set the GID of the caller - # - # @param gid GID object of the caller - - def set_gid_caller(self, gid): - self.gidCaller = gid - # gid origin caller is the caller's gid by default - self.gidOriginCaller = gid - - ## - # get the GID of the object - - def get_gid_caller(self): - if not self.gidCaller: - self.decode() - return self.gidCaller - - ## - # set the GID of the object - # - # @param gid GID object of the object - - def set_gid_object(self, gid): - self.gidObject = gid - - ## - # get the GID of the object - - def get_gid_object(self): - if not self.gidObject: - self.decode() - return self.gidObject - - - - ## - # Expiration: an absolute UTC time of expiration (as either an int or string or datetime) - # - def set_expiration(self, expiration): - if isinstance(expiration, (int, float)): - self.expiration = datetime.datetime.fromtimestamp(expiration) - elif isinstance (expiration, datetime.datetime): - self.expiration = expiration - elif isinstance (expiration, StringTypes): - self.expiration = utcparse (expiration) - else: - logger.error ("unexpected input type in Credential.set_expiration") - - - ## - # get the lifetime of the credential (always in datetime format) - - def get_expiration(self): - if not self.expiration: - self.decode() - # at this point self.expiration is normalized as a datetime - DON'T call utcparse again - return self.expiration - - ## - # For legacy sake - def get_lifetime(self): - return self.get_expiration() - - ## - # set the privileges - # - # @param privs either a comma-separated list of privileges of a Rights object - - def set_privileges(self, privs): - if isinstance(privs, str): - self.privileges = Rights(string = privs) - else: - self.privileges = privs - - - ## - # return the privileges as a Rights object - - def get_privileges(self): - if not self.privileges: - self.decode() - return self.privileges - - ## - # determine whether the credential allows a particular operation to be - # performed - # - # @param op_name string specifying name of operation ("lookup", "update", etc) - - def can_perform(self, op_name): - rights = self.get_privileges() - - if not rights: - return False - - return rights.can_perform(op_name) - - - ## - # Encode the attributes of the credential into an XML string - # This should be done immediately before signing the credential. - # WARNING: - # In general, a signed credential obtained externally should - # not be changed else the signature is no longer valid. So, once - # you have loaded an existing signed credential, do not call encode() or sign() on it. - - def encode(self): - # Create the XML document - doc = Document() - signed_cred = doc.createElement("signed-credential") - -# Declare namespaces -# Note that credential/policy.xsd are really the PG schemas -# in a PL namespace. -# Note that delegation of credentials between the 2 only really works -# cause those schemas are identical. -# Also note these PG schemas talk about PG tickets and CM policies. - signed_cred.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance") - signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.planet-lab.org/resources/sfa/credential.xsd") - signed_cred.setAttribute("xsi:schemaLocation", "http://www.planet-lab.org/resources/sfa/ext/policy/1 http://www.planet-lab.org/resources/sfa/ext/policy/1/policy.xsd") - -# PG says for those last 2: -# signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd") -# signed_cred.setAttribute("xsi:schemaLocation", "http://www.protogeni.net/resources/credential/ext/policy/1 http://www.protogeni.net/resources/credential/ext/policy/1/policy.xsd") - - doc.appendChild(signed_cred) - - # Fill in the bit - cred = doc.createElement("credential") - cred.setAttribute("xml:id", self.get_refid()) - signed_cred.appendChild(cred) - append_sub(doc, cred, "type", "privilege") - append_sub(doc, cred, "serial", "8") - append_sub(doc, cred, "owner_gid", self.gidCaller.save_to_string()) - append_sub(doc, cred, "owner_urn", self.gidCaller.get_urn()) - append_sub(doc, cred, "target_gid", self.gidObject.save_to_string()) - append_sub(doc, cred, "target_urn", self.gidObject.get_urn()) - append_sub(doc, cred, "uuid", "") - if not self.expiration: - self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME)) - self.expiration = self.expiration.replace(microsecond=0) - append_sub(doc, cred, "expires", self.expiration.isoformat()) - privileges = doc.createElement("privileges") - cred.appendChild(privileges) - - if self.privileges: - rights = self.get_privileges() - for right in rights.rights: - priv = doc.createElement("privilege") - append_sub(doc, priv, "name", right.kind) - append_sub(doc, priv, "can_delegate", str(right.delegate).lower()) - privileges.appendChild(priv) - - # Add the parent credential if it exists - if self.parent: - sdoc = parseString(self.parent.get_xml()) - # If the root node is a signed-credential (it should be), then - # get all its attributes and attach those to our signed_cred - # node. - # Specifically, PG and PLadd attributes for namespaces (which is reasonable), - # and we need to include those again here or else their signature - # no longer matches on the credential. - # We expect three of these, but here we copy them all: -# signed_cred.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance") -# and from PG (PL is equivalent, as shown above): -# signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd") -# signed_cred.setAttribute("xsi:schemaLocation", "http://www.protogeni.net/resources/credential/ext/policy/1 http://www.protogeni.net/resources/credential/ext/policy/1/policy.xsd") - - # HOWEVER! - # PL now also declares these, with different URLs, so - # the code notices those attributes already existed with - # different values, and complains. - # This happens regularly on delegation now that PG and - # PL both declare the namespace with different URLs. - # If the content ever differs this is a problem, - # but for now it works - different URLs (values in the attributes) - # but the same actual schema, so using the PG schema - # on delegated-to-PL credentials works fine. - - # Note: you could also not copy attributes - # which already exist. It appears that both PG and PL - # will actually validate a slicecred with a parent - # signed using PG namespaces and a child signed with PL - # namespaces over the whole thing. But I don't know - # if that is a bug in xmlsec1, an accident since - # the contents of the schemas are the same, - # or something else, but it seems odd. And this works. - parentRoot = sdoc.documentElement - if parentRoot.tagName == "signed-credential" and parentRoot.hasAttributes(): - for attrIx in range(0, parentRoot.attributes.length): - attr = parentRoot.attributes.item(attrIx) - # returns the old attribute of same name that was - # on the credential - # Below throws InUse exception if we forgot to clone the attribute first - oldAttr = signed_cred.setAttributeNode(attr.cloneNode(True)) - if oldAttr and oldAttr.value != attr.value: - msg = "Delegating cred from owner %s to %s over %s replaced attribute %s value '%s' with '%s'" % (self.parent.gidCaller.get_urn(), self.gidCaller.get_urn(), self.gidObject.get_urn(), oldAttr.name, oldAttr.value, attr.value) - logger.warn(msg) - #raise CredentialNotVerifiable("Can't encode new valid delegated credential: %s" % msg) - - p_cred = doc.importNode(sdoc.getElementsByTagName("credential")[0], True) - p = doc.createElement("parent") - p.appendChild(p_cred) - cred.appendChild(p) - # done handling parent credential - - # Create the tag - signatures = doc.createElement("signatures") - signed_cred.appendChild(signatures) - - # Add any parent signatures - if self.parent: - for cur_cred in self.get_credential_list()[1:]: - sdoc = parseString(cur_cred.get_signature().get_xml()) - ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True) - signatures.appendChild(ele) - - # Get the finished product - self.xml = doc.toxml() - - - def save_to_random_tmp_file(self): - fp, filename = mkstemp(suffix='cred', text=True) - fp = os.fdopen(fp, "w") - self.save_to_file(filename, save_parents=True, filep=fp) - return filename - - def save_to_file(self, filename, save_parents=True, filep=None): - if not self.xml: - self.encode() - if filep: - f = filep - else: - f = open(filename, "w") - f.write(self.xml) - f.close() - - def save_to_string(self, save_parents=True): - if not self.xml: - self.encode() - return self.xml - - def get_refid(self): - if not self.refid: - self.refid = 'ref0' - return self.refid - - def set_refid(self, rid): - self.refid = rid - - ## - # Figure out what refids exist, and update this credential's id - # so that it doesn't clobber the others. Returns the refids of - # the parents. - - def updateRefID(self): - if not self.parent: - self.set_refid('ref0') - return [] - - refs = [] - - next_cred = self.parent - while next_cred: - refs.append(next_cred.get_refid()) - if next_cred.parent: - next_cred = next_cred.parent - else: - next_cred = None - - - # Find a unique refid for this credential - rid = self.get_refid() - while rid in refs: - val = int(rid[3:]) - rid = "ref%d" % (val + 1) - - # Set the new refid - self.set_refid(rid) - - # Return the set of parent credential ref ids - return refs - - def get_xml(self): - if not self.xml: - self.encode() - return self.xml - - ## - # Sign the XML file created by encode() - # - # WARNING: - # In general, a signed credential obtained externally should - # not be changed else the signature is no longer valid. So, once - # you have loaded an existing signed credential, do not call encode() or sign() on it. - - def sign(self): - if not self.issuer_privkey or not self.issuer_gid: - return - doc = parseString(self.get_xml()) - sigs = doc.getElementsByTagName("signatures")[0] - - # Create the signature template to be signed - signature = Signature() - signature.set_refid(self.get_refid()) - sdoc = parseString(signature.get_xml()) - sig_ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True) - sigs.appendChild(sig_ele) - - self.xml = doc.toxml() - - - # Split the issuer GID into multiple certificates if it's a chain - chain = GID(filename=self.issuer_gid) - gid_files = [] - while chain: - gid_files.append(chain.save_to_random_tmp_file(False)) - if chain.get_parent(): - chain = chain.get_parent() - else: - chain = None - - - # Call out to xmlsec1 to sign it - ref = 'Sig_%s' % self.get_refid() - filename = self.save_to_random_tmp_file() - signed = os.popen('%s --sign --node-id "%s" --privkey-pem %s,%s %s' \ - % (self.xmlsec_path, ref, self.issuer_privkey, ",".join(gid_files), filename)).read() - os.remove(filename) - - for gid_file in gid_files: - os.remove(gid_file) - - self.xml = signed - - # This is no longer a legacy credential - if self.legacy: - self.legacy = None - - # Update signatures - self.decode() - - - ## - # Retrieve the attributes of the credential from the XML. - # This is automatically called by the various get_* methods of - # this class and should not need to be called explicitly. - - def decode(self): - if not self.xml: - return - doc = parseString(self.xml) - sigs = [] - signed_cred = doc.getElementsByTagName("signed-credential") - - # Is this a signed-cred or just a cred? - if len(signed_cred) > 0: - creds = signed_cred[0].getElementsByTagName("credential") - signatures = signed_cred[0].getElementsByTagName("signatures") - if len(signatures) > 0: - sigs = signatures[0].getElementsByTagName("Signature") - else: - creds = doc.getElementsByTagName("credential") - - if creds is None or len(creds) == 0: - # malformed cred file - raise CredentialNotVerifiable("Malformed XML: No credential tag found") - - # Just take the first cred if there are more than one - cred = creds[0] - - self.set_refid(cred.getAttribute("xml:id")) - self.set_expiration(utcparse(getTextNode(cred, "expires"))) - self.gidCaller = GID(string=getTextNode(cred, "owner_gid")) - self.gidObject = GID(string=getTextNode(cred, "target_gid")) - - - # Process privileges - privs = cred.getElementsByTagName("privileges")[0] - rlist = Rights() - for priv in privs.getElementsByTagName("privilege"): - kind = getTextNode(priv, "name") - deleg = str2bool(getTextNode(priv, "can_delegate")) - if kind == '*': - # Convert * into the default privileges for the credential's type - # Each inherits the delegatability from the * above - _ , type = urn_to_hrn(self.gidObject.get_urn()) - rl = determine_rights(type, self.gidObject.get_urn()) - for r in rl.rights: - r.delegate = deleg - rlist.add(r) - else: - rlist.add(Right(kind.strip(), deleg)) - self.set_privileges(rlist) - - - # Is there a parent? - parent = cred.getElementsByTagName("parent") - if len(parent) > 0: - parent_doc = parent[0].getElementsByTagName("credential")[0] - parent_xml = parent_doc.toxml() - self.parent = Credential(string=parent_xml) - self.updateRefID() - - # Assign the signatures to the credentials - for sig in sigs: - Sig = Signature(string=sig.toxml()) - - for cur_cred in self.get_credential_list(): - if cur_cred.get_refid() == Sig.get_refid(): - cur_cred.set_signature(Sig) - - - ## - # Verify - # trusted_certs: A list of trusted GID filenames (not GID objects!) - # Chaining is not supported within the GIDs by xmlsec1. - # - # trusted_certs_required: Should usually be true. Set False means an - # empty list of trusted_certs would still let this method pass. - # It just skips xmlsec1 verification et al. Only used by some utils - # - # Verify that: - # . All of the signatures are valid and that the issuers trace back - # to trusted roots (performed by xmlsec1) - # . The XML matches the credential schema - # . That the issuer of the credential is the authority in the target's urn - # . In the case of a delegated credential, this must be true of the root - # . That all of the gids presented in the credential are valid - # . Including verifying GID chains, and includ the issuer - # . The credential is not expired - # - # -- For Delegates (credentials with parents) - # . The privileges must be a subset of the parent credentials - # . The privileges must have "can_delegate" set for each delegated privilege - # . The target gid must be the same between child and parents - # . The expiry time on the child must be no later than the parent - # . The signer of the child must be the owner of the parent - # - # -- Verify does *NOT* - # . ensure that an xmlrpc client's gid matches a credential gid, that - # must be done elsewhere - # - # @param trusted_certs: The certificates of trusted CA certificates - def verify(self, trusted_certs=None, schema=None, trusted_certs_required=True): - if not self.xml: - self.decode() - - # validate against RelaxNG schema - if HAVELXML and not self.legacy: - if schema and os.path.exists(schema): - tree = etree.parse(StringIO(self.xml)) - schema_doc = etree.parse(schema) - xmlschema = etree.XMLSchema(schema_doc) - if not xmlschema.validate(tree): - error = xmlschema.error_log.last_error - message = "%s: %s (line %s)" % (self.get_summary_tostring(), error.message, error.line) - raise CredentialNotVerifiable(message) - - if trusted_certs_required and trusted_certs is None: - trusted_certs = [] - -# trusted_cert_objects = [GID(filename=f) for f in trusted_certs] - trusted_cert_objects = [] - ok_trusted_certs = [] - # If caller explicitly passed in None that means skip cert chain validation. - # Strange and not typical - if trusted_certs is not None: - for f in trusted_certs: - try: - # Failures here include unreadable files - # or non PEM files - trusted_cert_objects.append(GID(filename=f)) - ok_trusted_certs.append(f) - except Exception, exc: - logger.error("Failed to load trusted cert from %s: %r", f, exc) - trusted_certs = ok_trusted_certs - - # Use legacy verification if this is a legacy credential - if self.legacy: - self.legacy.verify_chain(trusted_cert_objects) - if self.legacy.client_gid: - self.legacy.client_gid.verify_chain(trusted_cert_objects) - if self.legacy.object_gid: - self.legacy.object_gid.verify_chain(trusted_cert_objects) - return True - - # make sure it is not expired - if self.get_expiration() < datetime.datetime.utcnow(): - raise CredentialNotVerifiable("Credential %s expired at %s" % (self.get_summary_tostring(), self.expiration.isoformat())) - - # Verify the signatures - filename = self.save_to_random_tmp_file() - if trusted_certs is not None: - cert_args = " ".join(['--trusted-pem %s' % x for x in trusted_certs]) - - # If caller explicitly passed in None that means skip cert chain validation. - # - Strange and not typical - if trusted_certs is not None: - # Verify the gids of this cred and of its parents - for cur_cred in self.get_credential_list(): - cur_cred.get_gid_object().verify_chain(trusted_cert_objects) - cur_cred.get_gid_caller().verify_chain(trusted_cert_objects) - - refs = [] - refs.append("Sig_%s" % self.get_refid()) - - parentRefs = self.updateRefID() - for ref in parentRefs: - refs.append("Sig_%s" % ref) - - for ref in refs: - # If caller explicitly passed in None that means skip xmlsec1 validation. - # Strange and not typical - if trusted_certs is None: - break - -# print "Doing %s --verify --node-id '%s' %s %s 2>&1" % \ -# (self.xmlsec_path, ref, cert_args, filename) - verified = os.popen('%s --verify --node-id "%s" %s %s 2>&1' \ - % (self.xmlsec_path, ref, cert_args, filename)).read() - if not verified.strip().startswith("OK"): - # xmlsec errors have a msg= which is the interesting bit. - mstart = verified.find("msg=") - msg = "" - if mstart > -1 and len(verified) > 4: - mstart = mstart + 4 - mend = verified.find('\\', mstart) - msg = verified[mstart:mend] - raise CredentialNotVerifiable("xmlsec1 error verifying cred %s using Signature ID %s: %s %s" % (self.get_summary_tostring(), ref, msg, verified.strip())) - os.remove(filename) - - # Verify the parents (delegation) - if self.parent: - self.verify_parent(self.parent) - - # Make sure the issuer is the target's authority, and is - # itself a valid GID - self.verify_issuer(trusted_cert_objects) - return True - - ## - # Creates a list of the credential and its parents, with the root - # (original delegated credential) as the last item in the list - def get_credential_list(self): - cur_cred = self - list = [] - while cur_cred: - list.append(cur_cred) - if cur_cred.parent: - cur_cred = cur_cred.parent - else: - cur_cred = None - return list - - ## - # Make sure the credential's target gid (a) was signed by or (b) - # is the same as the entity that signed the original credential, - # or (c) is an authority over the target's namespace. - # Also ensure that the credential issuer / signer itself has a valid - # GID signature chain (signed by an authority with namespace rights). - def verify_issuer(self, trusted_gids): - root_cred = self.get_credential_list()[-1] - root_target_gid = root_cred.get_gid_object() - root_cred_signer = root_cred.get_signature().get_issuer_gid() - - # Case 1: - # Allow non authority to sign target and cred about target. - # - # Why do we need to allow non authorities to sign? - # If in the target gid validation step we correctly - # checked that the target is only signed by an authority, - # then this is just a special case of case 3. - # This short-circuit is the common case currently - - # and cause GID validation doesn't check 'authority', - # this allows users to generate valid slice credentials. - if root_target_gid.is_signed_by_cert(root_cred_signer): - # cred signer matches target signer, return success - return - - # Case 2: - # Allow someone to sign credential about themeselves. Used? - # If not, remove this. - #root_target_gid_str = root_target_gid.save_to_string() - #root_cred_signer_str = root_cred_signer.save_to_string() - #if root_target_gid_str == root_cred_signer_str: - # # cred signer is target, return success - # return - - # Case 3: - - # root_cred_signer is not the target_gid - # So this is a different gid that we have not verified. - # xmlsec1 verified the cert chain on this already, but - # it hasn't verified that the gid meets the HRN namespace - # requirements. - # Below we'll ensure that it is an authority. - # But we haven't verified that it is _signed by_ an authority - # We also don't know if xmlsec1 requires that cert signers - # are marked as CAs. - - # Note that if verify() gave us no trusted_gids then this - # call will fail. So skip it if we have no trusted_gids - if trusted_gids and len(trusted_gids) > 0: - root_cred_signer.verify_chain(trusted_gids) - else: - logger.debug("No trusted gids. Cannot verify that cred signer is signed by a trusted authority. Skipping that check.") - - # See if the signer is an authority over the domain of the target. - # There are multiple types of authority - accept them all here - # Maybe should be (hrn, type) = urn_to_hrn(root_cred_signer.get_urn()) - root_cred_signer_type = root_cred_signer.get_type() - if (root_cred_signer_type.find('authority') == 0): - #logger.debug('Cred signer is an authority') - # signer is an authority, see if target is in authority's domain - signerhrn = root_cred_signer.get_hrn() - if hrn_authfor_hrn(signerhrn, root_target_gid.get_hrn()): - return - - # We've required that the credential be signed by an authority - # for that domain. Reasonable and probably correct. - # A looser model would also allow the signer to be an authority - # in my control framework - eg My CA or CH. Even if it is not - # the CH that issued these, eg, user credentials. - - # Give up, credential does not pass issuer verification - - raise CredentialNotVerifiable("Could not verify credential owned by %s for object %s. Cred signer %s not the trusted authority for Cred target %s" % (self.gidCaller.get_urn(), self.gidObject.get_urn(), root_cred_signer.get_hrn(), root_target_gid.get_hrn())) - - - ## - # -- For Delegates (credentials with parents) verify that: - # . The privileges must be a subset of the parent credentials - # . The privileges must have "can_delegate" set for each delegated privilege - # . The target gid must be the same between child and parents - # . The expiry time on the child must be no later than the parent - # . The signer of the child must be the owner of the parent - def verify_parent(self, parent_cred): - # make sure the rights given to the child are a subset of the - # parents rights (and check delegate bits) - if not parent_cred.get_privileges().is_superset(self.get_privileges()): - raise ChildRightsNotSubsetOfParent(("Parent cred ref %s rights " % parent_cred.get_refid()) + - self.parent.get_privileges().save_to_string() + (" not superset of delegated cred %s ref %s rights " % (self.get_summary_tostring(), self.get_refid())) + - self.get_privileges().save_to_string()) - - # make sure my target gid is the same as the parent's - if not parent_cred.get_gid_object().save_to_string() == \ - self.get_gid_object().save_to_string(): - raise CredentialNotVerifiable("Delegated cred %s: Target gid not equal between parent and child. Parent %s" % (self.get_summary_tostring(), parent_cred.get_summary_tostring())) - - # make sure my expiry time is <= my parent's - if not parent_cred.get_expiration() >= self.get_expiration(): - raise CredentialNotVerifiable("Delegated credential %s expires after parent %s" % (self.get_summary_tostring(), parent_cred.get_summary_tostring())) - - # make sure my signer is the parent's caller - if not parent_cred.get_gid_caller().save_to_string(False) == \ - self.get_signature().get_issuer_gid().save_to_string(False): - raise CredentialNotVerifiable("Delegated credential %s not signed by parent %s's caller" % (self.get_summary_tostring(), parent_cred.get_summary_tostring())) - - # Recurse - if parent_cred.parent: - parent_cred.verify_parent(parent_cred.parent) - - - def delegate(self, delegee_gidfile, caller_keyfile, caller_gidfile): - """ - Return a delegated copy of this credential, delegated to the - specified gid's user. - """ - # get the gid of the object we are delegating - object_gid = self.get_gid_object() - object_hrn = object_gid.get_hrn() - - # the hrn of the user who will be delegated to - delegee_gid = GID(filename=delegee_gidfile) - delegee_hrn = delegee_gid.get_hrn() - - #user_key = Keypair(filename=keyfile) - #user_hrn = self.get_gid_caller().get_hrn() - subject_string = "%s delegated to %s" % (object_hrn, delegee_hrn) - dcred = Credential(subject=subject_string) - dcred.set_gid_caller(delegee_gid) - dcred.set_gid_object(object_gid) - dcred.set_parent(self) - dcred.set_expiration(self.get_expiration()) - dcred.set_privileges(self.get_privileges()) - dcred.get_privileges().delegate_all_privileges(True) - #dcred.set_issuer_keys(keyfile, delegee_gidfile) - dcred.set_issuer_keys(caller_keyfile, caller_gidfile) - dcred.encode() - dcred.sign() - - return dcred - - # only informative - def get_filename(self): - return getattr(self,'filename',None) - - ## - # Dump the contents of a credential to stdout in human-readable format - # - # @param dump_parents If true, also dump the parent certificates - def dump (self, *args, **kwargs): - print self.dump_string(*args, **kwargs) - - - def dump_string(self, dump_parents=False): - result="" - result += "CREDENTIAL %s\n" % self.get_subject() - filename=self.get_filename() - if filename: result += "Filename %s\n"%filename - result += " privs: %s\n" % self.get_privileges().save_to_string() - gidCaller = self.get_gid_caller() - if gidCaller: - result += " gidCaller:\n" - result += gidCaller.dump_string(8, dump_parents) - - if self.get_signature(): - print " gidIssuer:" - self.get_signature().get_issuer_gid().dump(8, dump_parents) - - gidObject = self.get_gid_object() - if gidObject: - result += " gidObject:\n" - result += gidObject.dump_string(8, dump_parents) - - if self.parent and dump_parents: - result += "\nPARENT" - result += self.parent.dump_string(True) - - return result +#---------------------------------------------------------------------- +# Copyright (c) 2008 Board of Trustees, Princeton University +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and/or hardware specification (the "Work") to +# deal in the Work without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Work, and to permit persons to whom the Work +# is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Work. +# +# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS +# IN THE WORK. +#---------------------------------------------------------------------- +## +# Implements SFA Credentials +# +# Credentials are signed XML files that assign a subject gid privileges to an object gid +## + +import os +from types import StringTypes +import datetime +from StringIO import StringIO +from tempfile import mkstemp +from xml.dom.minidom import Document, parseString + +HAVELXML = False +try: + from lxml import etree + HAVELXML = True +except: + pass + +from xml.parsers.expat import ExpatError + +from sfa.util.faults import CredentialNotVerifiable, ChildRightsNotSubsetOfParent +from sfa.util.sfalogging import logger +from sfa.util.sfatime import utcparse +from sfa.trust.credential_legacy import CredentialLegacy +from sfa.trust.rights import Right, Rights, determine_rights +from sfa.trust.gid import GID +from sfa.util.xrn import urn_to_hrn, hrn_authfor_hrn + +# 2 weeks, in seconds +DEFAULT_CREDENTIAL_LIFETIME = 86400 * 14 + + +# TODO: +# . make privs match between PG and PL +# . Need to add support for other types of credentials, e.g. tickets +# . add namespaces to signed-credential element? + +signature_template = \ +''' + + + + + + + + + + + + + + + + + + + + + + +''' + +# PG formats the template (whitespace) slightly differently. +# Note that they don't include the xmlns in the template, but add it later. +# Otherwise the two are equivalent. +#signature_template_as_in_pg = \ +#''' +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +# +#''' + +## +# Convert a string into a bool +# used to convert an xsd:boolean to a Python boolean +def str2bool(str): + if str.lower() in ['true','1']: + return True + return False + + +## +# Utility function to get the text of an XML element + +def getTextNode(element, subele): + sub = element.getElementsByTagName(subele)[0] + if len(sub.childNodes) > 0: + return sub.childNodes[0].nodeValue + else: + return None + +## +# Utility function to set the text of an XML element +# It creates the element, adds the text to it, +# and then appends it to the parent. + +def append_sub(doc, parent, element, text): + ele = doc.createElement(element) + ele.appendChild(doc.createTextNode(text)) + parent.appendChild(ele) + +## +# Signature contains information about an xmlsec1 signature +# for a signed-credential +# + +class Signature(object): + + def __init__(self, string=None): + self.refid = None + self.issuer_gid = None + self.xml = None + if string: + self.xml = string + self.decode() + + + def get_refid(self): + if not self.refid: + self.decode() + return self.refid + + def get_xml(self): + if not self.xml: + self.encode() + return self.xml + + def set_refid(self, id): + self.refid = id + + def get_issuer_gid(self): + if not self.gid: + self.decode() + return self.gid + + def set_issuer_gid(self, gid): + self.gid = gid + + def decode(self): + try: + doc = parseString(self.xml) + except ExpatError,e: + logger.log_exc ("Failed to parse credential, %s"%self.xml) + raise + sig = doc.getElementsByTagName("Signature")[0] + self.set_refid(sig.getAttribute("xml:id").strip("Sig_")) + keyinfo = sig.getElementsByTagName("X509Data")[0] + szgid = getTextNode(keyinfo, "X509Certificate") + szgid = "-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----" % szgid + self.set_issuer_gid(GID(string=szgid)) + + def encode(self): + self.xml = signature_template % (self.get_refid(), self.get_refid()) + + +## +# A credential provides a caller gid with privileges to an object gid. +# A signed credential is signed by the object's authority. +# +# Credentials are encoded in one of two ways. The legacy style places +# it in the subjectAltName of an X509 certificate. The new credentials +# are placed in signed XML. +# +# WARNING: +# In general, a signed credential obtained externally should +# not be changed else the signature is no longer valid. So, once +# you have loaded an existing signed credential, do not call encode() or sign() on it. + +def filter_creds_by_caller(creds, caller_hrn_list): + """ + Returns a list of creds who's gid caller matches the + specified caller hrn + """ + if not isinstance(creds, list): creds = [creds] + if not isinstance(caller_hrn_list, list): + caller_hrn_list = [caller_hrn_list] + caller_creds = [] + for cred in creds: + try: + tmp_cred = Credential(string=cred) + if tmp_cred.get_gid_caller().get_hrn() in caller_hrn_list: + caller_creds.append(cred) + except: pass + return caller_creds + +class Credential(object): + + ## + # Create a Credential object + # + # @param create If true, create a blank x509 certificate + # @param subject If subject!=None, create an x509 cert with the subject name + # @param string If string!=None, load the credential from the string + # @param filename If filename!=None, load the credential from the file + # FIXME: create and subject are ignored! + def __init__(self, create=False, subject=None, string=None, filename=None): + self.gidCaller = None + self.gidObject = None + self.expiration = None + self.privileges = None + self.issuer_privkey = None + self.issuer_gid = None + self.issuer_pubkey = None + self.parent = None + self.signature = None + self.xml = None + self.refid = None + self.legacy = None + + # Check if this is a legacy credential, translate it if so + if string or filename: + if string: + str = string + elif filename: + str = file(filename).read() + + if str.strip().startswith("-----"): + self.legacy = CredentialLegacy(False,string=str) + self.translate_legacy(str) + else: + self.xml = str + self.decode() + + # Find an xmlsec1 path + self.xmlsec_path = '' + paths = ['/usr/bin','/usr/local/bin','/bin','/opt/bin','/opt/local/bin'] + for path in paths: + if os.path.isfile(path + '/' + 'xmlsec1'): + self.xmlsec_path = path + '/' + 'xmlsec1' + break + + def get_subject(self): + if not self.gidObject: + self.decode() + return self.gidObject.get_printable_subject() + + def get_summary_tostring(self): + if not self.gidObject: + self.decode() + obj = self.gidObject.get_printable_subject() + caller = self.gidCaller.get_printable_subject() + exp = self.get_expiration() + # Summarize the rights too? The issuer? + return "[ Grant %s rights on %s until %s ]" % (caller, obj, exp) + + def get_signature(self): + if not self.signature: + self.decode() + return self.signature + + def set_signature(self, sig): + self.signature = sig + + + ## + # Translate a legacy credential into a new one + # + # @param String of the legacy credential + + def translate_legacy(self, str): + legacy = CredentialLegacy(False,string=str) + self.gidCaller = legacy.get_gid_caller() + self.gidObject = legacy.get_gid_object() + lifetime = legacy.get_lifetime() + if not lifetime: + self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME)) + else: + self.set_expiration(int(lifetime)) + self.lifeTime = legacy.get_lifetime() + self.set_privileges(legacy.get_privileges()) + self.get_privileges().delegate_all_privileges(legacy.get_delegate()) + + ## + # Need the issuer's private key and name + # @param key Keypair object containing the private key of the issuer + # @param gid GID of the issuing authority + + def set_issuer_keys(self, privkey, gid): + self.issuer_privkey = privkey + self.issuer_gid = gid + + + ## + # Set this credential's parent + def set_parent(self, cred): + self.parent = cred + self.updateRefID() + + ## + # set the GID of the caller + # + # @param gid GID object of the caller + + def set_gid_caller(self, gid): + self.gidCaller = gid + # gid origin caller is the caller's gid by default + self.gidOriginCaller = gid + + ## + # get the GID of the object + + def get_gid_caller(self): + if not self.gidCaller: + self.decode() + return self.gidCaller + + ## + # set the GID of the object + # + # @param gid GID object of the object + + def set_gid_object(self, gid): + self.gidObject = gid + + ## + # get the GID of the object + + def get_gid_object(self): + if not self.gidObject: + self.decode() + return self.gidObject + + + + ## + # Expiration: an absolute UTC time of expiration (as either an int or string or datetime) + # + def set_expiration(self, expiration): + if isinstance(expiration, (int, float)): + self.expiration = datetime.datetime.fromtimestamp(expiration) + elif isinstance (expiration, datetime.datetime): + self.expiration = expiration + elif isinstance (expiration, StringTypes): + self.expiration = utcparse (expiration) + else: + logger.error ("unexpected input type in Credential.set_expiration") + + + ## + # get the lifetime of the credential (always in datetime format) + + def get_expiration(self): + if not self.expiration: + self.decode() + # at this point self.expiration is normalized as a datetime - DON'T call utcparse again + return self.expiration + + ## + # For legacy sake + def get_lifetime(self): + return self.get_expiration() + + ## + # set the privileges + # + # @param privs either a comma-separated list of privileges of a Rights object + + def set_privileges(self, privs): + if isinstance(privs, str): + self.privileges = Rights(string = privs) + else: + self.privileges = privs + + + ## + # return the privileges as a Rights object + + def get_privileges(self): + if not self.privileges: + self.decode() + return self.privileges + + ## + # determine whether the credential allows a particular operation to be + # performed + # + # @param op_name string specifying name of operation ("lookup", "update", etc) + + def can_perform(self, op_name): + rights = self.get_privileges() + + if not rights: + return False + + return rights.can_perform(op_name) + + + ## + # Encode the attributes of the credential into an XML string + # This should be done immediately before signing the credential. + # WARNING: + # In general, a signed credential obtained externally should + # not be changed else the signature is no longer valid. So, once + # you have loaded an existing signed credential, do not call encode() or sign() on it. + + def encode(self): + # Create the XML document + doc = Document() + signed_cred = doc.createElement("signed-credential") + +# Declare namespaces +# Note that credential/policy.xsd are really the PG schemas +# in a PL namespace. +# Note that delegation of credentials between the 2 only really works +# cause those schemas are identical. +# Also note these PG schemas talk about PG tickets and CM policies. + signed_cred.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance") + signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.planet-lab.org/resources/sfa/credential.xsd") + signed_cred.setAttribute("xsi:schemaLocation", "http://www.planet-lab.org/resources/sfa/ext/policy/1 http://www.planet-lab.org/resources/sfa/ext/policy/1/policy.xsd") + +# PG says for those last 2: +# signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd") +# signed_cred.setAttribute("xsi:schemaLocation", "http://www.protogeni.net/resources/credential/ext/policy/1 http://www.protogeni.net/resources/credential/ext/policy/1/policy.xsd") + + doc.appendChild(signed_cred) + + # Fill in the bit + cred = doc.createElement("credential") + cred.setAttribute("xml:id", self.get_refid()) + signed_cred.appendChild(cred) + append_sub(doc, cred, "type", "privilege") + append_sub(doc, cred, "serial", "8") + append_sub(doc, cred, "owner_gid", self.gidCaller.save_to_string()) + append_sub(doc, cred, "owner_urn", self.gidCaller.get_urn()) + append_sub(doc, cred, "target_gid", self.gidObject.save_to_string()) + append_sub(doc, cred, "target_urn", self.gidObject.get_urn()) + append_sub(doc, cred, "uuid", "") + if not self.expiration: + self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME)) + self.expiration = self.expiration.replace(microsecond=0) + append_sub(doc, cred, "expires", self.expiration.isoformat()) + privileges = doc.createElement("privileges") + cred.appendChild(privileges) + + if self.privileges: + rights = self.get_privileges() + for right in rights.rights: + priv = doc.createElement("privilege") + append_sub(doc, priv, "name", right.kind) + append_sub(doc, priv, "can_delegate", str(right.delegate).lower()) + privileges.appendChild(priv) + + # Add the parent credential if it exists + if self.parent: + sdoc = parseString(self.parent.get_xml()) + # If the root node is a signed-credential (it should be), then + # get all its attributes and attach those to our signed_cred + # node. + # Specifically, PG and PLadd attributes for namespaces (which is reasonable), + # and we need to include those again here or else their signature + # no longer matches on the credential. + # We expect three of these, but here we copy them all: +# signed_cred.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance") +# and from PG (PL is equivalent, as shown above): +# signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd") +# signed_cred.setAttribute("xsi:schemaLocation", "http://www.protogeni.net/resources/credential/ext/policy/1 http://www.protogeni.net/resources/credential/ext/policy/1/policy.xsd") + + # HOWEVER! + # PL now also declares these, with different URLs, so + # the code notices those attributes already existed with + # different values, and complains. + # This happens regularly on delegation now that PG and + # PL both declare the namespace with different URLs. + # If the content ever differs this is a problem, + # but for now it works - different URLs (values in the attributes) + # but the same actual schema, so using the PG schema + # on delegated-to-PL credentials works fine. + + # Note: you could also not copy attributes + # which already exist. It appears that both PG and PL + # will actually validate a slicecred with a parent + # signed using PG namespaces and a child signed with PL + # namespaces over the whole thing. But I don't know + # if that is a bug in xmlsec1, an accident since + # the contents of the schemas are the same, + # or something else, but it seems odd. And this works. + parentRoot = sdoc.documentElement + if parentRoot.tagName == "signed-credential" and parentRoot.hasAttributes(): + for attrIx in range(0, parentRoot.attributes.length): + attr = parentRoot.attributes.item(attrIx) + # returns the old attribute of same name that was + # on the credential + # Below throws InUse exception if we forgot to clone the attribute first + oldAttr = signed_cred.setAttributeNode(attr.cloneNode(True)) + if oldAttr and oldAttr.value != attr.value: + msg = "Delegating cred from owner %s to %s over %s replaced attribute %s value '%s' with '%s'" % (self.parent.gidCaller.get_urn(), self.gidCaller.get_urn(), self.gidObject.get_urn(), oldAttr.name, oldAttr.value, attr.value) + logger.warn(msg) + #raise CredentialNotVerifiable("Can't encode new valid delegated credential: %s" % msg) + + p_cred = doc.importNode(sdoc.getElementsByTagName("credential")[0], True) + p = doc.createElement("parent") + p.appendChild(p_cred) + cred.appendChild(p) + # done handling parent credential + + # Create the tag + signatures = doc.createElement("signatures") + signed_cred.appendChild(signatures) + + # Add any parent signatures + if self.parent: + for cur_cred in self.get_credential_list()[1:]: + sdoc = parseString(cur_cred.get_signature().get_xml()) + ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True) + signatures.appendChild(ele) + + # Get the finished product + self.xml = doc.toxml() + + + def save_to_random_tmp_file(self): + fp, filename = mkstemp(suffix='cred', text=True) + fp = os.fdopen(fp, "w") + self.save_to_file(filename, save_parents=True, filep=fp) + return filename + + def save_to_file(self, filename, save_parents=True, filep=None): + if not self.xml: + self.encode() + if filep: + f = filep + else: + f = open(filename, "w") + f.write(self.xml) + f.close() + + def save_to_string(self, save_parents=True): + if not self.xml: + self.encode() + return self.xml + + def get_refid(self): + if not self.refid: + self.refid = 'ref0' + return self.refid + + def set_refid(self, rid): + self.refid = rid + + ## + # Figure out what refids exist, and update this credential's id + # so that it doesn't clobber the others. Returns the refids of + # the parents. + + def updateRefID(self): + if not self.parent: + self.set_refid('ref0') + return [] + + refs = [] + + next_cred = self.parent + while next_cred: + refs.append(next_cred.get_refid()) + if next_cred.parent: + next_cred = next_cred.parent + else: + next_cred = None + + + # Find a unique refid for this credential + rid = self.get_refid() + while rid in refs: + val = int(rid[3:]) + rid = "ref%d" % (val + 1) + + # Set the new refid + self.set_refid(rid) + + # Return the set of parent credential ref ids + return refs + + def get_xml(self): + if not self.xml: + self.encode() + return self.xml + + ## + # Sign the XML file created by encode() + # + # WARNING: + # In general, a signed credential obtained externally should + # not be changed else the signature is no longer valid. So, once + # you have loaded an existing signed credential, do not call encode() or sign() on it. + + def sign(self): + if not self.issuer_privkey or not self.issuer_gid: + return + doc = parseString(self.get_xml()) + sigs = doc.getElementsByTagName("signatures")[0] + + # Create the signature template to be signed + signature = Signature() + signature.set_refid(self.get_refid()) + sdoc = parseString(signature.get_xml()) + sig_ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True) + sigs.appendChild(sig_ele) + + self.xml = doc.toxml() + + + # Split the issuer GID into multiple certificates if it's a chain + chain = GID(filename=self.issuer_gid) + gid_files = [] + while chain: + gid_files.append(chain.save_to_random_tmp_file(False)) + if chain.get_parent(): + chain = chain.get_parent() + else: + chain = None + + + # Call out to xmlsec1 to sign it + ref = 'Sig_%s' % self.get_refid() + filename = self.save_to_random_tmp_file() + signed = os.popen('%s --sign --node-id "%s" --privkey-pem %s,%s %s' \ + % (self.xmlsec_path, ref, self.issuer_privkey, ",".join(gid_files), filename)).read() + os.remove(filename) + + for gid_file in gid_files: + os.remove(gid_file) + + self.xml = signed + + # This is no longer a legacy credential + if self.legacy: + self.legacy = None + + # Update signatures + self.decode() + + + ## + # Retrieve the attributes of the credential from the XML. + # This is automatically called by the various get_* methods of + # this class and should not need to be called explicitly. + + def decode(self): + if not self.xml: + return + doc = parseString(self.xml) + sigs = [] + signed_cred = doc.getElementsByTagName("signed-credential") + + # Is this a signed-cred or just a cred? + if len(signed_cred) > 0: + creds = signed_cred[0].getElementsByTagName("credential") + signatures = signed_cred[0].getElementsByTagName("signatures") + if len(signatures) > 0: + sigs = signatures[0].getElementsByTagName("Signature") + else: + creds = doc.getElementsByTagName("credential") + + if creds is None or len(creds) == 0: + # malformed cred file + raise CredentialNotVerifiable("Malformed XML: No credential tag found") + + # Just take the first cred if there are more than one + cred = creds[0] + + self.set_refid(cred.getAttribute("xml:id")) + self.set_expiration(utcparse(getTextNode(cred, "expires"))) + self.gidCaller = GID(string=getTextNode(cred, "owner_gid")) + self.gidObject = GID(string=getTextNode(cred, "target_gid")) + + + # Process privileges + privs = cred.getElementsByTagName("privileges")[0] + rlist = Rights() + for priv in privs.getElementsByTagName("privilege"): + kind = getTextNode(priv, "name") + deleg = str2bool(getTextNode(priv, "can_delegate")) + if kind == '*': + # Convert * into the default privileges for the credential's type + # Each inherits the delegatability from the * above + _ , type = urn_to_hrn(self.gidObject.get_urn()) + rl = determine_rights(type, self.gidObject.get_urn()) + for r in rl.rights: + r.delegate = deleg + rlist.add(r) + else: + rlist.add(Right(kind.strip(), deleg)) + self.set_privileges(rlist) + + + # Is there a parent? + parent = cred.getElementsByTagName("parent") + if len(parent) > 0: + parent_doc = parent[0].getElementsByTagName("credential")[0] + parent_xml = parent_doc.toxml() + self.parent = Credential(string=parent_xml) + self.updateRefID() + + # Assign the signatures to the credentials + for sig in sigs: + Sig = Signature(string=sig.toxml()) + + for cur_cred in self.get_credential_list(): + if cur_cred.get_refid() == Sig.get_refid(): + cur_cred.set_signature(Sig) + + + ## + # Verify + # trusted_certs: A list of trusted GID filenames (not GID objects!) + # Chaining is not supported within the GIDs by xmlsec1. + # + # trusted_certs_required: Should usually be true. Set False means an + # empty list of trusted_certs would still let this method pass. + # It just skips xmlsec1 verification et al. Only used by some utils + # + # Verify that: + # . All of the signatures are valid and that the issuers trace back + # to trusted roots (performed by xmlsec1) + # . The XML matches the credential schema + # . That the issuer of the credential is the authority in the target's urn + # . In the case of a delegated credential, this must be true of the root + # . That all of the gids presented in the credential are valid + # . Including verifying GID chains, and includ the issuer + # . The credential is not expired + # + # -- For Delegates (credentials with parents) + # . The privileges must be a subset of the parent credentials + # . The privileges must have "can_delegate" set for each delegated privilege + # . The target gid must be the same between child and parents + # . The expiry time on the child must be no later than the parent + # . The signer of the child must be the owner of the parent + # + # -- Verify does *NOT* + # . ensure that an xmlrpc client's gid matches a credential gid, that + # must be done elsewhere + # + # @param trusted_certs: The certificates of trusted CA certificates + def verify(self, trusted_certs=None, schema=None, trusted_certs_required=True): + if not self.xml: + self.decode() + + # validate against RelaxNG schema + if HAVELXML and not self.legacy: + if schema and os.path.exists(schema): + tree = etree.parse(StringIO(self.xml)) + schema_doc = etree.parse(schema) + xmlschema = etree.XMLSchema(schema_doc) + if not xmlschema.validate(tree): + error = xmlschema.error_log.last_error + message = "%s: %s (line %s)" % (self.get_summary_tostring(), error.message, error.line) + raise CredentialNotVerifiable(message) + + if trusted_certs_required and trusted_certs is None: + trusted_certs = [] + +# trusted_cert_objects = [GID(filename=f) for f in trusted_certs] + trusted_cert_objects = [] + ok_trusted_certs = [] + # If caller explicitly passed in None that means skip cert chain validation. + # Strange and not typical + if trusted_certs is not None: + for f in trusted_certs: + try: + # Failures here include unreadable files + # or non PEM files + trusted_cert_objects.append(GID(filename=f)) + ok_trusted_certs.append(f) + except Exception, exc: + logger.error("Failed to load trusted cert from %s: %r", f, exc) + trusted_certs = ok_trusted_certs + + # Use legacy verification if this is a legacy credential + if self.legacy: + self.legacy.verify_chain(trusted_cert_objects) + if self.legacy.client_gid: + self.legacy.client_gid.verify_chain(trusted_cert_objects) + if self.legacy.object_gid: + self.legacy.object_gid.verify_chain(trusted_cert_objects) + return True + + # make sure it is not expired + if self.get_expiration() < datetime.datetime.utcnow(): + raise CredentialNotVerifiable("Credential %s expired at %s" % (self.get_summary_tostring(), self.expiration.isoformat())) + + # Verify the signatures + filename = self.save_to_random_tmp_file() + if trusted_certs is not None: + cert_args = " ".join(['--trusted-pem %s' % x for x in trusted_certs]) + + # If caller explicitly passed in None that means skip cert chain validation. + # - Strange and not typical + if trusted_certs is not None: + # Verify the gids of this cred and of its parents + for cur_cred in self.get_credential_list(): + cur_cred.get_gid_object().verify_chain(trusted_cert_objects) + cur_cred.get_gid_caller().verify_chain(trusted_cert_objects) + + refs = [] + refs.append("Sig_%s" % self.get_refid()) + + parentRefs = self.updateRefID() + for ref in parentRefs: + refs.append("Sig_%s" % ref) + + for ref in refs: + # If caller explicitly passed in None that means skip xmlsec1 validation. + # Strange and not typical + if trusted_certs is None: + break + +# print "Doing %s --verify --node-id '%s' %s %s 2>&1" % \ +# (self.xmlsec_path, ref, cert_args, filename) + verified = os.popen('%s --verify --node-id "%s" %s %s 2>&1' \ + % (self.xmlsec_path, ref, cert_args, filename)).read() + if not verified.strip().startswith("OK"): + # xmlsec errors have a msg= which is the interesting bit. + mstart = verified.find("msg=") + msg = "" + if mstart > -1 and len(verified) > 4: + mstart = mstart + 4 + mend = verified.find('\\', mstart) + msg = verified[mstart:mend] + raise CredentialNotVerifiable("xmlsec1 error verifying cred %s using Signature ID %s: %s %s" % (self.get_summary_tostring(), ref, msg, verified.strip())) + os.remove(filename) + + # Verify the parents (delegation) + if self.parent: + self.verify_parent(self.parent) + + # Make sure the issuer is the target's authority, and is + # itself a valid GID + self.verify_issuer(trusted_cert_objects) + return True + + ## + # Creates a list of the credential and its parents, with the root + # (original delegated credential) as the last item in the list + def get_credential_list(self): + cur_cred = self + list = [] + while cur_cred: + list.append(cur_cred) + if cur_cred.parent: + cur_cred = cur_cred.parent + else: + cur_cred = None + return list + + ## + # Make sure the credential's target gid (a) was signed by or (b) + # is the same as the entity that signed the original credential, + # or (c) is an authority over the target's namespace. + # Also ensure that the credential issuer / signer itself has a valid + # GID signature chain (signed by an authority with namespace rights). + def verify_issuer(self, trusted_gids): + root_cred = self.get_credential_list()[-1] + root_target_gid = root_cred.get_gid_object() + root_cred_signer = root_cred.get_signature().get_issuer_gid() + + # Case 1: + # Allow non authority to sign target and cred about target. + # + # Why do we need to allow non authorities to sign? + # If in the target gid validation step we correctly + # checked that the target is only signed by an authority, + # then this is just a special case of case 3. + # This short-circuit is the common case currently - + # and cause GID validation doesn't check 'authority', + # this allows users to generate valid slice credentials. + if root_target_gid.is_signed_by_cert(root_cred_signer): + # cred signer matches target signer, return success + return + + # Case 2: + # Allow someone to sign credential about themeselves. Used? + # If not, remove this. + #root_target_gid_str = root_target_gid.save_to_string() + #root_cred_signer_str = root_cred_signer.save_to_string() + #if root_target_gid_str == root_cred_signer_str: + # # cred signer is target, return success + # return + + # Case 3: + + # root_cred_signer is not the target_gid + # So this is a different gid that we have not verified. + # xmlsec1 verified the cert chain on this already, but + # it hasn't verified that the gid meets the HRN namespace + # requirements. + # Below we'll ensure that it is an authority. + # But we haven't verified that it is _signed by_ an authority + # We also don't know if xmlsec1 requires that cert signers + # are marked as CAs. + + # Note that if verify() gave us no trusted_gids then this + # call will fail. So skip it if we have no trusted_gids + if trusted_gids and len(trusted_gids) > 0: + root_cred_signer.verify_chain(trusted_gids) + else: + logger.debug("No trusted gids. Cannot verify that cred signer is signed by a trusted authority. Skipping that check.") + + # See if the signer is an authority over the domain of the target. + # There are multiple types of authority - accept them all here + # Maybe should be (hrn, type) = urn_to_hrn(root_cred_signer.get_urn()) + root_cred_signer_type = root_cred_signer.get_type() + if (root_cred_signer_type.find('authority') == 0): + #logger.debug('Cred signer is an authority') + # signer is an authority, see if target is in authority's domain + signerhrn = root_cred_signer.get_hrn() + if hrn_authfor_hrn(signerhrn, root_target_gid.get_hrn()): + return + + # We've required that the credential be signed by an authority + # for that domain. Reasonable and probably correct. + # A looser model would also allow the signer to be an authority + # in my control framework - eg My CA or CH. Even if it is not + # the CH that issued these, eg, user credentials. + + # Give up, credential does not pass issuer verification + + raise CredentialNotVerifiable("Could not verify credential owned by %s for object %s. Cred signer %s not the trusted authority for Cred target %s" % (self.gidCaller.get_urn(), self.gidObject.get_urn(), root_cred_signer.get_hrn(), root_target_gid.get_hrn())) + + + ## + # -- For Delegates (credentials with parents) verify that: + # . The privileges must be a subset of the parent credentials + # . The privileges must have "can_delegate" set for each delegated privilege + # . The target gid must be the same between child and parents + # . The expiry time on the child must be no later than the parent + # . The signer of the child must be the owner of the parent + def verify_parent(self, parent_cred): + # make sure the rights given to the child are a subset of the + # parents rights (and check delegate bits) + if not parent_cred.get_privileges().is_superset(self.get_privileges()): + raise ChildRightsNotSubsetOfParent(("Parent cred ref %s rights " % parent_cred.get_refid()) + + self.parent.get_privileges().save_to_string() + (" not superset of delegated cred %s ref %s rights " % (self.get_summary_tostring(), self.get_refid())) + + self.get_privileges().save_to_string()) + + # make sure my target gid is the same as the parent's + if not parent_cred.get_gid_object().save_to_string() == \ + self.get_gid_object().save_to_string(): + raise CredentialNotVerifiable("Delegated cred %s: Target gid not equal between parent and child. Parent %s" % (self.get_summary_tostring(), parent_cred.get_summary_tostring())) + + # make sure my expiry time is <= my parent's + if not parent_cred.get_expiration() >= self.get_expiration(): + raise CredentialNotVerifiable("Delegated credential %s expires after parent %s" % (self.get_summary_tostring(), parent_cred.get_summary_tostring())) + + # make sure my signer is the parent's caller + if not parent_cred.get_gid_caller().save_to_string(False) == \ + self.get_signature().get_issuer_gid().save_to_string(False): + raise CredentialNotVerifiable("Delegated credential %s not signed by parent %s's caller" % (self.get_summary_tostring(), parent_cred.get_summary_tostring())) + + # Recurse + if parent_cred.parent: + parent_cred.verify_parent(parent_cred.parent) + + + def delegate(self, delegee_gidfile, caller_keyfile, caller_gidfile): + """ + Return a delegated copy of this credential, delegated to the + specified gid's user. + """ + # get the gid of the object we are delegating + object_gid = self.get_gid_object() + object_hrn = object_gid.get_hrn() + + # the hrn of the user who will be delegated to + delegee_gid = GID(filename=delegee_gidfile) + delegee_hrn = delegee_gid.get_hrn() + + #user_key = Keypair(filename=keyfile) + #user_hrn = self.get_gid_caller().get_hrn() + subject_string = "%s delegated to %s" % (object_hrn, delegee_hrn) + dcred = Credential(subject=subject_string) + dcred.set_gid_caller(delegee_gid) + dcred.set_gid_object(object_gid) + dcred.set_parent(self) + dcred.set_expiration(self.get_expiration()) + dcred.set_privileges(self.get_privileges()) + dcred.get_privileges().delegate_all_privileges(True) + #dcred.set_issuer_keys(keyfile, delegee_gidfile) + dcred.set_issuer_keys(caller_keyfile, caller_gidfile) + dcred.encode() + dcred.sign() + + return dcred + + # only informative + def get_filename(self): + return getattr(self,'filename',None) + + ## + # Dump the contents of a credential to stdout in human-readable format + # + # @param dump_parents If true, also dump the parent certificates + def dump (self, *args, **kwargs): + print self.dump_string(*args, **kwargs) + + + def dump_string(self, dump_parents=False): + result="" + result += "CREDENTIAL %s\n" % self.get_subject() + filename=self.get_filename() + if filename: result += "Filename %s\n"%filename + result += " privs: %s\n" % self.get_privileges().save_to_string() + gidCaller = self.get_gid_caller() + if gidCaller: + result += " gidCaller:\n" + result += gidCaller.dump_string(8, dump_parents) + + if self.get_signature(): + print " gidIssuer:" + self.get_signature().get_issuer_gid().dump(8, dump_parents) + + gidObject = self.get_gid_object() + if gidObject: + result += " gidObject:\n" + result += gidObject.dump_string(8, dump_parents) + + if self.parent and dump_parents: + result += "\nPARENT" + result += self.parent.dump_string(True) + + return result diff --git a/sfa/trust/credential_legacy.py b/sfa/trust/credential_legacy.py index dda7096f..e66e6993 100644 --- a/sfa/trust/credential_legacy.py +++ b/sfa/trust/credential_legacy.py @@ -7,9 +7,8 @@ import xmlrpclib -from sfa.util.faults import * +from sfa.util.faults import MissingDelegateBit, ChildRightsNotSubsetOfParent from sfa.trust.certificate import Certificate -from sfa.trust.rights import Right,Rights from sfa.trust.gid import GID ## diff --git a/sfa/trust/gid.py b/sfa/trust/gid.py index 15ad6bff..656de4be 100644 --- a/sfa/trust/gid.py +++ b/sfa/trust/gid.py @@ -30,7 +30,7 @@ import uuid from sfa.trust.certificate import Certificate -from sfa.util.faults import * +from sfa.util.faults import GidInvalidParentHrn, GidParentHrn from sfa.util.sfalogging import logger from sfa.util.xrn import hrn_to_urn, urn_to_hrn, hrn_authfor_hrn diff --git a/sfa/trust/hierarchy.py b/sfa/trust/hierarchy.py index 63234363..9648c9d4 100644 --- a/sfa/trust/hierarchy.py +++ b/sfa/trust/hierarchy.py @@ -14,14 +14,14 @@ import os -from sfa.util.faults import * +from sfa.util.faults import MissingAuthority from sfa.util.sfalogging import logger from sfa.util.xrn import get_leaf, get_authority, hrn_to_urn, urn_to_hrn from sfa.trust.certificate import Keypair from sfa.trust.credential import Credential from sfa.trust.gid import GID, create_uuid from sfa.util.config import Config -from sfa.util.sfaticket import SfaTicket +from sfa.trust.sfaticket import SfaTicket ## # The AuthInfo class contains the information for an authority. This information @@ -204,7 +204,7 @@ class Hierarchy: def get_auth_info(self, xrn): hrn, type = urn_to_hrn(xrn) if not self.auth_exists(hrn): - logger.warning("Hierarchy: mising authority - xrn=%s, hrn=%s"%(xrn,hrn)) + logger.warning("Hierarchy: missing authority - xrn=%s, hrn=%s"%(xrn,hrn)) raise MissingAuthority(hrn) (directory, gid_filename, privkey_filename, dbinfo_filename) = \ diff --git a/sfa/util/sfaticket.py b/sfa/trust/sfaticket.py similarity index 98% rename from sfa/util/sfaticket.py rename to sfa/trust/sfaticket.py index 0be5d933..018d929e 100644 --- a/sfa/util/sfaticket.py +++ b/sfa/trust/sfaticket.py @@ -5,8 +5,7 @@ import xmlrpclib from sfa.trust.certificate import Certificate -from sfa.trust.rights import * -from sfa.trust.gid import * +from sfa.trust.gid import GID # Ticket is tuple: # (gidCaller, gidObject, attributes, rspec, delegate) diff --git a/sfa/util/PostgreSQL.py b/sfa/util/PostgreSQL.py index 19cd4d0e..f16f48d7 100644 --- a/sfa/util/PostgreSQL.py +++ b/sfa/util/PostgreSQL.py @@ -21,7 +21,7 @@ import sys try: import pgdb except: print >> sys.stderr, "WARNING, could not import pgdb" -from sfa.util.faults import * +from sfa.util.faults import SfaDBError from sfa.util.sfalogging import logger if not psycopg2: diff --git a/sfa/util/cache.py b/sfa/util/cache.py index 0383ccce..a2ded4a4 100644 --- a/sfa/util/cache.py +++ b/sfa/util/cache.py @@ -82,9 +82,18 @@ class Cache: def get(self, key): data = self.cache.get(key) - if not data or data.is_expired(): - return None - return data.get_data() + if not data: + data = None + elif data.is_expired(): + self.pop(key) + data = None + else: + data = data.get_data() + return data + + def pop(self, key): + if key in self.cache: + self.cache.pop(key) def dump(self): result = {} diff --git a/sfa/util/componentserver.py b/sfa/util/componentserver.py deleted file mode 100644 index 98373ec3..00000000 --- a/sfa/util/componentserver.py +++ /dev/null @@ -1,136 +0,0 @@ -## -# This module implements a general-purpose server layer for sfa. -# The same basic server should be usable on the registry, component, or -# other interfaces. -# -# TODO: investigate ways to combine this with existing PLC server? -## - -import sys -import traceback -import threading -import socket, os -import SocketServer -import BaseHTTPServer -import SimpleHTTPServer -import SimpleXMLRPCServer -from OpenSSL import SSL - -from sfa.util.sfalogging import logger -from sfa.trust.certificate import Keypair, Certificate -from sfa.trust.credential import * -from sfa.util.faults import * -from sfa.plc.api import ComponentAPI -from sfa.util.server import verify_callback, ThreadedServer - - -## -# taken from the web (XXX find reference). Implents HTTPS xmlrpc request handler - -class SecureXMLRpcRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler): - """Secure XML-RPC request handler class. - - It it very similar to SimpleXMLRPCRequestHandler but it uses HTTPS for transporting XML data. - """ - def setup(self): - self.connection = self.request - self.rfile = socket._fileobject(self.request, "rb", self.rbufsize) - self.wfile = socket._fileobject(self.request, "wb", self.wbufsize) - - def do_POST(self): - """Handles the HTTPS POST request. - - It was copied out from SimpleXMLRPCServer.py and modified to shutdown the socket cleanly. - """ - try: - peer_cert = Certificate() - peer_cert.load_from_pyopenssl_x509(self.connection.get_peer_certificate()) - self.api = ComponentAPI(peer_cert = peer_cert, - interface = self.server.interface, - key_file = self.server.key_file, - cert_file = self.server.cert_file) - # get arguments - request = self.rfile.read(int(self.headers["content-length"])) - # In previous versions of SimpleXMLRPCServer, _dispatch - # could be overridden in this class, instead of in - # SimpleXMLRPCDispatcher. To maintain backwards compatibility, - # check to see if a subclass implements _dispatch and dispatch - # using that method if present. - #response = self.server._marshaled_dispatch(request, getattr(self, '_dispatch', None)) - # XX TODO: Need to get the real remote address - remote_addr = (remote_ip, remote_port) = self.connection.getpeername() - self.api.remote_addr = remote_addr - #remote_addr = (self.rfile.connection.remote_ip, remote_port) - #self.api.remote_addr = remote_addr - response = self.api.handle(remote_addr, request) - - - except Exception, fault: - raise - # This should only happen if the module is buggy - # internal error, report as HTTP server error - self.send_response(500) - self.end_headers() - logger.log_exc("componentserver.SecureXMLRpcRequestHandler.do_POST") - else: - # got a valid XML RPC response - self.send_response(200) - self.send_header("Content-type", "text/xml") - self.send_header("Content-length", str(len(response))) - self.end_headers() - self.wfile.write(response) - - # shut down the connection - self.wfile.flush() - self.connection.shutdown() # Modified here! - -## -# Implements an HTTPS XML-RPC server. Generally it is expected that SFA -# functions will take a credential string, which is passed to -# decode_authentication. Decode_authentication() will verify the validity of -# the credential, and verify that the user is using the key that matches the -# GID supplied in the credential. - -class ComponentServer(threading.Thread): - - ## - # Create a new SfaServer object. - # - # @param ip the ip address to listen on - # @param port the port to listen on - # @param key_file private key filename of registry - # @param cert_file certificate filename containing public key - # (could be a GID file) - - def __init__(self, ip, port, key_file, cert_file, api=None): - threading.Thread.__init__(self) - self.key = Keypair(filename = key_file) - self.cert = Certificate(filename = cert_file) - self.server = ThreadedServer((ip, port), SecureXMLRpcRequestHandler, key_file, cert_file) - self.trusted_cert_list = None - self.register_functions() - - - ## - # Register functions that will be served by the XMLRPC server. This - # function should be overrided by each descendant class. - - def register_functions(self): - self.server.register_function(self.noop) - - ## - # Sample no-op server function. The no-op function decodes the credential - # that was passed to it. - - def noop(self, cred, anything): - self.decode_authentication(cred) - - return anything - - ## - # Execute the server, serving requests forever. - - def run(self): - self.server.serve_forever() - - diff --git a/sfa/util/config.py b/sfa/util/config.py index cf2cca9b..14669160 100644 --- a/sfa/util/config.py +++ b/sfa/util/config.py @@ -76,7 +76,7 @@ class Config: except: pass except IOError, e: - raise IOError, "Could not find the configuration file: %s" % config_file + raise IOError, "Could not find or load the configuration file: %s" % config_file def get_trustedroots_dir(self): return self.config_path + os.sep + 'trusted_roots' diff --git a/sfa/util/defaultdict.py b/sfa/util/defaultdict.py new file mode 100644 index 00000000..e0dd1450 --- /dev/null +++ b/sfa/util/defaultdict.py @@ -0,0 +1,38 @@ +# from http://code.activestate.com/recipes/523034/ +try: + from collections import defaultdict +except: + class defaultdict(dict): + def __init__(self, default_factory=None, *a, **kw): + if (default_factory is not None and + not hasattr(default_factory, '__call__')): + raise TypeError('first argument must be callable') + dict.__init__(self, *a, **kw) + self.default_factory = default_factory + def __getitem__(self, key): + try: + return dict.__getitem__(self, key) + except KeyError: + return self.__missing__(key) + def __missing__(self, key): + if self.default_factory is None: + raise KeyError(key) + self[key] = value = self.default_factory() + return value + def __reduce__(self): + if self.default_factory is None: + args = tuple() + else: + args = self.default_factory, + return type(self), args, None, None, self.items() + def copy(self): + return self.__copy__() + def __copy__(self): + return type(self)(self.default_factory, self) + def __deepcopy__(self, memo): + import copy + return type(self)(self.default_factory, + copy.deepcopy(self.items())) + def __repr__(self): + return 'defaultdict(%s, %s)' % (self.default_factory, + dict.__repr__(self)) diff --git a/sfa/util/filter.py b/sfa/util/filter.py index ada44ba5..8f037cab 100644 --- a/sfa/util/filter.py +++ b/sfa/util/filter.py @@ -5,11 +5,10 @@ except NameError: from sets import Set set = Set -import time try: import pgdb except: pass -from sfa.util.faults import * +from sfa.util.faults import SfaInvalidArgument from sfa.util.parameter import Parameter, Mixed, python_type @@ -128,7 +127,7 @@ class Filter(Parameter, dict): for char in modifiers.keys(): if field[0] == char: - modifiers[char]=True; + modifiers[char]=True field = field[1:] break diff --git a/sfa/util/httpsProtocol.py b/sfa/util/httpsProtocol.py deleted file mode 100644 index e6c6be1b..00000000 --- a/sfa/util/httpsProtocol.py +++ /dev/null @@ -1,51 +0,0 @@ -import httplib -import socket -import sys - - -def is_python26(): - return False - #return sys.version_info[0] == 2 and sys.version_info[1] == 6 - -# wrapper around standartd https modules. Properly supports timeouts. - -class HTTPSConnection(httplib.HTTPSConnection): - def __init__(self, host, port=None, key_file=None, cert_file=None, - strict=None, timeout = None): - httplib.HTTPSConnection.__init__(self, host, port, key_file, cert_file, strict) - if timeout: - timeout = float(timeout) - self.timeout = timeout - - def connect(self): - """Connect to a host on a given (SSL) port.""" - if is_python26(): - from sfa.util.ssl_socket import SSLSocket - sock = socket.create_connection((self.host, self.port), self.timeout) - if self._tunnel_host: - self.sock = sock - self._tunnel() - self.sock = SSLSocket(sock, self.key_file, self.cert_file) - else: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(self.timeout) - sock.connect((self.host, self.port)) - ssl = socket.ssl(sock, self.key_file, self.cert_file) - self.sock = httplib.FakeSocket(sock, ssl) - -class HTTPS(httplib.HTTPS): - def __init__(self, host='', port=None, key_file=None, cert_file=None, - strict=None, timeout = None): - # urf. compensate for bad input. - if port == 0: - port = None - self._setup(HTTPSConnection(host, port, key_file, cert_file, strict, timeout)) - - # we never actually use these for anything, but we keep them - # here for compatibility with post-1.5.2 CVS. - self.key_file = key_file - self.cert_file = cert_file - - def set_timeout(self, timeout): - if is_python26(): - self._conn.timeout = timeout diff --git a/sfa/util/method.py b/sfa/util/method.py index 4c37c676..97ddb399 100644 --- a/sfa/util/method.py +++ b/sfa/util/method.py @@ -3,18 +3,14 @@ # # -import os, time -from types import * -from types import StringTypes -import traceback +import time +from types import IntType, LongType, StringTypes import textwrap -import xmlrpclib from sfa.util.sfalogging import logger -from sfa.util.faults import * +from sfa.util.faults import SfaFault, SfaInvalidAPIMethod, SfaInvalidArgumentCount, SfaInvalidArgument from sfa.util.parameter import Parameter, Mixed, python_type, xmlrpc_type -from sfa.trust.auth import Auth class Method: """ diff --git a/sfa/util/parameter.py b/sfa/util/parameter.py index 7e384192..e520bfc2 100644 --- a/sfa/util/parameter.py +++ b/sfa/util/parameter.py @@ -5,8 +5,8 @@ # Copyright (C) 2006 The Trustees of Princeton University # -from types import * -from sfa.util.faults import * +from types import NoneType, IntType, LongType, FloatType, StringTypes, DictType, TupleType, ListType +from sfa.util.faults import SfaAPIError class Parameter: """ diff --git a/sfa/util/policy.py b/sfa/util/policy.py index 196c71ee..5e43be55 100644 --- a/sfa/util/policy.py +++ b/sfa/util/policy.py @@ -1,6 +1,6 @@ import os -from sfa.util.storage import * +from sfa.util.storage import SimpleStorage class Policy(SimpleStorage): diff --git a/sfa/util/record.py b/sfa/util/record.py index 7cfc752e..7ebf379a 100644 --- a/sfa/util/record.py +++ b/sfa/util/record.py @@ -6,9 +6,9 @@ from types import StringTypes -from sfa.trust.gid import * +from sfa.trust.gid import GID -from sfa.util.parameter import * +from sfa.util.parameter import Parameter from sfa.util.xrn import get_authority from sfa.util.row import Row from sfa.util.xml import XML @@ -279,6 +279,7 @@ class SfaRecord(Row): """ Load the record from a dictionary """ + self.set_name(dict['hrn']) gidstr = dict.get("gid", None) if gidstr: @@ -302,7 +303,7 @@ class SfaRecord(Row): recorddict = self.as_dict() filteredDict = dict([(key, val) for (key, val) in recorddict.iteritems() if key in self.fields.keys()]) record = XML('') - record.root.attrib.update(filteredDict) + record.parse_dict(filteredDict) str = record.toxml() return str @@ -316,7 +317,7 @@ class SfaRecord(Row): representation of the record. """ #dict = xmlrpclib.loads(str)[0][0] - + record = XML(str) self.load_from_dict(record.todict()) diff --git a/sfa/util/rspecHelper.py b/sfa/util/rspecHelper.py index 89f15afd..deaa746a 100755 --- a/sfa/util/rspecHelper.py +++ b/sfa/util/rspecHelper.py @@ -7,7 +7,7 @@ from lxml import etree from StringIO import StringIO from optparse import OptionParser -from sfa.util.faults import * +from sfa.util.faults import InvalidRSpec from sfa.util.sfalogging import logger def merge_rspecs(rspecs): diff --git a/sfa/util/specdict.py b/sfa/util/specdict.py deleted file mode 100644 index 2138e876..00000000 --- a/sfa/util/specdict.py +++ /dev/null @@ -1,131 +0,0 @@ -## -# SpecDict -# -# SpecDict defines a means for converting a dictionary with plc specific keys -# to a dict with rspec specific keys. -# -# SpecDict.fields dict defines all the rspec specific attribute and their -# expected type. -# -# SpecDict.plc_fields defines a one to one mapping of plc attribute to rspec -# attribute - -from types import StringTypes, ListType - -class SpecDict(dict): - """ - Base class of SpecDict objects. - """ - fields = {} - plc_fields = {} - type = None - - def __init__(self, spec_dict): - # convert plc dict and initialize self - sdict = self.plcToSpec(spec_dict) - dict.__init__(self, sdict) - - - def plcToSpec(self, spec_dict): - """ - Defines how to convert a plc dict to rspec dict - """ - spec = {} - for field in self.fields: - value = "" - expected = self.fields[field] - if isinstance(expected, StringTypes): - if self.plc_fields.has_key(field): - plc_field = self.plc_fields[field] - if spec_dict.has_key(plc_field): - value = spec_dict[plc_field] - elif isinstance(expected, ListType): - expected = expected[0] - if self.plc_fields.has_key(field): - plc_field = self.plc_fields[field] - if spec_dict.has_key(plc_field): - value = [expected(value) for value in spec_dict[plc_field]] - spec[field] = value - return {self.type: spec} - -# -# fields = { geni_field: type. Could be class for nested classes, otherwise prob str} -# plc_fields = {geni_field : plc_field} -# - -class IfSpecDict(SpecDict): - type = 'IfSpec' - fields = {'name': '', - 'addr': '', - 'type': '', - 'init_params': '', - 'min_rate': '', - 'max_rate': '', - 'max_kbyte': '', - 'ip_spoof': ''} - plc_fields = {'name': 'is_primary', # XXX needs munging to return name instead of True or False - 'addr': 'ip', - 'type': 'type'} - -class LinkSpecDict(SpecDict): - type = 'LinkSpec' - fields = {'min_alloc': '', - 'max_alloc': '', - 'type': '', - 'start_time': '', - 'bw': '', - 'duration': '', - 'init_params': '', - 'endpoints': [IfSpecDict]} - plc_fields = {'min_alloc': 'min_alloc', - 'max_alloc': 'max_alloc', - 'type': 'type', - 'start_time': 'start_time', - 'bw': 'bw', - 'duration': 'duration', - 'init_params': 'init_params', - 'endpoints': 'endpoints'} - - -class NodeSpecDict(SpecDict): - type = 'NodeSpec' - fields = {'name': '', - 'type': '', - 'init_params': '', - 'cpu_min': '', - 'cpu_share': '', - 'cpu_pct': '', - 'disk_max': '', - 'start_time': '', - 'duration': '', - 'net_if': [IfSpecDict]} - - plc_fields = {'name': 'hostname', - 'net_if': 'interfaces'} - -class NetSpecDict(SpecDict): - type = 'NetSpec' - fields = {'name': '', - 'start_time': '', - 'duration': '', - 'nodes': [NodeSpecDict], - 'links': [LinkSpecDict], - } - plc_fields = {'name': 'name', - 'start_time': 'start_time', - 'duration': 'duration', - 'nodes': 'nodes', - 'links': 'links'} - -class RSpecDict(SpecDict): - type = 'RSpec' - fields = {'start_time': '', - 'duration': '', - 'networks': [NetSpecDict] - } - plc_fields = {'networks': 'networks', - 'start_time': 'start_tim', - 'duration': 'duration' - } - -# vim:ts=4:expandtab diff --git a/sfa/util/table.py b/sfa/util/table.py index 0e162897..065e8abb 100644 --- a/sfa/util/table.py +++ b/sfa/util/table.py @@ -3,11 +3,13 @@ # # TODO: Use existing PLC database methods? or keep this separate? -from sfa.util.PostgreSQL import * -from sfa.trust.gid import * -from sfa.util.record import * -from sfa.util.config import * -from sfa.util.filter import * +from types import StringTypes + +from sfa.util.config import Config +from sfa.util.parameter import Parameter +from sfa.util.filter import Filter +from sfa.util.PostgreSQL import PostgreSQL +from sfa.util.record import SfaRecord, AuthorityRecord, NodeRecord, SliceRecord, UserRecord class SfaTable(list): diff --git a/sfa/util/topology.py b/sfa/util/topology.py new file mode 100644 index 00000000..bd7eb189 --- /dev/null +++ b/sfa/util/topology.py @@ -0,0 +1,40 @@ +## +# SFA Topology Info +# +# This module holds topology configuration for SFA. It is implemnted as a +# list of site_id tuples + +import os.path +import traceback +from sfa.util.sfalogging import logger + +class Topology(set): + """ + Parse the topology configuration file. + """ + + #def __init__(self, config_file = "/etc/sfa/topology"): + def __init__(self, config_file = "/tmp/topology"): + set.__init__(self) + self.config_file = None + self.config_path = None + self.load(config_file) + + def load(self, config_file): + try: + + self.config_file = config_file + # path to configuration data + self.config_path = os.path.dirname(config_file) + # load the links + f = open(config_file, 'r') + for line in f: + ignore = line.find('#') + if ignore > -1: + line = line[0:ignore] + tup = line.split() + if len(tup) > 1: + self.add((tup[0], tup[1])) + except Exception, e: + logger.log_exc("Could not find or load the configuration file: %s" % config_file) + raise diff --git a/sfa/util/xml.py b/sfa/util/xml.py index 91a1d950..25f16562 100755 --- a/sfa/util/xml.py +++ b/sfa/util/xml.py @@ -1,10 +1,9 @@ #!/usr/bin/python +from types import StringTypes from lxml import etree from StringIO import StringIO -from datetime import datetime, timedelta -from sfa.util.xrn import * -from sfa.util.plxrn import hostname_to_urn -from sfa.util.faults import SfaNotImplemented, InvalidXML + +from sfa.util.faults import InvalidXML class XpathFilter: @staticmethod @@ -56,13 +55,14 @@ class XML: self.root = tree.getroot() # set namespaces map self.namespaces = dict(self.root.nsmap) - # If the 'None' exist, then it's pointing to the default namespace. This makes - # it hard for us to write xpath queries for the default naemspace because lxml - # wont understand a None prefix. We will just associate the default namespeace - # with a key named 'default'. - if None in self.namespaces: - default_namespace = self.namespaces.pop(None) - self.namespaces['default'] = default_namespace + if 'default' not in self.namespaces and None in self.namespaces: + # If the 'None' exist, then it's pointing to the default namespace. This makes + # it hard for us to write xpath queries for the default naemspace because lxml + # wont understand a None prefix. We will just associate the default namespeace + # with a key named 'default'. + self.namespaces['default'] = self.namespaces[None] + else: + self.namespaces['default'] = 'default' # set schema for key in self.root.attrib.keys(): @@ -75,7 +75,8 @@ class XML: def parse_dict(self, d, root_tag_name='xml', element = None): if element is None: - self.parse_xml('<%s/>' % root_tag_name) + if self.root is None: + self.parse_xml('<%s/>' % root_tag_name) element = self.root if 'text' in d: @@ -89,8 +90,24 @@ class XML: for val in value: if isinstance(val, dict): child_element = etree.SubElement(element, key) - self.parse_dict(val, key, child_element) - + self.parse_dict(val, key, child_element) + elif isinstance(val, basestring): + child_element = etree.SubElement(element, key).text = val + + elif isinstance(value, int): + d[key] = unicode(d[key]) + elif value is None: + d.pop(key) + + # element.attrib.update will explode if DateTimes are in the + # dcitionary. + d=d.copy() + for (k,v) in d.iteritems(): + if not isinstance(v,StringTypes): del d[k] + for k in d.keys(): + if (type(d[k]) != str) and (type(d[k]) != unicode): + del d[k] + element.attrib.update(d) def validate(self, schema): @@ -200,8 +217,9 @@ class XML: return self.toxml() def toxml(self): - return etree.tostring(self.root, pretty_print=True) + return etree.tostring(self.root, encoding='UTF-8', pretty_print=True) + # XXX smbaker, for record.load_from_string def todict(self, elem=None): if elem is None: elem = self.root @@ -212,14 +230,19 @@ class XML: if child.tag not in d: d[child.tag] = [] d[child.tag].append(self.todict(child)) - return d + + if len(d)==1 and ("text" in d): + d = d["text"] + + return d def save(self, filename): f = open(filename, 'w') f.write(self.toxml()) f.close() - -if __name__ == '__main__': - rspec = RSpec('/tmp/resources.rspec') - print rspec + +# no RSpec in scope +#if __name__ == '__main__': +# rspec = RSpec('/tmp/resources.rspec') +# print rspec diff --git a/sfa/util/xmlrpcprotocol.py b/sfa/util/xmlrpcprotocol.py index 25e7b76d..2263b286 100644 --- a/sfa/util/xmlrpcprotocol.py +++ b/sfa/util/xmlrpcprotocol.py @@ -1,9 +1,10 @@ # XMLRPC-specific code for SFA Client import xmlrpclib -#from sfa.util.httpsProtocol import HTTPS, HTTPSConnection from httplib import HTTPS, HTTPSConnection + from sfa.util.sfalogging import logger + ## # ServerException, ExceptionUnmarshaller # diff --git a/sfa/util/xrn.py b/sfa/util/xrn.py index 3dc87b63..1f506289 100644 --- a/sfa/util/xrn.py +++ b/sfa/util/xrn.py @@ -23,7 +23,7 @@ import re -from sfa.util.faults import * +from sfa.util.faults import SfaAPIError # for convenience and smoother translation - we should get rid of these functions eventually def get_leaf(hrn): return Xrn(hrn).get_leaf() @@ -98,7 +98,7 @@ class Xrn: @staticmethod def urn_full (urn): if urn.startswith(Xrn.URN_PREFIX): return urn - else: return Xrn.URN_PREFIX+URN + else: return Xrn.URN_PREFIX+urn @staticmethod def urn_meaningful (urn): if urn.startswith(Xrn.URN_PREFIX): return urn[len(Xrn.URN_PREFIX):] @@ -173,13 +173,18 @@ class Xrn: # or completely change how record types are generated/stored if name != 'sa': type = type + "+" + name - + name ="" + else: + name = parts.pop(len(parts)-1) # convert parts (list) into hrn (str) by doing the following # 1. remove blank parts # 2. escape dots inside parts # 3. replace ':' with '.' inside parts - # 3. join parts using '.' - hrn = '.'.join([Xrn.escape(part).replace(':','.') for part in parts if part]) + # 3. join parts using '.' + hrn = '.'.join([Xrn.escape(part).replace(':','.') for part in parts if part]) + # dont replace ':' in the name section + if name: + hrn += '.%s' % Xrn.escape(name) self.hrn=str(hrn) self.type=str(type) diff --git a/sfatables/commands/Add.py b/sfatables/commands/Add.py index e7657fff..987cff50 100644 --- a/sfatables/commands/Add.py +++ b/sfatables/commands/Add.py @@ -1,7 +1,7 @@ -import os, time +import os import libxml2 from sfatables.command import Command -from sfatables.globals import * +from sfatables.globals import sfatables_config, target_dir, match_dir class Add(Command): def __init__(self): diff --git a/sfatables/commands/Delete.py b/sfatables/commands/Delete.py index 50b1d626..36469908 100644 --- a/sfatables/commands/Delete.py +++ b/sfatables/commands/Delete.py @@ -1,5 +1,5 @@ import os, time -from sfatables.globals import * +from sfatables.globals import sfatables_config from sfatables.command import Command class Delete(Command): diff --git a/sfatables/commands/Insert.py b/sfatables/commands/Insert.py index d4010920..852985eb 100644 --- a/sfatables/commands/Insert.py +++ b/sfatables/commands/Insert.py @@ -1,7 +1,7 @@ import os, time import libxml2 from sfatables.command import Command -from sfatables.globals import * +from sfatables.globals import sfatables_config, target_dir, match_dir class Insert(Command): def __init__(self): diff --git a/sfatables/commands/List.py b/sfatables/commands/List.py index 70d72064..cea40bb7 100644 --- a/sfatables/commands/List.py +++ b/sfatables/commands/List.py @@ -1,7 +1,7 @@ import os, time import libxml2 -from sfatables.globals import * +from sfatables.globals import sfatables_config from sfatables.pretty import Pretty from sfatables.command import Command diff --git a/sfatables/runtime.py b/sfatables/runtime.py index 99226f4f..e22967c7 100644 --- a/sfatables/runtime.py +++ b/sfatables/runtime.py @@ -2,15 +2,13 @@ import sys import os -import pdb -from optparse import OptionParser import libxml2 +import libxslt -from sfatables import commands -from sfatables.globals import * -from sfatables.commands.List import * -from sfatables.xmlrule import * +from sfatables.globals import sfatables_config +from sfatables.commands.List import List +from sfatables.xmlrule import XMLRule class SFATablesRules: def __init__(self, chain_name): diff --git a/sfatables/sfatables b/sfatables/sfatables index b413ef1a..a06680b6 100755 --- a/sfatables/sfatables +++ b/sfatables/sfatables @@ -8,14 +8,12 @@ import sys import os -import pdb import glob -import libxml2 from optparse import OptionParser from sfatables import commands from sfatables.xmlextension import Xmlextension -from sfatables.globals import * +from sfatables.globals import target_dir, match_dir def load_commands(module, list): command_dict={} diff --git a/sfatables/xmlextension.py b/sfatables/xmlextension.py index 5e298db9..f90e0fb1 100644 --- a/sfatables/xmlextension.py +++ b/sfatables/xmlextension.py @@ -5,7 +5,6 @@ # - The parameters that the processor needs to evaluate the context import libxml2 -from sfatables.globals import * class Xmlextension: def __init__(self, file_path): diff --git a/sfatables/xmlrule.py b/sfatables/xmlrule.py index e21f9d86..46f36018 100644 --- a/sfatables/xmlrule.py +++ b/sfatables/xmlrule.py @@ -1,10 +1,11 @@ +import sys,os + import libxml2 # allow to run sfa2wsdl if this is missing (for mac) -import sys try:import libxslt except: print >>sys.stderr, "WARNING, could not import libxslt" -from sfatables.globals import * +from sfatables.globals import sfatables_config class XMLRule: def apply_processor(self, type, doc, output_xpath_filter=None): @@ -88,14 +89,13 @@ class XMLRule: # then target(target_args, rspec) # else rspec - import pdb if (self.match(rspec)): return (True,self.wrap_up(self.target(rspec))) else: return (False,self.wrap_up(rspec)) - def apply_compiled(rspec): + def apply_compiled(self, rspec): # Not supported yet return None diff --git a/tests/testInterfaces.py b/tests/testInterfaces.py index d25484cf..91606371 100755 --- a/tests/testInterfaces.py +++ b/tests/testInterfaces.py @@ -12,7 +12,7 @@ from sfa.util.xrn import get_authority from sfa.util.config import * from sfa.trust.certificate import * from sfa.trust.credential import * -from sfa.util.sfaticket import * +from sfa.trust.sfaticket import SfaTicket from sfa.client import sfi def random_string(size): diff --git a/tools/Makefile b/tools/Makefile new file mode 100644 index 00000000..010f0194 --- /dev/null +++ b/tools/Makefile @@ -0,0 +1,18 @@ +########## compute dependency graphs +DEPTOOLS=py2depgraph.py depgraph2dot.py + +all:deps + +deps: server.png client.png + +server.dg: $(DEPTOOLS) + py2depgraph.py ../sfa/server/sfa-start.py > $@ + +client.dg: $(DEPTOOLS) + py2depgraph.py ../sfa/client/sfi.py > $@ + +%.png: %.dg + depgraph2dot.py < $*.dg | dot -T png -o $*.png + +clean: + rm -f *png *dg diff --git a/tools/depgraph2dot.py b/tools/depgraph2dot.py new file mode 100755 index 00000000..b8ecbce4 --- /dev/null +++ b/tools/depgraph2dot.py @@ -0,0 +1,197 @@ +#!/usr/bin/python +# Copyright 2004 Toby Dickenson +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject +# to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +import sys, getopt, colorsys, imp, hashlib + +class pydepgraphdot: + + def main(self,argv): + opts,args = getopt.getopt(argv,'',['mono']) + self.colored = 1 + for o,v in opts: + if o=='--mono': + self.colored = 0 + self.render() + + def fix(self,s): + # Convert a module name to a syntactically correct node name + return s.replace('.','_') + + def render(self): + p,t = self.get_data() + + # normalise our input data + for k,d in p.items(): + for v in d.keys(): + if not p.has_key(v): + p[v] = {} + + f = self.get_output_file() + + f.write('digraph G {\n') + #f.write('concentrate = true;\n') + #f.write('ordering = out;\n') + f.write('ranksep=1.0;\n') + f.write('node [style=filled,fontname=Helvetica,fontsize=10];\n') + allkd = p.items() + allkd.sort() + for k,d in allkd: + tk = t.get(k) + if self.use(k,tk): + allv = d.keys() + allv.sort() + for v in allv: + tv = t.get(v) + if self.use(v,tv) and not self.toocommon(v,tv): + f.write('%s -> %s' % ( self.fix(k),self.fix(v) ) ) + self.write_attributes(f,self.edge_attributes(k,v)) + f.write(';\n') + f.write(self.fix(k)) + self.write_attributes(f,self.node_attributes(k,tk)) + f.write(';\n') + f.write('}\n') + + def write_attributes(self,f,a): + if a: + f.write(' [') + f.write(','.join(a)) + f.write(']') + + def node_attributes(self,k,type): + a = [] + a.append('label="%s"' % self.label(k)) + if self.colored: + a.append('fillcolor="%s"' % self.color(k,type)) + else: + a.append('fillcolor=white') + if self.toocommon(k,type): + a.append('peripheries=2') + return a + + def edge_attributes(self,k,v): + a = [] + weight = self.weight(k,v) + if weight!=1: + a.append('weight=%d' % weight) + length = self.alien(k,v) + if length: + a.append('minlen=%d' % length) + return a + + def get_data(self): + t = eval(sys.stdin.read()) + return t['depgraph'],t['types'] + + def get_output_file(self): + return sys.stdout + + def use(self,s,type): + # Return true if this module is interesting and should be drawn. Return false + # if it should be completely omitted. This is a default policy - please override. + if s in ('os','sys','qt','time','__future__','types','re','string'): + # nearly all modules use all of these... more or less. They add nothing to + # our diagram. + return 0 + if s.startswith('encodings.'): + return 0 + if s=='__main__': + return 1 + if self.toocommon(s,type): + # A module where we dont want to draw references _to_. Dot doesnt handle these + # well, so it is probably best to not draw them at all. + return 0 + return 1 + + def toocommon(self,s,type): + # Return true if references to this module are uninteresting. Such references + # do not get drawn. This is a default policy - please override. + # + if s=='__main__': + # references *to* __main__ are never interesting. omitting them means + # that main floats to the top of the page + return 1 + if type==imp.PKG_DIRECTORY: + # dont draw references to packages. + return 1 + return 0 + + def weight(self,a,b): + # Return the weight of the dependency from a to b. Higher weights + # usually have shorter straighter edges. Return 1 if it has normal weight. + # A value of 4 is usually good for ensuring that a related pair of modules + # are drawn next to each other. This is a default policy - please override. + # + if b.split('.')[-1].startswith('_'): + # A module that starts with an underscore. You need a special reason to + # import these (for example random imports _random), so draw them close + # together + return 4 + return 1 + + def alien(self,a,b): + # Return non-zero if references to this module are strange, and should be drawn + # extra-long. the value defines the length, in rank. This is also good for putting some + # vertical space between seperate subsystems. This is a default policy - please override. + # + return 0 + + def label(self,s): + # Convert a module name to a formatted node label. This is a default policy - please override. + # + return '\\.\\n'.join(s.split('.')) + + def color(self,s,type): + # Return the node color for this module name. This is a default policy - please override. + # + # Calculate a color systematically based on the hash of the module name. Modules in the + # same package have the same color. Unpackaged modules are grey + t = self.normalise_module_name_for_hash_coloring(s,type) + return self.color_from_name(t) + + def normalise_module_name_for_hash_coloring(self,s,type): + if type==imp.PKG_DIRECTORY: + return s + else: + i = s.rfind('.') + if i<0: + return '' + else: + return s[:i] + + def color_from_name(self,name): + n = hashlib.md5(name).digest() + hf = float(ord(n[0])+ord(n[1])*0xff)/0xffff + sf = float(ord(n[2]))/0xff + vf = float(ord(n[3]))/0xff + r,g,b = colorsys.hsv_to_rgb(hf, 0.3+0.6*sf, 0.8+0.2*vf) + return '#%02x%02x%02x' % (r*256,g*256,b*256) + + +def main(): + pydepgraphdot().main(sys.argv[1:]) + +if __name__=='__main__': + main() + + + diff --git a/tools/py2depgraph.py b/tools/py2depgraph.py new file mode 100755 index 00000000..022add32 --- /dev/null +++ b/tools/py2depgraph.py @@ -0,0 +1,71 @@ +#!/usr/bin/python +# Copyright 2004,2009 Toby Dickenson +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject +# to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import sys, pprint +import modulefinder + +focus=[ 'sfa' , 'OpenSSL', 'M2Crypto', 'xmlrpclib', 'threading' ] + +class mymf(modulefinder.ModuleFinder): + def __init__(self,*args,**kwargs): + self._depgraph = {} + self._types = {} + self._last_caller = None + modulefinder.ModuleFinder.__init__(self,*args,**kwargs) + + def import_hook(self, name, caller=None, fromlist=None, level=None): + old_last_caller = self._last_caller + try: + self._last_caller = caller + return modulefinder.ModuleFinder.import_hook(self,name,caller,fromlist) + finally: + self._last_caller = old_last_caller + + def import_module(self,partnam,fqname,parent): + keep=False + for start in focus: + if fqname.startswith(start): keep=True + if not keep: + print >> sys.stderr, "Trimmed fqname",fqname + return + r = modulefinder.ModuleFinder.import_module(self,partnam,fqname,parent) + if r is not None: + self._depgraph.setdefault(self._last_caller.__name__,{})[r.__name__] = 1 + return r + + def load_module(self, fqname, fp, pathname, (suffix, mode, type)): + r = modulefinder.ModuleFinder.load_module(self, fqname, fp, pathname, (suffix, mode, type)) + if r is not None: + self._types[r.__name__] = type + return r + + +def main(argv): + path = sys.path[:] + debug = 0 + exclude = [] + mf = mymf(path,debug,exclude) + mf.run_script(argv[0]) + pprint.pprint({'depgraph':mf._depgraph,'types':mf._types}) + +if __name__=='__main__': + main(sys.argv[1:]) diff --git a/tools/readme b/tools/readme new file mode 100644 index 00000000..bb828945 --- /dev/null +++ b/tools/readme @@ -0,0 +1,5 @@ +initial version from + http://www.tarind.com/py2depgraph.py + http://www.tarind.com/depgraph2dot.py + +customized for trimming all non-project dependencies