*.pkey
*.cert
*.cred
+.DS_Store
.PHONY: all install clean uninstall
-VERSIONTAG=should-be-redefined-by-specfile
+VERSIONTAG=0.0-0-should.be-redefined-by-specfile
SCMURL=should-be-redefined-by-specfile
##########
RSpecs
- CreateSlivers should update SliverTags/attributes
+- ProtoGENI rspec integration testing
+- initscripts in the rspec
Registry
- Verify that sub authority certificates still work
<value>false</value>
<description>Flag to turn debug on.</description>
</variable>
+
+ <variable id="max_slice_renew" type="int">
+ <name>Max Slice Renew</name>
+ <value>60</value>
+ <description>Maximum amout of days a user can extend/renew their slices to</description>
+ </variable>
<variable id="session_key_path" type="string">
<name>User Session Keys Path </name>
'sfa/trust',
'sfa/util',
'sfa/managers',
+ 'sfa/managers/vini',
'sfa/rspecs',
+ 'sfa/rspecs/elements',
+ 'sfa/rspecs/elements/versions',
+ 'sfa/rspecs/versions',
'sfatables',
'sfatables/commands',
'sfatables/processors',
%define name sfa
%define version 1.0
-%define taglevel 25
+%define taglevel 36
%define release %{taglevel}%{?pldistro:.%{pldistro}}%{?date:.%{date}}
%global python_sitearch %( python -c "from distutils.sysconfig import get_python_lib; print get_python_lib(1)" )
#%endif
%package cm
-Summary: the SFA wrapper around MyPLC NodeManager
+Summary: the SFA layer around MyPLC NodeManager
Group: Applications/System
Requires: sfa
Requires: pyOpenSSL >= 0.6
%package plc
-Summary: the SFA wrapper arounf MyPLC
+Summary: the SFA layer around MyPLC
Group: Applications/System
Requires: sfa
Requires: python-psycopg2
fi
%postun cm
-[ "$1" -ge "1" ] && service sfa-cm restart
-
+[ "$1" -ge "1" ] && service sfa-cm restart || :
%changelog
+* Thu Sep 15 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-36
+- Unicode-friendliness for user names with accents/special chars.
+- Fix bug that could cause create the client to fail when calling CreateSliver for a slice that has the same hrn as a user.
+- CreaetSliver no longer fails for users that have a capital letter in their URN.
+- Fix bug in CreateSliver that generated incorrect login bases and email addresses for ProtoGENI requests.
+- Allow files with .gid, .pem or .crt extension to be loaded into the server's list of trusted certs.
+- Fix bugs and missing imports
+
+
+* Tue Aug 30 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-35
+- new method record.get_field for sface
+
+* Mon Aug 29 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-34
+- new option -c to sfa-nuke-plc.py
+- CreateSliver fixed for admin-only slice tags
+
+* Wed Aug 24 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-32
+- Fixed exploit that allowed an authorities to issue certs for objects that dont belong to them.
+- Fixed holes in certificate verification logic.
+- Aggregates no longer try to lookup slice and person records when processing CreateSliver requests. Clients are now required to specify this info in the 'users' argument.
+- Added 'boot_state' as an attribute of the node element in SFA rspec.
+- Non authority certificates are marked as CA:FALSE.
+
+* Tue Aug 16 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-32
+- fix typo in sfa-1.0-31 tag.
+- added CreateGid() Registry interface method.
+
+* Tue Aug 16 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-31
+- fix typo in sfa-1.0-30 tag
+
+* Tue Aug 16 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-30
+- Declare namespace and schema location in the credential.
+- Fix bug that prevetend connections from timing out.
+- Fix slice delegation.
+- Add statistics to slicemaanger listresources/createsliver rspec.
+- Added SFA_MAX_SLICE_RENEW which allows operators to configure the max ammout
+ of days a user can extend their slice expiration.
+- CA certs are only issued to objects of type authority
+
+* Fri Aug 05 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-29
+- tag 1.0-28 was broken due to typo in the changelog
+- new class sfa/util/httpsProtocol.py that supports timeouts
+
+* Thu Aug 4 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-28
+- Resolved issue that caused sfa hold onto idle db connections.
+- Fix bug that caused the registry to use the wrong type of credential.
+- Support authority+sm type.
+- Fix rspec merging bugs.
+- Only load certs that have .gid extension from /etc/sfa/trusted_roots/
+- Created a 'planetlab' extension to the ProtoGENI v2 rspec for supporting
+ planetlab hosted initscripts using the <planetlab:initscript> tag
+- Can now handle extraneous whitespace in the rspec without failing.
+
+* Fri Jul 8 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-27
+- ProtoGENI v2 RSpec updates.
+- Convert expiration timestamps with timezone info in credentials to utc.
+- Fixed redundant logging issue.
+- Improved SliceManager and SFI client logging.
+- Support aggregates that don't support the optional 'call_id' argument.
+- Only call get_trusted_certs() at aggreage interfaces that support the call.
+- CreateSliver() now handles MyPLC slice attributes/tags.
+- Cache now supports persistence.
+- Hide whitelisted nodes.
+
+* Tue Jun 21 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-26
+- fixed issues with sup authority signing
+- fixed bugs in remove_slivers and SliverStatus
+
* Thu Jun 16 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-25
- fix typo that prevented aggregates from operating properly
the api handler on every new server request, making it easier to access the
cache and use in more general ways.
-%changelog
-* Thu Jun 16 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-25
-- fix typo that prevented aggregates from operating properly
-
-* Tue Jun 14 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-24
-- load trusted certs into ssl context prior to handshake
-- client's logfile lives in ~/.sfi/sfi.log
-
-* Fri Jun 10 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-23
-- includes a change on passphrases that was intended in 1.0-22
-
-* Wed Mar 16 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-21
-- stable sfascan
-- fix in initscript, *ENABLED tags in config now taken into account
-
-* Fri Mar 11 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-20
-- some commits had not been pushed in tag 19
-
-* Fri Mar 11 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-19
-- GetVersion should now report full URLs with path
-- scansfa has nicer output and new syntax (entry URLs as args and not options)
-- dos2unix'ed flash policy pill
-
-* Wed Mar 09 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-18
-- fix packaging again for f8
-
-* Wed Mar 09 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-17
-- fix packaging (apparently broken in 1.0-16)
-- first working version of sfascan
-- tweaks in GetVersion for exposing hrn(AM) and full set of aggregates(SM)
-- deprecated the sfa_geni_aggregate config category
-
-* Tue Mar 08 2011 Andy Bavier <acb@cs.princeton.edu> - sfa-1.0-16
-- Fix build problem
-- First version of SFA scanner
-
-* Mon Mar 07 2011 Andy Bavier <acb@cs.princeton.edu> - sfa-1.0-15
-- Add support for Flash clients using flashpolicy
-- Fix problems with tag handling in RSpec
-
-* Wed Mar 02 2011 Andy Bavier <acb@cs.princeton.edu> - sfa-1.0-14
-- Modifications to the Eucalyptus Aggregate Manager
-- Fixes for VINI RSpec
-- Fix tag handling for PL RSpec
-- Fix XML Schema ordering for <urn> element
-
-* Tue Feb 01 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-13
-- just set x509 version to 2
-
-* Wed Jan 26 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-12
-- added urn to the node area in rspecs
-- conversion to urn now exports fqdn
-- sfa-import-plc.py now creates a unique registry record for each SFA interface
-
-* Thu Dec 16 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-11
-- undo broken attempt for python-2.7
-
-* Wed Dec 15 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-10
-- SMs avoid duplicates for when call graph has dags;
-- just based on network's name, when a duplicate occurs, one is just dropped
-- does not try to merge/aggregate 2 networks
-- also reviewed logging with the hope to fix the sfa startup msg:
-- TypeError: not all arguments converted during string formatting
-
-* Tue Dec 07 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-9
-- verify credentials against xsd schema
-- Fix SM to SM communication
-- Fix bug in sfa.util.sfalogging, sfa-import.py now logs to sfa_import.log
-- new setting session_key_path
-
-* Tue Nov 09 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-8
-- fix registry credential regeneration and handle expiration
-- support for setting slice tags (min_role=user)
-- client can display its own version: sfi.py version --local
-- GetVersion to provide urn in addition to hrn
-- more code uses plxrn vs previous helper functions
-- import replaces '+' in email addresses with '_'
-
-* Fri Oct 22 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-7
-- fix GetVersion code_tag and add code_url
-
-* Fri Oct 22 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-6
-- extend GetVersion towards minimum federation introspection, and expose local tag
-
-* Wed Oct 20 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-5
-- fixed some legacy issues (list vs List)
-- deprecated sfa.util.namespace for xrn and plxrn
-- unit tests ship as the sfa-tests rpm
-
-* Mon Oct 11 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-2
-- deprecated old methods (e.g. List/list, and GetCredential/get_credential)
-- NOTE: get_(self_)credential both have type and hrn swapped when moving to Get(Self)Credential
-- hrn-urn translations tweaked
-- fixed 'service sfa status'
-- sfa-nuke-plc has a -f/--file-system option to clean up /var/lib/authorities (exp.)
-- started to repair sfadump - although not usable yet
-- trust objects now have dump_string method that dump() actually prints
-- unit tests under review
-- logging cleanup ongoing (always safe to use sfalogging.sfa_logger())
-- binaries now support -v or -vv to increase loglevel
-- trashed obsolete sfa.util.client
-
-* Mon Oct 04 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-1
-- various bugfixes and cleanup, improved/harmonized logging
-
* Thu May 11 2010 Tony Mack <tmack@cs.princeton.edu> - sfa-0.9-11
- SfaServer now uses a pool of threads to handle requests concurrently
- sfa.util.rspec no longer used to process/manage rspecs (deprecated). This is now handled by sfa.plc.network and is not backwards compatible
+++ /dev/null
-tags:
- find . -name '*.py' | grep -v '/\.svn/' | xargs etags
-.PHONY: tags
-
-
# recompute the SFA graphs from different locations
+SFASCAN = ./sfascan.py -v
+
# AMs, at least MyPLC AMs, are boring
#BUNDLES += http://planet-lab.eu:12346/@auto-ple-am
BUNDLES += http://planet-lab.eu:12345/@auto-ple-reg
BUNDLES-LR += http://www.emanicslab.org:12345/@auto-elc-reg
BUNDLES-LR += http://www.emanicslab.org:12347/@auto-elc-sa
-EXTENSIONS := png svg
+#EXTENSIONS := png svg
+EXTENSIONS := png
####################
ALL += $(foreach bundle,$(BUNDLES),$(word 2,$(subst @, ,$(bundle))))
all: $(ALL)
+ple: auto-ple-reg auto-ple-sa-lr.out
+
####################
define bundle_scan_target
$(word 2,$(subst @, ,$(1))):
- ./sfascan.py $(foreach extension,$(EXTENSIONS),-o $(word 2,$(subst @, ,$(1))).$(extension)) $(word 1,$(subst @, ,$(1))) >& $(word 2,$(subst @, ,$(1))).out
+ $(SFASCAN) $(foreach extension,$(EXTENSIONS),-o $(word 2,$(subst @, ,$(1))).$(extension)) $(word 1,$(subst @, ,$(1))) >& .$(word 2,$(subst @, ,$(1))).out
.PHONY: $(word 2,$(subst @, ,$(1)))
endef
#################### same but left-to-right
define bundle_scan_target_lr
$(word 2,$(subst @, ,$(1)))-lr:
- ./sfascan.py -l $(foreach extension,$(EXTENSIONS),-o $(word 2,$(subst @, ,$(1)))-lr.$(extension)) $(word 1,$(subst @, ,$(1))) >& $(word 2,$(subst @, ,$(1)))-lr.out
+ $(SFASCAN) -l $(foreach extension,$(EXTENSIONS),-o $(word 2,$(subst @, ,$(1)))-lr.$(extension)) $(word 1,$(subst @, ,$(1))) >& .$(word 2,$(subst @, ,$(1)))-lr.out
.PHONY: $(word 2,$(subst @, ,$(1)))-lr
endef
rm -f auto-*.{out,version}
$(foreach extension,$(EXTENSIONS),rm -rf auto-*.$(extension);)
+DATE=$(shell date '+%Y-%m-%d')
PUBEXTENSIONS=png
publish:
- $(foreach extension,$(PUBEXTENSIONS),rsync -av auto-*.$(extension) tparment@srv-planete.inria.fr:/proj/planete/www/Thierry.Parmentelat/sfascan/ ;)
+ echo $(DATE)
+ ssh tparment@srv-planete.inria.fr mkdir /proj/planete/www/Thierry.Parmentelat/sfascan/$(DATE)
+ $(foreach extension,$(PUBEXTENSIONS),rsync -av auto-*.$(extension) tparment@srv-planete.inria.fr:/proj/planete/www/Thierry.Parmentelat/sfascan/$(DATE) ;)
#################### convenience, for debugging only
# make +foo : prints the value of $(foo)
--- /dev/null
+
+def pg_users_arg(records):
+ users = []
+ for record in records:
+ if record['type'] != 'user':
+ continue
+ user = {'urn': record['geni_urn'],
+ 'keys': record['keys']}
+ users.append(user)
+ return users
+
+def sfa_users_arg(records, slice_record):
+ users = []
+ for record in records:
+ if record['type'] != 'user':
+ continue
+ user = {'urn': record['geni_urn'], #
+ 'keys': record['keys'],
+ 'email': record['email'], # needed for MyPLC
+ 'person_id': record['person_id'], # needed for MyPLC
+ 'first_name': record['first_name'], # needed for MyPLC
+ 'last_name': record['last_name'], # needed for MyPLC
+ 'slice_record': slice_record, # needed for legacy refresh peer
+ 'key_ids': record['key_ids'] # needed for legacy refresh peer
+ }
+ users.append(user)
+ return users
+
+def sfa_to_pg_users_arg(users):
+
+ new_users = []
+ fields = ['urn', 'keys']
+ for user in users:
+ new_user = dict([item for item in user.items() \
+ if item[0] in fields])
+ new_users.append(new_user)
+ return new_users
from pprint import pprint
from types import StringTypes
-from sfa.util.rspec import RSpec
-
def create_parser():
command = sys.argv[0]
argv = sys.argv[1:]
from optparse import OptionParser
from pprint import pprint
from xml.parsers.expat import ExpatError
-
-from sfa.util.rspec import RecordSpec
-
+from sfa.util.xml import XML
def create_parser():
command = sys.argv[0]
return parser
-def printRec(record, filters, options):
+def printRec(record_dict, filters, options):
line = ""
if len(filters):
for filter in filters:
if options.DEBUG: print "Filtering on %s" %filter
line += "%s: %s\n" % (filter,
- printVal(record.dict["record"].get(filter, None)))
+ printVal(record_dict.get(filter, None)))
print line
else:
# print the wole thing
- for (key, value) in record.dict["record"].iteritems():
+ for (key, value) in record_dict.iteritems():
if (not options.withkey and key in ('gid', 'keys')) or\
(not options.plinfo and key == 'pl_info'):
continue
stdin = sys.stdin.read()
- record = RecordSpec(xml = stdin)
+ record = XML(stdin)
+ record_dict = record.todict()
- if not record.dict.has_key("record"):
- raise "RecordError", "Input record does not have 'record' tag."
-
if options.DEBUG:
- record.pprint()
+ pprint(record.toxml())
print "#####################################################"
- printRec(record, args, options)
+ printRec(record_dict, args, options)
if __name__ == '__main__':
try: main()
import os
from optparse import OptionParser
from pprint import pprint
-
-from sfa.util.rspec import RecordSpec
-
+from sfa.util.xml import XML
def create_parser():
command = sys.argv[0]
parser = create_parser();
(options, args) = parser.parse_args()
- record = RecordSpec(xml = sys.stdin.read())
-
+ record = XML(sys.stdin.read())
+ record_dict = record.todict()
if args:
- editDict(args, record.dict["record"], options)
+ editDict(args, record_dict, options)
if options.DEBUG:
- print "New Record:\n%s" % record.dict
- record.pprint()
-
- record.parseDict(record.dict)
+ print "New Record:\n%s" % record_dict
+
+ record.parse_dict(record_dict)
s = record.toxml()
sys.stdout.write(s)
from sfa.trust.credential import Credential
from sfa.trust.gid import GID
from sfa.util.record import SfaRecord
-from sfa.util.rspec import RSpec
-from sfa.util.sfalogging import sfa_logger, sfa_logger_goes_to_console
+from sfa.util.sfalogging import logger
def determine_sfa_filekind(fn):
print "%s: unknown filekind '%s'"% (filename,kind)
def main():
- sfa_logger_goes_to_console()
usage = """%prog file1 [ .. filen]
display info on input files"""
parser = OptionParser(usage=usage)
parser.add_option("-v", "--verbose", action='count', dest='verbose', default=0)
(options, args) = parser.parse_args()
- sfa_logger().setLevelFromOptVerbose(options.verbose)
+ logger.setLevelFromOptVerbose(options.verbose)
if len(args) <= 0:
parser.print_help()
sys.exit(1)
-#!/usr/bin/python
+#!/usr/bin/env python
import sys
import socket
from optparse import OptionParser
from sfa.client.sfi import Sfi
-from sfa.util.sfalogging import sfa_logger,sfa_logger_goes_to_console
+from sfa.util.sfalogging import logger, DEBUG
import sfa.util.xmlrpcprotocol as xmlrpcprotocol
def url_hostname_port (url):
self.ip=socket.gethostbyname(self.hostname)
self.probed=False
except:
-# traceback.print_exc()
self.hostname="unknown"
self.ip='0.0.0.0'
self.port="???"
pass
options=DummyOptions()
options.verbose=False
+ options.timeout=10
try:
client=Sfi(options)
client.read_config()
key_file = client.get_key_file()
cert_file = client.get_cert_file(key_file)
url=self.url()
- sfa_logger().info('issuing get version at %s'%url)
- server=xmlrpcprotocol.get_server(url, key_file, cert_file, options)
+ logger.info('issuing get version at %s'%url)
+ logger.debug("GetVersion, using timeout=%d"%options.timeout)
+ server=xmlrpcprotocol.get_server(url, key_file, cert_file, timeout=options.timeout, verbose=options.verbose)
self._version=server.GetVersion()
except:
-# traceback.print_exc()
self._version={}
self.probed=True
return self._version
result='<<TABLE BORDER="0" CELLBORDER="0"><TR><TD>' + \
'</TD></TR><TR><TD>'.join(lines) + \
'</TD></TR></TABLE>>'
-# print 'multilines=',result
return result
# default is for when we can't determine the type of the service
# performing xmlrpc call
version=interface.get_version()
if self.verbose:
- sfa_logger().info("GetVersion at interface %s"%interface.url())
+ logger.info("GetVersion at interface %s"%interface.url())
if not version:
- sfa_logger().info("<EMPTY GetVersion(); offline or cannot authenticate>")
+ logger.info("<EMPTY GetVersion(); offline or cannot authenticate>")
else:
for (k,v) in version.iteritems():
if not isinstance(v,dict):
- sfa_logger().info("\r\t%s:%s"%(k,v))
+ logger.info("\r\t%s:%s"%(k,v))
else:
- sfa_logger().info(k)
+ logger.info(k)
for (k1,v1) in v.iteritems():
- sfa_logger().info("\r\t\t%s:%s"%(k1,v1))
+ logger.info("\r\t\t%s:%s"%(k1,v1))
# 'geni_api' is expected if the call succeeded at all
# 'peers' is needed as well as AMs typically don't have peers
if 'geni_api' in version and 'peers' in version:
for (k,v) in interface.get_layout().iteritems():
node.attr[k]=v
else:
- sfa_logger().error("MISSED interface with node %s"%node)
+ logger.error("MISSED interface with node %s"%node)
default_outfiles=['sfa.png','sfa.svg','sfa.dot']
def main():
- sfa_logger_goes_to_console()
usage="%prog [options] url-entry-point(s)"
parser=OptionParser(usage=usage)
parser.add_option("-o","--output",action='append',dest='outfiles',default=[],
help="instead of top-to-bottom")
parser.add_option("-v","--verbose",action='store_true',dest='verbose',default=False,
help="verbose")
+ parser.add_option("-d","--debug",action='store_true',dest='debug',default=False,
+ help="debug")
(options,args)=parser.parse_args()
if not args:
parser.print_help()
sys.exit(1)
if not options.outfiles:
options.outfiles=default_outfiles
+ logger.enable_console()
+ if options.debug:
+ options.verbose=True
+ logger.setLevel(DEBUG)
scanner=SfaScan(left_to_right=options.left_to_right, verbose=options.verbose)
entries = [ Interface(entry) for entry in args ]
g=scanner.graph(entries)
- sfa_logger().info("creating layout")
+ logger.info("creating layout")
g.layout(prog='dot')
for outfile in options.outfiles:
- sfa_logger().info("drawing in %s"%outfile)
+ logger.info("drawing in %s"%outfile)
g.draw(outfile)
- sfa_logger().info("done")
+ logger.info("done")
if __name__ == '__main__':
main()
sys.path.append('.')
import os, os.path
import tempfile
-import traceback
import socket
-import random
import datetime
-import zlib
+import codecs
+import pickle
from lxml import etree
from StringIO import StringIO
-from types import StringTypes, ListType
from optparse import OptionParser
-from sfa.util.sfalogging import _SfaLogger, logging
+from sfa.client.client_helper import pg_users_arg, sfa_users_arg
+from sfa.util.sfalogging import sfi_logger
from sfa.trust.certificate import Keypair, Certificate
from sfa.trust.gid import GID
from sfa.trust.credential import Credential
from sfa.util.sfaticket import SfaTicket
from sfa.util.record import SfaRecord, UserRecord, SliceRecord, NodeRecord, AuthorityRecord
-from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn
+from sfa.rspecs.rspec import RSpec
+from sfa.rspecs.rspec_converter import RSpecConverter
+from sfa.util.xrn import get_leaf, get_authority, hrn_to_urn
import sfa.util.xmlrpcprotocol as xmlrpcprotocol
from sfa.util.config import Config
from sfa.util.version import version_core
from sfa.util.cache import Cache
+from sfa.rspecs.version_manager import VersionManager
AGGREGATE_PORT=12346
CM_PORT=12346
# save methods
+def save_variable_to_file(var, filename, format="text"):
+ f = open(filename, "w")
+ if format == "text":
+ f.write(str(var))
+ elif format == "pickled":
+ f.write(pickle.dumps(var))
+ else:
+ # this should never happen
+ print "unknown output format", format
+
+
def save_rspec_to_file(rspec, filename):
if not filename.endswith(".rspec"):
filename = filename + ".rspec"
-
f = open(filename, 'w')
f.write(rspec)
f.close()
return
-def save_records_to_file(filename, recordList):
- index = 0
- for record in recordList:
- if index > 0:
- save_record_to_file(filename + "." + str(index), record)
- else:
- save_record_to_file(filename, record)
- index = index + 1
+def save_records_to_file(filename, recordList, format="xml"):
+ if format == "xml":
+ index = 0
+ for record in recordList:
+ if index > 0:
+ save_record_to_file(filename + "." + str(index), record)
+ else:
+ save_record_to_file(filename, record)
+ index = index + 1
+ elif format == "xmllist":
+ f = open(filename, "w")
+ f.write("<recordlist>\n")
+ for record in recordList:
+ record = SfaRecord(dict=record)
+ f.write('<record hrn="' + record.get_name() + '" type="' + record.get_type() + '" />\n')
+ f.write("</recordlist>\n")
+ f.close()
+ elif format == "hrnlist":
+ f = open(filename, "w")
+ for record in recordList:
+ record = SfaRecord(dict=record)
+ f.write(record.get_name() + "\n")
+ f.close()
+ else:
+ # this should never happen
+ print "unknown output format", format
def save_record_to_file(filename, record):
if record['type'] in ['user']:
else:
record = SfaRecord(dict=record)
str = record.save_to_string()
- file(filename, "w").write(str)
+ f=codecs.open(filename, encoding='utf-8',mode="w")
+ f.write(str)
+ f.close()
return
# load methods
def load_record_from_file(filename):
- str = file(filename, "r").read()
+ f=codecs.open(filename, encoding="utf-8", mode="r")
+ str = f.read()
+ f.close()
record = SfaRecord(string=str)
return record
for opt in Sfi.required_options:
if not hasattr(options,opt): setattr(options,opt,None)
if not hasattr(options,'sfi_dir'): options.sfi_dir=os.path.expanduser("~/.sfi/")
+ # xxx oops, this is dangerous, sounds like ww sometimes have discrepency
+ # would be safer to remove self.sfi_dir altogether
self.sfi_dir = options.sfi_dir
self.options = options
self.slicemgr = None
self.user = None
self.authority = None
self.hashrequest = False
- self.logger = _SfaLogger(self.sfi_dir + 'sfi.log', level = logging.INFO)
+ self.logger = sfi_logger
+ self.logger.enable_console()
def create_cmd_parser(self, command, additional_cmdargs=None):
cmdargs = {"list": "authority",
"update": "record",
"aggregates": "[name]",
"registries": "[name]",
+ "create_gid": "[name]",
"get_gid": [],
"get_trusted_certs": "cred",
"slices": "",
help="optional component information", default=None)
- if command in ("resources", "show", "list"):
+ # 'create' does return the new rspec, makes sense to save that too
+ if command in ("resources", "show", "list", "create_gid", 'create'):
parser.add_option("-o", "--output", dest="file",
help="output XML to file", metavar="FILE", default=None)
-
+
if command in ("show", "list"):
parser.add_option("-f", "--format", dest="format", type="choice",
help="display format ([text]|xml)", default="text",
choices=("text", "xml"))
+ parser.add_option("-F", "--fileformat", dest="fileformat", type="choice",
+ help="output file format ([xml]|xmllist|hrnlist)", default="xml",
+ choices=("xml", "xmllist", "hrnlist"))
+
+ if command in ("status", "version"):
+ parser.add_option("-o", "--output", dest="file",
+ help="output dictionary to file", metavar="FILE", default=None)
+ parser.add_option("-F", "--fileformat", dest="fileformat", type="choice",
+ help="output file format ([text]|pickled)", default="text",
+ choices=("text","pickled"))
+
if command in ("delegate"):
parser.add_option("-u", "--user",
action="store_true", dest="delegate_user", default=False,
parser.add_option("-k", "--hashrequest",
action="store_true", dest="hashrequest", default=False,
help="Create a hash of the request that will be authenticated on the server")
+ parser.add_option("-t", "--timeout", dest="timeout", default=None,
+ help="Amout of time tom wait before timing out the request")
parser.disable_interspersed_args()
return parser
def read_config(self):
- config_file = self.options.sfi_dir + os.sep + "sfi_config"
+ config_file = os.path.join(self.options.sfi_dir,"sfi_config")
try:
config = Config (config_file)
except:
self.cert_file = cert_file
self.cert = GID(filename=cert_file)
self.logger.info("Contacting Registry at: %s"%self.reg_url)
- self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, self.options)
+ self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, timeout=self.options.timeout, verbose=self.options.debug)
self.logger.info("Contacting Slice Manager at: %s"%self.sm_url)
- self.slicemgr = xmlrpcprotocol.get_server(self.sm_url, key_file, cert_file, self.options)
+ self.slicemgr = xmlrpcprotocol.get_server(self.sm_url, key_file, cert_file, timeout=self.options.timeout, verbose=self.options.debug)
return
def get_cached_server_version(self, server):
# check local cache first
cache = None
version = None
- cache_file = self.sfi_dir + os.path.sep + 'sfi_cache.dat'
+ cache_file = os.path.join(self.options.sfi_dir,'sfi_cache.dat')
cache_key = server.url + "-version"
try:
cache = Cache(cache_file)
version = server.GetVersion()
# cache version for 24 hours
cache.add(cache_key, version, ttl= 60*60*24)
+ self.logger.info("Updating cache file %s" % cache_file)
+ cache.save_to_file(cache_file)
return version
Returns true if server support the optional call_id arg, false otherwise.
"""
server_version = self.get_cached_server_version(server)
- if 'sfa' in server_version:
+ if 'sfa' in server_version and 'code_tag' in server_version:
code_tag = server_version['code_tag']
code_tag_parts = code_tag.split("-")
self.logger.info("Getting Registry issued cert")
self.read_config()
# *hack. need to set registyr before _get_gid() is called
- self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, self.options)
+ self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, timeout=self.options.timeout, verbose=self.options.debug)
gid = self._get_gid(type='user')
self.registry = None
self.logger.info("Writing certificate to %s"%cert_file)
hrn = self.user
gidfile = os.path.join(self.options.sfi_dir, hrn + ".gid")
- print gidfile
gid = self.get_cached_gid(gidfile)
if not gid:
user_cred = self.get_user_cred()
host_parts = host.split('/')
host_parts[0] = host_parts[0] + ":" + str(port)
url = "http://%s" % "/".join(host_parts)
- return xmlrpcprotocol.get_server(url, keyfile, certfile, self.options)
+ return xmlrpcprotocol.get_server(url, keyfile, certfile, timeout=self.options.timeout, verbose=self.options.debug)
# xxx opts could be retrieved in self.options
def get_server_from_opts(self, opts):
def dispatch(self, command, cmd_opts, cmd_args):
return getattr(self, command)(cmd_opts, cmd_args)
-
+
+ def create_gid(self, opts, args):
+ if len(args) < 1:
+ self.print_help()
+ sys.exit(1)
+ target_hrn = args[0]
+ user_cred = self.get_user_cred().save_to_string(save_parents=True)
+ gid = self.registry.CreateGid(user_cred, target_hrn, self.cert.save_to_string())
+ if opts.file:
+ filename = opts.file
+ else:
+ filename = os.sep.join([self.sfi_dir, '%s.gid' % target_hrn])
+ self.logger.info("writing %s gid to %s" % (target_hrn, filename))
+ GID(string=gid).save_to_file(filename)
+
+
# list entires in named authority registry
def list(self, opts, args):
if len(args)!= 1:
list = self.registry.List(hrn, user_cred)
except IndexError:
raise Exception, "Not enough parameters for the 'list' command"
-
- # filter on person, slice, site, node, etc.
+
+ # filter on person, slice, site, node, etc.
# THis really should be in the self.filter_records funct def comment...
list = filter_records(opts.type, list)
for record in list:
- print "%s (%s)" % (record['hrn'], record['type'])
+ print "%s (%s)" % (record['hrn'], record['type'])
if opts.file:
- file = opts.file
- if not file.startswith(os.sep):
- file = os.path.join(self.options.sfi_dir, file)
- save_records_to_file(file, list)
+ save_records_to_file(opts.file, list, opts.fileformat)
return
# show named registry record
record.dump()
else:
print record.save_to_string()
-
if opts.file:
- file = opts.file
- if not file.startswith(os.sep):
- file = os.path.join(self.options.sfi_dir, file)
- save_records_to_file(file, records)
+ save_records_to_file(opts.file, records, opts.fileformat)
return
def delegate(self, opts, args):
version=server.GetVersion()
for (k,v) in version.iteritems():
print "%-20s: %s"%(k,v)
+ if opts.file:
+ save_variable_to_file(version, opts.file, opts.fileformat)
# list instantiated slices
def slices(self, opts, args):
delegated_cred = self.delegate_cred(cred, get_authority(self.authority))
creds.append(delegated_cred)
if opts.rspec_version:
- call_options['rspec_version'] = opts.rspec_version
+ version_manager = VersionManager()
+ server_version = self.get_cached_server_version(server)
+ if 'sfa' in server_version:
+ # just request the version the client wants
+ call_options['rspec_version'] = version_manager.get_version(opts.rspec_version).to_dict()
+ else:
+ # this must be a protogeni aggregate. We should request a v2 ad rspec
+ # regardless of what the client user requested
+ call_options['rspec_version'] = version_manager.get_version('ProtoGENI 2').to_dict()
#panos add info options
if opts.info:
call_options['info'] = opts.info
if self.server_supports_call_id_arg(server):
call_args.append(unique_call_id())
result = server.ListResources(*call_args)
- format = opts.format
if opts.file is None:
- display_rspec(result, format)
+ display_rspec(result, opts.format)
else:
- file = opts.file
- if not file.startswith(os.sep):
- file = os.path.join(self.options.sfi_dir, file)
- save_rspec_to_file(result, file)
+ save_rspec_to_file(result, opts.file)
return
-
+
# created named slice with given rspec
def create(self, opts, args):
+ server = self.get_server_from_opts(opts)
+ server_version = self.get_cached_server_version(server)
slice_hrn = args[0]
- slice_urn = hrn_to_urn(slice_hrn, 'slice')
+ slice_urn = hrn_to_urn(slice_hrn, 'slice')
user_cred = self.get_user_cred()
slice_cred = self.get_slice_cred(slice_hrn).save_to_string(save_parents=True)
- creds = [slice_cred]
- if opts.delegate:
- delegated_cred = self.delegate_cred(slice_cred, get_authority(self.authority))
- creds.append(delegated_cred)
+ # delegate the cred to the callers root authority
+ delegated_cred = self.delegate_cred(slice_cred, get_authority(self.authority)+'.slicemanager')
+ #delegated_cred = self.delegate_cred(slice_cred, get_authority(slice_hrn))
+ #creds.append(delegated_cred)
rspec_file = self.get_rspec_file(args[1])
rspec = open(rspec_file).read()
+ # need to pass along user keys to the aggregate.
# users = [
# { urn: urn:publicid:IDN+emulab.net+user+alice
- # keys: [<ssh key A>, <ssh key B>]
+ # keys: [<ssh key A>, <ssh key B>]
# }]
users = []
- server = self.get_server_from_opts(opts)
- version = server.GetVersion()
- if 'sfa' not in version:
- # need to pass along user keys if this request is going to a ProtoGENI aggregate
- # ProtoGeni Aggregates will only install the keys of the user that is issuing the
- # request. So we will only pass in one user that contains the keys for all
- # users of the slice
- user = {'urn': user_cred.get_gid_caller().get_urn(),
- 'keys': []}
- slice_record = self.registry.Resolve(slice_urn, creds)
- if slice_record and 'researchers' in slice_record:
- user_hrns = slice_record['researchers']
- user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns]
- user_records = self.registry.Resolve(user_urns, creds)
- for user_record in user_records:
- if 'keys' in user_record:
- user['keys'].extend(user_record['keys'])
- users.append(user)
-
+ slice_records = self.registry.Resolve(slice_urn, [user_cred.save_to_string(save_parents=True)])
+ if slice_records and 'researcher' in slice_records[0] and slice_records[0]['researcher']!=[]:
+ slice_record = slice_records[0]
+ user_hrns = slice_record['researcher']
+ user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns]
+ user_records = self.registry.Resolve(user_urns, [user_cred.save_to_string(save_parents=True)])
+
+ if 'sfa' not in server_version:
+ users = pg_users_arg(user_records)
+ rspec = RSpec(rspec)
+ rspec.filter({'component_manager_id': server_version['urn']})
+ rspec = RSpecConverter.to_pg_rspec(rspec.toxml(), content_type='request')
+ creds = [slice_cred]
+ else:
+ users = sfa_users_arg(user_records, slice_record)
+ creds = [slice_cred, delegated_cred]
call_args = [slice_urn, creds, rspec, users]
if self.server_supports_call_id_arg(server):
call_args.append(unique_call_id())
-
- result = server.CreateSliver(*call_args)
- print result
+
+ result = server.CreateSliver(*call_args)
+ if opts.file is None:
+ print result
+ else:
+ save_rspec_to_file (result, opts.file)
return result
# get a ticket for the specified slice
call_args = [slice_urn, creds]
if self.server_supports_call_id_arg(server):
call_args.append(unique_call_id())
- print server.SliverStatus(*call_args)
+ result = server.SliverStatus(*call_args)
+ print result
+ if opts.file:
+ save_variable_to_file(result, opts.file, opts.fileformat)
def shutdown(self, opts, args):
self.logger.debug("resources cmd_opts %s" % cmd_opts.format)
elif command in ("list", "show", "remove"):
self.logger.debug("cmd_opts.type %s" % cmd_opts.type)
- self.logger.debug('cmd_args %s',cmd_args)
+ self.logger.debug('cmd_args %s' % cmd_args)
try:
self.dispatch(command, cmd_opts, cmd_args)
except KeyError:
self.logger.critical ("Unknown command %s"%command)
+ raise
sys.exit(1)
return
import sys
from sfa.util.rspecHelper import RSpec, Commands
from sfa.client.sfi_commands import Commands
-from sfa.rspecs.rspec_parser import parse_rspec
+from sfa.rspecs.rspec import RSpec
command = Commands(usage="%prog [options] [node1 node2...]",
description="Add sliver attributes to the RSpec. " +
if command.opts.infile:
attrs = command.get_attribute_dict()
- rspec = parse_rspec(command.opts.infile)
+ rspec = RSpec(command.opts.infile)
nodes = []
if command.opts.nodefile:
f = open(command.opts.nodefile, "r")
for value in attrs[name]:
if not nodes:
try:
- rspec.add_default_sliver_attribute(name, value)
+ rspec.version.add_default_sliver_attribute(name, value)
except:
print >> sys.stderr, "FAILED: on all nodes: %s=%s" % (name, value)
else:
for node in nodes:
try:
- rspec.add_sliver_attribute(node, name, value)
+ rspec.version.add_sliver_attribute(node, name, value)
except:
print >> sys.stderr, "FAILED: on node %s: %s=%s" % (node, name, value)
import sys
from sfa.client.sfi_commands import Commands
-from sfa.rspecs.rspec_parser import parse_rspec
+from sfa.rspecs.rspec import RSpec
+from sfa.rspecs.version_manager import VersionManager
command = Commands(usage="%prog [options] node1 node2...",
description="Add slivers to the RSpec. " +
outfile=file(command.opts.outfile,"w")
else:
outfile=sys.stdout
-
-rspec = parse_rspec(infile)
-rspec.type = 'request'
+ad_rspec = RSpec(infile)
nodes = file(command.opts.nodefile).read().split()
+version_manager = VersionManager()
try:
- if rspec.version['type'].lower() == 'protogeni':
- rspec.xml.set('type', 'request')
+ type = ad_rspec.version.type
+ version_num = ad_rspec.version.version
+ request_version = version_manager._get_version(type, version_num, 'request')
+ request_rspec = RSpec(version=request_version)
slivers = [{'hostname': node} for node in nodes]
- rspec.add_slivers(slivers)
+ request_rspec.version.merge(ad_rspec)
+ request_rspec.version.add_slivers(slivers)
except:
print >> sys.stderr, "FAILED: %s" % nodes
+ raise
sys.exit(1)
-print >>outfile, rspec.toxml(cleanup=True)
+print >>outfile, request_rspec.toxml()
sys.exit(0)
import sys
from sfa.client.sfi_commands import Commands
-from sfa.rspecs.rspec_parser import parse_rspec
+from sfa.rspecs.rspec import RSpec
command = Commands(usage="%prog [options] [node1 node2...]",
description="Delete sliver attributes from the RSpec. " +
if command.opts.infile:
attrs = command.get_attribute_dict()
- rspec = parse_rspec(command.opts.infile)
+ rspec = RSpec(command.opts.infile)
nodes = []
if command.opts.nodefile:
f = open(command.opts.nodefile, "r")
for value in attrs[name]:
if not nodes:
try:
- rspec.remove_default_sliver_attribute(name, value)
+ rspec.version.remove_default_sliver_attribute(name, value)
except:
print >> sys.stderr, "FAILED: on all nodes: %s=%s" % (name, value)
else:
for node in nodes:
try:
- rspec.remove_sliver_attribute(node, name, value)
+ rspec.version.remove_sliver_attribute(node, name, value)
except:
print >> sys.stderr, "FAILED: on node %s: %s=%s" % (node, name, value)
import sys
from sfa.client.sfi_commands import Commands
-from sfa.rspecs.rspec_parser import parse_rspec
+from sfa.rspecs.rspec import RSpec
command = Commands(usage="%prog [options] node1 node2...",
description="Delete slivers from the RSpec. " +
command.prep()
if command.opts.infile:
- rspec = parse_rspec(command.opts.infile)
+ rspec = RSpec(command.opts.infile)
nodes = []
if command.opts.nodefile:
f = open(command.opts.nodefile, "r")
try:
slivers = [{'hostname': node} for node in nodes]
- rspec.remove_slivers(slivers)
+ rspec.version.remove_slivers(slivers)
+ print rspec.toxml()
except:
print >> sys.stderr, "FAILED: %s" % nodes
- print rspec.toxml()
import sys
from sfa.client.sfi_commands import Commands
-from sfa.rspecs.rspec_parser import parse_rspec
+from sfa.rspecs.rspec import RSpec
command = Commands(usage="%prog [options]",
description="List all nodes in the RSpec. " +
command.prep()
if command.opts.infile:
- rspec = parse_rspec(command.opts.infile)
- nodes = rspec.get_nodes()
+ rspec = RSpec(command.opts.infile)
+ nodes = rspec.version.get_nodes()
if command.opts.outfile:
sys.stdout = open(command.opts.outfile, 'w')
import sys
from sfa.client.sfi_commands import Commands
-from sfa.rspecs.rspec_parser import parse_rspec
+from sfa.rspecs.rspec import RSpec
command = Commands(usage="%prog [options]",
description="List all slivers in the RSpec. " +
command.prep()
if command.opts.infile:
- rspec = parse_rspec(command.opts.infile)
- nodes = rspec.get_nodes_with_slivers()
+ rspec = RSpec(command.opts.infile)
+ nodes = rspec.version.get_nodes_with_slivers()
if command.opts.showatt:
- defaults = rspec.get_default_sliver_attributes()
+ defaults = rspec.version.get_default_sliver_attributes()
if defaults:
print "ALL NODES"
for (name, value) in defaults:
for node in nodes:
print node
if command.opts.showatt:
- atts = rspec.get_sliver_attributes(node)
+ atts = rspec.version.get_sliver_attributes(node)
for (name, value) in atts:
print " %s: %s" % (name, value)
reload
+ # install peer certs
+ action $"SFA installing peer certs" daemon /usr/bin/sfa-server.py -t -d $OPTIONS
+
if [ "$SFA_REGISTRY_ENABLED" -eq 1 ]; then
action $"SFA Registry" daemon /usr/bin/sfa-server.py -r -d $OPTIONS
fi
#
# description: Wraps PLCAPI into the SFA compliant API
#
-# $Id: sfa 14304 2009-07-06 20:19:51Z thierry $
-# $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/init.d/sfa $
-#
# Source config
. /etc/sfa/sfa_config
-
-
import datetime
import time
import traceback
from types import StringTypes
from sfa.util.faults import *
-from sfa.util.xrn import get_authority, hrn_to_urn, urn_to_hrn, Xrn
+from sfa.util.xrn import get_authority, hrn_to_urn, urn_to_hrn, Xrn, urn_to_sliver_id
from sfa.util.plxrn import slicename_to_hrn, hrn_to_pl_slicename, hostname_to_urn
-from sfa.util.rspec import *
from sfa.util.specdict import *
from sfa.util.record import SfaRecord
from sfa.util.policy import Policy
from sfa.plc.aggregate import Aggregate
from sfa.plc.slices import *
from sfa.util.version import version_core
-from sfa.rspecs.rspec_version import RSpecVersion
-from sfa.rspecs.sfa_rspec import sfa_rspec_version
-from sfa.rspecs.pg_rspec import pg_rspec_ad_version, pg_rspec_request_version
-from sfa.rspecs.rspec_parser import parse_rspec
+from sfa.rspecs.version_manager import VersionManager
+from sfa.rspecs.rspec import RSpec
from sfa.util.sfatime import utcparse
from sfa.util.callids import Callids
def GetVersion(api):
+
+ version_manager = VersionManager()
+ ad_rspec_versions = []
+ request_rspec_versions = []
+ for rspec_version in version_manager.versions:
+ if rspec_version.content_type in ['*', 'ad']:
+ ad_rspec_versions.append(rspec_version.to_dict())
+ if rspec_version.content_type in ['*', 'request']:
+ request_rspec_versions.append(rspec_version.to_dict())
+ default_rspec_version = version_manager.get_version("sfa 1").to_dict()
xrn=Xrn(api.hrn)
- request_rspec_versions = [dict(pg_rspec_request_version), dict(sfa_rspec_version)]
- ad_rspec_versions = [dict(pg_rspec_ad_version), dict(sfa_rspec_version)]
version_more = {'interface':'aggregate',
'testbed':'myplc',
'hrn':xrn.get_hrn(),
'request_rspec_versions': request_rspec_versions,
'ad_rspec_versions': ad_rspec_versions,
- 'default_ad_rspec': dict(sfa_rspec_version)
+ 'default_ad_rspec': default_rspec_version
}
return version_core(version_more)
slice = {}
+ # get_expiration always returns a normalized datetime - no need to utcparse
extime = Credential(string=creds[0]).get_expiration()
# If the expiration time is > 60 days from now, set the expiration time to 60 days from now
if extime > datetime.datetime.utcnow() + datetime.timedelta(days=60):
(hrn, type) = urn_to_hrn(slice_xrn)
# find out where this slice is currently running
- api.logger.info(hrn)
slicename = hrn_to_pl_slicename(hrn)
- slices = api.plshell.GetSlices(api.plauth, [slicename], ['node_ids','person_ids','name','expires'])
+ slices = api.plshell.GetSlices(api.plauth, [slicename], ['slice_id', 'node_ids','person_ids','name','expires'])
if len(slices) == 0:
- raise Exception("Slice %s not found (used %s as slicename internally)" % slice_xrn, slicename)
+ raise Exception("Slice %s not found (used %s as slicename internally)" % (slice_xrn, slicename))
slice = slices[0]
# report about the local nodes only
nodes = api.plshell.GetNodes(api.plauth, {'node_id':slice['node_ids'],'peer_id':None},
- ['hostname', 'site_id', 'boot_state', 'last_contact'])
+ ['node_id', 'hostname', 'site_id', 'boot_state', 'last_contact'])
site_ids = [node['site_id'] for node in nodes]
sites = api.plshell.GetSites(api.plauth, site_ids, ['site_id', 'login_base'])
sites_dict = dict ( [ (site['site_id'],site['login_base'] ) for site in sites ] )
top_level_status = 'unknown'
if nodes:
top_level_status = 'ready'
- result['geni_urn'] = Xrn(slice_xrn, 'slice').get_urn()
+ slice_urn = Xrn(slice_xrn, 'slice').get_urn()
+ result['geni_urn'] = slice_urn
result['pl_login'] = slice['name']
result['pl_expires'] = datetime.datetime.fromtimestamp(slice['expires']).ctime()
res['pl_last_contact'] = node['last_contact']
if node['last_contact'] is not None:
res['pl_last_contact'] = datetime.datetime.fromtimestamp(node['last_contact']).ctime()
- res['geni_urn'] = hostname_to_urn(api.hrn, sites_dict[node['site_id']], node['hostname'])
+ sliver_id = urn_to_sliver_id(slice_urn, slice['slice_id'], node['node_id'])
+ res['geni_urn'] = sliver_id
if node['boot_state'] == 'boot':
res['geni_status'] = 'ready'
else:
result['geni_status'] = top_level_status
result['geni_resources'] = resources
- # XX remove me
- #api.logger.info(result)
- # XX remove me
return result
def CreateSliver(api, slice_xrn, creds, rspec_string, users, call_id):
"""
if Callids().already_handled(call_id): return ""
- reg_objects = __get_registry_objects(slice_xrn, creds, users)
-
- (hrn, type) = urn_to_hrn(slice_xrn)
- peer = None
aggregate = Aggregate(api)
slices = Slices(api)
+ (hrn, type) = urn_to_hrn(slice_xrn)
peer = slices.get_peer(hrn)
sfa_peer = slices.get_sfa_peer(hrn)
- registry = api.registries[api.hrn]
- credential = api.getCredential()
- (site_id, remote_site_id) = slices.verify_site(registry, credential, hrn,
- peer, sfa_peer, reg_objects)
-
- slice = slices.verify_slice(registry, credential, hrn, site_id,
- remote_site_id, peer, sfa_peer, reg_objects)
-
- nodes = api.plshell.GetNodes(api.plauth, slice['node_ids'], ['hostname'])
- current_slivers = [node['hostname'] for node in nodes]
- rspec = parse_rspec(rspec_string)
- requested_slivers = [str(host) for host in rspec.get_nodes_with_slivers()]
- # remove nodes not in rspec
- deleted_nodes = list(set(current_slivers).difference(requested_slivers))
-
- # add nodes from rspec
- added_nodes = list(set(requested_slivers).difference(current_slivers))
-
- try:
- if peer:
- api.plshell.UnBindObjectFromPeer(api.plauth, 'slice', slice['slice_id'], peer)
-
- api.plshell.AddSliceToNodes(api.plauth, slice['name'], added_nodes)
- api.plshell.DeleteSliceFromNodes(api.plauth, slice['name'], deleted_nodes)
-
- # TODO: update slice tags
- #network.updateSliceTags()
+ slice_record=None
+ if users:
+ slice_record = users[0].get('slice_record', {})
- finally:
- if peer:
- api.plshell.BindObjectToPeer(api.plauth, 'slice', slice.id, peer,
- slice.peer_id)
+ # parse rspec
+ rspec = RSpec(rspec_string)
+ requested_attributes = rspec.version.get_slice_attributes()
+
+ # ensure site record exists
+ site = slices.verify_site(hrn, slice_record, peer, sfa_peer)
+ # ensure slice record exists
+ slice = slices.verify_slice(hrn, slice_record, peer, sfa_peer)
+ # ensure person records exists
+ persons = slices.verify_persons(hrn, slice, users, peer, sfa_peer)
+ # ensure slice attributes exists
+ slices.verify_slice_attributes(slice, requested_attributes)
+
+ # add/remove slice from nodes
+ requested_slivers = [str(host) for host in rspec.version.get_nodes_with_slivers()]
+ slices.verify_slice_nodes(slice, requested_slivers, peer)
+ # hanlde MyPLC peer association.
+ # only used by plc and ple.
+ slices.handle_peer(site, slice, persons, peer)
+
return aggregate.get_rspec(slice_xrn=slice_xrn, version=rspec.version)
return slice_urns
-def ListResources(api, creds, options,call_id):
+def ListResources(api, creds, options, call_id):
if Callids().already_handled(call_id): return ""
# get slice's hrn from options
- xrn = options.get('geni_slice_urn', '')
+ xrn = options.get('geni_slice_urn', None)
(hrn, type) = urn_to_hrn(xrn)
+ version_manager = VersionManager()
# get the rspec's return format from options
- rspec_version = RSpecVersion(options.get('rspec_version'))
- version_string = "rspec_%s" % (rspec_version.get_version_name())
+ rspec_version = version_manager.get_version(options.get('rspec_version'))
+ version_string = "rspec_%s" % (rspec_version.to_string())
#panos adding the info option to the caching key (can be improved)
if options.get('info'):
- version_string = version_string + "_"+options.get('info', 'default')
+ version_string = version_string + "_"+options.get('info', 'default')
# look in cache first
if caching and api.cache and not xrn:
api.logger.info("aggregate.ListResources: returning cached value for hrn %s"%hrn)
return rspec
- #aggregate = Aggregate(api)
#panos: passing user-defined options
#print "manager options = ",options
aggregate = Aggregate(api, options)
-
rspec = aggregate.get_rspec(slice_xrn=xrn, version=rspec_version)
# cache the result
from __future__ import with_statement
import sys
-import os
+import os, errno
+import logging
+import datetime
import boto
from boto.ec2.regioninfo import RegionInfo
from sqlobject import *
from sfa.util.faults import *
-from sfa.util.xrn import urn_to_hrn
-from sfa.util.rspec import RSpec
+from sfa.util.xrn import urn_to_hrn, Xrn
from sfa.server.registry import Registries
from sfa.trust.credential import Credential
from sfa.plc.api import SfaAPI
+from sfa.plc.aggregate import Aggregate
+from sfa.plc.slices import *
from sfa.util.plxrn import hrn_to_pl_slicename, slicename_to_hrn
from sfa.util.callids import Callids
+from sfa.util.sfalogging import logger
+from sfa.rspecs.sfa_rspec import sfa_rspec_version
+from sfa.util.version import version_core
+
+from multiprocessing import Process
+from time import sleep
##
# The data structure used to represent a cloud.
#
EUCALYPTUS_RSPEC_SCHEMA='/etc/sfa/eucalyptus.rng'
-# Quick hack
-sys.stderr = file('/var/log/euca_agg.log', 'a+')
api = SfaAPI()
+##
+# Meta data of an instance.
+#
+class Meta(SQLObject):
+ instance = SingleJoin('EucaInstance')
+ state = StringCol(default = 'new')
+ pub_addr = StringCol(default = None)
+ pri_addr = StringCol(default = None)
+ start_time = DateTimeCol(default = None)
+
##
# A representation of an Eucalyptus instance. This is a support class
# for instance <-> slice mapping.
ramdisk_id = StringCol()
inst_type = StringCol()
key_pair = StringCol()
- slice = ForeignKey('Slice')
+ slice = ForeignKey('Slice')
+ meta = ForeignKey('Meta')
##
# Contacts Eucalyptus and tries to reserve this instance.
# @param pubKeys A list of public keys for the instance.
#
def reserveInstance(self, botoConn, pubKeys):
- print >>sys.stderr, 'Reserving an instance: image: %s, kernel: ' \
- '%s, ramdisk: %s, type: %s, key: %s' % \
- (self.image_id, self.kernel_id, self.ramdisk_id,
- self.inst_type, self.key_pair)
+ logger = logging.getLogger('EucaAggregate')
+ logger.info('Reserving an instance: image: %s, kernel: ' \
+ '%s, ramdisk: %s, type: %s, key: %s' % \
+ (self.image_id, self.kernel_id, self.ramdisk_id,
+ self.inst_type, self.key_pair))
# XXX The return statement is for testing. REMOVE in production
#return
except EC2ResponseError, ec2RespErr:
errTree = ET.fromstring(ec2RespErr.body)
msg = errTree.find('.//Message')
- print >>sys.stderr, msg.text
+ logger.error(msg.text)
self.destroySelf()
##
# Initialize the aggregate manager by reading a configuration file.
#
def init_server():
+ logger = logging.getLogger('EucaAggregate')
+ fileHandler = logging.FileHandler('/var/log/euca.log')
+ fileHandler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
+ logger.addHandler(fileHandler)
+ fileHandler.setLevel(logging.DEBUG)
+ logger.setLevel(logging.DEBUG)
+
configParser = ConfigParser()
configParser.read(['/etc/sfa/eucalyptus_aggregate.conf', 'eucalyptus_aggregate.conf'])
if len(configParser.sections()) < 1:
- print >>sys.stderr, 'No cloud defined in the config file'
+ logger.error('No cloud defined in the config file')
raise Exception('Cannot find cloud definition in configuration file.')
# Only read the first section.
detail = {'imageID' : i.id, 'kernelID' : i.kernel_id, 'ramdiskID' : i.ramdisk_id}
cloud['imageBundles'][name] = detail
- # Initialize sqlite3 database.
+ # Initialize sqlite3 database and tables.
dbPath = '/etc/sfa/db'
dbName = 'euca_aggregate.db'
if not os.path.isdir(dbPath):
- print >>sys.stderr, '%s not found. Creating directory ...' % dbPath
+ logger.info('%s not found. Creating directory ...' % dbPath)
os.mkdir(dbPath)
conn = connectionForURI('sqlite://%s/%s' % (dbPath, dbName))
sqlhub.processConnection = conn
Slice.createTable(ifNotExists=True)
EucaInstance.createTable(ifNotExists=True)
+ Meta.createTable(ifNotExists=True)
+
+ # Start the update process to keep track of the meta data
+ # about Eucalyptus instance.
+ Process(target=updateMeta).start()
# Make sure the schema exists.
if not os.path.exists(EUCALYPTUS_RSPEC_SCHEMA):
err = 'Cannot location schema at %s' % EUCALYPTUS_RSPEC_SCHEMA
- print >>sys.stderr, err
+ logger.error(err)
raise Exception(err)
##
useSSL = False
srvPath = '/'
eucaPort = 8773
+ logger = logging.getLogger('EucaAggregate')
if not accessKey or not secretKey or not eucaURL:
- print >>sys.stderr, 'Please set ALL of the required environment ' \
- 'variables by sourcing the eucarc file.'
+ logger.error('Please set ALL of the required environment ' \
+ 'variables by sourcing the eucarc file.')
return None
# Split the url into parts
# @param sliceHRN The hunman readable name of the slice.
# @return sting()
#
-def getKeysForSlice(sliceHRN):
- try:
- # convert hrn to slice name
- plSliceName = hrn_to_pl_slicename(sliceHRN)
- except IndexError, e:
- print >>sys.stderr, 'Invalid slice name (%s)' % sliceHRN
- return []
-
- # Get the slice's information
- sliceData = api.plshell.GetSlices(api.plauth, {'name':plSliceName})
- if not sliceData:
- print >>sys.stderr, 'Cannot get any data for slice %s' % plSliceName
+# This method is no longer needed because the user keys are passed into
+# CreateSliver
+#
+def getKeysForSlice(api, sliceHRN):
+ logger = logging.getLogger('EucaAggregate')
+ cred = api.getCredential()
+ registry = api.registries[api.hrn]
+ keys = []
+
+ # Get the slice record
+ records = registry.Resolve(sliceHRN, cred)
+ if not records:
+ logging.warn('Cannot find any record for slice %s' % sliceHRN)
return []
- # It should only return a list with len = 1
- sliceData = sliceData[0]
+ # Find who can log into this slice
+ persons = records[0]['persons']
- keys = []
- person_ids = sliceData['person_ids']
- if not person_ids:
- print >>sys.stderr, 'No users in slice %s' % sliceHRN
- return []
+ # Extract the keys from persons records
+ for p in persons:
+ sliceUser = registry.Resolve(p, cred)
+ userKeys = sliceUser[0]['keys']
+ keys += userKeys
- persons = api.plshell.GetPersons(api.plauth, person_ids)
- for person in persons:
- pkeys = api.plshell.GetKeys(api.plauth, person['key_ids'])
- for key in pkeys:
- keys.append(key['key'])
-
- return ''.join(keys)
+ return '\n'.join(keys)
##
# A class that builds the RSpec for Eucalyptus.
# Generates the RSpec.
#
def toXML(self):
+ logger = logging.getLogger('EucaAggregate')
if not self.cloudInfo:
- print >>sys.stderr, 'No cloud information'
+ logger.error('No cloud information')
return ''
xml = self.eucaRSpec
cloud = self.cloudInfo
with xml.RSpec(type='eucalyptus'):
- with xml.cloud(id=cloud['name']):
+ with xml.network(name=cloud['name']):
with xml.ipv4:
xml << cloud['ip']
#self.__keyPairsXML(cloud['keypairs'])
# get slice's hrn from options
xrn = options.get('geni_slice_urn', '')
hrn, type = urn_to_hrn(xrn)
+ logger = logging.getLogger('EucaAggregate')
# get hrn of the original caller
origin_hrn = options.get('origin_hrn', None)
conn = getEucaConnection()
if not conn:
- print >>sys.stderr, 'Error: Cannot create a connection to Eucalyptus'
+ logger.error('Cannot create a connection to Eucalyptus')
return 'Cannot create a connection to Eucalyptus'
try:
except EC2ResponseError, ec2RespErr:
errTree = ET.fromstring(ec2RespErr.body)
errMsgE = errTree.find('.//Message')
- print >>sys.stderr, errMsgE.text
+ logger.error(errMsgE.text)
rspec = EucaRSpecBuilder(cloud).toXML()
"""
Hook called via 'sfi.py create'
"""
-def CreateSliver(api, xrn, creds, xml, users, call_id):
+def CreateSliver(api, slice_xrn, creds, xml, users, call_id):
if Callids().already_handled(call_id): return ""
global cloud
- hrn = urn_to_hrn(xrn)[0]
+ logger = logging.getLogger('EucaAggregate')
+ logger.debug("In CreateSliver")
+
+ aggregate = Aggregate(api)
+ slices = Slices(api)
+ (hrn, type) = urn_to_hrn(slice_xrn)
+ peer = slices.get_peer(hrn)
+ sfa_peer = slices.get_sfa_peer(hrn)
+ slice_record=None
+ if users:
+ slice_record = users[0].get('slice_record', {})
conn = getEucaConnection()
if not conn:
- print >>sys.stderr, 'Error: Cannot create a connection to Eucalyptus'
+ logger.error('Cannot create a connection to Eucalyptus')
return ""
# Validate RSpec
schemaXML = ET.parse(EUCALYPTUS_RSPEC_SCHEMA)
rspecValidator = ET.RelaxNG(schemaXML)
rspecXML = ET.XML(xml)
+ for network in rspecXML.iterfind("./network"):
+ if network.get('name') != cloud['name']:
+ # Throw away everything except my own RSpec
+ # sfa_logger().error("CreateSliver: deleting %s from rspec"%network.get('id'))
+ network.getparent().remove(network)
if not rspecValidator(rspecXML):
error = rspecValidator.error_log.last_error
message = '%s (line %s)' % (error.message, error.line)
- # XXX: InvalidRSpec is new. Currently, I am not working with Trunk code.
- #raise InvalidRSpec(message)
- raise Exception(message)
+ raise InvalidRSpec(message)
+
+ """
+ Create the sliver[s] (slice) at this aggregate.
+ Verify HRN and initialize the slice record in PLC if necessary.
+ """
+
+ # ensure site record exists
+ site = slices.verify_site(hrn, slice_record, peer, sfa_peer)
+ # ensure slice record exists
+ slice = slices.verify_slice(hrn, slice_record, peer, sfa_peer)
+ # ensure person records exists
+ persons = slices.verify_persons(hrn, slice, users, peer, sfa_peer)
# Get the slice from db or create one.
s = Slice.select(Slice.q.slice_hrn == hrn).getOne(None)
pendingRmInst = []
for sliceInst in s.instances:
pendingRmInst.append(sliceInst.instance_id)
- existingInstGroup = rspecXML.findall('.//euca_instances')
+ existingInstGroup = rspecXML.findall(".//euca_instances")
for instGroup in existingInstGroup:
for existingInst in instGroup:
if existingInst.get('id') in pendingRmInst:
pendingRmInst.remove(existingInst.get('id'))
for inst in pendingRmInst:
- print >>sys.stderr, 'Instance %s will be terminated' % inst
dbInst = EucaInstance.select(EucaInstance.q.instance_id == inst).getOne(None)
- dbInst.destroySelf()
- conn.terminate_instances(pendingRmInst)
+ if dbInst.meta.state != 'deleted':
+ logger.debug('Instance %s will be terminated' % inst)
+ # Terminate instances one at a time for robustness
+ conn.terminate_instances([inst])
+ # Only change the state but do not remove the entry from the DB.
+ dbInst.meta.state = 'deleted'
+ #dbInst.destroySelf()
# Process new instance requests
- requests = rspecXML.findall('.//request')
+ requests = rspecXML.findall(".//request")
if requests:
# Get all the public keys associate with slice.
- pubKeys = getKeysForSlice(s.slice_hrn)
- print >>sys.stderr, "Passing the following keys to the instance:\n%s" % pubKeys
- sys.stderr.flush()
+ keys = []
+ for user in users:
+ keys += user['keys']
+ logger.debug("Keys: %s" % user['keys'])
+ pubKeys = '\n'.join(keys)
+ logger.debug('Passing the following keys to the instance:\n%s' % pubKeys)
for req in requests:
vmTypeElement = req.getparent()
instType = vmTypeElement.get('name')
bundleName = req.find('bundle').text
if not cloud['imageBundles'][bundleName]:
- print >>sys.stderr, 'Cannot find bundle %s' % bundleName
+ logger.error('Cannot find bundle %s' % bundleName)
bundleInfo = cloud['imageBundles'][bundleName]
instKernel = bundleInfo['kernelID']
instDiskImg = bundleInfo['imageID']
# Create the instances
for i in range(0, numInst):
- eucaInst = EucaInstance(slice = s,
- kernel_id = instKernel,
- image_id = instDiskImg,
+ eucaInst = EucaInstance(slice = s,
+ kernel_id = instKernel,
+ image_id = instDiskImg,
ramdisk_id = instRamDisk,
- key_pair = instKey,
- inst_type = instType)
+ key_pair = instKey,
+ inst_type = instType,
+ meta = Meta(start_time=datetime.datetime.now()))
eucaInst.reserveInstance(conn, pubKeys)
# xxx - should return altered rspec
# with enough data for the client to understand what's happened
return xml
+##
+# Return information on the IP addresses bound to each slice's instances
+#
+def dumpInstanceInfo():
+ logger = logging.getLogger('EucaMeta')
+ outdir = "/var/www/html/euca/"
+ outfile = outdir + "instances.txt"
+
+ try:
+ os.makedirs(outdir)
+ except OSError, e:
+ if e.errno != errno.EEXIST:
+ raise
+
+ dbResults = Meta.select(
+ AND(Meta.q.pri_addr != None,
+ Meta.q.state == 'running')
+ )
+ dbResults = list(dbResults)
+ f = open(outfile, "w")
+ for r in dbResults:
+ instId = r.instance.instance_id
+ ipaddr = r.pri_addr
+ hrn = r.instance.slice.slice_hrn
+ logger.debug('[dumpInstanceInfo] %s %s %s' % (instId, ipaddr, hrn))
+ f.write("%s %s %s\n" % (instId, ipaddr, hrn))
+ f.close()
+
+##
+# A separate process that will update the meta data.
+#
+def updateMeta():
+ logger = logging.getLogger('EucaMeta')
+ fileHandler = logging.FileHandler('/var/log/euca_meta.log')
+ fileHandler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
+ logger.addHandler(fileHandler)
+ fileHandler.setLevel(logging.DEBUG)
+ logger.setLevel(logging.DEBUG)
+
+ while True:
+ sleep(30)
+
+ # Get IDs of the instances that don't have IPs yet.
+ dbResults = Meta.select(
+ AND(Meta.q.pri_addr == None,
+ Meta.q.state != 'deleted')
+ )
+ dbResults = list(dbResults)
+ logger.debug('[update process] dbResults: %s' % dbResults)
+ instids = []
+ for r in dbResults:
+ if not r.instance:
+ continue
+ instids.append(r.instance.instance_id)
+ logger.debug('[update process] Instance Id: %s' % ', '.join(instids))
+
+ # Get instance information from Eucalyptus
+ conn = getEucaConnection()
+ vmInstances = []
+ reservations = conn.get_all_instances(instids)
+ for reservation in reservations:
+ vmInstances += reservation.instances
+
+ # Check the IPs
+ instIPs = [ {'id':i.id, 'pri_addr':i.private_dns_name, 'pub_addr':i.public_dns_name}
+ for i in vmInstances if i.private_dns_name != '0.0.0.0' ]
+ logger.debug('[update process] IP dict: %s' % str(instIPs))
+
+ # Update the local DB
+ for ipData in instIPs:
+ dbInst = EucaInstance.select(EucaInstance.q.instance_id == ipData['id']).getOne(None)
+ if not dbInst:
+ logger.info('[update process] Could not find %s in DB' % ipData['id'])
+ continue
+ dbInst.meta.pri_addr = ipData['pri_addr']
+ dbInst.meta.pub_addr = ipData['pub_addr']
+ dbInst.meta.state = 'running'
+
+ dumpInstanceInfo()
+
+def GetVersion(api):
+ xrn=Xrn(api.hrn)
+ request_rspec_versions = [dict(sfa_rspec_version)]
+ ad_rspec_versions = [dict(sfa_rspec_version)]
+ version_more = {'interface':'aggregate',
+ 'testbed':'myplc',
+ 'hrn':xrn.get_hrn(),
+ 'request_rspec_versions': request_rspec_versions,
+ 'ad_rspec_versions': ad_rspec_versions,
+ 'default_ad_rspec': dict(sfa_rspec_version)
+ }
+ return version_core(version_more)
+
def main():
init_server()
#rspec = ListResources('euca', 'planetcloud.pc.test', 'planetcloud.pc.marcoy', 'test_euca')
#print rspec
- print getKeysForSlice('gc.gc.test1')
+
+ server_key_file = '/var/lib/sfa/authorities/server.key'
+ server_cert_file = '/var/lib/sfa/authorities/server.cert'
+ api = SfaAPI(key_file = server_key_file, cert_file = server_cert_file, interface='aggregate')
+ print getKeysForSlice(api, 'gc.gc.test1')
if __name__ == "__main__":
main()
-from sfa.util.xrn import urn_to_hrn, hrn_to_urn, get_authority
-from sfa.util.plxrn import hrn_to_pl_slicename
-from sfa.util.plxrn import hrn_to_pl_slicename
-from sfa.util.rspec import RSpec
-from sfa.util.sfalogging import sfa_logger
-from sfa.util.config import Config
-from sfa.managers.aggregate_manager_pl import GetVersion, __get_registry_objects
-from sfa.plc.slices import Slices
-import os
-import time
-
-RSPEC_TMP_FILE_PREFIX = "/tmp/max_rspec"
-
-# execute shell command and return both exit code and text output
-def shell_execute(cmd, timeout):
- pipe = os.popen('{ ' + cmd + '; } 2>&1', 'r')
- pipe = os.popen(cmd + ' 2>&1', 'r')
- text = ''
- while timeout:
- line = pipe.read()
- text += line
- time.sleep(1)
- timeout = timeout-1
- code = pipe.close()
- if code is None: code = 0
- if text[-1:] == '\n': text = text[:-1]
- return code, text
-
-"""
- call AM API client with command like in the following example:
- cd aggregate_client; java -classpath AggregateWS-client-api.jar:lib/* \
- net.geni.aggregate.client.examples.CreateSliceNetworkClient \
- ./repo https://geni:8443/axis2/services/AggregateGENI \
- ... params ...
-"""
-
-def call_am_apiclient(client_app, params, timeout):
- (client_path, am_url) = Config().get_max_aggrMgr_info()
- sys_cmd = "cd " + client_path + "; java -classpath AggregateWS-client-api.jar:lib/* net.geni.aggregate.client.examples." + client_app + " ./repo " + am_url + " " + ' '.join(params)
- ret = shell_execute(sys_cmd, timeout)
- sfa_logger().debug("shell_execute cmd: %s returns %s" % (sys_cmd, ret))
-# save request RSpec xml content to a tmp file
-def save_rspec_to_file(rspec):
- path = RSPEC_TMP_FILE_PREFIX + "_" + time.strftime('%Y%m%dT%H:%M:%S', time.gmtime(time.time())) +".xml"
- file = open(path, "w")
- file.write(rspec)
- file.close()
- return path
-
-# get stripped down slice id/name plc:maxpl:xi_slice1 --> xi_slice1
-def get_short_slice_id(cred, hrn):
- if hrn == None:
- return None
- slice_id = hrn[hrn.rfind('+')+1:]
- if slice_id == None:
- slice_id = hrn[hrn.rfind(':')+1:]
- if slice_id == None:
- return hrn
- pass
- return str(slice_id)
-
-# extract xml
-def get_xml_by_tag(text, tag):
- indx1 = text.find('<'+tag)
- indx2 = text.find('/'+tag+'>')
- xml = None
- if indx1!=-1 and indx2>indx1:
- xml = text[indx1:indx2+len(tag)+2]
- return xml
-
-def prepare_slice(api, xrn, users):
- reg_objects = __get_registry_objects(slice_xrn, creds, users)
- (hrn, type) = urn_to_hrn(slice_xrn)
- slices = Slices(api)
- peer = slices.get_peer(hrn)
- sfa_peer = slices.get_sfa_peer(hrn)
- registry = api.registries[api.hrn]
- credential = api.getCredential()
- (site_id, remote_site_id) = slices.verify_site(registry, credential, hrn, peer, sfa_peer, reg_objects)
- slices.verify_slice(registry, credential, hrn, site_id, remote_site_id, peer, sfa_peer, reg_objects)
-
-def create_slice(api, xrn, cred, rspec, users):
- indx1 = rspec.find("<RSpec")
- indx2 = rspec.find("</RSpec>")
- if indx1 > -1 and indx2 > indx1:
- rspec = rspec[indx1+len("<RSpec type=\"SFA\">"):indx2-1]
- rspec_path = save_rspec_to_file(rspec)
- prepare_slice(api, xrn, users)
- (ret, output) = call_am_apiclient("CreateSliceNetworkClient", [rspec_path,], 3)
- # parse output ?
- rspec = "<RSpec type=\"SFA\"> Done! </RSpec>"
-def delete_slice(api, xrn, cred):
- slice_id = get_short_slice_id(cred, xrn)
- (ret, output) = call_am_apiclient("DeleteSliceNetworkClient", [slice_id,], 3)
- # parse output ?
-def get_rspec(api, cred, options):
- #geni_slice_urn: urn:publicid:IDN+plc:maxpl+slice+xi_rspec_test1
- urn = options.get('geni_slice_urn')
- slice_id = get_short_slice_id(cred, urn)
- if slice_id == None:
- (ret, output) = call_am_apiclient("GetResourceTopology", ['all', '\"\"'], 5)
- (ret, output) = call_am_apiclient("GetResourceTopology", ['all', slice_id,], 5)
- # parse output into rspec XML
- if output.find("No resouce found") > 0:
- rspec = "<RSpec type=\"SFA\"> <Fault>No resource found</Fault> </RSpec>"
- else:
- comp_rspec = get_xml_by_tag(output, 'computeResource')
- sfa_logger().debug("#### computeResource %s" % comp_rspec)
- topo_rspec = get_xml_by_tag(output, 'topology')
- sfa_logger().debug("#### topology %s" % topo_rspec)
- rspec = "<RSpec type=\"SFA\"> <network name=\"" + Config().get_interface_hrn() + "\">";
- if comp_rspec != None:
- rspec = rspec + get_xml_by_tag(output, 'computeResource')
- if topo_rspec != None:
- rspec = rspec + get_xml_by_tag(output, 'topology')
- rspec = rspec + "</network> </RSpec>"
-
- return (rspec)
-
-def start_slice(api, xrn, cred):
- # service not supported
- return None
-
-def stop_slice(api, xrn, cred):
- # service not supported
- return None
-
-def reset_slices(api, xrn):
- # service not supported
- return None
-
-"""
-Returns the request context required by sfatables. At some point, this mechanism should be changed
-to refer to "contexts", which is the information that sfatables is requesting. But for now, we just
-return the basic information needed in a dict.
-"""
-def fetch_context(slice_hrn, user_hrn, contexts):
- base_context = {'sfa':{'user':{'hrn':user_hrn}}}
- return base_context
- api = SfaAPI()
- create_slice(api, "plc.maxpl.test000", None, rspec_xml, None)
-
+from sfa.plc.slices import Slices\r
+from sfa.server.registry import Registries\r
+from sfa.util.xrn import urn_to_hrn, hrn_to_urn, get_authority, Xrn\r
+from sfa.util.plxrn import hrn_to_pl_slicename\r
+from sfa.util.sfalogging import logger\r
+from sfa.util.faults import *\r
+from sfa.util.config import Config\r
+from sfa.util.sfatime import utcparse\r
+from sfa.util.callids import Callids\r
+from sfa.util.version import version_core\r
+from sfa.rspecs.rspec_version import RSpecVersion\r
+from sfa.rspecs.sfa_rspec import sfa_rspec_version\r
+from sfa.rspecs.rspec_parser import parse_rspec\r
+from sfa.managers.aggregate_manager_pl import __get_registry_objects, ListSlices\r
+import os\r
+import time\r
+import re\r
+\r
+RSPEC_TMP_FILE_PREFIX = "/tmp/max_rspec"\r
+\r
+# execute shell command and return both exit code and text output\r
+def shell_execute(cmd, timeout):\r
+ pipe = os.popen('{ ' + cmd + '; } 2>&1', 'r')\r
+ pipe = os.popen(cmd + ' 2>&1', 'r')\r
+ text = ''\r
+ while timeout:\r
+ line = pipe.read()\r
+ text += line\r
+ time.sleep(1)\r
+ timeout = timeout-1\r
+ code = pipe.close()\r
+ if code is None: code = 0\r
+ if text[-1:] == '\n': text = text[:-1]\r
+ return code, text\r
+\r
+"""\r
+ call AM API client with command like in the following example:\r
+ cd aggregate_client; java -classpath AggregateWS-client-api.jar:lib/* \\r
+ net.geni.aggregate.client.examples.CreateSliceNetworkClient \\r
+ ./repo https://geni:8443/axis2/services/AggregateGENI \\r
+ ... params ...\r
+"""\r
+\r
+def call_am_apiclient(client_app, params, timeout):\r
+ (client_path, am_url) = Config().get_max_aggrMgr_info()\r
+ sys_cmd = "cd " + client_path + "; java -classpath AggregateWS-client-api.jar:lib/* net.geni.aggregate.client.examples." + client_app + " ./repo " + am_url + " " + ' '.join(params)\r
+ ret = shell_execute(sys_cmd, timeout)\r
+ logger.debug("shell_execute cmd: %s returns %s" % (sys_cmd, ret))\r
+ return ret\r
+\r
+# save request RSpec xml content to a tmp file\r
+def save_rspec_to_file(rspec):\r
+ path = RSPEC_TMP_FILE_PREFIX + "_" + time.strftime('%Y%m%dT%H:%M:%S', time.gmtime(time.time())) +".xml"\r
+ file = open(path, "w")\r
+ file.write(rspec)\r
+ file.close()\r
+ return path\r
+\r
+# get stripped down slice id/name plc.maxpl.xislice1 --> maxpl_xislice1\r
+def get_plc_slice_id(cred, xrn):\r
+ (hrn, type) = urn_to_hrn(xrn)\r
+ slice_id = hrn.find(':')\r
+ sep = '.'\r
+ if hrn.find(':') != -1:\r
+ sep=':'\r
+ elif hrn.find('+') != -1:\r
+ sep='+'\r
+ else:\r
+ sep='.'\r
+ slice_id = hrn.split(sep)[-2] + '_' + hrn.split(sep)[-1]\r
+ return slice_id\r
+\r
+# extract xml \r
+def get_xml_by_tag(text, tag):\r
+ indx1 = text.find('<'+tag)\r
+ indx2 = text.find('/'+tag+'>')\r
+ xml = None\r
+ if indx1!=-1 and indx2>indx1:\r
+ xml = text[indx1:indx2+len(tag)+2]\r
+ return xml\r
+\r
+def prepare_slice(api, slice_xrn, creds, users):\r
+ reg_objects = __get_registry_objects(slice_xrn, creds, users)\r
+ (hrn, type) = urn_to_hrn(slice_xrn)\r
+ slices = Slices(api)\r
+ peer = slices.get_peer(hrn)\r
+ sfa_peer = slices.get_sfa_peer(hrn)\r
+ slice_record=None\r
+ if users:\r
+ slice_record = users[0].get('slice_record', {})\r
+ registry = api.registries[api.hrn]\r
+ credential = api.getCredential()\r
+ # ensure site record exists\r
+ site = slices.verify_site(hrn, slice_record, peer, sfa_peer)\r
+ # ensure slice record exists\r
+ slice = slices.verify_slice(hrn, slice_record, peer, sfa_peer)\r
+ # ensure person records exists\r
+ persons = slices.verify_persons(hrn, slice, users, peer, sfa_peer)\r
+\r
+def parse_resources(text, slice_xrn):\r
+ resources = []\r
+ urn = hrn_to_urn(slice_xrn, 'sliver')\r
+ plc_slice = re.search("Slice Status => ([^\n]+)", text)\r
+ if plc_slice.group(1) != 'NONE':\r
+ res = {}\r
+ res['geni_urn'] = urn + '_plc_slice'\r
+ res['geni_error'] = ''\r
+ res['geni_status'] = 'unknown'\r
+ if plc_slice.group(1) == 'CREATED':\r
+ res['geni_status'] = 'ready'\r
+ resources.append(res)\r
+ vlans = re.findall("GRI => ([^\n]+)\n\t Status => ([^\n]+)", text)\r
+ for vlan in vlans:\r
+ res = {}\r
+ res['geni_error'] = ''\r
+ res['geni_urn'] = urn + '_vlan_' + vlan[0]\r
+ if vlan[1] == 'ACTIVE':\r
+ res['geni_status'] = 'ready'\r
+ elif vlan[1] == 'FAILED':\r
+ res['geni_status'] = 'failed'\r
+ else:\r
+ res['geni_status'] = 'configuring'\r
+ resources.append(res)\r
+ return resources\r
+\r
+def slice_status(api, slice_xrn, creds):\r
+ urn = hrn_to_urn(slice_xrn, 'slice')\r
+ result = {}\r
+ top_level_status = 'unknown'\r
+ slice_id = get_plc_slice_id(creds, urn)\r
+ (ret, output) = call_am_apiclient("QuerySliceNetworkClient", [slice_id,], 5)\r
+ # parse output into rspec XML\r
+ if output.find("Unkown Rspec:") > 0:\r
+ top_level_staus = 'failed'\r
+ result['geni_resources'] = ''\r
+ else:\r
+ has_failure = 0\r
+ all_active = 0\r
+ if output.find("Status => FAILED") > 0:\r
+ top_level_staus = 'failed'\r
+ elif ( output.find("Status => ACCEPTED") > 0 or output.find("Status => PENDING") > 0\r
+ or output.find("Status => INSETUP") > 0 or output.find("Status => INCREATE") > 0\r
+ ):\r
+ top_level_status = 'configuring'\r
+ else:\r
+ top_level_status = 'ready'\r
+ result['geni_resources'] = parse_resources(output, slice_xrn)\r
+ result['geni_urn'] = urn\r
+ result['geni_status'] = top_level_status\r
+ return result\r
+\r
+def create_slice(api, xrn, cred, rspec, users):\r
+ indx1 = rspec.find("<RSpec")\r
+ indx2 = rspec.find("</RSpec>")\r
+ if indx1 > -1 and indx2 > indx1:\r
+ rspec = rspec[indx1+len("<RSpec type=\"SFA\">"):indx2-1]\r
+ rspec_path = save_rspec_to_file(rspec)\r
+ prepare_slice(api, xrn, cred, users)\r
+ slice_id = get_plc_slice_id(cred, xrn)\r
+ sys_cmd = "sed -i \"s/rspec id=\\\"[^\\\"]*/rspec id=\\\"" +slice_id+ "/g\" " + rspec_path + ";sed -i \"s/:rspec=[^:'<\\\" ]*/:rspec=" +slice_id+ "/g\" " + rspec_path\r
+ ret = shell_execute(sys_cmd, 1)\r
+ sys_cmd = "sed -i \"s/rspec id=\\\"[^\\\"]*/rspec id=\\\"" + rspec_path + "/g\""\r
+ ret = shell_execute(sys_cmd, 1)\r
+ (ret, output) = call_am_apiclient("CreateSliceNetworkClient", [rspec_path,], 3)\r
+ # parse output ?\r
+ rspec = "<RSpec type=\"SFA\"> Done! </RSpec>"\r
+ return True\r
+\r
+def delete_slice(api, xrn, cred):\r
+ slice_id = get_plc_slice_id(cred, xrn)\r
+ (ret, output) = call_am_apiclient("DeleteSliceNetworkClient", [slice_id,], 3)\r
+ # parse output ?\r
+ return 1\r
+\r
+\r
+def get_rspec(api, cred, slice_urn):\r
+ logger.debug("#### called max-get_rspec")\r
+ #geni_slice_urn: urn:publicid:IDN+plc:maxpl+slice+xi_rspec_test1\r
+ if slice_urn == None:\r
+ (ret, output) = call_am_apiclient("GetResourceTopology", ['all', '\"\"'], 5)\r
+ else:\r
+ slice_id = get_plc_slice_id(cred, slice_urn)\r
+ (ret, output) = call_am_apiclient("GetResourceTopology", ['all', slice_id,], 5)\r
+ # parse output into rspec XML\r
+ if output.find("No resouce found") > 0:\r
+ rspec = "<RSpec type=\"SFA\"> <Fault>No resource found</Fault> </RSpec>"\r
+ else:\r
+ comp_rspec = get_xml_by_tag(output, 'computeResource')\r
+ logger.debug("#### computeResource %s" % comp_rspec)\r
+ topo_rspec = get_xml_by_tag(output, 'topology')\r
+ logger.debug("#### topology %s" % topo_rspec)\r
+ rspec = "<RSpec type=\"SFA\"> <network name=\"" + Config().get_interface_hrn() + "\">";\r
+ if comp_rspec != None:\r
+ rspec = rspec + get_xml_by_tag(output, 'computeResource')\r
+ if topo_rspec != None:\r
+ rspec = rspec + get_xml_by_tag(output, 'topology')\r
+ rspec = rspec + "</network> </RSpec>"\r
+ return (rspec)\r
+\r
+def start_slice(api, xrn, cred):\r
+ # service not supported\r
+ return None\r
+\r
+def stop_slice(api, xrn, cred):\r
+ # service not supported\r
+ return None\r
+\r
+def reset_slices(api, xrn):\r
+ # service not supported\r
+ return None\r
+\r
+"""\r
+ GENI AM API Methods\r
+"""\r
+\r
+def GetVersion(api):\r
+ xrn=Xrn(api.hrn)\r
+ request_rspec_versions = [dict(sfa_rspec_version)]\r
+ ad_rspec_versions = [dict(sfa_rspec_version)]\r
+ #TODO: MAX-AM specific\r
+ version_more = {'interface':'aggregate',\r
+ 'testbed':'myplc',\r
+ 'hrn':xrn.get_hrn(),\r
+ 'request_rspec_versions': request_rspec_versions,\r
+ 'ad_rspec_versions': ad_rspec_versions,\r
+ 'default_ad_rspec': dict(sfa_rspec_version)\r
+ }\r
+ return version_core(version_more)\r
+\r
+def SliverStatus(api, slice_xrn, creds, call_id):\r
+ if Callids().already_handled(call_id): return {}\r
+ return slice_status(api, slice_xrn, creds)\r
+\r
+def CreateSliver(api, slice_xrn, creds, rspec_string, users, call_id):\r
+ if Callids().already_handled(call_id): return ""\r
+ #TODO: create real CreateSliver response rspec\r
+ ret = create_slice(api, slice_xrn, creds, rspec_string, users)\r
+ if ret:\r
+ return get_rspec(api, creds, slice_xrn)\r
+ else:\r
+ return "<?xml version=\"1.0\" ?> <RSpec type=\"SFA\"> Error! </RSpec>"\r
+\r
+def DeleteSliver(api, xrn, creds, call_id):\r
+ if Callids().already_handled(call_id): return ""\r
+ return delete_slice(api, xrn, creds)\r
+\r
+# no caching\r
+def ListResources(api, creds, options,call_id):\r
+ if Callids().already_handled(call_id): return ""\r
+ # version_string = "rspec_%s" % (rspec_version.get_version_name())\r
+ slice_urn = options.get('geni_slice_urn')\r
+ return get_rspec(api, creds, slice_urn)\r
+\r
+"""\r
+Returns the request context required by sfatables. At some point, this mechanism should be changed\r
+to refer to "contexts", which is the information that sfatables is requesting. But for now, we just\r
+return the basic information needed in a dict.\r
+"""\r
+def fetch_context(slice_hrn, user_hrn, contexts):\r
+ base_context = {'sfa':{'user':{'hrn':user_hrn}}}\r
+ return base_context\r
+ api = SfaAPI()\r
+ create_slice(api, "plc.maxpl.test000", None, rspec_xml, None)\r
+\r
from sfa.util.faults import *
from sfa.util.xrn import urn_to_hrn
-from sfa.util.rspec import RSpec
from sfa.server.registry import Registries
from sfa.util.config import Config
from sfa.plc.nodes import *
from types import StringTypes
from sfa.util.xrn import urn_to_hrn, Xrn
from sfa.util.plxrn import hrn_to_pl_slicename
-from sfa.util.rspec import *
from sfa.util.specdict import *
from sfa.util.faults import *
from sfa.util.record import SfaRecord
from sfa.plc.slices import Slices
import sfa.plc.peers as peers
from sfa.managers.vini.vini_network import *
+from sfa.plc.vini_aggregate import ViniAggregate
+from sfa.rspecs.version_manager import VersionManager
from sfa.plc.api import SfaAPI
from sfa.plc.slices import *
from sfa.managers.aggregate_manager_pl import __get_registry_objects, __get_hostnames
# get slice's hrn from options
xrn = options.get('geni_slice_urn', '')
hrn, type = urn_to_hrn(xrn)
+
+ version_manager = VersionManager()
+ # get the rspec's return format from options
+ rspec_version = version_manager.get_version(options.get('rspec_version'))
+ version_string = "rspec_%s" % (rspec_version.to_string())
# look in cache first
if api.cache and not xrn:
- rspec = api.cache.get('nodes')
+ rspec = api.cache.get(version_string)
if rspec:
+ api.logger.info("aggregate.ListResources: returning cached value for hrn %s"%hrn)
return rspec
- network = ViniNetwork(api)
- if (hrn):
- if network.get_slice(api, hrn):
- network.addSlice()
-
- rspec = network.toxml()
-
+ aggregate = ViniAggregate(api, options)
+ rspec = aggregate.get_rspec(slice_xrn=xrn, version=rspec_version)
+
# cache the result
if api.cache and not xrn:
api.cache.add('nodes', rspec)
start = RSpec
RSpec = element RSpec {
+ attribute expires { xsd:NMTOKEN },
+ attribute generated { xsd:NMTOKEN },
attribute type { xsd:NMTOKEN },
- cloud
+ network
}
-cloud = element cloud {
- attribute id { xsd:NMTOKEN },
+network = element network {
+ attribute name { xsd:NMTOKEN },
user_info?,
ipv4,
bundles,
</start>
<define name="RSpec">
<element name="RSpec">
+ <attribute name="expires">
+ <data type="NMTOKEN"/>
+ </attribute>
+ <attribute name="generated">
+ <data type="NMTOKEN"/>
+ </attribute>
<attribute name="type">
<data type="NMTOKEN"/>
</attribute>
- <ref name="cloud"/>
+ <oneOrMore>
+ <ref name="network"/>
+ </oneOrMore>
</element>
</define>
- <define name="cloud">
- <element name="cloud">
- <attribute name="id">
+ <define name="network">
+ <element name="network">
+ <attribute name="name">
<data type="NMTOKEN"/>
</attribute>
<optional>
-<RSpec type="eucalyptus">
- <cloud id="OpenCirrus">
- <ipv4>198.55.32.86</ipv4>
+<?xml version="1.0"?>
+<RSpec expires="2011-09-26T21:03:16Z" generated="2011-09-26T20:03:16Z" type="SFA">
+ <statistics call="ListResources">
+ <aggregate status="success" name="genicloud.hplabs" elapsed="0.697860002518"/>
+ <aggregate status="success" name="genicloud.ucsd" elapsed="0.901086091995"/>
+ </statistics>
+ <network name="HpLabs-Cloud">
+ <ipv4>198.55.32.75</ipv4>
<bundles>
- <bundle id="fc12" />
- <bundle id="f12-plab" />
- <bundle id="f12-planetlab" />
- <bundle id="fc11" />
+ <bundle id="ubuntu904"/>
</bundles>
- <cluster id="euca-oc">
- <ipv4>198.55.32.86</ipv4>
+ <cluster id="hplabs">
+ <ipv4>198.55.32.75</ipv4>
<vm_types>
<vm_type name="m1.small">
- <free_slots>41</free_slots>
- <max_instances>44</max_instances>
+ <free_slots>39</free_slots>
+ <max_instances>40</max_instances>
<cores>1</cores>
- <memory unit="MB">192</memory>
+ <memory unit="MB">128</memory>
<disk_space unit="GB">2</disk_space>
</vm_type>
<vm_type name="c1.medium">
- <free_slots>41</free_slots>
- <max_instances>44</max_instances>
+ <free_slots>39</free_slots>
+ <max_instances>40</max_instances>
<cores>1</cores>
<memory unit="MB">256</memory>
<disk_space unit="GB">5</disk_space>
- <request>
+ <request>
<instances>1</instances>
- <bundle>f12-plab</bundle>
+ <bundle>ubuntu904</bundle>
</request>
- </vm_type>
+ </vm_type>
<vm_type name="m1.large">
<free_slots>19</free_slots>
- <max_instances>22</max_instances>
+ <max_instances>20</max_instances>
<cores>2</cores>
<memory unit="MB">512</memory>
<disk_space unit="GB">10</disk_space>
</vm_type>
<vm_type name="m1.xlarge">
<free_slots>19</free_slots>
- <max_instances>22</max_instances>
+ <max_instances>20</max_instances>
+ <cores>2</cores>
+ <memory unit="MB">1024</memory>
+ <disk_space unit="GB">20</disk_space>
+ </vm_type>
+ <vm_type name="c1.xlarge">
+ <free_slots>9</free_slots>
+ <max_instances>10</max_instances>
+ <cores>4</cores>
+ <memory unit="MB">2048</memory>
+ <disk_space unit="GB">20</disk_space>
+ </vm_type>
+ </vm_types>
+ </cluster>
+ </network>
+ <network name="UCSD-Cloud">
+ <ipv4>169.228.66.144</ipv4>
+ <bundles>
+ <bundle id="ubuntu904"/>
+ </bundles>
+ <cluster id="ucsd">
+ <ipv4>169.228.66.144</ipv4>
+ <vm_types>
+ <vm_type name="m1.small">
+ <free_slots>15</free_slots>
+ <max_instances>16</max_instances>
+ <cores>1</cores>
+ <memory unit="MB">128</memory>
+ <disk_space unit="GB">2</disk_space>
+ </vm_type>
+ <vm_type name="c1.medium">
+ <free_slots>15</free_slots>
+ <max_instances>16</max_instances>
+ <cores>1</cores>
+ <memory unit="MB">256</memory>
+ <disk_space unit="GB">5</disk_space>
+ </vm_type>
+ <vm_type name="m1.large">
+ <free_slots>7</free_slots>
+ <max_instances>8</max_instances>
+ <cores>2</cores>
+ <memory unit="MB">512</memory>
+ <disk_space unit="GB">10</disk_space>
+ </vm_type>
+ <vm_type name="m1.xlarge">
+ <free_slots>7</free_slots>
+ <max_instances>8</max_instances>
<cores>2</cores>
<memory unit="MB">1024</memory>
<disk_space unit="GB">20</disk_space>
</vm_type>
<vm_type name="c1.xlarge">
- <free_slots>8</free_slots>
- <max_instances>11</max_instances>
+ <free_slots>3</free_slots>
+ <max_instances>4</max_instances>
<cores>4</cores>
<memory unit="MB">2048</memory>
<disk_space unit="GB">20</disk_space>
</vm_type>
</vm_types>
</cluster>
- </cloud>
+ </network>
</RSpec>
--- /dev/null
+from sfa.util.sfalogging import logger
+
+def import_manager(kind, type):
+ """
+ kind expected in ['registry', 'aggregate', 'slice', 'component']
+ type is e.g. 'pl' or 'max' or whatever
+ """
+ basepath = 'sfa.managers'
+ qualified = "%s.%s_manager_%s"%(basepath,kind,type)
+ generic = "%s.%s_manager"%(basepath,kind)
+
+ message="import_manager for kind=%s and type=%s"%(kind,type)
+ try:
+ manager = __import__(qualified, fromlist=[basepath])
+ logger.info ("%s: loaded %s"%(message,qualified))
+ except:
+ try:
+ manager = __import__ (generic, fromlist=[basepath])
+ if type != 'pl' :
+ logger.warn ("%s: using generic with type!='pl'"%(message))
+ logger.info("%s: loaded %s"%(message,generic))
+ except:
+ manager=None
+ logger.log_exc("%s: unable to import either %s or %s"%(message,qualified,generic))
+ return manager
+
return records
+def create_gid(api, xrn, cert):
+ # get the authority
+ authority = Xrn(xrn=xrn).get_authority_hrn()
+ auth_info = api.auth.get_auth_info(authority)
+ if not cert:
+ pkey = Keypair(create=True)
+ else:
+ certificate = Certificate(string=cert)
+ pkey = certificate.get_pubkey()
+ gid = api.auth.hierarchy.create_gid(xrn, create_uuid(), pkey)
+ return gid.save_to_string(save_parents=True)
+
def register(api, record):
hrn, type = record['hrn'], record['type']
record['authority'] = get_authority(record['hrn'])
type = record['type']
hrn = record['hrn']
- api.auth.verify_object_permission(hrn)
auth_info = api.auth.get_auth_info(record['authority'])
pub_key = None
# make sure record has a gid
type = new_record['type']
hrn = new_record['hrn']
urn = hrn_to_urn(hrn,type)
- api.auth.verify_object_permission(hrn)
table = SfaTable()
# make sure the record exists
records = table.findObjects({'type': type, 'hrn': hrn})
-#
+#
import sys
import time,datetime
from StringIO import StringIO
from copy import copy
from lxml import etree
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
from sfa.util.rspecHelper import merge_rspecs
from sfa.util.xrn import Xrn, urn_to_hrn, hrn_to_urn
from sfa.util.plxrn import hrn_to_pl_slicename
-from sfa.util.rspec import *
from sfa.util.specdict import *
from sfa.util.faults import *
from sfa.util.record import SfaRecord
-from sfa.rspecs.pg_rspec import PGRSpec
-from sfa.rspecs.sfa_rspec import SfaRSpec
from sfa.rspecs.rspec_converter import RSpecConverter
-from sfa.rspecs.rspec_parser import parse_rspec
-from sfa.rspecs.rspec_version import RSpecVersion
-from sfa.rspecs.sfa_rspec import sfa_rspec_version
-from sfa.rspecs.pg_rspec import pg_rspec_ad_version, pg_rspec_request_version
+from sfa.client.client_helper import sfa_to_pg_users_arg
+from sfa.rspecs.version_manager import VersionManager
+from sfa.rspecs.rspec import RSpec
from sfa.util.policy import Policy
from sfa.util.prefixTree import prefixTree
from sfa.util.sfaticket import *
from sfa.util.version import version_core
from sfa.util.callids import Callids
+
+def _call_id_supported(api, server):
+ """
+ Returns true if server support the optional call_id arg, false otherwise.
+ """
+ server_version = api.get_cached_server_version(server)
+
+ if 'sfa' in server_version:
+ code_tag = server_version['code_tag']
+ code_tag_parts = code_tag.split("-")
+
+ version_parts = code_tag_parts[0].split(".")
+ major, minor = version_parts[0:2]
+ rev = code_tag_parts[1]
+ if int(major) > 1:
+ if int(minor) > 0 or int(rev) > 20:
+ return True
+ return False
+
# we have specialized xmlrpclib.ServerProxy to remember the input url
# OTOH it's not clear if we're only dealing with XMLRPCServerProxy instances
def get_serverproxy_url (server):
try:
- return server.url
+ return server.get_url()
except:
- sfa_logger().warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
- return server._ServerProxy__host + server._ServerProxy__handler
+ logger.warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
+ return server._ServerProxy__host + server._ServerProxy__handler
def GetVersion(api):
# peers explicitly in aggregates.xml
- peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems()
+ peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems()
if peername != api.hrn])
- xrn=Xrn (api.hrn)
- request_rspec_versions = [dict(pg_rspec_request_version), dict(sfa_rspec_version)]
- ad_rspec_versions = [dict(pg_rspec_ad_version), dict(sfa_rspec_version)]
+ version_manager = VersionManager()
+ ad_rspec_versions = []
+ request_rspec_versions = []
+ for rspec_version in version_manager.versions:
+ if rspec_version.content_type in ['*', 'ad']:
+ ad_rspec_versions.append(rspec_version.to_dict())
+ if rspec_version.content_type in ['*', 'request']:
+ request_rspec_versions.append(rspec_version.to_dict())
+ default_rspec_version = version_manager.get_version("sfa 1").to_dict()
+ xrn=Xrn(api.hrn, 'authority+sa')
version_more = {'interface':'slicemgr',
'hrn' : xrn.get_hrn(),
'urn' : xrn.get_urn(),
'peers': peers,
'request_rspec_versions': request_rspec_versions,
'ad_rspec_versions': ad_rspec_versions,
- 'default_ad_rspec': dict(sfa_rspec_version)
+ 'default_ad_rspec': default_rspec_version
}
sm_version=version_core(version_more)
# local aggregate if present needs to have localhost resolved
sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
return sm_version
-def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
+def drop_slicemgr_stats(rspec):
+ try:
+ stats_elements = rspec.xml.xpath('//statistics')
+ for node in stats_elements:
+ node.getparent().remove(node)
+ except Exception, e:
+ api.logger.warn("drop_slicemgr_stats failed: %s " % (str(e)))
- def _CreateSliver(aggregate, xrn, credential, rspec, users, call_id):
- # Need to call GetVersion at an aggregate to determine the supported
- # rspec type/format beofre calling CreateSliver at an Aggregate.
- # The Aggregate's verion info is cached
- server = api.aggregates[aggregate]
- # get cached aggregate version
- aggregate_version_key = 'version_'+ aggregate
- aggregate_version = api.cache.get(aggregate_version_key)
- if not aggregate_version:
- # get current aggregate version anc cache it for 24 hours
- aggregate_version = server.GetVersion()
- api.cache.add(aggregate_version_key, aggregate_version, 60 * 60 * 24)
-
- if 'sfa' not in aggregate_version and 'geni_api' in aggregate_version:
- # sfa aggregtes support both sfa and pg rspecs, no need to convert
- # if aggregate supports sfa rspecs. othewise convert to pg rspec
- rspec = RSpecConverter.to_pg_rspec(rspec)
+def add_slicemgr_stat(rspec, callname, aggname, elapsed, status):
+ try:
+ stats_tags = rspec.xml.xpath('//statistics[@call="%s"]' % callname)
+ if stats_tags:
+ stats_tag = stats_tags[0]
+ else:
+ stats_tag = etree.SubElement(rspec.xml.root, "statistics", call=callname)
- return server.CreateSliver(xrn, credential, rspec, users, call_id)
-
+ etree.SubElement(stats_tag, "aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status))
+ except Exception, e:
+ api.logger.warn("add_slicemgr_stat failed on %s: %s" %(aggname, str(e)))
+
+def ListResources(api, creds, options, call_id):
+ version_manager = VersionManager()
+ def _ListResources(aggregate, server, credential, opts, call_id):
+
+ my_opts = copy(opts)
+ args = [credential, my_opts]
+ tStart = time.time()
+ try:
+ if _call_id_supported(api, server):
+ args.append(call_id)
+ version = api.get_cached_server_version(server)
+ # force ProtoGENI aggregates to give us a v2 RSpec
+ if 'sfa' not in version.keys():
+ my_opts['rspec_version'] = version_manager.get_version('ProtoGENI 2').to_dict()
+ rspec = server.ListResources(*args)
+ return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
+ except Exception, e:
+ api.logger.log_exc("ListResources failed at %s" %(server.url))
+ return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
if Callids().already_handled(call_id): return ""
+ # get slice's hrn from options
+ xrn = options.get('geni_slice_urn', '')
+ (hrn, type) = urn_to_hrn(xrn)
+ if 'geni_compressed' in options:
+ del(options['geni_compressed'])
+
+ # get the rspec's return format from options
+ rspec_version = version_manager.get_version(options.get('rspec_version'))
+ version_string = "rspec_%s" % (rspec_version.to_string())
+
+ # look in cache first
+ if caching and api.cache and not xrn:
+ rspec = api.cache.get(version_string)
+ if rspec:
+ return rspec
+
+ # get the callers hrn
+ valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
+ caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
+
+ # attempt to use delegated credential first
+ cred = api.getDelegatedCredential(creds)
+ if not cred:
+ cred = api.getCredential()
+ threads = ThreadManager()
+ for aggregate in api.aggregates:
+ # prevent infinite loop. Dont send request back to caller
+ # unless the caller is the aggregate's SM
+ if caller_hrn == aggregate and aggregate != api.hrn:
+ continue
+
+ # get the rspec from the aggregate
+ interface = api.aggregates[aggregate]
+ server = api.get_server(interface, cred)
+ threads.run(_ListResources, aggregate, server, [cred], options, call_id)
+
+
+ results = threads.get_results()
+ rspec_version = version_manager.get_version(options.get('rspec_version'))
+ if xrn:
+ result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'manifest')
+ else:
+ result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'ad')
+ rspec = RSpec(version=result_version)
+ for result in results:
+ add_slicemgr_stat(rspec, "ListResources", result["aggregate"], result["elapsed"], result["status"])
+ if result["status"]=="success":
+ try:
+ rspec.version.merge(result["rspec"])
+ except:
+ api.logger.log_exc("SM.ListResources: Failed to merge aggregate rspec")
+
+ # cache the result
+ if caching and api.cache and not xrn:
+ api.cache.add(version_string, rspec.toxml())
+
+ return rspec.toxml()
+
+
+def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
+
+ version_manager = VersionManager()
+ def _CreateSliver(aggregate, server, xrn, credential, rspec, users, call_id):
+ tStart = time.time()
+ try:
+ # Need to call GetVersion at an aggregate to determine the supported
+ # rspec type/format beofre calling CreateSliver at an Aggregate.
+ server_version = api.get_cached_server_version(server)
+ requested_users = users
+ if 'sfa' not in server_version and 'geni_api' in server_version:
+ # sfa aggregtes support both sfa and pg rspecs, no need to convert
+ # if aggregate supports sfa rspecs. otherwise convert to pg rspec
+ rspec = RSpec(RSpecConverter.to_pg_rspec(rspec, 'request'))
+ filter = {'component_manager_id': server_version['urn']}
+ rspec.filter(filter)
+ rspec = rspec.toxml()
+ requested_users = sfa_to_pg_users_arg(users)
+ args = [xrn, credential, rspec, requested_users]
+ if _call_id_supported(api, server):
+ args.append(call_id)
+ rspec = server.CreateSliver(*args)
+ return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
+ except:
+ logger.log_exc('Something wrong in _CreateSliver with URL %s'%server.url)
+ return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
+
+ if Callids().already_handled(call_id): return ""
# Validate the RSpec against PlanetLab's schema --disabled for now
# The schema used here needs to aggregate the PL and VINI schemas
# schema = "/var/www/html/schemas/pl.rng"
- rspec = parse_rspec(rspec_str)
+ rspec = RSpec(rspec_str)
schema = None
if schema:
rspec.validate(schema)
+ # if there is a <statistics> section, the aggregates don't care about it,
+ # so delete it.
+ drop_slicemgr_stats(rspec)
+
# attempt to use delegated credential first
- credential = api.getDelegatedCredential(creds)
- if not credential:
- credential = api.getCredential()
+ cred = api.getDelegatedCredential(creds)
+ if not cred:
+ cred = api.getCredential()
# get the callers hrn
hrn, type = urn_to_hrn(xrn)
# unless the caller is the aggregate's SM
if caller_hrn == aggregate and aggregate != api.hrn:
continue
-
+ interface = api.aggregates[aggregate]
+ server = api.get_server(interface, cred)
# Just send entire RSpec to each aggregate
- threads.run(_CreateSliver, aggregate, xrn, credential, rspec.toxml(), users, call_id)
+ threads.run(_CreateSliver, aggregate, server, xrn, [cred], rspec.toxml(), users, call_id)
results = threads.get_results()
- rspec = SfaRSpec()
+ manifest_version = version_manager._get_version(rspec.version.type, rspec.version.version, 'manifest')
+ result_rspec = RSpec(version=manifest_version)
for result in results:
- rspec.merge(result)
- return rspec.toxml()
+ add_slicemgr_stat(result_rspec, "CreateSliver", result["aggregate"], result["elapsed"], result["status"])
+ if result["status"]=="success":
+ try:
+ result_rspec.version.merge(result["rspec"])
+ except:
+ api.logger.log_exc("SM.CreateSliver: Failed to merge aggregate rspec")
+ return result_rspec.toxml()
def RenewSliver(api, xrn, creds, expiration_time, call_id):
+ def _RenewSliver(server, xrn, creds, expiration_time, call_id):
+ server_version = api.get_cached_server_version(server)
+ args = [xrn, creds, expiration_time, call_id]
+ if _call_id_supported(api, server):
+ args.append(call_id)
+ return server.RenewSliver(*args)
+
if Callids().already_handled(call_id): return True
(hrn, type) = urn_to_hrn(xrn)
caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
# attempt to use delegated credential first
- credential = api.getDelegatedCredential(creds)
- if not credential:
- credential = api.getCredential()
+ cred = api.getDelegatedCredential(creds)
+ if not cred:
+ cred = api.getCredential()
threads = ThreadManager()
for aggregate in api.aggregates:
# prevent infinite loop. Dont send request back to caller
# unless the caller is the aggregate's SM
if caller_hrn == aggregate and aggregate != api.hrn:
continue
-
- server = api.aggregates[aggregate]
- threads.run(server.RenewSliver, xrn, [credential], expiration_time, call_id)
+ interface = api.aggregates[aggregate]
+ server = api.get_server(interface, cred)
+ threads.run(_RenewSliver, server, xrn, [cred], expiration_time, call_id)
# 'and' the results
return reduce (lambda x,y: x and y, threads.get_results() , True)
+def DeleteSliver(api, xrn, creds, call_id):
+ def _DeleteSliver(server, xrn, creds, call_id):
+ server_version = api.get_cached_server_version(server)
+ args = [xrn, creds]
+ if _call_id_supported(api, server):
+ args.append(call_id)
+ return server.DeleteSliver(*args)
+
+ if Callids().already_handled(call_id): return ""
+ (hrn, type) = urn_to_hrn(xrn)
+ # get the callers hrn
+ valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
+ caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
+
+ # attempt to use delegated credential first
+ cred = api.getDelegatedCredential(creds)
+ if not cred:
+ cred = api.getCredential()
+ threads = ThreadManager()
+ for aggregate in api.aggregates:
+ # prevent infinite loop. Dont send request back to caller
+ # unless the caller is the aggregate's SM
+ if caller_hrn == aggregate and aggregate != api.hrn:
+ continue
+ interface = api.aggregates[aggregate]
+ server = api.get_server(interface, cred)
+ threads.run(_DeleteSliver, server, xrn, [cred], call_id)
+ threads.get_results()
+ return 1
+
+
+# first draft at a merging SliverStatus
+def SliverStatus(api, slice_xrn, creds, call_id):
+ def _SliverStatus(server, xrn, creds, call_id):
+ server_version = api.get_cached_server_version(server)
+ args = [xrn, creds]
+ if _call_id_supported(api, server):
+ args.append(call_id)
+ return server.SliverStatus(*args)
+
+ if Callids().already_handled(call_id): return {}
+ # attempt to use delegated credential first
+ cred = api.getDelegatedCredential(creds)
+ if not cred:
+ cred = api.getCredential()
+ threads = ThreadManager()
+ for aggregate in api.aggregates:
+ interface = api.aggregates[aggregate]
+ server = api.get_server(interface, cred)
+ threads.run (_SliverStatus, server, slice_xrn, [cred], call_id)
+ results = threads.get_results()
+
+ # get rid of any void result - e.g. when call_id was hit where by convention we return {}
+ results = [ result for result in results if result and result['geni_resources']]
+
+ # do not try to combine if there's no result
+ if not results : return {}
+
+ # otherwise let's merge stuff
+ overall = {}
+
+ # mmh, it is expected that all results carry the same urn
+ overall['geni_urn'] = results[0]['geni_urn']
+ overall['pl_login'] = results[0]['pl_login']
+ # append all geni_resources
+ overall['geni_resources'] = \
+ reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
+ overall['status'] = 'unknown'
+ if overall['geni_resources']:
+ overall['status'] = 'ready'
+
+ return overall
+
+caching=True
+#caching=False
+def ListSlices(api, creds, call_id):
+ def _ListSlices(server, creds, call_id):
+ server_version = api.get_cached_server_version(server)
+ args = [creds]
+ if _call_id_supported(api, server):
+ args.append(call_id)
+ return server.ListSlices(*args)
+
+ if Callids().already_handled(call_id): return []
+
+ # look in cache first
+ if caching and api.cache:
+ slices = api.cache.get('slices')
+ if slices:
+ return slices
+
+ # get the callers hrn
+ valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
+ caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
+
+ # attempt to use delegated credential first
+ cred= api.getDelegatedCredential(creds)
+ if not cred:
+ cred = api.getCredential()
+ threads = ThreadManager()
+ # fetch from aggregates
+ for aggregate in api.aggregates:
+ # prevent infinite loop. Dont send request back to caller
+ # unless the caller is the aggregate's SM
+ if caller_hrn == aggregate and aggregate != api.hrn:
+ continue
+ interface = api.aggregates[aggregate]
+ server = api.get_server(interface, cred)
+ threads.run(_ListSlices, server, [cred], call_id)
+
+ # combime results
+ results = threads.get_results()
+ slices = []
+ for result in results:
+ slices.extend(result)
+
+ # cache the result
+ if caching and api.cache:
+ api.cache.add('slices', slices)
+
+ return slices
+
+
def get_ticket(api, xrn, creds, rspec, users):
slice_hrn, type = urn_to_hrn(xrn)
# get the netspecs contained within the clients rspec
caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
# attempt to use delegated credential first
- credential = api.getDelegatedCredential(creds)
- if not credential:
- credential = api.getCredential()
+ cred = api.getDelegatedCredential(creds)
+ if not cred:
+ cred = api.getCredential()
threads = ThreadManager()
for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
# prevent infinite loop. Dont send request back to caller
# unless the caller is the aggregate's SM
if caller_hrn == aggregate and aggregate != api.hrn:
continue
- server = None
- if aggregate in api.aggregates:
- server = api.aggregates[aggregate]
- else:
- net_urn = hrn_to_urn(aggregate, 'authority')
- # we may have a peer that knows about this aggregate
- for agg in api.aggregates:
- target_aggs = api.aggregates[agg].get_aggregates(credential, net_urn)
- if not target_aggs or not 'hrn' in target_aggs[0]:
- continue
- # send the request to this address
- url = target_aggs[0]['url']
- server = xmlrpcprotocol.get_server(url, api.key_file, api.cert_file)
- # aggregate found, no need to keep looping
- break
- if server is None:
- continue
- threads.run(server.GetTicket, xrn, credential, aggregate_rspec, users)
+
+ interface = api.aggregates[aggregate]
+ server = api.get_server(interface, cred)
+ threads.run(server.GetTicket, xrn, [cred], aggregate_rspec, users)
results = threads.get_results()
ticket.sign()
return ticket.save_to_string(save_parents=True)
-
-def DeleteSliver(api, xrn, creds, call_id):
- if Callids().already_handled(call_id): return ""
- (hrn, type) = urn_to_hrn(xrn)
- # get the callers hrn
- valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
- caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
-
- # attempt to use delegated credential first
- credential = api.getDelegatedCredential(creds)
- if not credential:
- credential = api.getCredential()
- threads = ThreadManager()
- for aggregate in api.aggregates:
- # prevent infinite loop. Dont send request back to caller
- # unless the caller is the aggregate's SM
- if caller_hrn == aggregate and aggregate != api.hrn:
- continue
- server = api.aggregates[aggregate]
- threads.run(server.DeleteSliver, xrn, credential, call_id)
- threads.get_results()
- return 1
-
def start_slice(api, xrn, creds):
hrn, type = urn_to_hrn(xrn)
caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
# attempt to use delegated credential first
- credential = api.getDelegatedCredential(creds)
- if not credential:
- credential = api.getCredential()
+ cred = api.getDelegatedCredential(creds)
+ if not cred:
+ cred = api.getCredential()
threads = ThreadManager()
for aggregate in api.aggregates:
# prevent infinite loop. Dont send request back to caller
# unless the caller is the aggregate's SM
if caller_hrn == aggregate and aggregate != api.hrn:
continue
- server = api.aggregates[aggregate]
- threads.run(server.Start, xrn, credential)
+ interface = api.aggregates[aggregate]
+ server = api.get_server(interface, cred)
+ threads.run(server.Start, xrn, cred)
threads.get_results()
return 1
caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
# attempt to use delegated credential first
- credential = api.getDelegatedCredential(creds)
- if not credential:
- credential = api.getCredential()
+ cred = api.getDelegatedCredential(creds)
+ if not cred:
+ cred = api.getCredential()
threads = ThreadManager()
for aggregate in api.aggregates:
# prevent infinite loop. Dont send request back to caller
# unless the caller is the aggregate's SM
if caller_hrn == aggregate and aggregate != api.hrn:
continue
- server = api.aggregates[aggregate]
- threads.run(server.Stop, xrn, credential)
+ interface = api.aggregates[aggregate]
+ server = api.get_server(interface, cred)
+ threads.run(server.Stop, xrn, cred)
threads.get_results()
return 1
"""
return 1
-# Thierry : caching at the slicemgr level makes sense to some extent
-caching=True
-#caching=False
-def ListSlices(api, creds, call_id):
-
- if Callids().already_handled(call_id): return []
-
- # look in cache first
- if caching and api.cache:
- slices = api.cache.get('slices')
- if slices:
- return slices
-
- # get the callers hrn
- valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
- caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
-
- # attempt to use delegated credential first
- credential = api.getDelegatedCredential(creds)
- if not credential:
- credential = api.getCredential()
- threads = ThreadManager()
- # fetch from aggregates
- for aggregate in api.aggregates:
- # prevent infinite loop. Dont send request back to caller
- # unless the caller is the aggregate's SM
- if caller_hrn == aggregate and aggregate != api.hrn:
- continue
- server = api.aggregates[aggregate]
- threads.run(server.ListSlices, credential, call_id)
-
- # combime results
- results = threads.get_results()
- slices = []
- for result in results:
- slices.extend(result)
-
- # cache the result
- if caching and api.cache:
- api.cache.add('slices', slices)
-
- return slices
-
-
-def ListResources(api, creds, options, call_id):
-
- if Callids().already_handled(call_id): return ""
-
- # get slice's hrn from options
- xrn = options.get('geni_slice_urn', '')
- (hrn, type) = urn_to_hrn(xrn)
-
- # get the rspec's return format from options
- rspec_version = RSpecVersion(options.get('rspec_version'))
- version_string = "rspec_%s" % (rspec_version.get_version_name())
-
- # look in cache first
- if caching and api.cache and not xrn:
- rspec = api.cache.get(version_string)
- if rspec:
- return rspec
-
- # get the callers hrn
- valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
- caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
-
- # attempt to use delegated credential first
- credential = api.getDelegatedCredential(creds)
- if not credential:
- credential = api.getCredential()
- threads = ThreadManager()
- for aggregate in api.aggregates:
- # prevent infinite loop. Dont send request back to caller
- # unless the caller is the aggregate's SM
- if caller_hrn == aggregate and aggregate != api.hrn:
- continue
- # get the rspec from the aggregate
- server = api.aggregates[aggregate]
- my_opts = copy(options)
- my_opts['geni_compressed'] = False
- threads.run(server.ListResources, credential, my_opts, call_id)
-
- results = threads.get_results()
- rspec_version = RSpecVersion(my_opts.get('rspec_version'))
- if rspec_version['type'] == pg_rspec_ad_version['type']:
- rspec = PGRSpec()
- else:
- rspec = SfaRSpec()
-
- for result in results:
- try:
- rspec.merge(result)
- except:
- api.logger.info("SM.ListResources: Failed to merge aggregate rspec")
-
- # cache the result
- if caching and api.cache and not xrn:
- api.cache.add(version_string, rspec.toxml())
-
- return rspec.toxml()
-
-# first draft at a merging SliverStatus
-def SliverStatus(api, slice_xrn, creds, call_id):
- if Callids().already_handled(call_id): return {}
- # attempt to use delegated credential first
- credential = api.getDelegatedCredential(creds)
- if not credential:
- credential = api.getCredential()
- threads = ThreadManager()
- for aggregate in api.aggregates:
- server = api.aggregates[aggregate]
- threads.run (server.SliverStatus, slice_xrn, credential, call_id)
- results = threads.get_results()
-
- # get rid of any void result - e.g. when call_id was hit where by convention we return {}
- results = [ result for result in results if result and result['geni_resources']]
-
- # do not try to combine if there's no result
- if not results : return {}
-
- # otherwise let's merge stuff
- overall = {}
-
- # mmh, it is expected that all results carry the same urn
- overall['geni_urn'] = results[0]['geni_urn']
- overall['pl_login'] = results[0]['pl_login']
- # append all geni_resources
- overall['geni_resources'] = \
- reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
- overall['status'] = 'unknown'
- if overall['geni_resources']:
- overall['status'] = 'ready'
-
- return overall
-
def main():
r = RSpec()
r.parseFile(sys.argv[1])
if __name__ == "__main__":
main()
-
+
#!/usr/bin/python
-
-# $Id: topology.py 14181 2009-07-01 19:46:07Z acb $
-# $URL: https://svn.planet-lab.org/svn/NodeManager-topo/trunk/topology.py $
-
#
# Links in the physical topology, gleaned from looking at the Internet2
# topology map. Link (a, b) connects sites with IDs a and b.
start = RSpec
RSpec = element RSpec {
+ attribute expires { xsd:NMTOKEN },
+ attribute generated { xsd:NMTOKEN },
attribute type { xsd:NMTOKEN },
( network | request )
}
</start>
<define name="RSpec">
<element name="RSpec">
+ <attribute name="expires">
+ <data type="NMTOKEN"/>
+ </attribute>
+ <attribute name="generated">
+ <data type="NMTOKEN"/>
+ </attribute>
<attribute name="type">
<data type="NMTOKEN"/>
</attribute>
--- /dev/null
+
+from sfa.util.xrn import urn_to_hrn
+from sfa.util.method import Method
+from sfa.util.parameter import Parameter, Mixed
+from sfa.trust.credential import Credential
+
+class CreateGid(Method):
+ """
+ Create a signed credential for the s object with the registry. In addition to being stored in the
+ SFA database, the appropriate records will also be created in the
+ PLC databases
+
+ @param xrn urn or hrn of certificate owner
+ @param cert caller's certificate
+ @param cred credential string
+
+ @return gid string representation
+ """
+
+ interfaces = ['registry']
+
+ accepts = [
+ Mixed(Parameter(str, "Credential string"),
+ Parameter(type([str]), "List of credentials")),
+ Parameter(str, "URN or HRN of certificate owner"),
+ Parameter(str, "Certificate string"),
+ ]
+
+ returns = Parameter(int, "String representation of gid object")
+
+ def call(self, creds, xrn, cert=None):
+ # TODO: is there a better right to check for or is 'update good enough?
+ valid_creds = self.api.auth.checkCredentials(creds, 'update')
+
+ # verify permissions
+ hrn, type = urn_to_hrn(xrn)
+ self.api.auth.verify_object_permission(hrn)
+
+ #log the call
+ origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
+
+ # log
+ origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
+ self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, xrn, self.name))
+
+ manager = self.api.get_interface_manager()
+
+ return manager.create_gid(self.api, xrn, cert)
valid_creds = self.api.auth.checkCredentials(creds, 'createsliver', hrn)
origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
+ # make sure users info is specified
+ if not users:
+ msg = "'users' musst be specified and cannot be null. You may need to update your client."
+ raise SfaInvalidArgument(name='users', extra=msg)
+
manager = self.api.get_interface_manager()
# flter rspec through sfatables
-### $Id: get_ticket.py 17732 2010-04-19 21:10:45Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/get_ticket.py $
import time
from sfa.util.faults import *
from sfa.util.xrn import urn_to_hrn
-### $Id: reset_slice.py 15428 2009-10-23 15:28:03Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfacomponent/methods/reset_slice.py $
import xmlrpclib
from sfa.util.faults import *
from sfa.util.method import Method
-### $Id: register.py 16477 2010-01-05 16:31:37Z thierry $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/register.py $
from sfa.trust.certificate import Keypair, convert_public_key
from sfa.trust.gid import *
returns = Parameter(int, "String representation of gid object")
def call(self, record, creds):
-
+ # validate cred
valid_creds = self.api.auth.checkCredentials(creds, 'register')
+
+ # verify permissions
+ hrn = record.get('hrn', '')
+ self.api.auth.verify_object_permission(hrn)
#log the call
origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
-
- hrn = None
- if 'hrn' in record:
- hrn = record['hrn']
self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, hrn, self.name))
manager = self.api.get_interface_manager()
-### $Id: register.py 15001 2009-09-11 20:18:54Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/register.py $
-
from sfa.trust.certificate import Keypair, convert_public_key
from sfa.trust.gid import *
# Validate that the time does not go beyond the credential's expiration time
requested_time = utcparse(expiration_time)
+ max_renew_days = int(self.api.config.SFA_MAX_SLICE_RENEW)
if requested_time > Credential(string=valid_creds[0]).get_expiration():
raise InsufficientRights('Renewsliver: Credential expires before requested expiration time')
- if requested_time > datetime.datetime.utcnow() + datetime.timedelta(days=60):
- raise Exception('Cannot renew > 60 days from now')
+ if requested_time > datetime.datetime.utcnow() + datetime.timedelta(days=max_renew_days):
+ raise Exception('Cannot renew > %s days from now' % max_renew_days)
manager = self.api.get_interface_manager()
return manager.RenewSliver(self.api, slice_xrn, valid_creds, expiration_time, call_id)
-### $Id: resolve.py 17157 2010-02-21 04:19:34Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/resolve.py $
import traceback
import types
from sfa.util.faults import *
-### $Id: stop_slice.py 17732 2010-04-19 21:10:45Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/stop_slice.py $
-
from sfa.util.faults import *
from sfa.util.xrn import urn_to_hrn
from sfa.util.method import Method
-### $Id: stop_slice.py 17732 2010-04-19 21:10:45Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/stop_slice.py $
-
from sfa.util.faults import *
from sfa.util.xrn import urn_to_hrn
from sfa.util.method import Method
-### $Id: update.py 16477 2010-01-05 16:31:37Z thierry $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/update.py $
-
import time
from sfa.util.faults import *
from sfa.util.method import Method
def call(self, record_dict, creds):
# validate the cred
valid_creds = self.api.auth.checkCredentials(creds, "update")
+
+ # verify permissions
+ hrn = record_dict.get('hrn', '')
+ self.api.auth.verify_object_permission(hrn)
+
+ # log
origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
- self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, None, self.name))
+ self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, hrn, self.name))
manager = self.api.get_interface_manager()
## Please use make index to update this file
all = """
CreateSliver
+CreateGid
DeleteSliver
GetCredential
GetGids
-### $Id: register.py 15001 2009-09-11 20:18:54Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/register.py $
-
from sfa.trust.certificate import Keypair, convert_public_key
from sfa.trust.gid import *
-### $Id: reset_slices.py 15428 2009-10-23 15:28:03Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/reset_slices.py $
-
from sfa.util.faults import *
from sfa.util.xrn import urn_to_hrn
from sfa.util.method import Method
#!/usr/bin/python
from sfa.util.xrn import *
from sfa.util.plxrn import *
-from sfa.rspecs.sfa_rspec import SfaRSpec
-from sfa.rspecs.pg_rspec import PGRSpec
-from sfa.rspecs.rspec_version import RSpecVersion
+#from sfa.rspecs.sfa_rspec import SfaRSpec
+#from sfa.rspecs.pg_rspec import PGRSpec
+#from sfa.rspecs.rspec_version import RSpecVersion
+from sfa.rspecs.rspec import RSpec
+from sfa.rspecs.version_manager import VersionManager
+from sfa.plc.vlink import get_tc_rate
class Aggregate:
interfaces = {}
links = {}
node_tags = {}
+ pl_initscripts = {}
prepared=False
#panos new user options variable
user_options = {}
def prepare_nodes(self, force=False):
if not self.nodes or force:
- for node in self.api.plshell.GetNodes(self.api.plauth):
+ for node in self.api.plshell.GetNodes(self.api.plauth, {'peer_id': None}):
+ # add site/interface info to nodes.
+ # assumes that sites, interfaces and tags have already been prepared.
+ site = self.sites[node['site_id']]
+ interfaces = [self.interfaces[interface_id] for interface_id in node['interface_ids']]
+ tags = [self.node_tags[tag_id] for tag_id in node['node_tag_ids']]
+ node['network'] = self.api.hrn
+ node['network_urn'] = hrn_to_urn(self.api.hrn, 'authority+am')
+ node['urn'] = hostname_to_urn(self.api.hrn, site['login_base'], node['hostname'])
+ node['site_urn'] = hrn_to_urn(PlXrn.site_hrn(self.api.hrn, site['login_base']), 'authority+sa')
+ node['site'] = site
+ node['interfaces'] = interfaces
+ node['tags'] = tags
self.nodes[node['node_id']] = node
def prepare_interfaces(self, force=False):
for node_tag in self.api.plshell.GetNodeTags(self.api.plauth):
self.node_tags[node_tag['node_tag_id']] = node_tag
+ def prepare_pl_initscripts(self, force=False):
+ if not self.pl_initscripts or force:
+ for initscript in self.api.plshell.GetInitScripts(self.api.plauth, {'enabled': True}):
+ self.pl_initscripts[initscript['initscript_id']] = initscript
+
def prepare(self, force=False):
if not self.prepared or force:
self.prepare_sites(force)
- self.prepare_nodes(force)
self.prepare_interfaces(force)
- self.prepare_links(force)
self.prepare_node_tags(force)
- # add site/interface info to nodes
- for node_id in self.nodes:
- node = self.nodes[node_id]
- site = self.sites[node['site_id']]
- interfaces = [self.interfaces[interface_id] for interface_id in node['interface_ids']]
- tags = [self.node_tags[tag_id] for tag_id in node['node_tag_ids']]
- node['network'] = self.api.hrn
- node['network_urn'] = hrn_to_urn(self.api.hrn, 'authority+am')
- node['urn'] = hostname_to_urn(self.api.hrn, site['login_base'], node['hostname'])
- node['site_urn'] = hrn_to_urn(PlXrn.site_hrn(self.api.hrn, site['login_base']), 'authority+sa')
- node['site'] = site
- node['interfaces'] = interfaces
- node['tags'] = tags
-
+ self.prepare_nodes(force)
+ self.prepare_links(force)
+ self.prepare_pl_initscripts()
self.prepared = True
def get_rspec(self, slice_xrn=None, version = None):
self.prepare()
- rspec = None
- rspec_version = RSpecVersion(version)
- if slice_xrn:
- type = 'manifest'
+ version_manager = VersionManager()
+ version = version_manager.get_version(version)
+ if not slice_xrn:
+ rspec_version = version_manager._get_version(version.type, version.version, 'ad')
else:
- type = 'advertisement'
- if rspec_version['type'].lower() == 'protogeni':
- rspec = PGRSpec(type=type)
- elif rspec_version['type'].lower() == 'sfa':
- rspec = SfaRSpec(type=type, user_options=self.user_options)
- else:
- rspec = SfaRSpec(type=type, user_options=self.user_options)
-
-
- rspec.add_nodes(self.nodes.values())
- rspec.add_interfaces(self.interfaces.values())
- rspec.add_links(self.links.values())
-
+ rspec_version = version_manager._get_version(version.type, version.version, 'manifest')
+
+ rspec = RSpec(version=rspec_version, user_options=self.user_options)
+ # get slice details if specified
+ slice = None
if slice_xrn:
- # If slicename is specified then resulting rspec is a manifest.
- # Add sliver details to rspec and remove 'advertisement' elements
slice_hrn, _ = urn_to_hrn(slice_xrn)
slice_name = hrn_to_pl_slicename(slice_hrn)
slices = self.api.plshell.GetSlices(self.api.plauth, slice_name)
if slices:
- slice = slices[0]
- slivers = []
- tags = self.api.plshell.GetSliceTags(self.api.plauth, slice['slice_tag_ids'])
- for node_id in slice['node_ids']:
+ slice = slices[0]
+
+ # filter out nodes with a whitelist:
+ valid_nodes = []
+ for node in self.nodes.values():
+ # only doing this because protogeni rspec needs
+ # to advertise available initscripts
+ node['pl_initscripts'] = self.pl_initscripts
+
+ if slice and node['node_id'] in slice['node_ids']:
+ valid_nodes.append(node)
+ elif slice and slice['slice_id'] in node['slice_ids_whitelist']:
+ valid_nodes.append(node)
+ elif not slice and not node['slice_ids_whitelist']:
+ valid_nodes.append(node)
+
+ rspec.version.add_nodes(valid_nodes)
+ rspec.version.add_interfaces(self.interfaces.values())
+ rspec.version.add_links(self.links.values())
+
+ # add slivers
+ if slice_xrn and slice:
+ slivers = []
+ tags = self.api.plshell.GetSliceTags(self.api.plauth, slice['slice_tag_ids'])
+
+ # add default tags
+ for tag in tags:
+ # if tag isn't bound to a node then it applies to all slivers
+ # and belongs in the <sliver_defaults> tag
+ if not tag['node_id']:
+ rspec.version.add_default_sliver_attribute(tag['tagname'], tag['value'], self.api.hrn)
+ if tag['tagname'] == 'topo_rspec' and tag['node_id']:
+ node = self.nodes[tag['node_id']]
+ value = eval(tag['value'])
+ for (id, realip, bw, lvip, rvip, vnet) in value:
+ bps = get_tc_rate(bw)
+ remote = self.nodes[id]
+ site1 = self.sites[node['site_id']]
+ site2 = self.sites[remote['site_id']]
+ link1_name = '%s:%s' % (site1['login_base'], site2['login_base'])
+ link2_name = '%s:%s' % (site2['login_base'], site1['login_base'])
+ p_link = None
+ if link1_name in self.links:
+ link = self.links[link1_name]
+ elif link2_name in self.links:
+ link = self.links[link2_name]
+ v_link = Link()
+
+ link.capacity = bps
+ for node_id in slice['node_ids']:
+ try:
sliver = {}
sliver['hostname'] = self.nodes[node_id]['hostname']
+ sliver['node_id'] = node_id
+ sliver['slice_id'] = slice['slice_id']
sliver['tags'] = []
slivers.append(sliver)
+
+ # add tags for this node only
for tag in tags:
- # if tag isn't bound to a node then it applies to all slivers
- if not tag['node_id']:
+ if tag['node_id'] and (tag['node_id'] == node_id):
sliver['tags'].append(tag)
- else:
- tag_host = self.nodes[tag['node_id']]['hostname']
- if tag_host == sliver['hostname']:
- sliver['tags'].append(tag)
- rspec.add_slivers(slivers, sliver_urn=slice_xrn)
+ except:
+ self.api.logger.log_exc('unable to add sliver %s to node %s' % (slice['name'], node_id))
+ rspec.version.add_slivers(slivers, sliver_urn=slice_xrn)
- return rspec.toxml(cleanup=True)
+ return rspec.toxml()
from sfa.util.faults import *
from sfa.util.api import *
from sfa.util.config import *
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
import sfa.util.xmlrpcprotocol as xmlrpcprotocol
from sfa.trust.auth import Auth
from sfa.trust.rights import Right, Rights, determine_rights
self.hrn = self.config.SFA_INTERFACE_HRN
self.time_format = "%Y-%m-%d %H:%M:%S"
- self.logger=sfa_logger()
+
def getPLCShell(self):
self.plauth = {'Username': self.config.SFA_PLC_USER,
'AuthMethod': 'password',
'AuthString': self.config.SFA_PLC_PASSWORD}
- try:
- sys.path.append(os.path.dirname(os.path.realpath("/usr/bin/plcsh")))
- self.plshell_type = 'direct'
- import PLC.Shell
- shell = PLC.Shell.Shell(globals = globals())
- except:
- self.plshell_type = 'xmlrpc'
- url = self.config.SFA_PLC_URL
- shell = xmlrpclib.Server(url, verbose = 0, allow_none = True)
+
+ # The native shell (PLC.Shell.Shell) is more efficient than xmlrpc,
+ # but it leaves idle db connections open. use xmlrpc until we can figure
+ # out why PLC.Shell.Shell doesn't close db connection properly
+ #try:
+ # sys.path.append(os.path.dirname(os.path.realpath("/usr/bin/plcsh")))
+ # self.plshell_type = 'direct'
+ # import PLC.Shell
+ # shell = PLC.Shell.Shell(globals = globals())
+ #except:
+ self.plshell_type = 'xmlrpc'
+ url = self.config.SFA_PLC_URL
+ shell = xmlrpclib.Server(url, verbose = 0, allow_none = True)
return shell
+ def get_server(self, interface, cred, timeout=30):
+ """
+ Returns a connection to the specified interface. Use the specified
+ credential to determine the caller and look for the caller's key/cert
+ in the registry hierarchy cache.
+ """
+ from sfa.trust.hierarchy import Hierarchy
+ if not isinstance(cred, Credential):
+ cred_obj = Credential(string=cred)
+ else:
+ cred_obj = cred
+ caller_gid = cred_obj.get_gid_caller()
+ hierarchy = Hierarchy()
+ auth_info = hierarchy.get_auth_info(caller_gid.get_hrn())
+ key_file = auth_info.get_privkey_filename()
+ cert_file = auth_info.get_gid_filename()
+ server = interface.get_server(key_file, cert_file, timeout)
+ return server
+
+
def getCredential(self):
"""
Return a valid credential for this interface.
cred = Credential(filename = cred_filename)
# make sure cred isnt expired
if not cred.get_expiration or \
- datetime.datetime.today() < cred.get_expiration():
+ datetime.datetime.utcnow() < cred.get_expiration():
return cred.save_to_string(save_parents=True)
# get a new credential
Attempt to find a credential delegated to us in
the specified list of creds.
"""
+ from sfa.trust.hierarchy import Hierarchy
if creds and not isinstance(creds, list):
creds = [creds]
- delegated_creds = filter_creds_by_caller(creds,self.hrn)
- if not delegated_creds:
- return None
- return delegated_creds[0]
+ hierarchy = Hierarchy()
+
+ delegated_cred = None
+ for cred in creds:
+ if hierarchy.auth_exists(Credential(string=cred).get_gid_caller().get_hrn()):
+ delegated_cred = cred
+ break
+ return delegated_cred
def __getCredential(self):
"""
Get our credential from a remote registry
"""
from sfa.server.registry import Registries
- registries = Registries(self)
- registry = registries[self.hrn]
+ registries = Registries()
+ registry = registries.get_server(self.hrn, self.key_file, self.cert_file)
cert_string=self.cert.save_to_string(save_parents=True)
# get self credential
self_cred = registry.GetSelfCredential(cert_string, self.hrn, 'authority')
auth_hrn = hrn
auth_info = self.auth.get_auth_info(auth_hrn)
table = self.SfaTable()
- records = table.findObjects(hrn)
+ records = table.findObjects({'hrn': hrn, 'type': 'authority+sa'})
if not records:
raise RecordNotFound
record = records[0]
except IOError:
self.credential = self.getCredentialFromRegistry()
+
+
##
# Convert SFA fields to PLC fields for use when registering up updating
# registry record in the PLC database
# fill in key info
if record['type'] == 'user':
if 'key_ids' not in record:
- self.logger.info("user record has no 'key_ids' - need to import from myplc ?")
+ logger.info("user record has no 'key_ids' - need to import from myplc ?")
else:
pubkeys = [keys[key_id]['key'] for key_id in record['key_ids'] if key_id in keys]
record['keys'] = pubkeys
elif (type.startswith("authority")):
record['url'] = None
if record['hrn'] in self.aggregates:
- record['url'] = self.aggregates[record['hrn']].url
+
+ record['url'] = self.aggregates[record['hrn']].get_url()
if record['pointer'] != -1:
record['PI'] = []
from xmlbuilder import XMLBuilder
from sfa.util.faults import *
-#from sfa.util.sfalogging import sfa_logger
from sfa.util.xrn import get_authority
from sfa.util.plxrn import hrn_to_pl_slicename, hostname_to_urn
from sfa.util.plxrn import hostname_to_hrn, slicename_to_hrn, email_to_hrn, hrn_to_pl_slicename
from sfa.util.config import Config
from sfa.trust.certificate import convert_public_key, Keypair
-from sfa.trust.trustedroot import *
+from sfa.trust.trustedroots import *
from sfa.trust.hierarchy import *
from sfa.util.xrn import Xrn
from sfa.plc.api import *
from sfa.trust.gid import create_uuid
-from sfa.plc.sfaImport import sfaImport
+from sfa.plc.sfaImport import sfaImport, _cleanup_string
def process_options():
f.write("keys = %s" % str(keys))
f.close()
+def _get_site_hrn(interface_hrn, site):
+ # Hardcode 'internet2' into the hrn for sites hosting
+ # internet2 nodes. This is a special operation for some vini
+ # sites only
+ hrn = ".".join([interface_hrn, site['login_base']])
+ if ".vini" in interface_hrn and interface_hrn.endswith('vini'):
+ if site['login_base'].startswith("i2") or site['login_base'].startswith("nlr"):
+ hrn = ".".join([interface_hrn, "internet2", site['login_base']])
+ return hrn
+
def main():
process_options()
if not root_auth == interface_hrn:
sfaImporter.create_top_level_auth_records(interface_hrn)
+ # create s user record for the slice manager
+ sfaImporter.create_sm_client_record()
+
# create interface records
sfaImporter.logger.info("Import: creating interface records")
sfaImporter.create_interface_records()
slices_dict[slice['slice_id']] = slice
# start importing
for site in sites:
- site_hrn = interface_hrn + "." + site['login_base']
- sfa_logger().info("Importing site: %s" % site_hrn)
+ site_hrn = _get_site_hrn(interface_hrn, site)
+ sfaImporter.logger.info("Importing site: %s" % site_hrn)
# import if hrn is not in list of existing hrns or if the hrn exists
# but its not a site record
if site_hrn not in existing_hrns or \
(site_hrn, 'authority') not in existing_records:
- site_hrn = sfaImporter.import_site(interface_hrn, site)
+ sfaImporter.import_site(site_hrn, site)
# import node records
for node_id in site['node_ids']:
if node_id not in nodes_dict:
continue
node = nodes_dict[node_id]
- hrn = hostname_to_hrn(interface_hrn, site['login_base'], node['hostname'])
+ site_auth = get_authority(site_hrn)
+ site_name = get_leaf(site_hrn)
+ hrn = hostname_to_hrn(site_auth, site_name, node['hostname'])
if hrn not in existing_hrns or \
(hrn, 'node') not in existing_records:
sfaImporter.import_node(hrn, node)
(hrn, 'user') not in existing_records or update_record:
sfaImporter.import_person(site_hrn, person)
+
# remove stale records
+ system_records = [interface_hrn, root_auth, interface_hrn + '.slicemanager']
for (record_hrn, type) in existing_records.keys():
+ if record_hrn in system_records:
+ continue
+
record = existing_records[(record_hrn, type)]
- # if this is the interface name dont do anything
- if record_hrn == interface_hrn or \
- record_hrn == root_auth or \
- record['peer_authority']:
+ if record['peer_authority']:
continue
+
# dont delete vini's internet2 placeholdder record
# normally this would be deleted becuase it does not have a plc record
if ".vini" in interface_hrn and interface_hrn.endswith('vini') and \
from sfa.trust.hierarchy import *
from sfa.util.record import *
from sfa.util.table import SfaTable
-from sfa.util.sfalogging import sfa_logger_goes_to_import,sfa_logger
+from sfa.util.sfalogging import logger
def main():
usage="%prog: trash the registry DB (the 'sfa' table in the 'planetlab5' database)"
parser = OptionParser(usage=usage)
parser.add_option('-f','--file-system',dest='clean_fs',action='store_true',default=False,
help='Clean up the /var/lib/sfa/authorities area as well')
+ parser.add_option('-c','--certs',dest='clean_certs',action='store_true',default=False,
+ help='Remove all cached certs/gids found in /var/lib/sfa/authorities area as well')
(options,args)=parser.parse_args()
if args:
parser.print_help()
sys.exit(1)
- sfa_logger_goes_to_import()
- sfa_logger().info("Purging SFA records from database")
+ logger.info("Purging SFA records from database")
table = SfaTable()
table.sfa_records_purge()
+
+ if options.clean_certs:
+ # remove the server certificate and all gids found in /var/lib/sfa/authorities
+ logger.info("Purging cached certificates")
+ for (dir, _, files) in os.walk('/var/lib/sfa/authorities'):
+ for file in files:
+ if file.endswith('.gid') or file == 'server.cert':
+ path=dir+os.sep+file
+ os.unlink(path)
+ if not os.path.exists(path):
+ logger.info("Unlinked file %s"%path)
+ else:
+ logger.error("Could not unlink file %s"%path)
+
if options.clean_fs:
# just remove all files that do not match 'server.key' or 'server.cert'
+ logger.info("Purging registry filesystem cache")
preserved_files = [ 'server.key', 'server.cert']
for (dir,_,files) in os.walk('/var/lib/sfa/authorities'):
for file in files:
path=dir+os.sep+file
os.unlink(path)
if not os.path.exists(path):
- sfa_logger().info("Unlinked file %s"%path)
+ logger.info("Unlinked file %s"%path)
else:
- sfa_logger().error("Could not unlink file %s"%path)
+ logger.error("Could not unlink file %s"%path)
if __name__ == "__main__":
main()
import sys
import tempfile
-from sfa.util.sfalogging import sfa_logger_goes_to_import,sfa_logger
+from sfa.util.sfalogging import _SfaLogger
from sfa.util.record import *
from sfa.util.table import SfaTable
from sfa.util.plxrn import email_to_hrn
from sfa.util.config import Config
from sfa.trust.certificate import convert_public_key, Keypair
-from sfa.trust.trustedroot import *
+from sfa.trust.trustedroots import TrustedRoots
from sfa.trust.hierarchy import *
from sfa.trust.gid import create_uuid
class sfaImport:
def __init__(self):
- sfa_logger_goes_to_import()
- self.logger = sfa_logger()
+ self.logger = _SfaLogger(logfile='/var/log/sfa_import.log', loggername='importlog')
self.AuthHierarchy = Hierarchy()
self.config = Config()
- self.TrustedRoots = TrustedRootList(Config.get_trustedroots_dir(self.config))
+ self.TrustedRoots = TrustedRoots(Config.get_trustedroots_dir(self.config))
self.plc_auth = self.config.get_plc_auth()
self.root_auth = self.config.SFA_REGISTRY_ROOT_AUTH
self.logger.info("Import: inserting authority record for %s"%hrn)
table.insert(auth_record)
+ def create_sm_client_record(self):
+ """
+ Create a user record for the Slicemanager service.
+ """
+ hrn = self.config.SFA_INTERFACE_HRN + '.slicemanager'
+ urn = hrn_to_urn(hrn, 'user')
+ if not self.AuthHierarchy.auth_exists(urn):
+ self.logger.info("Import: creating Slice Manager user")
+ self.AuthHierarchy.create_auth(urn)
+
+ auth_info = self.AuthHierarchy.get_auth_info(hrn)
+ table = SfaTable()
+ sm_user_record = table.find({'type': 'user', 'hrn': hrn})
+ if not sm_user_record:
+ record = SfaRecord(hrn=hrn, gid=auth_info.get_gid_object(), type="user", pointer=-1)
+ record['authority'] = get_authority(record['hrn'])
+ table.insert(record)
+
def create_interface_records(self):
"""
Create a record for each SFA interface
record = SfaRecord(hrn=interface_hrn, gid=gid, type=interface, pointer=-1)
record['authority'] = get_authority(interface_hrn)
table.insert(record)
+
+
def import_person(self, parent_hrn, person):
"""
Register a user record
# to planetlab
keys = self.shell.GetKeys(self.plc_auth, key_ids)
key = keys[0]['key']
- pkey = convert_public_key(key)
+ pkey = None
+ try:
+ pkey = convert_public_key(key)
+ except:
+ self.logger.warn('unable to convert public key for %s' % hrn)
if not pkey:
pkey = Keypair(create=True)
else:
# the user has no keys
- self.logger.warning("Import: person %s does not have a PL public key"%hrn)
+ self.logger.warn("Import: person %s does not have a PL public key"%hrn)
# if a key is unavailable, then we still need to put something in the
# user's GID. So make one up.
pkey = Keypair(create=True)
table.update(node_record)
- def import_site(self, parent_hrn, site):
+ def import_site(self, hrn, site):
shell = self.shell
plc_auth = self.plc_auth
- sitename = site['login_base']
- sitename = _cleanup_string(sitename)
- hrn = parent_hrn + "." + sitename
- # Hardcode 'internet2' into the hrn for sites hosting
- # internet2 nodes. This is a special operation for some vini
- # sites only
- if ".vini" in parent_hrn and parent_hrn.endswith('vini'):
- if sitename.startswith("i2"):
- #sitename = sitename.replace("ii", "")
- hrn = ".".join([parent_hrn, "internet2", sitename])
- elif sitename.startswith("nlr"):
- #sitename = sitename.replace("nlr", "")
- hrn = ".".join([parent_hrn, "internet2", sitename])
-
urn = hrn_to_urn(hrn, 'authority')
self.logger.info("Import: site %s"%hrn)
import sys
from types import StringTypes
-from sfa.util.xrn import get_leaf, get_authority, hrn_to_urn, urn_to_hrn
-from sfa.util.plxrn import hrn_to_pl_slicename
-from sfa.util.rspec import *
+from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn, urn_to_hrn
+from sfa.util.plxrn import hrn_to_pl_slicename, hrn_to_pl_login_base
from sfa.util.specdict import *
from sfa.util.faults import *
from sfa.util.record import SfaRecord
from sfa.util.policy import Policy
+from sfa.plc.vlink import VLink
from sfa.util.prefixTree import prefixTree
+from collections import defaultdict
MAXINT = 2L**31-1
#filepath = path + os.sep + filename
self.policy = Policy(self.api)
self.origin_hrn = origin_hrn
+ self.registry = api.registries[api.hrn]
+ self.credential = api.getCredential()
+ self.nodes = []
+ self.persons = []
def get_slivers(self, xrn, node=None):
hrn, type = urn_to_hrn(xrn)
for peer_record in peers:
names = [name.lower() for name in peer_record.values() if isinstance(name, StringTypes)]
if site_authority in names:
- peer = peer_record['shortname']
+ peer = peer_record
return peer
if site_authority != self.api.hrn:
sfa_peer = site_authority
- return sfa_peer
+ return sfa_peer
- def verify_site(self, registry, credential, slice_hrn, peer, sfa_peer, reg_objects=None):
- authority = get_authority(slice_hrn)
- authority_urn = hrn_to_urn(authority, 'authority')
- login_base = None
- if reg_objects:
- site = reg_objects['site']
- login_base = site['login_base']
- else:
- site_records = registry.Resolve(authority_urn, [credential])
- site = {}
- for site_record in site_records:
- if site_record['type'] == 'authority':
- site = site_record
- if not site:
- raise RecordNotFound(authority)
-
- remote_site_id = site.pop('site_id')
+ def verify_slice_nodes(self, slice, requested_slivers, peer):
- if login_base is None:
- login_base = get_leaf(authority)
- sites = self.api.plshell.GetSites(self.api.plauth, login_base)
+ nodes = self.api.plshell.GetNodes(self.api.plauth, slice['node_ids'], ['hostname'])
+ current_slivers = [node['hostname'] for node in nodes]
- if not sites:
- site_id = self.api.plshell.AddSite(self.api.plauth, site)
- if peer:
- try:
- self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', site_id, peer, remote_site_id)
- except Exception,e:
- self.api.plshell.DeleteSite(self.api.plauth, site_id)
- raise e
- # mark this site as an sfa peer record
- if sfa_peer and not reg_objects:
- peer_dict = {'type': 'authority', 'hrn': authority, 'peer_authority': sfa_peer, 'pointer': site_id}
- registry.register_peer_object(credential, peer_dict)
+ # remove nodes not in rspec
+ deleted_nodes = list(set(current_slivers).difference(requested_slivers))
- # exempt federated sites from monitor policies
- self.api.plshell.AddSiteTag(site_id, 'exempt_site_until', "20200101")
-
- else:
- site_id = sites[0]['site_id']
- remote_site_id = sites[0]['peer_site_id']
- old_site = sites[0]
- #the site is already on the remote agg. Let us update(e.g. max_slices field) it with the latest info.
- self.sync_site(old_site, site, peer)
+ # add nodes from rspec
+ added_nodes = list(set(requested_slivers).difference(current_slivers))
+ try:
+ if peer:
+ self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'slice', slice['slice_id'], peer['shortname'])
+ self.api.plshell.AddSliceToNodes(self.api.plauth, slice['name'], added_nodes)
+ self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slice['name'], deleted_nodes)
+
+ except:
+ self.api.logger.log_exc('Failed to add/remove slice from nodes')
+
+ def verify_slice_links(self, slice, links, peer=None):
+ if not links or not nodes:
+ return
+ for link in links:
+ topo_rspec = VLink.get_topo_rspec(link)
+
- return (site_id, remote_site_id)
+ def handle_peer(self, site, slice, persons, peer):
+ if peer:
+ # bind site
+ try:
+ if site:
+ self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', \
+ site['site_id'], peer['shortname'], slice['site_id'])
+ except Exception,e:
+ self.api.plshell.DeleteSite(self.api.plauth, site['site_id'])
+ raise e
+
+ # bind slice
+ try:
+ if slice:
+ self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', \
+ slice['slice_id'], peer['shortname'], slice['slice_id'])
+ except Exception,e:
+ self.api.plshell.DeleteSlice(self.api.plauth, slice['slice_id'])
+ raise e
+
+ # bind persons
+ for person in persons:
+ try:
+ self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', \
+ person['person_id'], peer['shortname'], person['peer_person_id'])
+
+ for (key, remote_key_id) in zip(person['keys'], person['key_ids']):
+ try:
+ self.api.plshell.BindObjectToPeer(self.api.plauth, 'key',\
+ key['key_id'], peer['shortname'], remote_key_id)
+ except:
+ self.api.plshell.DeleteKey(self.api.plauth, key['key_id'])
+ self.api.logger("failed to bind key: %s to peer: %s " % (key['key_id'], peer['shortname']))
+ except Exception,e:
+ self.api.plshell.DeletePerson(self.api.plauth, person['person_id'])
+ raise e
- def verify_slice(self, registry, credential, slice_hrn, site_id, remote_site_id, peer, sfa_peer, reg_objects=None):
- slice = {}
- slice_record = None
- authority = get_authority(slice_hrn)
+ return slice
- if reg_objects:
- slice_record = reg_objects['slice_record']
- else:
- slice_records = registry.Resolve(slice_hrn, [credential])
-
- for record in slice_records:
- if record['type'] in ['slice']:
- slice_record = record
- if not slice_record:
- raise RecordNotFound(hrn)
+ def verify_site(self, slice_xrn, slice_record={}, peer=None, sfa_peer=None):
+ (slice_hrn, type) = urn_to_hrn(slice_xrn)
+ site_hrn = get_authority(slice_hrn)
+ # login base can't be longer than 20 characters
+ slicename = hrn_to_pl_slicename(slice_hrn)
+ authority_name = slicename.split('_')[0]
+ login_base = authority_name[:20]
+ sites = self.api.plshell.GetSites(self.api.plauth, login_base)
+ if not sites:
+ # create new site record
+ site = {'name': 'geni.%s' % authority_name,
+ 'abbreviated_name': authority_name,
+ 'login_base': login_base,
+ 'max_slices': 100,
+ 'max_slivers': 1000,
+ 'enabled': True,
+ 'peer_site_id': None}
+ if peer:
+ site['peer_site_id'] = slice_record.get('site_id', None)
+ site['site_id'] = self.api.plshell.AddSite(self.api.plauth, site)
+ # exempt federated sites from monitor policies
+ self.api.plshell.AddSiteTag(self.api.plauth, site['site_id'], 'exempt_site_until', "20200101")
+ # is this still necessary?
+ # add record to the local registry
+ if sfa_peer and slice_record:
+ peer_dict = {'type': 'authority', 'hrn': site_hrn, \
+ 'peer_authority': sfa_peer, 'pointer': site['site_id']}
+ self.registry.register_peer_object(self.credential, peer_dict)
+ else:
+ site = sites[0]
+ if peer:
+ # unbind from peer so we can modify if necessary. Will bind back later
+ self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'site', site['site_id'], peer['shortname'])
+ return site
+
+ def verify_slice(self, slice_hrn, slice_record, peer, sfa_peer):
slicename = hrn_to_pl_slicename(slice_hrn)
parts = slicename.split("_")
login_base = parts[0]
slices = self.api.plshell.GetSlices(self.api.plauth, [slicename])
if not slices:
- slice_fields = {}
- slice_keys = ['name', 'url', 'description']
- for key in slice_keys:
- if key in slice_record and slice_record[key]:
- slice_fields[key] = slice_record[key]
+ slice = {'name': slicename,
+ 'url': slice_record.get('url', slice_hrn),
+ 'description': slice_record.get('description', slice_hrn)}
# add the slice
- slice_id = self.api.plshell.AddSlice(self.api.plauth, slice_fields)
- slice = slice_fields
- slice['slice_id'] = slice_id
-
+ slice['slice_id'] = self.api.plshell.AddSlice(self.api.plauth, slice)
+ slice['node_ids'] = []
+ slice['person_ids'] = []
+ if peer:
+ slice['peer_slice_id'] = slice_record.get('slice_id', None)
# mark this slice as an sfa peer record
if sfa_peer:
- peer_dict = {'type': 'slice', 'hrn': slice_hrn, 'peer_authority': sfa_peer, 'pointer': slice_id}
- registry.register_peer_object(credential, peer_dict)
-
- #this belongs to a peer
- if peer:
- try:
- self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', slice_id, peer, slice_record['pointer'])
- except Exception,e:
- self.api.plshell.DeleteSlice(self.api.plauth,slice_id)
- raise e
- slice['node_ids'] = []
+ peer_dict = {'type': 'slice', 'hrn': slice_hrn,
+ 'peer_authority': sfa_peer, 'pointer': slice['slice_id']}
+ self.registry.register_peer_object(self.credential, peer_dict)
else:
slice = slices[0]
- slice_id = slice['slice_id']
- site_id = slice['site_id']
- #the slice is alredy on the remote agg. Let us update(e.g. expires field) it with the latest info.
- self.sync_slice(slice, slice_record, peer)
-
- slice['peer_slice_id'] = slice_record['pointer']
- self.verify_persons(registry, credential, slice_record, site_id, remote_site_id, peer, sfa_peer, reg_objects)
-
- return slice
-
- def verify_persons(self, registry, credential, slice_record, site_id, remote_site_id, peer, sfa_peer, reg_objects=None):
- # get the list of valid slice users from the registry and make
- # sure they are added to the slice
- slicename = hrn_to_pl_slicename(slice_record['hrn'])
- if reg_objects:
- researchers = reg_objects['users'].keys()
- else:
- researchers = slice_record.get('researcher', [])
- for researcher in researchers:
- if reg_objects:
- person_dict = reg_objects['users'][researcher]
- else:
- person_records = registry.Resolve(researcher, [credential])
- for record in person_records:
- if record['type'] in ['user'] and record['enabled']:
- person_record = record
- if not person_record:
- return 1
- person_dict = person_record
-
- local_person=False
if peer:
- peer_id = self.api.plshell.GetPeers(self.api.plauth, {'shortname': peer}, ['peer_id'])[0]['peer_id']
- persons = self.api.plshell.GetPersons(self.api.plauth, {'email': [person_dict['email']], 'peer_id': peer_id}, ['person_id', 'key_ids'])
- if not persons:
- persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])
- if persons:
- local_person=True
-
- else:
- persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])
-
- if not persons:
- person_id=self.api.plshell.AddPerson(self.api.plauth, person_dict)
- self.api.plshell.UpdatePerson(self.api.plauth, person_id, {'enabled' : True})
+ slice['peer_slice_id'] = slice_record.get('slice_id', None)
+ # unbind from peer so we can modify if necessary. Will bind back later
+ self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'slice',\
+ slice['slice_id'], peer['shortname'])
+ #Update existing record (e.g. expires field) it with the latest info.
+ if slice_record and slice['expires'] != slice_record['expires']:
+ self.api.plshell.UpdateSlice(self.api.plauth, slice['slice_id'],\
+ {'expires' : slice_record['expires']})
+
+ return slice
+
+ #def get_existing_persons(self, users):
+ def verify_persons(self, slice_hrn, slice_record, users, peer, sfa_peer, append=True):
+ users_by_email = {}
+ users_by_site = defaultdict(list)
+
+ users_dict = {}
+ for user in users:
+ if 'append' in user and user['append'] == False:
+ append = False
+ if 'email' in user:
+ users_by_email[user['email']] = user
+ users_dict[user['email']] = user
+ elif 'urn' in user:
+ hrn, type = urn_to_hrn(user['urn'])
+ username = get_leaf(hrn)
+ login_base = get_leaf(get_authority(user['urn']))
+ user['username'] = username
+ users_by_site[login_base].append(user)
+
+ existing_user_ids = []
+ if users_by_email:
+ # get existing users by email
+ existing_users = self.api.plshell.GetPersons(self.api.plauth, \
+ {'email': users_by_email.keys()}, ['person_id', 'key_ids', 'email'])
+ existing_user_ids.extend([user['email'] for user in existing_users])
+
+ if users_by_site:
+ # get a list of user sites (based on requeste user urns
+ site_list = self.api.plshell.GetSites(self.api.plauth, users_by_site.keys(), \
+ ['site_id', 'login_base', 'person_ids'])
+ sites = {}
+ site_user_ids = []
+
+ # get all existing users at these sites
+ for site in site_list:
+ sites[site['site_id']] = site
+ site_user_ids.extend(site['person_ids'])
+
+ existing_site_persons_list = self.api.plshell.GetPersons(self.api.plauth, \
+ site_user_ids, ['person_id', 'key_ids', 'email', 'site_ids'])
+
+ # all requested users are either existing users or new (added) users
+ for login_base in users_by_site:
+ requested_site_users = users_by_site[login_base]
+ for requested_user in requested_site_users:
+ user_found = False
+ for existing_user in existing_site_persons_list:
+ for site_id in existing_user['site_ids']:
+ site = sites[site_id]
+ if login_base == site['login_base'] and \
+ existing_user['email'].startswith(requested_user['username']):
+ existing_user_ids.append(existing_user['email'])
+ users_dict[existing_user['email']] = requested_user
+ user_found = True
+ break
+ if user_found:
+ break
+
+ if user_found == False:
+ fake_email = requested_user['username'] + '@geni.net'
+ users_dict[fake_email] = requested_user
- # mark this person as an sfa peer record
- if sfa_peer:
- peer_dict = {'type': 'user', 'hrn': researcher, 'peer_authority': sfa_peer, 'pointer': person_id}
- registry.register_peer_object(credential, peer_dict)
- if peer:
- try:
- self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
- except Exception,e:
- self.api.plshell.DeletePerson(self.api.plauth,person_id)
- raise e
- key_ids = []
- else:
- person_id = persons[0]['person_id']
- key_ids = persons[0]['key_ids']
+ # requested slice users
+ requested_user_ids = users_dict.keys()
+ # existing slice users
+ existing_slice_users_filter = {'person_id': slice_record.get('person_ids', [])}
+ existing_slice_users = self.api.plshell.GetPersons(self.api.plauth, \
+ existing_slice_users_filter, ['person_id', 'key_ids', 'email'])
+ existing_slice_user_ids = [user['email'] for user in existing_slice_users]
+
+ # users to be added, removed or updated
+ added_user_ids = set(requested_user_ids).difference(existing_user_ids)
+ added_slice_user_ids = set(requested_user_ids).difference(existing_slice_user_ids)
+ removed_user_ids = set(existing_slice_user_ids).difference(requested_user_ids)
+ updated_user_ids = set(existing_slice_user_ids).intersection(requested_user_ids)
+
+ # Remove stale users (only if we are not appending).
+ if append == False:
+ for removed_user_id in removed_user_ids:
+ self.api.plshell.DeletePersonFromSlice(self.api.plauth, removed_user_id, slice_record['name'])
+ # update_existing users
+ updated_users_list = [user for user in existing_slice_users if user['email'] in \
+ updated_user_ids]
+ self.verify_keys(existing_slice_users, updated_users_list, peer, append)
+
+ added_persons = []
+ # add new users
+ for added_user_id in added_user_ids:
+ added_user = users_dict[added_user_id]
+ hrn, type = urn_to_hrn(added_user['urn'])
+ person = {
+ 'first_name': added_user.get('first_name', hrn),
+ 'last_name': added_user.get('last_name', hrn),
+ 'email': added_user_id,
+ 'peer_person_id': None,
+ 'keys': [],
+ 'key_ids': added_user.get('key_ids', []),
+ }
+ person['person_id'] = self.api.plshell.AddPerson(self.api.plauth, person)
+ if peer:
+ person['peer_person_id'] = added_user['person_id']
+ added_persons.append(person)
+
+ # enable the account
+ self.api.plshell.UpdatePerson(self.api.plauth, person['person_id'], {'enabled': True})
+
+ # add person to site
+ self.api.plshell.AddPersonToSite(self.api.plauth, added_user_id, login_base)
+ for key_string in added_user.get('keys', []):
+ key = {'key':key_string, 'key_type':'ssh'}
+ key['key_id'] = self.api.plshell.AddPersonKey(self.api.plauth, person['person_id'], key)
+ person['keys'].append(key)
- # if this is a peer person, we must unbind them from the peer or PLCAPI will throw
- # an error
- try:
- if peer and not local_person:
- self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'person', person_id, peer)
- if peer:
- self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'site', site_id, peer)
-
- self.api.plshell.AddPersonToSlice(self.api.plauth, person_dict['email'], slicename)
- self.api.plshell.AddPersonToSite(self.api.plauth, person_dict['email'], site_id)
- finally:
- if peer:
- try: self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', site_id, peer, remote_site_id)
- except: pass
- if peer and not local_person:
- try: self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
- except: pass
+ # add the registry record
+ if sfa_peer:
+ peer_dict = {'type': 'user', 'hrn': hrn, 'peer_authority': sfa_peer, \
+ 'pointer': person['person_id']}
+ self.registry.register_peer_object(self.credential, peer_dict)
+
+ for added_slice_user_id in added_slice_user_ids.union(added_user_ids):
+ # add person to the slice
+ self.api.plshell.AddPersonToSlice(self.api.plauth, added_slice_user_id, slice_record['name'])
+ # if this is a peer record then it should already be bound to a peer.
+ # no need to return worry about it getting bound later
+
+ return added_persons
- self.verify_keys(registry, credential, person_dict, key_ids, person_id, peer, local_person)
- def verify_keys(self, registry, credential, person_dict, key_ids, person_id, peer, local_person):
- keylist = self.api.plshell.GetKeys(self.api.plauth, key_ids, ['key'])
- keys = [key['key'] for key in keylist]
+ def verify_keys(self, persons, users, peer, append=True):
+ # existing keys
+ key_ids = []
+ for person in persons:
+ key_ids.extend(person['key_ids'])
+ keylist = self.api.plshell.GetKeys(self.api.plauth, key_ids, ['key_id', 'key'])
+ keydict = {}
+ for key in keylist:
+ keydict[key['key']] = key['key_id']
+ existing_keys = keydict.keys()
+ persondict = {}
+ for person in persons:
+ persondict[person['email']] = person
+
+ # add new keys
+ requested_keys = []
+ updated_persons = []
+ for user in users:
+ user_keys = user.get('keys', [])
+ updated_persons.append(user)
+ for key_string in user_keys:
+ requested_keys.append(key_string)
+ if key_string not in existing_keys:
+ key = {'key': key_string, 'key_type': 'ssh'}
+ try:
+ if peer:
+ person = persondict[user['email']]
+ self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'person', person['person_id'], peer['shortname'])
+ key['key_id'] = self.api.plshell.AddPersonKey(self.api.plauth, user['email'], key)
+ if peer:
+ key_index = user_keys.index(key['key'])
+ remote_key_id = user['key_ids'][key_index]
+ self.api.plshell.BindObjectToPeer(self.api.plauth, 'key', key['key_id'], peer['shortname'], remote_key_id)
+
+ finally:
+ if peer:
+ self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person['person_id'], peer['shortname'], user['person_id'])
- #add keys that arent already there
- key_ids = person_dict['key_ids']
- for personkey in person_dict['keys']:
- if personkey not in keys:
- key = {'key_type': 'ssh', 'key': personkey}
- try:
- if peer:
- self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'person', person_id, peer)
- key_id = self.api.plshell.AddPersonKey(self.api.plauth, person_dict['email'], key)
- finally:
- if peer and not local_person:
- self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
- if peer:
- # xxx - thierry how are we getting the peer_key_id in here ?
- try: self.api.plshell.BindObjectToPeer(self.api.plauth, 'key', key_id, peer, key_ids.pop(0))
- except: pass
+ # remove old keys (only if we are not appending)
+ if append == False:
+ removed_keys = set(existing_keys).difference(requested_keys)
+ for existing_key_id in keydict:
+ if keydict[existing_key_id] in removed_keys:
+ try:
+ if peer:
+ self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'key', existing_key_id, peer['shortname'])
+ self.api.plshell.DeleteKey(self.api.plauth, existing_key_id)
+ except:
+ pass
+
+ def verify_slice_attributes(self, slice, requested_slice_attributes):
+ # get list of attributes users ar able to manage
+ slice_attributes = self.api.plshell.GetTagTypes(self.api.plauth, {'category': '*slice*', '|roles': ['user']})
+ valid_slice_attribute_names = [attribute['tagname'] for attribute in slice_attributes]
+
+ # get sliver attributes
+ added_slice_attributes = []
+ removed_slice_attributes = []
+ ignored_slice_attribute_names = []
+ existing_slice_attributes = self.api.plshell.GetSliceTags(self.api.plauth, {'slice_id': slice['slice_id']})
+
+ # get attributes that should be removed
+ for slice_tag in existing_slice_attributes:
+ if slice_tag['tagname'] in ignored_slice_attribute_names:
+ # If a slice already has a admin only role it was probably given to them by an
+ # admin, so we should ignore it.
+ ignored_slice_attribute_names.append(slice_tag['tagname'])
+ else:
+ # If an existing slice attribute was not found in the request it should
+ # be removed
+ attribute_found=False
+ for requested_attribute in requested_slice_attributes:
+ if requested_attribute['name'] == slice_tag['tagname'] and \
+ requested_attribute['value'] == slice_tag['value']:
+ attribute_found=True
+ break
+
+ if not attribute_found:
+ removed_slice_attributes.append(slice_tag)
+
+ # get attributes that should be added:
+ for requested_attribute in requested_slice_attributes:
+ # if the requested attribute wasn't found we should add it
+ if requested_attribute['name'] in valid_slice_attribute_names:
+ attribute_found = False
+ for existing_attribute in existing_slice_attributes:
+ if requested_attribute['name'] == existing_attribute['tagname'] and \
+ requested_attribute['value'] == existing_attribute['value']:
+ attribute_found=True
+ break
+ if not attribute_found:
+ added_slice_attributes.append(requested_attribute)
+
+
+ # remove stale attributes
+ for attribute in removed_slice_attributes:
+ try:
+ self.api.plshell.DeleteSliceTag(self.api.plauth, attribute['slice_tag_id'])
+ except Exception, e:
+ self.api.logger.warn('Failed to remove sliver attribute. name: %s, value: %s, node_id: %s\nCause:%s'\
+ % (name, value, node_id, str(e)))
+
+ # add requested_attributes
+ for attribute in added_slice_attributes:
+ try:
+ name, value, node_id = attribute['name'], attribute['value'], attribute.get('node_id', None)
+ self.api.plshell.AddSliceTag(self.api.plauth, slice['name'], name, value, node_id)
+ except Exception, e:
+ self.api.logger.warn('Failed to add sliver attribute. name: %s, value: %s, node_id: %s\nCause:%s'\
+ % (name, value, node_id, str(e)))
def create_slice_aggregate(self, xrn, rspec):
hrn, type = urn_to_hrn(xrn)
return 1
- def sync_site(self, old_record, new_record, peer):
- if old_record['max_slices'] != new_record['max_slices'] or old_record['max_slivers'] != new_record['max_slivers']:
- try:
- if peer:
- self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'site', old_record['site_id'], peer)
- if old_record['max_slices'] != new_record['max_slices']:
- self.api.plshell.UpdateSite(self.api.plauth, old_record['site_id'], {'max_slices' : new_record['max_slices']})
- if old_record['max_slivers'] != new_record['max_slivers']:
- self.api.plshell.UpdateSite(self.api.plauth, old_record['site_id'], {'max_slivers' : new_record['max_slivers']})
- finally:
- if peer:
- self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', old_record['site_id'], peer, old_record['peer_site_id'])
- return 1
-
- def sync_slice(self, old_record, new_record, peer):
- if old_record['expires'] != new_record['expires']:
- try:
- if peer:
- self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'slice', old_record['slice_id'], peer)
- self.api.plshell.UpdateSlice(self.api.plauth, old_record['slice_id'], {'expires' : new_record['expires']})
- finally:
- if peer:
- self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', old_record['slice_id'], peer, old_record['peer_slice_id'])
- return 1
--- /dev/null
+from sfa.plc.aggregate import Aggregate
+from sfa.managers.vini.topology import PhysicalLinks
+from sfa.rspecs.elements.link import Link
+from sfa.util.xrn import hrn_to_urn
+from sfa.util.plxrn import PlXrn
+
+class ViniAggregate(Aggregate):
+
+ def prepare_links(self, force=False):
+ for (site_id1, site_id2) in PhysicalLinks:
+ link = Link()
+ if not site_id1 in self.sites or site_id2 not in self.sites:
+ continue
+ site1 = self.sites[site_id1]
+ site2 = self.sites[site_id2]
+ # get hrns
+ site1_hrn = self.api.hrn + '.' + site1['login_base']
+ site2_hrn = self.api.hrn + '.' + site2['login_base']
+ # get the first node
+ node1 = self.nodes[site1['node_id'][0]]
+ node2 = self.nodes[site2['node_id'][0]]
+
+ # set interfaces
+ # just get first interface of the first node
+ if1_xrn = PlXrn(auth=self.api.hrn, interface='node%s:eth0' % (node1['node_id']))
+ if2_xrn = PlXrn(auth=self.api.hrn, interface='node%s:eth0' % (node2['node_id']))
+
+ if1 = Interface({'component_id': if1_xrn.urn} )
+ if2 = Interface({'component_id': if2_xrn.urn} )
+
+ # set link
+ link = Link({'capacity': '1000000', 'latency': '0', 'packet_loss': '0', 'type': 'ipv4'})
+ link['interface1'] = if1
+ link['interface2'] = if2
+ link['component_name'] = "%s:%s" % (site1['login_base'], site2['login_base'])
+ link['component_id'] = PlXrn(auth=self.api.hrn, link=link['component_name'])
+ link['component_manager_id'] = hrn_to_urn(self.api.hrn, 'authority+am')
+ self.links[link['component_name']] = link
+
+
--- /dev/null
+
+from sfa.util.plxrn import PlXrn
+# Taken from bwlimit.py
+#
+# See tc_util.c and http://physics.nist.gov/cuu/Units/binary.html. Be
+# warned that older versions of tc interpret "kbps", "mbps", "mbit",
+# and "kbit" to mean (in this system) "kibps", "mibps", "mibit", and
+# "kibit" and that if an older version is installed, all rates will
+# be off by a small fraction.
+suffixes = {
+ "": 1,
+ "bit": 1,
+ "kibit": 1024,
+ "kbit": 1000,
+ "mibit": 1024*1024,
+ "mbit": 1000000,
+ "gibit": 1024*1024*1024,
+ "gbit": 1000000000,
+ "tibit": 1024*1024*1024*1024,
+ "tbit": 1000000000000,
+ "bps": 8,
+ "kibps": 8*1024,
+ "kbps": 8000,
+ "mibps": 8*1024*1024,
+ "mbps": 8000000,
+ "gibps": 8*1024*1024*1024,
+ "gbps": 8000000000,
+ "tibps": 8*1024*1024*1024*1024,
+ "tbps": 8000000000000
+}
+
+
+def get_tc_rate(s):
+ """
+ Parses an integer or a tc rate string (e.g., 1.5mbit) into bits/second
+ """
+
+ if type(s) == int:
+ return s
+ m = re.match(r"([0-9.]+)(\D*)", s)
+ if m is None:
+ return -1
+ suffix = m.group(2).lower()
+ if suffixes.has_key(suffix):
+ return int(float(m.group(1)) * suffixes[suffix])
+ else:
+ return -1
+
+def format_tc_rate(rate):
+ """
+ Formats a bits/second rate into a tc rate string
+ """
+
+ if rate >= 1000000000 and (rate % 1000000000) == 0:
+ return "%.0fgbit" % (rate / 1000000000.)
+ elif rate >= 1000000 and (rate % 1000000) == 0:
+ return "%.0fmbit" % (rate / 1000000.)
+ elif rate >= 1000:
+ return "%.0fkbit" % (rate / 1000.)
+ else:
+ return "%.0fbit" % rate
+
+class VLink:
+ @staticmethod
+ def get_link_id(if1, if2):
+ if if1['id'] < if2['id']:
+ link = (if1['id']<<7) + if2['id']
+ else:
+ link = (if2['id']<<7) + if1['id']
+ return link
+
+ @staticmethod
+ def get_iface_id(if1, if2):
+ if if1['id'] < if2['id']:
+ iface = 1
+ else:
+ iface = 2
+ return iface
+
+ @staticmethod
+ def get_virt_ip(if1, if2):
+ link_id = get_link_id(if1, if2)
+ iface_id = get_iface_id(if1, if2)
+ first = link_id >> 6
+ second = ((link_id & 0x3f)<<2) + iface_id
+ return "192.168.%d.%s" % (frist, second)
+
+ @staticmethod
+ def get_virt_net(link):
+ link_id = self.get_link_id(link)
+ first = link_id >> 6
+ second = (link_id & 0x3f)<<2
+ return "192.168.%d.%d/30" % (first, second)
+
+ @staticmethod
+ def get_interface_id(interface):
+ if_name = PlXrn(interface=interface['component_id']).interface_name()
+ node, dev = if_name.split(":")
+ node_id = int(node.replace("pc", ""))
+ return node_id
+
+
+ @staticmethod
+ def get_topo_rspec(link):
+ link['interface1']['id'] = VLink.get_interface_id(link['interface1'])
+ link['interface2']['id'] = VLink.get_interface_id(link['interface2'])
+ my_ip = VLink.get_virt_ip(link['interface1'], link['interface2'])
+ remote_ip = VLink.get_virt_ip(link['interface2'], link['interface1'])
+ net = VLink.get_virt_net(link)
+ bw = format_tc_rate(long(link['capacity']))
+ ipaddr = remote.get_primary_iface().ipv4
+ return (link['interface2']['id'], ipaddr, bw, my_ip, remote_ip, net)
--- /dev/null
+from lxml import etree
+
+class Element:
+ def __init__(self, root_node, namespaces = None):
+ self.root_node = root_node
+ self.namespaces = namespaces
+
+ def xpath(self, xpath):
+ return this.root_node.xpath(xpath, namespaces=self.namespaces)
+
+ def add_element(self, name, attrs={}, parent=None, text=""):
+ """
+ Generic wrapper around etree.SubElement(). Adds an element to
+ specified parent node. Adds element to root node is parent is
+ not specified.
+ """
+ if parent == None:
+ parent = self.root_node
+ element = etree.SubElement(parent, name)
+ if text:
+ element.text = text
+ if isinstance(attrs, dict):
+ for attr in attrs:
+ element.set(attr, attrs[attr])
+ return element
+
+ def remove_element(self, element_name, root_node = None):
+ """
+ Removes all occurences of an element from the tree. Start at
+ specified root_node if specified, otherwise start at tree's root.
+ """
+ if not root_node:
+ root_node = self.root_node
+
+ if not element_name.startswith('//'):
+ element_name = '//' + element_name
+
+ elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
+ for element in elements:
+ parent = element.getparent()
+ parent.remove(element)
+
+
+ def add_attribute(self, elem, name, value):
+ """
+ Add attribute to specified etree element
+ """
+ opt = etree.SubElement(elem, name)
+ opt.text = value
+
+ def remove_attribute(self, elem, name, value):
+ """
+ Removes an attribute from an element
+ """
+ if not elem == None:
+ opts = elem.iterfind(name)
+ if opts is not None:
+ for opt in opts:
+ if opt.text == value:
+ elem.remove(opt)
+
+ def get_attributes(self, elem=None, depth=None):
+ if elem == None:
+ elem = self.root_node
+ attrs = dict(elem.attrib)
+ attrs['text'] = str(elem.text).strip()
+ if depth is None or isinstance(depth, int) and depth > 0:
+ for child_elem in list(elem):
+ key = str(child_elem.tag)
+ if key not in attrs:
+ attrs[key] = [self.get_attributes(child_elem, recursive)]
+ else:
+ attrs[key].append(self.get_attributes(child_elem, recursive))
+ return attrs
+
+ def attributes_list(self, elem):
+ # convert a list of attribute tags into list of tuples
+ # (tagnme, text_value)
+ opts = []
+ if not elem == None:
+ for e in elem:
+ opts.append((e.tag, e.text))
+ return opts
+
+
--- /dev/null
+class Interface(dict):
+ fields = {'component_id': None,
+ 'role': None,
+ 'client_id': None,
+ 'ipv4': None
+ }
+ def __init__(self, fields={}):
+ dict.__init__(self, Interface.fields)
+ self.update(fields)
+
+
--- /dev/null
+from sfa.rspecs.elements.interface import Interface
+
+class Link(dict):
+
+ fields = {
+ 'client_id': None,
+ 'component_id': None,
+ 'component_name': None,
+ 'component_manager': None,
+ 'type': None,
+ 'interface1': None,
+ 'interface2': None,
+ 'capacity': None,
+ 'latency': None,
+ 'packet_loss': None,
+ 'description': None,
+ }
+
+ def __init__(self, fields={}):
+ dict.__init__(self, Link.fields)
+ self.update(fields)
+
--- /dev/null
+from sfa.rspecs.elements.element import Element
+from sfa.util.sfalogging import logger
+
+class Network(Element):
+
+ def get_networks(*args, **kwds):
+ logger.info("sfa.rspecs.networks: get_networks not implemented")
+
+ def add_networks(*args, **kwds):
+ logger.info("sfa.rspecs.networks: add_network not implemented")
+
--- /dev/null
+from sfa.rspecs.elements.element import Element
+from sfa.util.faults import SfaNotImplemented
+from sfa.util.sfalogging import logger
+
+class Node(Element):
+
+ def get_nodes(*args):
+ logger.info("sfa.rspecs.nodes: get_nodes not implemented")
+
+ def add_nodes(*args):
+ logger.info("sfa.rspecs.nodes: add_nodes not implemented")
+
+
--- /dev/null
+from sfa.rspecs.elements.element import Element
+from sfa.util.sfalogging import logger
+
+class Slivers(Element):
+
+ def get_slivers(*args, **kwds):
+ logger.debug("sfa.rspecs.slivers: get_slivers not implemented")
+
+ def add_slivers(*args, **kwds):
+ logger.debug("sfa.rspecs.slivers: add_slivers not implemented")
+
+ def remove_slivers(*args, **kwds):
+ logger.debug("sfa.rspecs.slivers: remove_slivers not implemented")
+
+ def get_sliver_defaults(*args, **kwds):
+ logger.debug("sfa.rspecs.slivers: get_sliver_defaults not implemented")
+
+ def add_default_sliver_attribute(*args, **kwds):
+ logger.debug("sfa.rspecs.slivers: add_default_sliver_attributes not implemented")
+
+ def add_sliver_attribute(*args, **kwds):
+ logger.debug("sfa.rspecs.slivers: add_sliver_attribute not implemented")
+
+ def remove_default_sliver_attribute(*args, **kwds):
+ logger.debug("sfa.rspecs.slivers: remove_default_sliver_attributes not implemented")
+
+ def remove_sliver_attribute(*args, **kwds):
+ logger.debuv("sfa.rspecs.slivers: remove_sliver_attribute not implemented")
+
--- /dev/null
+from lxml import etree
+from sfa.rspecs.elements.link import Link
+from sfa.rspecs.elements.interface import Interface
+from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements
+
+class PGv2Link:
+
+ elements = {
+ 'link': RSpecElement(RSpecElements.LINK, '//default:link | //link'),
+ 'component_manager': RSpecElement(RSpecElements.COMPONENT_MANAGER, './default:component_manager | ./component_manager')
+ }
+
+ @staticmethod
+ def add_links(xml, links):
+ for link in links:
+ link_elem = etree.SubElement(xml, 'link')
+ for attrib in ['component_name', 'component_id', 'client_id']:
+ if attrib in link and link[attrib]:
+ link_elem.set(attrib, link[attrib])
+ if 'component_manager' in link and link['component_manager']:
+ cm_element = etree.SubElement(xml, 'component_manager', name=link['component_manager'])
+ for if_ref in [link['interface1'], link['interface2']]:
+ if_ref_elem = etree.SubElement(xml, 'interface_ref')
+ for attrib in Interface.fields:
+ if attrib in if_ref and if_ref[attrib]:
+ if_ref_elem.attrib[attrib] = if_ref[attrib]
+ prop1 = etree.SubElement(xml, 'property', source_id = link['interface1']['component_id'],
+ dest_id = link['interface2']['component_id'], capacity=link['capacity'],
+ latency=link['latency'], packet_loss=link['packet_loss'])
+ prop2 = etree.SubElement(xml, 'property', source_id = link['interface2']['component_id'],
+ dest_id = link['interface1']['component_id'], capacity=link['capacity'],
+ latency=link['latency'], packet_loss=link['packet_loss'])
+ if 'type' in link and link['type']:
+ type_elem = etree.SubElement(xml, 'link_type', name=link['type'])
+
+ @staticmethod
+ def get_links(xml, namespaces=None):
+ links = []
+ link_elems = xml.xpath('//default:link', namespaces=namespaces)
+ for link_elem in link_elems:
+ # set client_id, component_id, component_name
+ link = Link(link_elem.attrib)
+ # set component manager
+ cm = link_elem.xpath('./default:component_manager', namespaces=namespaces)
+ if len(cm) > 0:
+ cm = cm[0]
+ if 'name' in cm.attrib:
+ link['component_manager'] = cm.attrib['name']
+ # set link type
+ link_types = link_elem.xpath('./default:link_type', namespaces=namespaces)
+ if len(link_types) > 0:
+ link_type = link_types[0]
+ if 'name' in link_type.attrib:
+ link['type'] = link_type.attrib['name']
+
+ # get capacity, latency and packet_loss from first property
+ props = link_elem.xpath('./default:property', namespaces=namespaces)
+ if len(props) > 0:
+ prop = props[0]
+ for attrib in ['capacity', 'latency', 'packet_loss']:
+ if attrib in prop.attrib:
+ link[attrib] = prop.attrib[attrib]
+
+ # get interfaces
+ if_elems = link_elem.xpath('./default:interface_ref', namespaces=namespaces)
+ ifs = []
+ for if_elem in if_elems:
+ if_ref = Interface(if_elem.attrib)
+ ifs.append(if_ref)
+ if len(ifs) > 1:
+ link['interface1'] = ifs[0]
+ link['interface2'] = ifs[1]
+ links.append(link)
+ return links
+
+++ /dev/null
-#!/usr/bin/python
-from lxml import etree
-from StringIO import StringIO
-from sfa.rspecs.rspec import RSpec
-from sfa.util.xrn import *
-from sfa.util.plxrn import hostname_to_urn
-from sfa.util.config import Config
-from sfa.rspecs.rspec_version import RSpecVersion
-
-_ad_version = {'type': 'ProtoGENI',
- 'version': '2',
- 'schema': 'http://www.protogeni.net/resources/rspec/2/ad.xsd',
- 'namespace': 'http://www.protogeni.net/resources/rspec/2',
- 'extensions': [
- 'http://www.protogeni.net/resources/rspec/ext/gre-tunnel/1',
- 'http://www.protogeni.net/resources/rspec/ext/other-ext/3'
- ]
-}
-
-_request_version = {'type': 'ProtoGENI',
- 'version': '2',
- 'schema': 'http://www.protogeni.net/resources/rspec/2/request.xsd',
- 'namespace': 'http://www.protogeni.net/resources/rspec/2',
- 'extensions': [
- 'http://www.protogeni.net/resources/rspec/ext/gre-tunnel/1',
- 'http://www.protogeni.net/resources/rspec/ext/other-ext/3'
- ]
-}
-pg_rspec_ad_version = RSpecVersion(_ad_version)
-pg_rspec_request_version = RSpecVersion(_request_version)
-
-class PGRSpec(RSpec):
- xml = None
- header = '<?xml version="1.0"?>\n'
- template = """<rspec xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://www.protogeni.net/resources/rspec/2" xsi:schemaLocation="http://www.protogeni.net/resources/rspec/2 http://www.protogeni.net/resources/rspec/2/%(rspec_type)s.xsd"></rspec>"""
-
- def __init__(self, rspec="", namespaces={}, type=None):
- if not type:
- type = 'advertisement'
- self.type = type
-
- if type == 'advertisement':
- self.version = pg_rspec_ad_version
- rspec_type = 'ad'
- else:
- self.version = pg_rspec_request_version
- rspec_type = type
-
- self.template = self.template % {'rspec_type': rspec_type}
-
- if not namespaces:
- self.namespaces = {'rspecv2': self.version['namespace']}
- else:
- self.namespaces = namespaces
-
- if rspec:
- self.parse_rspec(rspec, self.namespaces)
- else:
- self.create()
-
- def create(self):
- RSpec.create(self)
- if self.type:
- self.xml.set('type', self.type)
-
- def get_network(self):
- network = None
- nodes = self.xml.xpath('//rspecv2:node[@component_manager_uuid][1]', namespaces=self.namespaces)
- if nodes:
- network = nodes[0].get('component_manager_uuid')
- return network
-
- def get_networks(self):
- networks = self.xml.xpath('//rspecv2:node[@component_manager_uuid]/@component_manager_uuid', namespaces=self.namespaces)
- return set(networks)
-
- def get_node_elements(self):
- nodes = self.xml.xpath('//rspecv2:node | //node', namespaces=self.namespaces)
- return nodes
-
- def get_nodes(self, network=None):
- xpath = '//rspecv2:node[@component_name]/@component_name | //node[@component_name]/@component_name'
- return self.xml.xpath(xpath, namespaces=self.namespaces)
-
- def get_nodes_with_slivers(self, network=None):
- if network:
- return self.xml.xpath('//rspecv2:node[@component_manager_id="%s"][sliver_type]/@component_name' % network, namespaces=self.namespaces)
- else:
- return self.xml.xpath('//rspecv2:node[rspecv2:sliver_type]/@component_name', namespaces=self.namespaces)
-
- def get_nodes_without_slivers(self, network=None):
- pass
-
- def add_nodes(self, nodes, check_for_dupes=False):
- if not isinstance(nodes, list):
- nodes = [nodes]
- for node in nodes:
- urn = ""
- if check_for_dupes and \
- self.xml.xpath('//rspecv2:node[@component_uuid="%s"]' % urn, namespaces=self.namespaces):
- # node already exists
- continue
-
- node_tag = etree.SubElement(self.xml, 'node', exclusive='false')
- if 'network_urn' in node:
- node_tag.set('component_manager_id', node['network_urn'])
- if 'urn' in node:
- node_tag.set('component_id', node['urn'])
- if 'hostname' in node:
- node_tag.set('component_name', node['hostname'])
- # TODO: should replace plab-pc with pc model
- node_type_tag = etree.SubElement(node_tag, 'hardware_type', name='plab-pc')
- node_type_tag = etree.SubElement(node_tag, 'hardware_type', name='pc')
- available_tag = etree.SubElement(node_tag, 'available', now='true')
- location_tag = etree.SubElement(node_tag, 'location', country="us")
- if 'site' in node:
- if 'longitude' in node['site']:
- location_tag.set('longitude', str(node['site']['longitude']))
- if 'latitude' in node['site']:
- location_tag.set('latitude', str(node['site']['latitude']))
- #if 'interfaces' in node:
-
-
- def add_slivers(self, slivers, sliver_urn=None, no_dupes=False):
- slivers = self._process_slivers(slivers)
- nodes_with_slivers = self.get_nodes_with_slivers()
- for sliver in slivers:
- hostname = sliver['hostname']
- if hostname in nodes_with_slivers:
- continue
- nodes = self.xml.xpath('//rspecv2:node[@component_name="%s"] | //node[@component_name="%s"]' % (hostname, hostname), namespaces=self.namespaces)
- if nodes:
- node = nodes[0]
- node.set('client_id', hostname)
- if sliver_urn:
- node.set('sliver_id', sliver_urn)
- etree.SubElement(node, 'sliver_type', name='plab-vnode')
-
- def add_interfaces(self, interfaces, no_dupes=False):
- pass
-
- def add_links(self, links, no_dupes=False):
- pass
-
-
- def merge(self, in_rspec):
- """
- Merge contents for specified rspec with current rspec
- """
-
- # just copy over all the child elements under the root element
- tree = etree.parse(StringIO(in_rspec))
- root = tree.getroot()
- for child in root.getchildren():
- self.xml.append(child)
-
- def cleanup(self):
- # remove unncecessary elements, attributes
- if self.type in ['request', 'manifest']:
- # remove nodes without slivers
- nodes = self.get_node_elements()
- for node in nodes:
- delete = True
- hostname = node.get('component_name')
- parent = node.getparent()
- children = node.getchildren()
- for child in children:
- if child.tag.endswith('sliver_type'):
- delete = False
- if delete:
- parent.remove(node)
-
- # remove 'available' element from remaining node elements
- self.remove_element('//rspecv2:available | //available')
-
-if __name__ == '__main__':
- rspec = PGRSpec()
- rspec.add_nodes([1])
- print rspec
from lxml import etree
from StringIO import StringIO
from sfa.util.xrn import *
-from sfa.rspecs.pg_rspec import PGRSpec
-from sfa.rspecs.sfa_rspec import SfaRSpec
+from sfa.rspecs.rspec import RSpec
+from sfa.rspecs.version_manager import VersionManager
xslt='''<xsl:stylesheet version="1.0" xmlns:xsl="http://www.w3.org/1999/XSL/Transform">
<xsl:output method="xml" indent="no"/>
class PGRSpecConverter:
@staticmethod
- def to_sfa_rspec(rspec):
- if isinstance(rspec, PGRSpec):
+ def to_sfa_rspec(rspec, content_type = None):
+ if not isinstance(rspec, RSpec):
+ pg_rspec = RSpec(rspec)
+ else:
pg_rspec = rspec
- else:
- pg_rspec = PGRSpec(rspec=rspec)
- sfa_rspec = SfaRSpec()
+
+ version_manager = VersionManager()
+ sfa_version = version_manager._get_version('sfa', '1')
+ sfa_rspec = RSpec(version=sfa_version)
# get network
- network_urn = pg_rspec.get_network()
+ network_urn = pg_rspec.version.get_network()
network, _ = urn_to_hrn(network_urn)
- network_element = sfa_rspec.add_element('network', {'name': network, 'id': network})
+ network_element = sfa_rspec.xml.add_element('network', {'name': network, 'id': network})
# get nodes
- pg_nodes_elements = pg_rspec.get_node_elements()
- nodes_with_slivers = pg_rspec.get_nodes_with_slivers()
+ pg_nodes_elements = pg_rspec.version.get_node_elements()
+ nodes_with_slivers = pg_rspec.version.get_nodes_with_slivers()
i = 1
for pg_node_element in pg_nodes_elements:
- node_element = sfa_rspec.add_element('node', {'id': 'n'+str(i)}, parent=network_element)
- urn = pg_node_element.xpath('@component_uuid | @component_id')
+ attribs = dict(pg_node_element.attrib.items())
+ attribs['id'] = 'n'+str(i)
+
+ node_element = sfa_rspec.xml.add_element('node', attribs, parent=network_element)
+ urn = pg_node_element.xpath('@component_id', namespaces=pg_rspec.namespaces)
if urn:
urn = urn[0]
hostname = Xrn.urn_split(urn)[-1]
- hostname_element = sfa_rspec.add_element('hostname', parent=node_element, text=hostname)
+ hostname_element = sfa_rspec.xml.add_element('hostname', parent=node_element, text=hostname)
if hostname in nodes_with_slivers:
- sfa_rspec.add_element('sliver', parent=node_element)
+ sfa_rspec.xml.add_element('sliver', parent=node_element)
- urn_element = sfa_rspec.add_element('urn', parent=node_element, text=urn)
+ urn_element = sfa_rspec.xml.add_element('urn', parent=node_element, text=urn)
# just copy over remaining child elements
--- /dev/null
+#
+## Extension for the "initscript" type for RSpecV2 on PlanetLab
+## Version 1
+##
+
+default namespace = "http://www.planet-lab.org/resources/ext/initscript/1"
+
+Node = element initscript {
+ attribute name { text }
+}
+
+start = Node
+
--- /dev/null
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+
+ Extension for the "initscript" type for RSpecV2 on PlanetLab
+ Version 1
+
+-->
+<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema"
+elementFormDefault="qualified"
+targetNamespace="http://www.planet-lab.org/resources/sfa/ext/planetlab/1"
+xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1">
+ <xs:element name="initscript">
+ <xs:complexType>
+ <xs:attribute name="name" use="required"/>
+ </xs:complexType>
+ </xs:element>
+</xs:schema>
#!/usr/bin/python
-from lxml import etree
-from StringIO import StringIO
from datetime import datetime, timedelta
+from sfa.util.xml import XML, XpathFilter
+from sfa.rspecs.version_manager import VersionManager
from sfa.util.xrn import *
from sfa.util.plxrn import hostname_to_urn
-from sfa.util.faults import SfaNotImplemented, InvalidRSpec
+from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements
+from sfa.util.faults import SfaNotImplemented, InvalidRSpec, InvalidRSpecElement
class RSpec:
- header = '<?xml version="1.0"?>\n'
- template = """<RSpec></RSpec>"""
- xml = None
- type = None
- version = None
- namespaces = None
- user_options = {}
-
- def __init__(self, rspec="", namespaces={}, type=None, user_options={}):
- self.type = type
+
+ def __init__(self, rspec="", version=None, user_options={}):
+ self.header = '<?xml version="1.0"?>\n'
+ self.template = """<RSpec></RSpec>"""
+ self.version = None
+ self.xml = XML()
+ self.version_manager = VersionManager()
self.user_options = user_options
+ self.elements = {}
if rspec:
- self.parse_rspec(rspec, namespaces)
+ self.parse_xml(rspec)
else:
- self.create()
+ self.create(version)
- def create(self):
+ def create(self, version=None):
"""
Create root element
"""
+ self.version = self.version_manager.get_version(version)
+ self.namespaces = self.version.namespaces
+ self.parse_xml(self.version.template)
# eg. 2011-03-23T19:53:28Z
date_format = '%Y-%m-%dT%H:%M:%SZ'
now = datetime.utcnow()
generated_ts = now.strftime(date_format)
expires_ts = (now + timedelta(hours=1)).strftime(date_format)
- self.parse_rspec(self.template, self.namespaces)
self.xml.set('expires', expires_ts)
self.xml.set('generated', generated_ts)
-
- def parse_rspec(self, rspec, namespaces={}):
- """
- parse rspec into etree
- """
- parser = etree.XMLParser(remove_blank_text=True)
- try:
- tree = etree.parse(rspec, parser)
- except IOError:
- # 'rspec' file doesnt exist. 'rspec' is proably an xml string
- try:
- tree = etree.parse(StringIO(rspec), parser)
- except:
- raise InvalidRSpec('Must specify a xml file or xml string. Received: ' + rspec )
- self.xml = tree.getroot()
- if namespaces:
- self.namespaces = namespaces
- def xpath(self, xpath):
- return this.xml.xpath(xpath, namespaces=self.namespaces)
- def add_attribute(self, elem, name, value):
- """
- Add attribute to specified etree element
- """
- opt = etree.SubElement(elem, name)
- opt.text = value
+ def parse_xml(self, xml):
+ self.xml.parse_xml(xml)
+ self.version = None
+ if self.xml.schema:
+ self.version = self.version_manager.get_version_by_schema(self.xml.schema)
+ else:
+ #raise InvalidRSpec('unknown rspec schema: %s' % schema)
+ # TODO: Should start raising an exception once SFA defines a schema.
+ # for now we just use the default
+ self.version = self.version_manager.get_version()
+ self.version.xml = self.xml
+ self.namespaces = self.xml.namespaces
+
+ def load_rspec_elements(self, rspec_elements):
+ self.elements = {}
+ for rspec_element in rspec_elements:
+ if isinstance(rspec_element, RSpecElement):
+ self.elements[rspec_element.type] = rspec_element
- def add_element(self, name, attrs={}, parent=None, text=""):
- """
- Generic wrapper around etree.SubElement(). Adds an element to
- specified parent node. Adds element to root node is parent is
- not specified.
- """
- if parent == None:
- parent = self.xml
- element = etree.SubElement(parent, name)
- if text:
- element.text = text
- if isinstance(attrs, dict):
- for attr in attrs:
- element.set(attr, attrs[attr])
- return element
+ def register_rspec_element(self, element_type, element_name, element_path):
+ if element_type not in RSpecElements:
+ raise InvalidRSpecElement(element_type, extra="no such element type: %s. Must specify a valid RSpecElement" % element_type)
+ self.elements[element_type] = RSpecElement(element_type, element_name, element_path)
- def remove_attribute(self, elem, name, value):
- """
- Removes an attribute from an element
- """
- if elem is not None:
- opts = elem.iterfind(name)
- if opts is not None:
- for opt in opts:
- if opt.text == value:
- elem.remove(opt)
+ def get_rspec_element(self, element_type):
+ if element_type not in self.elements:
+ msg = "ElementType %s not registerd for this rspec" % element_type
+ raise InvalidRSpecElement(element_type, extra=msg)
+ return self.elements[element_type]
- def remove_element(self, element_name, root_node = None):
+ def get(self, element_type, filter={}, depth=0):
+ elements = self.get_elements(element_type, filter)
+ elements = [self.get_element_attributes(element, depth=depth) for element in elements]
+ return elements
+
+ def get_elements(self, element_type, filter={}):
"""
- Removes all occurences of an element from the tree. Start at
- specified root_node if specified, otherwise start at tree's root.
+ search for a registered element
"""
- if not root_node:
- root_node = self.xml
-
- if not element_name.startswith('//'):
- element_name = '//' + element_name
-
- elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
- for element in elements:
- parent = element.getparent()
- parent.remove(element)
-
+ if element_type not in self.elements:
+ msg = "Unable to search for element %s in rspec, expath expression not found." % \
+ element_type
+ raise InvalidRSpecElement(element_type, extra=msg)
+ rspec_element = self.get_rspec_element(element_type)
+ xpath = rspec_element.path + XpathFilter.xpath(filter)
+ return self.xpath(xpath)
def merge(self, in_rspec):
- pass
+ self.version.merge(in_rspec)
- def validate(self, schema):
- """
- Validate against rng schema
- """
+ def filter(self, filter):
+ if 'component_manager_id' in filter:
+ nodes = self.version.get_node_elements()
+ for node in nodes:
+ if 'component_manager_id' not in node.attrib or \
+ node.attrib['component_manager_id'] != filter['component_manager_id']:
+ parent = node.getparent()
+ parent.remove(node)
- relaxng_doc = etree.parse(schema)
- relaxng = etree.RelaxNG(relaxng_doc)
- if not relaxng(self.xml):
- error = relaxng.error_log.last_error
- message = "%s (line %s)" % (error.message, error.line)
- raise InvalidRSpec(message)
- return True
-
- def cleanup(self):
- """
- Optional method which inheriting classes can choose to implent.
- """
- pass
- def _process_slivers(self, slivers):
- """
- Creates a dict of sliver details for each sliver host
-
- @param slivers a single hostname, list of hostanmes or list of dicts keys on hostname,
- Returns a list of dicts
- """
- if not isinstance(slivers, list):
- slivers = [slivers]
- dicts = []
- for sliver in slivers:
- if isinstance(sliver, dict):
- dicts.append(sliver)
- elif isinstance(sliver, basestring):
- dicts.append({'hostname': sliver})
- return dicts
-
- def __str__(self):
- return self.toxml()
+ def toxml(self, header=True):
+ if header:
+ return self.header + self.xml.toxml()
+ else:
+ return self.xml.toxml()
+
- def toxml(self, cleanup=False):
- if cleanup:
- self.cleanup()
- return self.header + etree.tostring(self.xml, pretty_print=True)
-
def save(self, filename):
- f = open(filename, 'w')
- f.write(self.toxml())
- f.close()
-
+ return self.xml.save(filename)
+
if __name__ == '__main__':
- rspec = RSpec()
+ rspec = RSpec('/tmp/resources.rspec')
print rspec
+ rspec.register_rspec_element(RSpecElements.NETWORK, 'network', '//network')
+ rspec.register_rspec_element(RSpecElements.NODE, 'node', '//node')
+ print rspec.get(RSpecElements.NODE)[0]
+ print rspec.get(RSpecElements.NODE, depth=1)[0]
+
from sfa.rspecs.pg_rspec_converter import PGRSpecConverter
from sfa.rspecs.sfa_rspec_converter import SfaRSpecConverter
-from sfa.rspecs.sfa_rspec import sfa_rspec_version
-from sfa.rspecs.pg_rspec import pg_rspec_ad_version, pg_rspec_request_version
-from sfa.rspecs.rspec_parser import parse_rspec
-
+from sfa.rspecs.rspec import RSpec
+from sfa.rspecs.version_manager import VersionManager
class RSpecConverter:
@staticmethod
- def to_sfa_rspec(in_rspec):
- rspec = parse_rspec(in_rspec)
- if rspec.version['type'] == sfa_rspec_version['type']:
+ def to_sfa_rspec(in_rspec, content_type=None):
+ rspec = RSpec(in_rspec)
+ version_manager = VersionManager()
+ sfa_version = version_manager._get_version('sfa', '1')
+ pg_version = version_manager._get_version('protogeni', '2')
+ if rspec.version.type.lower() == sfa_version.type.lower():
return in_rspec
- elif rspec.version['type'] == pg_rspec_ad_version['type']:
- return PGRSpecConverter.to_sfa_rspec(in_rspec)
+ elif rspec.version.type.lower() == pg_version.type.lower():
+ return PGRSpecConverter.to_sfa_rspec(in_rspec, content_type)
else:
- return in_rspec
+ return in_rspec
@staticmethod
- def to_pg_rspec(in_rspec):
- rspec = parse_rspec(in_rspec)
- if rspec.version['type'] == pg_rspec_ad_version['type']:
+ def to_pg_rspec(in_rspec, content_type=None):
+ rspec = RSpec(in_rspec)
+ version_manager = VersionManager()
+ sfa_version = version_manager._get_version('sfa', '1')
+ pg_version = version_manager._get_version('protogeni', '2')
+
+ if rspec.version.type.lower() == pg_version.type.lower():
return in_rspec
- elif rspec.version['type'] == sfa_rspec_version['type']:
- return SfaRSpecConverter.to_pg_rspec(in_rspec)
+ elif rspec.version.type.lower() == sfa_version.type.lower():
+ return SfaRSpecConverter.to_pg_rspec(in_rspec, content_type)
else:
return in_rspec
--- /dev/null
+from sfa.util.enumeration import Enum
+
+# recognized top level rspec elements
+RSpecElements = Enum(NETWORK='NETWORK',
+ COMPONENT_MANAGER='COMPONENT_MANAGER',
+ SLIVER='SLIVER',
+ NODE='NODE',
+ INTERFACE='INTERFACE',
+ LINK='LINK',
+ SERVICE='SERVICE'
+ )
+
+class RSpecElement:
+ def __init__(self, element_type, path):
+ if not element_type in RSpecElements:
+ raise InvalidRSpecElement(element_type)
+ self.type = element_type
+ self.path = path
+++ /dev/null
-#!/usr/bin/python
-from sfa.rspecs.sfa_rspec import SfaRSpec
-from sfa.rspecs.pg_rspec import PGRSpec
-from sfa.rspecs.rspec import RSpec
-from lxml import etree
-
-def parse_rspec(in_rspec):
- rspec = RSpec(rspec=in_rspec)
- # really simple check
- # TODO: check against schema instead
- out_rspec = None
- if rspec.xml.xpath('//network'):
- #out_rspec = SfaRSpec(in_rspec)
- out_rspec = SfaRSpec()
- out_rspec.xml = rspec.xml
- else:
- #out_rspec = PGRSpec(in_rspec)
- # TODO: determine if this is an ad or request
- out_rspec = PGRSpec()
- out_rspec.xml = rspec.xml
- return out_rspec
-
-if __name__ == '__main__':
-
- print "Parsing SFA RSpec:",
- rspec = parse_rspec('nodes.rspec')
- print rspec.version
- rspec = parse_rspec('protogeni.rspec')
- print "Parsing ProtoGENI RSpec:",
- print rspec.version
-
-
-
#!/usr/bin/python
-from sfa.util.sfalogging import _SfaLogger
+from sfa.util.sfalogging import logger
-class RSpecVersion(dict):
+class BaseVersion:
+ type = None
+ content_type = None
+ version = None
+ schema = None
+ namespace = None
+ extensions = {}
+ namespaces = dict(extensions.items() + [('default', namespace)])
+ elements = []
+ enabled = False
- fields = {'type': None,
- 'version': None,
- 'schema': None,
- 'namespace': None,
- 'extensions': []
- }
- def __init__(self, version={}):
-
- self.logger = _SfaLogger('/var/log/sfa.log')
- dict.__init__(self, self.fields)
-
- if not version:
- from sfa.rspecs.sfa_rspec import sfa_rspec_version
- self.update(sfa_rspec_version)
- elif isinstance(version, dict):
- self.update(version)
- elif isinstance(version, basestring):
- version_parts = version.split(' ')
- num_parts = len(version_parts)
- self['type'] = version_parts[0]
- if num_parts > 1:
- self['version'] = version_parts[1]
- else:
- logger.info("Unable to parse rspec version, using default")
+ def __init__(self, xml=None):
+ self.xml = xml
- def get_version_name(self):
- return "%s %s" % (str(self['type']), str(self['version']))
-
-if __name__ == '__main__':
+ def to_dict(self):
+ return {
+ 'type': self.type,
+ 'version': self.version,
+ 'schema': self.schema,
+ 'namespace': self.namespace,
+ 'extensions': self.extensions
+ }
- from sfa.rspecs.pl_rspec_version import ad_rspec_versions
- for version in [RSpecVersion(),
- RSpecVersion("SFA"),
- RSpecVersion("SFA 1"),
- RSpecVersion(ad_rspec_versions[0])]:
- print version.get_version_name() + ": " + str(version)
+ def to_string(self):
+ return "%s %s" % (self.type, self.version)
+
from lxml import etree
from StringIO import StringIO
from sfa.util.xrn import *
-from sfa.rspecs.sfa_rspec import SfaRSpec
-from sfa.rspecs.pg_rspec import PGRSpec
+from sfa.rspecs.rspec import RSpec
+from sfa.rspecs.version_manager import VersionManager
class SfaRSpecConverter:
@staticmethod
- def to_pg_rspec(rspec):
- if isinstance(rspec, SfaRSpec):
- sfa_rspec = rspec
+ def to_pg_rspec(rspec, content_type = None):
+ if not isinstance(rspec, RSpec):
+ sfa_rspec = RSpec(rspec)
else:
- sfa_rspec = SfaRSpec(rspec=rspec)
- pg_rspec = PGRSpec()
-
+ sfa_rspec = rspec
+
+ if not content_type or content_type not in \
+ ['ad', 'request', 'manifest']:
+ content_type = sfa_rspec.version.content_type
+
+
+ version_manager = VersionManager()
+ pg_version = version_manager._get_version('protogeni', '2', 'request')
+ pg_rspec = RSpec(version=pg_version)
+
# get networks
- networks = sfa_rspec.get_networks()
+ networks = sfa_rspec.version.get_networks()
for network in networks:
# get nodes
- sfa_node_elements = sfa_rspec.get_node_elements(network=network)
+ sfa_node_elements = sfa_rspec.version.get_node_elements(network=network)
for sfa_node_element in sfa_node_elements:
# create node element
node_attrs = {}
node_attrs['exclusive'] = 'false'
- node_attrs['component_manager_id'] = network
- if sfa_node_element.find('hostname') != None:
- node_attrs['component_name'] = sfa_node_element.find('hostname').text
- if sfa_node_element.find('urn') != None:
- node_attrs['component_id'] = sfa_node_element.find('urn').text
- node_element = pg_rspec.add_element('node', node_attrs)
+ if 'component_manager_id' in sfa_node_element.attrib:
+ node_attrs['component_manager_id'] = sfa_node_element.attrib['component_manager_id']
+ else:
+ node_attrs['component_manager_id'] = hrn_to_urn(network, 'authority+cm')
- # create node_type element
- for hw_type in ['plab-pc', 'pc']:
- hdware_type_element = pg_rspec.add_element('hardware_type', {'name': hw_type}, parent=node_element)
- # create available element
- pg_rspec.add_element('available', {'now': 'true'}, parent=node_element)
- # create locaiton element
- # We don't actually associate nodes with a country.
- # Set country to "unknown" until we figure out how to make
- # sure this value is always accurate.
- location = sfa_node_element.find('location')
- if location != None:
- location_attrs = {}
- location_attrs['country'] = locatiton.get('country', 'unknown')
- location_attrs['latitude'] = location.get('latitiue', 'None')
- location_attrs['longitude'] = location.get('longitude', 'None')
- pg_rspec.add_element('location', location_attrs, parent=node_element)
+ if 'component_id' in sfa_node_element.attrib:
+ node_attrs['compoenent_id'] = sfa_node_element.attrib['component_id']
- sliver_element = sfa_node_element.find('sliver')
- if sliver_element != None:
- pg_rspec.add_element('sliver_type', {'name': 'planetlab-vnode'}, parent=node_element)
+ if sfa_node_element.find('hostname') != None:
+ hostname = sfa_node_element.find('hostname').text
+ node_attrs['component_name'] = hostname
+ node_attrs['client_id'] = hostname
+ node_element = pg_rspec.xml.add_element('node', node_attrs)
+
+ if content_type == 'request':
+ sliver_element = sfa_node_element.find('sliver')
+ sliver_type_elements = sfa_node_element.xpath('./sliver_type', namespaces=sfa_rspec.namespaces)
+ available_sliver_types = [element.attrib['name'] for element in sliver_type_elements]
+ valid_sliver_types = ['emulab-openvz', 'raw-pc']
+
+ # determine sliver type
+ requested_sliver_type = 'emulab-openvz'
+ for available_sliver_type in available_sliver_types:
+ if available_sliver_type in valid_sliver_types:
+ requested_sliver_type = available_sliver_type
+
+ if sliver_element != None:
+ pg_rspec.xml.add_element('sliver_type', {'name': requested_sliver_type}, parent=node_element)
+ else:
+ # create node_type element
+ for hw_type in ['plab-pc', 'pc']:
+ hdware_type_element = pg_rspec.xml.add_element('hardware_type', {'name': hw_type}, parent=node_element)
+ # create available element
+ pg_rspec.xml.add_element('available', {'now': 'true'}, parent=node_element)
+ # create locaiton element
+ # We don't actually associate nodes with a country.
+ # Set country to "unknown" until we figure out how to make
+ # sure this value is always accurate.
+ location = sfa_node_element.find('location')
+ if location != None:
+ location_attrs = {}
+ location_attrs['country'] = location.get('country', 'unknown')
+ location_attrs['latitude'] = location.get('latitude', 'None')
+ location_attrs['longitude'] = location.get('longitude', 'None')
+ pg_rspec.xml.add_element('location', location_attrs, parent=node_element)
return pg_rspec.toxml()
--- /dev/null
+import os
+from sfa.util.faults import InvalidRSpec
+from sfa.rspecs.rspec_version import BaseVersion
+from sfa.util.sfalogging import logger
+
+class VersionManager:
+ default_type = 'SFA'
+ default_version_num = '1'
+
+ def __init__(self):
+ self.versions = []
+ self.load_versions()
+
+ def load_versions(self):
+ path = os.path.dirname(os.path.abspath( __file__ ))
+ versions_path = path + os.sep + 'versions'
+ versions_module_path = 'sfa.rspecs.versions'
+ valid_module = lambda x: os.path.isfile(os.sep.join([versions_path, x])) \
+ and not x.endswith('.pyc') and x not in ['__init__.py']
+ files = [f for f in os.listdir(versions_path) if valid_module(f)]
+ for filename in files:
+ basename = filename.split('.')[0]
+ module_path = versions_module_path +'.'+basename
+ module = __import__(module_path, fromlist=module_path)
+ for attr_name in dir(module):
+ attr = getattr(module, attr_name)
+ if hasattr(attr, 'version') and hasattr(attr, 'enabled') and attr.enabled == True:
+ self.versions.append(attr())
+
+ def _get_version(self, type, version_num=None, content_type=None):
+ retval = None
+ for version in self.versions:
+ if type is None or type.lower() == version.type.lower():
+ if version_num is None or version_num == version.version:
+ if content_type is None or content_type.lower() == version.content_type.lower() \
+ or version.content_type == '*':
+ retval = version
+ if not retval:
+ raise InvalidRSpec("No such version format: %s version: %s type:%s "% (type, version_num, content_type))
+ return retval
+
+ def get_version(self, version=None):
+ retval = None
+ if isinstance(version, dict):
+ retval = self._get_version(version.get('type'), version.get('version'), version.get('content_type'))
+ elif isinstance(version, basestring):
+ version_parts = version.split(' ')
+ num_parts = len(version_parts)
+ type = version_parts[0]
+ version_num = None
+ content_type = None
+ if num_parts > 1:
+ version_num = version_parts[1]
+ if num_parts > 2:
+ content_type = version_parts[2]
+ retval = self._get_version(type, version_num, content_type)
+ elif isinstance(version, BaseVersion):
+ retval = version
+ else:
+ retval = self._get_version(self.default_type, self.default_version_num)
+
+ return retval
+
+ def get_version_by_schema(self, schema):
+ retval = None
+ for version in self.versions:
+ if schema == version.schema:
+ retval = version
+ if not retval:
+ raise InvalidRSpec("Unkwnown RSpec schema: %s" % schema)
+ return retval
+
+if __name__ == '__main__':
+ v = VersionManager()
+ print v.versions
+ print v.get_version('sfa 1')
+ print v.get_version('protogeni 2')
+ print v.get_version('protogeni 2 advertisement')
+ print v.get_version_by_schema('http://www.protogeni.net/resources/rspec/2/ad.xsd')
+
--- /dev/null
+from lxml import etree
+from copy import deepcopy
+from StringIO import StringIO
+from sfa.util.xrn import *
+from sfa.util.plxrn import hostname_to_urn, xrn_to_hostname
+from sfa.rspecs.rspec_version import BaseVersion
+from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements
+from sfa.rspecs.elements.versions.pgv2Link import PGv2Link
+
+class PGv2(BaseVersion):
+ type = 'ProtoGENI'
+ content_type = 'ad'
+ version = '2'
+ schema = 'http://www.protogeni.net/resources/rspec/2/ad.xsd'
+ namespace = 'http://www.protogeni.net/resources/rspec/2'
+ extensions = {
+ 'flack': "http://www.protogeni.net/resources/rspec/ext/flack/1",
+ 'planetlab': "http://www.planet-lab.org/resources/sfa/ext/planetlab/1",
+ }
+ namespaces = dict(extensions.items() + [('default', namespace)])
+ elements = []
+
+ def get_network(self):
+ network = None
+ nodes = self.xml.xpath('//default:node[@component_manager_id][1]', namespaces=self.namespaces)
+ if nodes:
+ network = nodes[0].get('component_manager_id')
+ return network
+
+ def get_networks(self):
+ networks = self.xml.xpath('//default:node[@component_manager_id]/@component_manager_id', namespaces=self.namespaces)
+ return set(networks)
+
+ def get_node_element(self, hostname, network=None):
+ nodes = self.xml.xpath('//default:node[@component_id[contains(., "%s")]] | node[@component_id[contains(., "%s")]]' % (hostname, hostname), namespaces=self.namespaces)
+ if isinstance(nodes,list) and nodes:
+ return nodes[0]
+ else:
+ return None
+
+ def get_node_elements(self, network=None):
+ nodes = self.xml.xpath('//default:node | //node', namespaces=self.namespaces)
+ return nodes
+
+
+ def get_nodes(self, network=None):
+ xpath = '//default:node[@component_name]/@component_id | //node[@component_name]/@component_id'
+ nodes = self.xml.xpath(xpath, namespaces=self.namespaces)
+ nodes = [xrn_to_hostname(node) for node in nodes]
+ return nodes
+
+ def get_nodes_with_slivers(self, network=None):
+ if network:
+ nodes = self.xml.xpath('//default:node[@component_manager_id="%s"][sliver_type]/@component_id' % network, namespaces=self.namespaces)
+ else:
+ nodes = self.xml.xpath('//default:node[default:sliver_type]/@component_id', namespaces=self.namespaces)
+ nodes = [xrn_to_hostname(node) for node in nodes]
+ return nodes
+
+ def get_nodes_without_slivers(self, network=None):
+ return []
+
+ def get_sliver_attributes(self, hostname, network=None):
+ node = self.get_node_element(hostname, network)
+ sliver = node.xpath('./default:sliver_type', namespaces=self.namespaces)
+ if sliver is not None and isinstance(sliver, list):
+ sliver = sliver[0]
+ return self.attributes_list(sliver)
+
+ def get_slice_attributes(self, network=None):
+ slice_attributes = []
+ nodes_with_slivers = self.get_nodes_with_slivers(network)
+ # TODO: default sliver attributes in the PG rspec?
+ default_ns_prefix = self.namespaces['default']
+ for node in nodes_with_slivers:
+ sliver_attributes = self.get_sliver_attributes(node, network)
+ for sliver_attribute in sliver_attributes:
+ name=str(sliver_attribute[0])
+ text =str(sliver_attribute[1])
+ attribs = sliver_attribute[2]
+ # we currently only suppor the <initscript> and <flack> attributes
+ if 'info' in name:
+ attribute = {'name': 'flack_info', 'value': str(attribs), 'node_id': node}
+ slice_attributes.append(attribute)
+ elif 'initscript' in name:
+ if attribs is not None and 'name' in attribs:
+ value = attribs['name']
+ else:
+ value = text
+ attribute = {'name': 'initscript', 'value': value, 'node_id': node}
+ slice_attributes.append(attribute)
+
+ return slice_attributes
+
+ def get_links(self, network=None):
+ links = PGv2Link.get_links(self.xml.root, self.namespaces)
+ return links
+
+ def add_links(self, links):
+ PGv2Link.add_links(self.xml.root, links)
+
+ def attributes_list(self, elem):
+ opts = []
+ if elem is not None:
+ for e in elem:
+ opts.append((e.tag, str(e.text).strip(), e.attrib))
+ return opts
+
+ def get_default_sliver_attributes(self, network=None):
+ return []
+
+ def add_default_sliver_attribute(self, name, value, network=None):
+ pass
+
+ def add_nodes(self, nodes, check_for_dupes=False):
+ if not isinstance(nodes, list):
+ nodes = [nodes]
+ for node in nodes:
+ urn = ""
+ if check_for_dupes and \
+ self.xml.xpath('//default:node[@component_uuid="%s"]' % urn, namespaces=self.namespaces):
+ # node already exists
+ continue
+
+ node_tag = etree.SubElement(self.xml.root, 'node', exclusive='false')
+ if 'network_urn' in node:
+ node_tag.set('component_manager_id', node['network_urn'])
+ if 'urn' in node:
+ node_tag.set('component_id', node['urn'])
+ if 'hostname' in node:
+ node_tag.set('component_name', node['hostname'])
+ # TODO: should replace plab-pc with pc model
+ node_type_tag = etree.SubElement(node_tag, 'hardware_type', name='plab-pc')
+ node_type_tag = etree.SubElement(node_tag, 'hardware_type', name='pc')
+ available_tag = etree.SubElement(node_tag, 'available', now='true')
+ sliver_type_tag = etree.SubElement(node_tag, 'sliver_type', name='plab-vserver')
+
+ pl_initscripts = node.get('pl_initscripts', {})
+ for pl_initscript in pl_initscripts.values():
+ etree.SubElement(sliver_type_tag, '{%s}initscript' % self.namespaces['planetlab'], name=pl_initscript['name'])
+
+ # protogeni uses the <sliver_type> tag to identify the types of
+ # vms available at the node.
+ # only add location tag if longitude and latitude are not null
+ if 'site' in node:
+ longitude = node['site'].get('longitude', None)
+ latitude = node['site'].get('latitude', None)
+ if longitude and latitude:
+ location_tag = etree.SubElement(node_tag, 'location', country="us", \
+ longitude=str(longitude), latitude=str(latitude))
+
+ def merge_node(self, source_node_tag):
+ # this is untested
+ self.xml.root.append(deepcopy(source_node_tag))
+
+ def add_slivers(self, slivers, sliver_urn=None, no_dupes=False, append=False):
+
+ # all nodes hould already be present in the rspec. Remove all
+ # nodes that done have slivers
+ slivers_dict = {}
+ for sliver in slivers:
+ if isinstance(sliver, basestring):
+ slivers_dict[sliver] = {'hostname': sliver}
+ elif isinstance(sliver, dict):
+ slivers_dict[sliver['hostname']] = sliver
+
+ nodes = self.get_node_elements()
+ for node in nodes:
+ urn = node.get('component_id')
+ hostname = xrn_to_hostname(urn)
+ if hostname not in slivers_dict and not append:
+ parent = node.getparent()
+ parent.remove(node)
+ else:
+ sliver_info = slivers_dict[hostname]
+ sliver_type_elements = node.xpath('./default:sliver_type', namespaces=self.namespaces)
+ available_sliver_types = [element.attrib['name'] for element in sliver_type_elements]
+ valid_sliver_types = ['emulab-openvz', 'raw-pc', 'plab-vserver', 'plab-vnode']
+ requested_sliver_type = None
+ for valid_sliver_type in valid_sliver_types:
+ if valid_sliver_type in available_sliver_types:
+ requested_sliver_type = valid_sliver_type
+ if requested_sliver_type:
+ # remove existing sliver_type tags,it needs to be recreated
+ sliver_elem = node.xpath('./default:sliver_type | ./sliver_type', namespaces=self.namespaces)
+ if sliver_elem and isinstance(sliver_elem, list):
+ sliver_elem = sliver_elem[0]
+ node.remove(sliver_elem)
+ # set the client id
+ node.set('client_id', hostname)
+ if sliver_urn:
+ # set the sliver id
+ slice_id = sliver_info.get('slice_id', -1)
+ node_id = sliver_info.get('node_id', -1)
+ sliver_id = urn_to_sliver_id(sliver_urn, slice_id, node_id)
+ node.set('sliver_id', sliver_id)
+
+ # add the sliver element
+ sliver_elem = etree.SubElement(node, 'sliver_type', name=requested_sliver_type)
+ for tag in sliver_info.get('tags', []):
+ if tag['tagname'] == 'flack_info':
+ e = etree.SubElement(sliver_elem, '{%s}info' % self.namespaces['flack'], attrib=eval(tag['value']))
+ elif tag['tagname'] == 'initscript':
+ e = etree.SubElement(sliver_elem, '{%s}initscript' % self.namespaces['planetlab'], attrib={'name': tag['value']})
+ else:
+ # node isn't usable. just remove it from the request
+ parent = node.getparent()
+ parent.remove(node)
+
+
+
+ def remove_slivers(self, slivers, network=None, no_dupes=False):
+ for sliver in slivers:
+ node_elem = self.get_node_element(sliver['hostname'])
+ sliver_elem = node_elem.xpath('./default:sliver_type', self.namespaces)
+ if sliver_elem != None and sliver_elem != []:
+ node_elem.remove(sliver_elem[0])
+
+ def add_default_sliver_attribute(self, name, value, network=None):
+ pass
+
+ def add_interfaces(self, interfaces, no_dupes=False):
+ pass
+
+ def merge(self, in_rspec):
+ """
+ Merge contents for specified rspec with current rspec
+ """
+ from sfa.rspecs.rspec import RSpec
+ # just copy over all the child elements under the root element
+ if isinstance(in_rspec, RSpec):
+ in_rspec = in_rspec.toxml()
+ tree = etree.parse(StringIO(in_rspec))
+ root = tree.getroot()
+ for child in root.getchildren():
+ self.xml.root.append(child)
+
+ def cleanup(self):
+ # remove unncecessary elements, attributes
+ if self.type in ['request', 'manifest']:
+ # remove 'available' element from remaining node elements
+ self.xml.remove_element('//default:available | //available')
+
+class PGv2Ad(PGv2):
+ enabled = True
+ content_type = 'ad'
+ schema = 'http://www.protogeni.net/resources/rspec/2/ad.xsd'
+ template = '<rspec type="advertisement" xmlns="http://www.protogeni.net/resources/rspec/2" xmlns:flack="http://www.protogeni.net/resources/rspec/ext/flack/1" xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.protogeni.net/resources/rspec/2 http://www.protogeni.net/resources/rspec/2/ad.xsd http://www.planet-lab.org/resources/sfa/ext/planetlab/1 http://www.planet-lab.org/resources/sfa/ext/planetlab/1/planetlab.xsd"/>'
+
+class PGv2Request(PGv2):
+ enabled = True
+ content_type = 'request'
+ schema = 'http://www.protogeni.net/resources/rspec/2/request.xsd'
+ template = '<rspec type="request" xmlns="http://www.protogeni.net/resources/rspec/2" xmlns:flack="http://www.protogeni.net/resources/rspec/ext/flack/1" xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.protogeni.net/resources/rspec/2 http://www.protogeni.net/resources/rspec/2/request.xsd http://www.planet-lab.org/resources/sfa/ext/planetlab/1 http://www.planet-lab.org/resources/sfa/ext/planetlab/1/planetlab.xsd"/>'
+
+class PGv2Manifest(PGv2):
+ enabled = True
+ content_type = 'manifest'
+ schema = 'http://www.protogeni.net/resources/rspec/2/manifest.xsd'
+ template = '<rspec type="manifest" xmlns="http://www.protogeni.net/resources/rspec/2" xmlns:flack="http://www.protogeni.net/resources/rspec/ext/flack/1" xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.protogeni.net/resources/rspec/2 http://www.protogeni.net/resources/rspec/2/manifest.xsd http://www.planet-lab.org/resources/sfa/ext/planetlab/1 http://www.planet-lab.org/resources/sfa/ext/planetlab/1/planetlab.xsd"/>'
+
+
+
+
+if __name__ == '__main__':
+ from sfa.rspecs.rspec import RSpec
+ from sfa.rspecs.rspec_elements import *
+ r = RSpec('/tmp/pg.rspec')
+ r.load_rspec_elements(PGv2.elements)
+ r.namespaces = PGv2.namespaces
+ print r.get(RSpecElements.NODE)
--- /dev/null
+from sfa.rspecs.versions.pgv2 import PGv2
+
+class PGv3(PGv2):
+ type = 'GENI'
+ content_type = 'ad'
+ version = '3'
+ schema = 'http://www.geni.net/resources/rspec/3/ad.xsd'
+ namespace = 'http://www.geni.net/resources/rspec/3'
+ extensions = {
+ 'flack': "http://www.protogeni.net/resources/rspec/ext/flack/1",
+ 'planetlab': "http://www.planet-lab.org/resources/sfa/ext/planetlab/1",
+ }
+ namespaces = dict(extensions.items() + [('default', namespace)])
+ elements = []
+
+
+class PGv3Ad(PGv3):
+ enabled = True
+ content_type = 'ad'
+ schema = 'http://www.geni.net/resources/rspec/3/ad.xsd'
+ template = '<rspec type="advertisement" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://www.geni.net/resources/rspec/3" xmlns:flack="http://www.protogeni.net/resources/rspec/ext/flack/1" xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1" xsi:schemaLocation="http://www.geni.net/resources/rspec/3 http://www.geni.net/resources/rspec/3/ad.xsd http://www.planet-lab.org/resources/sfa/ext/planetlab/1 http://www.planet-lab.org/resources/sfa/ext/planetlab/1/planetlab.xsd"/>'
+
+class PGv3Request(PGv3):
+ enabled = True
+ content_type = 'request'
+ schema = 'http://www.geni.net/resources/rspec/3/request.xsd'
+ template = '<rspec type="request" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://www.geni.net/resources/rspec/3" xmlns:flack="http://www.protogeni.net/resources/rspec/ext/flack/1" xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1" xsi:schemaLocation="http://www.geni.net/resources/rspec/3 http://www.geni.net/resources/rspec/3/request.xsd http://www.planet-lab.org/resources/sfa/ext/planetlab/1 http://www.planet-lab.org/resources/sfa/ext/planetlab/1/planetlab.xsd"/>'
+
+class PGv2Manifest(PGv3):
+ enabled = True
+ content_type = 'manifest'
+ schema = 'http://www.geni.net/resources/rspec/3/manifest.xsd'
+ template = '<rspec type="manifest" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://www.geni.net/resources/rspec/3" xmlns:flack="http://www.protogeni.net/resources/rspec/ext/flack/1" xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1" xsi:schemaLocation="http://www.geni.net/resources/rspec/3 http://www.geni.net/resources/rspec/3/ad.xsd http://www.planet-lab.org/resources/sfa/ext/planetlab/1 http://www.planet-lab.org/resources/sfa/ext/planetlab/1/planetlab.xsd"/>'
+
-#!/usr/bin/python
from lxml import etree
-from StringIO import StringIO
-from sfa.rspecs.rspec import RSpec
-from sfa.util.xrn import *
-from sfa.util.plxrn import hostname_to_urn
-from sfa.util.config import Config
-from sfa.rspecs.rspec_version import RSpecVersion
+from sfa.util.xrn import hrn_to_urn, urn_to_hrn
+from sfa.rspecs.rspec_version import BaseVersion
+from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements
+from sfa.rspecs.elements.versions.pgv2Link import PGv2Link
+
+class SFAv1(BaseVersion):
+ enabled = True
+ type = 'SFA'
+ content_type = '*'
+ version = '1'
+ schema = None
+ namespace = None
+ extensions = {}
+ namespaces = None
+ elements = []
+ template = '<RSpec type="%s"></RSpec>' % type
-
-_version = { 'type': 'SFA',
- 'version': '1'
-}
-
-sfa_rspec_version = RSpecVersion(_version)
-
-class SfaRSpec(RSpec):
- xml = None
- header = '<?xml version="1.0"?>\n'
- version = sfa_rspec_version
-
- def create(self):
- RSpec.create(self)
- self.xml.set('type', 'SFA')
-
- ###################
- # Parser
- ###################
def get_network_elements(self):
return self.xml.xpath('//network')
else:
names = self.xml.xpath('//node/hostname')
for name in names:
- if name.text == hostname:
+ if str(name.text).strip() == hostname:
return name.getparent()
return None
-
+
def get_node_elements(self, network=None):
if network:
return self.xml.xpath('//network[@name="%s"]//node' % network)
nodes = self.xml.xpath('//node/hostname/text()')
else:
nodes = self.xml.xpath('//network[@name="%s"]//node/hostname/text()' % network)
+
+ nodes = [node.strip() for node in nodes]
return nodes
def get_nodes_with_slivers(self, network = None):
if network:
- return self.xml.xpath('//network[@name="%s"]//node[sliver]/hostname/text()' % network)
+ nodes = self.xml.xpath('//network[@name="%s"]//node[sliver]/hostname/text()' % network)
else:
- return self.xml.xpath('//node[sliver]/hostname/text()')
+ nodes = self.xml.xpath('//node[sliver]/hostname/text()')
- def get_nodes_without_slivers(self, network=None):
+ nodes = [node.strip() for node in nodes]
+ return nodes
+
+ def get_nodes_without_slivers(self, network=None):
xpath_nodes_without_slivers = '//node[not(sliver)]/hostname/text()'
- xpath_nodes_without_slivers_in_network = '//network[@name="%s"]//node[not(sliver)]/hostname/text()'
+ xpath_nodes_without_slivers_in_network = '//network[@name="%s"]//node[not(sliver)]/hostname/text()'
if network:
return self.xml.xpath('//network[@name="%s"]//node[not(sliver)]/hostname/text()' % network)
else:
- return self.xml.xpath('//node[not(sliver)]/hostname/text()')
-
+ return self.xml.xpath('//node[not(sliver)]/hostname/text()')
def attributes_list(self, elem):
# convert a list of attribute tags into list of tuples
- # (tagnme, text_value)
+ # (tagnme, text_value)
opts = []
if elem is not None:
for e in elem:
- opts.append((e.tag, e.text))
+ opts.append((e.tag, str(e.text).strip()))
return opts
def get_default_sliver_attributes(self, network=None):
if network:
- defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)
+ defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)
else:
- defaults = self.xml.xpath("//network/sliver_defaults" % network)
+ defaults = self.xml.xpath("//sliver_defaults")
+ if isinstance(defaults, list) and defaults:
+ defaults = defaults[0]
return self.attributes_list(defaults)
def get_sliver_attributes(self, hostname, network=None):
+ attributes = []
node = self.get_node_element(hostname, network)
- sliver = node.find("sliver")
- return self.attributes_list(sliver)
+ #sliver = node.find("sliver")
+ slivers = node.xpath('./sliver')
+ if isinstance(slivers, list) and slivers:
+ attributes = self.attributes_list(slivers[0])
+ return attributes
+
+ def get_slice_attributes(self, network=None):
+ slice_attributes = []
+ nodes_with_slivers = self.get_nodes_with_slivers(network)
+ for default_attribute in self.get_default_sliver_attributes(network):
+ attribute = {'name': str(default_attribute[0]), 'value': str(default_attribute[1]), 'node_id': None}
+ slice_attributes.append(attribute)
+ for node in nodes_with_slivers:
+ sliver_attributes = self.get_sliver_attributes(node, network)
+ for sliver_attribute in sliver_attributes:
+ attribute = {'name': str(sliver_attribute[0]), 'value': str(sliver_attribute[1]), 'node_id': node}
+ slice_attributes.append(attribute)
+ return slice_attributes
def get_site_nodes(self, siteid, network=None):
if network:
else:
nodes = self.xml.xpath('//site[@id="%s"]/node/hostname/text()' % siteid)
return nodes
-
+
def get_links(self, network=None):
- if network:
- links = self.xml.xpath('//network[@name="%s"]/link' % network)
- else:
- links = self.xml.xpath('//link')
- linklist = []
- for link in links:
- (end1, end2) = link.get("endpoints").split()
- name = link.find("description")
- linklist.append((name.text,
- self.get_site_nodes(end1, network),
- self.get_site_nodes(end2, network)))
- return linklist
+ links = PGv2Link.get_links(self.xml, self.namespaces)
+ return links
def get_link(self, fromnode, tonode, network=None):
fromsite = fromnode.getparent()
def get_vlinks(self, network=None):
vlinklist = []
- if network:
+ if network:
vlinks = self.xml.xpath("//network[@name='%s']//vlink" % network)
else:
- vlinks = self.xml.xpath("//vlink")
+ vlinks = self.xml.xpath("//vlink")
for vlink in vlinks:
endpoints = vlink.get("endpoints")
(end1, end2) = endpoints.split()
- if network:
+ if network:
node1 = self.xml.xpath('//network[@name="%s"]//node[@id="%s"]/hostname/text()' % \
(network, end1))[0]
node2 = self.xml.xpath('//network[@name="%s"]//node[@id="%s"]/hostname/text()' % \
(network, end2))[0]
- else:
+ else:
node1 = self.xml.xpath('//node[@id="%s"]/hostname/text()' % end1)[0]
node2 = self.xml.xpath('//node[@id="%s"]/hostname/text()' % end2)[0]
desc = "%s <--> %s" % (node1, node2)
query = "//vlink[@endpoints = '%s']" % (network, endpoints)
results = self.rspec.xpath(query)
return results
-
+
def query_vlinks(self, endpoints, network=None):
return get_vlink(endpoints,network)
+
##################
# Builder
##################
def add_network(self, network):
- network_tag = etree.SubElement(self.xml, 'network', id=network)
+ network_tags = self.xml.xpath('//network[@name="%s"]' % network)
+ if not network_tags:
+ network_tag = etree.SubElement(self.xml.root, 'network', name=network)
+ else:
+ network_tag = network_tags[0]
+ return network_tag
def add_nodes(self, nodes, network = None, no_dupes=False):
if not isinstance(nodes, list):
# node already exists
continue
- network_tag = self.xml
+ network_tag = self.xml.root
if 'network' in node:
network = node['network']
- network_tags = self.xml.xpath('//network[@name="%s"]' % network)
- if not network_tags:
- network_tag = etree.SubElement(self.xml, 'network', name=network)
- else:
- network_tag = network_tags[0]
-
+ network_tag = self.add_network(network)
+
node_tag = etree.SubElement(network_tag, 'node')
if 'network' in node:
- node_tag.set('component_manager_id', network)
+ node_tag.set('component_manager_id', hrn_to_urn(network, 'authority+sa'))
if 'urn' in node:
- node_tag.set('component_id', node['urn'])
+ node_tag.set('component_id', node['urn'])
if 'site_urn' in node:
node_tag.set('site_id', node['site_urn'])
- if 'node_id' in node:
+ if 'node_id' in node:
node_tag.set('node_id', 'n'+str(node['node_id']))
+ if 'boot_state' in node:
+ node_tag.set('boot_state', node['boot_state'])
if 'hostname' in node:
+ node_tag.set('component_name', node['hostname'])
hostname_tag = etree.SubElement(node_tag, 'hostname').text = node['hostname']
if 'interfaces' in node:
+ i = 0
for interface in node['interfaces']:
if 'bwlimit' in interface and interface['bwlimit']:
bwlimit = etree.SubElement(node_tag, 'bw_limit', units='kbps').text = str(interface['bwlimit']/1000)
+ comp_id = hrn_to_urn(network, 'pc%s:eth%s' % (node['node_id'], i))
+ interface_tag = etree.SubElement(node_tag, 'interface', component_id=comp_id)
+ i+=1
+ if 'bw_unallocated' in node:
+ bw_unallocated = etree.SubElement(node_tag, 'bw_unallocated', units='kbps').text = str(node['bw_unallocated']/1000)
if 'tags' in node:
for tag in node['tags']:
- # expose this hard wired list of tags, plus the ones that are marked 'sfa' in their category
+ # expose this hard wired list of tags, plus the ones that are marked 'sfa' in their category
if tag['tagname'] in ['fcdistro', 'arch'] or 'sfa' in tag['category'].split('/'):
- tag_element = etree.SubElement(node_tag, tag['tagname'], value=tag['value'])
+ tag_element = etree.SubElement(node_tag, tag['tagname']).text=tag['value']
if 'site' in node:
longitude = str(node['site']['longitude'])
latitude = str(node['site']['latitude'])
location = etree.SubElement(node_tag, 'location', country='unknown', \
- longitude=longitude, latitude=latitude)
+ longitude=longitude, latitude=latitude)
+
+ def merge_node(self, source_node_tag, network, no_dupes=False):
+ if no_dupes and self.get_node_element(node['hostname']):
+ # node already exists
+ return
+
+ network_tag = self.add_network(network)
+ network_tag.append(deepcopy(source_node_tag))
def add_interfaces(self, interfaces):
- pass
+ pass
def add_links(self, links):
- pass
-
- def add_slivers(self, slivers, network=None, sliver_urn=None, no_dupes=False):
- slivers = self._process_slivers(slivers)
- nodes_with_slivers = self.get_nodes_with_slivers(network)
+ PGv2Link.add_links(self.xml, links)
+
+ def add_slivers(self, slivers, network=None, sliver_urn=None, no_dupes=False, append=False):
+ # add slice name to network tag
+ network_tags = self.xml.xpath('//network')
+ if network_tags:
+ network_tag = network_tags[0]
+ network_tag.set('slice', urn_to_hrn(sliver_urn)[0])
+
+ all_nodes = self.get_nodes()
+ nodes_with_slivers = [sliver['hostname'] for sliver in slivers]
+ nodes_without_slivers = set(all_nodes).difference(nodes_with_slivers)
+
+ # add slivers
for sliver in slivers:
- if sliver['hostname'] in nodes_with_slivers:
- continue
node_elem = self.get_node_element(sliver['hostname'], network)
+ if not node_elem: continue
sliver_elem = etree.SubElement(node_elem, 'sliver')
if 'tags' in sliver:
for tag in sliver['tags']:
- etree.SubElement(sliver_elem, tag['tagname'], value=tag['value'])
+ etree.SubElement(sliver_elem, tag['tagname']).text = value=tag['value']
+
+ # remove all nodes without slivers
+ if not append:
+ for node in nodes_without_slivers:
+ node_elem = self.get_node_element(node)
+ parent = node_elem.getparent()
+ parent.remove(node_elem)
def remove_slivers(self, slivers, network=None, no_dupes=False):
- if not isinstance(slivers, list):
- slivers = [slivers]
for sliver in slivers:
node_elem = self.get_node_element(sliver['hostname'], network)
- sliver_elem = node.find('sliver')
+ sliver_elem = node_elem.find('sliver')
if sliver_elem != None:
- node_elem.remove(sliver)
-
+ node_elem.remove(sliver_elem)
+
def add_default_sliver_attribute(self, name, value, network=None):
if network:
defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)
else:
defaults = self.xml.xpath("//sliver_defaults" % network)
- if defaults is None:
- defaults = etree.Element("sliver_defaults")
- network = self.xml.xpath("//network[@name='%s']" % network)
- network.insert(0, defaults)
- self.add_attribute(defaults, name, value)
+ if not defaults :
+ network_tag = self.xml.xpath("//network[@name='%s']" % network)
+ if isinstance(network_tag, list):
+ network_tag = network_tag[0]
+ defaults = self.xml.add_element('sliver_defaults', attrs={}, parent=network_tag)
+ elif isinstance(defaults, list):
+ defaults = defaults[0]
+ self.xml.add_attribute(defaults, name, value)
def add_sliver_attribute(self, hostname, name, value, network=None):
node = self.get_node_element(hostname, network)
sliver = node.find("sliver")
- self.add_attribute(sliver, name, value)
+ self.xml.add_attribute(sliver, name, value)
def remove_default_sliver_attribute(self, name, value, network=None):
if network:
defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)
else:
defaults = self.xml.xpath("//sliver_defaults" % network)
- self.remove_attribute(defaults, name, value)
+ self.xml.remove_attribute(defaults, name, value)
def remove_sliver_attribute(self, hostname, name, value, network=None):
node = self.get_node_element(hostname, network)
sliver = node.find("sliver")
- self.remove_attribute(sliver, name, value)
+ self.xml.remove_attribute(sliver, name, value)
def add_vlink(self, fromhost, tohost, kbps, network=None):
fromnode = self.get_node_element(fromhost, network)
fromid = fromnode.get("id")
toid = tonode.get("id")
vlink.set("endpoints", "%s %s" % (fromid, toid))
- self.add_attribute(vlink, "kbps", kbps)
+ self.xml.add_attribute(vlink, "kbps", kbps)
def remove_vlink(self, endpoints, network=None):
def merge(self, in_rspec):
"""
- Merge contents for specified rspec with current rspec
+ Merge contents for specified rspec with current rspec
"""
+ from sfa.rspecs.rspec import RSpec
+ if isinstance(in_rspec, RSpec):
+ rspec = in_rspec
+ else:
+ rspec = RSpec(in_rspec)
+ if rspec.version.type.lower() == 'protogeni':
+ from sfa.rspecs.rspec_converter import RSpecConverter
+ in_rspec = RSpecConverter.to_sfa_rspec(rspec.toxml())
+ rspec = RSpec(in_rspec)
+
# just copy over all networks
current_networks = self.get_networks()
- rspec = SfaRSpec(rspec=in_rspec)
- networks = rspec.get_network_elements()
+ networks = rspec.version.get_network_elements()
for network in networks:
current_network = network.get('name')
- if not current_network in current_networks:
- self.xml.append(network)
+ if current_network and current_network not in current_networks:
+ self.xml.root.append(network)
current_networks.append(current_network)
-
-
if __name__ == '__main__':
- rspec = SfaRSpec()
- nodes = [
- {'network': 'plc',
- 'hostname': 'node1.planet-lab.org',
- 'site_urn': 'urn:publicid:IDN+plc+authority+cm',
- 'node_id': 1,
- },
- {'network': 'plc',
- 'hostname': 'node2.planet-lab.org',
- 'site_urn': 'urn:publicid:IDN+plc+authority+cm',
- 'node_id': 1,
- },
- {'network': 'ple',
- 'hostname': 'node1.planet-lab.eu',
- 'site_urn': 'urn:publicid:IDN+plc+authority+cm',
- 'node_id': 1,
- },
- ]
- rspec.add_nodes(nodes)
- print rspec
+ from sfa.rspecs.rspec import RSpec
+ from sfa.rspecs.rspec_elements import *
+ r = RSpec('/tmp/resources.rspec')
+ r.load_rspec_elements(SFAv1.elements)
+ print r.get(RSpecElements.NODE)
from sfa.util.faults import *
from sfa.util.server import SfaServer
from sfa.util.xrn import hrn_to_urn
-from sfa.server.interface import Interfaces
+from sfa.server.interface import Interfaces, Interface
+from sfa.util.config import Config
class Aggregate(SfaServer):
default_dict = {'aggregates': {'aggregate': [Interfaces.default_fields]}}
- def __init__(self, api, conf_file = "/etc/sfa/aggregates.xml"):
- Interfaces.__init__(self, api, conf_file)
+ def __init__(self, conf_file = "/etc/sfa/aggregates.xml"):
+ Interfaces.__init__(self, conf_file)
+ sfa_config = Config()
# set up a connection to the local aggregate
- if self.api.config.SFA_AGGREGATE_ENABLED:
- address = self.api.config.SFA_AGGREGATE_HOST
- port = self.api.config.SFA_AGGREGATE_PORT
- url = 'http://%(address)s:%(port)s' % locals()
- local_aggregate = {'hrn': self.api.hrn,
- 'urn': hrn_to_urn(self.api.hrn, 'authority'),
- 'addr': address,
- 'port': port,
- 'url': url}
- self.interfaces[self.api.hrn] = local_aggregate
-
- # get connections
- self.update(self.get_connections())
+ if sfa_config.SFA_AGGREGATE_ENABLED:
+ addr = sfa_config.SFA_AGGREGATE_HOST
+ port = sfa_config.SFA_AGGREGATE_PORT
+ hrn = sfa_config.SFA_INTERFACE_HRN
+ interface = Interface(hrn, addr, port)
+ self[hrn] = interface
#
# Component is a SfaServer that implements the Component interface
#
-### $Id:
-### $URL:
-#
-
import tempfile
import os
import time
import traceback
import os.path
-from sfa.util.sfalogging import sfa_logger
from sfa.util.faults import *
from sfa.util.storage import XmlStorage
from sfa.util.xrn import get_authority, hrn_to_urn
GeniClientLight = None
+
+class Interface:
+
+ def __init__(self, hrn, addr, port, client_type='sfa'):
+ self.hrn = hrn
+ self.addr = addr
+ self.port = port
+ self.client_type = client_type
+
+ def get_url(self):
+ address_parts = self.addr.split('/')
+ address_parts[0] = address_parts[0] + ":" + str(self.port)
+ url = "http://%s" % "/".join(address_parts)
+ return url
+
+ def get_server(self, key_file, cert_file, timeout=30):
+ server = None
+ if self.client_type == 'geniclientlight' and GeniClientLight:
+ server = GeniClientLight(url, self.api.key_file, self.api.cert_file)
+ else:
+ server = xmlrpcprotocol.get_server(self.get_url(), key_file, cert_file, timeout)
+
+ return server
##
# In is a dictionary of registry connections keyed on the registry
# hrn
class Interfaces(dict):
"""
Interfaces is a base class for managing information on the
- peers we are federated with. It is responsible for the following:
-
- 1) Makes sure a record exist in the local registry for the each
- fedeated peer
- 2) Attempts to fetch and install trusted gids
- 3) Provides connections (xmlrpc or soap) to federated peers
+ peers we are federated with. Provides connections (xmlrpc or soap) to federated peers
"""
# fields that must be specified in the config file
# defined by the class
default_dict = {}
- types = ['authority']
-
- def __init__(self, api, conf_file, type='authority'):
- if type not in self.types:
- raise SfaInfaildArgument('Invalid type %s: must be in %s' % (type, self.types))
+ def __init__(self, conf_file):
dict.__init__(self, {})
- self.api = api
- self.type = type
# load config file
self.interface_info = XmlStorage(conf_file, self.default_dict)
self.interface_info.load()
- interfaces = self.interface_info.values()[0].values()[0]
- if not isinstance(interfaces, list):
- interfaces = [self.interfaces]
- # set the url and urn
- for interface in interfaces:
- # port is appended onto the domain, before the path. Should look like:
- # http://domain:port/path
- hrn, address, port = interface['hrn'], interface['addr'], interface['port']
- address_parts = address.split('/')
- address_parts[0] = address_parts[0] + ":" + str(port)
- url = "http://%s" % "/".join(address_parts)
- interface['url'] = url
- interface['urn'] = hrn_to_urn(hrn, 'authority')
-
- self.interfaces = {}
- required_fields = self.default_fields.keys()
- for interface in interfaces:
- valid = True
- # skp any interface definition that has a null hrn,
- # address or port
- for field in required_fields:
- if field not in interface or not interface[field]:
- valid = False
- break
- if valid:
- self.interfaces[interface['hrn']] = interface
-
-
- def sync_interfaces(self):
- """
- Install missing trusted gids and db records for our federated
- interfaces
- """
- # Attempt to get any missing peer gids
- # There should be a gid file in /etc/sfa/trusted_roots for every
- # peer registry found in in the registries.xml config file. If there
- # are any missing gids, request a new one from the peer registry.
- gids_current = self.api.auth.trusted_cert_list
- hrns_current = [gid.get_hrn() for gid in gids_current]
- hrns_expected = self.interfaces.keys()
- new_hrns = set(hrns_expected).difference(hrns_current)
- gids = self.get_peer_gids(new_hrns) + gids_current
- # make sure there is a record for every gid
- self.update_db_records(self.type, gids)
-
- def get_peer_gids(self, new_hrns):
- """
- Install trusted gids from the specified interfaces.
- """
- peer_gids = []
- if not new_hrns:
- return peer_gids
- trusted_certs_dir = self.api.config.get_trustedroots_dir()
- for new_hrn in new_hrns:
- if not new_hrn:
- continue
- # the gid for this interface should already be installed
- if new_hrn == self.api.config.SFA_INTERFACE_HRN:
- continue
- try:
- # get gid from the registry
- interface_info = self.interfaces[new_hrn]
- interface = self[new_hrn]
- trusted_gids = interface.get_trusted_certs()
- if trusted_gids:
- # the gid we want shoudl be the first one in the list,
- # but lets make sure
- for trusted_gid in trusted_gids:
- # default message
- message = "interface: %s\t" % (self.api.interface)
- message += "unable to install trusted gid for %s" % \
- (new_hrn)
- gid = GID(string=trusted_gids[0])
- peer_gids.append(gid)
- if gid.get_hrn() == new_hrn:
- gid_filename = os.path.join(trusted_certs_dir, '%s.gid' % new_hrn)
- gid.save_to_file(gid_filename, save_parents=True)
- message = "interface: %s\tinstalled trusted gid for %s" % \
- (self.api.interface, new_hrn)
- # log the message
- self.api.logger.info(message)
- except:
- message = "interface: %s\tunable to install trusted gid for %s" % \
- (self.api.interface, new_hrn)
- self.api.logger.log_exc(message)
-
- # reload the trusted certs list
- self.api.auth.load_trusted_certs()
- return peer_gids
-
- def update_db_records(self, type, gids):
- """
- Make sure there is a record in the local db for allowed registries
- defined in the config file (registries.xml). Removes old records from
- the db.
- """
- # import SfaTable here so this module can be loaded by ComponentAPI
- from sfa.util.table import SfaTable
- if not gids:
- return
+ records = self.interface_info.values()[0]
+ if not isinstance(records, list):
+ records = [records]
- # hrns that should have a record
- hrns_expected = [gid.get_hrn() for gid in gids]
-
- # get hrns that actually exist in the db
- table = SfaTable()
- records = table.find({'type': type, 'pointer': -1})
- hrns_found = [record['hrn'] for record in records]
-
- # remove old records
- for record in records:
- if record['hrn'] not in hrns_expected and \
- record['hrn'] != self.api.config.SFA_INTERFACE_HRN:
- table.remove(record)
-
- # add new records
- for gid in gids:
- hrn = gid.get_hrn()
- if hrn not in hrns_found:
- record = {
- 'hrn': hrn,
- 'type': type,
- 'pointer': -1,
- 'authority': get_authority(hrn),
- 'gid': gid.save_to_string(save_parents=True),
- }
- record = SfaRecord(dict=record)
- table.insert(record)
-
- def get_connections(self):
- """
- read connection details for the trusted peer registries from file return
- a dictionary of connections keyed on interface hrn.
- """
- connections = {}
required_fields = self.default_fields.keys()
- for interface in self.interfaces.values():
- url = interface['url']
-# sfa_logger().debug("Interfaces.get_connections - looping on neighbour %s"%url)
- # check which client we should use
- # sfa.util.xmlrpcprotocol is default
- client_type = 'xmlrpcprotocol'
- if interface.has_key('client') and \
- interface['client'] in ['geniclientlight'] and \
- GeniClientLight:
- client_type = 'geniclientlight'
- connections[hrn] = GeniClientLight(url, self.api.key_file, self.api.cert_file)
- else:
- connections[interface['hrn']] = xmlrpcprotocol.get_server(url, self.api.key_file, self.api.cert_file)
+ for record in records:
+ if not record or not set(required_fields).issubset(record.keys()):
+ continue
+ # port is appended onto the domain, before the path. Should look like:
+ # http://domain:port/path
+ hrn, address, port = record['hrn'], record['addr'], record['port']
+ interface = Interface(hrn, address, port)
+ self[hrn] = interface
- return connections
+ def get_server(self, hrn, key_file, cert_file, timeout=30):
+ return self[hrn].get_server(key_file, cert_file, timeout)
from mod_python import apache
from sfa.plc.api import SfaAPI
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
api = SfaAPI(interface='aggregate')
except Exception, err:
# Log error in /var/log/httpd/(ssl_)?error_log
- sfa_logger().log_exc('%r'%err)
+ logger.log_exc('%r'%err)
return apache.HTTP_INTERNAL_SERVER_ERROR
from mod_python import apache
from sfa.plc.api import SfaAPI
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
api = SfaAPI(interface='registry')
except Exception, err:
# Log error in /var/log/httpd/(ssl_)?error_log
- sfa_logger().log_exc('%r'%err)
+ logger.log_exc('%r'%err)
return apache.HTTP_INTERNAL_SERVER_ERROR
from mod_python import apache
from sfa.plc.api import SfaAPI
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
api = SfaAPI(interface='slicemgr')
except Exception, err:
# Log error in /var/log/httpd/(ssl_)?error_log
- sfa_logger().log_exc('%r'%err)
+ logger.log_exc('%r'%err)
return apache.HTTP_INTERNAL_SERVER_ERROR
#
# Registry is a SfaServer that implements the Registry interface
#
-### $Id$
-### $URL$
-#
-
from sfa.util.server import SfaServer
from sfa.util.faults import *
from sfa.util.xrn import hrn_to_urn
-from sfa.server.interface import Interfaces
-import sfa.util.xmlrpcprotocol as xmlrpcprotocol
-import sfa.util.soapprotocol as soapprotocol
-
+from sfa.server.interface import Interfaces, Interface
+from sfa.util.config import Config
##
# Registry is a SfaServer that serves registry and slice operations at PLC.
default_dict = {'registries': {'registry': [Interfaces.default_fields]}}
- def __init__(self, api, conf_file = "/etc/sfa/registries.xml"):
- Interfaces.__init__(self, api, conf_file)
- address = self.api.config.SFA_REGISTRY_HOST
- port = self.api.config.SFA_REGISTRY_PORT
- url = 'http://%(address)s:%(port)s' % locals()
- local_registry = {'hrn': self.api.hrn,
- 'urn': hrn_to_urn(self.api.hrn, 'authority'),
- 'addr': address,
- 'port': port,
- 'url': url}
- self.interfaces[self.api.hrn] = local_registry
-
- # get connections
- self.update(self.get_connections())
+ def __init__(self, conf_file = "/etc/sfa/registries.xml"):
+ Interfaces.__init__(self, conf_file)
+ sfa_config = Config()
+ if sfa_config.SFA_REGISTRY_ENABLED:
+ addr = sfa_config.SFA_REGISTRY_HOST
+ port = sfa_config.SFA_REGISTRY_PORT
+ hrn = sfa_config.SFA_INTERFACE_HRN
+ interface = Interface(hrn, addr, port)
+ self[hrn] = interface
#!/usr/bin/python
#
-# SFA PLC Wrapper
+# PlanetLab SFA implementation
#
-# This wrapper implements the SFA Registry and Slice Interfaces on PLC.
+# This implements the SFA Registry and Slice Interfaces on PLC.
# Depending on command line options, it starts some combination of a
# Registry, an Aggregate Manager, and a Slice Manager.
#
-# There are several items that need to be done before starting the wrapper
-# server.
+# There are several items that need to be done before starting the servers.
#
# NOTE: Many configuration settings, including the PLC maintenance account
# credentials, URI of the PLCAPI, and PLC DB URI and admin credentials are initialized
import os, os.path
import traceback
import sys
+import sfa.util.xmlrpcprotocol as xmlrpcprotocol
from optparse import OptionParser
-from sfa.util.sfalogging import sfa_logger
-from sfa.trust.trustedroot import TrustedRootList
+from sfa.util.sfalogging import logger
from sfa.trust.certificate import Keypair, Certificate
from sfa.trust.hierarchy import Hierarchy
from sfa.trust.gid import GID
from sfa.plc.api import SfaAPI
from sfa.server.registry import Registries
from sfa.server.aggregate import Aggregates
+from sfa.util.xrn import get_authority, hrn_to_urn
+from sfa.util.sfalogging import logger
+
+from sfa.managers.import_manager import import_manager
# after http://www.erlenstar.demon.co.uk/unix/faq_2.html
def daemon():
if not os.path.exists(key_file):
# if it doesnt exist then this is probably a fresh interface
# with no records. Generate a random keypair for now
- sfa_logger().debug("server's public key not found in %s" % key_file)
- sfa_logger().debug("generating a random server key pair")
+ logger.debug("server's public key not found in %s" % key_file)
+
+ logger.debug("generating a random server key pair")
key = Keypair(create=True)
key.save_to_file(server_key_file)
init_server_cert(hrn, key, server_cert_file, self_signed=True)
else:
try:
# look for gid file
- sfa_logger().debug("generating server cert from gid: %s"% hrn)
+ logger.debug("generating server cert from gid: %s"% hrn)
hierarchy = Hierarchy()
auth_info = hierarchy.get_auth_info(hrn)
gid = GID(filename=auth_info.gid_filename)
gid.save_to_file(filename=server_cert_file)
except:
# fall back to self signed cert
- sfa_logger().debug("gid for %s not found" % hrn)
+ logger.debug("gid for %s not found" % hrn)
init_self_signed_cert(hrn, key, server_cert_file)
def init_self_signed_cert(hrn, key, server_cert_file):
- sfa_logger().debug("generating self signed cert")
+ logger.debug("generating self signed cert")
# generate self signed certificate
cert = Certificate(subject=hrn)
cert.set_issuer(key=key, subject=hrn)
def init_server(options, config):
"""
- Execute the init method defined in the manager file
+ Locate the manager based on config.*TYPE
+ Execute the init_server method (well in fact function, sigh) if defined in that module
+ In order to migrate to a more generic approach:
+ * search for <>_manager_<type>.py
+ * if not found, try <>_manager.py (and issue a warning if <type>!='pl')
"""
- def init_manager(manager_module, manager_base):
- try: manager = __import__(manager_module, fromlist=[manager_base])
- except: manager = None
- if manager and hasattr(manager, 'init_server'):
- manager.init_server()
-
- manager_base = 'sfa.managers'
if options.registry:
- mgr_type = config.SFA_REGISTRY_TYPE
- manager_module = manager_base + ".registry_manager_%s" % mgr_type
- init_manager(manager_module, manager_base)
- if options.am:
- mgr_type = config.SFA_AGGREGATE_TYPE
- manager_module = manager_base + ".aggregate_manager_%s" % mgr_type
- init_manager(manager_module, manager_base)
+ manager=import_manager ("registry", config.SFA_REGISTRY_TYPE)
+ if manager and hasattr(manager, 'init_server'): manager.init_server()
+ if options.am:
+ manager=import_manager ("aggregate", config.SFA_AGGREGATE_TYPE)
+ if manager and hasattr(manager, 'init_server'): manager.init_server()
if options.sm:
- mgr_type = config.SFA_SM_TYPE
- manager_module = manager_base + ".slice_manager_%s" % mgr_type
- init_manager(manager_module, manager_base)
+ manager=import_manager ("slice", config.SFA_SM_TYPE)
+ if manager and hasattr(manager, 'init_server'): manager.init_server()
if options.cm:
- mgr_type = config.SFA_CM_TYPE
- manager_module = manager_base + ".component_manager_%s" % mgr_type
- init_manager(manager_module, manager_base)
+ manager=import_manager ("component", config.SFA_CM_TYPE)
+ if manager and hasattr(manager, 'init_server'): manager.init_server()
+
-def sync_interfaces(server_key_file, server_cert_file):
+def install_peer_certs(server_key_file, server_cert_file):
"""
Attempt to install missing trusted gids and db records for
our federated interfaces
"""
+ # Attempt to get any missing peer gids
+ # There should be a gid file in /etc/sfa/trusted_roots for every
+ # peer registry found in in the registries.xml config file. If there
+ # are any missing gids, request a new one from the peer registry.
api = SfaAPI(key_file = server_key_file, cert_file = server_cert_file)
- registries = Registries(api)
- aggregates = Aggregates(api)
- registries.sync_interfaces()
- aggregates.sync_interfaces()
+ registries = Registries()
+ aggregates = Aggregates()
+ interfaces = dict(registries.items() + aggregates.items())
+ gids_current = api.auth.trusted_cert_list
+ hrns_current = [gid.get_hrn() for gid in gids_current]
+ hrns_expected = set([hrn for hrn in interfaces])
+ new_hrns = set(hrns_expected).difference(hrns_current)
+ #gids = self.get_peer_gids(new_hrns) + gids_current
+ peer_gids = []
+ if not new_hrns:
+ return
+
+ trusted_certs_dir = api.config.get_trustedroots_dir()
+ for new_hrn in new_hrns:
+ if not new_hrn: continue
+ # the gid for this interface should already be installed
+ if new_hrn == api.config.SFA_INTERFACE_HRN: continue
+ try:
+ # get gid from the registry
+ url = interfaces[new_hrn].get_url()
+ interface = interfaces[new_hrn].get_server(server_key_file, server_cert_file, timeout=30)
+ # skip non sfa aggregates
+ server_version = api.get_cached_server_version(interface)
+ if 'sfa' not in server_version:
+ logger.info("get_trusted_certs: skipping non sfa aggregate: %s" % new_hrn)
+ continue
+
+ trusted_gids = interface.get_trusted_certs()
+ if trusted_gids:
+ # the gid we want should be the first one in the list,
+ # but lets make sure
+ for trusted_gid in trusted_gids:
+ # default message
+ message = "interface: %s\t" % (api.interface)
+ message += "unable to install trusted gid for %s" % \
+ (new_hrn)
+ gid = GID(string=trusted_gids[0])
+ peer_gids.append(gid)
+ if gid.get_hrn() == new_hrn:
+ gid_filename = os.path.join(trusted_certs_dir, '%s.gid' % new_hrn)
+ gid.save_to_file(gid_filename, save_parents=True)
+ message = "installed trusted cert for %s" % new_hrn
+ # log the message
+ api.logger.info(message)
+ except:
+ message = "interface: %s\tunable to install trusted gid for %s" % \
+ (api.interface, new_hrn)
+ api.logger.log_exc(message)
+ # doesnt matter witch one
+ update_cert_records(peer_gids)
+
+def update_cert_records(gids):
+ """
+ Make sure there is a record in the registry for the specified gids.
+ Removes old records from the db.
+ """
+ # import SfaTable here so this module can be loaded by ComponentAPI
+ from sfa.util.table import SfaTable
+ from sfa.util.record import SfaRecord
+ if not gids:
+ return
+ table = SfaTable()
+ # get records that actually exist in the db
+ gid_urns = [gid.get_urn() for gid in gids]
+ hrns_expected = [gid.get_hrn() for gid in gids]
+ records_found = table.find({'hrn': hrns_expected, 'pointer': -1})
+ # remove old records
+ for record in records_found:
+ if record['hrn'] not in hrns_expected and \
+ record['hrn'] != self.api.config.SFA_INTERFACE_HRN:
+ table.remove(record)
+
+ # TODO: store urn in the db so we do this in 1 query
+ for gid in gids:
+ hrn, type = gid.get_hrn(), gid.get_type()
+ record = table.find({'hrn': hrn, 'type': type, 'pointer': -1})
+ if not record:
+ record = {
+ 'hrn': hrn, 'type': type, 'pointer': -1,
+ 'authority': get_authority(hrn),
+ 'gid': gid.save_to_string(save_parents=True),
+ }
+ record = SfaRecord(dict=record)
+ table.insert(record)
+
def main():
# Generate command line parser
parser = OptionParser(usage="sfa-server [options]")
help="run aggregate manager", default=False)
parser.add_option("-c", "--component", dest="cm", action="store_true",
help="run component server", default=False)
+ parser.add_option("-t", "--trusted-certs", dest="trusted_certs", action="store_true",
+ help="refresh trusted certs", default=False)
parser.add_option("-v", "--verbose", action="count", dest="verbose", default=0,
help="verbose mode - cumulative")
parser.add_option("-d", "--daemon", dest="daemon", action="store_true",
help="Run as daemon.", default=False)
(options, args) = parser.parse_args()
- sfa_logger().setLevelFromOptVerbose(options.verbose)
-
+
config = Config()
- if config.SFA_API_DEBUG: sfa_logger().setLevelDebug()
+ if config.SFA_API_DEBUG: pass
hierarchy = Hierarchy()
server_key_file = os.path.join(hierarchy.basedir, "server.key")
server_cert_file = os.path.join(hierarchy.basedir, "server.cert")
init_server_key(server_key_file, server_cert_file, config, hierarchy)
init_server(options, config)
- sync_interfaces(server_key_file, server_cert_file)
if (options.daemon): daemon()
+
+ if options.trusted_certs:
+ install_peer_certs(server_key_file, server_cert_file)
+
# start registry server
if (options.registry):
from sfa.server.registry import Registry
try:
main()
except:
- sfa_logger().log_exc_critical("SFA server is exiting")
+ logger.log_exc_critical("SFA server is exiting")
-### $Id$
-### $URL$
-
import os
import sys
import datetime
from sfa.trust.certificate import Keypair, Certificate
from sfa.trust.credential import Credential
-from sfa.trust.trustedroot import TrustedRootList
+from sfa.trust.trustedroots import TrustedRoots
from sfa.util.faults import *
from sfa.trust.hierarchy import Hierarchy
from sfa.util.config import *
from sfa.util.xrn import get_authority
from sfa.util.sfaticket import *
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
class Auth:
"""
self.load_trusted_certs()
def load_trusted_certs(self):
- self.trusted_cert_list = TrustedRootList(self.config.get_trustedroots_dir()).get_list()
- self.trusted_cert_file_list = TrustedRootList(self.config.get_trustedroots_dir()).get_file_list()
+ self.trusted_cert_list = TrustedRoots(self.config.get_trustedroots_dir()).get_list()
+ self.trusted_cert_file_list = TrustedRoots(self.config.get_trustedroots_dir()).get_file_list()
valid = []
if not isinstance(creds, list):
creds = [creds]
- sfa_logger().debug("Auth.checkCredentials with %d creds"%len(creds))
+ logger.debug("Auth.checkCredentials with %d creds"%len(creds))
for cred in creds:
try:
self.check(cred, operation, hrn)
valid.append(cred)
except:
cred_obj=Credential(string=cred)
- sfa_logger().debug("failed to validate credential - dump=%s"%cred_obj.dump_string(dump_parents=True))
+ logger.debug("failed to validate credential - dump=%s"%cred_obj.dump_string(dump_parents=True))
error = sys.exc_info()[:2]
continue
def get_authority(self, hrn):
return get_authority(hrn)
- def filter_creds_by_caller(self, creds, caller_hrn):
+ def filter_creds_by_caller(self, creds, caller_hrn_list):
"""
Returns a list of creds who's gid caller matches the
specified caller hrn
if not isinstance(creds, list):
creds = [creds]
creds = []
+ if not isinistance(caller_hrn_list, list):
+ caller_hrn_list = [caller_hrn_list]
for cred in creds:
try:
tmp_cred = Credential(string=cred)
- if tmp_cred.get_gid_caller().get_hrn() == caller_hrn:
+ if tmp_cred.get_gid_caller().get_hrn() in [caller_hrn_list]:
creds.append(cred)
except: pass
return creds
-#----------------------------------------------------------------------
-# Copyright (c) 2008 Board of Trustees, Princeton University
-#
-# Permission is hereby granted, free of charge, to any person obtaining
-# a copy of this software and/or hardware specification (the "Work") to
-# deal in the Work without restriction, including without limitation the
-# rights to use, copy, modify, merge, publish, distribute, sublicense,
-# and/or sell copies of the Work, and to permit persons to whom the Work
-# is furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be
-# included in all copies or substantial portions of the Work.
-#
-# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
-# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
-# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
-# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
-# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS
-# IN THE WORK.
-#----------------------------------------------------------------------
-
-##
-# SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement
-# the necessary crypto functionality. Ideally just one of these libraries
-# would be used, but unfortunately each of these libraries is independently
-# lacking. The pyOpenSSL library is missing many necessary functions, and
-# the M2Crypto library has crashed inside of some of the functions. The
-# design decision is to use pyOpenSSL whenever possible as it seems more
-# stable, and only use M2Crypto for those functions that are not possible
-# in pyOpenSSL.
-#
-# This module exports two classes: Keypair and Certificate.
-##
-#
-
-import functools
-import os
-import tempfile
-import base64
-import traceback
-from tempfile import mkstemp
-
-from OpenSSL import crypto
-import M2Crypto
-from M2Crypto import X509
-
-from sfa.util.sfalogging import sfa_logger
-from sfa.util.xrn import urn_to_hrn
-from sfa.util.faults import *
-
-glo_passphrase_callback = None
-
-##
-# A global callback msy be implemented for requesting passphrases from the
-# user. The function will be called with three arguments:
-#
-# keypair_obj: the keypair object that is calling the passphrase
-# string: the string containing the private key that's being loaded
-# x: unknown, appears to be 0, comes from pyOpenSSL and/or m2crypto
-#
-# The callback should return a string containing the passphrase.
-
-def set_passphrase_callback(callback_func):
- global glo_passphrase_callback
-
- glo_passphrase_callback = callback_func
-
-##
-# Sets a fixed passphrase.
-
-def set_passphrase(passphrase):
- set_passphrase_callback( lambda k,s,x: passphrase )
-
-##
-# Check to see if a passphrase works for a particular private key string.
-# Intended to be used by passphrase callbacks for input validation.
-
-def test_passphrase(string, passphrase):
- try:
- crypto.load_privatekey(crypto.FILETYPE_PEM, string, (lambda x: passphrase))
- return True
- except:
- return False
-
-def convert_public_key(key):
- keyconvert_path = "/usr/bin/keyconvert.py"
- if not os.path.isfile(keyconvert_path):
- raise IOError, "Could not find keyconvert in %s" % keyconvert_path
-
- # we can only convert rsa keys
- if "ssh-dss" in key:
- return None
-
- (ssh_f, ssh_fn) = tempfile.mkstemp()
- ssl_fn = tempfile.mktemp()
- os.write(ssh_f, key)
- os.close(ssh_f)
-
- cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
- os.system(cmd)
-
- # this check leaves the temporary file containing the public key so
- # that it can be expected to see why it failed.
- # TODO: for production, cleanup the temporary files
- if not os.path.exists(ssl_fn):
- return None
-
- k = Keypair()
- try:
- k.load_pubkey_from_file(ssl_fn)
- except:
- sfa_logger().log_exc("convert_public_key caught exception")
- k = None
-
- # remove the temporary files
- os.remove(ssh_fn)
- os.remove(ssl_fn)
-
- return k
-
-##
-# Public-private key pairs are implemented by the Keypair class.
-# A Keypair object may represent both a public and private key pair, or it
-# may represent only a public key (this usage is consistent with OpenSSL).
-
-class Keypair:
- key = None # public/private keypair
- m2key = None # public key (m2crypto format)
-
- ##
- # Creates a Keypair object
- # @param create If create==True, creates a new public/private key and
- # stores it in the object
- # @param string If string!=None, load the keypair from the string (PEM)
- # @param filename If filename!=None, load the keypair from the file
-
- def __init__(self, create=False, string=None, filename=None):
- if create:
- self.create()
- if string:
- self.load_from_string(string)
- if filename:
- self.load_from_file(filename)
-
- ##
- # Create a RSA public/private key pair and store it inside the keypair object
-
- def create(self):
- self.key = crypto.PKey()
- self.key.generate_key(crypto.TYPE_RSA, 1024)
-
- ##
- # Save the private key to a file
- # @param filename name of file to store the keypair in
-
- def save_to_file(self, filename):
- open(filename, 'w').write(self.as_pem())
- self.filename=filename
-
- ##
- # Load the private key from a file. Implicity the private key includes the public key.
-
- def load_from_file(self, filename):
- self.filename=filename
- buffer = open(filename, 'r').read()
- self.load_from_string(buffer)
-
- ##
- # Load the private key from a string. Implicitly the private key includes the public key.
-
- def load_from_string(self, string):
- if glo_passphrase_callback:
- self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string) )
- self.m2key = M2Crypto.EVP.load_key_string(string, functools.partial(glo_passphrase_callback, self, string) )
- else:
- self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
- self.m2key = M2Crypto.EVP.load_key_string(string)
-
- ##
- # Load the public key from a string. No private key is loaded.
-
- def load_pubkey_from_file(self, filename):
- # load the m2 public key
- m2rsakey = M2Crypto.RSA.load_pub_key(filename)
- self.m2key = M2Crypto.EVP.PKey()
- self.m2key.assign_rsa(m2rsakey)
-
- # create an m2 x509 cert
- m2name = M2Crypto.X509.X509_Name()
- m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
- m2x509 = M2Crypto.X509.X509()
- m2x509.set_pubkey(self.m2key)
- m2x509.set_serial_number(0)
- m2x509.set_issuer_name(m2name)
- m2x509.set_subject_name(m2name)
- ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
- ASN1.set_time(500)
- m2x509.set_not_before(ASN1)
- m2x509.set_not_after(ASN1)
- # x509v3 so it can have extensions
- # prob not necc since this cert itself is junk but still...
- m2x509.set_version(2)
- junk_key = Keypair(create=True)
- m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
-
- # convert the m2 x509 cert to a pyopenssl x509
- m2pem = m2x509.as_pem()
- pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
-
- # get the pyopenssl pkey from the pyopenssl x509
- self.key = pyx509.get_pubkey()
- self.filename=filename
-
- ##
- # Load the public key from a string. No private key is loaded.
-
- def load_pubkey_from_string(self, string):
- (f, fn) = tempfile.mkstemp()
- os.write(f, string)
- os.close(f)
- self.load_pubkey_from_file(fn)
- os.remove(fn)
-
- ##
- # Return the private key in PEM format.
-
- def as_pem(self):
- return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
-
- ##
- # Return an M2Crypto key object
-
- def get_m2_pkey(self):
- if not self.m2key:
- self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
- return self.m2key
-
- ##
- # Returns a string containing the public key represented by this object.
-
- def get_pubkey_string(self):
- m2pkey = self.get_m2_pkey()
- return base64.b64encode(m2pkey.as_der())
-
- ##
- # Return an OpenSSL pkey object
-
- def get_openssl_pkey(self):
- return self.key
-
- ##
- # Given another Keypair object, return TRUE if the two keys are the same.
-
- def is_same(self, pkey):
- return self.as_pem() == pkey.as_pem()
-
- def sign_string(self, data):
- k = self.get_m2_pkey()
- k.sign_init()
- k.sign_update(data)
- return base64.b64encode(k.sign_final())
-
- def verify_string(self, data, sig):
- k = self.get_m2_pkey()
- k.verify_init()
- k.verify_update(data)
- return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
-
- def compute_hash(self, value):
- return self.sign_string(str(value))
-
- # only informative
- def get_filename(self):
- return getattr(self,'filename',None)
-
- def dump (self, *args, **kwargs):
- print self.dump_string(*args, **kwargs)
-
- def dump_string (self):
- result=""
- result += "KEYPAIR: pubkey=%40s..."%self.get_pubkey_string()
- filename=self.get_filename()
- if filename: result += "Filename %s\n"%filename
- return result
-
-##
-# The certificate class implements a general purpose X509 certificate, making
-# use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
-# several addition features, such as the ability to maintain a chain of
-# parent certificates, and storage of application-specific data.
-#
-# Certificates include the ability to maintain a chain of parents. Each
-# certificate includes a pointer to it's parent certificate. When loaded
-# from a file or a string, the parent chain will be automatically loaded.
-# When saving a certificate to a file or a string, the caller can choose
-# whether to save the parent certificates as well.
-
-class Certificate:
- digest = "md5"
-
- cert = None
- issuerKey = None
- issuerSubject = None
- parent = None
-
- separator="-----parent-----"
-
- ##
- # Create a certificate object.
- #
- # @param create If create==True, then also create a blank X509 certificate.
- # @param subject If subject!=None, then create a blank certificate and set
- # it's subject name.
- # @param string If string!=None, load the certficate from the string.
- # @param filename If filename!=None, load the certficiate from the file.
-
- def __init__(self, create=False, subject=None, string=None, filename=None, intermediate=None):
- self.data = {}
- if create or subject:
- self.create()
- if subject:
- self.set_subject(subject)
- if string:
- self.load_from_string(string)
- if filename:
- self.load_from_file(filename)
-
- if intermediate:
- self.set_intermediate_ca(intermediate)
-
- ##
- # Create a blank X509 certificate and store it in this object.
-
- def create(self):
- self.cert = crypto.X509()
- self.cert.set_serial_number(3)
- self.cert.gmtime_adj_notBefore(0)
- self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
- self.cert.set_version(2) # x509v3 so it can have extensions
-
-
- ##
- # Given a pyOpenSSL X509 object, store that object inside of this
- # certificate object.
-
- def load_from_pyopenssl_x509(self, x509):
- self.cert = x509
-
- ##
- # Load the certificate from a string
-
- def load_from_string(self, string):
- # if it is a chain of multiple certs, then split off the first one and
- # load it (support for the ---parent--- tag as well as normal chained certs)
-
- string = string.strip()
-
- # If it's not in proper PEM format, wrap it
- if string.count('-----BEGIN CERTIFICATE') == 0:
- string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
-
- # If there is a PEM cert in there, but there is some other text first
- # such as the text of the certificate, skip the text
- beg = string.find('-----BEGIN CERTIFICATE')
- if beg > 0:
- # skipping over non cert beginning
- string = string[beg:]
-
- parts = []
-
- if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
- string.count(Certificate.separator) == 0:
- parts = string.split('-----END CERTIFICATE-----',1)
- parts[0] += '-----END CERTIFICATE-----'
- else:
- parts = string.split(Certificate.separator, 1)
-
- self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
-
- # if there are more certs, then create a parent and let the parent load
- # itself from the remainder of the string
- if len(parts) > 1 and parts[1] != '':
- self.parent = self.__class__()
- self.parent.load_from_string(parts[1])
-
- ##
- # Load the certificate from a file
-
- def load_from_file(self, filename):
- file = open(filename)
- string = file.read()
- self.load_from_string(string)
- self.filename=filename
-
- ##
- # Save the certificate to a string.
- #
- # @param save_parents If save_parents==True, then also save the parent certificates.
-
- def save_to_string(self, save_parents=True):
- string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
- if save_parents and self.parent:
- string = string + self.parent.save_to_string(save_parents)
- return string
-
- ##
- # Save the certificate to a file.
- # @param save_parents If save_parents==True, then also save the parent certificates.
-
- def save_to_file(self, filename, save_parents=True, filep=None):
- string = self.save_to_string(save_parents=save_parents)
- if filep:
- f = filep
- else:
- f = open(filename, 'w')
- f.write(string)
- f.close()
- self.filename=filename
-
- ##
- # Save the certificate to a random file in /tmp/
- # @param save_parents If save_parents==True, then also save the parent certificates.
- def save_to_random_tmp_file(self, save_parents=True):
- fp, filename = mkstemp(suffix='cert', text=True)
- fp = os.fdopen(fp, "w")
- self.save_to_file(filename, save_parents=True, filep=fp)
- return filename
-
- ##
- # Sets the issuer private key and name
- # @param key Keypair object containing the private key of the issuer
- # @param subject String containing the name of the issuer
- # @param cert (optional) Certificate object containing the name of the issuer
-
- def set_issuer(self, key, subject=None, cert=None):
- self.issuerKey = key
- if subject:
- # it's a mistake to use subject and cert params at the same time
- assert(not cert)
- if isinstance(subject, dict) or isinstance(subject, str):
- req = crypto.X509Req()
- reqSubject = req.get_subject()
- if (isinstance(subject, dict)):
- for key in reqSubject.keys():
- setattr(reqSubject, key, subject[key])
- else:
- setattr(reqSubject, "CN", subject)
- subject = reqSubject
- # subject is not valid once req is out of scope, so save req
- self.issuerReq = req
- if cert:
- # if a cert was supplied, then get the subject from the cert
- subject = cert.cert.get_subject()
- assert(subject)
- self.issuerSubject = subject
-
- ##
- # Get the issuer name
-
- def get_issuer(self, which="CN"):
- x = self.cert.get_issuer()
- return getattr(x, which)
-
- ##
- # Set the subject name of the certificate
-
- def set_subject(self, name):
- req = crypto.X509Req()
- subj = req.get_subject()
- if (isinstance(name, dict)):
- for key in name.keys():
- setattr(subj, key, name[key])
- else:
- setattr(subj, "CN", name)
- self.cert.set_subject(subj)
- ##
- # Get the subject name of the certificate
-
- def get_subject(self, which="CN"):
- x = self.cert.get_subject()
- return getattr(x, which)
-
- ##
- # Get the public key of the certificate.
- #
- # @param key Keypair object containing the public key
-
- def set_pubkey(self, key):
- assert(isinstance(key, Keypair))
- self.cert.set_pubkey(key.get_openssl_pkey())
-
- ##
- # Get the public key of the certificate.
- # It is returned in the form of a Keypair object.
-
- def get_pubkey(self):
- m2x509 = X509.load_cert_string(self.save_to_string())
- pkey = Keypair()
- pkey.key = self.cert.get_pubkey()
- pkey.m2key = m2x509.get_pubkey()
- return pkey
-
- def set_intermediate_ca(self, val):
- self.intermediate = val
- if val:
- self.add_extension('basicConstraints', 1, 'CA:TRUE')
-
-
-
- ##
- # Add an X509 extension to the certificate. Add_extension can only be called
- # once for a particular extension name, due to limitations in the underlying
- # library.
- #
- # @param name string containing name of extension
- # @param value string containing value of the extension
-
- def add_extension(self, name, critical, value):
- ext = crypto.X509Extension (name, critical, value)
- self.cert.add_extensions([ext])
-
- ##
- # Get an X509 extension from the certificate
-
- def get_extension(self, name):
-
- # pyOpenSSL does not have a way to get extensions
- m2x509 = X509.load_cert_string(self.save_to_string())
- value = m2x509.get_ext(name).get_value()
-
- return value
-
- ##
- # Set_data is a wrapper around add_extension. It stores the parameter str in
- # the X509 subject_alt_name extension. Set_data can only be called once, due
- # to limitations in the underlying library.
-
- def set_data(self, str, field='subjectAltName'):
- # pyOpenSSL only allows us to add extensions, so if we try to set the
- # same extension more than once, it will not work
- if self.data.has_key(field):
- raise "Cannot set ", field, " more than once"
- self.data[field] = str
- self.add_extension(field, 0, str)
-
- ##
- # Return the data string that was previously set with set_data
-
- def get_data(self, field='subjectAltName'):
- if self.data.has_key(field):
- return self.data[field]
-
- try:
- uri = self.get_extension(field)
- self.data[field] = uri
- except LookupError:
- return None
-
- return self.data[field]
-
- ##
- # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
-
- def sign(self):
- sfa_logger().debug('certificate.sign')
- assert self.cert != None
- assert self.issuerSubject != None
- assert self.issuerKey != None
- self.cert.set_issuer(self.issuerSubject)
- self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
-
- ##
- # Verify the authenticity of a certificate.
- # @param pkey is a Keypair object representing a public key. If Pkey
- # did not sign the certificate, then an exception will be thrown.
-
- def verify(self, pkey):
- # pyOpenSSL does not have a way to verify signatures
- m2x509 = X509.load_cert_string(self.save_to_string())
- m2pkey = pkey.get_m2_pkey()
- # verify it
- return m2x509.verify(m2pkey)
-
- # XXX alternatively, if openssl has been patched, do the much simpler:
- # try:
- # self.cert.verify(pkey.get_openssl_key())
- # return 1
- # except:
- # return 0
-
- ##
- # Return True if pkey is identical to the public key that is contained in the certificate.
- # @param pkey Keypair object
-
- def is_pubkey(self, pkey):
- return self.get_pubkey().is_same(pkey)
-
- ##
- # Given a certificate cert, verify that this certificate was signed by the
- # public key contained in cert. Throw an exception otherwise.
- #
- # @param cert certificate object
-
- def is_signed_by_cert(self, cert):
- k = cert.get_pubkey()
- result = self.verify(k)
- return result
-
- ##
- # Set the parent certficiate.
- #
- # @param p certificate object.
-
- def set_parent(self, p):
- self.parent = p
-
- ##
- # Return the certificate object of the parent of this certificate.
-
- def get_parent(self):
- return self.parent
-
- ##
- # Verification examines a chain of certificates to ensure that each parent
- # signs the child, and that some certificate in the chain is signed by a
- # trusted certificate.
- #
- # Verification is a basic recursion: <pre>
- # if this_certificate was signed by trusted_certs:
- # return
- # else
- # return verify_chain(parent, trusted_certs)
- # </pre>
- #
- # At each recursion, the parent is tested to ensure that it did sign the
- # child. If a parent did not sign a child, then an exception is thrown. If
- # the bottom of the recursion is reached and the certificate does not match
- # a trusted root, then an exception is thrown.
- #
- # @param Trusted_certs is a list of certificates that are trusted.
- #
-
- def verify_chain(self, trusted_certs = None):
- # Verify a chain of certificates. Each certificate must be signed by
- # the public key contained in it's parent. The chain is recursed
- # until a certificate is found that is signed by a trusted root.
-
- # verify expiration time
- if self.cert.has_expired():
- sfa_logger().debug("verify_chain: NO our certificate has expired")
- raise CertExpired(self.get_subject(), "client cert")
-
- # if this cert is signed by a trusted_cert, then we are set
- for trusted_cert in trusted_certs:
- if self.is_signed_by_cert(trusted_cert):
- # verify expiration of trusted_cert ?
- if not trusted_cert.cert.has_expired():
- sfa_logger().debug("verify_chain: YES cert %s signed by trusted cert %s"%(
- self.get_subject(), trusted_cert.get_subject()))
- return trusted_cert
- else:
- sfa_logger().debug("verify_chain: NO cert %s is signed by trusted_cert %s, but this is expired..."%(
- self.get_subject(),trusted_cert.get_subject()))
- raise CertExpired(self.get_subject(),"trusted_cert %s"%trusted_cert.get_subject())
-
- # if there is no parent, then no way to verify the chain
- if not self.parent:
- sfa_logger().debug("verify_chain: NO %s has no parent and is not in trusted roots"%self.get_subject())
- raise CertMissingParent(self.get_subject())
-
- # if it wasn't signed by the parent...
- if not self.is_signed_by_cert(self.parent):
- sfa_logger().debug("verify_chain: NO %s is not signed by parent"%self.get_subject())
- return CertNotSignedByParent(self.get_subject())
-
- # if the parent isn't verified...
- sfa_logger().debug("verify_chain: .. %s, -> verifying parent %s"%(self.get_subject(),self.parent.get_subject()))
- self.parent.verify_chain(trusted_certs)
-
- return
-
- ### more introspection
- def get_extensions(self):
- # pyOpenSSL does not have a way to get extensions
- triples=[]
- m2x509 = X509.load_cert_string(self.save_to_string())
- nb_extensions=m2x509.get_ext_count()
- sfa_logger().debug("X509 had %d extensions"%nb_extensions)
- for i in range(nb_extensions):
- ext=m2x509.get_ext_at(i)
- triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) )
- return triples
-
- def get_data_names(self):
- return self.data.keys()
-
- def get_all_datas (self):
- triples=self.get_extensions()
- for name in self.get_data_names():
- triples.append( (name,self.get_data(name),'data',) )
- return triples
-
- # only informative
- def get_filename(self):
- return getattr(self,'filename',None)
-
- def dump (self, *args, **kwargs):
- print self.dump_string(*args, **kwargs)
-
- def dump_string (self,show_extensions=False):
- result = ""
- result += "CERTIFICATE for %s\n"%self.get_subject()
- result += "Issued by %s\n"%self.get_issuer()
- filename=self.get_filename()
- if filename: result += "Filename %s\n"%filename
- if show_extensions:
- all_datas=self.get_all_datas()
- result += " has %d extensions/data attached"%len(all_datas)
- for (n,v,c) in all_datas:
- if c=='data':
- result += " data: %s=%s\n"%(n,v)
- else:
- result += " ext: %s (crit=%s)=<<<%s>>>\n"%(n,c,v)
- return result
+#----------------------------------------------------------------------\r
+# Copyright (c) 2008 Board of Trustees, Princeton University\r
+#\r
+# Permission is hereby granted, free of charge, to any person obtaining\r
+# a copy of this software and/or hardware specification (the "Work") to\r
+# deal in the Work without restriction, including without limitation the\r
+# rights to use, copy, modify, merge, publish, distribute, sublicense,\r
+# and/or sell copies of the Work, and to permit persons to whom the Work\r
+# is furnished to do so, subject to the following conditions:\r
+#\r
+# The above copyright notice and this permission notice shall be\r
+# included in all copies or substantial portions of the Work.\r
+#\r
+# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS \r
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF \r
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND \r
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT \r
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, \r
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, \r
+# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS \r
+# IN THE WORK.\r
+#----------------------------------------------------------------------\r
+\r
+##\r
+# SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement\r
+# the necessary crypto functionality. Ideally just one of these libraries\r
+# would be used, but unfortunately each of these libraries is independently\r
+# lacking. The pyOpenSSL library is missing many necessary functions, and\r
+# the M2Crypto library has crashed inside of some of the functions. The\r
+# design decision is to use pyOpenSSL whenever possible as it seems more\r
+# stable, and only use M2Crypto for those functions that are not possible\r
+# in pyOpenSSL.\r
+#\r
+# This module exports two classes: Keypair and Certificate.\r
+##\r
+#\r
+\r
+import functools\r
+import os\r
+import tempfile\r
+import base64\r
+import traceback\r
+from tempfile import mkstemp\r
+\r
+from OpenSSL import crypto\r
+import M2Crypto\r
+from M2Crypto import X509\r
+\r
+from sfa.util.sfalogging import logger\r
+from sfa.util.xrn import urn_to_hrn\r
+from sfa.util.faults import *\r
+from sfa.util.sfalogging import logger\r
+\r
+glo_passphrase_callback = None\r
+\r
+##\r
+# A global callback msy be implemented for requesting passphrases from the\r
+# user. The function will be called with three arguments:\r
+#\r
+# keypair_obj: the keypair object that is calling the passphrase\r
+# string: the string containing the private key that's being loaded\r
+# x: unknown, appears to be 0, comes from pyOpenSSL and/or m2crypto\r
+#\r
+# The callback should return a string containing the passphrase.\r
+\r
+def set_passphrase_callback(callback_func):\r
+ global glo_passphrase_callback\r
+\r
+ glo_passphrase_callback = callback_func\r
+\r
+##\r
+# Sets a fixed passphrase.\r
+\r
+def set_passphrase(passphrase):\r
+ set_passphrase_callback( lambda k,s,x: passphrase )\r
+\r
+##\r
+# Check to see if a passphrase works for a particular private key string.\r
+# Intended to be used by passphrase callbacks for input validation.\r
+\r
+def test_passphrase(string, passphrase):\r
+ try:\r
+ crypto.load_privatekey(crypto.FILETYPE_PEM, string, (lambda x: passphrase))\r
+ return True\r
+ except:\r
+ return False\r
+\r
+def convert_public_key(key):\r
+ keyconvert_path = "/usr/bin/keyconvert.py"\r
+ if not os.path.isfile(keyconvert_path):\r
+ raise IOError, "Could not find keyconvert in %s" % keyconvert_path\r
+\r
+ # we can only convert rsa keys\r
+ if "ssh-dss" in key:\r
+ return None\r
+\r
+ (ssh_f, ssh_fn) = tempfile.mkstemp()\r
+ ssl_fn = tempfile.mktemp()\r
+ os.write(ssh_f, key)\r
+ os.close(ssh_f)\r
+\r
+ cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn\r
+ os.system(cmd)\r
+\r
+ # this check leaves the temporary file containing the public key so\r
+ # that it can be expected to see why it failed.\r
+ # TODO: for production, cleanup the temporary files\r
+ if not os.path.exists(ssl_fn):\r
+ return None\r
+\r
+ k = Keypair()\r
+ try:\r
+ k.load_pubkey_from_file(ssl_fn)\r
+ except:\r
+ logger.log_exc("convert_public_key caught exception")\r
+ k = None\r
+\r
+ # remove the temporary files\r
+ os.remove(ssh_fn)\r
+ os.remove(ssl_fn)\r
+\r
+ return k\r
+\r
+##\r
+# Public-private key pairs are implemented by the Keypair class.\r
+# A Keypair object may represent both a public and private key pair, or it\r
+# may represent only a public key (this usage is consistent with OpenSSL).\r
+\r
+class Keypair:\r
+ key = None # public/private keypair\r
+ m2key = None # public key (m2crypto format)\r
+\r
+ ##\r
+ # Creates a Keypair object\r
+ # @param create If create==True, creates a new public/private key and\r
+ # stores it in the object\r
+ # @param string If string!=None, load the keypair from the string (PEM)\r
+ # @param filename If filename!=None, load the keypair from the file\r
+\r
+ def __init__(self, create=False, string=None, filename=None):\r
+ if create:\r
+ self.create()\r
+ if string:\r
+ self.load_from_string(string)\r
+ if filename:\r
+ self.load_from_file(filename)\r
+\r
+ ##\r
+ # Create a RSA public/private key pair and store it inside the keypair object\r
+\r
+ def create(self):\r
+ self.key = crypto.PKey()\r
+ self.key.generate_key(crypto.TYPE_RSA, 1024)\r
+\r
+ ##\r
+ # Save the private key to a file\r
+ # @param filename name of file to store the keypair in\r
+\r
+ def save_to_file(self, filename):\r
+ open(filename, 'w').write(self.as_pem())\r
+ self.filename=filename\r
+\r
+ ##\r
+ # Load the private key from a file. Implicity the private key includes the public key.\r
+\r
+ def load_from_file(self, filename):\r
+ self.filename=filename\r
+ buffer = open(filename, 'r').read()\r
+ self.load_from_string(buffer)\r
+\r
+ ##\r
+ # Load the private key from a string. Implicitly the private key includes the public key.\r
+\r
+ def load_from_string(self, string):\r
+ if glo_passphrase_callback:\r
+ self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string) )\r
+ self.m2key = M2Crypto.EVP.load_key_string(string, functools.partial(glo_passphrase_callback, self, string) )\r
+ else:\r
+ self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)\r
+ self.m2key = M2Crypto.EVP.load_key_string(string)\r
+\r
+ ##\r
+ # Load the public key from a string. No private key is loaded.\r
+\r
+ def load_pubkey_from_file(self, filename):\r
+ # load the m2 public key\r
+ m2rsakey = M2Crypto.RSA.load_pub_key(filename)\r
+ self.m2key = M2Crypto.EVP.PKey()\r
+ self.m2key.assign_rsa(m2rsakey)\r
+\r
+ # create an m2 x509 cert\r
+ m2name = M2Crypto.X509.X509_Name()\r
+ m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)\r
+ m2x509 = M2Crypto.X509.X509()\r
+ m2x509.set_pubkey(self.m2key)\r
+ m2x509.set_serial_number(0)\r
+ m2x509.set_issuer_name(m2name)\r
+ m2x509.set_subject_name(m2name)\r
+ ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()\r
+ ASN1.set_time(500)\r
+ m2x509.set_not_before(ASN1)\r
+ m2x509.set_not_after(ASN1)\r
+ # x509v3 so it can have extensions\r
+ # prob not necc since this cert itself is junk but still...\r
+ m2x509.set_version(2)\r
+ junk_key = Keypair(create=True)\r
+ m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")\r
+\r
+ # convert the m2 x509 cert to a pyopenssl x509\r
+ m2pem = m2x509.as_pem()\r
+ pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)\r
+\r
+ # get the pyopenssl pkey from the pyopenssl x509\r
+ self.key = pyx509.get_pubkey()\r
+ self.filename=filename\r
+\r
+ ##\r
+ # Load the public key from a string. No private key is loaded.\r
+\r
+ def load_pubkey_from_string(self, string):\r
+ (f, fn) = tempfile.mkstemp()\r
+ os.write(f, string)\r
+ os.close(f)\r
+ self.load_pubkey_from_file(fn)\r
+ os.remove(fn)\r
+\r
+ ##\r
+ # Return the private key in PEM format.\r
+\r
+ def as_pem(self):\r
+ return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)\r
+\r
+ ##\r
+ # Return an M2Crypto key object\r
+\r
+ def get_m2_pkey(self):\r
+ if not self.m2key:\r
+ self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())\r
+ return self.m2key\r
+\r
+ ##\r
+ # Returns a string containing the public key represented by this object.\r
+\r
+ def get_pubkey_string(self):\r
+ m2pkey = self.get_m2_pkey()\r
+ return base64.b64encode(m2pkey.as_der())\r
+\r
+ ##\r
+ # Return an OpenSSL pkey object\r
+\r
+ def get_openssl_pkey(self):\r
+ return self.key\r
+\r
+ ##\r
+ # Given another Keypair object, return TRUE if the two keys are the same.\r
+\r
+ def is_same(self, pkey):\r
+ return self.as_pem() == pkey.as_pem()\r
+\r
+ def sign_string(self, data):\r
+ k = self.get_m2_pkey()\r
+ k.sign_init()\r
+ k.sign_update(data)\r
+ return base64.b64encode(k.sign_final())\r
+\r
+ def verify_string(self, data, sig):\r
+ k = self.get_m2_pkey()\r
+ k.verify_init()\r
+ k.verify_update(data)\r
+ return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)\r
+\r
+ def compute_hash(self, value):\r
+ return self.sign_string(str(value))\r
+\r
+ # only informative\r
+ def get_filename(self):\r
+ return getattr(self,'filename',None)\r
+\r
+ def dump (self, *args, **kwargs):\r
+ print self.dump_string(*args, **kwargs)\r
+\r
+ def dump_string (self):\r
+ result=""\r
+ result += "KEYPAIR: pubkey=%40s..."%self.get_pubkey_string()\r
+ filename=self.get_filename()\r
+ if filename: result += "Filename %s\n"%filename\r
+ return result\r
+\r
+##\r
+# The certificate class implements a general purpose X509 certificate, making\r
+# use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds\r
+# several addition features, such as the ability to maintain a chain of\r
+# parent certificates, and storage of application-specific data.\r
+#\r
+# Certificates include the ability to maintain a chain of parents. Each\r
+# certificate includes a pointer to it's parent certificate. When loaded\r
+# from a file or a string, the parent chain will be automatically loaded.\r
+# When saving a certificate to a file or a string, the caller can choose\r
+# whether to save the parent certificates as well.\r
+\r
+class Certificate:\r
+ digest = "md5"\r
+\r
+ cert = None\r
+ issuerKey = None\r
+ issuerSubject = None\r
+ parent = None\r
+ isCA = None # will be a boolean once set\r
+\r
+ separator="-----parent-----"\r
+\r
+ ##\r
+ # Create a certificate object.\r
+ #\r
+ # @param lifeDays life of cert in days - default is 1825==5 years\r
+ # @param create If create==True, then also create a blank X509 certificate.\r
+ # @param subject If subject!=None, then create a blank certificate and set\r
+ # it's subject name.\r
+ # @param string If string!=None, load the certficate from the string.\r
+ # @param filename If filename!=None, load the certficiate from the file.\r
+ # @param isCA If !=None, set whether this cert is for a CA\r
+\r
+ def __init__(self, lifeDays=1825, create=False, subject=None, string=None, filename=None, isCA=None):\r
+ self.data = {}\r
+ if create or subject:\r
+ self.create(lifeDays)\r
+ if subject:\r
+ self.set_subject(subject)\r
+ if string:\r
+ self.load_from_string(string)\r
+ if filename:\r
+ self.load_from_file(filename)\r
+\r
+ # Set the CA bit if a value was supplied\r
+ if isCA != None:\r
+ self.set_is_ca(isCA)\r
+\r
+ # Create a blank X509 certificate and store it in this object.\r
+\r
+ def create(self, lifeDays=1825):\r
+ self.cert = crypto.X509()\r
+ # FIXME: Use different serial #s\r
+ self.cert.set_serial_number(3)\r
+ self.cert.gmtime_adj_notBefore(0) # 0 means now\r
+ self.cert.gmtime_adj_notAfter(lifeDays*60*60*24) # five years is default\r
+ self.cert.set_version(2) # x509v3 so it can have extensions\r
+\r
+\r
+ ##\r
+ # Given a pyOpenSSL X509 object, store that object inside of this\r
+ # certificate object.\r
+\r
+ def load_from_pyopenssl_x509(self, x509):\r
+ self.cert = x509\r
+\r
+ ##\r
+ # Load the certificate from a string\r
+\r
+ def load_from_string(self, string):\r
+ # if it is a chain of multiple certs, then split off the first one and\r
+ # load it (support for the ---parent--- tag as well as normal chained certs)\r
+\r
+ string = string.strip()\r
+ \r
+ # If it's not in proper PEM format, wrap it\r
+ if string.count('-----BEGIN CERTIFICATE') == 0:\r
+ string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string\r
+\r
+ # If there is a PEM cert in there, but there is some other text first\r
+ # such as the text of the certificate, skip the text\r
+ beg = string.find('-----BEGIN CERTIFICATE')\r
+ if beg > 0:\r
+ # skipping over non cert beginning \r
+ string = string[beg:]\r
+\r
+ parts = []\r
+\r
+ if string.count('-----BEGIN CERTIFICATE-----') > 1 and \\r
+ string.count(Certificate.separator) == 0:\r
+ parts = string.split('-----END CERTIFICATE-----',1)\r
+ parts[0] += '-----END CERTIFICATE-----'\r
+ else:\r
+ parts = string.split(Certificate.separator, 1)\r
+\r
+ self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])\r
+\r
+ # if there are more certs, then create a parent and let the parent load\r
+ # itself from the remainder of the string\r
+ if len(parts) > 1 and parts[1] != '':\r
+ self.parent = self.__class__()\r
+ self.parent.load_from_string(parts[1])\r
+\r
+ ##\r
+ # Load the certificate from a file\r
+\r
+ def load_from_file(self, filename):\r
+ file = open(filename)\r
+ string = file.read()\r
+ self.load_from_string(string)\r
+ self.filename=filename\r
+\r
+ ##\r
+ # Save the certificate to a string.\r
+ #\r
+ # @param save_parents If save_parents==True, then also save the parent certificates.\r
+\r
+ def save_to_string(self, save_parents=True):\r
+ string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)\r
+ if save_parents and self.parent:\r
+ string = string + self.parent.save_to_string(save_parents)\r
+ return string\r
+\r
+ ##\r
+ # Save the certificate to a file.\r
+ # @param save_parents If save_parents==True, then also save the parent certificates.\r
+\r
+ def save_to_file(self, filename, save_parents=True, filep=None):\r
+ string = self.save_to_string(save_parents=save_parents)\r
+ if filep:\r
+ f = filep\r
+ else:\r
+ f = open(filename, 'w')\r
+ f.write(string)\r
+ f.close()\r
+ self.filename=filename\r
+\r
+ ##\r
+ # Save the certificate to a random file in /tmp/\r
+ # @param save_parents If save_parents==True, then also save the parent certificates.\r
+ def save_to_random_tmp_file(self, save_parents=True):\r
+ fp, filename = mkstemp(suffix='cert', text=True)\r
+ fp = os.fdopen(fp, "w")\r
+ self.save_to_file(filename, save_parents=True, filep=fp)\r
+ return filename\r
+\r
+ ##\r
+ # Sets the issuer private key and name\r
+ # @param key Keypair object containing the private key of the issuer\r
+ # @param subject String containing the name of the issuer\r
+ # @param cert (optional) Certificate object containing the name of the issuer\r
+\r
+ def set_issuer(self, key, subject=None, cert=None):\r
+ self.issuerKey = key\r
+ if subject:\r
+ # it's a mistake to use subject and cert params at the same time\r
+ assert(not cert)\r
+ if isinstance(subject, dict) or isinstance(subject, str):\r
+ req = crypto.X509Req()\r
+ reqSubject = req.get_subject()\r
+ if (isinstance(subject, dict)):\r
+ for key in reqSubject.keys():\r
+ setattr(reqSubject, key, subject[key])\r
+ else:\r
+ setattr(reqSubject, "CN", subject)\r
+ subject = reqSubject\r
+ # subject is not valid once req is out of scope, so save req\r
+ self.issuerReq = req\r
+ if cert:\r
+ # if a cert was supplied, then get the subject from the cert\r
+ subject = cert.cert.get_subject()\r
+ assert(subject)\r
+ self.issuerSubject = subject\r
+\r
+ ##\r
+ # Get the issuer name\r
+\r
+ def get_issuer(self, which="CN"):\r
+ x = self.cert.get_issuer()\r
+ return getattr(x, which)\r
+\r
+ ##\r
+ # Set the subject name of the certificate\r
+\r
+ def set_subject(self, name):\r
+ req = crypto.X509Req()\r
+ subj = req.get_subject()\r
+ if (isinstance(name, dict)):\r
+ for key in name.keys():\r
+ setattr(subj, key, name[key])\r
+ else:\r
+ setattr(subj, "CN", name)\r
+ self.cert.set_subject(subj)\r
+\r
+ ##\r
+ # Get the subject name of the certificate\r
+\r
+ def get_subject(self, which="CN"):\r
+ x = self.cert.get_subject()\r
+ return getattr(x, which)\r
+\r
+ ##\r
+ # Get a pretty-print subject name of the certificate\r
+\r
+ def get_printable_subject(self):\r
+ x = self.cert.get_subject()\r
+ return "[ OU: %s, CN: %s, SubjectAltName: %s ]" % (getattr(x, "OU"), getattr(x, "CN"), self.get_data())\r
+\r
+ ##\r
+ # Get the public key of the certificate.\r
+ #\r
+ # @param key Keypair object containing the public key\r
+\r
+ def set_pubkey(self, key):\r
+ assert(isinstance(key, Keypair))\r
+ self.cert.set_pubkey(key.get_openssl_pkey())\r
+\r
+ ##\r
+ # Get the public key of the certificate.\r
+ # It is returned in the form of a Keypair object.\r
+\r
+ def get_pubkey(self):\r
+ m2x509 = X509.load_cert_string(self.save_to_string())\r
+ pkey = Keypair()\r
+ pkey.key = self.cert.get_pubkey()\r
+ pkey.m2key = m2x509.get_pubkey()\r
+ return pkey\r
+\r
+ def set_intermediate_ca(self, val):\r
+ return self.set_is_ca(val)\r
+\r
+ # Set whether this cert is for a CA. All signers and only signers should be CAs.\r
+ # The local member starts unset, letting us check that you only set it once\r
+ # @param val Boolean indicating whether this cert is for a CA\r
+ def set_is_ca(self, val):\r
+ if val is None:\r
+ return\r
+\r
+ if self.isCA != None:\r
+ # Can't double set properties\r
+ raise "Cannot set basicConstraints CA:?? more than once. Was %s, trying to set as %s" % (self.isCA, val)\r
+\r
+ self.isCA = val\r
+ if val:\r
+ self.add_extension('basicConstraints', 1, 'CA:TRUE')\r
+ else:\r
+ self.add_extension('basicConstraints', 1, 'CA:FALSE')\r
+\r
+\r
+\r
+ ##\r
+ # Add an X509 extension to the certificate. Add_extension can only be called\r
+ # once for a particular extension name, due to limitations in the underlying\r
+ # library.\r
+ #\r
+ # @param name string containing name of extension\r
+ # @param value string containing value of the extension\r
+\r
+ def add_extension(self, name, critical, value):\r
+ oldExtVal = None\r
+ try:\r
+ oldExtVal = self.get_extension(name)\r
+ except:\r
+ # M2Crypto LookupError when the extension isn't there (yet)\r
+ pass\r
+\r
+ # This code limits you from adding the extension with the same value\r
+ # The method comment says you shouldn't do this with the same name\r
+ # But actually it (m2crypto) appears to allow you to do this.\r
+ if oldExtVal and oldExtVal == value:\r
+ # don't add this extension again\r
+ # just do nothing as here\r
+ return\r
+ # FIXME: What if they are trying to set with a different value?\r
+ # Is this ever OK? Or should we raise an exception?\r
+# elif oldExtVal:\r
+# raise "Cannot add extension %s which had val %s with new val %s" % (name, oldExtVal, value)\r
+\r
+ ext = crypto.X509Extension (name, critical, value)\r
+ self.cert.add_extensions([ext])\r
+\r
+ ##\r
+ # Get an X509 extension from the certificate\r
+\r
+ def get_extension(self, name):\r
+\r
+ # pyOpenSSL does not have a way to get extensions\r
+ m2x509 = X509.load_cert_string(self.save_to_string())\r
+ value = m2x509.get_ext(name).get_value()\r
+\r
+ return value\r
+\r
+ ##\r
+ # Set_data is a wrapper around add_extension. It stores the parameter str in\r
+ # the X509 subject_alt_name extension. Set_data can only be called once, due\r
+ # to limitations in the underlying library.\r
+\r
+ def set_data(self, str, field='subjectAltName'):\r
+ # pyOpenSSL only allows us to add extensions, so if we try to set the\r
+ # same extension more than once, it will not work\r
+ if self.data.has_key(field):\r
+ raise "Cannot set ", field, " more than once"\r
+ self.data[field] = str\r
+ self.add_extension(field, 0, str)\r
+\r
+ ##\r
+ # Return the data string that was previously set with set_data\r
+\r
+ def get_data(self, field='subjectAltName'):\r
+ if self.data.has_key(field):\r
+ return self.data[field]\r
+\r
+ try:\r
+ uri = self.get_extension(field)\r
+ self.data[field] = uri\r
+ except LookupError:\r
+ return None\r
+\r
+ return self.data[field]\r
+\r
+ ##\r
+ # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().\r
+\r
+ def sign(self):\r
+ logger.debug('certificate.sign')\r
+ assert self.cert != None\r
+ assert self.issuerSubject != None\r
+ assert self.issuerKey != None\r
+ self.cert.set_issuer(self.issuerSubject)\r
+ self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)\r
+\r
+ ##\r
+ # Verify the authenticity of a certificate.\r
+ # @param pkey is a Keypair object representing a public key. If Pkey\r
+ # did not sign the certificate, then an exception will be thrown.\r
+\r
+ def verify(self, pkey):\r
+ # pyOpenSSL does not have a way to verify signatures\r
+ m2x509 = X509.load_cert_string(self.save_to_string())\r
+ m2pkey = pkey.get_m2_pkey()\r
+ # verify it\r
+ return m2x509.verify(m2pkey)\r
+\r
+ # XXX alternatively, if openssl has been patched, do the much simpler:\r
+ # try:\r
+ # self.cert.verify(pkey.get_openssl_key())\r
+ # return 1\r
+ # except:\r
+ # return 0\r
+\r
+ ##\r
+ # Return True if pkey is identical to the public key that is contained in the certificate.\r
+ # @param pkey Keypair object\r
+\r
+ def is_pubkey(self, pkey):\r
+ return self.get_pubkey().is_same(pkey)\r
+\r
+ ##\r
+ # Given a certificate cert, verify that this certificate was signed by the\r
+ # public key contained in cert. Throw an exception otherwise.\r
+ #\r
+ # @param cert certificate object\r
+\r
+ def is_signed_by_cert(self, cert):\r
+ k = cert.get_pubkey()\r
+ result = self.verify(k)\r
+ return result\r
+\r
+ ##\r
+ # Set the parent certficiate.\r
+ #\r
+ # @param p certificate object.\r
+\r
+ def set_parent(self, p):\r
+ self.parent = p\r
+\r
+ ##\r
+ # Return the certificate object of the parent of this certificate.\r
+\r
+ def get_parent(self):\r
+ return self.parent\r
+\r
+ ##\r
+ # Verification examines a chain of certificates to ensure that each parent\r
+ # signs the child, and that some certificate in the chain is signed by a\r
+ # trusted certificate.\r
+ #\r
+ # Verification is a basic recursion: <pre>\r
+ # if this_certificate was signed by trusted_certs:\r
+ # return\r
+ # else\r
+ # return verify_chain(parent, trusted_certs)\r
+ # </pre>\r
+ #\r
+ # At each recursion, the parent is tested to ensure that it did sign the\r
+ # child. If a parent did not sign a child, then an exception is thrown. If\r
+ # the bottom of the recursion is reached and the certificate does not match\r
+ # a trusted root, then an exception is thrown.\r
+ # Also require that parents are CAs.\r
+ #\r
+ # @param Trusted_certs is a list of certificates that are trusted.\r
+ #\r
+\r
+ def verify_chain(self, trusted_certs = None):\r
+ # Verify a chain of certificates. Each certificate must be signed by\r
+ # the public key contained in it's parent. The chain is recursed\r
+ # until a certificate is found that is signed by a trusted root.\r
+\r
+ # verify expiration time\r
+ if self.cert.has_expired():\r
+ logger.debug("verify_chain: NO, Certificate %s has expired" % self.get_printable_subject())\r
+ raise CertExpired(self.get_printable_subject(), "client cert")\r
+\r
+ # if this cert is signed by a trusted_cert, then we are set\r
+ for trusted_cert in trusted_certs:\r
+ if self.is_signed_by_cert(trusted_cert):\r
+ # verify expiration of trusted_cert ?\r
+ if not trusted_cert.cert.has_expired():\r
+ logger.debug("verify_chain: YES. Cert %s signed by trusted cert %s"%(\r
+ self.get_printable_subject(), trusted_cert.get_printable_subject()))\r
+ return trusted_cert\r
+ else:\r
+ logger.debug("verify_chain: NO. Cert %s is signed by trusted_cert %s, but that signer is expired..."%(\r
+ self.get_printable_subject(),trusted_cert.get_printable_subject()))\r
+ raise CertExpired(self.get_printable_subject()," signer trusted_cert %s"%trusted_cert.get_printable_subject())\r
+\r
+ # if there is no parent, then no way to verify the chain\r
+ if not self.parent:\r
+ logger.debug("verify_chain: NO. %s has no parent and issuer %s is not in %d trusted roots"%(self.get_printable_subject(), self.get_issuer(), len(trusted_certs)))\r
+ raise CertMissingParent(self.get_printable_subject() + ": Issuer %s not trusted by any of %d trusted roots, and cert has no parent." % (self.get_issuer(), len(trusted_certs)))\r
+\r
+ # if it wasn't signed by the parent...\r
+ if not self.is_signed_by_cert(self.parent):\r
+ logger.debug("verify_chain: NO. %s is not signed by parent %s, but by %s"%self.get_printable_subject(), self.parent.get_printable_subject(), self.get_issuer())\r
+ raise CertNotSignedByParent(self.get_printable_subject() + ": Parent %s, issuer %s" % (self.parent.get_printable_subject(), self.get_issuer()))\r
+\r
+ # Confirm that the parent is a CA. Only CAs can be trusted as\r
+ # signers.\r
+ # Note that trusted roots are not parents, so don't need to be\r
+ # CAs.\r
+ # Ugly - cert objects aren't parsed so we need to read the\r
+ # extension and hope there are no other basicConstraints\r
+ if not self.parent.isCA and not (self.parent.get_extension('basicConstraints') == 'CA:TRUE'):\r
+ logger.warn("verify_chain: cert %s's parent %s is not a CA" % (self.get_printable_subject(), self.parent.get_printable_subject()))\r
+ raise CertNotSignedByParent(self.get_printable_subject() + ": Parent %s not a CA" % self.parent.get_printable_subject())\r
+\r
+ # if the parent isn't verified...\r
+ logger.debug("verify_chain: .. %s, -> verifying parent %s"%(self.get_printable_subject(),self.parent.get_printable_subject()))\r
+ self.parent.verify_chain(trusted_certs)\r
+\r
+ return\r
+\r
+ ### more introspection\r
+ def get_extensions(self):\r
+ # pyOpenSSL does not have a way to get extensions\r
+ triples=[]\r
+ m2x509 = X509.load_cert_string(self.save_to_string())\r
+ nb_extensions=m2x509.get_ext_count()\r
+ logger.debug("X509 had %d extensions"%nb_extensions)\r
+ for i in range(nb_extensions):\r
+ ext=m2x509.get_ext_at(i)\r
+ triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) )\r
+ return triples\r
+\r
+ def get_data_names(self):\r
+ return self.data.keys()\r
+\r
+ def get_all_datas (self):\r
+ triples=self.get_extensions()\r
+ for name in self.get_data_names():\r
+ triples.append( (name,self.get_data(name),'data',) )\r
+ return triples\r
+\r
+ # only informative\r
+ def get_filename(self):\r
+ return getattr(self,'filename',None)\r
+\r
+ def dump (self, *args, **kwargs):\r
+ print self.dump_string(*args, **kwargs)\r
+\r
+ def dump_string (self,show_extensions=False):\r
+ result = ""\r
+ result += "CERTIFICATE for %s\n"%self.get_printable_subject()\r
+ result += "Issued by %s\n"%self.get_issuer()\r
+ filename=self.get_filename()\r
+ if filename: result += "Filename %s\n"%filename\r
+ if show_extensions:\r
+ all_datas=self.get_all_datas()\r
+ result += " has %d extensions/data attached"%len(all_datas)\r
+ for (n,v,c) in all_datas:\r
+ if c=='data':\r
+ result += " data: %s=%s\n"%(n,v)\r
+ else:\r
+ result += " ext: %s (crit=%s)=<<<%s>>>\n"%(n,c,v)\r
+ return result\r
-#----------------------------------------------------------------------
-# Copyright (c) 2008 Board of Trustees, Princeton University
-#
-# Permission is hereby granted, free of charge, to any person obtaining
-# a copy of this software and/or hardware specification (the "Work") to
-# deal in the Work without restriction, including without limitation the
-# rights to use, copy, modify, merge, publish, distribute, sublicense,
-# and/or sell copies of the Work, and to permit persons to whom the Work
-# is furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be
-# included in all copies or substantial portions of the Work.
-#
-# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
-# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
-# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
-# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
-# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS
-# IN THE WORK.
-#----------------------------------------------------------------------
-##
-# Implements SFA Credentials
-#
-# Credentials are signed XML files that assign a subject gid privileges to an object gid
-##
-
-import os
-import datetime
-from tempfile import mkstemp
-import dateutil.parser
-from StringIO import StringIO
-from xml.dom.minidom import Document, parseString
-from lxml import etree
-
-from sfa.util.faults import *
-from sfa.util.sfalogging import sfa_logger
-from sfa.trust.certificate import Keypair
-from sfa.trust.credential_legacy import CredentialLegacy
-from sfa.trust.rights import Right, Rights
-from sfa.trust.gid import GID
-from sfa.util.xrn import urn_to_hrn
-
-# 2 weeks, in seconds
-DEFAULT_CREDENTIAL_LIFETIME = 86400 * 14
-
-
-# TODO:
-# . make privs match between PG and PL
-# . Need to add support for other types of credentials, e.g. tickets
-
-
-signature_template = \
-'''
-<Signature xml:id="Sig_%s" xmlns="http://www.w3.org/2000/09/xmldsig#">
- <SignedInfo>
- <CanonicalizationMethod Algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315"/>
- <SignatureMethod Algorithm="http://www.w3.org/2000/09/xmldsig#rsa-sha1"/>
- <Reference URI="#%s">
- <Transforms>
- <Transform Algorithm="http://www.w3.org/2000/09/xmldsig#enveloped-signature" />
- </Transforms>
- <DigestMethod Algorithm="http://www.w3.org/2000/09/xmldsig#sha1"/>
- <DigestValue></DigestValue>
- </Reference>
- </SignedInfo>
- <SignatureValue />
- <KeyInfo>
- <X509Data>
- <X509SubjectName/>
- <X509IssuerSerial/>
- <X509Certificate/>
- </X509Data>
- <KeyValue />
- </KeyInfo>
- </Signature>
-'''
-
-##
-# Convert a string into a bool
-
-def str2bool(str):
- if str.lower() in ['yes','true','1']:
- return True
- return False
-
-
-##
-# Utility function to get the text of an XML element
-
-def getTextNode(element, subele):
- sub = element.getElementsByTagName(subele)[0]
- if len(sub.childNodes) > 0:
- return sub.childNodes[0].nodeValue
- else:
- return None
-
-##
-# Utility function to set the text of an XML element
-# It creates the element, adds the text to it,
-# and then appends it to the parent.
-
-def append_sub(doc, parent, element, text):
- ele = doc.createElement(element)
- ele.appendChild(doc.createTextNode(text))
- parent.appendChild(ele)
-
-##
-# Signature contains information about an xmlsec1 signature
-# for a signed-credential
-#
-
-class Signature(object):
-
- def __init__(self, string=None):
- self.refid = None
- self.issuer_gid = None
- self.xml = None
- if string:
- self.xml = string
- self.decode()
-
-
- def get_refid(self):
- if not self.refid:
- self.decode()
- return self.refid
-
- def get_xml(self):
- if not self.xml:
- self.encode()
- return self.xml
-
- def set_refid(self, id):
- self.refid = id
-
- def get_issuer_gid(self):
- if not self.gid:
- self.decode()
- return self.gid
-
- def set_issuer_gid(self, gid):
- self.gid = gid
-
- def decode(self):
- doc = parseString(self.xml)
- sig = doc.getElementsByTagName("Signature")[0]
- self.set_refid(sig.getAttribute("xml:id").strip("Sig_"))
- keyinfo = sig.getElementsByTagName("X509Data")[0]
- szgid = getTextNode(keyinfo, "X509Certificate")
- szgid = "-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----" % szgid
- self.set_issuer_gid(GID(string=szgid))
-
- def encode(self):
- self.xml = signature_template % (self.get_refid(), self.get_refid())
-
-
-##
-# A credential provides a caller gid with privileges to an object gid.
-# A signed credential is signed by the object's authority.
-#
-# Credentials are encoded in one of two ways. The legacy style places
-# it in the subjectAltName of an X509 certificate. The new credentials
-# are placed in signed XML.
-#
-# WARNING:
-# In general, a signed credential obtained externally should
-# not be changed else the signature is no longer valid. So, once
-# you have loaded an existing signed credential, do not call encode() or sign() on it.
-
-def filter_creds_by_caller(creds, caller_hrn):
- """
- Returns a list of creds who's gid caller matches the
- specified caller hrn
- """
- if not isinstance(creds, list): creds = [creds]
- caller_creds = []
- for cred in creds:
- try:
- tmp_cred = Credential(string=cred)
- if tmp_cred.get_gid_caller().get_hrn() == caller_hrn:
- caller_creds.append(cred)
- except: pass
- return caller_creds
-
-class Credential(object):
-
- ##
- # Create a Credential object
- #
- # @param create If true, create a blank x509 certificate
- # @param subject If subject!=None, create an x509 cert with the subject name
- # @param string If string!=None, load the credential from the string
- # @param filename If filename!=None, load the credential from the file
- # FIXME: create and subject are ignored!
- def __init__(self, create=False, subject=None, string=None, filename=None):
- self.gidCaller = None
- self.gidObject = None
- self.expiration = None
- self.privileges = None
- self.issuer_privkey = None
- self.issuer_gid = None
- self.issuer_pubkey = None
- self.parent = None
- self.signature = None
- self.xml = None
- self.refid = None
- self.legacy = None
-
- # Check if this is a legacy credential, translate it if so
- if string or filename:
- if string:
- str = string
- elif filename:
- str = file(filename).read()
- self.filename=filename
-
- if str.strip().startswith("-----"):
- self.legacy = CredentialLegacy(False,string=str)
- self.translate_legacy(str)
- else:
- self.xml = str
- self.decode()
-
- # Find an xmlsec1 path
- self.xmlsec_path = ''
- paths = ['/usr/bin','/usr/local/bin','/bin','/opt/bin','/opt/local/bin']
- for path in paths:
- if os.path.isfile(path + '/' + 'xmlsec1'):
- self.xmlsec_path = path + '/' + 'xmlsec1'
- break
-
- def get_subject(self):
- if not self.gidObject:
- self.decode()
- return self.gidObject.get_subject()
-
- def get_signature(self):
- if not self.signature:
- self.decode()
- return self.signature
-
- def set_signature(self, sig):
- self.signature = sig
-
-
- ##
- # Translate a legacy credential into a new one
- #
- # @param String of the legacy credential
-
- def translate_legacy(self, str):
- legacy = CredentialLegacy(False,string=str)
- self.gidCaller = legacy.get_gid_caller()
- self.gidObject = legacy.get_gid_object()
- lifetime = legacy.get_lifetime()
- if not lifetime:
- self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME))
- else:
- self.set_expiration(int(lifetime))
- self.lifeTime = legacy.get_lifetime()
- self.set_privileges(legacy.get_privileges())
- self.get_privileges().delegate_all_privileges(legacy.get_delegate())
-
- ##
- # Need the issuer's private key and name
- # @param key Keypair object containing the private key of the issuer
- # @param gid GID of the issuing authority
-
- def set_issuer_keys(self, privkey, gid):
- self.issuer_privkey = privkey
- self.issuer_gid = gid
-
-
- ##
- # Set this credential's parent
- def set_parent(self, cred):
- self.parent = cred
- self.updateRefID()
-
- ##
- # set the GID of the caller
- #
- # @param gid GID object of the caller
-
- def set_gid_caller(self, gid):
- self.gidCaller = gid
- # gid origin caller is the caller's gid by default
- self.gidOriginCaller = gid
-
- ##
- # get the GID of the object
-
- def get_gid_caller(self):
- if not self.gidCaller:
- self.decode()
- return self.gidCaller
-
- ##
- # set the GID of the object
- #
- # @param gid GID object of the object
-
- def set_gid_object(self, gid):
- self.gidObject = gid
-
- ##
- # get the GID of the object
-
- def get_gid_object(self):
- if not self.gidObject:
- self.decode()
- return self.gidObject
-
-
-
- ##
- # Expiration: an absolute UTC time of expiration (as either an int or datetime)
- #
- def set_expiration(self, expiration):
- if isinstance(expiration, int):
- self.expiration = datetime.datetime.fromtimestamp(expiration)
- else:
- self.expiration = expiration
-
-
- ##
- # get the lifetime of the credential (in datetime format)
-
- def get_expiration(self):
- if not self.expiration:
- self.decode()
- return self.expiration
-
- ##
- # For legacy sake
- def get_lifetime(self):
- return self.get_expiration()
-
- ##
- # set the privileges
- #
- # @param privs either a comma-separated list of privileges of a Rights object
-
- def set_privileges(self, privs):
- if isinstance(privs, str):
- self.privileges = Rights(string = privs)
- else:
- self.privileges = privs
-
-
- ##
- # return the privileges as a Rights object
-
- def get_privileges(self):
- if not self.privileges:
- self.decode()
- return self.privileges
-
- ##
- # determine whether the credential allows a particular operation to be
- # performed
- #
- # @param op_name string specifying name of operation ("lookup", "update", etc)
-
- def can_perform(self, op_name):
- rights = self.get_privileges()
-
- if not rights:
- return False
-
- return rights.can_perform(op_name)
-
-
- ##
- # Encode the attributes of the credential into an XML string
- # This should be done immediately before signing the credential.
- # WARNING:
- # In general, a signed credential obtained externally should
- # not be changed else the signature is no longer valid. So, once
- # you have loaded an existing signed credential, do not call encode() or sign() on it.
-
- def encode(self):
- # Create the XML document
- doc = Document()
- signed_cred = doc.createElement("signed-credential")
- doc.appendChild(signed_cred)
-
- # Fill in the <credential> bit
- cred = doc.createElement("credential")
- cred.setAttribute("xml:id", self.get_refid())
- signed_cred.appendChild(cred)
- append_sub(doc, cred, "type", "privilege")
- append_sub(doc, cred, "serial", "8")
- append_sub(doc, cred, "owner_gid", self.gidCaller.save_to_string())
- append_sub(doc, cred, "owner_urn", self.gidCaller.get_urn())
- append_sub(doc, cred, "target_gid", self.gidObject.save_to_string())
- append_sub(doc, cred, "target_urn", self.gidObject.get_urn())
- append_sub(doc, cred, "uuid", "")
- if not self.expiration:
- self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME))
- self.expiration = self.expiration.replace(microsecond=0)
- append_sub(doc, cred, "expires", self.expiration.isoformat())
- privileges = doc.createElement("privileges")
- cred.appendChild(privileges)
-
- if self.privileges:
- rights = self.get_privileges()
- for right in rights.rights:
- priv = doc.createElement("privilege")
- append_sub(doc, priv, "name", right.kind)
- append_sub(doc, priv, "can_delegate", str(right.delegate).lower())
- privileges.appendChild(priv)
-
- # Add the parent credential if it exists
- if self.parent:
- sdoc = parseString(self.parent.get_xml())
- p_cred = doc.importNode(sdoc.getElementsByTagName("credential")[0], True)
- p = doc.createElement("parent")
- p.appendChild(p_cred)
- cred.appendChild(p)
-
-
- # Create the <signatures> tag
- signatures = doc.createElement("signatures")
- signed_cred.appendChild(signatures)
-
- # Add any parent signatures
- if self.parent:
- for cur_cred in self.get_credential_list()[1:]:
- sdoc = parseString(cur_cred.get_signature().get_xml())
- ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True)
- signatures.appendChild(ele)
-
- # Get the finished product
- self.xml = doc.toxml()
-
-
- def save_to_random_tmp_file(self):
- fp, filename = mkstemp(suffix='cred', text=True)
- fp = os.fdopen(fp, "w")
- self.save_to_file(filename, save_parents=True, filep=fp)
- return filename
-
- def save_to_file(self, filename, save_parents=True, filep=None):
- if not self.xml:
- self.encode()
- if filep:
- f = filep
- else:
- f = open(filename, "w")
- f.write(self.xml)
- f.close()
- self.filename=filename
-
- def save_to_string(self, save_parents=True):
- if not self.xml:
- self.encode()
- return self.xml
-
- def get_refid(self):
- if not self.refid:
- self.refid = 'ref0'
- return self.refid
-
- def set_refid(self, rid):
- self.refid = rid
-
- ##
- # Figure out what refids exist, and update this credential's id
- # so that it doesn't clobber the others. Returns the refids of
- # the parents.
-
- def updateRefID(self):
- if not self.parent:
- self.set_refid('ref0')
- return []
-
- refs = []
-
- next_cred = self.parent
- while next_cred:
- refs.append(next_cred.get_refid())
- if next_cred.parent:
- next_cred = next_cred.parent
- else:
- next_cred = None
-
-
- # Find a unique refid for this credential
- rid = self.get_refid()
- while rid in refs:
- val = int(rid[3:])
- rid = "ref%d" % (val + 1)
-
- # Set the new refid
- self.set_refid(rid)
-
- # Return the set of parent credential ref ids
- return refs
-
- def get_xml(self):
- if not self.xml:
- self.encode()
- return self.xml
-
- ##
- # Sign the XML file created by encode()
- #
- # WARNING:
- # In general, a signed credential obtained externally should
- # not be changed else the signature is no longer valid. So, once
- # you have loaded an existing signed credential, do not call encode() or sign() on it.
-
- def sign(self):
- if not self.issuer_privkey or not self.issuer_gid:
- return
- doc = parseString(self.get_xml())
- sigs = doc.getElementsByTagName("signatures")[0]
-
- # Create the signature template to be signed
- signature = Signature()
- signature.set_refid(self.get_refid())
- sdoc = parseString(signature.get_xml())
- sig_ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True)
- sigs.appendChild(sig_ele)
-
- self.xml = doc.toxml()
-
-
- # Split the issuer GID into multiple certificates if it's a chain
- chain = GID(filename=self.issuer_gid)
- gid_files = []
- while chain:
- gid_files.append(chain.save_to_random_tmp_file(False))
- if chain.get_parent():
- chain = chain.get_parent()
- else:
- chain = None
-
-
- # Call out to xmlsec1 to sign it
- ref = 'Sig_%s' % self.get_refid()
- filename = self.save_to_random_tmp_file()
- signed = os.popen('%s --sign --node-id "%s" --privkey-pem %s,%s %s' \
- % (self.xmlsec_path, ref, self.issuer_privkey, ",".join(gid_files), filename)).read()
- os.remove(filename)
-
- for gid_file in gid_files:
- os.remove(gid_file)
-
- self.xml = signed
-
- # This is no longer a legacy credential
- if self.legacy:
- self.legacy = None
-
- # Update signatures
- self.decode()
-
-
- ##
- # Retrieve the attributes of the credential from the XML.
- # This is automatically called by the various get_* methods of
- # this class and should not need to be called explicitly.
-
- def decode(self):
- if not self.xml:
- return
- doc = parseString(self.xml)
- sigs = []
- signed_cred = doc.getElementsByTagName("signed-credential")
-
- # Is this a signed-cred or just a cred?
- if len(signed_cred) > 0:
- cred = signed_cred[0].getElementsByTagName("credential")[0]
- signatures = signed_cred[0].getElementsByTagName("signatures")
- if len(signatures) > 0:
- sigs = signatures[0].getElementsByTagName("Signature")
- else:
- cred = doc.getElementsByTagName("credential")[0]
-
-
- self.set_refid(cred.getAttribute("xml:id"))
- self.set_expiration(dateutil.parser.parse(getTextNode(cred, "expires")))
- self.gidCaller = GID(string=getTextNode(cred, "owner_gid"))
- self.gidObject = GID(string=getTextNode(cred, "target_gid"))
-
-
- # Process privileges
- privs = cred.getElementsByTagName("privileges")[0]
- rlist = Rights()
- for priv in privs.getElementsByTagName("privilege"):
- kind = getTextNode(priv, "name")
- deleg = str2bool(getTextNode(priv, "can_delegate"))
- if kind == '*':
- # Convert * into the default privileges for the credential's type
- _ , type = urn_to_hrn(self.gidObject.get_urn())
- rl = rlist.determine_rights(type, self.gidObject.get_urn())
- for r in rl.rights:
- rlist.add(r)
- else:
- rlist.add(Right(kind.strip(), deleg))
- self.set_privileges(rlist)
-
-
- # Is there a parent?
- parent = cred.getElementsByTagName("parent")
- if len(parent) > 0:
- parent_doc = parent[0].getElementsByTagName("credential")[0]
- parent_xml = parent_doc.toxml()
- self.parent = Credential(string=parent_xml)
- self.updateRefID()
-
- # Assign the signatures to the credentials
- for sig in sigs:
- Sig = Signature(string=sig.toxml())
-
- for cur_cred in self.get_credential_list():
- if cur_cred.get_refid() == Sig.get_refid():
- cur_cred.set_signature(Sig)
-
-
- ##
- # Verify
- # trusted_certs: A list of trusted GID filenames (not GID objects!)
- # Chaining is not supported within the GIDs by xmlsec1.
- #
- # Verify that:
- # . All of the signatures are valid and that the issuers trace back
- # to trusted roots (performed by xmlsec1)
- # . The XML matches the credential schema
- # . That the issuer of the credential is the authority in the target's urn
- # . In the case of a delegated credential, this must be true of the root
- # . That all of the gids presented in the credential are valid
- # . The credential is not expired
- #
- # -- For Delegates (credentials with parents)
- # . The privileges must be a subset of the parent credentials
- # . The privileges must have "can_delegate" set for each delegated privilege
- # . The target gid must be the same between child and parents
- # . The expiry time on the child must be no later than the parent
- # . The signer of the child must be the owner of the parent
- #
- # -- Verify does *NOT*
- # . ensure that an xmlrpc client's gid matches a credential gid, that
- # must be done elsewhere
- #
- # @param trusted_certs: The certificates of trusted CA certificates
- # @param schema: The RelaxNG schema to validate the credential against
- def verify(self, trusted_certs, schema=None):
- if not self.xml:
- self.decode()
-
- # validate against RelaxNG schema
- if not self.legacy:
- if schema and os.path.exists(schema):
- tree = etree.parse(StringIO(self.xml))
- schema_doc = etree.parse(schema)
- xmlschema = etree.XMLSchema(schema_doc)
- if not xmlschema.validate(tree):
- error = xmlschema.error_log.last_error
- message = "%s (line %s)" % (error.message, error.line)
- raise CredentialNotVerifiable(message)
-
-
-# trusted_cert_objects = [GID(filename=f) for f in trusted_certs]
- trusted_cert_objects = []
- ok_trusted_certs = []
- for f in trusted_certs:
- try:
- # Failures here include unreadable files
- # or non PEM files
- trusted_cert_objects.append(GID(filename=f))
- ok_trusted_certs.append(f)
- except Exception, exc:
- sfa_logger().error("Failed to load trusted cert from %s: %r"%( f, exc))
- trusted_certs = ok_trusted_certs
-
- # Use legacy verification if this is a legacy credential
- if self.legacy:
- self.legacy.verify_chain(trusted_cert_objects)
- if self.legacy.client_gid:
- self.legacy.client_gid.verify_chain(trusted_cert_objects)
- if self.legacy.object_gid:
- self.legacy.object_gid.verify_chain(trusted_cert_objects)
- return True
-
-
- # make sure it is not expired
- if self.get_expiration() < datetime.datetime.utcnow():
- raise CredentialNotVerifiable("Credential expired at %s" % self.expiration.isoformat())
-
- # Verify the signatures
- filename = self.save_to_random_tmp_file()
- cert_args = " ".join(['--trusted-pem %s' % x for x in trusted_certs])
-
- # Verify the gids of this cred and of its parents
- for cur_cred in self.get_credential_list():
- cur_cred.get_gid_object().verify_chain(trusted_cert_objects)
- cur_cred.get_gid_caller().verify_chain(trusted_cert_objects)
-
- refs = []
- refs.append("Sig_%s" % self.get_refid())
-
- parentRefs = self.updateRefID()
- for ref in parentRefs:
- refs.append("Sig_%s" % ref)
-
- for ref in refs:
- verified = os.popen('%s --verify --node-id "%s" %s %s 2>&1' \
- % (self.xmlsec_path, ref, cert_args, filename)).read()
- if not verified.strip().startswith("OK"):
- raise CredentialNotVerifiable("xmlsec1 error verifying cert: " + verified)
- os.remove(filename)
-
- # Verify the parents (delegation)
- if self.parent:
- self.verify_parent(self.parent)
-
- # Make sure the issuer is the target's authority
- self.verify_issuer()
- return True
-
- ##
- # Creates a list of the credential and its parents, with the root
- # (original delegated credential) as the last item in the list
- def get_credential_list(self):
- cur_cred = self
- list = []
- while cur_cred:
- list.append(cur_cred)
- if cur_cred.parent:
- cur_cred = cur_cred.parent
- else:
- cur_cred = None
- return list
-
- ##
- # Make sure the credential's target gid was signed by (or is the same) the entity that signed
- # the original credential or an authority over that namespace.
- def verify_issuer(self):
- root_cred = self.get_credential_list()[-1]
- root_target_gid = root_cred.get_gid_object()
- root_cred_signer = root_cred.get_signature().get_issuer_gid()
-
- if root_target_gid.is_signed_by_cert(root_cred_signer):
- # cred signer matches target signer, return success
- return
-
- root_target_gid_str = root_target_gid.save_to_string()
- root_cred_signer_str = root_cred_signer.save_to_string()
- if root_target_gid_str == root_cred_signer_str:
- # cred signer is target, return success
- return
-
- # See if it the signer is an authority over the domain of the target
- # Maybe should be (hrn, type) = urn_to_hrn(root_cred_signer.get_urn())
- root_cred_signer_type = root_cred_signer.get_type()
- if (root_cred_signer_type == 'authority'):
- #sfa_logger().debug('Cred signer is an authority')
- # signer is an authority, see if target is in authority's domain
- hrn = root_cred_signer.get_hrn()
- if root_target_gid.get_hrn().startswith(hrn):
- return
-
- # We've required that the credential be signed by an authority
- # for that domain. Reasonable and probably correct.
- # A looser model would also allow the signer to be an authority
- # in my control framework - eg My CA or CH. Even if it is not
- # the CH that issued these, eg, user credentials.
-
- # Give up, credential does not pass issuer verification
-
- raise CredentialNotVerifiable("Could not verify credential owned by %s for object %s. Cred signer %s not the trusted authority for Cred target %s" % (self.gidCaller.get_urn(), self.gidObject.get_urn(), root_cred_signer.get_hrn(), root_target_gid.get_hrn()))
-
-
- ##
- # -- For Delegates (credentials with parents) verify that:
- # . The privileges must be a subset of the parent credentials
- # . The privileges must have "can_delegate" set for each delegated privilege
- # . The target gid must be the same between child and parents
- # . The expiry time on the child must be no later than the parent
- # . The signer of the child must be the owner of the parent
- def verify_parent(self, parent_cred):
- # make sure the rights given to the child are a subset of the
- # parents rights (and check delegate bits)
- if not parent_cred.get_privileges().is_superset(self.get_privileges()):
- raise ChildRightsNotSubsetOfParent(
- self.parent.get_privileges().save_to_string() + " " +
- self.get_privileges().save_to_string())
-
- # make sure my target gid is the same as the parent's
- if not parent_cred.get_gid_object().save_to_string() == \
- self.get_gid_object().save_to_string():
- raise CredentialNotVerifiable("Target gid not equal between parent and child")
-
- # make sure my expiry time is <= my parent's
- if not parent_cred.get_expiration() >= self.get_expiration():
- raise CredentialNotVerifiable("Delegated credential expires after parent")
-
- # make sure my signer is the parent's caller
- if not parent_cred.get_gid_caller().save_to_string(False) == \
- self.get_signature().get_issuer_gid().save_to_string(False):
- raise CredentialNotVerifiable("Delegated credential not signed by parent caller")
-
- # Recurse
- if parent_cred.parent:
- parent_cred.verify_parent(parent_cred.parent)
-
-
- def delegate(self, delegee_gidfile, caller_keyfile, caller_gidfile):
- """
- Return a delegated copy of this credential, delegated to the
- specified gid's user.
- """
- # get the gid of the object we are delegating
- object_gid = self.get_gid_object()
- object_hrn = object_gid.get_hrn()
-
- # the hrn of the user who will be delegated to
- delegee_gid = GID(filename=delegee_gidfile)
- delegee_hrn = delegee_gid.get_hrn()
-
- #user_key = Keypair(filename=keyfile)
- #user_hrn = self.get_gid_caller().get_hrn()
- subject_string = "%s delegated to %s" % (object_hrn, delegee_hrn)
- dcred = Credential(subject=subject_string)
- dcred.set_gid_caller(delegee_gid)
- dcred.set_gid_object(object_gid)
- dcred.set_parent(self)
- dcred.set_expiration(self.get_expiration())
- dcred.set_privileges(self.get_privileges())
- dcred.get_privileges().delegate_all_privileges(True)
- #dcred.set_issuer_keys(keyfile, delegee_gidfile)
- dcred.set_issuer_keys(caller_keyfile, caller_gidfile)
- dcred.encode()
- dcred.sign()
-
- return dcred
-
- # only informative
- def get_filename(self):
- return getattr(self,'filename',None)
-
- # @param dump_parents If true, also dump the parent certificates
- def dump (self, *args, **kwargs):
- print self.dump_string(*args, **kwargs)
-
- def dump_string(self, dump_parents=False):
- result=""
- result += "CREDENTIAL %s\n" % self.get_subject()
- filename=self.get_filename()
- if filename: result += "Filename %s\n"%filename
- result += " privs: %s\n" % self.get_privileges().save_to_string()
- gidCaller = self.get_gid_caller()
- if gidCaller:
- result += " gidCaller:\n"
- result += gidCaller.dump_string(8, dump_parents)
-
- gidObject = self.get_gid_object()
- if gidObject:
- result += " gidObject:\n"
- result += gidObject.dump_string(8, dump_parents)
-
- if self.parent and dump_parents:
- result += "PARENT"
- result += self.parent.dump_string(dump_parents)
- return result
-
+#----------------------------------------------------------------------\r
+# Copyright (c) 2008 Board of Trustees, Princeton University\r
+#\r
+# Permission is hereby granted, free of charge, to any person obtaining\r
+# a copy of this software and/or hardware specification (the "Work") to\r
+# deal in the Work without restriction, including without limitation the\r
+# rights to use, copy, modify, merge, publish, distribute, sublicense,\r
+# and/or sell copies of the Work, and to permit persons to whom the Work\r
+# is furnished to do so, subject to the following conditions:\r
+#\r
+# The above copyright notice and this permission notice shall be\r
+# included in all copies or substantial portions of the Work.\r
+#\r
+# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS \r
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF \r
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND \r
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT \r
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, \r
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, \r
+# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS \r
+# IN THE WORK.\r
+#----------------------------------------------------------------------\r
+##\r
+# Implements SFA Credentials\r
+#\r
+# Credentials are signed XML files that assign a subject gid privileges to an object gid\r
+##\r
+\r
+import os\r
+from types import StringTypes\r
+import datetime\r
+from StringIO import StringIO\r
+from tempfile import mkstemp\r
+from xml.dom.minidom import Document, parseString\r
+\r
+HAVELXML = False\r
+try:\r
+ from lxml import etree\r
+ HAVELXML = True\r
+except:\r
+ pass\r
+\r
+from sfa.util.faults import *\r
+from sfa.util.sfalogging import logger\r
+from sfa.util.sfatime import utcparse\r
+from sfa.trust.certificate import Keypair\r
+from sfa.trust.credential_legacy import CredentialLegacy\r
+from sfa.trust.rights import Right, Rights, determine_rights\r
+from sfa.trust.gid import GID\r
+from sfa.util.xrn import urn_to_hrn, hrn_authfor_hrn\r
+\r
+# 2 weeks, in seconds \r
+DEFAULT_CREDENTIAL_LIFETIME = 86400 * 14\r
+\r
+\r
+# TODO:\r
+# . make privs match between PG and PL\r
+# . Need to add support for other types of credentials, e.g. tickets\r
+# . add namespaces to signed-credential element?\r
+\r
+signature_template = \\r
+'''\r
+<Signature xml:id="Sig_%s" xmlns="http://www.w3.org/2000/09/xmldsig#">\r
+ <SignedInfo>\r
+ <CanonicalizationMethod Algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315"/>\r
+ <SignatureMethod Algorithm="http://www.w3.org/2000/09/xmldsig#rsa-sha1"/>\r
+ <Reference URI="#%s">\r
+ <Transforms>\r
+ <Transform Algorithm="http://www.w3.org/2000/09/xmldsig#enveloped-signature" />\r
+ </Transforms>\r
+ <DigestMethod Algorithm="http://www.w3.org/2000/09/xmldsig#sha1"/>\r
+ <DigestValue></DigestValue>\r
+ </Reference>\r
+ </SignedInfo>\r
+ <SignatureValue />\r
+ <KeyInfo>\r
+ <X509Data>\r
+ <X509SubjectName/>\r
+ <X509IssuerSerial/>\r
+ <X509Certificate/>\r
+ </X509Data>\r
+ <KeyValue />\r
+ </KeyInfo>\r
+</Signature>\r
+'''\r
+\r
+# PG formats the template (whitespace) slightly differently.\r
+# Note that they don't include the xmlns in the template, but add it later.\r
+# Otherwise the two are equivalent.\r
+#signature_template_as_in_pg = \\r
+#'''\r
+#<Signature xml:id="Sig_%s" >\r
+# <SignedInfo>\r
+# <CanonicalizationMethod Algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315"/>\r
+# <SignatureMethod Algorithm="http://www.w3.org/2000/09/xmldsig#rsa-sha1"/>\r
+# <Reference URI="#%s">\r
+# <Transforms>\r
+# <Transform Algorithm="http://www.w3.org/2000/09/xmldsig#enveloped-signature" />\r
+# </Transforms>\r
+# <DigestMethod Algorithm="http://www.w3.org/2000/09/xmldsig#sha1"/>\r
+# <DigestValue></DigestValue>\r
+# </Reference>\r
+# </SignedInfo>\r
+# <SignatureValue />\r
+# <KeyInfo>\r
+# <X509Data >\r
+# <X509SubjectName/>\r
+# <X509IssuerSerial/>\r
+# <X509Certificate/>\r
+# </X509Data>\r
+# <KeyValue />\r
+# </KeyInfo>\r
+#</Signature>\r
+#'''\r
+\r
+##\r
+# Convert a string into a bool\r
+# used to convert an xsd:boolean to a Python boolean\r
+def str2bool(str):\r
+ if str.lower() in ['true','1']:\r
+ return True\r
+ return False\r
+\r
+\r
+##\r
+# Utility function to get the text of an XML element\r
+\r
+def getTextNode(element, subele):\r
+ sub = element.getElementsByTagName(subele)[0]\r
+ if len(sub.childNodes) > 0: \r
+ return sub.childNodes[0].nodeValue\r
+ else:\r
+ return None\r
+ \r
+##\r
+# Utility function to set the text of an XML element\r
+# It creates the element, adds the text to it,\r
+# and then appends it to the parent.\r
+\r
+def append_sub(doc, parent, element, text):\r
+ ele = doc.createElement(element)\r
+ ele.appendChild(doc.createTextNode(text))\r
+ parent.appendChild(ele)\r
+\r
+##\r
+# Signature contains information about an xmlsec1 signature\r
+# for a signed-credential\r
+#\r
+\r
+class Signature(object):\r
+ \r
+ def __init__(self, string=None):\r
+ self.refid = None\r
+ self.issuer_gid = None\r
+ self.xml = None\r
+ if string:\r
+ self.xml = string\r
+ self.decode()\r
+\r
+\r
+ def get_refid(self):\r
+ if not self.refid:\r
+ self.decode()\r
+ return self.refid\r
+\r
+ def get_xml(self):\r
+ if not self.xml:\r
+ self.encode()\r
+ return self.xml\r
+\r
+ def set_refid(self, id):\r
+ self.refid = id\r
+\r
+ def get_issuer_gid(self):\r
+ if not self.gid:\r
+ self.decode()\r
+ return self.gid \r
+\r
+ def set_issuer_gid(self, gid):\r
+ self.gid = gid\r
+\r
+ def decode(self):\r
+ try:\r
+ doc = parseString(self.xml)\r
+ except ExpatError,e:\r
+ logger.log_exc ("Failed to parse credential, %s"%self.xml)\r
+ raise\r
+ sig = doc.getElementsByTagName("Signature")[0]\r
+ self.set_refid(sig.getAttribute("xml:id").strip("Sig_"))\r
+ keyinfo = sig.getElementsByTagName("X509Data")[0]\r
+ szgid = getTextNode(keyinfo, "X509Certificate")\r
+ szgid = "-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----" % szgid\r
+ self.set_issuer_gid(GID(string=szgid)) \r
+ \r
+ def encode(self):\r
+ self.xml = signature_template % (self.get_refid(), self.get_refid())\r
+\r
+\r
+##\r
+# A credential provides a caller gid with privileges to an object gid.\r
+# A signed credential is signed by the object's authority.\r
+#\r
+# Credentials are encoded in one of two ways. The legacy style places\r
+# it in the subjectAltName of an X509 certificate. The new credentials\r
+# are placed in signed XML.\r
+#\r
+# WARNING:\r
+# In general, a signed credential obtained externally should\r
+# not be changed else the signature is no longer valid. So, once\r
+# you have loaded an existing signed credential, do not call encode() or sign() on it.\r
+\r
+def filter_creds_by_caller(creds, caller_hrn_list):\r
+ """\r
+ Returns a list of creds who's gid caller matches the\r
+ specified caller hrn\r
+ """\r
+ if not isinstance(creds, list): creds = [creds]\r
+ if not isinstance(caller_hrn_list, list): \r
+ caller_hrn_list = [caller_hrn_list]\r
+ caller_creds = []\r
+ for cred in creds:\r
+ try:\r
+ tmp_cred = Credential(string=cred)\r
+ if tmp_cred.get_gid_caller().get_hrn() in caller_hrn_list:\r
+ caller_creds.append(cred)\r
+ except: pass\r
+ return caller_creds\r
+\r
+class Credential(object):\r
+\r
+ ##\r
+ # Create a Credential object\r
+ #\r
+ # @param create If true, create a blank x509 certificate\r
+ # @param subject If subject!=None, create an x509 cert with the subject name\r
+ # @param string If string!=None, load the credential from the string\r
+ # @param filename If filename!=None, load the credential from the file\r
+ # FIXME: create and subject are ignored!\r
+ def __init__(self, create=False, subject=None, string=None, filename=None):\r
+ self.gidCaller = None\r
+ self.gidObject = None\r
+ self.expiration = None\r
+ self.privileges = None\r
+ self.issuer_privkey = None\r
+ self.issuer_gid = None\r
+ self.issuer_pubkey = None\r
+ self.parent = None\r
+ self.signature = None\r
+ self.xml = None\r
+ self.refid = None\r
+ self.legacy = None\r
+\r
+ # Check if this is a legacy credential, translate it if so\r
+ if string or filename:\r
+ if string: \r
+ str = string\r
+ elif filename:\r
+ str = file(filename).read()\r
+ \r
+ if str.strip().startswith("-----"):\r
+ self.legacy = CredentialLegacy(False,string=str)\r
+ self.translate_legacy(str)\r
+ else:\r
+ self.xml = str\r
+ self.decode()\r
+\r
+ # Find an xmlsec1 path\r
+ self.xmlsec_path = ''\r
+ paths = ['/usr/bin','/usr/local/bin','/bin','/opt/bin','/opt/local/bin']\r
+ for path in paths:\r
+ if os.path.isfile(path + '/' + 'xmlsec1'):\r
+ self.xmlsec_path = path + '/' + 'xmlsec1'\r
+ break\r
+\r
+ def get_subject(self):\r
+ if not self.gidObject:\r
+ self.decode()\r
+ return self.gidObject.get_printable_subject()\r
+\r
+ def get_summary_tostring(self):\r
+ if not self.gidObject:\r
+ self.decode()\r
+ obj = self.gidObject.get_printable_subject()\r
+ caller = self.gidCaller.get_printable_subject()\r
+ exp = self.get_expiration()\r
+ # Summarize the rights too? The issuer?\r
+ return "[ Grant %s rights on %s until %s ]" % (caller, obj, exp)\r
+\r
+ def get_signature(self):\r
+ if not self.signature:\r
+ self.decode()\r
+ return self.signature\r
+\r
+ def set_signature(self, sig):\r
+ self.signature = sig\r
+\r
+ \r
+ ##\r
+ # Translate a legacy credential into a new one\r
+ #\r
+ # @param String of the legacy credential\r
+\r
+ def translate_legacy(self, str):\r
+ legacy = CredentialLegacy(False,string=str)\r
+ self.gidCaller = legacy.get_gid_caller()\r
+ self.gidObject = legacy.get_gid_object()\r
+ lifetime = legacy.get_lifetime()\r
+ if not lifetime:\r
+ self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME))\r
+ else:\r
+ self.set_expiration(int(lifetime))\r
+ self.lifeTime = legacy.get_lifetime()\r
+ self.set_privileges(legacy.get_privileges())\r
+ self.get_privileges().delegate_all_privileges(legacy.get_delegate())\r
+\r
+ ##\r
+ # Need the issuer's private key and name\r
+ # @param key Keypair object containing the private key of the issuer\r
+ # @param gid GID of the issuing authority\r
+\r
+ def set_issuer_keys(self, privkey, gid):\r
+ self.issuer_privkey = privkey\r
+ self.issuer_gid = gid\r
+\r
+\r
+ ##\r
+ # Set this credential's parent\r
+ def set_parent(self, cred):\r
+ self.parent = cred\r
+ self.updateRefID()\r
+\r
+ ##\r
+ # set the GID of the caller\r
+ #\r
+ # @param gid GID object of the caller\r
+\r
+ def set_gid_caller(self, gid):\r
+ self.gidCaller = gid\r
+ # gid origin caller is the caller's gid by default\r
+ self.gidOriginCaller = gid\r
+\r
+ ##\r
+ # get the GID of the object\r
+\r
+ def get_gid_caller(self):\r
+ if not self.gidCaller:\r
+ self.decode()\r
+ return self.gidCaller\r
+\r
+ ##\r
+ # set the GID of the object\r
+ #\r
+ # @param gid GID object of the object\r
+\r
+ def set_gid_object(self, gid):\r
+ self.gidObject = gid\r
+\r
+ ##\r
+ # get the GID of the object\r
+\r
+ def get_gid_object(self):\r
+ if not self.gidObject:\r
+ self.decode()\r
+ return self.gidObject\r
+\r
+\r
+ \r
+ ##\r
+ # Expiration: an absolute UTC time of expiration (as either an int or string or datetime)\r
+ # \r
+ def set_expiration(self, expiration):\r
+ if isinstance(expiration, (int, float)):\r
+ self.expiration = datetime.datetime.fromtimestamp(expiration)\r
+ elif isinstance (expiration, datetime.datetime):\r
+ self.expiration = expiration\r
+ elif isinstance (expiration, StringTypes):\r
+ self.expiration = utcparse (expiration)\r
+ else:\r
+ logger.error ("unexpected input type in Credential.set_expiration")\r
+\r
+\r
+ ##\r
+ # get the lifetime of the credential (always in datetime format)\r
+\r
+ def get_expiration(self):\r
+ if not self.expiration:\r
+ self.decode()\r
+ # at this point self.expiration is normalized as a datetime - DON'T call utcparse again\r
+ return self.expiration\r
+\r
+ ##\r
+ # For legacy sake\r
+ def get_lifetime(self):\r
+ return self.get_expiration()\r
+ \r
+ ##\r
+ # set the privileges\r
+ #\r
+ # @param privs either a comma-separated list of privileges of a Rights object\r
+\r
+ def set_privileges(self, privs):\r
+ if isinstance(privs, str):\r
+ self.privileges = Rights(string = privs)\r
+ else:\r
+ self.privileges = privs\r
+ \r
+\r
+ ##\r
+ # return the privileges as a Rights object\r
+\r
+ def get_privileges(self):\r
+ if not self.privileges:\r
+ self.decode()\r
+ return self.privileges\r
+\r
+ ##\r
+ # determine whether the credential allows a particular operation to be\r
+ # performed\r
+ #\r
+ # @param op_name string specifying name of operation ("lookup", "update", etc)\r
+\r
+ def can_perform(self, op_name):\r
+ rights = self.get_privileges()\r
+ \r
+ if not rights:\r
+ return False\r
+\r
+ return rights.can_perform(op_name)\r
+\r
+\r
+ ##\r
+ # Encode the attributes of the credential into an XML string \r
+ # This should be done immediately before signing the credential. \r
+ # WARNING:\r
+ # In general, a signed credential obtained externally should\r
+ # not be changed else the signature is no longer valid. So, once\r
+ # you have loaded an existing signed credential, do not call encode() or sign() on it.\r
+\r
+ def encode(self):\r
+ # Create the XML document\r
+ doc = Document()\r
+ signed_cred = doc.createElement("signed-credential")\r
+\r
+# Declare namespaces\r
+# Note that credential/policy.xsd are really the PG schemas\r
+# in a PL namespace.\r
+# Note that delegation of credentials between the 2 only really works\r
+# cause those schemas are identical.\r
+# Also note these PG schemas talk about PG tickets and CM policies.\r
+ signed_cred.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance")\r
+ signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.planet-lab.org/resources/sfa/credential.xsd")\r
+ signed_cred.setAttribute("xsi:schemaLocation", "http://www.planet-lab.org/resources/sfa/ext/policy/1 http://www.planet-lab.org/resources/sfa/ext/policy/1/policy.xsd")\r
+\r
+# PG says for those last 2:\r
+# signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd")\r
+# signed_cred.setAttribute("xsi:schemaLocation", "http://www.protogeni.net/resources/credential/ext/policy/1 http://www.protogeni.net/resources/credential/ext/policy/1/policy.xsd")\r
+\r
+ doc.appendChild(signed_cred) \r
+ \r
+ # Fill in the <credential> bit \r
+ cred = doc.createElement("credential")\r
+ cred.setAttribute("xml:id", self.get_refid())\r
+ signed_cred.appendChild(cred)\r
+ append_sub(doc, cred, "type", "privilege")\r
+ append_sub(doc, cred, "serial", "8")\r
+ append_sub(doc, cred, "owner_gid", self.gidCaller.save_to_string())\r
+ append_sub(doc, cred, "owner_urn", self.gidCaller.get_urn())\r
+ append_sub(doc, cred, "target_gid", self.gidObject.save_to_string())\r
+ append_sub(doc, cred, "target_urn", self.gidObject.get_urn())\r
+ append_sub(doc, cred, "uuid", "")\r
+ if not self.expiration:\r
+ self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME))\r
+ self.expiration = self.expiration.replace(microsecond=0)\r
+ append_sub(doc, cred, "expires", self.expiration.isoformat())\r
+ privileges = doc.createElement("privileges")\r
+ cred.appendChild(privileges)\r
+\r
+ if self.privileges:\r
+ rights = self.get_privileges()\r
+ for right in rights.rights:\r
+ priv = doc.createElement("privilege")\r
+ append_sub(doc, priv, "name", right.kind)\r
+ append_sub(doc, priv, "can_delegate", str(right.delegate).lower())\r
+ privileges.appendChild(priv)\r
+\r
+ # Add the parent credential if it exists\r
+ if self.parent:\r
+ sdoc = parseString(self.parent.get_xml())\r
+ # If the root node is a signed-credential (it should be), then\r
+ # get all its attributes and attach those to our signed_cred\r
+ # node.\r
+ # Specifically, PG and PLadd attributes for namespaces (which is reasonable),\r
+ # and we need to include those again here or else their signature\r
+ # no longer matches on the credential.\r
+ # We expect three of these, but here we copy them all:\r
+# signed_cred.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance")\r
+# and from PG (PL is equivalent, as shown above):\r
+# signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd")\r
+# signed_cred.setAttribute("xsi:schemaLocation", "http://www.protogeni.net/resources/credential/ext/policy/1 http://www.protogeni.net/resources/credential/ext/policy/1/policy.xsd")\r
+\r
+ # HOWEVER!\r
+ # PL now also declares these, with different URLs, so\r
+ # the code notices those attributes already existed with\r
+ # different values, and complains.\r
+ # This happens regularly on delegation now that PG and\r
+ # PL both declare the namespace with different URLs.\r
+ # If the content ever differs this is a problem,\r
+ # but for now it works - different URLs (values in the attributes)\r
+ # but the same actual schema, so using the PG schema\r
+ # on delegated-to-PL credentials works fine.\r
+\r
+ # Note: you could also not copy attributes\r
+ # which already exist. It appears that both PG and PL\r
+ # will actually validate a slicecred with a parent\r
+ # signed using PG namespaces and a child signed with PL\r
+ # namespaces over the whole thing. But I don't know\r
+ # if that is a bug in xmlsec1, an accident since\r
+ # the contents of the schemas are the same,\r
+ # or something else, but it seems odd. And this works.\r
+ parentRoot = sdoc.documentElement\r
+ if parentRoot.tagName == "signed-credential" and parentRoot.hasAttributes():\r
+ for attrIx in range(0, parentRoot.attributes.length):\r
+ attr = parentRoot.attributes.item(attrIx)\r
+ # returns the old attribute of same name that was\r
+ # on the credential\r
+ # Below throws InUse exception if we forgot to clone the attribute first\r
+ oldAttr = signed_cred.setAttributeNode(attr.cloneNode(True))\r
+ if oldAttr and oldAttr.value != attr.value:\r
+ msg = "Delegating cred from owner %s to %s over %s replaced attribute %s value '%s' with '%s'" % (self.parent.gidCaller.get_urn(), self.gidCaller.get_urn(), self.gidObject.get_urn(), oldAttr.name, oldAttr.value, attr.value)\r
+ logger.warn(msg)\r
+ #raise CredentialNotVerifiable("Can't encode new valid delegated credential: %s" % msg)\r
+\r
+ p_cred = doc.importNode(sdoc.getElementsByTagName("credential")[0], True)\r
+ p = doc.createElement("parent")\r
+ p.appendChild(p_cred)\r
+ cred.appendChild(p)\r
+ # done handling parent credential\r
+\r
+ # Create the <signatures> tag\r
+ signatures = doc.createElement("signatures")\r
+ signed_cred.appendChild(signatures)\r
+\r
+ # Add any parent signatures\r
+ if self.parent:\r
+ for cur_cred in self.get_credential_list()[1:]:\r
+ sdoc = parseString(cur_cred.get_signature().get_xml())\r
+ ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True)\r
+ signatures.appendChild(ele)\r
+ \r
+ # Get the finished product\r
+ self.xml = doc.toxml()\r
+\r
+\r
+ def save_to_random_tmp_file(self): \r
+ fp, filename = mkstemp(suffix='cred', text=True)\r
+ fp = os.fdopen(fp, "w")\r
+ self.save_to_file(filename, save_parents=True, filep=fp)\r
+ return filename\r
+ \r
+ def save_to_file(self, filename, save_parents=True, filep=None):\r
+ if not self.xml:\r
+ self.encode()\r
+ if filep:\r
+ f = filep \r
+ else:\r
+ f = open(filename, "w")\r
+ f.write(self.xml)\r
+ f.close()\r
+\r
+ def save_to_string(self, save_parents=True):\r
+ if not self.xml:\r
+ self.encode()\r
+ return self.xml\r
+\r
+ def get_refid(self):\r
+ if not self.refid:\r
+ self.refid = 'ref0'\r
+ return self.refid\r
+\r
+ def set_refid(self, rid):\r
+ self.refid = rid\r
+\r
+ ##\r
+ # Figure out what refids exist, and update this credential's id\r
+ # so that it doesn't clobber the others. Returns the refids of\r
+ # the parents.\r
+ \r
+ def updateRefID(self):\r
+ if not self.parent:\r
+ self.set_refid('ref0')\r
+ return []\r
+ \r
+ refs = []\r
+\r
+ next_cred = self.parent\r
+ while next_cred:\r
+ refs.append(next_cred.get_refid())\r
+ if next_cred.parent:\r
+ next_cred = next_cred.parent\r
+ else:\r
+ next_cred = None\r
+\r
+ \r
+ # Find a unique refid for this credential\r
+ rid = self.get_refid()\r
+ while rid in refs:\r
+ val = int(rid[3:])\r
+ rid = "ref%d" % (val + 1)\r
+\r
+ # Set the new refid\r
+ self.set_refid(rid)\r
+\r
+ # Return the set of parent credential ref ids\r
+ return refs\r
+\r
+ def get_xml(self):\r
+ if not self.xml:\r
+ self.encode()\r
+ return self.xml\r
+\r
+ ##\r
+ # Sign the XML file created by encode()\r
+ #\r
+ # WARNING:\r
+ # In general, a signed credential obtained externally should\r
+ # not be changed else the signature is no longer valid. So, once\r
+ # you have loaded an existing signed credential, do not call encode() or sign() on it.\r
+\r
+ def sign(self):\r
+ if not self.issuer_privkey or not self.issuer_gid:\r
+ return\r
+ doc = parseString(self.get_xml())\r
+ sigs = doc.getElementsByTagName("signatures")[0]\r
+\r
+ # Create the signature template to be signed\r
+ signature = Signature()\r
+ signature.set_refid(self.get_refid())\r
+ sdoc = parseString(signature.get_xml()) \r
+ sig_ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True)\r
+ sigs.appendChild(sig_ele)\r
+\r
+ self.xml = doc.toxml()\r
+\r
+\r
+ # Split the issuer GID into multiple certificates if it's a chain\r
+ chain = GID(filename=self.issuer_gid)\r
+ gid_files = []\r
+ while chain:\r
+ gid_files.append(chain.save_to_random_tmp_file(False))\r
+ if chain.get_parent():\r
+ chain = chain.get_parent()\r
+ else:\r
+ chain = None\r
+\r
+\r
+ # Call out to xmlsec1 to sign it\r
+ ref = 'Sig_%s' % self.get_refid()\r
+ filename = self.save_to_random_tmp_file()\r
+ signed = os.popen('%s --sign --node-id "%s" --privkey-pem %s,%s %s' \\r
+ % (self.xmlsec_path, ref, self.issuer_privkey, ",".join(gid_files), filename)).read()\r
+ os.remove(filename)\r
+\r
+ for gid_file in gid_files:\r
+ os.remove(gid_file)\r
+\r
+ self.xml = signed\r
+\r
+ # This is no longer a legacy credential\r
+ if self.legacy:\r
+ self.legacy = None\r
+\r
+ # Update signatures\r
+ self.decode() \r
+\r
+ \r
+ ##\r
+ # Retrieve the attributes of the credential from the XML.\r
+ # This is automatically called by the various get_* methods of\r
+ # this class and should not need to be called explicitly.\r
+\r
+ def decode(self):\r
+ if not self.xml:\r
+ return\r
+ doc = parseString(self.xml)\r
+ sigs = []\r
+ signed_cred = doc.getElementsByTagName("signed-credential")\r
+\r
+ # Is this a signed-cred or just a cred?\r
+ if len(signed_cred) > 0:\r
+ creds = signed_cred[0].getElementsByTagName("credential")\r
+ signatures = signed_cred[0].getElementsByTagName("signatures")\r
+ if len(signatures) > 0:\r
+ sigs = signatures[0].getElementsByTagName("Signature")\r
+ else:\r
+ creds = doc.getElementsByTagName("credential")\r
+ \r
+ if creds is None or len(creds) == 0:\r
+ # malformed cred file\r
+ raise CredentialNotVerifiable("Malformed XML: No credential tag found")\r
+\r
+ # Just take the first cred if there are more than one\r
+ cred = creds[0]\r
+\r
+ self.set_refid(cred.getAttribute("xml:id"))\r
+ self.set_expiration(utcparse(getTextNode(cred, "expires")))\r
+ self.gidCaller = GID(string=getTextNode(cred, "owner_gid"))\r
+ self.gidObject = GID(string=getTextNode(cred, "target_gid")) \r
+\r
+\r
+ # Process privileges\r
+ privs = cred.getElementsByTagName("privileges")[0]\r
+ rlist = Rights()\r
+ for priv in privs.getElementsByTagName("privilege"):\r
+ kind = getTextNode(priv, "name")\r
+ deleg = str2bool(getTextNode(priv, "can_delegate"))\r
+ if kind == '*':\r
+ # Convert * into the default privileges for the credential's type\r
+ # Each inherits the delegatability from the * above\r
+ _ , type = urn_to_hrn(self.gidObject.get_urn())\r
+ rl = determine_rights(type, self.gidObject.get_urn())\r
+ for r in rl.rights:\r
+ r.delegate = deleg\r
+ rlist.add(r)\r
+ else:\r
+ rlist.add(Right(kind.strip(), deleg))\r
+ self.set_privileges(rlist)\r
+\r
+\r
+ # Is there a parent?\r
+ parent = cred.getElementsByTagName("parent")\r
+ if len(parent) > 0:\r
+ parent_doc = parent[0].getElementsByTagName("credential")[0]\r
+ parent_xml = parent_doc.toxml()\r
+ self.parent = Credential(string=parent_xml)\r
+ self.updateRefID()\r
+\r
+ # Assign the signatures to the credentials\r
+ for sig in sigs:\r
+ Sig = Signature(string=sig.toxml())\r
+\r
+ for cur_cred in self.get_credential_list():\r
+ if cur_cred.get_refid() == Sig.get_refid():\r
+ cur_cred.set_signature(Sig)\r
+ \r
+ \r
+ ##\r
+ # Verify\r
+ # trusted_certs: A list of trusted GID filenames (not GID objects!) \r
+ # Chaining is not supported within the GIDs by xmlsec1.\r
+ #\r
+ # trusted_certs_required: Should usually be true. Set False means an\r
+ # empty list of trusted_certs would still let this method pass.\r
+ # It just skips xmlsec1 verification et al. Only used by some utils\r
+ # \r
+ # Verify that:\r
+ # . All of the signatures are valid and that the issuers trace back\r
+ # to trusted roots (performed by xmlsec1)\r
+ # . The XML matches the credential schema\r
+ # . That the issuer of the credential is the authority in the target's urn\r
+ # . In the case of a delegated credential, this must be true of the root\r
+ # . That all of the gids presented in the credential are valid\r
+ # . Including verifying GID chains, and includ the issuer\r
+ # . The credential is not expired\r
+ #\r
+ # -- For Delegates (credentials with parents)\r
+ # . The privileges must be a subset of the parent credentials\r
+ # . The privileges must have "can_delegate" set for each delegated privilege\r
+ # . The target gid must be the same between child and parents\r
+ # . The expiry time on the child must be no later than the parent\r
+ # . The signer of the child must be the owner of the parent\r
+ #\r
+ # -- Verify does *NOT*\r
+ # . ensure that an xmlrpc client's gid matches a credential gid, that\r
+ # must be done elsewhere\r
+ #\r
+ # @param trusted_certs: The certificates of trusted CA certificates\r
+ def verify(self, trusted_certs=None, schema=None, trusted_certs_required=True):\r
+ if not self.xml:\r
+ self.decode()\r
+\r
+ # validate against RelaxNG schema\r
+ if HAVELXML and not self.legacy:\r
+ if schema and os.path.exists(schema):\r
+ tree = etree.parse(StringIO(self.xml))\r
+ schema_doc = etree.parse(schema)\r
+ xmlschema = etree.XMLSchema(schema_doc)\r
+ if not xmlschema.validate(tree):\r
+ error = xmlschema.error_log.last_error\r
+ message = "%s: %s (line %s)" % (self.get_summary_tostring(), error.message, error.line)\r
+ raise CredentialNotVerifiable(message)\r
+\r
+ if trusted_certs_required and trusted_certs is None:\r
+ trusted_certs = []\r
+\r
+# trusted_cert_objects = [GID(filename=f) for f in trusted_certs]\r
+ trusted_cert_objects = []\r
+ ok_trusted_certs = []\r
+ # If caller explicitly passed in None that means skip cert chain validation.\r
+ # Strange and not typical\r
+ if trusted_certs is not None:\r
+ for f in trusted_certs:\r
+ try:\r
+ # Failures here include unreadable files\r
+ # or non PEM files\r
+ trusted_cert_objects.append(GID(filename=f))\r
+ ok_trusted_certs.append(f)\r
+ except Exception, exc:\r
+ logger.error("Failed to load trusted cert from %s: %r", f, exc)\r
+ trusted_certs = ok_trusted_certs\r
+\r
+ # Use legacy verification if this is a legacy credential\r
+ if self.legacy:\r
+ self.legacy.verify_chain(trusted_cert_objects)\r
+ if self.legacy.client_gid:\r
+ self.legacy.client_gid.verify_chain(trusted_cert_objects)\r
+ if self.legacy.object_gid:\r
+ self.legacy.object_gid.verify_chain(trusted_cert_objects)\r
+ return True\r
+ \r
+ # make sure it is not expired\r
+ if self.get_expiration() < datetime.datetime.utcnow():\r
+ raise CredentialNotVerifiable("Credential %s expired at %s" % (self.get_summary_tostring(), self.expiration.isoformat()))\r
+\r
+ # Verify the signatures\r
+ filename = self.save_to_random_tmp_file()\r
+ if trusted_certs is not None:\r
+ cert_args = " ".join(['--trusted-pem %s' % x for x in trusted_certs])\r
+\r
+ # If caller explicitly passed in None that means skip cert chain validation.\r
+ # - Strange and not typical\r
+ if trusted_certs is not None:\r
+ # Verify the gids of this cred and of its parents\r
+ for cur_cred in self.get_credential_list():\r
+ cur_cred.get_gid_object().verify_chain(trusted_cert_objects)\r
+ cur_cred.get_gid_caller().verify_chain(trusted_cert_objects)\r
+\r
+ refs = []\r
+ refs.append("Sig_%s" % self.get_refid())\r
+\r
+ parentRefs = self.updateRefID()\r
+ for ref in parentRefs:\r
+ refs.append("Sig_%s" % ref)\r
+\r
+ for ref in refs:\r
+ # If caller explicitly passed in None that means skip xmlsec1 validation.\r
+ # Strange and not typical\r
+ if trusted_certs is None:\r
+ break\r
+\r
+# print "Doing %s --verify --node-id '%s' %s %s 2>&1" % \\r
+# (self.xmlsec_path, ref, cert_args, filename)\r
+ verified = os.popen('%s --verify --node-id "%s" %s %s 2>&1' \\r
+ % (self.xmlsec_path, ref, cert_args, filename)).read()\r
+ if not verified.strip().startswith("OK"):\r
+ # xmlsec errors have a msg= which is the interesting bit.\r
+ mstart = verified.find("msg=")\r
+ msg = ""\r
+ if mstart > -1 and len(verified) > 4:\r
+ mstart = mstart + 4\r
+ mend = verified.find('\\', mstart)\r
+ msg = verified[mstart:mend]\r
+ raise CredentialNotVerifiable("xmlsec1 error verifying cred %s using Signature ID %s: %s %s" % (self.get_summary_tostring(), ref, msg, verified.strip()))\r
+ os.remove(filename)\r
+\r
+ # Verify the parents (delegation)\r
+ if self.parent:\r
+ self.verify_parent(self.parent)\r
+\r
+ # Make sure the issuer is the target's authority, and is\r
+ # itself a valid GID\r
+ self.verify_issuer(trusted_cert_objects)\r
+ return True\r
+\r
+ ##\r
+ # Creates a list of the credential and its parents, with the root \r
+ # (original delegated credential) as the last item in the list\r
+ def get_credential_list(self): \r
+ cur_cred = self\r
+ list = []\r
+ while cur_cred:\r
+ list.append(cur_cred)\r
+ if cur_cred.parent:\r
+ cur_cred = cur_cred.parent\r
+ else:\r
+ cur_cred = None\r
+ return list\r
+ \r
+ ##\r
+ # Make sure the credential's target gid (a) was signed by or (b)\r
+ # is the same as the entity that signed the original credential,\r
+ # or (c) is an authority over the target's namespace.\r
+ # Also ensure that the credential issuer / signer itself has a valid\r
+ # GID signature chain (signed by an authority with namespace rights).\r
+ def verify_issuer(self, trusted_gids):\r
+ root_cred = self.get_credential_list()[-1]\r
+ root_target_gid = root_cred.get_gid_object()\r
+ root_cred_signer = root_cred.get_signature().get_issuer_gid()\r
+\r
+ # Case 1:\r
+ # Allow non authority to sign target and cred about target.\r
+ #\r
+ # Why do we need to allow non authorities to sign?\r
+ # If in the target gid validation step we correctly\r
+ # checked that the target is only signed by an authority,\r
+ # then this is just a special case of case 3.\r
+ # This short-circuit is the common case currently -\r
+ # and cause GID validation doesn't check 'authority',\r
+ # this allows users to generate valid slice credentials.\r
+ if root_target_gid.is_signed_by_cert(root_cred_signer):\r
+ # cred signer matches target signer, return success\r
+ return\r
+\r
+ # Case 2:\r
+ # Allow someone to sign credential about themeselves. Used?\r
+ # If not, remove this.\r
+ #root_target_gid_str = root_target_gid.save_to_string()\r
+ #root_cred_signer_str = root_cred_signer.save_to_string()\r
+ #if root_target_gid_str == root_cred_signer_str:\r
+ # # cred signer is target, return success\r
+ # return\r
+\r
+ # Case 3:\r
+\r
+ # root_cred_signer is not the target_gid\r
+ # So this is a different gid that we have not verified.\r
+ # xmlsec1 verified the cert chain on this already, but\r
+ # it hasn't verified that the gid meets the HRN namespace\r
+ # requirements.\r
+ # Below we'll ensure that it is an authority.\r
+ # But we haven't verified that it is _signed by_ an authority\r
+ # We also don't know if xmlsec1 requires that cert signers\r
+ # are marked as CAs.\r
+\r
+ # Note that if verify() gave us no trusted_gids then this\r
+ # call will fail. So skip it if we have no trusted_gids\r
+ if trusted_gids and len(trusted_gids) > 0:\r
+ root_cred_signer.verify_chain(trusted_gids)\r
+ else:\r
+ logger.debug("No trusted gids. Cannot verify that cred signer is signed by a trusted authority. Skipping that check.")\r
+\r
+ # See if the signer is an authority over the domain of the target.\r
+ # There are multiple types of authority - accept them all here\r
+ # Maybe should be (hrn, type) = urn_to_hrn(root_cred_signer.get_urn())\r
+ root_cred_signer_type = root_cred_signer.get_type()\r
+ if (root_cred_signer_type.find('authority') == 0):\r
+ #logger.debug('Cred signer is an authority')\r
+ # signer is an authority, see if target is in authority's domain\r
+ signerhrn = root_cred_signer.get_hrn()\r
+ if hrn_authfor_hrn(signerhrn, root_target_gid.get_hrn()):\r
+ return\r
+\r
+ # We've required that the credential be signed by an authority\r
+ # for that domain. Reasonable and probably correct.\r
+ # A looser model would also allow the signer to be an authority\r
+ # in my control framework - eg My CA or CH. Even if it is not\r
+ # the CH that issued these, eg, user credentials.\r
+\r
+ # Give up, credential does not pass issuer verification\r
+\r
+ raise CredentialNotVerifiable("Could not verify credential owned by %s for object %s. Cred signer %s not the trusted authority for Cred target %s" % (self.gidCaller.get_urn(), self.gidObject.get_urn(), root_cred_signer.get_hrn(), root_target_gid.get_hrn()))\r
+\r
+\r
+ ##\r
+ # -- For Delegates (credentials with parents) verify that:\r
+ # . The privileges must be a subset of the parent credentials\r
+ # . The privileges must have "can_delegate" set for each delegated privilege\r
+ # . The target gid must be the same between child and parents\r
+ # . The expiry time on the child must be no later than the parent\r
+ # . The signer of the child must be the owner of the parent \r
+ def verify_parent(self, parent_cred):\r
+ # make sure the rights given to the child are a subset of the\r
+ # parents rights (and check delegate bits)\r
+ if not parent_cred.get_privileges().is_superset(self.get_privileges()):\r
+ raise ChildRightsNotSubsetOfParent(("Parent cred ref %s rights " % parent_cred.get_refid()) +\r
+ self.parent.get_privileges().save_to_string() + (" not superset of delegated cred %s ref %s rights " % (self.get_summary_tostring(), self.get_refid())) +\r
+ self.get_privileges().save_to_string())\r
+\r
+ # make sure my target gid is the same as the parent's\r
+ if not parent_cred.get_gid_object().save_to_string() == \\r
+ self.get_gid_object().save_to_string():\r
+ raise CredentialNotVerifiable("Delegated cred %s: Target gid not equal between parent and child. Parent %s" % (self.get_summary_tostring(), parent_cred.get_summary_tostring()))\r
+\r
+ # make sure my expiry time is <= my parent's\r
+ if not parent_cred.get_expiration() >= self.get_expiration():\r
+ raise CredentialNotVerifiable("Delegated credential %s expires after parent %s" % (self.get_summary_tostring(), parent_cred.get_summary_tostring()))\r
+\r
+ # make sure my signer is the parent's caller\r
+ if not parent_cred.get_gid_caller().save_to_string(False) == \\r
+ self.get_signature().get_issuer_gid().save_to_string(False):\r
+ raise CredentialNotVerifiable("Delegated credential %s not signed by parent %s's caller" % (self.get_summary_tostring(), parent_cred.get_summary_tostring()))\r
+ \r
+ # Recurse\r
+ if parent_cred.parent:\r
+ parent_cred.verify_parent(parent_cred.parent)\r
+\r
+\r
+ def delegate(self, delegee_gidfile, caller_keyfile, caller_gidfile):\r
+ """\r
+ Return a delegated copy of this credential, delegated to the \r
+ specified gid's user. \r
+ """\r
+ # get the gid of the object we are delegating\r
+ object_gid = self.get_gid_object()\r
+ object_hrn = object_gid.get_hrn() \r
+ \r
+ # the hrn of the user who will be delegated to\r
+ delegee_gid = GID(filename=delegee_gidfile)\r
+ delegee_hrn = delegee_gid.get_hrn()\r
+ \r
+ #user_key = Keypair(filename=keyfile)\r
+ #user_hrn = self.get_gid_caller().get_hrn()\r
+ subject_string = "%s delegated to %s" % (object_hrn, delegee_hrn)\r
+ dcred = Credential(subject=subject_string)\r
+ dcred.set_gid_caller(delegee_gid)\r
+ dcred.set_gid_object(object_gid)\r
+ dcred.set_parent(self)\r
+ dcred.set_expiration(self.get_expiration())\r
+ dcred.set_privileges(self.get_privileges())\r
+ dcred.get_privileges().delegate_all_privileges(True)\r
+ #dcred.set_issuer_keys(keyfile, delegee_gidfile)\r
+ dcred.set_issuer_keys(caller_keyfile, caller_gidfile)\r
+ dcred.encode()\r
+ dcred.sign()\r
+\r
+ return dcred\r
+\r
+ # only informative\r
+ def get_filename(self):\r
+ return getattr(self,'filename',None)\r
+\r
+ ##\r
+ # Dump the contents of a credential to stdout in human-readable format\r
+ #\r
+ # @param dump_parents If true, also dump the parent certificates\r
+ def dump (self, *args, **kwargs):\r
+ print self.dump_string(*args, **kwargs)\r
+\r
+\r
+ def dump_string(self, dump_parents=False):\r
+ result=""\r
+ result += "CREDENTIAL %s\n" % self.get_subject()\r
+ filename=self.get_filename()\r
+ if filename: result += "Filename %s\n"%filename\r
+ result += " privs: %s\n" % self.get_privileges().save_to_string()\r
+ gidCaller = self.get_gid_caller()\r
+ if gidCaller:\r
+ result += " gidCaller:\n"\r
+ result += gidCaller.dump_string(8, dump_parents)\r
+\r
+ if self.get_signature():\r
+ print " gidIssuer:"\r
+ self.get_signature().get_issuer_gid().dump(8, dump_parents)\r
+\r
+ gidObject = self.get_gid_object()\r
+ if gidObject:\r
+ result += " gidObject:\n"\r
+ result += gidObject.dump_string(8, dump_parents)\r
+\r
+ if self.parent and dump_parents:\r
+ result += "\nPARENT"\r
+ result += self.parent.dump_string(True)\r
+\r
+ return result\r
-->
<!--
- ProtoGENI credential and privilege specification. The key points:
+ PlanetLab credential specification. The key points:
* A credential is a set of privileges or a Ticket, each with a flag
to indicate delegation is permitted.
blob will be signed. So, there will be multiple signatures in the
document, each with a reference to the credential it signs.
- default namespace = "http://www.protogeni.net/resources/credential/0.1"
+ default namespace = "http://www.planet-lab.org/resources/ext/credential/1"
-->
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema" elementFormDefault="qualified" xmlns:sig="http://www.w3.org/2000/09/xmldsig#">
<xs:include schemaLocation="protogeni-rspec-common.xsd"/>
# certificate that stores a tuple of parameters.
##
-### $Id: credential.py 17477 2010-03-25 16:49:34Z jkarlin $
-### $URL: svn+ssh://svn.planet-lab.org/svn/sfa/branches/geni-api/sfa/trust/credential.py $
-
import xmlrpclib
from sfa.util.faults import *
import xmlrpclib
import uuid
-from sfa.util.sfalogging import sfa_logger
from sfa.trust.certificate import Certificate
-from sfa.util.xrn import hrn_to_urn, urn_to_hrn
+
+from sfa.util.faults import *
+from sfa.util.sfalogging import logger
+from sfa.util.xrn import hrn_to_urn, urn_to_hrn, hrn_authfor_hrn
##
# Create a new uuid. Returns the UUID as a string.
# @param subject If subject!=None, create the X509 cert and set the subject name
# @param string If string!=None, load the GID from a string
# @param filename If filename!=None, load the GID from a file
+ # @param lifeDays life of GID in days - default is 1825==5 years
- def __init__(self, create=False, subject=None, string=None, filename=None, uuid=None, hrn=None, urn=None):
+ def __init__(self, create=False, subject=None, string=None, filename=None, uuid=None, hrn=None, urn=None, lifeDays=1825):
- Certificate.__init__(self, create, subject, string, filename)
+ Certificate.__init__(self, lifeDays, create, subject, string, filename)
if subject:
- sfa_logger().debug("Creating GID for subject: %s" % subject)
+ logger.debug("Creating GID for subject: %s" % subject)
if uuid:
self.uuid = int(uuid)
if hrn:
print self.dump_string(*args,**kwargs)
def dump_string(self, indent=0, dump_parents=False):
- result="GID\n"
+ result=" "*(indent-2) + "GID\n"
result += " "*indent + "hrn:" + str(self.get_hrn()) +"\n"
result += " "*indent + "urn:" + str(self.get_urn()) +"\n"
result += " "*indent + "uuid:" + str(self.get_uuid()) + "\n"
# Verify the chain of authenticity of the GID. First perform the checks
# of the certificate class (verifying that each parent signs the child,
# etc). In addition, GIDs also confirm that the parent's HRN is a prefix
- # of the child's HRN.
+ # of the child's HRN, and the parent is of type 'authority'.
#
# Verifying these prefixes prevents a rogue authority from signing a GID
# for a principal that is not a member of that authority. For example,
if self.parent:
# make sure the parent's hrn is a prefix of the child's hrn
- if not self.get_hrn().startswith(self.parent.get_hrn()):
- raise GidParentHrn("This cert HRN %s doesnt start with parent HRN %s" % (self.get_hrn(), self.parent.get_hrn()))
+ if not hrn_authfor_hrn(self.parent.get_hrn(), self.get_hrn()):
+ raise GidParentHrn("This cert HRN %s isn't in the namespace for parent HRN %s" % (self.get_hrn(), self.parent.get_hrn()))
+
+ # Parent must also be an authority (of some type) to sign a GID
+ # There are multiple types of authority - accept them all here
+ if not self.parent.get_type().find('authority') == 0:
+ raise GidInvalidParentHrn("This cert %s's parent %s is not an authority (is a %s)" % (self.get_hrn(), self.parent.get_hrn(), self.parent.get_type()))
+
+ # Then recurse up the chain - ensure the parent is a trusted
+ # root or is in the namespace of a trusted root
+ self.parent.verify_chain(trusted_certs)
else:
# make sure that the trusted root's hrn is a prefix of the child's
trusted_gid = GID(string=trusted_root.save_to_string())
#if trusted_type == 'authority':
# trusted_hrn = trusted_hrn[:trusted_hrn.rindex('.')]
cur_hrn = self.get_hrn()
- if not self.get_hrn().startswith(trusted_hrn):
- raise GidParentHrn("Trusted roots HRN %s isnt start of this cert %s" % (trusted_hrn, cur_hrn))
+ if not hrn_authfor_hrn(trusted_hrn, cur_hrn):
+ raise GidParentHrn("Trusted root with HRN %s isn't a namespace authority for this cert %s" % (trusted_hrn, cur_hrn))
+
+ # There are multiple types of authority - accept them all here
+ if not trusted_type.find('authority') == 0:
+ raise GidInvalidParentHrn("This cert %s's trusted root signer %s is not an authority (is a %s)" % (self.get_hrn(), trusted_hrn, trusted_type))
return
import os
from sfa.util.faults import *
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
from sfa.util.xrn import get_leaf, get_authority, hrn_to_urn, urn_to_hrn
from sfa.trust.certificate import Keypair
from sfa.trust.credential import Credential
gid_filename = None
privkey_filename = None
dbinfo_filename = None
-
##
# Initialize and authority object.
#
def create_auth(self, xrn, create_parents=False):
hrn, type = urn_to_hrn(xrn)
- sfa_logger().debug("Hierarchy: creating authority: %s"% hrn)
+ logger.debug("Hierarchy: creating authority: %s"% hrn)
# create the parent authority if necessary
parent_hrn = get_authority(hrn)
pass
if os.path.exists(privkey_filename):
- sfa_logger().debug("using existing key %r for authority %r"%(privkey_filename,hrn))
+ logger.debug("using existing key %r for authority %r"%(privkey_filename,hrn))
pkey = Keypair(filename = privkey_filename)
else:
pkey = Keypair(create = True)
def get_auth_info(self, xrn):
hrn, type = urn_to_hrn(xrn)
if not self.auth_exists(hrn):
- sfa_logger().warning("Hierarchy: mising authority - xrn=%s, hrn=%s"%(xrn,hrn))
+ logger.warning("Hierarchy: mising authority - xrn=%s, hrn=%s"%(xrn,hrn))
raise MissingAuthority(hrn)
(directory, gid_filename, privkey_filename, dbinfo_filename) = \
# @param uuid the unique identifier to store in the GID
# @param pkey the public key to store in the GID
- def create_gid(self, xrn, uuid, pkey):
+ def create_gid(self, xrn, uuid, pkey, CA=False):
hrn, type = urn_to_hrn(xrn)
+ parent_hrn = get_authority(hrn)
# Using hrn_to_urn() here to make sure the urn is in the right format
# If xrn was a hrn instead of a urn, then the gid's urn will be
# of type None
urn = hrn_to_urn(hrn, type)
gid = GID(subject=hrn, uuid=uuid, hrn=hrn, urn=urn)
- parent_hrn = get_authority(hrn)
+ # is this a CA cert
+ if hrn == self.config.SFA_INTERFACE_HRN or not parent_hrn:
+ # root or sub authority
+ gid.set_intermediate_ca(True)
+ elif type and 'authority' in type:
+ # authority type
+ gid.set_intermediate_ca(True)
+ elif CA:
+ gid.set_intermediate_ca(True)
+ else:
+ gid.set_intermediate_ca(False)
+
+ # set issuer
if not parent_hrn or hrn == self.config.SFA_INTERFACE_HRN:
# if there is no parent hrn, then it must be self-signed. this
# is where we terminate the recursion
parent_auth_info = self.get_auth_info(parent_hrn)
gid.set_issuer(parent_auth_info.get_pkey_object(), parent_auth_info.hrn)
gid.set_parent(parent_auth_info.get_gid_object())
- gid.set_intermediate_ca(True)
gid.set_pubkey(pkey)
gid.encode()
elif type in ["sa", "authority+sa"]:
rl.add("authority")
rl.add("sa")
- elif type in ["ma", "authority+ma", "cm", "authority+cm"]:
+ elif type in ["ma", "authority+ma", "cm", "authority+cm", "sm", "authority+sm"]:
rl.add("authority")
rl.add("ma")
elif type == "authority":
for my_right in self.rights:
if my_right.is_superset(child_right):
allowed = True
+ break
if not allowed:
return False
return True
return False
return True
-
-
- ##
- # Determine the rights that an object should have. The rights are entirely
- # dependent on the type of the object. For example, users automatically
- # get "refresh", "resolve", and "info".
- #
- # @param type the type of the object (user | sa | ma | slice | node)
- # @param name human readable name of the object (not used at this time)
- #
- # @return Rights object containing rights
-
- def determine_rights(self, type, name):
- rl = Rights()
-
- # rights seem to be somewhat redundant with the type of the credential.
- # For example, a "sa" credential implies the authority right, because
- # a sa credential cannot be issued to a user who is not an owner of
- # the authority
-
- if type == "user":
- rl.add("refresh")
- rl.add("resolve")
- rl.add("info")
- elif type in ["sa", "authority+sa"]:
- rl.add("authority")
- rl.add("sa")
- elif type in ["ma", "authority+ma", "cm", "authority+cm"]:
- rl.add("authority")
- rl.add("ma")
- elif type == "authority":
- rl.add("authority")
- rl.add("sa")
- rl.add("ma")
- elif type == "slice":
- rl.add("refresh")
- rl.add("embed")
- rl.add("bind")
- rl.add("control")
- rl.add("info")
- elif type == "component":
- rl.add("operator")
-
- return rl
+++ /dev/null
-### $Id$
-### $URL$
-
-import os
-
-from sfa.trust.gid import *
-
-class TrustedRootList:
- def __init__(self, dir):
- self.basedir = dir
-
- # create the directory to hold the files
- try:
- os.makedirs(self.basedir)
- # if the path already exists then pass
- except OSError, (errno, strerr):
- if errno == 17:
- pass
-
- def add_gid(self, gid):
- fn = os.path.join(self.basedir, gid.get_hrn() + ".gid")
-
- gid.save_to_file(fn)
-
- def get_list(self):
- gid_list = []
- file_list = os.listdir(self.basedir)
- for gid_file in file_list:
- fn = os.path.join(self.basedir, gid_file)
- if os.path.isfile(fn):
- gid = GID(filename = fn)
- gid_list.append(gid)
- return gid_list
-
- def get_file_list(self):
- gid_file_list = []
-
- file_list = os.listdir(self.basedir)
- for gid_file in file_list:
- fn = os.path.join(self.basedir, gid_file)
- if os.path.isfile(fn):
- gid_file_list.append(fn)
-
- return gid_file_list
--- /dev/null
+import os.path
+import glob
+
+from sfa.trust.gid import GID
+from sfa.util.sfalogging import logger
+
+class TrustedRoots:
+
+ # we want to avoid reading all files in the directory
+ # this is because it's common to have backups of all kinds
+ # e.g. *~, *.hide, *-00, *.bak and the like
+ supported_extensions= [ 'gid', 'cert', 'pem' ]
+
+ def __init__(self, dir):
+ self.basedir = dir
+ # create the directory to hold the files, if not existing
+ if not os.path.isdir (self.basedir):
+ os.makedirs(self.basedir)
+
+ def add_gid(self, gid):
+ fn = os.path.join(self.basedir, gid.get_hrn() + ".gid")
+ gid.save_to_file(fn)
+
+ def get_list(self):
+ gid_list = [GID(filename=cert_file) for cert_file in self.get_file_list()]
+ return gid_list
+
+ def get_file_list(self):
+ file_list = []
+ pattern=os.path.join(self.basedir,"*")
+ for cert_file in glob.glob(pattern):
+ if os.path.isfile(cert_file):
+ if self.has_supported_extension(cert_file):
+ file_list.append(cert_file)
+ else:
+ logger.warning("File %s ignored - supported extensions are %r"%\
+ (cert_file,TrustedRoots.supported_extensions))
+ return file_list
+
+ def has_supported_extension (self,path):
+ (_,ext)=os.path.splitext(path)
+ ext=ext.replace('.','').lower()
+ return ext in TrustedRoots.supported_extensions
except: print >> sys.stderr, "WARNING, could not import pgdb"
from sfa.util.faults import *
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
if not psycopg2:
is8bit = re.compile("[\x80-\xff]").search
if not params:
if self.debug:
- sfa_logger().debug('execute0 %r'%query)
+ logger.debug('execute0 %r'%query)
cursor.execute(query)
elif isinstance(params,dict):
if self.debug:
- sfa_logger().debug('execute-dict: params=[%r] query=[%r]'%(params,query%params))
+ logger.debug('execute-dict: params=[%r] query=[%r]'%(params,query%params))
cursor.execute(query,params)
elif isinstance(params,tuple) and len(params)==1:
if self.debug:
- sfa_logger().debug('execute-tuple %r'%(query%params[0]))
+ logger.debug('execute-tuple %r'%(query%params[0]))
cursor.execute(query,params[0])
else:
param_seq=(params,)
if self.debug:
for params in param_seq:
- sfa_logger().debug('executemany %r'%(query%params))
+ logger.debug('executemany %r'%(query%params))
cursor.executemany(query, param_seq)
(self.rowcount, self.description, self.lastrowid) = \
(cursor.rowcount, cursor.description, cursor.lastrowid)
except:
pass
uuid = commands.getoutput("uuidgen")
- sfa_logger().error("Database error %s:" % uuid)
- sfa_logger().error("Exception=%r"%e)
- sfa_logger().error("Query=%r"%query)
- sfa_logger().error("Params=%r"%pformat(params))
- sfa_logger().log_exc("PostgreSQL.execute caught exception")
+ logger.error("Database error %s:" % uuid)
+ logger.error("Exception=%r"%e)
+ logger.error("Query=%r"%query)
+ logger.error("Params=%r"%pformat(params))
+ logger.log_exc("PostgreSQL.execute caught exception")
raise SfaDBError("Please contact support: %s" % str(e))
return cursor
import string
import xmlrpclib
-from sfa.util.sfalogging import sfa_logger
-from sfa.trust.auth import Auth
-from sfa.util.config import *
from sfa.util.faults import *
+from sfa.util.config import *
+import sfa.util.xmlrpcprotocol as xmlrpcprotocol
+from sfa.util.sfalogging import logger
+from sfa.trust.auth import Auth
+from sfa.util.cache import Cache
from sfa.trust.credential import *
from sfa.trust.certificate import *
+# this is wrong all right, but temporary
+from sfa.managers.import_manager import import_manager
+
# See "2.2 Characters" in the XML specification:
#
# #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD]
"""
This class acts as a wrapper around an SFA interface manager module, but
can be used with any python module. The purpose of this class is raise a
- SfaNotImplemented exception if the a someone attepmts to use an attribute
+ SfaNotImplemented exception if someone attempts to use an attribute
(could be a callable) thats not available in the library by checking the
library using hasattr. This helps to communicate better errors messages
to the users and developers in the event that a specifiec operation
self.interface = interface
def __getattr__(self, method):
-
if not hasattr(self.manager, method):
raise SfaNotImplemented(method, self.interface)
return getattr(self.manager, method)
class BaseAPI:
- cache = None
protocol = None
def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8",
methods='sfa.methods', peer_cert = None, interface = None,
- key_file = None, cert_file = None, cache = cache):
+ key_file = None, cert_file = None, cache = None):
self.encoding = encoding
# Better just be documenting the API
if config is None:
return
-
# Load configuration
self.config = Config(config)
self.auth = Auth(peer_cert)
self.cert_file = cert_file
self.cert = Certificate(filename=self.cert_file)
self.cache = cache
+ if self.cache is None:
+ self.cache = Cache()
self.credential = None
self.source = None
self.time_format = "%Y-%m-%d %H:%M:%S"
- self.logger=sfa_logger
-
+ self.logger = logger
+
# load registries
from sfa.server.registry import Registries
- self.registries = Registries(self)
+ self.registries = Registries()
# load aggregates
from sfa.server.aggregate import Aggregates
- self.aggregates = Aggregates(self)
+ self.aggregates = Aggregates()
def get_interface_manager(self, manager_base = 'sfa.managers'):
Returns the appropriate manager module for this interface.
Modules are usually found in sfa/managers/
"""
-
+ manager=None
if self.interface in ['registry']:
- mgr_type = self.config.SFA_REGISTRY_TYPE
- manager_module = manager_base + ".registry_manager_%s" % mgr_type
+ manager=import_manager ("registry", self.config.SFA_REGISTRY_TYPE)
elif self.interface in ['aggregate']:
- mgr_type = self.config.SFA_AGGREGATE_TYPE
- manager_module = manager_base + ".aggregate_manager_%s" % mgr_type
+ manager=import_manager ("aggregate", self.config.SFA_AGGREGATE_TYPE)
elif self.interface in ['slicemgr', 'sm']:
- mgr_type = self.config.SFA_SM_TYPE
- manager_module = manager_base + ".slice_manager_%s" % mgr_type
+ manager=import_manager ("slice", self.config.SFA_SM_TYPE)
elif self.interface in ['component', 'cm']:
- mgr_type = self.config.SFA_CM_TYPE
- manager_module = manager_base + ".component_manager_%s" % mgr_type
- else:
+ manager=import_manager ("component", self.config.SFA_CM_TYPE)
+ if not manager:
raise SfaAPIError("No manager for interface: %s" % self.interface)
- manager = __import__(manager_module, fromlist=[manager_base])
- # this isnt necessary but will hlep to produce better error messages
+
+ # this isnt necessary but will help to produce better error messages
# if someone tries to access an operation this manager doesn't implement
manager = ManagerWrapper(manager, self.interface)
except SfaFault, fault:
result = fault
except Exception, fault:
- sfa_logger().log_exc("BaseAPI.handle has caught Exception")
+ logger.log_exc("BaseAPI.handle has caught Exception")
result = SfaAPIError(fault)
raise result
return response
+
+ def get_cached_server_version(self, server):
+ cache_key = server.url + "-version"
+ server_version = None
+ if self.cache:
+ server_version = self.cache.get(cache_key)
+ if not server_version:
+ server_version = server.GetVersion()
+ # cache version for 24 hours
+ self.cache.add(cache_key, server_version, ttl= 60*60*24)
+ return server_version
import threading
import time
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
"""
Callids: a simple mechanism to remember the call ids served so fas
if not call_id: return False
has_lock=False
for attempt in range(_call_ids_impl.retries):
- if debug: sfa_logger().debug("Waiting for lock (%d)"%attempt)
+ if debug: logger.debug("Waiting for lock (%d)"%attempt)
if self._lock.acquire(False):
has_lock=True
- if debug: sfa_logger().debug("got lock (%d)"%attempt)
+ if debug: logger.debug("got lock (%d)"%attempt)
break
time.sleep(float(_call_ids_impl.wait_ms)/1000)
# in the unlikely event where we can't get the lock
if not has_lock:
- sfa_logger().warning("_call_ids_impl.should_handle_call_id: could not acquire lock")
+ logger.warning("_call_ids_impl.should_handle_call_id: could not acquire lock")
return False
# we're good to go
if self.has_key(call_id):
self[call_id]=time.time()
self._purge()
self._lock.release()
- if debug: sfa_logger().debug("released lock")
+ if debug: logger.debug("released lock")
return False
def _purge(self):
for (k,v) in self.iteritems():
if (now-v) >= _call_ids_impl.purge_timeout: o_keys.append(k)
for k in o_keys:
- if debug: sfa_logger().debug("Purging call_id %r (%s)"%(k,time.strftime("%H:%M:%S",time.localtime(self[k]))))
+ if debug: logger.debug("Purging call_id %r (%s)"%(k,time.strftime("%H:%M:%S",time.localtime(self[k]))))
del self[k]
if debug:
- sfa_logger().debug("AFTER PURGE")
- for (k,v) in self.iteritems(): sfa_logger().debug("%s -> %s"%(k,time.strftime("%H:%M:%S",time.localtime(v))))
+ logger.debug("AFTER PURGE")
+ for (k,v) in self.iteritems(): logger.debug("%s -> %s"%(k,time.strftime("%H:%M:%S",time.localtime(v))))
def Callids ():
if not _call_ids_impl._instance:
# TODO: investigate ways to combine this with existing PLC server?
##
-### $Id$
-### $URL$
-
import sys
import traceback
import threading
import SimpleXMLRPCServer
from OpenSSL import SSL
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
from sfa.trust.certificate import Keypair, Certificate
from sfa.trust.credential import *
from sfa.util.faults import *
# internal error, report as HTTP server error
self.send_response(500)
self.end_headers()
- sfa_logger().log_exc("componentserver.SecureXMLRpcRequestHandler.do_POST")
+ logger.log_exc("componentserver.SecureXMLRpcRequestHandler.do_POST")
else:
# got a valid XML RPC response
self.send_response(200)
# Note that SFA does not access any of the PLC databases directly via
# a mysql connection; All PLC databases are accessed via PLCAPI.
-### $Id$
-### $URL$
-
import os.path
import traceback
if not hasattr(self, 'SFA_CM_TYPE'):
self.SFA_COMPONENT_TYPE='pl'
+ if not hasattr(self, 'SFA_MAX_SLICE_RENEW'):
+ self.SFA_MAX_SLICE_RENEW=60
+
# create the data directory if it doesnt exist
if not os.path.isdir(self.SFA_DATA_DIR):
try:
--- /dev/null
+
+class Enum(set):
+ def __init__(self, *args, **kwds):
+ set.__init__(self)
+ enums = dict(zip(args, [object() for i in range(len(args))]), **kwds)
+ for (key, value) in enums.items():
+ setattr(self, key, value)
+ self.add(eval('self.%s' % key))
+
+
+#def Enum2(*args, **kwds):
+# enums = dict(zip(sequential, range(len(sequential))), **named)
+# return type('Enum', (), enums)
def __str__(self):
return repr(self.value)
+class InvalidRSpecElement(SfaFault):
+ def __init__(self, value, extra = None):
+ self.value = value
+ faultString = "Invalid RSpec Element: %(value)s" % locals()
+ SfaFault.__init__(self, 108, faultString, extra)
+ def __str__(self):
+ return repr(self.value)
+
+class InvalidXML(SfaFault):
+ def __init__(self, value, extra = None):
+ self.value = value
+ faultString = "Invalid XML Document: %(value)s" % locals()
+ SfaFault.__init__(self, 108, faultString, extra)
+ def __str__(self):
+ return repr(self.value)
+
+class InvalidXMLElement(SfaFault):
+ def __init__(self, value, extra = None):
+ self.value = value
+ faultString = "Invalid XML Element: %(value)s" % locals()
+ SfaFault.__init__(self, 108, faultString, extra)
+ def __str__(self):
+ return repr(self.value)
+
class AccountNotEnabled(SfaFault):
def __init__(self, extra = None):
faultString = "Account Disabled"
--- /dev/null
+import httplib
+import socket
+import sys
+
+
+def is_python26():
+ return False
+ #return sys.version_info[0] == 2 and sys.version_info[1] == 6
+
+# wrapper around standartd https modules. Properly supports timeouts.
+
+class HTTPSConnection(httplib.HTTPSConnection):
+ def __init__(self, host, port=None, key_file=None, cert_file=None,
+ strict=None, timeout = None):
+ httplib.HTTPSConnection.__init__(self, host, port, key_file, cert_file, strict)
+ if timeout:
+ timeout = float(timeout)
+ self.timeout = timeout
+
+ def connect(self):
+ """Connect to a host on a given (SSL) port."""
+ if is_python26():
+ from sfa.util.ssl_socket import SSLSocket
+ sock = socket.create_connection((self.host, self.port), self.timeout)
+ if self._tunnel_host:
+ self.sock = sock
+ self._tunnel()
+ self.sock = SSLSocket(sock, self.key_file, self.cert_file)
+ else:
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(self.timeout)
+ sock.connect((self.host, self.port))
+ ssl = socket.ssl(sock, self.key_file, self.cert_file)
+ self.sock = httplib.FakeSocket(sock, ssl)
+
+class HTTPS(httplib.HTTPS):
+ def __init__(self, host='', port=None, key_file=None, cert_file=None,
+ strict=None, timeout = None):
+ # urf. compensate for bad input.
+ if port == 0:
+ port = None
+ self._setup(HTTPSConnection(host, port, key_file, cert_file, strict, timeout))
+
+ # we never actually use these for anything, but we keep them
+ # here for compatibility with post-1.5.2 CVS.
+ self.key_file = key_file
+ self.cert_file = cert_file
+
+ def set_timeout(self, timeout):
+ if is_python26():
+ self._conn.timeout = timeout
import xmlrpclib
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
from sfa.util.faults import *
from sfa.util.parameter import Parameter, Mixed, python_type, xmlrpc_type
from sfa.trust.auth import Auth
self.type_check(name, value, expected, args)
if self.api.config.SFA_API_DEBUG:
- sfa_logger().debug("method.__call__ [%s] : BEG %s"%(self.api.interface,methodname))
+ logger.debug("method.__call__ [%s] : BEG %s"%(self.api.interface,methodname))
result = self.call(*args, **kwds)
runtime = time.time() - start
if self.api.config.SFA_API_DEBUG or hasattr(self, 'message'):
- sfa_logger().debug("method.__call__ [%s] : END %s in %02f s (%s)"%\
+ logger.debug("method.__call__ [%s] : END %s in %02f s (%s)"%\
(self.api.interface,methodname,runtime,getattr(self,'message',"[no-msg]")))
return result
# Prepend caller and method name to expected faults
fault.faultString = caller + ": " + self.name + ": " + fault.faultString
runtime = time.time() - start
- sfa_logger().log_exc("Method %s raised an exception"%self.name)
+ logger.log_exc("Method %s raised an exception"%self.name)
raise fault
# Mark Huang <mlhuang@cs.princeton.edu>
# Copyright (C) 2006 The Trustees of Princeton University
#
-# $Id$
-#
-
-### $Id$
-### $URL$
from types import *
from sfa.util.faults import *
return PlXrn(xrn=hrn,type='slice').pl_login_base()
def hrn_to_pl_authname (hrn):
return PlXrn(xrn=hrn,type='any').pl_authname()
-
+def xrn_to_hostname(hrn):
+ return Xrn.unescape(PlXrn(xrn=hrn, type='node').get_leaf())
class PlXrn (Xrn):
def site_hrn (auth, login_base):
return '.'.join([auth,login_base])
- def __init__ (self, auth=None, hostname=None, slicename=None, email=None, **kwargs):
+ def __init__ (self, auth=None, hostname=None, slicename=None, email=None, interface=None, **kwargs):
#def hostname_to_hrn(auth_hrn, login_base, hostname):
if hostname is not None:
self.type='node'
# keep only the part before '@' and replace special chars into _
self.hrn='.'.join([auth,email.split('@')[0].replace(".", "_").replace("+", "_")])
self.hrn_to_urn()
+ elif interface is not None:
+ self.type = 'interface'
+ self.hrn = auth + '.' + interface
+ self.hrn_to_urn()
else:
Xrn.__init__ (self,**kwargs)
self._normalize()
return self.authority[-1]
+ def interface_name(self):
+ self._normalize()
+ return self.leaf
+
#def hrn_to_pl_login_base(hrn):
def pl_login_base (self):
self._normalize()
-### $Id$
-### $URL$
-
import os
from sfa.util.storage import *
# TODO: Use existing PLC database methods? or keep this separate?
##
-### $Id$
-### $URL$
-
from types import StringTypes
from sfa.trust.gid import *
-from sfa.util.rspec import *
from sfa.util.parameter import *
from sfa.util.xrn import get_authority
from sfa.util.row import Row
+from sfa.util.xml import XML
class SfaRecord(Row):
"""
"""
return GID(string=self.gid)
+ ##
+ # Returns the value of a field
+
+ def get_field(self, fieldname, default=None):
+ # sometimes records act like classes, and sometimes they act like dicts
+ try:
+ return getattr(self, fieldname)
+ except AttributeError:
+ try:
+ return self[fieldname]
+ except KeyError:
+ if default != None:
+ return default
+ else:
+ raise
+
##
# Returns a list of field names in this record.
"""
recorddict = self.as_dict()
filteredDict = dict([(key, val) for (key, val) in recorddict.iteritems() if key in self.fields.keys()])
- record = RecordSpec()
- record.parseDict(filteredDict)
+ record = XML('<record/>')
+ record.root.attrib.update(filteredDict)
str = record.toxml()
- #str = xmlrpclib.dumps((dict,), allow_none=True)
return str
##
"""
#dict = xmlrpclib.loads(str)[0][0]
- record = RecordSpec()
- record.parseString(str)
- record_dict = record.toDict()
- sfa_dict = record_dict['record']
- self.load_from_dict(sfa_dict)
+ record = XML(str)
+ self.load_from_dict(record.todict())
##
# Dump the record to stdout
+++ /dev/null
-import sys
-import pprint
-import os
-from StringIO import StringIO
-from types import StringTypes, ListType
-import httplib
-from xml.dom import minidom
-from lxml import etree
-
-from sfa.util.sfalogging import sfa_logger
-
-class RSpec:
-
- def __init__(self, xml = None, xsd = None, NSURL = None):
- '''
- Class to manipulate RSpecs. Reads and parses rspec xml into python dicts
- and reads python dicts and writes rspec xml
-
- self.xsd = # Schema. Can be local or remote file.
- self.NSURL = # If schema is remote, Name Space URL to query (full path minus filename)
- self.rootNode = # root of the DOM
- self.dict = # dict of the RSpec.
- self.schemaDict = {} # dict of the Schema
- '''
-
- self.xsd = xsd
- self.rootNode = None
- self.dict = {}
- self.schemaDict = {}
- self.NSURL = NSURL
- if xml:
- if type(xml) == file:
- self.parseFile(xml)
- if type(xml) in StringTypes:
- self.parseString(xml)
- self.dict = self.toDict()
- if xsd:
- self._parseXSD(self.NSURL + self.xsd)
-
-
- def _getText(self, nodelist):
- rc = ""
- for node in nodelist:
- if node.nodeType == node.TEXT_NODE:
- rc = rc + node.data
- return rc
-
- # The rspec is comprised of 2 parts, and 1 reference:
- # attributes/elements describe individual resources
- # complexTypes are used to describe a set of attributes/elements
- # complexTypes can include a reference to other complexTypes.
-
-
- def _getName(self, node):
- '''Gets name of node. If tag has no name, then return tag's localName'''
- name = None
- if not node.nodeName.startswith("#"):
- if node.localName:
- name = node.localName
- elif node.attributes.has_key("name"):
- name = node.attributes.get("name").value
- return name
-
-
- # Attribute. {name : nameofattribute, {items: values})
- def _attributeDict(self, attributeDom):
- '''Traverse single attribute node. Create a dict {attributename : {name: value,}]}'''
- node = {} # parsed dict
- for attr in attributeDom.attributes.keys():
- node[attr] = attributeDom.attributes.get(attr).value
- return node
-
-
- def appendToDictOrCreate(self, dict, key, value):
- if (dict.has_key(key)):
- dict[key].append(value)
- else:
- dict[key]=[value]
- return dict
-
- def toGenDict(self, nodeDom=None, parentdict=None, siblingdict={}, parent=None):
- """
- convert an XML to a nested dict:
- * Non-terminal nodes (elements with string children and attributes) are simple dictionaries
- * Terminal nodes (the rest) are nested dictionaries
- """
-
- if (not nodeDom):
- nodeDom=self.rootNode
-
- curNodeName = nodeDom.localName
-
- if (nodeDom.hasChildNodes()):
- childdict={}
- for attribute in nodeDom.attributes.keys():
- childdict = self.appendToDictOrCreate(childdict, attribute, nodeDom.getAttribute(attribute))
- for child in nodeDom.childNodes[:-1]:
- if (child.nodeValue):
- siblingdict = self.appendToDictOrCreate(siblingdict, curNodeName, child.nodeValue)
- else:
- childdict = self.toGenDict(child, None, childdict, curNodeName)
-
- child = nodeDom.childNodes[-1]
- if (child.nodeValue):
- siblingdict = self.appendToDictOrCreate(siblingdict, curNodeName, child.nodeValue)
- if (childdict):
- siblingdict = self.appendToDictOrCreate(siblingdict, curNodeName, childdict)
- else:
- siblingdict = self.toGenDict(child, siblingdict, childdict, curNodeName)
- else:
- childdict={}
- for attribute in nodeDom.attributes.keys():
- childdict = self.appendToDictOrCreate(childdict, attribute, nodeDom.getAttribute(attribute))
-
- self.appendToDictOrCreate(siblingdict, curNodeName, childdict)
-
- if (parentdict is not None):
- parentdict = self.appendToDictOrCreate(parentdict, parent, siblingdict)
- return parentdict
- else:
- return siblingdict
-
-
-
- def toDict(self, nodeDom = None):
- """
- convert this rspec to a dict and return it.
- """
- node = {}
- if not nodeDom:
- nodeDom = self.rootNode
-
- elementName = nodeDom.nodeName
- if elementName and not elementName.startswith("#"):
- # attributes have tags and values. get {tag: value}, else {type: value}
- node[elementName] = self._attributeDict(nodeDom)
- # resolve the child nodes.
- if nodeDom.hasChildNodes():
- for child in nodeDom.childNodes:
- childName = self._getName(child)
-
- # skip null children
- if not childName: continue
-
- # initialize the possible array of children
- if not node[elementName].has_key(childName): node[elementName][childName] = []
-
- if isinstance(child, minidom.Text):
- # add if data is not empty
- if child.data.strip():
- node[elementName][childName].append(nextchild.data)
- elif child.hasChildNodes() and isinstance(child.childNodes[0], minidom.Text):
- for nextchild in child.childNodes:
- node[elementName][childName].append(nextchild.data)
- else:
- childdict = self.toDict(child)
- for value in childdict.values():
- node[elementName][childName].append(value)
-
- return node
-
-
- def toxml(self):
- """
- convert this rspec to an xml string and return it.
- """
- return self.rootNode.toxml()
-
-
- def toprettyxml(self):
- """
- print this rspec in xml in a pretty format.
- """
- return self.rootNode.toprettyxml()
-
-
- def __removeWhitespaceNodes(self, parent):
- for child in list(parent.childNodes):
- if child.nodeType == minidom.Node.TEXT_NODE and child.data.strip() == '':
- parent.removeChild(child)
- else:
- self.__removeWhitespaceNodes(child)
-
- def parseFile(self, filename):
- """
- read a local xml file and store it as a dom object.
- """
- dom = minidom.parse(filename)
- self.__removeWhitespaceNodes(dom)
- self.rootNode = dom.childNodes[0]
-
-
- def parseString(self, xml):
- """
- read an xml string and store it as a dom object.
- """
- dom = minidom.parseString(xml)
- self.__removeWhitespaceNodes(dom)
- self.rootNode = dom.childNodes[0]
-
-
- def _httpGetXSD(self, xsdURI):
- # split the URI into relevant parts
- host = xsdURI.split("/")[2]
- if xsdURI.startswith("https"):
- conn = httplib.HTTPSConnection(host,
- httplib.HTTPSConnection.default_port)
- elif xsdURI.startswith("http"):
- conn = httplib.HTTPConnection(host,
- httplib.HTTPConnection.default_port)
- conn.request("GET", xsdURI)
- # If we can't download the schema, raise an exception
- r1 = conn.getresponse()
- if r1.status != 200:
- raise Exception
- return r1.read().replace('\n', '').replace('\t', '').strip()
-
-
- def _parseXSD(self, xsdURI):
- """
- Download XSD from URL, or if file, read local xsd file and set
- schemaDict.
-
- Since the schema definiton is a global namespace shared by and
- agreed upon by others, this should probably be a URL. Check
- for URL, download xsd, parse, or if local file, use that.
- """
- schemaDom = None
- if xsdURI.startswith("http"):
- try:
- schemaDom = minidom.parseString(self._httpGetXSD(xsdURI))
- except Exception, e:
- # logging.debug("%s: web file not found" % xsdURI)
- # logging.debug("Using local file %s" % self.xsd")
- sfa_logger().log_exc("rspec.parseXSD: can't find %s on the web. Continuing." % xsdURI)
- if not schemaDom:
- if os.path.exists(xsdURI):
- # logging.debug("using local copy.")
- sfa_logger().debug("rspec.parseXSD: Using local %s" % xsdURI)
- schemaDom = minidom.parse(xsdURI)
- else:
- raise Exception("rspec.parseXSD: can't find xsd locally")
- self.schemaDict = self.toDict(schemaDom.childNodes[0])
-
-
- def dict2dom(self, rdict, include_doc = False):
- """
- convert a dict object into a dom object.
- """
-
- def elementNode(tagname, rd):
- element = minidom.Element(tagname)
- for key in rd.keys():
- if isinstance(rd[key], StringTypes) or isinstance(rd[key], int):
- element.setAttribute(key, str(rd[key]))
- elif isinstance(rd[key], dict):
- child = elementNode(key, rd[key])
- element.appendChild(child)
- elif isinstance(rd[key], list):
- for item in rd[key]:
- if isinstance(item, dict):
- child = elementNode(key, item)
- element.appendChild(child)
- elif isinstance(item, StringTypes) or isinstance(item, int):
- child = minidom.Element(key)
- text = minidom.Text()
- text.data = item
- child.appendChild(text)
- element.appendChild(child)
- return element
-
- # Minidom does not allow documents to have more then one
- # child, but elements may have many children. Because of
- # this, the document's root node will be the first key/value
- # pair in the dictionary.
- node = elementNode(rdict.keys()[0], rdict.values()[0])
- if include_doc:
- rootNode = minidom.Document()
- rootNode.appendChild(node)
- else:
- rootNode = node
- return rootNode
-
-
- def parseDict(self, rdict, include_doc = True):
- """
- Convert a dictionary into a dom object and store it.
- """
- self.rootNode = self.dict2dom(rdict, include_doc).childNodes[0]
-
-
- def getDictsByTagName(self, tagname, dom = None):
- """
- Search the dom for all elements with the specified tagname
- and return them as a list of dicts
- """
- if not dom:
- dom = self.rootNode
- dicts = []
- doms = dom.getElementsByTagName(tagname)
- dictlist = [self.toDict(d) for d in doms]
- for item in dictlist:
- for value in item.values():
- dicts.append(value)
- return dicts
-
- def getDictByTagNameValue(self, tagname, value, dom = None):
- """
- Search the dom for the first element with the specified tagname
- and value and return it as a dict.
- """
- tempdict = {}
- if not dom:
- dom = self.rootNode
- dicts = self.getDictsByTagName(tagname, dom)
-
- for rdict in dicts:
- if rdict.has_key('name') and rdict['name'] in [value]:
- return rdict
-
- return tempdict
-
-
- def filter(self, tagname, attribute, blacklist = [], whitelist = [], dom = None):
- """
- Removes all elements where:
- 1. tagname matches the element tag
- 2. attribute matches the element attribte
- 3. attribute value is in valuelist
- """
-
- tempdict = {}
- if not dom:
- dom = self.rootNode
-
- if dom.localName in [tagname] and dom.attributes.has_key(attribute):
- if whitelist and dom.attributes.get(attribute).value not in whitelist:
- dom.parentNode.removeChild(dom)
- if blacklist and dom.attributes.get(attribute).value in blacklist:
- dom.parentNode.removeChild(dom)
-
- if dom.hasChildNodes():
- for child in dom.childNodes:
- self.filter(tagname, attribute, blacklist, whitelist, child)
-
-
- def merge(self, rspecs, tagname, dom=None):
- """
- Merge this rspec with the requested rspec based on the specified
- starting tag name. The start tag (and all of its children) will be merged
- """
- tempdict = {}
- if not dom:
- dom = self.rootNode
-
- whitelist = []
- blacklist = []
-
- if dom.localName in [tagname] and dom.attributes.has_key(attribute):
- if whitelist and dom.attributes.get(attribute).value not in whitelist:
- dom.parentNode.removeChild(dom)
- if blacklist and dom.attributes.get(attribute).value in blacklist:
- dom.parentNode.removeChild(dom)
-
- if dom.hasChildNodes():
- for child in dom.childNodes:
- self.filter(tagname, attribute, blacklist, whitelist, child)
-
- def validateDicts(self):
- types = {
- 'EInt' : int,
- 'EString' : str,
- 'EByteArray' : list,
- 'EBoolean' : bool,
- 'EFloat' : float,
- 'EDate' : date}
-
-
- def pprint(self, r = None, depth = 0):
- """
- Pretty print the dict
- """
- line = ""
- if r == None: r = self.dict
- # Set the dept
- for tab in range(0,depth): line += " "
- # check if it's nested
- if type(r) == dict:
- for i in r.keys():
- print line + "%s:" % i
- self.pprint(r[i], depth + 1)
- elif type(r) in (tuple, list):
- for j in r: self.pprint(j, depth + 1)
- # not nested so just print.
- else:
- print line + "%s" % r
-
-
-
-class RecordSpec(RSpec):
-
- root_tag = 'record'
- def parseDict(self, rdict, include_doc = False):
- """
- Convert a dictionary into a dom object and store it.
- """
- self.rootNode = self.dict2dom(rdict, include_doc)
-
- def dict2dom(self, rdict, include_doc = False):
- record_dict = rdict
- if not len(rdict.keys()) == 1:
- record_dict = {self.root_tag : rdict}
- return RSpec.dict2dom(self, record_dict, include_doc)
-
-
-# vim:ts=4:expandtab
-
from optparse import OptionParser
from sfa.util.faults import *
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
def merge_rspecs(rspecs):
"""
try:
known_networks[network.get('name')]=True
except:
- sfa_logger().error("merge_rspecs: cannot register network with no name in rspec")
+ logger.error("merge_rspecs: cannot register network with no name in rspec")
pass
def is_registered_network (network):
try:
return network.get('name') in known_networks
except:
- sfa_logger().error("merge_rspecs: cannot retrieve network with no name in rspec")
+ logger.error("merge_rspecs: cannot retrieve network with no name in rspec")
return False
# the resulting tree
tree = etree.parse(StringIO(input_rspec))
except etree.XMLSyntaxError:
# consider failing silently here
- sfa_logger().log_exc("merge_rspecs, parse error")
+ logger.log_exc("merge_rspecs, parse error")
message = str(sys.exc_info()[1]) + ' with ' + input_rspec
raise InvalidRSpec(message)
root = tree.getroot()
if not root.get("type") in ["SFA"]:
- sfa_logger().error("merge_rspecs: unexpected type for rspec root, %s"%root.get('type'))
+ logger.error("merge_rspecs: unexpected type for rspec root, %s"%root.get('type'))
continue
if rspec == None:
# we scan the first input, register all networks
from OpenSSL import SSL
from sfa.trust.certificate import Keypair, Certificate
-from sfa.trust.trustedroot import TrustedRootList
+from sfa.trust.trustedroots import TrustedRoots
from sfa.util.config import Config
from sfa.trust.credential import *
from sfa.util.faults import *
from sfa.plc.api import SfaAPI
from sfa.util.cache import Cache
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
##
# Verification callback for pyOpenSSL. We do our own checking of keys because
except Exception, fault:
# This should only happen if the module is buggy
# internal error, report as HTTP server error
- sfa_logger().log_exc("server.do_POST")
+ logger.log_exc("server.do_POST")
response = self.api.prepare_response(fault)
#self.send_response(500)
#self.end_headers()
It it very similar to SimpleXMLRPCServer but it uses HTTPS for transporting XML data.
"""
- sfa_logger().debug("SecureXMLRPCServer.__init__, server_address=%s, cert_file=%s"%(server_address,cert_file))
+ logger.debug("SecureXMLRPCServer.__init__, server_address=%s, cert_file=%s"%(server_address,cert_file))
self.logRequests = logRequests
self.interface = None
self.key_file = key_file
# If you wanted to verify certs against known CAs.. this is how you would do it
#ctx.load_verify_locations('/etc/sfa/trusted_roots/plc.gpo.gid')
config = Config()
- trusted_cert_files = TrustedRootList(config.get_trustedroots_dir()).get_file_list()
+ trusted_cert_files = TrustedRoots(config.get_trustedroots_dir()).get_file_list()
for cert_file in trusted_cert_files:
ctx.load_verify_locations(cert_file)
ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, verify_callback)
# the client.
def _dispatch(self, method, params):
- sfa_logger().debug("SecureXMLRPCServer._dispatch, method=%s"%method)
+ logger.debug("SecureXMLRPCServer._dispatch, method=%s"%method)
try:
return SimpleXMLRPCServer.SimpleXMLRPCDispatcher._dispatch(self, method, params)
except:
self.server.interface=interface
self.trusted_cert_list = None
self.register_functions()
- sfa_logger().info("Starting SfaServer, interface=%s"%interface)
+ logger.info("Starting SfaServer, interface=%s"%interface)
##
# Register functions that will be served by the XMLRPC server. This
def __init__ (self,logfile=None,loggername=None,level=logging.INFO):
# default is to locate loggername from the logfile if avail.
if not logfile:
- loggername='console'
- handler=logging.StreamHandler()
- handler.setFormatter(logging.Formatter("%(levelname)s %(message)s"))
- else:
- if not loggername:
- loggername=os.path.basename(logfile)
- try:
- handler=logging.handlers.RotatingFileHandler(logfile,maxBytes=1000000, backupCount=5)
- except IOError:
- # This is usually a permissions error becaue the file is
- # owned by root, but httpd is trying to access it.
- tmplogfile=os.getenv("TMPDIR", "/tmp") + os.path.sep + os.path.basename(logfile)
+ #loggername='console'
+ #handler=logging.StreamHandler()
+ #handler.setFormatter(logging.Formatter("%(levelname)s %(message)s"))
+ logfile = "/var/log/sfa.log"
+
+ if not loggername:
+ loggername=os.path.basename(logfile)
+ try:
+ handler=logging.handlers.RotatingFileHandler(logfile,maxBytes=1000000, backupCount=5)
+ except IOError:
+ # This is usually a permissions error becaue the file is
+ # owned by root, but httpd is trying to access it.
+ tmplogfile=os.getenv("TMPDIR", "/tmp") + os.path.sep + os.path.basename(logfile)
+ # In strange uses, 2 users on same machine might use same code,
+ # meaning they would clobber each others files
+ # We could (a) rename the tmplogfile, or (b)
+ # just log to the console in that case.
+ # Here we default to the console.
+ if os.path.exists(tmplogfile) and not os.access(tmplogfile,os.W_OK):
+ loggername = loggername + "-console"
+ handler = logging.StreamHandler()
+ else:
handler=logging.handlers.RotatingFileHandler(tmplogfile,maxBytes=1000000, backupCount=5)
- handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
-
+ handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
self.logger=logging.getLogger(loggername)
self.logger.setLevel(level)
- self.logger.addHandler(handler)
+ # check if logger already has the handler we're about to add
+ handler_exists = False
+ for l_handler in self.logger.handlers:
+ if l_handler.baseFilename == handler.baseFilename and \
+ l_handler.level == handler.level:
+ handler_exists = True
+
+ if not handler_exists:
+ self.logger.addHandler(handler)
+
self.loggername=loggername
def setLevel(self,level):
self.logger.setLevel(logging.DEBUG)
####################
- def wrap(fun):
- def wrapped(self,msg,*args,**kwds):
- native=getattr(self.logger,fun.__name__)
- return native(msg,*args,**kwds)
- #wrapped.__doc__=native.__doc__
- return wrapped
-
- @wrap
- def critical(): pass
- @wrap
- def error(): pass
- @wrap
- def warning(): pass
- @wrap
- def info(): pass
- @wrap
- def debug(): pass
-
+ def info(self, msg):
+ self.logger.info(msg)
+
+ def debug(self, msg):
+ self.logger.debug(msg)
+
+ def warn(self, msg):
+ self.logger.warn(msg)
+
+ # some code is using logger.warn(), some is using logger.warning()
+ def warning(self, msg):
+ self.logger.warning(msg)
+
+ def error(self, msg):
+ self.logger.error(msg)
+
+ def critical(self, msg):
+ self.logger.critical(msg)
+
# logs an exception - use in an except statement
def log_exc(self,message):
self.error("%s BEG TRACEBACK"%message+"\n"+traceback.format_exc().strip("\n"))
# for investigation purposes, can be placed anywhere
def log_stack(self,message):
to_log="".join(traceback.format_stack())
- self.debug("%s BEG STACK"%message+"\n"+to_log)
- self.debug("%s END STACK"%message)
-
-####################
-# import-related operations go in this file
-_import_logger=_SfaLogger(logfile='/var/log/sfa_import.log')
-# servers log into /var/log/sfa.log
-_server_logger=_SfaLogger(logfile='/var/log/sfa.log')
-# clients use the console
-_console_logger=_SfaLogger()
-
-# default is to use the server-side logger
-_the_logger=_server_logger
-
-# clients would change the default by issuing one of these call
-def sfa_logger_goes_to_console():
- current_module=sys.modules[globals()['__name__']]
- current_module._the_logger=_console_logger
-
-# clients would change the default by issuing one of these call
-def sfa_logger_goes_to_import():
- current_module=sys.modules[globals()['__name__']]
- current_module._the_logger=_import_logger
-
-# this is how to retrieve the 'right' logger
-def sfa_logger():
- return _the_logger
+ self.info("%s BEG STACK"%message+"\n"+to_log)
+ self.info("%s END STACK"%message)
+
+ def enable_console(self, stream=sys.stdout):
+ formatter = logging.Formatter("%(message)s")
+ handler = logging.StreamHandler(stream)
+ handler.setFormatter(formatter)
+ self.logger.addHandler(handler)
+
+info_logger = _SfaLogger(loggername='info', level=logging.INFO)
+debug_logger = _SfaLogger(loggername='debug', level=logging.DEBUG)
+warn_logger = _SfaLogger(loggername='warning', level=logging.WARNING)
+error_logger = _SfaLogger(loggername='error', level=logging.ERROR)
+critical_logger = _SfaLogger(loggername='critical', level=logging.CRITICAL)
+logger = info_logger
+sfi_logger = _SfaLogger(logfile=os.path.expanduser("~/.sfi/")+'sfi.log',loggername='sfilog', level=logging.DEBUG)
########################################
import time
if __name__ == '__main__':
print 'testing sfalogging into logger.log'
- logger=_SfaLogger('logger.log')
- logger.critical("logger.critical")
- logger.error("logger.error")
- logger.warning("logger.warning")
- logger.info("logger.info")
- logger.debug("logger.debug")
- logger.setLevel(logging.DEBUG)
- logger.debug("logger.debug again")
+ logger1=_SfaLogger('logger.log', loggername='std(info)')
+ logger2=_SfaLogger('logger.log', loggername='error', level=logging.ERROR)
+ logger3=_SfaLogger('logger.log', loggername='debug', level=logging.DEBUG)
+
+ for (logger,msg) in [ (logger1,"std(info)"),(logger2,"error"),(logger3,"debug")]:
+
+ print "====================",msg, logger.logger.handlers
+
+ logger.enable_console()
+ logger.critical("logger.critical")
+ logger.error("logger.error")
+ logger.warn("logger.warning")
+ logger.info("logger.info")
+ logger.debug("logger.debug")
+ logger.setLevel(logging.DEBUG)
+ logger.debug("logger.debug again")
- sfa_logger_goes_to_console()
- my_logger=sfa_logger()
- my_logger.info("redirected to console")
-
- @profile(my_logger)
- def sleep(seconds = 1):
- time.sleep(seconds)
-
- my_logger.info('console.info')
- sleep(0.5)
- my_logger.setLevel(logging.DEBUG)
- sleep(0.25)
+ @profile(logger)
+ def sleep(seconds = 1):
+ time.sleep(seconds)
+
+ logger.info('console.info')
+ sleep(0.5)
+ logger.setLevel(logging.DEBUG)
+ sleep(0.25)
+from types import StringTypes
import dateutil.parser
+import datetime
-def utcparse(str):
+from sfa.util.sfalogging import logger
+
+def utcparse(input):
""" Translate a string into a time using dateutil.parser.parse but make sure it's in UTC time and strip
- the timezone, so that it's compatible with normal datetime.datetime objects"""
+the timezone, so that it's compatible with normal datetime.datetime objects.
+
+For safety this can also handle inputs that are either timestamps, or datetimes
+"""
- t = dateutil.parser.parse(str)
- if not t.utcoffset() is None:
- t = t.utcoffset() + t.replace(tzinfo=None)
- return t
+ if isinstance (input, datetime.datetime):
+ logger.warn ("argument to utcparse already a datetime - doing nothing")
+ return input
+ elif isinstance (input, StringTypes):
+ t = dateutil.parser.parse(input)
+ if t.utcoffset() is not None:
+ t = t.utcoffset() + t.replace(tzinfo=None)
+ return t
+ elif isinstance (input, (int,float)):
+ return datetime.datetime.fromtimestamp(input)
+ else:
+ logger.error("Unexpected type in utcparse [%s]"%type(input))
+
# SpecDict.plc_fields defines a one to one mapping of plc attribute to rspec
# attribute
-### $Id$
-### $URL$
-
from types import StringTypes, ListType
class SpecDict(dict):
--- /dev/null
+from ssl import SSLSocket
+
+import textwrap
+
+import _ssl # if we can't import it, let the error propagate
+
+from _ssl import SSLError
+from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
+from _ssl import PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1
+from _ssl import RAND_status, RAND_egd, RAND_add
+from _ssl import \
+ SSL_ERROR_ZERO_RETURN, \
+ SSL_ERROR_WANT_READ, \
+ SSL_ERROR_WANT_WRITE, \
+ SSL_ERROR_WANT_X509_LOOKUP, \
+ SSL_ERROR_SYSCALL, \
+ SSL_ERROR_SSL, \
+ SSL_ERROR_WANT_CONNECT, \
+ SSL_ERROR_EOF, \
+ SSL_ERROR_INVALID_ERROR_CODE
+
+from socket import socket, _fileobject
+from socket import getnameinfo as _getnameinfo
+import base64 # for DER-to-PEM translation
+
+class SSLSocket(SSLSocket, socket):
+
+ """This class implements a subtype of socket.socket that wraps
+ the underlying OS socket in an SSL context when necessary, and
+ provides read and write methods over that channel."""
+
+ def __init__(self, sock, keyfile=None, certfile=None,
+ server_side=False, cert_reqs=CERT_NONE,
+ ssl_version=PROTOCOL_SSLv23, ca_certs=None,
+ do_handshake_on_connect=True,
+ suppress_ragged_eofs=True):
+ socket.__init__(self, _sock=sock._sock)
+ # the initializer for socket trashes the methods (tsk, tsk), so...
+ self.send = lambda data, flags=0: SSLSocket.send(self, data, flags)
+ self.sendto = lambda data, addr, flags=0: SSLSocket.sendto(self, data, addr, flags)
+ self.recv = lambda buflen=1024, flags=0: SSLSocket.recv(self, buflen, flags)
+ self.recvfrom = lambda addr, buflen=1024, flags=0: SSLSocket.recvfrom(self, addr, buflen, flags)
+ self.recv_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recv_into(self, buffer, nbytes, flags)
+ self.recvfrom_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recvfrom_into(self, buffer, nbytes, flags)
+
+ if certfile and not keyfile:
+ keyfile = certfile
+ # see if it's connected
+ try:
+ socket.getpeername(self)
+ except:
+ # no, no connection yet
+ self._sslobj = None
+ else:
+ # yes, create the SSL object
+ self._sslobj = _ssl.sslwrap(self._sock, server_side,
+ keyfile, certfile,
+ cert_reqs, ssl_version, ca_certs)
+ if do_handshake_on_connect:
+ timeout = self.gettimeout()
+ try:
+ if timeout == 0:
+ self.settimeout(None)
+ self.do_handshake()
+ finally:
+ self.settimeout(timeout)
+ self.keyfile = keyfile
+ self.certfile = certfile
+ self.cert_reqs = cert_reqs
+ self.ssl_version = ssl_version
+ self.ca_certs = ca_certs
+ self.do_handshake_on_connect = do_handshake_on_connect
+ self.suppress_ragged_eofs = suppress_ragged_eofs
+ self._makefile_refs = 0
+
+
import os
-
-from sfa.util.rspec import RecordSpec
-
+from sfa.util.xml import XML
class SimpleStorage(dict):
"""
Handles storing and loading python dictionaries. The storage file created
"""
Parse an xml file and store it as a dict
"""
- data = RecordSpec()
if os.path.exists(self.db_filename) and os.path.isfile(self.db_filename):
- data.parseFile(self.db_filename)
- dict.__init__(self, data.toDict())
+ xml = XML(self.db_filename)
+ dict.__init__(self, xml.todict())
elif os.path.exists(self.db_filename) and not os.path.isfile(self.db_filename):
raise IOError, '%s exists but is not a file. please remove it and try again' \
% self.db_filename
self.load()
def write(self):
- data = RecordSpec()
- data.parseDict(self)
+ xml = XML()
+ xml.parseDict(self)
db_file = open(self.db_filename, 'w')
db_file.write(data.toprettyxml())
db_file.close()
import traceback
import time
from Queue import Queue
+from sfa.util.sfalogging import logger
def ThreadedMethod(callable, results, errors):
"""
try:
results.put(callable(*args, **kwds))
except Exception, e:
+ logger.log_exc('ThreadManager: Error in thread: ')
errors.put(traceback.format_exc())
thread = ThreadInstance()
ThreadManager executes a callable in a thread and stores the result
in a thread safe queue.
"""
- results = Queue()
- errors = Queue()
- threads = []
+
+ def __init__(self):
+ self.results = Queue()
+ self.errors = Queue()
+ self.threads = []
def run (self, method, *args, **kwds):
"""
--- /dev/null
+#!/usr/bin/python
+from lxml import etree
+from StringIO import StringIO
+from datetime import datetime, timedelta
+from sfa.util.xrn import *
+from sfa.util.plxrn import hostname_to_urn
+from sfa.util.faults import SfaNotImplemented, InvalidXML
+
+class XpathFilter:
+ @staticmethod
+ def xpath(filter={}):
+ xpath = ""
+ if filter:
+ filter_list = []
+ for (key, value) in filter.items():
+ if key == 'text':
+ key = 'text()'
+ else:
+ key = '@'+key
+ if isinstance(value, str):
+ filter_list.append('%s="%s"' % (key, value))
+ elif isinstance(value, list):
+ filter_list.append('contains("%s", %s)' % (' '.join(map(str, value)), key))
+ if filter_list:
+ xpath = ' and '.join(filter_list)
+ xpath = '[' + xpath + ']'
+ return xpath
+
+class XML:
+
+ def __init__(self, xml=None):
+ self.root = None
+ self.namespaces = None
+ self.default_namespace = None
+ self.schema = None
+ if isinstance(xml, basestring):
+ self.parse_xml(xml)
+ elif isinstance(xml, etree._ElementTree):
+ self.root = xml.getroot()
+ elif isinstance(xml, etree._Element):
+ self.root = xml
+
+ def parse_xml(self, xml):
+ """
+ parse rspec into etree
+ """
+ parser = etree.XMLParser(remove_blank_text=True)
+ try:
+ tree = etree.parse(xml, parser)
+ except IOError:
+ # 'rspec' file doesnt exist. 'rspec' is proably an xml string
+ try:
+ tree = etree.parse(StringIO(xml), parser)
+ except Exception, e:
+ raise InvalidXML(str(e))
+ self.root = tree.getroot()
+ # set namespaces map
+ self.namespaces = dict(self.root.nsmap)
+ # If the 'None' exist, then it's pointing to the default namespace. This makes
+ # it hard for us to write xpath queries for the default naemspace because lxml
+ # wont understand a None prefix. We will just associate the default namespeace
+ # with a key named 'default'.
+ if None in self.namespaces:
+ default_namespace = self.namespaces.pop(None)
+ self.namespaces['default'] = default_namespace
+
+ # set schema
+ for key in self.root.attrib.keys():
+ if key.endswith('schemaLocation'):
+ # schema location should be at the end of the list
+ schema_parts = self.root.attrib[key].split(' ')
+ self.schema = schema_parts[1]
+ namespace, schema = schema_parts[0], schema_parts[1]
+ break
+
+ def parse_dict(self, d, root_tag_name='xml', element = None):
+ if element is None:
+ self.parse_xml('<%s/>' % root_tag_name)
+ element = self.root
+
+ if 'text' in d:
+ text = d.pop('text')
+ element.text = text
+
+ # handle repeating fields
+ for (key, value) in d.items():
+ if isinstance(value, list):
+ value = d.pop(key)
+ for val in value:
+ if isinstance(val, dict):
+ child_element = etree.SubElement(element, key)
+ self.parse_dict(val, key, child_element)
+
+ element.attrib.update(d)
+
+ def validate(self, schema):
+ """
+ Validate against rng schema
+ """
+ relaxng_doc = etree.parse(schema)
+ relaxng = etree.RelaxNG(relaxng_doc)
+ if not relaxng(self.root):
+ error = relaxng.error_log.last_error
+ message = "%s (line %s)" % (error.message, error.line)
+ raise InvalidXML(message)
+ return True
+
+ def xpath(self, xpath, namespaces=None):
+ if not namespaces:
+ namespaces = self.namespaces
+ return self.root.xpath(xpath, namespaces=namespaces)
+
+ def set(self, key, value):
+ return self.root.set(key, value)
+
+ def add_attribute(self, elem, name, value):
+ """
+ Add attribute to specified etree element
+ """
+ opt = etree.SubElement(elem, name)
+ opt.text = value
+
+ def add_element(self, name, attrs={}, parent=None, text=""):
+ """
+ Generic wrapper around etree.SubElement(). Adds an element to
+ specified parent node. Adds element to root node is parent is
+ not specified.
+ """
+ if parent == None:
+ parent = self.root
+ element = etree.SubElement(parent, name)
+ if text:
+ element.text = text
+ if isinstance(attrs, dict):
+ for attr in attrs:
+ element.set(attr, attrs[attr])
+ return element
+
+ def remove_attribute(self, elem, name, value):
+ """
+ Removes an attribute from an element
+ """
+ if elem is not None:
+ opts = elem.iterfind(name)
+ if opts is not None:
+ for opt in opts:
+ if opt.text == value:
+ elem.remove(opt)
+
+ def remove_element(self, element_name, root_node = None):
+ """
+ Removes all occurences of an element from the tree. Start at
+ specified root_node if specified, otherwise start at tree's root.
+ """
+ if not root_node:
+ root_node = self.root
+
+ if not element_name.startswith('//'):
+ element_name = '//' + element_name
+
+ elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
+ for element in elements:
+ parent = element.getparent()
+ parent.remove(element)
+
+ def attributes_list(self, elem):
+ # convert a list of attribute tags into list of tuples
+ # (tagnme, text_value)
+ opts = []
+ if elem is not None:
+ for e in elem:
+ opts.append((e.tag, str(e.text).strip()))
+ return opts
+
+ def get_element_attributes(self, elem=None, depth=0):
+ if elem == None:
+ elem = self.root_node
+ if not hasattr(elem, 'attrib'):
+ # this is probably not an element node with attribute. could be just and an
+ # attribute, return it
+ return elem
+ attrs = dict(elem.attrib)
+ attrs['text'] = str(elem.text).strip()
+ attrs['parent'] = elem.getparent()
+ if isinstance(depth, int) and depth > 0:
+ for child_elem in list(elem):
+ key = str(child_elem.tag)
+ if key not in attrs:
+ attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
+ else:
+ attrs[key].append(self.get_element_attributes(child_elem, depth-1))
+ else:
+ attrs['child_nodes'] = list(elem)
+ return attrs
+
+ def merge(self, in_xml):
+ pass
+
+ def __str__(self):
+ return self.toxml()
+
+ def toxml(self):
+ return etree.tostring(self.root, pretty_print=True)
+
+ def todict(self, elem=None):
+ if elem is None:
+ elem = self.root
+ d = {}
+ d.update(elem.attrib)
+ d['text'] = elem.text
+ for child in elem.iterchildren():
+ if child.tag not in d:
+ d[child.tag] = []
+ d[child.tag].append(self.todict(child))
+ return d
+
+ def save(self, filename):
+ f = open(filename, 'w')
+ f.write(self.toxml())
+ f.close()
+
+if __name__ == '__main__':
+ rspec = RSpec('/tmp/resources.rspec')
+ print rspec
+
# XMLRPC-specific code for SFA Client
-import httplib
import xmlrpclib
-
-from sfa.util.sfalogging import sfa_logger
-
+#from sfa.util.httpsProtocol import HTTPS, HTTPSConnection
+from httplib import HTTPS, HTTPSConnection
+from sfa.util.sfalogging import logger
##
# ServerException, ExceptionUnmarshaller
#
need_HTTPSConnection=hasattr(xmlrpclib.Transport().make_connection('localhost'),'getresponse')
class XMLRPCTransport(xmlrpclib.Transport):
- key_file = None
- cert_file = None
+
+ def __init__(self, key_file=None, cert_file=None, timeout=None):
+ xmlrpclib.Transport.__init__(self)
+ self.timeout=timeout
+ self.key_file = key_file
+ self.cert_file = cert_file
+
def make_connection(self, host):
# create a HTTPS connection object from a host descriptor
# host may be a string, or a (host, x509-dict) tuple
host, extra_headers, x509 = self.get_host_info(host)
if need_HTTPSConnection:
- return httplib.HTTPSConnection(host, None, key_file=self.key_file, cert_file=self.cert_file) #**(x509 or {}))
+ #conn = HTTPSConnection(host, None, key_file=self.key_file, cert_file=self.cert_file, timeout=self.timeout) #**(x509 or {}))
+ conn = HTTPSConnection(host, None, key_file=self.key_file, cert_file=self.cert_file) #**(x509 or {}))
else:
- return httplib.HTTPS(host, None, key_file=self.key_file, cert_file=self.cert_file) #**(x509 or {}))
+ #conn = HTTPS(host, None, key_file=self.key_file, cert_file=self.cert_file, timeout=self.timeout) #**(x509 or {}))
+ conn = HTTPS(host, None, key_file=self.key_file, cert_file=self.cert_file) #**(x509 or {}))
+
+ if hasattr(conn, 'set_timeout'):
+ conn.set_timeout(self.timeout)
+
+ # Some logic to deal with timeouts. It appears that some (or all) versions
+ # of python don't set the timeout after the socket is created. We'll do it
+ # ourselves by forcing the connection to connect, finding the socket, and
+ # calling settimeout() on it. (tested with python 2.6)
+ if self.timeout:
+ if hasattr(conn, "_conn"):
+ # HTTPS is a wrapper around HTTPSConnection
+ real_conn = conn._conn
+ else:
+ real_conn = conn
+ conn.connect()
+ if hasattr(real_conn, "sock") and hasattr(real_conn.sock, "settimeout"):
+ real_conn.sock.settimeout(float(self.timeout))
+
+ return conn
def getparser(self):
unmarshaller = ExceptionUnmarshaller()
return parser, unmarshaller
class XMLRPCServerProxy(xmlrpclib.ServerProxy):
- def __init__(self, url, transport, allow_none=True, options=None):
+ def __init__(self, url, transport, allow_none=True, verbose=False):
# remember url for GetVersion
self.url=url
- verbose = False
- if options and options.debug:
- verbose = True
-# sfa_logger().debug ("xmlrpcprotocol.XMLRPCServerProxy.__init__ %s (with verbose=%s)"%(url,verbose))
xmlrpclib.ServerProxy.__init__(self, url, transport, allow_none=allow_none, verbose=verbose)
def __getattr__(self, attr):
- sfa_logger().debug ("xml-rpc %s method:%s"%(self.url,attr))
+ logger.debug ("xml-rpc %s method:%s"%(self.url,attr))
return xmlrpclib.ServerProxy.__getattr__(self, attr)
-
-def get_server(url, key_file, cert_file, options=None):
- transport = XMLRPCTransport()
- transport.key_file = key_file
- transport.cert_file = cert_file
-
- return XMLRPCServerProxy(url, transport, allow_none=True, options=options)
+def get_server(url, key_file, cert_file, timeout=None, verbose=False):
+ transport = XMLRPCTransport(key_file, cert_file, timeout)
+ return XMLRPCServerProxy(url, transport, allow_none=True, verbose=verbose)
+#----------------------------------------------------------------------
+# Copyright (c) 2008 Board of Trustees, Princeton University
+#
+# Permission is hereby granted, free of charge, to any person obtaining
+# a copy of this software and/or hardware specification (the "Work") to
+# deal in the Work without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Work, and to permit persons to whom the Work
+# is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Work.
+#
+# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS
+# IN THE WORK.
+#----------------------------------------------------------------------
+
import re
from sfa.util.faults import *
-from sfa.util.sfalogging import sfa_logger
# for convenience and smoother translation - we should get rid of these functions eventually
def get_leaf(hrn): return Xrn(hrn).get_leaf()
def get_authority(hrn): return Xrn(hrn).get_authority_hrn()
def urn_to_hrn(urn): xrn=Xrn(urn); return (xrn.hrn, xrn.type)
def hrn_to_urn(hrn,type): return Xrn(hrn, type=type).urn
+def hrn_authfor_hrn(parenthrn, hrn): return Xrn.hrn_is_auth_for_hrn(parenthrn, hrn)
+
+def urn_to_sliver_id(urn, slice_id, node_id, index=0):
+ return ":".join(map(str, [urn, slice_id, node_id, index]))
class Xrn:
# e.g. escape ('a.b') -> 'a\.b'
@staticmethod
def escape(token): return re.sub(r'([^\\])\.', r'\1\.', token)
+
# e.g. unescape ('a\.b') -> 'a.b'
@staticmethod
def unescape(token): return token.replace('\\.','.')
-
+
+ # Return the HRN authority chain from top to bottom.
+ # e.g. hrn_auth_chain('a\.b.c.d') -> ['a\.b', 'a\.b.c']
+ @staticmethod
+ def hrn_auth_chain(hrn):
+ parts = Xrn.hrn_auth_list(hrn)
+ chain = []
+ for i in range(len(parts)):
+ chain.append('.'.join(parts[:i+1]))
+ # Include the HRN itself?
+ #chain.append(hrn)
+ return chain
+
+ # Is the given HRN a true authority over the namespace of the other
+ # child HRN?
+ # A better alternative than childHRN.startswith(parentHRN)
+ # e.g. hrn_is_auth_for_hrn('a\.b', 'a\.b.c.d') -> True,
+ # but hrn_is_auth_for_hrn('a', 'a\.b.c.d') -> False
+ # Also hrn_is_uauth_for_hrn('a\.b.c.d', 'a\.b.c.d') -> True
+ @staticmethod
+ def hrn_is_auth_for_hrn(parenthrn, hrn):
+ if parenthrn == hrn:
+ return True
+ for auth in Xrn.hrn_auth_chain(hrn):
+ if parenthrn == auth:
+ return True
+ return False
+
URN_PREFIX = "urn:publicid:IDN"
########## basic tools on URNs
self.hrn_to_urn()
# happens all the time ..
# if not type:
-# sfa_logger().debug("type-less Xrn's are not safe")
+# debug_logger.debug("type-less Xrn's are not safe")
def get_urn(self): return self.urn
def get_hrn(self): return self.hrn
+++ /dev/null
-these files used to be in geniwrapper/cmdline
from sfa.trust.certificate import *
from sfa.trust.credential import *
from sfa.util.sfaticket import *
-from sfa.util.rspec import *
from sfa.client import sfi
def random_string(size):
-#!/usr/bin/env python\r
-#-------------------------------------------------------------------------------\r
-import os\r
-import sys\r
-import glob\r
-import os.path\r
-from setuptools import setup\r
-#from distutils.core import setup\r
-#-------------------------------------------------------------------------------\r
-if 'upload' in sys.argv:\r
- # for .pypirc file\r
- try:\r
- os.environ['HOME']\r
- except KeyError:\r
- os.environ['HOME'] = '..\\'\r
-#-------------------------------------------------------------------------------\r
-fpath = lambda x : os.path.join(*x.split('/'))\r
-#-------------------------------------------------------------------------------\r
-PYPI_URL = 'http://pypi.python.org/pypi/xmlbuilder'\r
-ld = open(fpath('xmlbuilder/docs/long_descr.rst')).read()\r
-ld = ld.replace('&','&').replace('<','<').replace('>','>')\r
-setup(\r
- name = "xmlbuilder",\r
- fullname = "xmlbuilder",\r
- version = "0.9",\r
- packages = ["xmlbuilder"],\r
- package_dir = {'xmlbuilder':'xmlbuilder'},\r
- author = "koder",\r
- author_email = "koder_dot_mail@gmail_dot_com",\r
- maintainer = 'koder',\r
- maintainer_email = "koder_dot_mail@gmail_dot_com",\r
- description = "Pythonic way to create xml files",\r
- license = "MIT",\r
- keywords = "xml",\r
- test_suite = "xml_buider.tests",\r
- url = PYPI_URL,\r
- download_url = PYPI_URL,\r
- long_description = ld,\r
- #include_package_data = True,\r
- #package_data = {'xmlbuilder':["docs/*.rst"]},\r
- #data_files = [('', ['xmlbuilder/docs/long_descr.rst'])]\r
-)\r
-#-------------------------------------------------------------------------------\r
+#!/usr/bin/env python
+#-------------------------------------------------------------------------------
+import os
+import sys
+import glob
+import os.path
+from setuptools import setup
+#from distutils.core import setup
+#-------------------------------------------------------------------------------
+if 'upload' in sys.argv:
+ # for .pypirc file
+ try:
+ os.environ['HOME']
+ except KeyError:
+ os.environ['HOME'] = '..\\'
+#-------------------------------------------------------------------------------
+fpath = lambda x : os.path.join(*x.split('/'))
+#-------------------------------------------------------------------------------
+PYPI_URL = 'http://pypi.python.org/pypi/xmlbuilder'
+ld = open(fpath('xmlbuilder/docs/long_descr.rst')).read()
+ld = ld.replace('&','&').replace('<','<').replace('>','>')
+setup(
+ name = "xmlbuilder",
+ fullname = "xmlbuilder",
+ version = "0.9",
+ packages = ["xmlbuilder"],
+ package_dir = {'xmlbuilder':'xmlbuilder'},
+ author = "koder",
+ author_email = "koder_dot_mail@gmail_dot_com",
+ maintainer = 'koder',
+ maintainer_email = "koder_dot_mail@gmail_dot_com",
+ description = "Pythonic way to create xml files",
+ license = "MIT",
+ keywords = "xml",
+ test_suite = "xml_buider.tests",
+ url = PYPI_URL,
+ download_url = PYPI_URL,
+ long_description = ld,
+ #include_package_data = True,
+ #package_data = {'xmlbuilder':["docs/*.rst"]},
+ #data_files = [('', ['xmlbuilder/docs/long_descr.rst'])]
+)
+#-------------------------------------------------------------------------------
-#!/usr/bin/env python\r
-#-------------------------------------------------------------------------------\r
-from __future__ import with_statement\r
-#-------------------------------------------------------------------------------\r
-from xml.etree.ElementTree import TreeBuilder,tostring\r
-#-------------------------------------------------------------------------------\r
-__all__ = ["XMLBuilder"]\r
-__doc__ = """\r
-XMLBuilder is simple library build on top of ElementTree.TreeBuilder to\r
-simplify xml files creation as much as possible. Althow it can produce\r
-structured result with identated child tags. `XMLBuilder` use python `with`\r
-statement to define xml tag levels and `<<` operator for simple cases -\r
-text and tag without childs.\r
-\r
-from __future__ import with_statement\r
-from xmlbuilder import XMLBuilder\r
-x = XMLBuilder(format=True)\r
-with x.root(a = 1):\r
- with x.data:\r
- [x << ('node',{'val':i}) for i in range(10)]\r
-\r
-etree_node = ~x\r
-print str(x)\r
-"""\r
-#-------------------------------------------------------------------------------\r
-class _XMLNode(object):\r
- """Class for internal usage"""\r
- def __init__(self,parent,name,builder):\r
- self.builder = builder\r
- self.name = name\r
- self.text = []\r
- self.attrs = {}\r
- self.entered = False\r
- self.parent = parent\r
- def __call__(self,*dt,**mp):\r
- text = "".join(dt)\r
- if self.entered:\r
- self.builder.data(text)\r
- else:\r
- self.text.append(text)\r
- if self.entered:\r
- raise ValueError("Can't add attributes to already opened element")\r
- smp = dict((k,str(v)) for k,v in mp.items())\r
- self.attrs.update(smp)\r
- return self\r
- def __enter__(self):\r
- self.parent += 1\r
- self.builder.start(self.name,self.attrs)\r
- self.builder.data("".join(self.text))\r
- self.entered = True\r
- return self\r
- def __exit__(self,x,y,z):\r
- self.parent -= 1\r
- self.builder.end(self.name)\r
- return False\r
-#-------------------------------------------------------------------------------\r
-class XMLBuilder(object):\r
- """XmlBuilder(encoding = 'utf-8', # result xml file encoding\r
- builder = None, #etree.TreeBuilder or compatible class\r
- tab_level = None, #current tabulation level - string\r
- format = False, # make formatted output\r
- tab_step = " " * 4) # tabulation step\r
- use str(builder) or unicode(builder) to get xml text or\r
- ~builder to obtaine etree.ElementTree\r
- """\r
- def __init__(self,encoding = 'utf-8',\r
- builder = None,\r
- tab_level = None,\r
- format = False,\r
- tab_step = " " * 4):\r
- self.__builder = builder or TreeBuilder()\r
- self.__encoding = encoding \r
- if format :\r
- if tab_level is None:\r
- tab_level = ""\r
- if tab_level is not None:\r
- if not format:\r
- raise ValueError("format is False, but tab_level not None")\r
- self.__tab_level = tab_level # current format level\r
- self.__tab_step = tab_step # format step\r
- self.__has_sub_tag = False # True, if current tag had childrens\r
- self.__node = None\r
- # called from _XMLNode when tag opened\r
- def __iadd__(self,val):\r
- self.__has_sub_tag = False\r
- if self.__tab_level is not None:\r
- self.__builder.data("\n" + self.__tab_level)\r
- self.__tab_level += self.__tab_step\r
- return self\r
- # called from XMLNode when tag closed\r
- def __isub__(self,val):\r
- if self.__tab_level is not None:\r
- self.__tab_level = self.__tab_level[:-len(self.__tab_step)]\r
- if self.__has_sub_tag:\r
- self.__builder.data("\n" + self.__tab_level)\r
- self.__has_sub_tag = True\r
- return self\r
- def __getattr__(self,name):\r
- return _XMLNode(self,name,self.__builder)\r
- def __call__(self,name,*dt,**mp):\r
- x = _XMLNode(self,name,self.__builder)\r
- x(*dt,**mp)\r
- return x\r
- #create new tag or add text\r
- #possible shift values\r
- #string - text\r
- #tuple(string1,string2,dict) - new tag with name string1,attrs = dict,and text string2\r
- #dict and string2 are optional\r
- def __lshift__(self,val):\r
- if isinstance(val,basestring):\r
- self.__builder.data(val)\r
- else:\r
- self.__has_sub_tag = True\r
- assert hasattr(val,'__len__'),\\r
- 'Shifted value should be tuple or list like object not %r' % val\r
- assert hasattr(val,'__getitem__'),\\r
- 'Shifted value should be tuple or list like object not %r' % val\r
- name = val[0]\r
- if len(val) == 3:\r
- text = val[1]\r
- attrs = val[2]\r
- elif len(val) == 1:\r
- text = ""\r
- attrs = {}\r
- elif len(val) == 2:\r
- if isinstance(val[1],basestring):\r
- text = val[1]\r
- attrs = {}\r
- else:\r
- text = ""\r
- attrs = val[1]\r
- if self.__tab_level is not None:\r
- self.__builder.data("\n" + self.__tab_level)\r
- self.__builder.start(name,\r
- dict((k,str(v)) for k,v in attrs.items()))\r
- if text:\r
- self.__builder.data(text)\r
- self.__builder.end(name)\r
- return self # to allow xml << some1 << some2 << some3\r
- #close builder\r
- def __invert__(self):\r
- if self.__node is not None:\r
- return self.__node\r
- self.__node = self.__builder.close()\r
- return self.__node\r
- def __str__(self):\r
- """return generated xml"""\r
- return tostring(~self,self.__encoding)\r
- def __unicode__(self):\r
- """return generated xml"""\r
- res = tostring(~self,self.__encoding)\r
- return res.decode(self.__encoding)\r
-#-------------------------------------------------------------------------------\r
+#!/usr/bin/env python
+#-------------------------------------------------------------------------------
+from __future__ import with_statement
+#-------------------------------------------------------------------------------
+from xml.etree.ElementTree import TreeBuilder,tostring
+#-------------------------------------------------------------------------------
+__all__ = ["XMLBuilder"]
+__doc__ = """
+XMLBuilder is simple library build on top of ElementTree.TreeBuilder to
+simplify xml files creation as much as possible. Althow it can produce
+structured result with identated child tags. `XMLBuilder` use python `with`
+statement to define xml tag levels and `<<` operator for simple cases -
+text and tag without childs.
+
+from __future__ import with_statement
+from xmlbuilder import XMLBuilder
+x = XMLBuilder(format=True)
+with x.root(a = 1):
+ with x.data:
+ [x << ('node',{'val':i}) for i in range(10)]
+
+etree_node = ~x
+print str(x)
+"""
+#-------------------------------------------------------------------------------
+class _XMLNode(object):
+ """Class for internal usage"""
+ def __init__(self,parent,name,builder):
+ self.builder = builder
+ self.name = name
+ self.text = []
+ self.attrs = {}
+ self.entered = False
+ self.parent = parent
+ def __call__(self,*dt,**mp):
+ text = "".join(dt)
+ if self.entered:
+ self.builder.data(text)
+ else:
+ self.text.append(text)
+ if self.entered:
+ raise ValueError("Can't add attributes to already opened element")
+ smp = dict((k,str(v)) for k,v in mp.items())
+ self.attrs.update(smp)
+ return self
+ def __enter__(self):
+ self.parent += 1
+ self.builder.start(self.name,self.attrs)
+ self.builder.data("".join(self.text))
+ self.entered = True
+ return self
+ def __exit__(self,x,y,z):
+ self.parent -= 1
+ self.builder.end(self.name)
+ return False
+#-------------------------------------------------------------------------------
+class XMLBuilder(object):
+ """XmlBuilder(encoding = 'utf-8', # result xml file encoding
+ builder = None, #etree.TreeBuilder or compatible class
+ tab_level = None, #current tabulation level - string
+ format = False, # make formatted output
+ tab_step = " " * 4) # tabulation step
+ use str(builder) or unicode(builder) to get xml text or
+ ~builder to obtaine etree.ElementTree
+ """
+ def __init__(self,encoding = 'utf-8',
+ builder = None,
+ tab_level = None,
+ format = False,
+ tab_step = " " * 4):
+ self.__builder = builder or TreeBuilder()
+ self.__encoding = encoding
+ if format :
+ if tab_level is None:
+ tab_level = ""
+ if tab_level is not None:
+ if not format:
+ raise ValueError("format is False, but tab_level not None")
+ self.__tab_level = tab_level # current format level
+ self.__tab_step = tab_step # format step
+ self.__has_sub_tag = False # True, if current tag had childrens
+ self.__node = None
+ # called from _XMLNode when tag opened
+ def __iadd__(self,val):
+ self.__has_sub_tag = False
+ if self.__tab_level is not None:
+ self.__builder.data("\n" + self.__tab_level)
+ self.__tab_level += self.__tab_step
+ return self
+ # called from XMLNode when tag closed
+ def __isub__(self,val):
+ if self.__tab_level is not None:
+ self.__tab_level = self.__tab_level[:-len(self.__tab_step)]
+ if self.__has_sub_tag:
+ self.__builder.data("\n" + self.__tab_level)
+ self.__has_sub_tag = True
+ return self
+ def __getattr__(self,name):
+ return _XMLNode(self,name,self.__builder)
+ def __call__(self,name,*dt,**mp):
+ x = _XMLNode(self,name,self.__builder)
+ x(*dt,**mp)
+ return x
+ #create new tag or add text
+ #possible shift values
+ #string - text
+ #tuple(string1,string2,dict) - new tag with name string1,attrs = dict,and text string2
+ #dict and string2 are optional
+ def __lshift__(self,val):
+ if isinstance(val,basestring):
+ self.__builder.data(val)
+ else:
+ self.__has_sub_tag = True
+ assert hasattr(val,'__len__'),\
+ 'Shifted value should be tuple or list like object not %r' % val
+ assert hasattr(val,'__getitem__'),\
+ 'Shifted value should be tuple or list like object not %r' % val
+ name = val[0]
+ if len(val) == 3:
+ text = val[1]
+ attrs = val[2]
+ elif len(val) == 1:
+ text = ""
+ attrs = {}
+ elif len(val) == 2:
+ if isinstance(val[1],basestring):
+ text = val[1]
+ attrs = {}
+ else:
+ text = ""
+ attrs = val[1]
+ if self.__tab_level is not None:
+ self.__builder.data("\n" + self.__tab_level)
+ self.__builder.start(name,
+ dict((k,str(v)) for k,v in attrs.items()))
+ if text:
+ self.__builder.data(text)
+ self.__builder.end(name)
+ return self # to allow xml << some1 << some2 << some3
+ #close builder
+ def __invert__(self):
+ if self.__node is not None:
+ return self.__node
+ self.__node = self.__builder.close()
+ return self.__node
+ def __str__(self):
+ """return generated xml"""
+ return tostring(~self,self.__encoding)
+ def __unicode__(self):
+ """return generated xml"""
+ res = tostring(~self,self.__encoding)
+ return res.decode(self.__encoding)
+#-------------------------------------------------------------------------------
-#!/usr/bin/env python\r
-from __future__ import with_statement\r
-#-------------------------------------------------------------------------------\r
-import unittest\r
-from xml.etree.ElementTree import fromstring\r
-#-------------------------------------------------------------------------------\r
-from xmlbuilder import XMLBuilder\r
-#-------------------------------------------------------------------------------\r
-def xmlStructureEqual(xml1,xml2):\r
- tree1 = fromstring(xml1)\r
- tree2 = fromstring(xml2)\r
- return _xmlStructureEqual(tree1,tree2)\r
-#-------------------------------------------------------------------------------\r
-def _xmlStructureEqual(tree1,tree2):\r
- if tree1.tag != tree2.tag:\r
- return False\r
- attr1 = list(tree1.attrib.items())\r
- attr1.sort()\r
- attr2 = list(tree2.attrib.items())\r
- attr2.sort()\r
- if attr1 != attr2:\r
- return False\r
- return tree1.getchildren() == tree2.getchildren()\r
-#-------------------------------------------------------------------------------\r
-result1 = \\r
-"""\r
-<root>\r
- <array />\r
- <array len="10">\r
- <el val="0" />\r
- <el val="1">xyz</el>\r
- <el val="2">abc</el>\r
- <el val="3" />\r
- <el val="4" />\r
- <el val="5" />\r
- <sup-el val="23">test </sup-el>\r
- </array>\r
-</root>\r
-""".strip()\r
-#-------------------------------------------------------------------------------\r
-class TestXMLBuilder(unittest.TestCase):\r
- def testShift(self):\r
- xml = (XMLBuilder() << ('root',))\r
- self.assertEqual(str(xml),"<root />")\r
- \r
- xml = XMLBuilder()\r
- xml << ('root',"some text")\r
- self.assertEqual(str(xml),"<root>some text</root>")\r
- \r
- xml = XMLBuilder()\r
- xml << ('root',{'x':1,'y':'2'})\r
- self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'>some text</root>"))\r
- \r
- xml = XMLBuilder()\r
- xml << ('root',{'x':1,'y':'2'})\r
- self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'></root>"))\r
-\r
- xml = XMLBuilder()\r
- xml << ('root',{'x':1,'y':'2'})\r
- self.assert_(not xmlStructureEqual(str(xml),"<root x='2' y='2'></root>"))\r
-\r
- \r
- xml = XMLBuilder()\r
- xml << ('root',"gonduras.ua",{'x':1,'y':'2'})\r
- self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'>gonduras.ua</root>"))\r
- \r
- xml = XMLBuilder()\r
- xml << ('root',"gonduras.ua",{'x':1,'y':'2'})\r
- self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'>gonduras.com</root>"))\r
- #---------------------------------------------------------------------------\r
- def testWith(self):\r
- xml = XMLBuilder()\r
- with xml.root(lenght = 12):\r
- pass\r
- self.assertEqual(str(xml),'<root lenght="12" />')\r
- \r
- xml = XMLBuilder()\r
- with xml.root():\r
- xml << "text1" << "text2" << ('some_node',)\r
- self.assertEqual(str(xml),"<root>text1text2<some_node /></root>")\r
- #---------------------------------------------------------------------------\r
- def testFormat(self):\r
- x = XMLBuilder('utf-8',format = True)\r
- with x.root():\r
- x << ('array',)\r
- with x.array(len = 10):\r
- with x.el(val = 0):\r
- pass\r
- with x.el('xyz',val = 1):\r
- pass\r
- x << ("el","abc",{'val':2}) << ('el',dict(val=3))\r
- x << ('el',dict(val=4)) << ('el',dict(val='5'))\r
- with x('sup-el',val = 23):\r
- x << "test "\r
- self.assertEqual(str(x),result1)\r
-#-------------------------------------------------------------------------------\r
-if __name__ == '__main__':\r
- unittest.main()\r
-#-------------------------------------------------------------------------------\r
+#!/usr/bin/env python
+from __future__ import with_statement
+#-------------------------------------------------------------------------------
+import unittest
+from xml.etree.ElementTree import fromstring
+#-------------------------------------------------------------------------------
+from xmlbuilder import XMLBuilder
+#-------------------------------------------------------------------------------
+def xmlStructureEqual(xml1,xml2):
+ tree1 = fromstring(xml1)
+ tree2 = fromstring(xml2)
+ return _xmlStructureEqual(tree1,tree2)
+#-------------------------------------------------------------------------------
+def _xmlStructureEqual(tree1,tree2):
+ if tree1.tag != tree2.tag:
+ return False
+ attr1 = list(tree1.attrib.items())
+ attr1.sort()
+ attr2 = list(tree2.attrib.items())
+ attr2.sort()
+ if attr1 != attr2:
+ return False
+ return tree1.getchildren() == tree2.getchildren()
+#-------------------------------------------------------------------------------
+result1 = \
+"""
+<root>
+ <array />
+ <array len="10">
+ <el val="0" />
+ <el val="1">xyz</el>
+ <el val="2">abc</el>
+ <el val="3" />
+ <el val="4" />
+ <el val="5" />
+ <sup-el val="23">test </sup-el>
+ </array>
+</root>
+""".strip()
+#-------------------------------------------------------------------------------
+class TestXMLBuilder(unittest.TestCase):
+ def testShift(self):
+ xml = (XMLBuilder() << ('root',))
+ self.assertEqual(str(xml),"<root />")
+
+ xml = XMLBuilder()
+ xml << ('root',"some text")
+ self.assertEqual(str(xml),"<root>some text</root>")
+
+ xml = XMLBuilder()
+ xml << ('root',{'x':1,'y':'2'})
+ self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'>some text</root>"))
+
+ xml = XMLBuilder()
+ xml << ('root',{'x':1,'y':'2'})
+ self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'></root>"))
+
+ xml = XMLBuilder()
+ xml << ('root',{'x':1,'y':'2'})
+ self.assert_(not xmlStructureEqual(str(xml),"<root x='2' y='2'></root>"))
+
+
+ xml = XMLBuilder()
+ xml << ('root',"gonduras.ua",{'x':1,'y':'2'})
+ self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'>gonduras.ua</root>"))
+
+ xml = XMLBuilder()
+ xml << ('root',"gonduras.ua",{'x':1,'y':'2'})
+ self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'>gonduras.com</root>"))
+ #---------------------------------------------------------------------------
+ def testWith(self):
+ xml = XMLBuilder()
+ with xml.root(lenght = 12):
+ pass
+ self.assertEqual(str(xml),'<root lenght="12" />')
+
+ xml = XMLBuilder()
+ with xml.root():
+ xml << "text1" << "text2" << ('some_node',)
+ self.assertEqual(str(xml),"<root>text1text2<some_node /></root>")
+ #---------------------------------------------------------------------------
+ def testFormat(self):
+ x = XMLBuilder('utf-8',format = True)
+ with x.root():
+ x << ('array',)
+ with x.array(len = 10):
+ with x.el(val = 0):
+ pass
+ with x.el('xyz',val = 1):
+ pass
+ x << ("el","abc",{'val':2}) << ('el',dict(val=3))
+ x << ('el',dict(val=4)) << ('el',dict(val='5'))
+ with x('sup-el',val = 23):
+ x << "test "
+ self.assertEqual(str(x),result1)
+#-------------------------------------------------------------------------------
+if __name__ == '__main__':
+ unittest.main()
+#-------------------------------------------------------------------------------