Merge branch 'master' of ssh://git.f-lab.fr/git/sfa
authorNicolas Turro <nturro@sfa2.grenoble.senslab.info>
Thu, 27 Oct 2011 15:04:43 +0000 (17:04 +0200)
committerNicolas Turro <nturro@sfa2.grenoble.senslab.info>
Thu, 27 Oct 2011 15:04:43 +0000 (17:04 +0200)
136 files changed:
.gitignore
Makefile
TODO
config/default_config.xml
flashpolicy/__init__.py [new file with mode: 0644]
setup.py
sfa.spec
sfa/Makefile [deleted file]
sfa/client/Makefile
sfa/client/client_helper.py [new file with mode: 0644]
sfa/client/getNodes.py
sfa/client/getRecord.py
sfa/client/setRecord.py
sfa/client/sfadump.py
sfa/client/sfascan.py
sfa/client/sfi.py
sfa/client/sfiAddAttribute.py
sfa/client/sfiAddSliver.py
sfa/client/sfiDeleteAttribute.py
sfa/client/sfiDeleteSliver.py
sfa/client/sfiListNodes.py
sfa/client/sfiListSlivers.py
sfa/init.d/sfa
sfa/init.d/sfa-cm
sfa/managers/aggregate_manager.py [moved from sfa/managers/aggregate_manager_pl.py with 81% similarity]
sfa/managers/aggregate_manager_eucalyptus.py
sfa/managers/aggregate_manager_max.py
sfa/managers/aggregate_manager_openflow.py
sfa/managers/aggregate_manager_vini.py
sfa/managers/eucalyptus/eucalyptus.rnc
sfa/managers/eucalyptus/eucalyptus.rng
sfa/managers/eucalyptus/eucalyptus.xml
sfa/managers/import_manager.py [new file with mode: 0644]
sfa/managers/registry_manager.py [moved from sfa/managers/registry_manager_pl.py with 97% similarity]
sfa/managers/slice_manager.py [moved from sfa/managers/slice_manager_pl.py with 54% similarity]
sfa/managers/vini/topology.py
sfa/managers/vini/vini.rnc
sfa/managers/vini/vini.rng
sfa/methods/CreateGid.py [new file with mode: 0644]
sfa/methods/CreateSliver.py
sfa/methods/GetTicket.py
sfa/methods/RedeemTicket.py
sfa/methods/Register.py
sfa/methods/RegisterPeerObject.py
sfa/methods/RenewSliver.py
sfa/methods/Resolve.py
sfa/methods/Start.py
sfa/methods/Stop.py
sfa/methods/Update.py
sfa/methods/__init__.py
sfa/methods/register_peer_object.py
sfa/methods/reset_slice.py
sfa/plc/aggregate.py
sfa/plc/api.py
sfa/plc/network.py
sfa/plc/sfa-import-plc.py
sfa/plc/sfa-nuke-plc.py
sfa/plc/sfaImport.py
sfa/plc/slices.py
sfa/plc/vini_aggregate.py [new file with mode: 0644]
sfa/plc/vlink.py [new file with mode: 0644]
sfa/rspecs/elements/__init__.py [new file with mode: 0644]
sfa/rspecs/elements/element.py [new file with mode: 0644]
sfa/rspecs/elements/interface.py [new file with mode: 0644]
sfa/rspecs/elements/link.py [new file with mode: 0644]
sfa/rspecs/elements/network.py [new file with mode: 0644]
sfa/rspecs/elements/node.py [new file with mode: 0644]
sfa/rspecs/elements/sliver.py [new file with mode: 0644]
sfa/rspecs/elements/versions/__init__.py [new file with mode: 0644]
sfa/rspecs/elements/versions/pgv2Link.py [new file with mode: 0644]
sfa/rspecs/pg_rspec.py [deleted file]
sfa/rspecs/pg_rspec_converter.py
sfa/rspecs/resources/__init__.py [new file with mode: 0644]
sfa/rspecs/resources/ext/__init__.py [new file with mode: 0644]
sfa/rspecs/resources/ext/planetlab.rnc [new file with mode: 0644]
sfa/rspecs/resources/ext/planetlab.xsd [new file with mode: 0644]
sfa/rspecs/rspec.py
sfa/rspecs/rspec_converter.py
sfa/rspecs/rspec_elements.py [new file with mode: 0644]
sfa/rspecs/rspec_parser.py [deleted file]
sfa/rspecs/rspec_version.py [changed mode: 0755->0644]
sfa/rspecs/sfa_rspec_converter.py
sfa/rspecs/version_manager.py [new file with mode: 0644]
sfa/rspecs/versions/__init__.py [new file with mode: 0644]
sfa/rspecs/versions/pgv2.py [new file with mode: 0644]
sfa/rspecs/versions/pgv3.py [new file with mode: 0644]
sfa/rspecs/versions/sfav1.py [moved from sfa/rspecs/sfa_rspec.py with 58% similarity, mode: 0644]
sfa/server/aggregate.py
sfa/server/component.py
sfa/server/interface.py
sfa/server/modpython/SfaAggregateModPython.py
sfa/server/modpython/SfaRegistryModPython.py
sfa/server/modpython/SfaSliceMgrModPython.py
sfa/server/registry.py
sfa/server/sfa-server.py
sfa/server/slicemgr.py
sfa/trust/auth.py
sfa/trust/certificate.py
sfa/trust/credential.py
sfa/trust/credential.xsd
sfa/trust/credential_legacy.py
sfa/trust/gid.py
sfa/trust/hierarchy.py
sfa/trust/rights.py
sfa/trust/trustedroot.py [deleted file]
sfa/trust/trustedroots.py [new file with mode: 0644]
sfa/util/PostgreSQL.py
sfa/util/api.py
sfa/util/callids.py
sfa/util/componentserver.py
sfa/util/config.py
sfa/util/enumeration.py [new file with mode: 0644]
sfa/util/faults.py
sfa/util/httpsProtocol.py [new file with mode: 0644]
sfa/util/method.py
sfa/util/parameter.py
sfa/util/plxrn.py
sfa/util/policy.py
sfa/util/record.py
sfa/util/rspec.py [deleted file]
sfa/util/rspecHelper.py
sfa/util/server.py
sfa/util/sfalogging.py [changed mode: 0755->0644]
sfa/util/sfatime.py
sfa/util/specdict.py
sfa/util/ssl_socket.py [new file with mode: 0644]
sfa/util/storage.py
sfa/util/threadmanager.py [changed mode: 0755->0644]
sfa/util/xml.py [new file with mode: 0755]
sfa/util/xmlrpcprotocol.py
sfa/util/xrn.py
tests/client/README [deleted file]
tests/testInterfaces.py
xmlbuilder-0.9/setup.py
xmlbuilder-0.9/xmlbuilder/__init__.py
xmlbuilder-0.9/xmlbuilder/tests/__init__.py

index 9acfb22..7e5b62e 100644 (file)
@@ -15,3 +15,4 @@ sfa/client/*.version
 *.pkey
 *.cert
 *.cred
+.DS_Store
index 5bb1bba..f162f5f 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -15,7 +15,7 @@ uninstall: python-uninstall tests-uninstall
 
 .PHONY: all install clean uninstall
 
-VERSIONTAG=should-be-redefined-by-specfile
+VERSIONTAG=0.0-0-should.be-redefined-by-specfile
 SCMURL=should-be-redefined-by-specfile
 
 ##########
diff --git a/TODO b/TODO
index ab99b42..6af99bd 100644 (file)
--- a/TODO
+++ b/TODO
@@ -1,5 +1,7 @@
 RSpecs
 - CreateSlivers should update SliverTags/attributes 
+- ProtoGENI rspec integration testing
+- initscripts in the rspec
 
 Registry
 - Verify that sub authority certificates still work
index 35a23ab..212dee4 100644 (file)
@@ -35,6 +35,12 @@ Thierry Parmentelat
           <value>false</value>
           <description>Flag to turn debug on.</description>
         </variable>
+    
+        <variable id="max_slice_renew" type="int">
+          <name>Max Slice Renew</name>
+          <value>60</value>
+          <description>Maximum amout of days a user can extend/renew their slices to</description>
+        </variable>
 
         <variable id="session_key_path" type="string">
             <name>User Session Keys Path </name>
diff --git a/flashpolicy/__init__.py b/flashpolicy/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
index 2800555..b10be90 100755 (executable)
--- a/setup.py
+++ b/setup.py
@@ -43,7 +43,11 @@ package_dirs = [
     'sfa/trust',
     'sfa/util', 
     'sfa/managers',
+    'sfa/managers/vini',
     'sfa/rspecs',
+    'sfa/rspecs/elements',
+    'sfa/rspecs/elements/versions',
+    'sfa/rspecs/versions',
     'sfatables',
     'sfatables/commands',
     'sfatables/processors',
index 9b56813..4f10df4 100644 (file)
--- a/sfa.spec
+++ b/sfa.spec
@@ -1,6 +1,6 @@
 %define name sfa
 %define version 1.0
-%define taglevel 25
+%define taglevel 36
 
 %define release %{taglevel}%{?pldistro:.%{pldistro}}%{?date:.%{date}}
 %global python_sitearch        %( python -c "from distutils.sysconfig import get_python_lib; print get_python_lib(1)" )
@@ -46,13 +46,13 @@ Requires: python-dateutil
 #%endif
 
 %package cm
-Summary: the SFA wrapper around MyPLC NodeManager
+Summary: the SFA layer around MyPLC NodeManager
 Group: Applications/System
 Requires: sfa
 Requires: pyOpenSSL >= 0.6
 
 %package plc
-Summary: the SFA wrapper arounf MyPLC
+Summary: the SFA layer around MyPLC
 Group: Applications/System
 Requires: sfa
 Requires: python-psycopg2
@@ -193,10 +193,77 @@ if [ "$1" = 0 ] ; then
 fi
 
 %postun cm
-[ "$1" -ge "1" ] && service sfa-cm restart
-
+[ "$1" -ge "1" ] && service sfa-cm restart || :
 
 %changelog
+* Thu Sep 15 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-36
+- Unicode-friendliness for user names with accents/special chars.
+- Fix bug that could cause create the client to fail when calling CreateSliver for a slice that has the same hrn as a user.
+- CreaetSliver no longer fails for users that have a capital letter in their URN.
+- Fix bug in CreateSliver that generated incorrect login bases and email addresses for ProtoGENI requests. 
+- Allow files with .gid, .pem or .crt extension to be loaded into the server's list of trusted certs.
+- Fix bugs and missing imports     
+
+* Tue Aug 30 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-35
+- new method record.get_field for sface
+
+* Mon Aug 29 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-34
+- new option -c to sfa-nuke-plc.py
+- CreateSliver fixed for admin-only slice tags
+
+* Wed Aug 24 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-32
+- Fixed exploit that allowed an authorities to issue certs for objects that dont belong to them.
+- Fixed holes in certificate verification logic.
+- Aggregates no longer try to lookup slice and person records when processing CreateSliver requests. Clients are now required to specify this info in the 'users' argument. 
+- Added 'boot_state' as an attribute of the node element in SFA rspec.
+- Non authority certificates are marked as CA:FALSE.
+
+* Tue Aug 16 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-32
+- fix typo in sfa-1.0-31 tag.
+- added CreateGid() Registry interface method.
+
+* Tue Aug 16 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-31
+- fix typo in sfa-1.0-30 tag
+
+* Tue Aug 16 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-30
+- Declare namespace and schema location in the credential.
+- Fix bug that prevetend connections from timing out.
+- Fix slice delegation.
+- Add statistics to slicemaanger listresources/createsliver rspec.
+- Added SFA_MAX_SLICE_RENEW which allows operators to configure the max ammout
+  of days a user can extend their slice expiration.
+- CA certs are only issued to objects of type authority
+   
+* Fri Aug 05 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-29
+- tag 1.0-28 was broken due to typo in the changelog
+- new class sfa/util/httpsProtocol.py that supports timeouts
+
+* Thu Aug 4 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-28
+- Resolved issue that caused sfa hold onto idle db connections.
+- Fix bug that caused the registry to use the wrong type of credential.
+- Support authority+sm type.
+- Fix rspec merging bugs.
+- Only load certs that have .gid extension from /etc/sfa/trusted_roots/
+- Created a 'planetlab' extension to the ProtoGENI v2 rspec for supporting 
+ planetlab hosted initscripts using the <planetlab:initscript> tag  
+- Can now handle extraneous whitespace in the rspec without failing.   
+* Fri Jul 8 2011 Tony Mack <tmack@cs.princeton.edu> - sfa-1.0-27
+- ProtoGENI v2 RSpec updates.
+- Convert expiration timestamps with timezone info in credentials to utc.
+- Fixed redundant logging issue. 
+- Improved SliceManager and SFI client logging.
+- Support aggregates that don't support the optional 'call_id' argument. 
+- Only call get_trusted_certs() at aggreage interfaces that support the call.
+- CreateSliver() now handles MyPLC slice attributes/tags.
+- Cache now supports persistence.
+- Hide whitelisted nodes.
+
+* Tue Jun 21 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-26
+- fixed issues with sup authority signing
+- fixed bugs in remove_slivers and SliverStatus
+
 * Thu Jun 16 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-25
 - fix typo that prevented aggregates from operating properly
 
@@ -335,111 +402,6 @@ fi
   the api handler on every new server request, making it easier to access the 
   cache and use in more general ways.     
 
-%changelog
-* Thu Jun 16 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-25
-- fix typo that prevented aggregates from operating properly
-
-* Tue Jun 14 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-24
-- load trusted certs into ssl context prior to handshake
-- client's logfile lives in ~/.sfi/sfi.log
-
-* Fri Jun 10 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-23
-- includes a change on passphrases that was intended in 1.0-22
-
-* Wed Mar 16 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-21
-- stable sfascan
-- fix in initscript, *ENABLED tags in config now taken into account
-
-* Fri Mar 11 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-20
-- some commits had not been pushed in tag 19
-
-* Fri Mar 11 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-19
-- GetVersion should now report full URLs with path
-- scansfa has nicer output and new syntax (entry URLs as args and not options)
-- dos2unix'ed flash policy pill
-
-* Wed Mar 09 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-18
-- fix packaging again for f8
-
-* Wed Mar 09 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-17
-- fix packaging (apparently broken in 1.0-16)
-- first working version of sfascan
-- tweaks in GetVersion for exposing hrn(AM) and full set of aggregates(SM)
-- deprecated the sfa_geni_aggregate config category
-
-* Tue Mar 08 2011 Andy Bavier <acb@cs.princeton.edu> - sfa-1.0-16
-- Fix build problem
-- First version of SFA scanner
-
-* Mon Mar 07 2011 Andy Bavier <acb@cs.princeton.edu> - sfa-1.0-15
-- Add support for Flash clients using flashpolicy
-- Fix problems with tag handling in RSpec
-
-* Wed Mar 02 2011 Andy Bavier <acb@cs.princeton.edu> - sfa-1.0-14
-- Modifications to the Eucalyptus Aggregate Manager
-- Fixes for VINI RSpec
-- Fix tag handling for PL RSpec
-- Fix XML Schema ordering for <urn> element
-
-* Tue Feb 01 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-13
-- just set x509 version to 2
-
-* Wed Jan 26 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-12
-- added urn to the node area in rspecs
-- conversion to urn now exports fqdn
-- sfa-import-plc.py now creates a unique registry record for each SFA interface
-
-* Thu Dec 16 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-11
-- undo broken attempt for python-2.7
-
-* Wed Dec 15 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-10
-- SMs avoid duplicates for when call graph has dags;
-- just based on network's name, when a duplicate occurs, one is just dropped
-- does not try to merge/aggregate 2 networks
-- also reviewed logging with the hope to fix the sfa startup msg:
-- TypeError: not all arguments converted during string formatting
-
-* Tue Dec 07 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-9
-- verify credentials against xsd schema
-- Fix SM to SM communication
-- Fix bug in sfa.util.sfalogging, sfa-import.py now logs to sfa_import.log
-- new setting session_key_path
-
-* Tue Nov 09 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-8
-- fix registry credential regeneration and handle expiration
-- support for setting slice tags (min_role=user)
-- client can display its own version: sfi.py version --local
-- GetVersion to provide urn in addition to hrn
-- more code uses plxrn vs previous helper functions
-- import replaces '+' in email addresses with '_'
-
-* Fri Oct 22 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-7
-- fix GetVersion code_tag and add code_url
-
-* Fri Oct 22 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-6
-- extend GetVersion towards minimum federation introspection, and expose local tag
-
-* Wed Oct 20 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-5
-- fixed some legacy issues (list vs List)
-- deprecated sfa.util.namespace for xrn and plxrn
-- unit tests ship as the sfa-tests rpm
-
-* Mon Oct 11 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-2
-- deprecated old methods (e.g. List/list, and GetCredential/get_credential)
-- NOTE:  get_(self_)credential both have type and hrn swapped when moving to Get(Self)Credential
-- hrn-urn translations tweaked
-- fixed 'service sfa status'
-- sfa-nuke-plc has a -f/--file-system option to clean up /var/lib/authorities (exp.)
-- started to repair sfadump - although not usable yet
-- trust objects now have dump_string method that dump() actually prints
-- unit tests under review
-- logging cleanup ongoing (always safe to use sfalogging.sfa_logger())
-- binaries now support -v or -vv to increase loglevel
-- trashed obsolete sfa.util.client
-
-* Mon Oct 04 2010 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-1
-- various bugfixes and cleanup, improved/harmonized logging
-
 * Thu May 11 2010 Tony Mack <tmack@cs.princeton.edu> - sfa-0.9-11
 - SfaServer now uses a pool of threads to handle requests concurrently
 - sfa.util.rspec no longer used to process/manage rspecs (deprecated). This is now handled by sfa.plc.network and is not backwards compatible
diff --git a/sfa/Makefile b/sfa/Makefile
deleted file mode 100644 (file)
index 06d3355..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-tags:  
-       find . -name '*.py' | grep -v '/\.svn/' | xargs etags
-.PHONY: tags
-
-
index 8f334b5..af366fc 100644 (file)
@@ -1,5 +1,7 @@
 # recompute the SFA graphs from different locations
 
+SFASCAN = ./sfascan.py -v
+
 # AMs, at least MyPLC AMs, are boring
 #BUNDLES += http://planet-lab.eu:12346/@auto-ple-am
 BUNDLES += http://planet-lab.eu:12345/@auto-ple-reg 
@@ -38,7 +40,8 @@ BUNDLES-LR += http://www.planet-lab.jp:12347/@auto-plj-sa
 BUNDLES-LR += http://www.emanicslab.org:12345/@auto-elc-reg 
 BUNDLES-LR += http://www.emanicslab.org:12347/@auto-elc-sa
 
-EXTENSIONS := png svg
+#EXTENSIONS := png svg
+EXTENSIONS := png
 
 ####################
 ALL += $(foreach bundle,$(BUNDLES),$(word 2,$(subst @, ,$(bundle))))
@@ -46,10 +49,12 @@ ALL += $(foreach bundle,$(BUNDLES-LR),$(word 2,$(subst @, ,$(bundle)))-lr)
 
 all: $(ALL)
 
+ple: auto-ple-reg auto-ple-sa-lr.out
+
 ####################
 define bundle_scan_target
 $(word 2,$(subst @, ,$(1))):
-       ./sfascan.py $(foreach extension,$(EXTENSIONS),-o $(word 2,$(subst @, ,$(1))).$(extension)) $(word 1,$(subst @, ,$(1))) >& $(word 2,$(subst @, ,$(1))).out
+       $(SFASCAN) $(foreach extension,$(EXTENSIONS),-o $(word 2,$(subst @, ,$(1))).$(extension)) $(word 1,$(subst @, ,$(1))) >& .$(word 2,$(subst @, ,$(1))).out
 .PHONY: $(word 2,$(subst @, ,$(1)))
 endef
 
@@ -59,7 +64,7 @@ $(foreach bundle,$(BUNDLES),$(eval $(call bundle_scan_target,$(bundle))))
 #################### same but left-to-right
 define bundle_scan_target_lr
 $(word 2,$(subst @, ,$(1)))-lr:
-       ./sfascan.py -l $(foreach extension,$(EXTENSIONS),-o $(word 2,$(subst @, ,$(1)))-lr.$(extension)) $(word 1,$(subst @, ,$(1))) >& $(word 2,$(subst @, ,$(1)))-lr.out
+       $(SFASCAN) -l $(foreach extension,$(EXTENSIONS),-o $(word 2,$(subst @, ,$(1)))-lr.$(extension)) $(word 1,$(subst @, ,$(1))) >& .$(word 2,$(subst @, ,$(1)))-lr.out
 .PHONY: $(word 2,$(subst @, ,$(1)))-lr
 endef
 
@@ -87,9 +92,12 @@ clean:
        rm -f auto-*.{out,version}
        $(foreach extension,$(EXTENSIONS),rm -rf auto-*.$(extension);)
 
+DATE=$(shell date '+%Y-%m-%d')
 PUBEXTENSIONS=png
 publish:
-       $(foreach extension,$(PUBEXTENSIONS),rsync -av auto-*.$(extension) tparment@srv-planete.inria.fr:/proj/planete/www/Thierry.Parmentelat/sfascan/ ;)
+       echo $(DATE)
+       ssh tparment@srv-planete.inria.fr mkdir /proj/planete/www/Thierry.Parmentelat/sfascan/$(DATE)
+       $(foreach extension,$(PUBEXTENSIONS),rsync -av auto-*.$(extension) tparment@srv-planete.inria.fr:/proj/planete/www/Thierry.Parmentelat/sfascan/$(DATE) ;)
 
 #################### convenience, for debugging only
 # make +foo : prints the value of $(foo)
diff --git a/sfa/client/client_helper.py b/sfa/client/client_helper.py
new file mode 100644 (file)
index 0000000..32e21a1
--- /dev/null
@@ -0,0 +1,37 @@
+
+def pg_users_arg(records):
+    users = []  
+    for record in records:
+        if record['type'] != 'user': 
+            continue
+        user = {'urn': record['geni_urn'],
+                'keys': record['keys']}
+        users.append(user)
+    return users    
+
+def sfa_users_arg(records, slice_record):
+    users = []
+    for record in records:
+        if record['type'] != 'user': 
+            continue
+        user = {'urn': record['geni_urn'], #
+                'keys': record['keys'],
+                'email': record['email'], # needed for MyPLC
+                'person_id': record['person_id'], # needed for MyPLC
+                'first_name': record['first_name'], # needed for MyPLC
+                'last_name': record['last_name'], # needed for MyPLC
+                'slice_record': slice_record, # needed for legacy refresh peer
+                'key_ids': record['key_ids'] # needed for legacy refresh peer
+                }         
+        users.append(user)
+    return users        
+
+def sfa_to_pg_users_arg(users):
+
+    new_users = []
+    fields = ['urn', 'keys']
+    for user in users:
+        new_user = dict([item for item in user.items() \
+          if item[0] in fields])
+        new_users.append(new_user)
+    return new_users        
index 67f9a28..71d17f0 100644 (file)
@@ -6,8 +6,6 @@ from optparse import OptionParser
 from pprint import pprint
 from types import StringTypes
 
-from sfa.util.rspec import RSpec
-
 def create_parser():
     command = sys.argv[0]
     argv = sys.argv[1:]
index cb765e0..e2be593 100755 (executable)
@@ -14,9 +14,7 @@ import os
 from optparse import OptionParser
 from pprint import pprint
 from xml.parsers.expat import ExpatError
-
-from sfa.util.rspec import RecordSpec
-
+from sfa.util.xml import XML    
 
 def create_parser():
     command = sys.argv[0]
@@ -34,17 +32,17 @@ def create_parser():
     return parser    
 
 
-def printRec(record, filters, options):
+def printRec(record_dict, filters, options):
     line = ""
     if len(filters):
         for filter in filters:
             if options.DEBUG:  print "Filtering on %s" %filter
             line += "%s: %s\n" % (filter, 
-                printVal(record.dict["record"].get(filter, None)))
+                printVal(record_dict.get(filter, None)))
         print line
     else:
         # print the wole thing
-        for (key, value) in record.dict["record"].iteritems():
+        for (key, value) in record_dict.iteritems():
             if (not options.withkey and key in ('gid', 'keys')) or\
                 (not options.plinfo and key == 'pl_info'):
                 continue
@@ -69,16 +67,14 @@ def main():
 
     stdin = sys.stdin.read()
     
-    record = RecordSpec(xml = stdin)
+    record = XML(stdin)
+    record_dict = record.todict()
     
-    if not record.dict.has_key("record"):
-        raise "RecordError", "Input record does not have 'record' tag."
-
     if options.DEBUG: 
-        record.pprint()
+        pprint(record.toxml())
         print "#####################################################"
 
-    printRec(record, args, options)
+    printRec(record_dict, args, options)
 
 if __name__ == '__main__':
     try: main()
index 5f48e68..405c90d 100755 (executable)
@@ -14,9 +14,7 @@ sys.path.append('.')
 import os
 from optparse import OptionParser
 from pprint import pprint
-
-from sfa.util.rspec import RecordSpec
-
+from sfa.util.xml import XML
 
 def create_parser():
     command = sys.argv[0]
@@ -92,15 +90,14 @@ def main():
     parser = create_parser(); 
     (options, args) = parser.parse_args()
 
-    record = RecordSpec(xml = sys.stdin.read())
-
+    record = XML(sys.stdin.read())
+    record_dict = record.todict()
     if args:
-        editDict(args, record.dict["record"], options)
+        editDict(args, record_dict, options)
     if options.DEBUG:
-        print "New Record:\n%s" % record.dict
-        record.pprint()
-
-    record.parseDict(record.dict)
+        print "New Record:\n%s" % record_dict
+        
+    record.parse_dict(record_dict)
     s = record.toxml()
     sys.stdout.write(s)
 
index b1169b9..52a9105 100755 (executable)
@@ -12,8 +12,7 @@ from sfa.trust.certificate import Certificate
 from sfa.trust.credential import Credential
 from sfa.trust.gid import GID
 from sfa.util.record import SfaRecord
-from sfa.util.rspec import RSpec
-from sfa.util.sfalogging import sfa_logger, sfa_logger_goes_to_console
+from sfa.util.sfalogging import logger
 
 def determine_sfa_filekind(fn):
 
@@ -100,7 +99,6 @@ def handle_input_kind (filename, options, kind):
         print "%s: unknown filekind '%s'"% (filename,kind)
 
 def main():
-    sfa_logger_goes_to_console()
     usage = """%prog file1 [ .. filen]
 display info on input files"""
     parser = OptionParser(usage=usage)
@@ -111,7 +109,7 @@ display info on input files"""
     parser.add_option("-v", "--verbose", action='count', dest='verbose', default=0)
     (options, args) = parser.parse_args()
 
-    sfa_logger().setLevelFromOptVerbose(options.verbose)
+    logger.setLevelFromOptVerbose(options.verbose)
     if len(args) <= 0:
         parser.print_help()
         sys.exit(1)
index f85384d..494a727 100755 (executable)
@@ -1,4 +1,4 @@
-#!/usr/bin/python
+#!/usr/bin/env python
 
 import sys
 import socket
@@ -10,7 +10,7 @@ import pygraphviz
 from optparse import OptionParser
 
 from sfa.client.sfi import Sfi
-from sfa.util.sfalogging import sfa_logger,sfa_logger_goes_to_console
+from sfa.util.sfalogging import logger, DEBUG
 import sfa.util.xmlrpcprotocol as xmlrpcprotocol
 
 def url_hostname_port (url):
@@ -38,7 +38,6 @@ class Interface:
             self.ip=socket.gethostbyname(self.hostname)
             self.probed=False
         except:
-#            traceback.print_exc()
             self.hostname="unknown"
             self.ip='0.0.0.0'
             self.port="???"
@@ -62,17 +61,18 @@ class Interface:
             pass
         options=DummyOptions()
         options.verbose=False
+        options.timeout=10
         try:
             client=Sfi(options)
             client.read_config()
             key_file = client.get_key_file()
             cert_file = client.get_cert_file(key_file)
             url=self.url()
-            sfa_logger().info('issuing get version at %s'%url)
-            server=xmlrpcprotocol.get_server(url, key_file, cert_file, options)
+            logger.info('issuing get version at %s'%url)
+            logger.debug("GetVersion, using timeout=%d"%options.timeout)
+            server=xmlrpcprotocol.get_server(url, key_file, cert_file, timeout=options.timeout, verbose=options.verbose)
             self._version=server.GetVersion()
         except:
-#            traceback.print_exc()
             self._version={}
         self.probed=True
         return self._version
@@ -82,7 +82,6 @@ class Interface:
         result='<<TABLE BORDER="0" CELLBORDER="0"><TR><TD>' + \
             '</TD></TR><TR><TD>'.join(lines) + \
             '</TD></TR></TABLE>>'
-#        print 'multilines=',result
         return result
 
     # default is for when we can't determine the type of the service
@@ -158,17 +157,17 @@ class SfaScan:
                 # performing xmlrpc call
                 version=interface.get_version()
                 if self.verbose:
-                    sfa_logger().info("GetVersion at interface %s"%interface.url())
+                    logger.info("GetVersion at interface %s"%interface.url())
                     if not version:
-                        sfa_logger().info("<EMPTY GetVersion(); offline or cannot authenticate>")
+                        logger.info("<EMPTY GetVersion(); offline or cannot authenticate>")
                     else: 
                         for (k,v) in version.iteritems(): 
                             if not isinstance(v,dict):
-                                sfa_logger().info("\r\t%s:%s"%(k,v))
+                                logger.info("\r\t%s:%s"%(k,v))
                             else:
-                                sfa_logger().info(k)
+                                logger.info(k)
                                 for (k1,v1) in v.iteritems():
-                                    sfa_logger().info("\r\t\t%s:%s"%(k1,v1))
+                                    logger.info("\r\t\t%s:%s"%(k1,v1))
                 # 'geni_api' is expected if the call succeeded at all
                 # 'peers' is needed as well as AMs typically don't have peers
                 if 'geni_api' in version and 'peers' in version: 
@@ -195,13 +194,12 @@ class SfaScan:
                     for (k,v) in interface.get_layout().iteritems():
                         node.attr[k]=v
                 else:
-                    sfa_logger().error("MISSED interface with node %s"%node)
+                    logger.error("MISSED interface with node %s"%node)
     
 
 default_outfiles=['sfa.png','sfa.svg','sfa.dot']
 
 def main():
-    sfa_logger_goes_to_console()
     usage="%prog [options] url-entry-point(s)"
     parser=OptionParser(usage=usage)
     parser.add_option("-o","--output",action='append',dest='outfiles',default=[],
@@ -210,21 +208,27 @@ def main():
                       help="instead of top-to-bottom")
     parser.add_option("-v","--verbose",action='store_true',dest='verbose',default=False,
                       help="verbose")
+    parser.add_option("-d","--debug",action='store_true',dest='debug',default=False,
+                      help="debug")
     (options,args)=parser.parse_args()
     if not args:
         parser.print_help()
         sys.exit(1)
     if not options.outfiles:
         options.outfiles=default_outfiles
+    logger.enable_console()
+    if options.debug:
+        options.verbose=True
+        logger.setLevel(DEBUG)
     scanner=SfaScan(left_to_right=options.left_to_right, verbose=options.verbose)
     entries = [ Interface(entry) for entry in args ]
     g=scanner.graph(entries)
-    sfa_logger().info("creating layout")
+    logger.info("creating layout")
     g.layout(prog='dot')
     for outfile in options.outfiles:
-        sfa_logger().info("drawing in %s"%outfile)
+        logger.info("drawing in %s"%outfile)
         g.draw(outfile)
-    sfa_logger().info("done")
+    logger.info("done")
 
 if __name__ == '__main__':
     main()
index 80cb34b..83a66f9 100755 (executable)
@@ -6,26 +6,28 @@ import sys
 sys.path.append('.')
 import os, os.path
 import tempfile
-import traceback
 import socket
-import random
 import datetime
-import zlib
+import codecs
+import pickle
 from lxml import etree
 from StringIO import StringIO
-from types import StringTypes, ListType
 from optparse import OptionParser
-from sfa.util.sfalogging import _SfaLogger, logging
+from sfa.client.client_helper import pg_users_arg, sfa_users_arg
+from sfa.util.sfalogging import sfi_logger
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.gid import GID
 from sfa.trust.credential import Credential
 from sfa.util.sfaticket import SfaTicket
 from sfa.util.record import SfaRecord, UserRecord, SliceRecord, NodeRecord, AuthorityRecord
-from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn
+from sfa.rspecs.rspec import RSpec
+from sfa.rspecs.rspec_converter import RSpecConverter
+from sfa.util.xrn import get_leaf, get_authority, hrn_to_urn
 import sfa.util.xmlrpcprotocol as xmlrpcprotocol
 from sfa.util.config import Config
 from sfa.util.version import version_core
 from sfa.util.cache import Cache
+from sfa.rspecs.version_manager import VersionManager
 
 AGGREGATE_PORT=12346
 CM_PORT=12346
@@ -76,23 +78,51 @@ def filter_records(type, records):
 
 
 # save methods
+def save_variable_to_file(var, filename, format="text"):
+    f = open(filename, "w")
+    if format == "text":
+        f.write(str(var))
+    elif format == "pickled":
+        f.write(pickle.dumps(var))
+    else:
+        # this should never happen
+        print "unknown output format", format
+
+
 def save_rspec_to_file(rspec, filename):
     if not filename.endswith(".rspec"):
         filename = filename + ".rspec"
-
     f = open(filename, 'w')
     f.write(rspec)
     f.close()
     return
 
-def save_records_to_file(filename, recordList):
-    index = 0
-    for record in recordList:
-        if index > 0:
-            save_record_to_file(filename + "." + str(index), record)
-        else:
-            save_record_to_file(filename, record)
-        index = index + 1
+def save_records_to_file(filename, recordList, format="xml"):
+    if format == "xml":
+        index = 0
+        for record in recordList:
+            if index > 0:
+                save_record_to_file(filename + "." + str(index), record)
+            else:
+                save_record_to_file(filename, record)
+            index = index + 1
+    elif format == "xmllist":
+        f = open(filename, "w")
+        f.write("<recordlist>\n")
+        for record in recordList:
+            record = SfaRecord(dict=record)
+            f.write('<record hrn="' + record.get_name() + '" type="' + record.get_type() + '" />\n')
+        f.write("</recordlist>\n")
+        f.close()
+    elif format == "hrnlist":
+        f = open(filename, "w")
+        for record in recordList:
+            record = SfaRecord(dict=record)
+            f.write(record.get_name() + "\n")
+        f.close()
+    else:
+        # this should never happen
+        print "unknown output format", format
 
 def save_record_to_file(filename, record):
     if record['type'] in ['user']:
@@ -106,13 +136,17 @@ def save_record_to_file(filename, record):
     else:
         record = SfaRecord(dict=record)
     str = record.save_to_string()
-    file(filename, "w").write(str)
+    f=codecs.open(filename, encoding='utf-8',mode="w")
+    f.write(str)
+    f.close()
     return
 
 
 # load methods
 def load_record_from_file(filename):
-    str = file(filename, "r").read()
+    f=codecs.open(filename, encoding="utf-8", mode="r")
+    str = f.read()
+    f.close()
     record = SfaRecord(string=str)
     return record
 
@@ -134,6 +168,8 @@ class Sfi:
         for opt in Sfi.required_options:
             if not hasattr(options,opt): setattr(options,opt,None)
         if not hasattr(options,'sfi_dir'): options.sfi_dir=os.path.expanduser("~/.sfi/")
+        # xxx oops, this is dangerous, sounds like ww sometimes have discrepency
+        # would be safer to remove self.sfi_dir altogether
         self.sfi_dir = options.sfi_dir
         self.options = options
         self.slicemgr = None
@@ -141,7 +177,8 @@ class Sfi:
         self.user = None
         self.authority = None
         self.hashrequest = False
-        self.logger = _SfaLogger(self.sfi_dir + 'sfi.log', level = logging.INFO)
+        self.logger = sfi_logger
+        self.logger.enable_console()
    
     def create_cmd_parser(self, command, additional_cmdargs=None):
         cmdargs = {"list": "authority",
@@ -151,6 +188,7 @@ class Sfi:
                   "update": "record",
                   "aggregates": "[name]",
                   "registries": "[name]",
+                  "create_gid": "[name]",
                   "get_gid": [],  
                   "get_trusted_certs": "cred",
                   "slices": "",
@@ -214,15 +252,27 @@ class Sfi:
                                 help="optional component information", default=None)
 
 
-        if command in ("resources", "show", "list"):
+        # 'create' does return the new rspec, makes sense to save that too
+        if command in ("resources", "show", "list", "create_gid", 'create'):
            parser.add_option("-o", "--output", dest="file",
                             help="output XML to file", metavar="FILE", default=None)
-        
+
         if command in ("show", "list"):
            parser.add_option("-f", "--format", dest="format", type="choice",
                              help="display format ([text]|xml)", default="text",
                              choices=("text", "xml"))
 
+           parser.add_option("-F", "--fileformat", dest="fileformat", type="choice",
+                             help="output file format ([xml]|xmllist|hrnlist)", default="xml",
+                             choices=("xml", "xmllist", "hrnlist"))
+
+        if command in ("status", "version"):
+           parser.add_option("-o", "--output", dest="file",
+                            help="output dictionary to file", metavar="FILE", default=None)
+           parser.add_option("-F", "--fileformat", dest="fileformat", type="choice",
+                             help="output file format ([text]|pickled)", default="text",
+                             choices=("text","pickled"))
+
         if command in ("delegate"):
            parser.add_option("-u", "--user",
                             action="store_true", dest="delegate_user", default=False,
@@ -272,13 +322,15 @@ class Sfi:
         parser.add_option("-k", "--hashrequest",
                          action="store_true", dest="hashrequest", default=False,
                          help="Create a hash of the request that will be authenticated on the server")
+        parser.add_option("-t", "--timeout", dest="timeout", default=None,
+                         help="Amout of time tom wait before timing out the request")
         parser.disable_interspersed_args()
 
         return parser
         
 
     def read_config(self):
-       config_file = self.options.sfi_dir + os.sep + "sfi_config"
+       config_file = os.path.join(self.options.sfi_dir,"sfi_config")
        try:
           config = Config (config_file)
        except:
@@ -346,16 +398,16 @@ class Sfi:
        self.cert_file = cert_file
        self.cert = GID(filename=cert_file)
        self.logger.info("Contacting Registry at: %s"%self.reg_url)
-       self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, self.options)  
+       self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, timeout=self.options.timeout, verbose=self.options.debug)  
        self.logger.info("Contacting Slice Manager at: %s"%self.sm_url)
-       self.slicemgr = xmlrpcprotocol.get_server(self.sm_url, key_file, cert_file, self.options)
+       self.slicemgr = xmlrpcprotocol.get_server(self.sm_url, key_file, cert_file, timeout=self.options.timeout, verbose=self.options.debug)
        return
 
     def get_cached_server_version(self, server):
         # check local cache first
         cache = None
         version = None 
-        cache_file = self.sfi_dir + os.path.sep + 'sfi_cache.dat'
+        cache_file = os.path.join(self.options.sfi_dir,'sfi_cache.dat')
         cache_key = server.url + "-version"
         try:
             cache = Cache(cache_file)
@@ -370,6 +422,8 @@ class Sfi:
             version = server.GetVersion()
             # cache version for 24 hours
             cache.add(cache_key, version, ttl= 60*60*24)
+            self.logger.info("Updating cache file %s" % cache_file)
+            cache.save_to_file(cache_file)
 
 
         return version   
@@ -380,7 +434,7 @@ class Sfi:
         Returns true if server support the optional call_id arg, false otherwise. 
         """
         server_version = self.get_cached_server_version(server)
-        if 'sfa' in server_version:
+        if 'sfa' in server_version and 'code_tag' in server_version:
             code_tag = server_version['code_tag']
             code_tag_parts = code_tag.split("-")
             
@@ -439,7 +493,7 @@ class Sfi:
             self.logger.info("Getting Registry issued cert")
             self.read_config()
             # *hack.  need to set registyr before _get_gid() is called 
-            self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, self.options)
+            self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, timeout=self.options.timeout, verbose=self.options.debug)
             gid = self._get_gid(type='user')
             self.registry = None 
             self.logger.info("Writing certificate to %s"%cert_file)
@@ -479,7 +533,6 @@ class Sfi:
             hrn = self.user
  
         gidfile = os.path.join(self.options.sfi_dir, hrn + ".gid")
-        print gidfile
         gid = self.get_cached_gid(gidfile)
         if not gid:
             user_cred = self.get_user_cred()
@@ -613,7 +666,7 @@ class Sfi:
         host_parts = host.split('/')
         host_parts[0] = host_parts[0] + ":" + str(port)
         url =  "http://%s" %  "/".join(host_parts)    
-        return xmlrpcprotocol.get_server(url, keyfile, certfile, self.options)
+        return xmlrpcprotocol.get_server(url, keyfile, certfile, timeout=self.options.timeout, verbose=self.options.debug)
 
     # xxx opts could be retrieved in self.options
     def get_server_from_opts(self, opts):
@@ -638,7 +691,22 @@ class Sfi:
   
     def dispatch(self, command, cmd_opts, cmd_args):
         return getattr(self, command)(cmd_opts, cmd_args)
+
+    def create_gid(self, opts, args):
+        if len(args) < 1:
+            self.print_help()
+            sys.exit(1)
+        target_hrn = args[0]
+        user_cred = self.get_user_cred().save_to_string(save_parents=True)
+        gid = self.registry.CreateGid(user_cred, target_hrn, self.cert.save_to_string())
+        if opts.file:
+            filename = opts.file
+        else:
+            filename = os.sep.join([self.sfi_dir, '%s.gid' % target_hrn])
+        self.logger.info("writing %s gid to %s" % (target_hrn, filename))
+        GID(string=gid).save_to_file(filename)
+         
+     
     # list entires in named authority registry
     def list(self, opts, args):
         if len(args)!= 1:
@@ -650,17 +718,14 @@ class Sfi:
             list = self.registry.List(hrn, user_cred)
         except IndexError:
             raise Exception, "Not enough parameters for the 'list' command"
-          
-        # filter on person, slice, site, node, etc.  
+
+        # filter on person, slice, site, node, etc.
         # THis really should be in the self.filter_records funct def comment...
         list = filter_records(opts.type, list)
         for record in list:
-            print "%s (%s)" % (record['hrn'], record['type'])     
+            print "%s (%s)" % (record['hrn'], record['type'])
         if opts.file:
-            file = opts.file
-            if not file.startswith(os.sep):
-                file = os.path.join(self.options.sfi_dir, file)
-            save_records_to_file(file, list)
+            save_records_to_file(opts.file, list, opts.fileformat)
         return
     
     # show named registry record
@@ -689,12 +754,8 @@ class Sfi:
                 record.dump()  
             else:
                 print record.save_to_string() 
         if opts.file:
-            file = opts.file
-            if not file.startswith(os.sep):
-                file = os.path.join(self.options.sfi_dir, file)
-            save_records_to_file(file, records)
+            save_records_to_file(opts.file, records, opts.fileformat)
         return
     
     def delegate(self, opts, args):
@@ -854,6 +915,8 @@ class Sfi:
             version=server.GetVersion()
         for (k,v) in version.iteritems():
             print "%-20s: %s"%(k,v)
+        if opts.file:
+            save_variable_to_file(version, opts.file, opts.fileformat)
 
     # list instantiated slices
     def slices(self, opts, args):
@@ -891,7 +954,15 @@ class Sfi:
             delegated_cred = self.delegate_cred(cred, get_authority(self.authority))
             creds.append(delegated_cred)
         if opts.rspec_version:
-            call_options['rspec_version'] = opts.rspec_version 
+            version_manager = VersionManager()
+            server_version = self.get_cached_server_version(server)
+            if 'sfa' in server_version:
+                # just request the version the client wants 
+                call_options['rspec_version'] = version_manager.get_version(opts.rspec_version).to_dict()
+            else:
+                # this must be a protogeni aggregate. We should request a v2 ad rspec
+                # regardless of what the client user requested 
+                call_options['rspec_version'] = version_manager.get_version('ProtoGENI 2').to_dict()     
         #panos add info options
         if opts.info:
             call_options['info'] = opts.info 
@@ -900,59 +971,58 @@ class Sfi:
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
         result = server.ListResources(*call_args)
-        format = opts.format
         if opts.file is None:
-            display_rspec(result, format)
+            display_rspec(result, opts.format)
         else:
-            file = opts.file
-            if not file.startswith(os.sep):
-                file = os.path.join(self.options.sfi_dir, file)
-            save_rspec_to_file(result, file)
+            save_rspec_to_file(result, opts.file)
         return
-    
+
     # created named slice with given rspec
     def create(self, opts, args):
+        server = self.get_server_from_opts(opts)
+        server_version = self.get_cached_server_version(server)
         slice_hrn = args[0]
-        slice_urn = hrn_to_urn(slice_hrn, 'slice') 
+        slice_urn = hrn_to_urn(slice_hrn, 'slice')
         user_cred = self.get_user_cred()
         slice_cred = self.get_slice_cred(slice_hrn).save_to_string(save_parents=True)
-        creds = [slice_cred]
-        if opts.delegate:
-            delegated_cred = self.delegate_cred(slice_cred, get_authority(self.authority))
-            creds.append(delegated_cred)
+        # delegate the cred to the callers root authority
+        delegated_cred = self.delegate_cred(slice_cred, get_authority(self.authority)+'.slicemanager')
+        #delegated_cred = self.delegate_cred(slice_cred, get_authority(slice_hrn))
+        #creds.append(delegated_cred)
         rspec_file = self.get_rspec_file(args[1])
         rspec = open(rspec_file).read()
 
+        # need to pass along user keys to the aggregate.
         # users = [
         #  { urn: urn:publicid:IDN+emulab.net+user+alice
-        #    keys: [<ssh key A>, <ssh key B>] 
+        #    keys: [<ssh key A>, <ssh key B>]
         #  }]
         users = []
-        server = self.get_server_from_opts(opts)
-        version = server.GetVersion()
-        if 'sfa' not in version:
-            # need to pass along user keys if this request is going to a ProtoGENI aggregate 
-            # ProtoGeni Aggregates will only install the keys of the user that is issuing the
-            # request. So we will only pass in one user that contains the keys for all
-            # users of the slice 
-            user = {'urn': user_cred.get_gid_caller().get_urn(),
-                    'keys': []}
-            slice_record = self.registry.Resolve(slice_urn, creds)
-            if slice_record and 'researchers' in slice_record:
-                user_hrns = slice_record['researchers']
-                user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns] 
-                user_records = self.registry.Resolve(user_urns, creds)
-                for user_record in user_records:
-                    if 'keys' in user_record:
-                        user['keys'].extend(user_record['keys'])
-            users.append(user)
-
+        slice_records = self.registry.Resolve(slice_urn, [user_cred.save_to_string(save_parents=True)])
+        if slice_records and 'researcher' in slice_records[0] and slice_records[0]['researcher']!=[]:
+            slice_record = slice_records[0]
+            user_hrns = slice_record['researcher']
+            user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns]
+            user_records = self.registry.Resolve(user_urns, [user_cred.save_to_string(save_parents=True)])
+
+            if 'sfa' not in server_version:
+                users = pg_users_arg(user_records)
+                rspec = RSpec(rspec)
+                rspec.filter({'component_manager_id': server_version['urn']})
+                rspec = RSpecConverter.to_pg_rspec(rspec.toxml(), content_type='request')
+                creds = [slice_cred]
+            else:
+                users = sfa_users_arg(user_records, slice_record)
+                creds = [slice_cred, delegated_cred]
         call_args = [slice_urn, creds, rspec, users]
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
-             
-        result =  server.CreateSliver(*call_args)
-        print result
+           
+        result = server.CreateSliver(*call_args)
+        if opts.file is None:
+            print result
+        else:
+            save_rspec_to_file (result, opts.file)
         return result
 
     # get a ticket for the specified slice
@@ -1089,7 +1159,10 @@ class Sfi:
         call_args = [slice_urn, creds]
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
-        print server.SliverStatus(*call_args)
+        result = server.SliverStatus(*call_args)
+        print result
+        if opts.file:
+            save_variable_to_file(result, opts.file, opts.fileformat)
 
 
     def shutdown(self, opts, args):
@@ -1133,12 +1206,13 @@ class Sfi:
             self.logger.debug("resources cmd_opts %s" % cmd_opts.format)
         elif command in ("list", "show", "remove"):
             self.logger.debug("cmd_opts.type %s" % cmd_opts.type)
-        self.logger.debug('cmd_args %s',cmd_args)
+        self.logger.debug('cmd_args %s' % cmd_args)
 
         try:
             self.dispatch(command, cmd_opts, cmd_args)
         except KeyError:
             self.logger.critical ("Unknown command %s"%command)
+            raise
             sys.exit(1)
     
         return
index 0242a13..9c2eae5 100755 (executable)
@@ -3,7 +3,7 @@
 import sys
 from sfa.util.rspecHelper import RSpec, Commands
 from sfa.client.sfi_commands import Commands
-from sfa.rspecs.rspec_parser import parse_rspec
+from sfa.rspecs.rspec import RSpec
 
 command = Commands(usage="%prog [options] [node1 node2...]",
                    description="Add sliver attributes to the RSpec. " +
@@ -19,7 +19,7 @@ command.prep()
 
 if command.opts.infile:
     attrs = command.get_attribute_dict()
-    rspec = parse_rspec(command.opts.infile)
+    rspec = RSpec(command.opts.infile)
     nodes = []
     if command.opts.nodefile:
         f = open(command.opts.nodefile, "r")
@@ -32,13 +32,13 @@ if command.opts.infile:
         for value in attrs[name]:
             if not nodes:
                 try:
-                    rspec.add_default_sliver_attribute(name, value)
+                    rspec.version.add_default_sliver_attribute(name, value)
                 except:
                     print >> sys.stderr, "FAILED: on all nodes: %s=%s" % (name, value)
             else:
                 for node in nodes:
                     try:
-                        rspec.add_sliver_attribute(node, name, value)
+                        rspec.version.add_sliver_attribute(node, name, value)
                     except:
                         print >> sys.stderr, "FAILED: on node %s: %s=%s" % (node, name, value)
 
index a96c678..c72dee3 100755 (executable)
@@ -2,7 +2,8 @@
 
 import sys
 from sfa.client.sfi_commands import Commands
-from sfa.rspecs.rspec_parser import parse_rspec
+from sfa.rspecs.rspec import RSpec
+from sfa.rspecs.version_manager import VersionManager
 
 command = Commands(usage="%prog [options] node1 node2...",
                    description="Add slivers to the RSpec. " +
@@ -24,17 +25,20 @@ if command.opts.outfile:
     outfile=file(command.opts.outfile,"w")
 else:
     outfile=sys.stdout
-
-rspec = parse_rspec(infile)
-rspec.type = 'request'
+ad_rspec = RSpec(infile)
 nodes = file(command.opts.nodefile).read().split()
+version_manager = VersionManager()
 try:
-    if rspec.version['type'].lower() == 'protogeni':
-        rspec.xml.set('type', 'request')
+    type = ad_rspec.version.type
+    version_num = ad_rspec.version.version
+    request_version = version_manager._get_version(type, version_num, 'request')    
+    request_rspec = RSpec(version=request_version)
     slivers = [{'hostname': node} for node in nodes]
-    rspec.add_slivers(slivers)
+    request_rspec.version.merge(ad_rspec)
+    request_rspec.version.add_slivers(slivers)
 except:
     print >> sys.stderr, "FAILED: %s" % nodes
+    raise
     sys.exit(1)
-print >>outfile, rspec.toxml(cleanup=True)
+print >>outfile, request_rspec.toxml()
 sys.exit(0)
index f372488..53b2542 100755 (executable)
@@ -2,7 +2,7 @@
 
 import sys
 from sfa.client.sfi_commands import Commands
-from sfa.rspecs.rspec_parser import parse_rspec
+from sfa.rspecs.rspec import RSpec
 
 command = Commands(usage="%prog [options] [node1 node2...]",
                    description="Delete sliver attributes from the RSpec. " +
@@ -18,7 +18,7 @@ command.prep()
 
 if command.opts.infile:
     attrs = command.get_attribute_dict()
-    rspec = parse_rspec(command.opts.infile)
+    rspec = RSpec(command.opts.infile)
     nodes = []
     if command.opts.nodefile:
         f = open(command.opts.nodefile, "r")
@@ -31,13 +31,13 @@ if command.opts.infile:
         for value in attrs[name]:
             if not nodes:
                 try:
-                    rspec.remove_default_sliver_attribute(name, value)
+                    rspec.version.remove_default_sliver_attribute(name, value)
                 except:
                     print >> sys.stderr, "FAILED: on all nodes: %s=%s" % (name, value)
             else:
                 for node in nodes:
                     try:
-                        rspec.remove_sliver_attribute(node, name, value)
+                        rspec.version.remove_sliver_attribute(node, name, value)
                     except:
                         print >> sys.stderr, "FAILED: on node %s: %s=%s" % (node, name, value)
 
index 5b4f70b..be10f0b 100755 (executable)
@@ -2,7 +2,7 @@
 
 import sys
 from sfa.client.sfi_commands import Commands
-from sfa.rspecs.rspec_parser import parse_rspec
+from sfa.rspecs.rspec import RSpec
 
 command = Commands(usage="%prog [options] node1 node2...",
                    description="Delete slivers from the RSpec. " +
@@ -12,7 +12,7 @@ command.add_nodefile_option()
 command.prep()
 
 if command.opts.infile:
-    rspec = parse_rspec(command.opts.infile)
+    rspec = RSpec(command.opts.infile)
     nodes = []
     if command.opts.nodefile:
         f = open(command.opts.nodefile, "r")
@@ -21,11 +21,11 @@ if command.opts.infile:
        
     try:
         slivers = [{'hostname': node} for node in nodes]
-        rspec.remove_slivers(slivers)
+        rspec.version.remove_slivers(slivers)
+        print rspec.toxml()
     except:
         print >> sys.stderr, "FAILED: %s" % nodes 
 
-    print rspec.toxml()
     
 
     
index 305bf25..f9e794b 100755 (executable)
@@ -2,7 +2,7 @@
 
 import sys
 from sfa.client.sfi_commands import Commands
-from sfa.rspecs.rspec_parser import parse_rspec 
+from sfa.rspecs.rspec import RSpec 
 
 command = Commands(usage="%prog [options]",
                    description="List all nodes in the RSpec. " + 
@@ -11,8 +11,8 @@ command = Commands(usage="%prog [options]",
 command.prep()
 
 if command.opts.infile:
-    rspec = parse_rspec(command.opts.infile)
-    nodes = rspec.get_nodes()
+    rspec = RSpec(command.opts.infile)
+    nodes = rspec.version.get_nodes()
     if command.opts.outfile:
         sys.stdout = open(command.opts.outfile, 'w')
     
index 08173ec..cf76823 100755 (executable)
@@ -2,7 +2,7 @@
 
 import sys
 from sfa.client.sfi_commands import Commands
-from sfa.rspecs.rspec_parser import parse_rspec
+from sfa.rspecs.rspec import RSpec
 
 command = Commands(usage="%prog [options]",
                    description="List all slivers in the RSpec. " + 
@@ -12,11 +12,11 @@ command.add_show_attributes_option()
 command.prep()
 
 if command.opts.infile:
-    rspec = parse_rspec(command.opts.infile)
-    nodes = rspec.get_nodes_with_slivers()
+    rspec = RSpec(command.opts.infile)
+    nodes = rspec.version.get_nodes_with_slivers()
     
     if command.opts.showatt:
-        defaults = rspec.get_default_sliver_attributes()
+        defaults = rspec.version.get_default_sliver_attributes()
         if defaults:
             print "ALL NODES"
             for (name, value) in defaults:
@@ -25,7 +25,7 @@ if command.opts.infile:
     for node in nodes:
         print node
         if command.opts.showatt:
-            atts = rspec.get_sliver_attributes(node)
+            atts = rspec.version.get_sliver_attributes(node)
             for (name, value) in atts:
                 print "  %s: %s" % (name, value)
 
index b039c24..e2fdb10 100755 (executable)
@@ -61,6 +61,9 @@ start() {
     
     reload
 
+    # install peer certs
+    action $"SFA installing peer certs" daemon /usr/bin/sfa-server.py -t -d $OPTIONS 
+
     if [ "$SFA_REGISTRY_ENABLED" -eq 1 ]; then
         action $"SFA Registry" daemon /usr/bin/sfa-server.py -r -d $OPTIONS
     fi
index e3bbd96..eea507c 100755 (executable)
@@ -6,9 +6,6 @@
 #
 # description:   Wraps PLCAPI into the SFA compliant API
 #
-# $Id: sfa 14304 2009-07-06 20:19:51Z thierry $
-# $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/init.d/sfa $
-#
 
 # Source config
 . /etc/sfa/sfa_config
similarity index 81%
rename from sfa/managers/aggregate_manager_pl.py
rename to sfa/managers/aggregate_manager.py
index 894823c..77af072 100644 (file)
@@ -1,5 +1,3 @@
-
-
 import datetime
 import time
 import traceback
@@ -8,9 +6,8 @@ import re
 from types import StringTypes
 
 from sfa.util.faults import *
-from sfa.util.xrn import get_authority, hrn_to_urn, urn_to_hrn, Xrn
+from sfa.util.xrn import get_authority, hrn_to_urn, urn_to_hrn, Xrn, urn_to_sliver_id
 from sfa.util.plxrn import slicename_to_hrn, hrn_to_pl_slicename, hostname_to_urn
-from sfa.util.rspec import *
 from sfa.util.specdict import *
 from sfa.util.record import SfaRecord
 from sfa.util.policy import Policy
@@ -24,23 +21,29 @@ from sfa.plc.api import SfaAPI
 from sfa.plc.aggregate import Aggregate
 from sfa.plc.slices import *
 from sfa.util.version import version_core
-from sfa.rspecs.rspec_version import RSpecVersion
-from sfa.rspecs.sfa_rspec import sfa_rspec_version
-from sfa.rspecs.pg_rspec import pg_rspec_ad_version, pg_rspec_request_version
-from sfa.rspecs.rspec_parser import parse_rspec 
+from sfa.rspecs.version_manager import VersionManager
+from sfa.rspecs.rspec import RSpec
 from sfa.util.sfatime import utcparse
 from sfa.util.callids import Callids
 
 def GetVersion(api):
+
+    version_manager = VersionManager()
+    ad_rspec_versions = []
+    request_rspec_versions = []
+    for rspec_version in version_manager.versions:
+        if rspec_version.content_type in ['*', 'ad']:
+            ad_rspec_versions.append(rspec_version.to_dict())
+        if rspec_version.content_type in ['*', 'request']:
+            request_rspec_versions.append(rspec_version.to_dict()) 
+    default_rspec_version = version_manager.get_version("sfa 1").to_dict()
     xrn=Xrn(api.hrn)
-    request_rspec_versions = [dict(pg_rspec_request_version), dict(sfa_rspec_version)]
-    ad_rspec_versions = [dict(pg_rspec_ad_version), dict(sfa_rspec_version)]
     version_more = {'interface':'aggregate',
                     'testbed':'myplc',
                     'hrn':xrn.get_hrn(),
                     'request_rspec_versions': request_rspec_versions,
                     'ad_rspec_versions': ad_rspec_versions,
-                    'default_ad_rspec': dict(sfa_rspec_version)
+                    'default_ad_rspec': default_rspec_version
                     }
     return version_core(version_more)
 
@@ -79,6 +82,7 @@ def __get_registry_objects(slice_xrn, creds, users):
 
         slice = {}
         
+        # get_expiration always returns a normalized datetime - no need to utcparse
         extime = Credential(string=creds[0]).get_expiration()
         # If the expiration time is > 60 days from now, set the expiration time to 60 days from now
         if extime > datetime.datetime.utcnow() + datetime.timedelta(days=60):
@@ -113,17 +117,16 @@ def SliverStatus(api, slice_xrn, creds, call_id):
 
     (hrn, type) = urn_to_hrn(slice_xrn)
     # find out where this slice is currently running
-    api.logger.info(hrn)
     slicename = hrn_to_pl_slicename(hrn)
     
-    slices = api.plshell.GetSlices(api.plauth, [slicename], ['node_ids','person_ids','name','expires'])
+    slices = api.plshell.GetSlices(api.plauth, [slicename], ['slice_id', 'node_ids','person_ids','name','expires'])
     if len(slices) == 0:        
-        raise Exception("Slice %s not found (used %s as slicename internally)" % slice_xrn, slicename)
+        raise Exception("Slice %s not found (used %s as slicename internally)" % (slice_xrn, slicename))
     slice = slices[0]
     
     # report about the local nodes only
     nodes = api.plshell.GetNodes(api.plauth, {'node_id':slice['node_ids'],'peer_id':None},
-                                 ['hostname', 'site_id', 'boot_state', 'last_contact'])
+                                 ['node_id', 'hostname', 'site_id', 'boot_state', 'last_contact'])
     site_ids = [node['site_id'] for node in nodes]
     sites = api.plshell.GetSites(api.plauth, site_ids, ['site_id', 'login_base'])
     sites_dict = dict ( [ (site['site_id'],site['login_base'] ) for site in sites ] )
@@ -132,7 +135,8 @@ def SliverStatus(api, slice_xrn, creds, call_id):
     top_level_status = 'unknown'
     if nodes:
         top_level_status = 'ready'
-    result['geni_urn'] = Xrn(slice_xrn, 'slice').get_urn()
+    slice_urn = Xrn(slice_xrn, 'slice').get_urn()
+    result['geni_urn'] = slice_urn
     result['pl_login'] = slice['name']
     result['pl_expires'] = datetime.datetime.fromtimestamp(slice['expires']).ctime()
     
@@ -144,7 +148,8 @@ def SliverStatus(api, slice_xrn, creds, call_id):
         res['pl_last_contact'] = node['last_contact']
         if node['last_contact'] is not None:
             res['pl_last_contact'] = datetime.datetime.fromtimestamp(node['last_contact']).ctime()
-        res['geni_urn'] = hostname_to_urn(api.hrn, sites_dict[node['site_id']], node['hostname'])
+        sliver_id = urn_to_sliver_id(slice_urn, slice['slice_id'], node['node_id']) 
+        res['geni_urn'] = sliver_id
         if node['boot_state'] == 'boot':
             res['geni_status'] = 'ready'
         else:
@@ -157,9 +162,6 @@ def SliverStatus(api, slice_xrn, creds, call_id):
         
     result['geni_status'] = top_level_status
     result['geni_resources'] = resources
-    # XX remove me
-    #api.logger.info(result)
-    # XX remove me
     return result
 
 def CreateSliver(api, slice_xrn, creds, rspec_string, users, call_id):
@@ -169,47 +171,36 @@ def CreateSliver(api, slice_xrn, creds, rspec_string, users, call_id):
     """
     if Callids().already_handled(call_id): return ""
 
-    reg_objects = __get_registry_objects(slice_xrn, creds, users)
-
-    (hrn, type) = urn_to_hrn(slice_xrn)
-    peer = None
     aggregate = Aggregate(api)
     slices = Slices(api)
+    (hrn, type) = urn_to_hrn(slice_xrn)
     peer = slices.get_peer(hrn)
     sfa_peer = slices.get_sfa_peer(hrn)
-    registry = api.registries[api.hrn]
-    credential = api.getCredential()
-    (site_id, remote_site_id) = slices.verify_site(registry, credential, hrn, 
-                                                   peer, sfa_peer, reg_objects)
-
-    slice = slices.verify_slice(registry, credential, hrn, site_id, 
-                                       remote_site_id, peer, sfa_peer, reg_objects)
-     
-    nodes = api.plshell.GetNodes(api.plauth, slice['node_ids'], ['hostname'])
-    current_slivers = [node['hostname'] for node in nodes] 
-    rspec = parse_rspec(rspec_string)
-    requested_slivers = [str(host) for host in rspec.get_nodes_with_slivers()]
-    # remove nodes not in rspec
-    deleted_nodes = list(set(current_slivers).difference(requested_slivers))
-
-    # add nodes from rspec
-    added_nodes = list(set(requested_slivers).difference(current_slivers))
-
-    try:
-        if peer:
-            api.plshell.UnBindObjectFromPeer(api.plauth, 'slice', slice['slice_id'], peer)
-
-        api.plshell.AddSliceToNodes(api.plauth, slice['name'], added_nodes) 
-        api.plshell.DeleteSliceFromNodes(api.plauth, slice['name'], deleted_nodes)
-
-        # TODO: update slice tags
-        #network.updateSliceTags()
+    slice_record=None    
+    if users:
+        slice_record = users[0].get('slice_record', {})
 
-    finally:
-        if peer:
-            api.plshell.BindObjectToPeer(api.plauth, 'slice', slice.id, peer, 
-                                         slice.peer_id)
+    # parse rspec
+    rspec = RSpec(rspec_string)
+    requested_attributes = rspec.version.get_slice_attributes()
+    
+    # ensure site record exists
+    site = slices.verify_site(hrn, slice_record, peer, sfa_peer)
+    # ensure slice record exists
+    slice = slices.verify_slice(hrn, slice_record, peer, sfa_peer)
+    # ensure person records exists
+    persons = slices.verify_persons(hrn, slice, users, peer, sfa_peer)
+    # ensure slice attributes exists
+    slices.verify_slice_attributes(slice, requested_attributes)
+    
+    # add/remove slice from nodes
+    requested_slivers = [str(host) for host in rspec.version.get_nodes_with_slivers()]
+    slices.verify_slice_nodes(slice, requested_slivers, peer) 
 
+    # hanlde MyPLC peer association.
+    # only used by plc and ple.
+    slices.handle_peer(site, slice, persons, peer)
+    
     return aggregate.get_rspec(slice_xrn=slice_xrn, version=rspec.version)
 
 
@@ -304,19 +295,20 @@ def ListSlices(api, creds, call_id):
 
     return slice_urns
     
-def ListResources(api, creds, options,call_id):
+def ListResources(api, creds, options, call_id):
     if Callids().already_handled(call_id): return ""
     # get slice's hrn from options
-    xrn = options.get('geni_slice_urn', '')
+    xrn = options.get('geni_slice_urn', None)
     (hrn, type) = urn_to_hrn(xrn)
 
+    version_manager = VersionManager()
     # get the rspec's return format from options
-    rspec_version = RSpecVersion(options.get('rspec_version'))
-    version_string = "rspec_%s" % (rspec_version.get_version_name())
+    rspec_version = version_manager.get_version(options.get('rspec_version'))
+    version_string = "rspec_%s" % (rspec_version.to_string())
 
     #panos adding the info option to the caching key (can be improved)
     if options.get('info'):
-       version_string = version_string + "_"+options.get('info', 'default')
+        version_string = version_string + "_"+options.get('info', 'default')
 
     # look in cache first
     if caching and api.cache and not xrn:
@@ -325,11 +317,9 @@ def ListResources(api, creds, options,call_id):
             api.logger.info("aggregate.ListResources: returning cached value for hrn %s"%hrn)
             return rspec 
 
-    #aggregate = Aggregate(api)
     #panos: passing user-defined options
     #print "manager options = ",options
     aggregate = Aggregate(api, options)
-
     rspec =  aggregate.get_rspec(slice_xrn=xrn, version=rspec_version)
 
     # cache the result
index 68669bd..6c7c1f4 100644 (file)
@@ -1,7 +1,9 @@
 from __future__ import with_statement 
 
 import sys
-import os
+import os, errno
+import logging
+import datetime
 
 import boto
 from boto.ec2.regioninfo import RegionInfo
@@ -12,13 +14,20 @@ from lxml import etree as ET
 from sqlobject import *
 
 from sfa.util.faults import *
-from sfa.util.xrn import urn_to_hrn
-from sfa.util.rspec import RSpec
+from sfa.util.xrn import urn_to_hrn, Xrn
 from sfa.server.registry import Registries
 from sfa.trust.credential import Credential
 from sfa.plc.api import SfaAPI
+from sfa.plc.aggregate import Aggregate
+from sfa.plc.slices import *
 from sfa.util.plxrn import hrn_to_pl_slicename, slicename_to_hrn
 from sfa.util.callids import Callids
+from sfa.util.sfalogging import logger
+from sfa.rspecs.sfa_rspec import sfa_rspec_version
+from sfa.util.version import version_core
+
+from multiprocessing import Process
+from time import sleep
 
 ##
 # The data structure used to represent a cloud.
@@ -32,10 +41,18 @@ cloud = {}
 #
 EUCALYPTUS_RSPEC_SCHEMA='/etc/sfa/eucalyptus.rng'
 
-# Quick hack
-sys.stderr = file('/var/log/euca_agg.log', 'a+')
 api = SfaAPI()
 
+##
+# Meta data of an instance.
+#
+class Meta(SQLObject):
+    instance   = SingleJoin('EucaInstance')
+    state      = StringCol(default = 'new')
+    pub_addr   = StringCol(default = None)
+    pri_addr   = StringCol(default = None)
+    start_time = DateTimeCol(default = None)
+
 ##
 # A representation of an Eucalyptus instance. This is a support class
 # for instance <-> slice mapping.
@@ -47,7 +64,8 @@ class EucaInstance(SQLObject):
     ramdisk_id  = StringCol()
     inst_type   = StringCol()
     key_pair    = StringCol()
-    slice = ForeignKey('Slice')
+    slice       = ForeignKey('Slice')
+    meta        = ForeignKey('Meta')
 
     ##
     # Contacts Eucalyptus and tries to reserve this instance.
@@ -56,10 +74,11 @@ class EucaInstance(SQLObject):
     # @param pubKeys A list of public keys for the instance.
     #
     def reserveInstance(self, botoConn, pubKeys):
-        print >>sys.stderr, 'Reserving an instance: image: %s, kernel: ' \
-                            '%s, ramdisk: %s, type: %s, key: %s' % \
-                            (self.image_id, self.kernel_id, self.ramdisk_id, 
-                             self.inst_type, self.key_pair)
+        logger = logging.getLogger('EucaAggregate')
+        logger.info('Reserving an instance: image: %s, kernel: ' \
+                    '%s, ramdisk: %s, type: %s, key: %s' % \
+                    (self.image_id, self.kernel_id, self.ramdisk_id,
+                    self.inst_type, self.key_pair))
 
         # XXX The return statement is for testing. REMOVE in production
         #return
@@ -78,7 +97,7 @@ class EucaInstance(SQLObject):
         except EC2ResponseError, ec2RespErr:
             errTree = ET.fromstring(ec2RespErr.body)
             msg = errTree.find('.//Message')
-            print >>sys.stderr, msg.text
+            logger.error(msg.text)
             self.destroySelf()
 
 ##
@@ -94,10 +113,17 @@ class Slice(SQLObject):
 # Initialize the aggregate manager by reading a configuration file.
 #
 def init_server():
+    logger = logging.getLogger('EucaAggregate')
+    fileHandler = logging.FileHandler('/var/log/euca.log')
+    fileHandler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
+    logger.addHandler(fileHandler)
+    fileHandler.setLevel(logging.DEBUG)
+    logger.setLevel(logging.DEBUG)
+
     configParser = ConfigParser()
     configParser.read(['/etc/sfa/eucalyptus_aggregate.conf', 'eucalyptus_aggregate.conf'])
     if len(configParser.sections()) < 1:
-        print >>sys.stderr, 'No cloud defined in the config file'
+        logger.error('No cloud defined in the config file')
         raise Exception('Cannot find cloud definition in configuration file.')
 
     # Only read the first section.
@@ -123,23 +149,28 @@ def init_server():
         detail = {'imageID' : i.id, 'kernelID' : i.kernel_id, 'ramdiskID' : i.ramdisk_id}
         cloud['imageBundles'][name] = detail
 
-    # Initialize sqlite3 database.
+    # Initialize sqlite3 database and tables.
     dbPath = '/etc/sfa/db'
     dbName = 'euca_aggregate.db'
 
     if not os.path.isdir(dbPath):
-        print >>sys.stderr, '%s not found. Creating directory ...' % dbPath
+        logger.info('%s not found. Creating directory ...' % dbPath)
         os.mkdir(dbPath)
 
     conn = connectionForURI('sqlite://%s/%s' % (dbPath, dbName))
     sqlhub.processConnection = conn
     Slice.createTable(ifNotExists=True)
     EucaInstance.createTable(ifNotExists=True)
+    Meta.createTable(ifNotExists=True)
+
+    # Start the update process to keep track of the meta data
+    # about Eucalyptus instance.
+    Process(target=updateMeta).start()
 
     # Make sure the schema exists.
     if not os.path.exists(EUCALYPTUS_RSPEC_SCHEMA):
         err = 'Cannot location schema at %s' % EUCALYPTUS_RSPEC_SCHEMA
-        print >>sys.stderr, err
+        logger.error(err)
         raise Exception(err)
 
 ##
@@ -156,10 +187,11 @@ def getEucaConnection():
     useSSL    = False
     srvPath   = '/'
     eucaPort  = 8773
+    logger    = logging.getLogger('EucaAggregate')
 
     if not accessKey or not secretKey or not eucaURL:
-        print >>sys.stderr, 'Please set ALL of the required environment ' \
-                            'variables by sourcing the eucarc file.'
+        logger.error('Please set ALL of the required environment ' \
+                     'variables by sourcing the eucarc file.')
         return None
     
     # Split the url into parts
@@ -188,36 +220,31 @@ def getEucaConnection():
 # @param sliceHRN The hunman readable name of the slice.
 # @return sting()
 #
-def getKeysForSlice(sliceHRN):
-    try:
-        # convert hrn to slice name
-        plSliceName = hrn_to_pl_slicename(sliceHRN)
-    except IndexError, e:
-        print >>sys.stderr, 'Invalid slice name (%s)' % sliceHRN
-        return []
-
-    # Get the slice's information
-    sliceData = api.plshell.GetSlices(api.plauth, {'name':plSliceName})
-    if not sliceData:
-        print >>sys.stderr, 'Cannot get any data for slice %s' % plSliceName
+# This method is no longer needed because the user keys are passed into
+# CreateSliver
+#
+def getKeysForSlice(api, sliceHRN):
+    logger   = logging.getLogger('EucaAggregate')
+    cred     = api.getCredential()
+    registry = api.registries[api.hrn]
+    keys     = []
+
+    # Get the slice record
+    records = registry.Resolve(sliceHRN, cred)
+    if not records:
+        logging.warn('Cannot find any record for slice %s' % sliceHRN)
         return []
 
-    # It should only return a list with len = 1
-    sliceData = sliceData[0]
+    # Find who can log into this slice
+    persons = records[0]['persons']
 
-    keys = []
-    person_ids = sliceData['person_ids']
-    if not person_ids: 
-        print >>sys.stderr, 'No users in slice %s' % sliceHRN
-        return []
+    # Extract the keys from persons records
+    for p in persons:
+        sliceUser = registry.Resolve(p, cred)
+        userKeys = sliceUser[0]['keys']
+        keys += userKeys
 
-    persons = api.plshell.GetPersons(api.plauth, person_ids)
-    for person in persons:
-        pkeys = api.plshell.GetKeys(api.plauth, person['key_ids'])
-        for key in pkeys:
-            keys.append(key['key'])
-    return ''.join(keys)
+    return '\n'.join(keys)
 
 ##
 # A class that builds the RSpec for Eucalyptus.
@@ -338,14 +365,15 @@ class EucaRSpecBuilder(object):
     # Generates the RSpec.
     #
     def toXML(self):
+        logger = logging.getLogger('EucaAggregate')
         if not self.cloudInfo:
-            print >>sys.stderr, 'No cloud information'
+            logger.error('No cloud information')
             return ''
 
         xml = self.eucaRSpec
         cloud = self.cloudInfo
         with xml.RSpec(type='eucalyptus'):
-            with xml.cloud(id=cloud['name']):
+            with xml.network(name=cloud['name']):
                 with xml.ipv4:
                     xml << cloud['ip']
                 #self.__keyPairsXML(cloud['keypairs'])
@@ -400,6 +428,7 @@ def ListResources(api, creds, options, call_id):
     # get slice's hrn from options
     xrn = options.get('geni_slice_urn', '')
     hrn, type = urn_to_hrn(xrn)
+    logger = logging.getLogger('EucaAggregate')
 
     # get hrn of the original caller
     origin_hrn = options.get('origin_hrn', None)
@@ -409,7 +438,7 @@ def ListResources(api, creds, options, call_id):
     conn = getEucaConnection()
 
     if not conn:
-        print >>sys.stderr, 'Error: Cannot create a connection to Eucalyptus'
+        logger.error('Cannot create a connection to Eucalyptus')
         return 'Cannot create a connection to Eucalyptus'
 
     try:
@@ -476,7 +505,7 @@ def ListResources(api, creds, options, call_id):
     except EC2ResponseError, ec2RespErr:
         errTree = ET.fromstring(ec2RespErr.body)
         errMsgE = errTree.find('.//Message')
-        print >>sys.stderr, errMsgE.text
+        logger.error(errMsgE.text)
 
     rspec = EucaRSpecBuilder(cloud).toXML()
 
@@ -490,27 +519,52 @@ def ListResources(api, creds, options, call_id):
 """
 Hook called via 'sfi.py create'
 """
-def CreateSliver(api, xrn, creds, xml, users, call_id):
+def CreateSliver(api, slice_xrn, creds, xml, users, call_id):
     if Callids().already_handled(call_id): return ""
 
     global cloud
-    hrn = urn_to_hrn(xrn)[0]
+    logger = logging.getLogger('EucaAggregate')
+    logger.debug("In CreateSliver")
+
+    aggregate = Aggregate(api)
+    slices = Slices(api)
+    (hrn, type) = urn_to_hrn(slice_xrn)
+    peer = slices.get_peer(hrn)
+    sfa_peer = slices.get_sfa_peer(hrn)
+    slice_record=None
+    if users:
+        slice_record = users[0].get('slice_record', {})
 
     conn = getEucaConnection()
     if not conn:
-        print >>sys.stderr, 'Error: Cannot create a connection to Eucalyptus'
+        logger.error('Cannot create a connection to Eucalyptus')
         return ""
 
     # Validate RSpec
     schemaXML = ET.parse(EUCALYPTUS_RSPEC_SCHEMA)
     rspecValidator = ET.RelaxNG(schemaXML)
     rspecXML = ET.XML(xml)
+    for network in rspecXML.iterfind("./network"):
+        if network.get('name') != cloud['name']:
+            # Throw away everything except my own RSpec
+            # sfa_logger().error("CreateSliver: deleting %s from rspec"%network.get('id'))
+            network.getparent().remove(network)
     if not rspecValidator(rspecXML):
         error = rspecValidator.error_log.last_error
         message = '%s (line %s)' % (error.message, error.line) 
-        # XXX: InvalidRSpec is new. Currently, I am not working with Trunk code.
-        #raise InvalidRSpec(message)
-        raise Exception(message)
+        raise InvalidRSpec(message)
+
+    """
+    Create the sliver[s] (slice) at this aggregate.
+    Verify HRN and initialize the slice record in PLC if necessary.
+    """
+
+    # ensure site record exists
+    site = slices.verify_site(hrn, slice_record, peer, sfa_peer)
+    # ensure slice record exists
+    slice = slices.verify_slice(hrn, slice_record, peer, sfa_peer)
+    # ensure person records exists
+    persons = slices.verify_persons(hrn, slice, users, peer, sfa_peer)
 
     # Get the slice from db or create one.
     s = Slice.select(Slice.q.slice_hrn == hrn).getOne(None)
@@ -521,24 +575,31 @@ def CreateSliver(api, xrn, creds, xml, users, call_id):
     pendingRmInst = []
     for sliceInst in s.instances:
         pendingRmInst.append(sliceInst.instance_id)
-    existingInstGroup = rspecXML.findall('.//euca_instances')
+    existingInstGroup = rspecXML.findall(".//euca_instances")
     for instGroup in existingInstGroup:
         for existingInst in instGroup:
             if existingInst.get('id') in pendingRmInst:
                 pendingRmInst.remove(existingInst.get('id'))
     for inst in pendingRmInst:
-        print >>sys.stderr, 'Instance %s will be terminated' % inst
         dbInst = EucaInstance.select(EucaInstance.q.instance_id == inst).getOne(None)
-        dbInst.destroySelf()
-    conn.terminate_instances(pendingRmInst)
+        if dbInst.meta.state != 'deleted':
+            logger.debug('Instance %s will be terminated' % inst)
+            # Terminate instances one at a time for robustness
+            conn.terminate_instances([inst])
+            # Only change the state but do not remove the entry from the DB.
+            dbInst.meta.state = 'deleted'
+            #dbInst.destroySelf()
 
     # Process new instance requests
-    requests = rspecXML.findall('.//request')
+    requests = rspecXML.findall(".//request")
     if requests:
         # Get all the public keys associate with slice.
-        pubKeys = getKeysForSlice(s.slice_hrn)
-        print >>sys.stderr, "Passing the following keys to the instance:\n%s" % pubKeys
-        sys.stderr.flush()
+        keys = []
+        for user in users:
+            keys += user['keys']
+            logger.debug("Keys: %s" % user['keys'])
+        pubKeys = '\n'.join(keys)
+        logger.debug('Passing the following keys to the instance:\n%s' % pubKeys)
     for req in requests:
         vmTypeElement = req.getparent()
         instType = vmTypeElement.get('name')
@@ -546,7 +607,7 @@ def CreateSliver(api, xrn, creds, xml, users, call_id):
         
         bundleName = req.find('bundle').text
         if not cloud['imageBundles'][bundleName]:
-            print >>sys.stderr, 'Cannot find bundle %s' % bundleName
+            logger.error('Cannot find bundle %s' % bundleName)
         bundleInfo = cloud['imageBundles'][bundleName]
         instKernel  = bundleInfo['kernelID']
         instDiskImg = bundleInfo['imageID']
@@ -555,18 +616,112 @@ def CreateSliver(api, xrn, creds, xml, users, call_id):
 
         # Create the instances
         for i in range(0, numInst):
-            eucaInst = EucaInstance(slice = s, 
-                                    kernel_id = instKernel,
-                                    image_id = instDiskImg,
+            eucaInst = EucaInstance(slice      = s,
+                                    kernel_id  = instKernel,
+                                    image_id   = instDiskImg,
                                     ramdisk_id = instRamDisk,
-                                    key_pair = instKey,
-                                    inst_type = instType)
+                                    key_pair   = instKey,
+                                    inst_type  = instType,
+                                    meta       = Meta(start_time=datetime.datetime.now()))
             eucaInst.reserveInstance(conn, pubKeys)
 
     # xxx - should return altered rspec 
     # with enough data for the client to understand what's happened
     return xml
 
+##
+# Return information on the IP addresses bound to each slice's instances
+#
+def dumpInstanceInfo():
+    logger = logging.getLogger('EucaMeta')
+    outdir = "/var/www/html/euca/"
+    outfile = outdir + "instances.txt"
+
+    try:
+        os.makedirs(outdir)
+    except OSError, e:
+        if e.errno != errno.EEXIST:
+            raise
+
+    dbResults = Meta.select(
+        AND(Meta.q.pri_addr != None,
+            Meta.q.state    == 'running')
+        )
+    dbResults = list(dbResults)
+    f = open(outfile, "w")
+    for r in dbResults:
+        instId = r.instance.instance_id
+        ipaddr = r.pri_addr
+        hrn = r.instance.slice.slice_hrn
+        logger.debug('[dumpInstanceInfo] %s %s %s' % (instId, ipaddr, hrn))
+        f.write("%s %s %s\n" % (instId, ipaddr, hrn))
+    f.close()
+
+##
+# A separate process that will update the meta data.
+#
+def updateMeta():
+    logger = logging.getLogger('EucaMeta')
+    fileHandler = logging.FileHandler('/var/log/euca_meta.log')
+    fileHandler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
+    logger.addHandler(fileHandler)
+    fileHandler.setLevel(logging.DEBUG)
+    logger.setLevel(logging.DEBUG)
+
+    while True:
+        sleep(30)
+
+        # Get IDs of the instances that don't have IPs yet.
+        dbResults = Meta.select(
+                      AND(Meta.q.pri_addr == None,
+                          Meta.q.state    != 'deleted')
+                    )
+        dbResults = list(dbResults)
+        logger.debug('[update process] dbResults: %s' % dbResults)
+        instids = []
+        for r in dbResults:
+            if not r.instance:
+                continue
+            instids.append(r.instance.instance_id)
+        logger.debug('[update process] Instance Id: %s' % ', '.join(instids))
+
+        # Get instance information from Eucalyptus
+        conn = getEucaConnection()
+        vmInstances = []
+        reservations = conn.get_all_instances(instids)
+        for reservation in reservations:
+            vmInstances += reservation.instances
+
+        # Check the IPs
+        instIPs = [ {'id':i.id, 'pri_addr':i.private_dns_name, 'pub_addr':i.public_dns_name}
+                    for i in vmInstances if i.private_dns_name != '0.0.0.0' ]
+        logger.debug('[update process] IP dict: %s' % str(instIPs))
+
+        # Update the local DB
+        for ipData in instIPs:
+            dbInst = EucaInstance.select(EucaInstance.q.instance_id == ipData['id']).getOne(None)
+            if not dbInst:
+                logger.info('[update process] Could not find %s in DB' % ipData['id'])
+                continue
+            dbInst.meta.pri_addr = ipData['pri_addr']
+            dbInst.meta.pub_addr = ipData['pub_addr']
+            dbInst.meta.state    = 'running'
+
+        dumpInstanceInfo()
+
+def GetVersion(api):
+    xrn=Xrn(api.hrn)
+    request_rspec_versions = [dict(sfa_rspec_version)]
+    ad_rspec_versions = [dict(sfa_rspec_version)]
+    version_more = {'interface':'aggregate',
+                    'testbed':'myplc',
+                    'hrn':xrn.get_hrn(),
+                    'request_rspec_versions': request_rspec_versions,
+                    'ad_rspec_versions': ad_rspec_versions,
+                    'default_ad_rspec': dict(sfa_rspec_version)
+                    }
+    return version_core(version_more)
+
 def main():
     init_server()
 
@@ -577,7 +732,11 @@ def main():
 
     #rspec = ListResources('euca', 'planetcloud.pc.test', 'planetcloud.pc.marcoy', 'test_euca')
     #print rspec
-    print getKeysForSlice('gc.gc.test1')
+
+    server_key_file = '/var/lib/sfa/authorities/server.key'
+    server_cert_file = '/var/lib/sfa/authorities/server.cert'
+    api = SfaAPI(key_file = server_key_file, cert_file = server_cert_file, interface='aggregate')
+    print getKeysForSlice(api, 'gc.gc.test1')
 
 if __name__ == "__main__":
     main()
index 0c374b4..d7d3776 100644 (file)
-from sfa.util.xrn import urn_to_hrn, hrn_to_urn, get_authority
-from sfa.util.plxrn import hrn_to_pl_slicename
-from sfa.util.plxrn import hrn_to_pl_slicename
-from sfa.util.rspec import RSpec
-from sfa.util.sfalogging import sfa_logger
-from sfa.util.config import Config
-from sfa.managers.aggregate_manager_pl import GetVersion, __get_registry_objects
-from sfa.plc.slices import Slices
-import os
-import time
-
-RSPEC_TMP_FILE_PREFIX = "/tmp/max_rspec"
-
-# execute shell command and return both exit code and text output
-def shell_execute(cmd, timeout):
-    pipe = os.popen('{ ' + cmd + '; } 2>&1', 'r')
-    pipe = os.popen(cmd + ' 2>&1', 'r')
-    text = ''
-    while timeout:
-        line = pipe.read()
-        text += line
-        time.sleep(1)
-        timeout = timeout-1
-    code = pipe.close()
-    if code is None: code = 0
-    if text[-1:] == '\n': text = text[:-1]
-    return code, text
-
-"""
- call AM API client with command like in the following example:
- cd aggregate_client; java -classpath AggregateWS-client-api.jar:lib/* \
-      net.geni.aggregate.client.examples.CreateSliceNetworkClient \
-      ./repo https://geni:8443/axis2/services/AggregateGENI \
-      ... params ...
-"""
-
-def call_am_apiclient(client_app, params, timeout):
-    (client_path, am_url) = Config().get_max_aggrMgr_info()
-    sys_cmd = "cd " + client_path + "; java -classpath AggregateWS-client-api.jar:lib/* net.geni.aggregate.client.examples." + client_app + " ./repo " + am_url + " " + ' '.join(params)
-    ret = shell_execute(sys_cmd, timeout)
-    sfa_logger().debug("shell_execute cmd: %s returns %s" % (sys_cmd, ret))
-# save request RSpec xml content to a tmp file
-def save_rspec_to_file(rspec):
-    path = RSPEC_TMP_FILE_PREFIX + "_" + time.strftime('%Y%m%dT%H:%M:%S', time.gmtime(time.time())) +".xml"
-    file = open(path, "w")
-    file.write(rspec)
-    file.close()
-    return path
-
-# get stripped down slice id/name plc:maxpl:xi_slice1 --> xi_slice1
-def get_short_slice_id(cred, hrn):
-    if hrn == None:
-        return None
-    slice_id = hrn[hrn.rfind('+')+1:]
-    if slice_id == None:
-        slice_id = hrn[hrn.rfind(':')+1:]
-    if slice_id == None:
-       return hrn
-       pass
-    return str(slice_id)
-
-# extract xml 
-def get_xml_by_tag(text, tag):
-    indx1 = text.find('<'+tag)
-    indx2 = text.find('/'+tag+'>')
-    xml = None
-    if indx1!=-1 and indx2>indx1:
-        xml = text[indx1:indx2+len(tag)+2]
-    return xml
-
-def prepare_slice(api, xrn, users):
-    reg_objects = __get_registry_objects(slice_xrn, creds, users)
-    (hrn, type) = urn_to_hrn(slice_xrn)
-    slices = Slices(api)
-    peer = slices.get_peer(hrn)
-    sfa_peer = slices.get_sfa_peer(hrn)
-    registry = api.registries[api.hrn]
-    credential = api.getCredential()
-    (site_id, remote_site_id) = slices.verify_site(registry, credential, hrn, peer, sfa_peer, reg_objects)
-    slices.verify_slice(registry, credential, hrn, site_id, remote_site_id, peer, sfa_peer, reg_objects)
-
-def create_slice(api, xrn, cred, rspec, users):
-    indx1 = rspec.find("<RSpec")
-    indx2 = rspec.find("</RSpec>")
-    if indx1 > -1 and indx2 > indx1:
-        rspec = rspec[indx1+len("<RSpec type=\"SFA\">"):indx2-1]
-    rspec_path = save_rspec_to_file(rspec)
-    prepare_slice(api, xrn, users)
-    (ret, output) = call_am_apiclient("CreateSliceNetworkClient", [rspec_path,], 3)
-    # parse output ?
-    rspec = "<RSpec type=\"SFA\"> Done! </RSpec>"
-def delete_slice(api, xrn, cred):
-    slice_id = get_short_slice_id(cred, xrn)
-    (ret, output) = call_am_apiclient("DeleteSliceNetworkClient", [slice_id,], 3)
-    # parse output ?
-def get_rspec(api, cred, options):
-    #geni_slice_urn: urn:publicid:IDN+plc:maxpl+slice+xi_rspec_test1
-    urn = options.get('geni_slice_urn')
-    slice_id = get_short_slice_id(cred, urn)
-    if slice_id == None:
-        (ret, output) = call_am_apiclient("GetResourceTopology", ['all', '\"\"'], 5)
-        (ret, output) = call_am_apiclient("GetResourceTopology", ['all', slice_id,], 5)
-    # parse output into rspec XML
-    if output.find("No resouce found") > 0:
-        rspec = "<RSpec type=\"SFA\"> <Fault>No resource found</Fault> </RSpec>"
-    else:
-        comp_rspec = get_xml_by_tag(output, 'computeResource')
-        sfa_logger().debug("#### computeResource %s" % comp_rspec)
-        topo_rspec = get_xml_by_tag(output, 'topology')
-        sfa_logger().debug("#### topology %s" % topo_rspec)
-        rspec = "<RSpec type=\"SFA\"> <network name=\"" + Config().get_interface_hrn() + "\">";
-        if comp_rspec != None:
-            rspec = rspec + get_xml_by_tag(output, 'computeResource')
-        if topo_rspec != None:
-            rspec = rspec + get_xml_by_tag(output, 'topology')
-        rspec = rspec + "</network> </RSpec>"
-
-    return (rspec)
-
-def start_slice(api, xrn, cred):
-    # service not supported
-    return None
-
-def stop_slice(api, xrn, cred):
-    # service not supported
-    return None
-
-def reset_slices(api, xrn):
-    # service not supported
-    return None
-
-"""
-Returns the request context required by sfatables. At some point, this mechanism should be changed
-to refer to "contexts", which is the information that sfatables is requesting. But for now, we just
-return the basic information needed in a dict.
-"""
-def fetch_context(slice_hrn, user_hrn, contexts):
-    base_context = {'sfa':{'user':{'hrn':user_hrn}}}
-    return base_context
-    api = SfaAPI()
-    create_slice(api, "plc.maxpl.test000", None, rspec_xml, None)
-
+from sfa.plc.slices import Slices\r
+from sfa.server.registry import Registries\r
+from sfa.util.xrn import urn_to_hrn, hrn_to_urn, get_authority, Xrn\r
+from sfa.util.plxrn import hrn_to_pl_slicename\r
+from sfa.util.sfalogging import logger\r
+from sfa.util.faults import *\r
+from sfa.util.config import Config\r
+from sfa.util.sfatime import utcparse\r
+from sfa.util.callids import Callids\r
+from sfa.util.version import version_core\r
+from sfa.rspecs.rspec_version import RSpecVersion\r
+from sfa.rspecs.sfa_rspec import sfa_rspec_version\r
+from sfa.rspecs.rspec_parser import parse_rspec\r
+from sfa.managers.aggregate_manager_pl import __get_registry_objects, ListSlices\r
+import os\r
+import time\r
+import re\r
+\r
+RSPEC_TMP_FILE_PREFIX = "/tmp/max_rspec"\r
+\r
+# execute shell command and return both exit code and text output\r
+def shell_execute(cmd, timeout):\r
+    pipe = os.popen('{ ' + cmd + '; } 2>&1', 'r')\r
+    pipe = os.popen(cmd + ' 2>&1', 'r')\r
+    text = ''\r
+    while timeout:\r
+        line = pipe.read()\r
+        text += line\r
+        time.sleep(1)\r
+        timeout = timeout-1\r
+    code = pipe.close()\r
+    if code is None: code = 0\r
+    if text[-1:] == '\n': text = text[:-1]\r
+    return code, text\r
+\r
+"""\r
+ call AM API client with command like in the following example:\r
+ cd aggregate_client; java -classpath AggregateWS-client-api.jar:lib/* \\r
+      net.geni.aggregate.client.examples.CreateSliceNetworkClient \\r
+      ./repo https://geni:8443/axis2/services/AggregateGENI \\r
+      ... params ...\r
+"""\r
+\r
+def call_am_apiclient(client_app, params, timeout):\r
+    (client_path, am_url) = Config().get_max_aggrMgr_info()\r
+    sys_cmd = "cd " + client_path + "; java -classpath AggregateWS-client-api.jar:lib/* net.geni.aggregate.client.examples." + client_app + " ./repo " + am_url + " " + ' '.join(params)\r
+    ret = shell_execute(sys_cmd, timeout)\r
+    logger.debug("shell_execute cmd: %s returns %s" % (sys_cmd, ret))\r
+    return ret\r
+\r
+# save request RSpec xml content to a tmp file\r
+def save_rspec_to_file(rspec):\r
+    path = RSPEC_TMP_FILE_PREFIX + "_" + time.strftime('%Y%m%dT%H:%M:%S', time.gmtime(time.time())) +".xml"\r
+    file = open(path, "w")\r
+    file.write(rspec)\r
+    file.close()\r
+    return path\r
+\r
+# get stripped down slice id/name plc.maxpl.xislice1 --> maxpl_xislice1\r
+def get_plc_slice_id(cred, xrn):\r
+    (hrn, type) = urn_to_hrn(xrn)\r
+    slice_id = hrn.find(':')\r
+    sep = '.'\r
+    if hrn.find(':') != -1:\r
+        sep=':'\r
+    elif hrn.find('+') != -1:\r
+        sep='+'\r
+    else:\r
+        sep='.'\r
+    slice_id = hrn.split(sep)[-2] + '_' + hrn.split(sep)[-1]\r
+    return slice_id\r
+\r
+# extract xml \r
+def get_xml_by_tag(text, tag):\r
+    indx1 = text.find('<'+tag)\r
+    indx2 = text.find('/'+tag+'>')\r
+    xml = None\r
+    if indx1!=-1 and indx2>indx1:\r
+        xml = text[indx1:indx2+len(tag)+2]\r
+    return xml\r
+\r
+def prepare_slice(api, slice_xrn, creds, users):\r
+    reg_objects = __get_registry_objects(slice_xrn, creds, users)\r
+    (hrn, type) = urn_to_hrn(slice_xrn)\r
+    slices = Slices(api)\r
+    peer = slices.get_peer(hrn)\r
+    sfa_peer = slices.get_sfa_peer(hrn)\r
+    slice_record=None\r
+    if users:\r
+        slice_record = users[0].get('slice_record', {})\r
+    registry = api.registries[api.hrn]\r
+    credential = api.getCredential()\r
+    # ensure site record exists\r
+    site = slices.verify_site(hrn, slice_record, peer, sfa_peer)\r
+    # ensure slice record exists\r
+    slice = slices.verify_slice(hrn, slice_record, peer, sfa_peer)\r
+    # ensure person records exists\r
+    persons = slices.verify_persons(hrn, slice, users, peer, sfa_peer)\r
+\r
+def parse_resources(text, slice_xrn):\r
+    resources = []\r
+    urn = hrn_to_urn(slice_xrn, 'sliver')\r
+    plc_slice = re.search("Slice Status => ([^\n]+)", text)\r
+    if plc_slice.group(1) != 'NONE':\r
+        res = {}\r
+        res['geni_urn'] = urn + '_plc_slice'\r
+        res['geni_error'] = ''\r
+        res['geni_status'] = 'unknown'\r
+        if plc_slice.group(1) == 'CREATED':\r
+            res['geni_status'] = 'ready'\r
+        resources.append(res)\r
+    vlans = re.findall("GRI => ([^\n]+)\n\t  Status => ([^\n]+)", text)\r
+    for vlan in vlans:\r
+        res = {}\r
+        res['geni_error'] = ''\r
+        res['geni_urn'] = urn + '_vlan_' + vlan[0]\r
+        if vlan[1] == 'ACTIVE':\r
+            res['geni_status'] = 'ready'\r
+        elif vlan[1] == 'FAILED':\r
+            res['geni_status'] = 'failed'\r
+        else:\r
+            res['geni_status'] = 'configuring'\r
+        resources.append(res)\r
+    return resources\r
+\r
+def slice_status(api, slice_xrn, creds):\r
+    urn = hrn_to_urn(slice_xrn, 'slice')\r
+    result = {}\r
+    top_level_status = 'unknown'\r
+    slice_id = get_plc_slice_id(creds, urn)\r
+    (ret, output) = call_am_apiclient("QuerySliceNetworkClient", [slice_id,], 5)\r
+    # parse output into rspec XML\r
+    if output.find("Unkown Rspec:") > 0:\r
+        top_level_staus = 'failed'\r
+        result['geni_resources'] = ''\r
+    else:\r
+        has_failure = 0\r
+        all_active = 0\r
+        if output.find("Status => FAILED") > 0:\r
+            top_level_staus = 'failed'\r
+        elif (    output.find("Status => ACCEPTED") > 0 or output.find("Status => PENDING") > 0\r
+               or output.find("Status => INSETUP") > 0 or output.find("Status => INCREATE") > 0\r
+             ):\r
+            top_level_status = 'configuring'\r
+        else:\r
+            top_level_status = 'ready'\r
+        result['geni_resources'] = parse_resources(output, slice_xrn)\r
+    result['geni_urn'] = urn\r
+    result['geni_status'] = top_level_status\r
+    return result\r
+\r
+def create_slice(api, xrn, cred, rspec, users):\r
+    indx1 = rspec.find("<RSpec")\r
+    indx2 = rspec.find("</RSpec>")\r
+    if indx1 > -1 and indx2 > indx1:\r
+        rspec = rspec[indx1+len("<RSpec type=\"SFA\">"):indx2-1]\r
+    rspec_path = save_rspec_to_file(rspec)\r
+    prepare_slice(api, xrn, cred, users)\r
+    slice_id = get_plc_slice_id(cred, xrn)\r
+    sys_cmd = "sed -i \"s/rspec id=\\\"[^\\\"]*/rspec id=\\\"" +slice_id+ "/g\" " + rspec_path + ";sed -i \"s/:rspec=[^:'<\\\" ]*/:rspec=" +slice_id+ "/g\" " + rspec_path\r
+    ret = shell_execute(sys_cmd, 1)\r
+    sys_cmd = "sed -i \"s/rspec id=\\\"[^\\\"]*/rspec id=\\\"" + rspec_path + "/g\""\r
+    ret = shell_execute(sys_cmd, 1)\r
+    (ret, output) = call_am_apiclient("CreateSliceNetworkClient", [rspec_path,], 3)\r
+    # parse output ?\r
+    rspec = "<RSpec type=\"SFA\"> Done! </RSpec>"\r
+    return True\r
+\r
+def delete_slice(api, xrn, cred):\r
+    slice_id = get_plc_slice_id(cred, xrn)\r
+    (ret, output) = call_am_apiclient("DeleteSliceNetworkClient", [slice_id,], 3)\r
+    # parse output ?\r
+    return 1\r
+\r
+\r
+def get_rspec(api, cred, slice_urn):\r
+    logger.debug("#### called max-get_rspec")\r
+    #geni_slice_urn: urn:publicid:IDN+plc:maxpl+slice+xi_rspec_test1\r
+    if slice_urn == None:\r
+        (ret, output) = call_am_apiclient("GetResourceTopology", ['all', '\"\"'], 5)\r
+    else:\r
+        slice_id = get_plc_slice_id(cred, slice_urn)\r
+        (ret, output) = call_am_apiclient("GetResourceTopology", ['all', slice_id,], 5)\r
+    # parse output into rspec XML\r
+    if output.find("No resouce found") > 0:\r
+        rspec = "<RSpec type=\"SFA\"> <Fault>No resource found</Fault> </RSpec>"\r
+    else:\r
+        comp_rspec = get_xml_by_tag(output, 'computeResource')\r
+        logger.debug("#### computeResource %s" % comp_rspec)\r
+        topo_rspec = get_xml_by_tag(output, 'topology')\r
+        logger.debug("#### topology %s" % topo_rspec)\r
+        rspec = "<RSpec type=\"SFA\"> <network name=\"" + Config().get_interface_hrn() + "\">";\r
+        if comp_rspec != None:\r
+            rspec = rspec + get_xml_by_tag(output, 'computeResource')\r
+        if topo_rspec != None:\r
+            rspec = rspec + get_xml_by_tag(output, 'topology')\r
+        rspec = rspec + "</network> </RSpec>"\r
+    return (rspec)\r
+\r
+def start_slice(api, xrn, cred):\r
+    # service not supported\r
+    return None\r
+\r
+def stop_slice(api, xrn, cred):\r
+    # service not supported\r
+    return None\r
+\r
+def reset_slices(api, xrn):\r
+    # service not supported\r
+    return None\r
+\r
+"""\r
+    GENI AM API Methods\r
+"""\r
+\r
+def GetVersion(api):\r
+    xrn=Xrn(api.hrn)\r
+    request_rspec_versions = [dict(sfa_rspec_version)]\r
+    ad_rspec_versions = [dict(sfa_rspec_version)]\r
+    #TODO: MAX-AM specific\r
+    version_more = {'interface':'aggregate',\r
+                    'testbed':'myplc',\r
+                    'hrn':xrn.get_hrn(),\r
+                    'request_rspec_versions': request_rspec_versions,\r
+                    'ad_rspec_versions': ad_rspec_versions,\r
+                    'default_ad_rspec': dict(sfa_rspec_version)\r
+                    }\r
+    return version_core(version_more)\r
+\r
+def SliverStatus(api, slice_xrn, creds, call_id):\r
+    if Callids().already_handled(call_id): return {}\r
+    return slice_status(api, slice_xrn, creds)\r
+\r
+def CreateSliver(api, slice_xrn, creds, rspec_string, users, call_id):\r
+    if Callids().already_handled(call_id): return ""\r
+    #TODO: create real CreateSliver response rspec\r
+    ret = create_slice(api, slice_xrn, creds, rspec_string, users)\r
+    if ret:\r
+        return get_rspec(api, creds, slice_xrn)\r
+    else:\r
+        return "<?xml version=\"1.0\" ?> <RSpec type=\"SFA\"> Error! </RSpec>"\r
+\r
+def DeleteSliver(api, xrn, creds, call_id):\r
+    if Callids().already_handled(call_id): return ""\r
+    return delete_slice(api, xrn, creds)\r
+\r
+# no caching\r
+def ListResources(api, creds, options,call_id):\r
+    if Callids().already_handled(call_id): return ""\r
+    # version_string = "rspec_%s" % (rspec_version.get_version_name())\r
+    slice_urn = options.get('geni_slice_urn')\r
+    return get_rspec(api, creds, slice_urn)\r
+\r
+"""\r
+Returns the request context required by sfatables. At some point, this mechanism should be changed\r
+to refer to "contexts", which is the information that sfatables is requesting. But for now, we just\r
+return the basic information needed in a dict.\r
+"""\r
+def fetch_context(slice_hrn, user_hrn, contexts):\r
+    base_context = {'sfa':{'user':{'hrn':user_hrn}}}\r
+    return base_context\r
+    api = SfaAPI()\r
+    create_slice(api, "plc.maxpl.test000", None, rspec_xml, None)\r
+\r
index 1edc90b..a804a65 100755 (executable)
@@ -10,7 +10,6 @@ import struct
 
 from sfa.util.faults import *
 from sfa.util.xrn import urn_to_hrn
-from sfa.util.rspec import RSpec
 from sfa.server.registry import Registries
 from sfa.util.config import Config
 from sfa.plc.nodes import *
index 66fb5cc..eadcbfd 100644 (file)
@@ -6,7 +6,6 @@ import sys
 from types import StringTypes
 from sfa.util.xrn import urn_to_hrn, Xrn
 from sfa.util.plxrn import hrn_to_pl_slicename
-from sfa.util.rspec import *
 from sfa.util.specdict import *
 from sfa.util.faults import *
 from sfa.util.record import SfaRecord
@@ -17,6 +16,8 @@ from sfa.server.registry import Registries
 from sfa.plc.slices import Slices
 import sfa.plc.peers as peers
 from sfa.managers.vini.vini_network import *
+from sfa.plc.vini_aggregate import ViniAggregate
+from sfa.rspecs.version_manager import VersionManager
 from sfa.plc.api import SfaAPI
 from sfa.plc.slices import *
 from sfa.managers.aggregate_manager_pl import __get_registry_objects, __get_hostnames
@@ -96,20 +97,22 @@ def ListResources(api, creds, options,call_id):
     # get slice's hrn from options
     xrn = options.get('geni_slice_urn', '')
     hrn, type = urn_to_hrn(xrn)
+
+    version_manager = VersionManager()
+    # get the rspec's return format from options
+    rspec_version = version_manager.get_version(options.get('rspec_version'))
+    version_string = "rspec_%s" % (rspec_version.to_string())
     
     # look in cache first
     if api.cache and not xrn:
-        rspec = api.cache.get('nodes')
+        rspec = api.cache.get(version_string)
         if rspec:
+            api.logger.info("aggregate.ListResources: returning cached value for hrn %s"%hrn)
             return rspec
 
-    network = ViniNetwork(api)
-    if (hrn):
-        if network.get_slice(api, hrn):
-            network.addSlice()
-
-    rspec =  network.toxml()
-
+    aggregate = ViniAggregate(api, options) 
+    rspec =  aggregate.get_rspec(slice_xrn=xrn, version=rspec_version)
+           
     # cache the result
     if api.cache and not xrn:
         api.cache.add('nodes', rspec)
index ba9758c..ead2a36 100644 (file)
@@ -1,10 +1,12 @@
 start = RSpec
 RSpec = element RSpec {
+    attribute expires { xsd:NMTOKEN },
+    attribute generated { xsd:NMTOKEN },
     attribute type { xsd:NMTOKEN },
-    cloud
+    network
 }
-cloud = element cloud {
-    attribute id { xsd:NMTOKEN },
+network = element network {
+    attribute name { xsd:NMTOKEN },
     user_info?,
     ipv4,
     bundles,
index 51d23c6..9f40a5e 100644 (file)
@@ -5,15 +5,23 @@
   </start>
   <define name="RSpec">
     <element name="RSpec">
+      <attribute name="expires">
+        <data type="NMTOKEN"/>
+      </attribute>
+      <attribute name="generated">
+        <data type="NMTOKEN"/>
+      </attribute>
       <attribute name="type">
         <data type="NMTOKEN"/>
       </attribute>
-      <ref name="cloud"/>
+      <oneOrMore>
+        <ref name="network"/>
+      </oneOrMore>
     </element>
   </define>
-  <define name="cloud">
-    <element name="cloud">
-      <attribute name="id">
+  <define name="network">
+    <element name="network">
+      <attribute name="name">
         <data type="NMTOKEN"/>
       </attribute>
       <optional>
index 22162b1..cca6190 100644 (file)
-<RSpec type="eucalyptus">
-  <cloud id="OpenCirrus">
-    <ipv4>198.55.32.86</ipv4>
+<?xml version="1.0"?>
+<RSpec expires="2011-09-26T21:03:16Z" generated="2011-09-26T20:03:16Z" type="SFA">
+  <statistics call="ListResources">
+    <aggregate status="success" name="genicloud.hplabs" elapsed="0.697860002518"/>
+    <aggregate status="success" name="genicloud.ucsd" elapsed="0.901086091995"/>
+  </statistics>
+  <network name="HpLabs-Cloud">
+    <ipv4>198.55.32.75</ipv4>
     <bundles>
-      <bundle id="fc12" />
-      <bundle id="f12-plab" />
-      <bundle id="f12-planetlab" />
-      <bundle id="fc11" />
+      <bundle id="ubuntu904"/>
     </bundles>
-    <cluster id="euca-oc">
-      <ipv4>198.55.32.86</ipv4>
+    <cluster id="hplabs">
+      <ipv4>198.55.32.75</ipv4>
       <vm_types>
         <vm_type name="m1.small">
-          <free_slots>41</free_slots>
-          <max_instances>44</max_instances>
+          <free_slots>39</free_slots>
+          <max_instances>40</max_instances>
           <cores>1</cores>
-          <memory unit="MB">192</memory>
+          <memory unit="MB">128</memory>
           <disk_space unit="GB">2</disk_space>
         </vm_type>
         <vm_type name="c1.medium">
-          <free_slots>41</free_slots>
-          <max_instances>44</max_instances>
+          <free_slots>39</free_slots>
+          <max_instances>40</max_instances>
           <cores>1</cores>
           <memory unit="MB">256</memory>
           <disk_space unit="GB">5</disk_space>
-          <request>
+           <request>
             <instances>1</instances>
-            <bundle>f12-plab</bundle>
+            <bundle>ubuntu904</bundle>
           </request>
-        </vm_type>
+       </vm_type>
         <vm_type name="m1.large">
           <free_slots>19</free_slots>
-          <max_instances>22</max_instances>
+          <max_instances>20</max_instances>
           <cores>2</cores>
           <memory unit="MB">512</memory>
           <disk_space unit="GB">10</disk_space>
         </vm_type>
         <vm_type name="m1.xlarge">
           <free_slots>19</free_slots>
-          <max_instances>22</max_instances>
+          <max_instances>20</max_instances>
+          <cores>2</cores>
+          <memory unit="MB">1024</memory>
+          <disk_space unit="GB">20</disk_space>
+        </vm_type>
+        <vm_type name="c1.xlarge">
+          <free_slots>9</free_slots>
+          <max_instances>10</max_instances>
+          <cores>4</cores>
+          <memory unit="MB">2048</memory>
+          <disk_space unit="GB">20</disk_space>
+        </vm_type>
+      </vm_types>
+    </cluster>
+  </network>
+  <network name="UCSD-Cloud">
+    <ipv4>169.228.66.144</ipv4>
+    <bundles>
+      <bundle id="ubuntu904"/>
+    </bundles>
+    <cluster id="ucsd">
+      <ipv4>169.228.66.144</ipv4>
+      <vm_types>
+        <vm_type name="m1.small">
+          <free_slots>15</free_slots>
+          <max_instances>16</max_instances>
+          <cores>1</cores>
+          <memory unit="MB">128</memory>
+          <disk_space unit="GB">2</disk_space>
+        </vm_type>
+        <vm_type name="c1.medium">
+          <free_slots>15</free_slots>
+          <max_instances>16</max_instances>
+          <cores>1</cores>
+          <memory unit="MB">256</memory>
+          <disk_space unit="GB">5</disk_space>
+        </vm_type>
+        <vm_type name="m1.large">
+          <free_slots>7</free_slots>
+          <max_instances>8</max_instances>
+          <cores>2</cores>
+          <memory unit="MB">512</memory>
+          <disk_space unit="GB">10</disk_space>
+        </vm_type>
+        <vm_type name="m1.xlarge">
+          <free_slots>7</free_slots>
+          <max_instances>8</max_instances>
           <cores>2</cores>
           <memory unit="MB">1024</memory>
           <disk_space unit="GB">20</disk_space>
         </vm_type>
         <vm_type name="c1.xlarge">
-          <free_slots>8</free_slots>
-          <max_instances>11</max_instances>
+          <free_slots>3</free_slots>
+          <max_instances>4</max_instances>
           <cores>4</cores>
           <memory unit="MB">2048</memory>
           <disk_space unit="GB">20</disk_space>
         </vm_type>
       </vm_types>
     </cluster>
-  </cloud>
+  </network>
 </RSpec>
diff --git a/sfa/managers/import_manager.py b/sfa/managers/import_manager.py
new file mode 100644 (file)
index 0000000..f5f30c4
--- /dev/null
@@ -0,0 +1,26 @@
+from sfa.util.sfalogging import logger
+
+def import_manager(kind, type):
+    """
+    kind expected in ['registry', 'aggregate', 'slice', 'component']
+    type is e.g. 'pl' or 'max' or whatever
+    """
+    basepath = 'sfa.managers'
+    qualified = "%s.%s_manager_%s"%(basepath,kind,type)
+    generic = "%s.%s_manager"%(basepath,kind)
+
+    message="import_manager for kind=%s and type=%s"%(kind,type)
+    try: 
+        manager = __import__(qualified, fromlist=[basepath])
+        logger.info ("%s: loaded %s"%(message,qualified))
+    except:
+        try:
+            manager = __import__ (generic, fromlist=[basepath])
+            if type != 'pl' : 
+                logger.warn ("%s: using generic with type!='pl'"%(message))
+            logger.info("%s: loaded %s"%(message,generic))
+        except:
+            manager=None
+            logger.log_exc("%s: unable to import either %s or %s"%(message,qualified,generic))
+    return manager
+    
similarity index 97%
rename from sfa/managers/registry_manager_pl.py
rename to sfa/managers/registry_manager.py
index 8bec1f6..6052eee 100644 (file)
@@ -174,6 +174,18 @@ def list(api, xrn, origin_hrn=None):
     return records
 
 
+def create_gid(api, xrn, cert):
+    # get the authority
+    authority = Xrn(xrn=xrn).get_authority_hrn()
+    auth_info = api.auth.get_auth_info(authority)
+    if not cert:
+        pkey = Keypair(create=True)
+    else:
+        certificate = Certificate(string=cert)
+        pkey = certificate.get_pubkey()    
+    gid = api.auth.hierarchy.create_gid(xrn, create_uuid(), pkey) 
+    return gid.save_to_string(save_parents=True)
+    
 def register(api, record):
 
     hrn, type = record['hrn'], record['type']
@@ -192,7 +204,6 @@ def register(api, record):
     record['authority'] = get_authority(record['hrn'])
     type = record['type']
     hrn = record['hrn']
-    api.auth.verify_object_permission(hrn)
     auth_info = api.auth.get_auth_info(record['authority'])
     pub_key = None
     # make sure record has a gid
@@ -288,7 +299,6 @@ def update(api, record_dict):
     type = new_record['type']
     hrn = new_record['hrn']
     urn = hrn_to_urn(hrn,type)
-    api.auth.verify_object_permission(hrn)
     table = SfaTable()
     # make sure the record exists
     records = table.findObjects({'type': type, 'hrn': hrn})
similarity index 54%
rename from sfa/managers/slice_manager_pl.py
rename to sfa/managers/slice_manager.py
index 1077dcf..3175310 100644 (file)
@@ -1,4 +1,4 @@
-# 
+#
 import sys
 import time,datetime
 from StringIO import StringIO
@@ -7,21 +7,17 @@ from copy import deepcopy
 from copy import copy
 from lxml import etree
 
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
 from sfa.util.rspecHelper import merge_rspecs
 from sfa.util.xrn import Xrn, urn_to_hrn, hrn_to_urn
 from sfa.util.plxrn import hrn_to_pl_slicename
-from sfa.util.rspec import *
 from sfa.util.specdict import *
 from sfa.util.faults import *
 from sfa.util.record import SfaRecord
-from sfa.rspecs.pg_rspec import PGRSpec
-from sfa.rspecs.sfa_rspec import SfaRSpec
 from sfa.rspecs.rspec_converter import RSpecConverter
-from sfa.rspecs.rspec_parser import parse_rspec    
-from sfa.rspecs.rspec_version import RSpecVersion
-from sfa.rspecs.sfa_rspec import sfa_rspec_version
-from sfa.rspecs.pg_rspec import pg_rspec_ad_version, pg_rspec_request_version   
+from sfa.client.client_helper import sfa_to_pg_users_arg
+from sfa.rspecs.version_manager import VersionManager
+from sfa.rspecs.rspec import RSpec 
 from sfa.util.policy import Policy
 from sfa.util.prefixTree import prefixTree
 from sfa.util.sfaticket import *
@@ -32,29 +28,55 @@ import sfa.plc.peers as peers
 from sfa.util.version import version_core
 from sfa.util.callids import Callids
 
+
+def _call_id_supported(api, server):
+    """
+    Returns true if server support the optional call_id arg, false otherwise.
+    """
+    server_version = api.get_cached_server_version(server)
+
+    if 'sfa' in server_version:
+        code_tag = server_version['code_tag']
+        code_tag_parts = code_tag.split("-")
+
+        version_parts = code_tag_parts[0].split(".")
+        major, minor = version_parts[0:2]
+        rev = code_tag_parts[1]
+        if int(major) > 1:
+            if int(minor) > 0 or int(rev) > 20:
+                return True
+    return False
+
 # we have specialized xmlrpclib.ServerProxy to remember the input url
 # OTOH it's not clear if we're only dealing with XMLRPCServerProxy instances
 def get_serverproxy_url (server):
     try:
-        return server.url
+        return server.get_url()
     except:
-        sfa_logger().warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
-        return server._ServerProxy__host + server._ServerProxy__handler 
+        logger.warning("GetVersion, falling back to xmlrpclib.ServerProxy internals")
+        return server._ServerProxy__host + server._ServerProxy__handler
 
 def GetVersion(api):
     # peers explicitly in aggregates.xml
-    peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems() 
+    peers =dict ([ (peername,get_serverproxy_url(v)) for (peername,v) in api.aggregates.iteritems()
                    if peername != api.hrn])
-    xrn=Xrn (api.hrn)
-    request_rspec_versions = [dict(pg_rspec_request_version), dict(sfa_rspec_version)]
-    ad_rspec_versions = [dict(pg_rspec_ad_version), dict(sfa_rspec_version)]
+    version_manager = VersionManager()
+    ad_rspec_versions = []
+    request_rspec_versions = []
+    for rspec_version in version_manager.versions:
+        if rspec_version.content_type in ['*', 'ad']:
+            ad_rspec_versions.append(rspec_version.to_dict())
+        if rspec_version.content_type in ['*', 'request']:
+            request_rspec_versions.append(rspec_version.to_dict())
+    default_rspec_version = version_manager.get_version("sfa 1").to_dict()
+    xrn=Xrn(api.hrn, 'authority+sa')
     version_more = {'interface':'slicemgr',
                     'hrn' : xrn.get_hrn(),
                     'urn' : xrn.get_urn(),
                     'peers': peers,
                     'request_rspec_versions': request_rspec_versions,
                     'ad_rspec_versions': ad_rspec_versions,
-                    'default_ad_rspec': dict(sfa_rspec_version)
+                    'default_ad_rspec': default_rspec_version
                     }
     sm_version=version_core(version_more)
     # local aggregate if present needs to have localhost resolved
@@ -63,43 +85,151 @@ def GetVersion(api):
         sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname'])
     return sm_version
 
-def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
+def drop_slicemgr_stats(rspec):
+    try:
+        stats_elements = rspec.xml.xpath('//statistics')
+        for node in stats_elements:
+            node.getparent().remove(node)
+    except Exception, e:
+        api.logger.warn("drop_slicemgr_stats failed: %s " % (str(e)))
 
-    def _CreateSliver(aggregate, xrn, credential, rspec, users, call_id):
-            # Need to call GetVersion at an aggregate to determine the supported 
-            # rspec type/format beofre calling CreateSliver at an Aggregate. 
-            # The Aggregate's verion info is cached 
-            server = api.aggregates[aggregate]
-            # get cached aggregate version
-            aggregate_version_key = 'version_'+ aggregate
-            aggregate_version = api.cache.get(aggregate_version_key)
-            if not aggregate_version:
-                # get current aggregate version anc cache it for 24 hours
-                aggregate_version = server.GetVersion()
-                api.cache.add(aggregate_version_key, aggregate_version, 60 * 60 * 24)
-                
-            if 'sfa' not in aggregate_version and 'geni_api' in aggregate_version:
-                # sfa aggregtes support both sfa and pg rspecs, no need to convert
-                # if aggregate supports sfa rspecs. othewise convert to pg rspec
-                rspec = RSpecConverter.to_pg_rspec(rspec)
+def add_slicemgr_stat(rspec, callname, aggname, elapsed, status):
+    try:
+        stats_tags = rspec.xml.xpath('//statistics[@call="%s"]' % callname)
+        if stats_tags:
+            stats_tag = stats_tags[0]
+        else:
+            stats_tag = etree.SubElement(rspec.xml.root, "statistics", call=callname)
 
-            return server.CreateSliver(xrn, credential, rspec, users, call_id)
-                
+        etree.SubElement(stats_tag, "aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status))
+    except Exception, e:
+        api.logger.warn("add_slicemgr_stat failed on  %s: %s" %(aggname, str(e)))
+
+def ListResources(api, creds, options, call_id):
+    version_manager = VersionManager()
+    def _ListResources(aggregate, server, credential, opts, call_id):
+
+        my_opts = copy(opts)
+        args = [credential, my_opts]
+        tStart = time.time()
+        try:
+            if _call_id_supported(api, server):
+                args.append(call_id)
+            version = api.get_cached_server_version(server)
+            # force ProtoGENI aggregates to give us a v2 RSpec
+            if 'sfa' not in version.keys():
+                my_opts['rspec_version'] = version_manager.get_version('ProtoGENI 2').to_dict()
+            rspec = server.ListResources(*args)
+            return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
+        except Exception, e:
+            api.logger.log_exc("ListResources failed at %s" %(server.url))
+            return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
 
     if Callids().already_handled(call_id): return ""
 
+    # get slice's hrn from options
+    xrn = options.get('geni_slice_urn', '')
+    (hrn, type) = urn_to_hrn(xrn)
+    if 'geni_compressed' in options:
+        del(options['geni_compressed'])
+
+    # get the rspec's return format from options
+    rspec_version = version_manager.get_version(options.get('rspec_version'))
+    version_string = "rspec_%s" % (rspec_version.to_string())
+
+    # look in cache first
+    if caching and api.cache and not xrn:
+        rspec =  api.cache.get(version_string)
+        if rspec:
+            return rspec
+
+    # get the callers hrn
+    valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
+    caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
+
+    # attempt to use delegated credential first
+    cred = api.getDelegatedCredential(creds)
+    if not cred:
+        cred = api.getCredential()
+    threads = ThreadManager()
+    for aggregate in api.aggregates:
+        # prevent infinite loop. Dont send request back to caller
+        # unless the caller is the aggregate's SM
+        if caller_hrn == aggregate and aggregate != api.hrn:
+            continue
+
+        # get the rspec from the aggregate
+        interface = api.aggregates[aggregate]
+        server = api.get_server(interface, cred)
+        threads.run(_ListResources, aggregate, server, [cred], options, call_id)
+
+
+    results = threads.get_results()
+    rspec_version = version_manager.get_version(options.get('rspec_version'))
+    if xrn:    
+        result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'manifest')
+    else: 
+        result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'ad')
+    rspec = RSpec(version=result_version)
+    for result in results:
+        add_slicemgr_stat(rspec, "ListResources", result["aggregate"], result["elapsed"], result["status"])
+        if result["status"]=="success":
+            try:
+                rspec.version.merge(result["rspec"])
+            except:
+                api.logger.log_exc("SM.ListResources: Failed to merge aggregate rspec")
+
+    # cache the result
+    if caching and api.cache and not xrn:
+        api.cache.add(version_string, rspec.toxml())
+
+    return rspec.toxml()
+
+
+def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
+
+    version_manager = VersionManager()
+    def _CreateSliver(aggregate, server, xrn, credential, rspec, users, call_id):
+        tStart = time.time()
+        try:
+            # Need to call GetVersion at an aggregate to determine the supported
+            # rspec type/format beofre calling CreateSliver at an Aggregate.
+            server_version = api.get_cached_server_version(server)
+            requested_users = users
+            if 'sfa' not in server_version and 'geni_api' in server_version:
+                # sfa aggregtes support both sfa and pg rspecs, no need to convert
+                # if aggregate supports sfa rspecs. otherwise convert to pg rspec
+                rspec = RSpec(RSpecConverter.to_pg_rspec(rspec, 'request'))
+                filter = {'component_manager_id': server_version['urn']}
+                rspec.filter(filter)
+                rspec = rspec.toxml()
+                requested_users = sfa_to_pg_users_arg(users)
+            args = [xrn, credential, rspec, requested_users]
+            if _call_id_supported(api, server):
+                args.append(call_id)
+            rspec = server.CreateSliver(*args)
+            return {"aggregate": aggregate, "rspec": rspec, "elapsed": time.time()-tStart, "status": "success"}
+        except: 
+            logger.log_exc('Something wrong in _CreateSliver with URL %s'%server.url)
+            return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception"}
+
+    if Callids().already_handled(call_id): return ""
     # Validate the RSpec against PlanetLab's schema --disabled for now
     # The schema used here needs to aggregate the PL and VINI schemas
     # schema = "/var/www/html/schemas/pl.rng"
-    rspec = parse_rspec(rspec_str)
+    rspec = RSpec(rspec_str)
     schema = None
     if schema:
         rspec.validate(schema)
 
+    # if there is a <statistics> section, the aggregates don't care about it,
+    # so delete it.
+    drop_slicemgr_stats(rspec)
+
     # attempt to use delegated credential first
-    credential = api.getDelegatedCredential(creds)
-    if not credential:
-        credential = api.getCredential()
+    cred = api.getDelegatedCredential(creds)
+    if not cred:
+        cred = api.getCredential()
 
     # get the callers hrn
     hrn, type = urn_to_hrn(xrn)
@@ -111,17 +241,31 @@ def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
         # unless the caller is the aggregate's SM 
         if caller_hrn == aggregate and aggregate != api.hrn:
             continue
-            
+        interface = api.aggregates[aggregate]
+        server = api.get_server(interface, cred)
         # Just send entire RSpec to each aggregate
-        threads.run(_CreateSliver, aggregate, xrn, credential, rspec.toxml(), users, call_id)
+        threads.run(_CreateSliver, aggregate, server, xrn, [cred], rspec.toxml(), users, call_id)
             
     results = threads.get_results()
-    rspec = SfaRSpec()
+    manifest_version = version_manager._get_version(rspec.version.type, rspec.version.version, 'manifest')
+    result_rspec = RSpec(version=manifest_version)
     for result in results:
-        rspec.merge(result)     
-    return rspec.toxml()
+        add_slicemgr_stat(result_rspec, "CreateSliver", result["aggregate"], result["elapsed"], result["status"])
+        if result["status"]=="success":
+            try:
+                result_rspec.version.merge(result["rspec"])
+            except:
+                api.logger.log_exc("SM.CreateSliver: Failed to merge aggregate rspec")
+    return result_rspec.toxml()
 
 def RenewSliver(api, xrn, creds, expiration_time, call_id):
+    def _RenewSliver(server, xrn, creds, expiration_time, call_id):
+        server_version = api.get_cached_server_version(server)
+        args =  [xrn, creds, expiration_time, call_id]
+        if _call_id_supported(api, server):
+            args.append(call_id)
+        return server.RenewSliver(*args)
+
     if Callids().already_handled(call_id): return True
 
     (hrn, type) = urn_to_hrn(xrn)
@@ -130,21 +274,144 @@ def RenewSliver(api, xrn, creds, expiration_time, call_id):
     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
 
     # attempt to use delegated credential first
-    credential = api.getDelegatedCredential(creds)
-    if not credential:
-        credential = api.getCredential()
+    cred = api.getDelegatedCredential(creds)
+    if not cred:
+        cred = api.getCredential()
     threads = ThreadManager()
     for aggregate in api.aggregates:
         # prevent infinite loop. Dont send request back to caller
         # unless the caller is the aggregate's SM
         if caller_hrn == aggregate and aggregate != api.hrn:
             continue
-
-        server = api.aggregates[aggregate]
-        threads.run(server.RenewSliver, xrn, [credential], expiration_time, call_id)
+        interface = api.aggregates[aggregate]
+        server = api.get_server(interface, cred)
+        threads.run(_RenewSliver, server, xrn, [cred], expiration_time, call_id)
     # 'and' the results
     return reduce (lambda x,y: x and y, threads.get_results() , True)
 
+def DeleteSliver(api, xrn, creds, call_id):
+    def _DeleteSliver(server, xrn, creds, call_id):
+        server_version = api.get_cached_server_version(server)
+        args =  [xrn, creds]
+        if _call_id_supported(api, server):
+            args.append(call_id)
+        return server.DeleteSliver(*args)
+
+    if Callids().already_handled(call_id): return ""
+    (hrn, type) = urn_to_hrn(xrn)
+    # get the callers hrn
+    valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
+    caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
+
+    # attempt to use delegated credential first
+    cred = api.getDelegatedCredential(creds)
+    if not cred:
+        cred = api.getCredential()
+    threads = ThreadManager()
+    for aggregate in api.aggregates:
+        # prevent infinite loop. Dont send request back to caller
+        # unless the caller is the aggregate's SM
+        if caller_hrn == aggregate and aggregate != api.hrn:
+            continue
+        interface = api.aggregates[aggregate]
+        server = api.get_server(interface, cred)
+        threads.run(_DeleteSliver, server, xrn, [cred], call_id)
+    threads.get_results()
+    return 1
+
+
+# first draft at a merging SliverStatus
+def SliverStatus(api, slice_xrn, creds, call_id):
+    def _SliverStatus(server, xrn, creds, call_id):
+        server_version = api.get_cached_server_version(server)
+        args =  [xrn, creds]
+        if _call_id_supported(api, server):
+            args.append(call_id)
+        return server.SliverStatus(*args)
+    
+    if Callids().already_handled(call_id): return {}
+    # attempt to use delegated credential first
+    cred = api.getDelegatedCredential(creds)
+    if not cred:
+        cred = api.getCredential()
+    threads = ThreadManager()
+    for aggregate in api.aggregates:
+        interface = api.aggregates[aggregate]
+        server = api.get_server(interface, cred)
+        threads.run (_SliverStatus, server, slice_xrn, [cred], call_id)
+    results = threads.get_results()
+
+    # get rid of any void result - e.g. when call_id was hit where by convention we return {}
+    results = [ result for result in results if result and result['geni_resources']]
+
+    # do not try to combine if there's no result
+    if not results : return {}
+
+    # otherwise let's merge stuff
+    overall = {}
+
+    # mmh, it is expected that all results carry the same urn
+    overall['geni_urn'] = results[0]['geni_urn']
+    overall['pl_login'] = results[0]['pl_login']
+    # append all geni_resources
+    overall['geni_resources'] = \
+        reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
+    overall['status'] = 'unknown'
+    if overall['geni_resources']:
+        overall['status'] = 'ready'
+
+    return overall
+
+caching=True
+#caching=False
+def ListSlices(api, creds, call_id):
+    def _ListSlices(server, creds, call_id):
+        server_version = api.get_cached_server_version(server)
+        args =  [creds]
+        if _call_id_supported(api, server):
+            args.append(call_id)
+        return server.ListSlices(*args)
+
+    if Callids().already_handled(call_id): return []
+
+    # look in cache first
+    if caching and api.cache:
+        slices = api.cache.get('slices')
+        if slices:
+            return slices
+
+    # get the callers hrn
+    valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
+    caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
+
+    # attempt to use delegated credential first
+    cred= api.getDelegatedCredential(creds)
+    if not cred:
+        cred = api.getCredential()
+    threads = ThreadManager()
+    # fetch from aggregates
+    for aggregate in api.aggregates:
+        # prevent infinite loop. Dont send request back to caller
+        # unless the caller is the aggregate's SM
+        if caller_hrn == aggregate and aggregate != api.hrn:
+            continue
+        interface = api.aggregates[aggregate]
+        server = api.get_server(interface, cred)
+        threads.run(_ListSlices, server, [cred], call_id)
+
+    # combime results
+    results = threads.get_results()
+    slices = []
+    for result in results:
+        slices.extend(result)
+
+    # cache the result
+    if caching and api.cache:
+        api.cache.add('slices', slices)
+
+    return slices
+
+
 def get_ticket(api, xrn, creds, rspec, users):
     slice_hrn, type = urn_to_hrn(xrn)
     # get the netspecs contained within the clients rspec
@@ -160,33 +427,19 @@ def get_ticket(api, xrn, creds, rspec, users):
     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
 
     # attempt to use delegated credential first
-    credential = api.getDelegatedCredential(creds)
-    if not credential:
-        credential = api.getCredential() 
+    cred = api.getDelegatedCredential(creds)
+    if not cred:
+        cred = api.getCredential() 
     threads = ThreadManager()
     for (aggregate, aggregate_rspec) in aggregate_rspecs.iteritems():
         # prevent infinite loop. Dont send request back to caller
         # unless the caller is the aggregate's SM
         if caller_hrn == aggregate and aggregate != api.hrn:
             continue
-        server = None
-        if aggregate in api.aggregates:
-            server = api.aggregates[aggregate]
-        else:
-            net_urn = hrn_to_urn(aggregate, 'authority')     
-            # we may have a peer that knows about this aggregate
-            for agg in api.aggregates:
-                target_aggs = api.aggregates[agg].get_aggregates(credential, net_urn)
-                if not target_aggs or not 'hrn' in target_aggs[0]:
-                    continue
-                # send the request to this address 
-                url = target_aggs[0]['url']
-                server = xmlrpcprotocol.get_server(url, api.key_file, api.cert_file)
-                # aggregate found, no need to keep looping
-                break   
-        if server is None:
-            continue 
-        threads.run(server.GetTicket, xrn, credential, aggregate_rspec, users)
+        
+        interface = api.aggregates[aggregate]
+        server = api.get_server(interface, cred)
+        threads.run(server.GetTicket, xrn, [cred], aggregate_rspec, users)
 
     results = threads.get_results()
     
@@ -222,29 +475,6 @@ def get_ticket(api, xrn, creds, rspec, users):
     ticket.sign()          
     return ticket.save_to_string(save_parents=True)
 
-
-def DeleteSliver(api, xrn, creds, call_id):
-    if Callids().already_handled(call_id): return ""
-    (hrn, type) = urn_to_hrn(xrn)
-    # get the callers hrn
-    valid_cred = api.auth.checkCredentials(creds, 'deletesliver', hrn)[0]
-    caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
-
-    # attempt to use delegated credential first
-    credential = api.getDelegatedCredential(creds)
-    if not credential:
-        credential = api.getCredential()
-    threads = ThreadManager()
-    for aggregate in api.aggregates:
-        # prevent infinite loop. Dont send request back to caller
-        # unless the caller is the aggregate's SM
-        if caller_hrn == aggregate and aggregate != api.hrn:
-            continue
-        server = api.aggregates[aggregate]
-        threads.run(server.DeleteSliver, xrn, credential, call_id)
-    threads.get_results()
-    return 1
-
 def start_slice(api, xrn, creds):
     hrn, type = urn_to_hrn(xrn)
 
@@ -253,17 +483,18 @@ def start_slice(api, xrn, creds):
     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
 
     # attempt to use delegated credential first
-    credential = api.getDelegatedCredential(creds)
-    if not credential:
-        credential = api.getCredential()
+    cred = api.getDelegatedCredential(creds)
+    if not cred:
+        cred = api.getCredential()
     threads = ThreadManager()
     for aggregate in api.aggregates:
         # prevent infinite loop. Dont send request back to caller
         # unless the caller is the aggregate's SM
         if caller_hrn == aggregate and aggregate != api.hrn:
             continue
-        server = api.aggregates[aggregate]
-        threads.run(server.Start, xrn, credential)
+        interface = api.aggregates[aggregate]
+        server = api.get_server(interface, cred)    
+        threads.run(server.Start, xrn, cred)
     threads.get_results()    
     return 1
  
@@ -275,17 +506,18 @@ def stop_slice(api, xrn, creds):
     caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
 
     # attempt to use delegated credential first
-    credential = api.getDelegatedCredential(creds)
-    if not credential:
-        credential = api.getCredential()
+    cred = api.getDelegatedCredential(creds)
+    if not cred:
+        cred = api.getCredential()
     threads = ThreadManager()
     for aggregate in api.aggregates:
         # prevent infinite loop. Dont send request back to caller
         # unless the caller is the aggregate's SM
         if caller_hrn == aggregate and aggregate != api.hrn:
             continue
-        server = api.aggregates[aggregate]
-        threads.run(server.Stop, xrn, credential)
+        interface = api.aggregates[aggregate]
+        server = api.get_server(interface, cred)
+        threads.run(server.Stop, xrn, cred)
     threads.get_results()    
     return 1
 
@@ -307,141 +539,6 @@ def status(api, xrn, creds):
     """
     return 1
 
-# Thierry : caching at the slicemgr level makes sense to some extent
-caching=True
-#caching=False
-def ListSlices(api, creds, call_id):
-
-    if Callids().already_handled(call_id): return []
-
-    # look in cache first
-    if caching and api.cache:
-        slices = api.cache.get('slices')
-        if slices:
-            return slices    
-
-    # get the callers hrn
-    valid_cred = api.auth.checkCredentials(creds, 'listslices', None)[0]
-    caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
-
-    # attempt to use delegated credential first
-    credential = api.getDelegatedCredential(creds)
-    if not credential:
-        credential = api.getCredential()
-    threads = ThreadManager()
-    # fetch from aggregates
-    for aggregate in api.aggregates:
-        # prevent infinite loop. Dont send request back to caller
-        # unless the caller is the aggregate's SM
-        if caller_hrn == aggregate and aggregate != api.hrn:
-            continue
-        server = api.aggregates[aggregate]
-        threads.run(server.ListSlices, credential, call_id)
-
-    # combime results
-    results = threads.get_results()
-    slices = []
-    for result in results:
-        slices.extend(result)
-    
-    # cache the result
-    if caching and api.cache:
-        api.cache.add('slices', slices)
-
-    return slices
-
-
-def ListResources(api, creds, options, call_id):
-
-    if Callids().already_handled(call_id): return ""
-
-    # get slice's hrn from options
-    xrn = options.get('geni_slice_urn', '')
-    (hrn, type) = urn_to_hrn(xrn)
-
-    # get the rspec's return format from options
-    rspec_version = RSpecVersion(options.get('rspec_version'))
-    version_string = "rspec_%s" % (rspec_version.get_version_name())
-
-    # look in cache first
-    if caching and api.cache and not xrn:
-        rspec =  api.cache.get(version_string)
-        if rspec:
-            return rspec
-
-    # get the callers hrn
-    valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0]
-    caller_hrn = Credential(string=valid_cred).get_gid_caller().get_hrn()
-
-    # attempt to use delegated credential first
-    credential = api.getDelegatedCredential(creds)
-    if not credential:
-        credential = api.getCredential()
-    threads = ThreadManager()
-    for aggregate in api.aggregates:
-        # prevent infinite loop. Dont send request back to caller
-        # unless the caller is the aggregate's SM
-        if caller_hrn == aggregate and aggregate != api.hrn:
-            continue
-        # get the rspec from the aggregate
-        server = api.aggregates[aggregate]
-        my_opts = copy(options)
-        my_opts['geni_compressed'] = False
-        threads.run(server.ListResources, credential, my_opts, call_id)
-                    
-    results = threads.get_results()
-    rspec_version = RSpecVersion(my_opts.get('rspec_version'))
-    if rspec_version['type'] == pg_rspec_ad_version['type']:
-        rspec = PGRSpec()
-    else:
-        rspec = SfaRSpec()
-
-    for result in results:
-        try:
-            rspec.merge(result)
-        except:
-            api.logger.info("SM.ListResources: Failed to merge aggregate rspec")
-
-    # cache the result
-    if caching and api.cache and not xrn:
-        api.cache.add(version_string, rspec.toxml())
-    return rspec.toxml()
-
-# first draft at a merging SliverStatus
-def SliverStatus(api, slice_xrn, creds, call_id):
-    if Callids().already_handled(call_id): return {}
-    # attempt to use delegated credential first
-    credential = api.getDelegatedCredential(creds)
-    if not credential:
-        credential = api.getCredential()
-    threads = ThreadManager()
-    for aggregate in api.aggregates:
-        server = api.aggregates[aggregate]
-        threads.run (server.SliverStatus, slice_xrn, credential, call_id)
-    results = threads.get_results()
-
-    # get rid of any void result - e.g. when call_id was hit where by convention we return {}
-    results = [ result for result in results if result and result['geni_resources']]
-
-    # do not try to combine if there's no result
-    if not results : return {}
-
-    # otherwise let's merge stuff
-    overall = {}
-
-    # mmh, it is expected that all results carry the same urn
-    overall['geni_urn'] = results[0]['geni_urn']
-    overall['pl_login'] = results[0]['pl_login']
-    # append all geni_resources
-    overall['geni_resources'] = \
-        reduce (lambda x,y: x+y, [ result['geni_resources'] for result in results] , [])
-    overall['status'] = 'unknown'
-    if overall['geni_resources']:
-        overall['status'] = 'ready'
-
-    return overall
-
 def main():
     r = RSpec()
     r.parseFile(sys.argv[1])
@@ -450,4 +547,4 @@ def main():
 
 if __name__ == "__main__":
     main()
-    
+
index b905c2f..7ec3e95 100644 (file)
@@ -1,8 +1,4 @@
 #!/usr/bin/python
-
-# $Id: topology.py 14181 2009-07-01 19:46:07Z acb $
-# $URL: https://svn.planet-lab.org/svn/NodeManager-topo/trunk/topology.py $
-
 #
 # Links in the physical topology, gleaned from looking at the Internet2
 # topology map.  Link (a, b) connects sites with IDs a and b.
index 0be7640..09cf902 100644 (file)
@@ -1,5 +1,7 @@
 start = RSpec
 RSpec = element RSpec { 
+   attribute expires { xsd:NMTOKEN },
+   attribute generated { xsd:NMTOKEN },
    attribute type { xsd:NMTOKEN },
    ( network | request )
 }
index 1545cb5..387c831 100644 (file)
@@ -5,6 +5,12 @@
   </start>
   <define name="RSpec">
     <element name="RSpec">
+      <attribute name="expires">
+        <data type="NMTOKEN"/>
+      </attribute>
+      <attribute name="generated">
+        <data type="NMTOKEN"/>
+      </attribute>
       <attribute name="type">
         <data type="NMTOKEN"/>
       </attribute>
diff --git a/sfa/methods/CreateGid.py b/sfa/methods/CreateGid.py
new file mode 100644 (file)
index 0000000..b25fbd1
--- /dev/null
@@ -0,0 +1,48 @@
+
+from sfa.util.xrn import urn_to_hrn
+from sfa.util.method import Method
+from sfa.util.parameter import Parameter, Mixed
+from sfa.trust.credential import Credential
+
+class CreateGid(Method):
+    """
+    Create a signed credential for the s object with the registry. In addition to being stored in the
+    SFA database, the appropriate records will also be created in the
+    PLC databases
+    
+    @param xrn urn or hrn of certificate owner
+    @param cert caller's certificate
+    @param cred credential string
+    
+    @return gid string representation
+    """
+
+    interfaces = ['registry']
+    
+    accepts = [
+        Mixed(Parameter(str, "Credential string"),
+              Parameter(type([str]), "List of credentials")),
+        Parameter(str, "URN or HRN of certificate owner"),
+        Parameter(str, "Certificate string"),
+        ]
+
+    returns = Parameter(int, "String representation of gid object")
+    
+    def call(self, creds, xrn, cert=None):
+        # TODO: is there a better right to check for or is 'update good enough? 
+        valid_creds = self.api.auth.checkCredentials(creds, 'update')
+
+        # verify permissions
+        hrn, type = urn_to_hrn(xrn)
+        self.api.auth.verify_object_permission(hrn)
+
+        #log the call
+        origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
+
+        # log
+        origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
+        self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, xrn, self.name))
+
+        manager = self.api.get_interface_manager()
+
+        return manager.create_gid(self.api, xrn, cert)
index e62e6f4..7895de3 100644 (file)
@@ -37,6 +37,11 @@ class CreateSliver(Method):
         valid_creds = self.api.auth.checkCredentials(creds, 'createsliver', hrn)
         origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
 
+        # make sure users info is specified
+        if not users:
+            msg = "'users' musst be specified and cannot be null. You may need to update your client." 
+            raise SfaInvalidArgument(name='users', extra=msg)  
+
         manager = self.api.get_interface_manager()
         
         # flter rspec through sfatables
index e175cfe..3a250d5 100644 (file)
@@ -1,5 +1,3 @@
-### $Id: get_ticket.py 17732 2010-04-19 21:10:45Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/get_ticket.py $
 import time
 from sfa.util.faults import *
 from sfa.util.xrn import urn_to_hrn
index f66e90c..3aff1e7 100644 (file)
@@ -1,5 +1,3 @@
-### $Id: reset_slice.py 15428 2009-10-23 15:28:03Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfacomponent/methods/reset_slice.py $
 import xmlrpclib
 from sfa.util.faults import *
 from sfa.util.method import Method
index 1233fa8..f4b7801 100644 (file)
@@ -1,5 +1,3 @@
-### $Id: register.py 16477 2010-01-05 16:31:37Z thierry $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/register.py $
 
 from sfa.trust.certificate import Keypair, convert_public_key
 from sfa.trust.gid import *
@@ -34,15 +32,15 @@ class Register(Method):
     returns = Parameter(int, "String representation of gid object")
     
     def call(self, record, creds):
-        
+        # validate cred    
         valid_creds = self.api.auth.checkCredentials(creds, 'register')
+        
+        # verify permissions
+        hrn = record.get('hrn', '')
+        self.api.auth.verify_object_permission(hrn)
 
         #log the call
         origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
-
-        hrn = None
-        if 'hrn' in record:
-            hrn = record['hrn']
         self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, hrn, self.name))
         
         manager = self.api.get_interface_manager()
index 864b1d5..11fd1bd 100644 (file)
@@ -1,6 +1,3 @@
-### $Id: register.py 15001 2009-09-11 20:18:54Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/register.py $
-
 from sfa.trust.certificate import Keypair, convert_public_key
 from sfa.trust.gid import *
 
index 6cbde77..1669517 100644 (file)
@@ -34,10 +34,11 @@ class RenewSliver(Method):
 
         # Validate that the time does not go beyond the credential's expiration time
         requested_time = utcparse(expiration_time)
+        max_renew_days = int(self.api.config.SFA_MAX_SLICE_RENEW)
         if requested_time > Credential(string=valid_creds[0]).get_expiration():
             raise InsufficientRights('Renewsliver: Credential expires before requested expiration time')
-        if requested_time > datetime.datetime.utcnow() + datetime.timedelta(days=60):
-            raise Exception('Cannot renew > 60 days from now')
+        if requested_time > datetime.datetime.utcnow() + datetime.timedelta(days=max_renew_days):
+            raise Exception('Cannot renew > %s days from now' % max_renew_days)
         manager = self.api.get_interface_manager()
         return manager.RenewSliver(self.api, slice_xrn, valid_creds, expiration_time, call_id)    
     
index 49104b2..36b2bde 100644 (file)
@@ -1,5 +1,3 @@
-### $Id: resolve.py 17157 2010-02-21 04:19:34Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/resolve.py $
 import traceback
 import types
 from sfa.util.faults import *
index e1ca60e..e119f3c 100644 (file)
@@ -1,6 +1,3 @@
-### $Id: stop_slice.py 17732 2010-04-19 21:10:45Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/stop_slice.py $
-
 from sfa.util.faults import *
 from sfa.util.xrn import urn_to_hrn
 from sfa.util.method import Method
index 579a77d..cdae0fc 100644 (file)
@@ -1,6 +1,3 @@
-### $Id: stop_slice.py 17732 2010-04-19 21:10:45Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/stop_slice.py $
-
 from sfa.util.faults import *
 from sfa.util.xrn import urn_to_hrn
 from sfa.util.method import Method
index d36ea36..aa881ea 100644 (file)
@@ -1,6 +1,3 @@
-### $Id: update.py 16477 2010-01-05 16:31:37Z thierry $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/update.py $
-
 import time
 from sfa.util.faults import *
 from sfa.util.method import Method
@@ -31,8 +28,14 @@ class Update(Method):
     def call(self, record_dict, creds):
         # validate the cred
         valid_creds = self.api.auth.checkCredentials(creds, "update")
+        
+        # verify permissions
+        hrn = record_dict.get('hrn', '')  
+        self.api.auth.verify_object_permission(hrn)
+    
+        # log
         origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn()
-        self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, None, self.name))
+        self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, hrn, self.name))
        
         manager = self.api.get_interface_manager()
  
index a585d93..eef24de 100644 (file)
@@ -1,6 +1,7 @@
 ## Please use make index to update this file
 all = """
 CreateSliver
+CreateGid
 DeleteSliver
 GetCredential
 GetGids
index e89f18b..484fae5 100644 (file)
@@ -1,6 +1,3 @@
-### $Id: register.py 15001 2009-09-11 20:18:54Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/register.py $
-
 from sfa.trust.certificate import Keypair, convert_public_key
 from sfa.trust.gid import *
 
index 9d02364..a4e9c5e 100644 (file)
@@ -1,6 +1,3 @@
-### $Id: reset_slices.py 15428 2009-10-23 15:28:03Z tmack $
-### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/methods/reset_slices.py $
-
 from sfa.util.faults import *
 from sfa.util.xrn import urn_to_hrn
 from sfa.util.method import Method
index 12580b8..5ad2bff 100644 (file)
@@ -1,9 +1,12 @@
 #!/usr/bin/python
 from sfa.util.xrn import *
 from sfa.util.plxrn import *
-from sfa.rspecs.sfa_rspec import SfaRSpec
-from sfa.rspecs.pg_rspec  import PGRSpec
-from sfa.rspecs.rspec_version import RSpecVersion
+#from sfa.rspecs.sfa_rspec import SfaRSpec
+#from sfa.rspecs.pg_rspec  import PGRSpec
+#from sfa.rspecs.rspec_version import RSpecVersion
+from sfa.rspecs.rspec import RSpec
+from sfa.rspecs.version_manager import VersionManager
+from sfa.plc.vlink import get_tc_rate
 
 class Aggregate:
 
@@ -13,6 +16,7 @@ class Aggregate:
     interfaces = {}
     links = {}
     node_tags = {}
+    pl_initscripts = {} 
     prepared=False
     #panos new user options variable
     user_options = {}
@@ -28,7 +32,19 @@ class Aggregate:
     
     def prepare_nodes(self, force=False):
         if not self.nodes or force:
-            for node in self.api.plshell.GetNodes(self.api.plauth):
+            for node in self.api.plshell.GetNodes(self.api.plauth, {'peer_id': None}):
+                # add site/interface info to nodes.
+                # assumes that sites, interfaces and tags have already been prepared.
+                site = self.sites[node['site_id']]
+                interfaces = [self.interfaces[interface_id] for interface_id in node['interface_ids']]
+                tags = [self.node_tags[tag_id] for tag_id in node['node_tag_ids']]
+                node['network'] = self.api.hrn
+                node['network_urn'] = hrn_to_urn(self.api.hrn, 'authority+am')
+                node['urn'] = hostname_to_urn(self.api.hrn, site['login_base'], node['hostname'])
+                node['site_urn'] = hrn_to_urn(PlXrn.site_hrn(self.api.hrn, site['login_base']), 'authority+sa')
+                node['site'] = site
+                node['interfaces'] = interfaces
+                node['tags'] = tags
                 self.nodes[node['node_id']] = node
 
     def prepare_interfaces(self, force=False):
@@ -45,72 +61,102 @@ class Aggregate:
             for node_tag in self.api.plshell.GetNodeTags(self.api.plauth):
                 self.node_tags[node_tag['node_tag_id']] = node_tag
 
+    def prepare_pl_initscripts(self, force=False):
+        if not self.pl_initscripts or force:
+            for initscript in self.api.plshell.GetInitScripts(self.api.plauth, {'enabled': True}):
+                self.pl_initscripts[initscript['initscript_id']] = initscript
+
     def prepare(self, force=False):
         if not self.prepared or force:
             self.prepare_sites(force)
-            self.prepare_nodes(force)
             self.prepare_interfaces(force)
-            self.prepare_links(force)
             self.prepare_node_tags(force)
-            # add site/interface info to nodes
-            for node_id in self.nodes:
-                node = self.nodes[node_id]
-                site = self.sites[node['site_id']]
-                interfaces = [self.interfaces[interface_id] for interface_id in node['interface_ids']]
-                tags = [self.node_tags[tag_id] for tag_id in node['node_tag_ids']]
-                node['network'] = self.api.hrn
-                node['network_urn'] = hrn_to_urn(self.api.hrn, 'authority+am')
-                node['urn'] = hostname_to_urn(self.api.hrn, site['login_base'], node['hostname'])
-                node['site_urn'] = hrn_to_urn(PlXrn.site_hrn(self.api.hrn, site['login_base']), 'authority+sa') 
-                node['site'] = site
-                node['interfaces'] = interfaces
-                node['tags'] = tags
-
+            self.prepare_nodes(force)
+            self.prepare_links(force)
+            self.prepare_pl_initscripts()
         self.prepared = True  
 
     def get_rspec(self, slice_xrn=None, version = None):
         self.prepare()
-        rspec = None
-        rspec_version = RSpecVersion(version)
-        if slice_xrn:
-            type = 'manifest'
+        version_manager = VersionManager()
+        version = version_manager.get_version(version)
+        if not slice_xrn:
+            rspec_version = version_manager._get_version(version.type, version.version, 'ad')
         else:
-            type = 'advertisement' 
-        if rspec_version['type'].lower() == 'protogeni':
-            rspec = PGRSpec(type=type)
-        elif rspec_version['type'].lower() == 'sfa':
-            rspec = SfaRSpec(type=type, user_options=self.user_options)
-        else:
-            rspec = SfaRSpec(type=type, user_options=self.user_options)
-
-
-        rspec.add_nodes(self.nodes.values())
-        rspec.add_interfaces(self.interfaces.values()) 
-        rspec.add_links(self.links.values())
-
+            rspec_version = version_manager._get_version(version.type, version.version, 'manifest')
+               
+        rspec = RSpec(version=rspec_version, user_options=self.user_options)
+        # get slice details if specified
+        slice = None
         if slice_xrn:
-            # If slicename is specified then resulting rspec is a manifest. 
-            # Add sliver details to rspec and remove 'advertisement' elements
             slice_hrn, _ = urn_to_hrn(slice_xrn)
             slice_name = hrn_to_pl_slicename(slice_hrn)
             slices = self.api.plshell.GetSlices(self.api.plauth, slice_name)
             if slices:
-                slice = slices[0]
-                slivers = []
-                tags = self.api.plshell.GetSliceTags(self.api.plauth, slice['slice_tag_ids'])
-                for node_id in slice['node_ids']:
+                slice = slices[0]            
+
+        # filter out nodes with a whitelist:
+        valid_nodes = [] 
+        for node in self.nodes.values():
+            # only doing this because protogeni rspec needs
+            # to advertise available initscripts 
+            node['pl_initscripts'] = self.pl_initscripts
+
+            if slice and node['node_id'] in slice['node_ids']:
+                valid_nodes.append(node)
+            elif slice and slice['slice_id'] in node['slice_ids_whitelist']:
+                valid_nodes.append(node)
+            elif not slice and not node['slice_ids_whitelist']:
+                valid_nodes.append(node)
+    
+        rspec.version.add_nodes(valid_nodes)
+        rspec.version.add_interfaces(self.interfaces.values()) 
+        rspec.version.add_links(self.links.values())
+
+        # add slivers
+        if slice_xrn and slice:
+            slivers = []
+            tags = self.api.plshell.GetSliceTags(self.api.plauth, slice['slice_tag_ids'])
+
+            # add default tags
+            for tag in tags:
+                # if tag isn't bound to a node then it applies to all slivers
+                # and belongs in the <sliver_defaults> tag
+                if not tag['node_id']:
+                    rspec.version.add_default_sliver_attribute(tag['tagname'], tag['value'], self.api.hrn)
+                if tag['tagname'] == 'topo_rspec' and tag['node_id']:
+                    node = self.nodes[tag['node_id']]
+                    value = eval(tag['value'])
+                    for (id, realip, bw, lvip, rvip, vnet) in value:
+                        bps = get_tc_rate(bw)
+                        remote = self.nodes[id]
+                        site1 = self.sites[node['site_id']]
+                        site2 = self.sites[remote['site_id']]
+                        link1_name = '%s:%s' % (site1['login_base'], site2['login_base']) 
+                        link2_name = '%s:%s' % (site2['login_base'], site1['login_base']) 
+                        p_link = None
+                        if link1_name in self.links:
+                            link = self.links[link1_name] 
+                        elif link2_name in self.links:
+                            link = self.links[link2_name]
+                        v_link = Link()
+                        
+                        link.capacity = bps 
+            for node_id in slice['node_ids']:
+                try:
                     sliver = {}
                     sliver['hostname'] = self.nodes[node_id]['hostname']
+                    sliver['node_id'] = node_id
+                    sliver['slice_id'] = slice['slice_id']    
                     sliver['tags'] = []
                     slivers.append(sliver)
+
+                    # add tags for this node only
                     for tag in tags:
-                        # if tag isn't bound to a node then it applies to all slivers
-                        if not tag['node_id']:
+                        if tag['node_id'] and (tag['node_id'] == node_id):
                             sliver['tags'].append(tag)
-                        else:
-                            tag_host = self.nodes[tag['node_id']]['hostname']
-                            if tag_host == sliver['hostname']:
-                                sliver['tags'].append(tag)
-                rspec.add_slivers(slivers, sliver_urn=slice_xrn)
+                except:
+                    self.api.logger.log_exc('unable to add sliver %s to node %s' % (slice['name'], node_id))
+            rspec.version.add_slivers(slivers, sliver_urn=slice_xrn)
 
-        return rspec.toxml(cleanup=True)          
+        return rspec.toxml()
index 8e15ea7..cad2267 100644 (file)
@@ -12,7 +12,7 @@ import xmlrpclib
 from sfa.util.faults import *
 from sfa.util.api import *
 from sfa.util.config import *
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
 import sfa.util.xmlrpcprotocol as xmlrpcprotocol
 from sfa.trust.auth import Auth
 from sfa.trust.rights import Right, Rights, determine_rights
@@ -106,24 +106,48 @@ class SfaAPI(BaseAPI):
 
         self.hrn = self.config.SFA_INTERFACE_HRN
         self.time_format = "%Y-%m-%d %H:%M:%S"
-        self.logger=sfa_logger()
 
+    
     def getPLCShell(self):
         self.plauth = {'Username': self.config.SFA_PLC_USER,
                        'AuthMethod': 'password',
                        'AuthString': self.config.SFA_PLC_PASSWORD}
-        try:
-            sys.path.append(os.path.dirname(os.path.realpath("/usr/bin/plcsh")))
-            self.plshell_type = 'direct'
-            import PLC.Shell
-            shell = PLC.Shell.Shell(globals = globals())
-        except:
-            self.plshell_type = 'xmlrpc' 
-            url = self.config.SFA_PLC_URL
-            shell = xmlrpclib.Server(url, verbose = 0, allow_none = True)
+
+        # The native shell (PLC.Shell.Shell) is more efficient than xmlrpc,
+        # but it leaves idle db connections open. use xmlrpc until we can figure
+        # out why PLC.Shell.Shell doesn't close db connection properly     
+        #try:
+        #    sys.path.append(os.path.dirname(os.path.realpath("/usr/bin/plcsh")))
+        #    self.plshell_type = 'direct'
+        #    import PLC.Shell
+        #    shell = PLC.Shell.Shell(globals = globals())
+        #except:
         
+        self.plshell_type = 'xmlrpc' 
+        url = self.config.SFA_PLC_URL
+        shell = xmlrpclib.Server(url, verbose = 0, allow_none = True)
         return shell
 
+    def get_server(self, interface, cred, timeout=30):
+        """
+        Returns a connection to the specified interface. Use the specified
+        credential to determine the caller and look for the caller's key/cert 
+        in the registry hierarchy cache. 
+        """       
+        from sfa.trust.hierarchy import Hierarchy
+        if not isinstance(cred, Credential):
+            cred_obj = Credential(string=cred)
+        else:
+            cred_obj = cred
+        caller_gid = cred_obj.get_gid_caller()
+        hierarchy = Hierarchy()
+        auth_info = hierarchy.get_auth_info(caller_gid.get_hrn())
+        key_file = auth_info.get_privkey_filename()
+        cert_file = auth_info.get_gid_filename()
+        server = interface.get_server(key_file, cert_file, timeout)
+        return server
+               
+        
     def getCredential(self):
         """
         Return a valid credential for this interface. 
@@ -137,7 +161,7 @@ class SfaAPI(BaseAPI):
             cred = Credential(filename = cred_filename)
             # make sure cred isnt expired
             if not cred.get_expiration or \
-               datetime.datetime.today() < cred.get_expiration():    
+               datetime.datetime.utcnow() < cred.get_expiration():    
                 return cred.save_to_string(save_parents=True)
 
         # get a new credential
@@ -155,20 +179,25 @@ class SfaAPI(BaseAPI):
         Attempt to find a credential delegated to us in
         the specified list of creds.
         """
+        from sfa.trust.hierarchy import Hierarchy
         if creds and not isinstance(creds, list): 
             creds = [creds]
-        delegated_creds = filter_creds_by_caller(creds,self.hrn)
-        if not delegated_creds:
-            return None
-        return delegated_creds[0]
+        hierarchy = Hierarchy()
+                
+        delegated_cred = None
+        for cred in creds:
+            if hierarchy.auth_exists(Credential(string=cred).get_gid_caller().get_hrn()):
+                delegated_cred = cred
+                break
+        return delegated_cred
  
     def __getCredential(self):
         """ 
         Get our credential from a remote registry 
         """
         from sfa.server.registry import Registries
-        registries = Registries(self)
-        registry = registries[self.hrn]
+        registries = Registries()
+        registry = registries.get_server(self.hrn, self.key_file, self.cert_file)
         cert_string=self.cert.save_to_string(save_parents=True)
         # get self credential
         self_cred = registry.GetSelfCredential(cert_string, self.hrn, 'authority')
@@ -189,7 +218,7 @@ class SfaAPI(BaseAPI):
             auth_hrn = hrn
         auth_info = self.auth.get_auth_info(auth_hrn)
         table = self.SfaTable()
-        records = table.findObjects(hrn)
+        records = table.findObjects({'hrn': hrn, 'type': 'authority+sa'})
         if not records:
             raise RecordNotFound
         record = records[0]
@@ -223,6 +252,8 @@ class SfaAPI(BaseAPI):
         except IOError:
             self.credential = self.getCredentialFromRegistry()
 
+
+
     ##
     # Convert SFA fields to PLC fields for use when registering up updating
     # registry record in the PLC database
@@ -340,7 +371,7 @@ class SfaAPI(BaseAPI):
             # fill in key info
             if record['type'] == 'user':
                 if 'key_ids' not in record:
-                    self.logger.info("user record has no 'key_ids' - need to import from myplc ?")
+                    logger.info("user record has no 'key_ids' - need to import from myplc ?")
                 else:
                     pubkeys = [keys[key_id]['key'] for key_id in record['key_ids'] if key_id in keys] 
                     record['keys'] = pubkeys
@@ -502,7 +533,8 @@ class SfaAPI(BaseAPI):
             elif (type.startswith("authority")):
                 record['url'] = None
                 if record['hrn'] in self.aggregates:
-                    record['url'] = self.aggregates[record['hrn']].url
+                    
+                    record['url'] = self.aggregates[record['hrn']].get_url()
 
                 if record['pointer'] != -1:
                     record['PI'] = []
index 203d321..9276fb0 100644 (file)
@@ -7,7 +7,6 @@ from lxml import etree
 from xmlbuilder import XMLBuilder
 
 from sfa.util.faults import *
-#from sfa.util.sfalogging import sfa_logger
 from sfa.util.xrn import get_authority
 from sfa.util.plxrn import hrn_to_pl_slicename, hostname_to_urn
 
index 7b96c35..95793a1 100755 (executable)
@@ -24,12 +24,12 @@ from sfa.util.xrn import get_leaf, get_authority
 from sfa.util.plxrn import hostname_to_hrn, slicename_to_hrn, email_to_hrn, hrn_to_pl_slicename
 from sfa.util.config import Config
 from sfa.trust.certificate import convert_public_key, Keypair
-from sfa.trust.trustedroot import *
+from sfa.trust.trustedroots import *
 from sfa.trust.hierarchy import *
 from sfa.util.xrn import Xrn
 from sfa.plc.api import *
 from sfa.trust.gid import create_uuid
-from sfa.plc.sfaImport import sfaImport
+from sfa.plc.sfaImport import sfaImport, _cleanup_string
 
 def process_options():
 
@@ -55,6 +55,16 @@ def save_keys(filename, keys):
     f.write("keys = %s" % str(keys))
     f.close()
 
+def _get_site_hrn(interface_hrn, site):
+    # Hardcode 'internet2' into the hrn for sites hosting
+    # internet2 nodes. This is a special operation for some vini
+    # sites only
+    hrn = ".".join([interface_hrn, site['login_base']]) 
+    if ".vini" in interface_hrn and interface_hrn.endswith('vini'):
+        if site['login_base'].startswith("i2") or site['login_base'].startswith("nlr"):
+            hrn = ".".join([interface_hrn, "internet2", site['login_base']])
+    return hrn
+
 def main():
 
     process_options()
@@ -79,6 +89,9 @@ def main():
     if not root_auth == interface_hrn:
         sfaImporter.create_top_level_auth_records(interface_hrn)
 
+    # create s user record for the slice manager
+    sfaImporter.create_sm_client_record()
+
     # create interface records
     sfaImporter.logger.info("Import: creating interface records")
     sfaImporter.create_interface_records()
@@ -145,21 +158,23 @@ def main():
         slices_dict[slice['slice_id']] = slice
     # start importing 
     for site in sites:
-        site_hrn = interface_hrn + "." + site['login_base']
-        sfa_logger().info("Importing site: %s" % site_hrn)
+        site_hrn = _get_site_hrn(interface_hrn, site)
+        sfaImporter.logger.info("Importing site: %s" % site_hrn)
 
         # import if hrn is not in list of existing hrns or if the hrn exists
         # but its not a site record
         if site_hrn not in existing_hrns or \
            (site_hrn, 'authority') not in existing_records:
-            site_hrn = sfaImporter.import_site(interface_hrn, site)
+            sfaImporter.import_site(site_hrn, site)
              
         # import node records
         for node_id in site['node_ids']:
             if node_id not in nodes_dict:
                 continue 
             node = nodes_dict[node_id]
-            hrn =  hostname_to_hrn(interface_hrn, site['login_base'], node['hostname'])
+            site_auth = get_authority(site_hrn)
+            site_name = get_leaf(site_hrn)
+            hrn =  hostname_to_hrn(site_auth, site_name, node['hostname'])
             if hrn not in existing_hrns or \
                (hrn, 'node') not in existing_records:
                 sfaImporter.import_node(hrn, node)
@@ -195,14 +210,17 @@ def main():
                (hrn, 'user') not in existing_records or update_record:
                 sfaImporter.import_person(site_hrn, person)
 
+    
     # remove stale records    
+    system_records = [interface_hrn, root_auth, interface_hrn + '.slicemanager']
     for (record_hrn, type) in existing_records.keys():
+        if record_hrn in system_records:
+            continue
+        
         record = existing_records[(record_hrn, type)]
-        # if this is the interface name dont do anything
-        if record_hrn == interface_hrn or \
-           record_hrn == root_auth or \
-           record['peer_authority']:
+        if record['peer_authority']:
             continue
+
         # dont delete vini's internet2 placeholdder record
         # normally this would be deleted becuase it does not have a plc record 
         if ".vini" in interface_hrn and interface_hrn.endswith('vini') and \
index 8436035..fb84020 100755 (executable)
@@ -14,23 +14,39 @@ from optparse import OptionParser
 from sfa.trust.hierarchy import *
 from sfa.util.record import *
 from sfa.util.table import SfaTable
-from sfa.util.sfalogging import sfa_logger_goes_to_import,sfa_logger
+from sfa.util.sfalogging import logger
 
 def main():
    usage="%prog: trash the registry DB (the 'sfa' table in the 'planetlab5' database)"
    parser = OptionParser(usage=usage)
    parser.add_option('-f','--file-system',dest='clean_fs',action='store_true',default=False,
                      help='Clean up the /var/lib/sfa/authorities area as well')
+   parser.add_option('-c','--certs',dest='clean_certs',action='store_true',default=False,
+                     help='Remove all cached certs/gids found in /var/lib/sfa/authorities area as well')
    (options,args)=parser.parse_args()
    if args:
       parser.print_help()
       sys.exit(1)
-   sfa_logger_goes_to_import()
-   sfa_logger().info("Purging SFA records from database")
+   logger.info("Purging SFA records from database")
    table = SfaTable()
    table.sfa_records_purge()
+
+   if options.clean_certs:
+      # remove the server certificate and all gids found in /var/lib/sfa/authorities
+      logger.info("Purging cached certificates")
+      for (dir, _, files) in os.walk('/var/lib/sfa/authorities'):
+         for file in files:
+            if file.endswith('.gid') or file == 'server.cert':
+               path=dir+os.sep+file
+               os.unlink(path)
+               if not os.path.exists(path):
+                  logger.info("Unlinked file %s"%path)
+               else:
+                  logger.error("Could not unlink file %s"%path)
+
    if options.clean_fs:
       # just remove all files that do not match 'server.key' or 'server.cert'
+      logger.info("Purging registry filesystem cache")
       preserved_files = [ 'server.key', 'server.cert']
       for (dir,_,files) in os.walk('/var/lib/sfa/authorities'):
          for file in files:
@@ -38,8 +54,8 @@ def main():
             path=dir+os.sep+file
             os.unlink(path)
             if not os.path.exists(path):
-               sfa_logger().info("Unlinked file %s"%path)
+               logger.info("Unlinked file %s"%path)
             else:
-               sfa_logger().error("Could not unlink file %s"%path)
+               logger.error("Could not unlink file %s"%path)
 if __name__ == "__main__":
    main()
index 238b5e1..1effe71 100644 (file)
@@ -12,7 +12,7 @@ import getopt
 import sys
 import tempfile
 
-from sfa.util.sfalogging import sfa_logger_goes_to_import,sfa_logger
+from sfa.util.sfalogging import _SfaLogger
 
 from sfa.util.record import *
 from sfa.util.table import SfaTable
@@ -20,7 +20,7 @@ from sfa.util.xrn import get_authority, hrn_to_urn
 from sfa.util.plxrn import email_to_hrn
 from sfa.util.config import Config
 from sfa.trust.certificate import convert_public_key, Keypair
-from sfa.trust.trustedroot import *
+from sfa.trust.trustedroots import TrustedRoots
 from sfa.trust.hierarchy import *
 from sfa.trust.gid import create_uuid
 
@@ -52,11 +52,10 @@ def _cleanup_string(str):
 class sfaImport:
 
     def __init__(self):
-       sfa_logger_goes_to_import()
-       self.logger = sfa_logger()
+       self.logger = _SfaLogger(logfile='/var/log/sfa_import.log', loggername='importlog')
        self.AuthHierarchy = Hierarchy()
        self.config = Config()
-       self.TrustedRoots = TrustedRootList(Config.get_trustedroots_dir(self.config))
+       self.TrustedRoots = TrustedRoots(Config.get_trustedroots_dir(self.config))
        self.plc_auth = self.config.get_plc_auth()
        self.root_auth = self.config.SFA_REGISTRY_ROOT_AUTH
         
@@ -97,6 +96,24 @@ class sfaImport:
             self.logger.info("Import: inserting authority record for %s"%hrn)
             table.insert(auth_record)
 
+    def create_sm_client_record(self):
+        """
+        Create a user record for the Slicemanager service.
+        """
+        hrn = self.config.SFA_INTERFACE_HRN + '.slicemanager'
+        urn = hrn_to_urn(hrn, 'user')
+        if not self.AuthHierarchy.auth_exists(urn):
+            self.logger.info("Import: creating Slice Manager user")
+            self.AuthHierarchy.create_auth(urn)
+
+        auth_info = self.AuthHierarchy.get_auth_info(hrn)
+        table = SfaTable()
+        sm_user_record = table.find({'type': 'user', 'hrn': hrn})
+        if not sm_user_record:
+            record = SfaRecord(hrn=hrn, gid=auth_info.get_gid_object(), type="user", pointer=-1)
+            record['authority'] = get_authority(record['hrn'])
+            table.insert(record)    
+
     def create_interface_records(self):
         """
         Create a record for each SFA interface
@@ -117,7 +134,9 @@ class sfaImport:
                 record = SfaRecord(hrn=interface_hrn, gid=gid, type=interface, pointer=-1)  
                 record['authority'] = get_authority(interface_hrn)
                 table.insert(record) 
+                                
 
+    
     def import_person(self, parent_hrn, person):
         """
         Register a user record 
@@ -136,12 +155,16 @@ class sfaImport:
             # to planetlab
             keys = self.shell.GetKeys(self.plc_auth, key_ids)
             key = keys[0]['key']
-            pkey = convert_public_key(key)
+            pkey = None
+            try:
+                pkey = convert_public_key(key)
+            except:
+                self.logger.warn('unable to convert public key for %s' % hrn) 
             if not pkey:
                 pkey = Keypair(create=True)
         else:
             # the user has no keys
-            self.logger.warning("Import: person %s does not have a PL public key"%hrn)
+            self.logger.warn("Import: person %s does not have a PL public key"%hrn)
             # if a key is unavailable, then we still need to put something in the
             # user's GID. So make one up.
             pkey = Keypair(create=True)
@@ -210,23 +233,9 @@ class sfaImport:
             table.update(node_record)
 
     
-    def import_site(self, parent_hrn, site):
+    def import_site(self, hrn, site):
         shell = self.shell
         plc_auth = self.plc_auth
-        sitename = site['login_base']
-        sitename = _cleanup_string(sitename)
-        hrn = parent_hrn + "." + sitename
-        # Hardcode 'internet2' into the hrn for sites hosting
-        # internet2 nodes. This is a special operation for some vini
-        # sites only
-        if ".vini" in parent_hrn and parent_hrn.endswith('vini'):
-            if sitename.startswith("i2"):
-                #sitename = sitename.replace("ii", "")
-                hrn = ".".join([parent_hrn, "internet2", sitename])
-            elif sitename.startswith("nlr"):
-                #sitename = sitename.replace("nlr", "")
-                hrn = ".".join([parent_hrn, "internet2", sitename])
-
         urn = hrn_to_urn(hrn, 'authority')
         self.logger.info("Import: site %s"%hrn)
 
index f99ddc1..557fc37 100644 (file)
@@ -4,14 +4,15 @@ import traceback
 import sys
 
 from types import StringTypes
-from sfa.util.xrn import get_leaf, get_authority, hrn_to_urn, urn_to_hrn
-from sfa.util.plxrn import hrn_to_pl_slicename
-from sfa.util.rspec import *
+from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn, urn_to_hrn
+from sfa.util.plxrn import hrn_to_pl_slicename, hrn_to_pl_login_base
 from sfa.util.specdict import *
 from sfa.util.faults import *
 from sfa.util.record import SfaRecord
 from sfa.util.policy import Policy
+from sfa.plc.vlink import VLink
 from sfa.util.prefixTree import prefixTree
+from collections import defaultdict
 
 MAXINT =  2L**31-1
 
@@ -24,6 +25,10 @@ class Slices:
         #filepath = path + os.sep + filename
         self.policy = Policy(self.api)    
         self.origin_hrn = origin_hrn
+        self.registry = api.registries[api.hrn]
+        self.credential = api.getCredential()
+        self.nodes = []
+        self.persons = []
 
     def get_slivers(self, xrn, node=None):
         hrn, type = urn_to_hrn(xrn)
@@ -148,7 +153,7 @@ class Slices:
         for peer_record in peers:
             names = [name.lower() for name in peer_record.values() if isinstance(name, StringTypes)]
             if site_authority in names:
-                peer = peer_record['shortname']
+                peer = peer_record
 
         return peer
 
@@ -163,206 +168,385 @@ class Slices:
         if site_authority != self.api.hrn:
             sfa_peer = site_authority
 
-        return sfa_peer 
+        return sfa_peer
 
-    def verify_site(self, registry, credential, slice_hrn, peer, sfa_peer, reg_objects=None):
-        authority = get_authority(slice_hrn)
-        authority_urn = hrn_to_urn(authority, 'authority')
-        login_base = None
-        if reg_objects:
-            site = reg_objects['site']
-            login_base = site['login_base']
-        else:
-            site_records = registry.Resolve(authority_urn, [credential])
-            site = {}            
-            for site_record in site_records:            
-                if site_record['type'] == 'authority':
-                    site = site_record
-            if not site:
-                raise RecordNotFound(authority)
-            
-        remote_site_id = site.pop('site_id')    
+    def verify_slice_nodes(self, slice, requested_slivers, peer):
         
-        if login_base is None:
-            login_base = get_leaf(authority)
-        sites = self.api.plshell.GetSites(self.api.plauth, login_base)
+        nodes = self.api.plshell.GetNodes(self.api.plauth, slice['node_ids'], ['hostname'])
+        current_slivers = [node['hostname'] for node in nodes]
 
-        if not sites:
-            site_id = self.api.plshell.AddSite(self.api.plauth, site)
-            if peer:
-                try:
-                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', site_id, peer, remote_site_id)   
-                except Exception,e:
-                    self.api.plshell.DeleteSite(self.api.plauth, site_id)
-                    raise e
-            # mark this site as an sfa peer record
-            if sfa_peer and not reg_objects:
-                peer_dict = {'type': 'authority', 'hrn': authority, 'peer_authority': sfa_peer, 'pointer': site_id}
-                registry.register_peer_object(credential, peer_dict)
+        # remove nodes not in rspec
+        deleted_nodes = list(set(current_slivers).difference(requested_slivers))
 
-            # exempt federated sites from monitor policies
-            self.api.plshell.AddSiteTag(site_id, 'exempt_site_until', "20200101")
-             
-        else:
-            site_id = sites[0]['site_id']
-            remote_site_id = sites[0]['peer_site_id']
-            old_site = sites[0]
-            #the site is already on the remote agg. Let us update(e.g. max_slices field) it with the latest info.
-            self.sync_site(old_site, site, peer)
+        # add nodes from rspec
+        added_nodes = list(set(requested_slivers).difference(current_slivers))        
 
+        try:
+            if peer:
+                self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'slice', slice['slice_id'], peer['shortname'])
+            self.api.plshell.AddSliceToNodes(self.api.plauth, slice['name'], added_nodes)
+            self.api.plshell.DeleteSliceFromNodes(self.api.plauth, slice['name'], deleted_nodes)
+
+        except: 
+            self.api.logger.log_exc('Failed to add/remove slice from nodes')
+
+    def verify_slice_links(self, slice, links, peer=None):
+        if not links or not nodes:
+            return 
+        for link in links:
+            topo_rspec = VLink.get_topo_rspec(link)            
+        
 
-        return (site_id, remote_site_id) 
+    def handle_peer(self, site, slice, persons, peer):
+        if peer:
+            # bind site
+            try:
+                if site:
+                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', \
+                       site['site_id'], peer['shortname'], slice['site_id'])
+            except Exception,e:
+                self.api.plshell.DeleteSite(self.api.plauth, site['site_id'])
+                raise e
+            
+            # bind slice
+            try:
+                if slice:
+                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', \
+                       slice['slice_id'], peer['shortname'], slice['slice_id'])
+            except Exception,e:
+                self.api.plshell.DeleteSlice(self.api.plauth, slice['slice_id'])
+                raise e 
+
+            # bind persons
+            for person in persons:
+                try:
+                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', \
+                        person['person_id'], peer['shortname'], person['peer_person_id'])
+
+                    for (key, remote_key_id) in zip(person['keys'], person['key_ids']):
+                        try:
+                            self.api.plshell.BindObjectToPeer(self.api.plauth, 'key',\
+                                key['key_id'], peer['shortname'], remote_key_id)
+                        except:
+                            self.api.plshell.DeleteKey(self.api.plauth, key['key_id'])
+                            self.api.logger("failed to bind key: %s to peer: %s " % (key['key_id'], peer['shortname']))
+                except Exception,e:
+                    self.api.plshell.DeletePerson(self.api.plauth, person['person_id'])
+                    raise e       
 
-    def verify_slice(self, registry, credential, slice_hrn, site_id, remote_site_id, peer, sfa_peer, reg_objects=None):
-        slice = {}
-        slice_record = None
-        authority = get_authority(slice_hrn)
+        return slice
 
-        if reg_objects:
-            slice_record = reg_objects['slice_record']
-        else:
-            slice_records = registry.Resolve(slice_hrn, [credential])
-    
-            for record in slice_records:
-                if record['type'] in ['slice']:
-                    slice_record = record
-            if not slice_record:
-                raise RecordNotFound(hrn)
+    def verify_site(self, slice_xrn, slice_record={}, peer=None, sfa_peer=None):
+        (slice_hrn, type) = urn_to_hrn(slice_xrn)
+        site_hrn = get_authority(slice_hrn)
+        # login base can't be longer than 20 characters
+        slicename = hrn_to_pl_slicename(slice_hrn)
+        authority_name = slicename.split('_')[0]
+        login_base = authority_name[:20]
+        sites = self.api.plshell.GetSites(self.api.plauth, login_base)
+        if not sites:
+            # create new site record
+            site = {'name': 'geni.%s' % authority_name,
+                    'abbreviated_name': authority_name,
+                    'login_base': login_base,
+                    'max_slices': 100,
+                    'max_slivers': 1000,
+                    'enabled': True,
+                    'peer_site_id': None}
+            if peer:
+                site['peer_site_id'] = slice_record.get('site_id', None)
+            site['site_id'] = self.api.plshell.AddSite(self.api.plauth, site)
+            # exempt federated sites from monitor policies
+            self.api.plshell.AddSiteTag(self.api.plauth, site['site_id'], 'exempt_site_until', "20200101")
             
+            # is this still necessary?
+            # add record to the local registry 
+            if sfa_peer and slice_record:
+                peer_dict = {'type': 'authority', 'hrn': site_hrn, \
+                             'peer_authority': sfa_peer, 'pointer': site['site_id']}
+                self.registry.register_peer_object(self.credential, peer_dict)
+        else:
+            site =  sites[0]
+            if peer:
+                # unbind from peer so we can modify if necessary. Will bind back later
+                self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'site', site['site_id'], peer['shortname']) 
         
+        return site        
+
+    def verify_slice(self, slice_hrn, slice_record, peer, sfa_peer):
         slicename = hrn_to_pl_slicename(slice_hrn)
         parts = slicename.split("_")
         login_base = parts[0]
         slices = self.api.plshell.GetSlices(self.api.plauth, [slicename]) 
         if not slices:
-            slice_fields = {}
-            slice_keys = ['name', 'url', 'description']
-            for key in slice_keys:
-                if key in slice_record and slice_record[key]:
-                    slice_fields[key] = slice_record[key]
+            slice = {'name': slicename,
+                     'url': slice_record.get('url', slice_hrn), 
+                     'description': slice_record.get('description', slice_hrn)}
             # add the slice                          
-            slice_id = self.api.plshell.AddSlice(self.api.plauth, slice_fields)
-            slice = slice_fields
-            slice['slice_id'] = slice_id
-
+            slice['slice_id'] = self.api.plshell.AddSlice(self.api.plauth, slice)
+            slice['node_ids'] = []
+            slice['person_ids'] = []
+            if peer:
+                slice['peer_slice_id'] = slice_record.get('slice_id', None) 
             # mark this slice as an sfa peer record
             if sfa_peer:
-                peer_dict = {'type': 'slice', 'hrn': slice_hrn, 'peer_authority': sfa_peer, 'pointer': slice_id}
-                registry.register_peer_object(credential, peer_dict)
-
-            #this belongs to a peer
-            if peer:
-                try:
-                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', slice_id, peer, slice_record['pointer'])
-                except Exception,e:
-                    self.api.plshell.DeleteSlice(self.api.plauth,slice_id)
-                    raise e
-            slice['node_ids'] = []
+                peer_dict = {'type': 'slice', 'hrn': slice_hrn, 
+                             'peer_authority': sfa_peer, 'pointer': slice['slice_id']}
+                self.registry.register_peer_object(self.credential, peer_dict)
         else:
             slice = slices[0]
-            slice_id = slice['slice_id']
-            site_id = slice['site_id']
-           #the slice is alredy on the remote agg. Let us update(e.g. expires field) it with the latest info.
-           self.sync_slice(slice, slice_record, peer)
-
-        slice['peer_slice_id'] = slice_record['pointer']
-        self.verify_persons(registry, credential, slice_record, site_id, remote_site_id, peer, sfa_peer, reg_objects)
-    
-        return slice        
-
-    def verify_persons(self, registry, credential, slice_record, site_id, remote_site_id, peer, sfa_peer, reg_objects=None):
-        # get the list of valid slice users from the registry and make 
-        # sure they are added to the slice 
-        slicename = hrn_to_pl_slicename(slice_record['hrn'])
-        if reg_objects:
-            researchers = reg_objects['users'].keys()
-        else:
-            researchers = slice_record.get('researcher', [])
-        for researcher in researchers:
-            if reg_objects:
-                person_dict = reg_objects['users'][researcher]
-            else:
-                person_records = registry.Resolve(researcher, [credential])
-                for record in person_records:
-                    if record['type'] in ['user'] and record['enabled']:
-                        person_record = record
-                if not person_record:
-                    return 1
-                person_dict = person_record
-
-            local_person=False
             if peer:
-                peer_id = self.api.plshell.GetPeers(self.api.plauth, {'shortname': peer}, ['peer_id'])[0]['peer_id']
-                persons = self.api.plshell.GetPersons(self.api.plauth, {'email': [person_dict['email']], 'peer_id': peer_id}, ['person_id', 'key_ids'])
-                if not persons:
-                    persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])
-                    if persons:
-                        local_person=True
-                        
-            else:
-                persons = self.api.plshell.GetPersons(self.api.plauth, [person_dict['email']], ['person_id', 'key_ids'])   
-        
-            if not persons:
-                person_id=self.api.plshell.AddPerson(self.api.plauth, person_dict)
-                self.api.plshell.UpdatePerson(self.api.plauth, person_id, {'enabled' : True})
+                slice['peer_slice_id'] = slice_record.get('slice_id', None)
+                # unbind from peer so we can modify if necessary. Will bind back later
+                self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'slice',\
+                             slice['slice_id'], peer['shortname'])
+               #Update existing record (e.g. expires field) it with the latest info.
+            if slice_record and slice['expires'] != slice_record['expires']:
+                self.api.plshell.UpdateSlice(self.api.plauth, slice['slice_id'],\
+                             {'expires' : slice_record['expires']})
+       
+        return slice
+
+    #def get_existing_persons(self, users):
+    def verify_persons(self, slice_hrn, slice_record, users, peer, sfa_peer, append=True):
+        users_by_email = {}
+        users_by_site = defaultdict(list)
+
+        users_dict = {} 
+        for user in users:
+            if 'append' in user and user['append'] == False:
+                append = False
+            if 'email' in user:
+                users_by_email[user['email']] = user
+                users_dict[user['email']] = user
+            elif 'urn' in user:
+                hrn, type = urn_to_hrn(user['urn'])
+                username = get_leaf(hrn) 
+                login_base = get_leaf(get_authority(user['urn']))
+                user['username'] = username 
+                users_by_site[login_base].append(user)
+
+        existing_user_ids = []
+        if users_by_email:
+            # get existing users by email 
+            existing_users = self.api.plshell.GetPersons(self.api.plauth, \
+                {'email': users_by_email.keys()}, ['person_id', 'key_ids', 'email'])
+            existing_user_ids.extend([user['email'] for user in existing_users])
+
+        if users_by_site:
+            # get a list of user sites (based on requeste user urns
+            site_list = self.api.plshell.GetSites(self.api.plauth, users_by_site.keys(), \
+                ['site_id', 'login_base', 'person_ids'])
+            sites = {}
+            site_user_ids = []
+            
+            # get all existing users at these sites
+            for site in site_list:
+                sites[site['site_id']] = site
+                site_user_ids.extend(site['person_ids'])
+
+            existing_site_persons_list = self.api.plshell.GetPersons(self.api.plauth, \
+              site_user_ids,  ['person_id', 'key_ids', 'email', 'site_ids'])
+
+            # all requested users are either existing users or new (added) users      
+            for login_base in users_by_site:
+                requested_site_users = users_by_site[login_base]
+                for requested_user in requested_site_users:
+                    user_found = False
+                    for existing_user in existing_site_persons_list:
+                        for site_id in existing_user['site_ids']:
+                            site = sites[site_id]
+                            if login_base == site['login_base'] and \
+                               existing_user['email'].startswith(requested_user['username']):
+                                existing_user_ids.append(existing_user['email'])
+                                users_dict[existing_user['email']] = requested_user
+                                user_found = True
+                                break
+                        if user_found:
+                            break
+      
+                    if user_found == False:
+                        fake_email = requested_user['username'] + '@geni.net'
+                        users_dict[fake_email] = requested_user
                 
-                # mark this person as an sfa peer record
-                if sfa_peer:
-                    peer_dict = {'type': 'user', 'hrn': researcher, 'peer_authority': sfa_peer, 'pointer': person_id}
-                    registry.register_peer_object(credential, peer_dict)
 
-                if peer:
-                    try:
-                        self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
-                    except Exception,e:
-                        self.api.plshell.DeletePerson(self.api.plauth,person_id)
-                        raise e
-                key_ids = []
-            else:
-                person_id = persons[0]['person_id']
-                key_ids = persons[0]['key_ids']
+        # requested slice users        
+        requested_user_ids = users_dict.keys()
+        # existing slice users
+        existing_slice_users_filter = {'person_id': slice_record.get('person_ids', [])}
+        existing_slice_users = self.api.plshell.GetPersons(self.api.plauth, \
+             existing_slice_users_filter, ['person_id', 'key_ids', 'email'])
+        existing_slice_user_ids = [user['email'] for user in existing_slice_users]
+        
+        # users to be added, removed or updated
+        added_user_ids = set(requested_user_ids).difference(existing_user_ids)
+        added_slice_user_ids = set(requested_user_ids).difference(existing_slice_user_ids)
+        removed_user_ids = set(existing_slice_user_ids).difference(requested_user_ids)
+        updated_user_ids = set(existing_slice_user_ids).intersection(requested_user_ids)
+
+        # Remove stale users (only if we are not appending).
+        if append == False:
+            for removed_user_id in removed_user_ids:
+                self.api.plshell.DeletePersonFromSlice(self.api.plauth, removed_user_id, slice_record['name'])
+        # update_existing users
+        updated_users_list = [user for user in existing_slice_users if user['email'] in \
+          updated_user_ids]
+        self.verify_keys(existing_slice_users, updated_users_list, peer, append)
+
+        added_persons = []
+        # add new users
+        for added_user_id in added_user_ids:
+            added_user = users_dict[added_user_id]
+            hrn, type = urn_to_hrn(added_user['urn'])  
+            person = {
+                'first_name': added_user.get('first_name', hrn),
+                'last_name': added_user.get('last_name', hrn),
+                'email': added_user_id,
+                'peer_person_id': None,
+                'keys': [],
+                'key_ids': added_user.get('key_ids', []),
+            }
+            person['person_id'] = self.api.plshell.AddPerson(self.api.plauth, person)
+            if peer:
+                person['peer_person_id'] = added_user['person_id']
+            added_persons.append(person)
+           
+            # enable the account 
+            self.api.plshell.UpdatePerson(self.api.plauth, person['person_id'], {'enabled': True})
+            
+            # add person to site
+            self.api.plshell.AddPersonToSite(self.api.plauth, added_user_id, login_base)
 
+            for key_string in added_user.get('keys', []):
+                key = {'key':key_string, 'key_type':'ssh'}
+                key['key_id'] = self.api.plshell.AddPersonKey(self.api.plauth, person['person_id'], key)
+                person['keys'].append(key)
 
-            # if this is a peer person, we must unbind them from the peer or PLCAPI will throw
-            # an error
-            try:
-                if peer and not local_person:
-                    self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'person', person_id, peer)
-                if peer:
-                    self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'site', site_id,  peer)
-
-                self.api.plshell.AddPersonToSlice(self.api.plauth, person_dict['email'], slicename)
-                self.api.plshell.AddPersonToSite(self.api.plauth, person_dict['email'], site_id)
-            finally:
-                if peer:
-                    try: self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', site_id, peer, remote_site_id)
-                    except: pass
-                if peer and not local_person:
-                    try: self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
-                    except: pass
+            # add the registry record
+            if sfa_peer:
+                peer_dict = {'type': 'user', 'hrn': hrn, 'peer_authority': sfa_peer, \
+                    'pointer': person['person_id']}
+                self.registry.register_peer_object(self.credential, peer_dict)
+    
+        for added_slice_user_id in added_slice_user_ids.union(added_user_ids):
+            # add person to the slice 
+            self.api.plshell.AddPersonToSlice(self.api.plauth, added_slice_user_id, slice_record['name'])
+            # if this is a peer record then it should already be bound to a peer.
+            # no need to return worry about it getting bound later 
+
+        return added_persons
             
-            self.verify_keys(registry, credential, person_dict, key_ids, person_id, peer, local_person)
 
-    def verify_keys(self, registry, credential, person_dict, key_ids, person_id,  peer, local_person):
-        keylist = self.api.plshell.GetKeys(self.api.plauth, key_ids, ['key'])
-        keys = [key['key'] for key in keylist]
+    def verify_keys(self, persons, users, peer, append=True):
+        # existing keys 
+        key_ids = []
+        for person in persons:
+            key_ids.extend(person['key_ids'])
+        keylist = self.api.plshell.GetKeys(self.api.plauth, key_ids, ['key_id', 'key'])
+        keydict = {}
+        for key in keylist:
+            keydict[key['key']] = key['key_id']     
+        existing_keys = keydict.keys()
+        persondict = {}
+        for person in persons:
+            persondict[person['email']] = person    
+    
+        # add new keys
+        requested_keys = []
+        updated_persons = []
+        for user in users:
+            user_keys = user.get('keys', [])
+            updated_persons.append(user)
+            for key_string in user_keys:
+                requested_keys.append(key_string)
+                if key_string not in existing_keys:
+                    key = {'key': key_string, 'key_type': 'ssh'}
+                    try:
+                        if peer:
+                            person = persondict[user['email']]
+                            self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'person', person['person_id'], peer['shortname'])
+                        key['key_id'] = self.api.plshell.AddPersonKey(self.api.plauth, user['email'], key)
+                        if peer:
+                            key_index = user_keys.index(key['key'])
+                            remote_key_id = user['key_ids'][key_index]
+                            self.api.plshell.BindObjectToPeer(self.api.plauth, 'key', key['key_id'], peer['shortname'], remote_key_id)
+                            
+                    finally:
+                        if peer:
+                            self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person['person_id'], peer['shortname'], user['person_id'])
         
-        #add keys that arent already there
-        key_ids = person_dict['key_ids']
-        for personkey in person_dict['keys']:
-            if personkey not in keys:
-                key = {'key_type': 'ssh', 'key': personkey}
-                try:
-                    if peer:
-                        self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'person', person_id, peer)
-                    key_id = self.api.plshell.AddPersonKey(self.api.plauth, person_dict['email'], key)
-                finally:
-                    if peer and not local_person:
-                        self.api.plshell.BindObjectToPeer(self.api.plauth, 'person', person_id, peer, person_dict['pointer'])
-                    if peer:
-                        # xxx - thierry how are we getting the peer_key_id in here ?
-                        try: self.api.plshell.BindObjectToPeer(self.api.plauth, 'key', key_id, peer, key_ids.pop(0))
-                        except: pass   
+        # remove old keys (only if we are not appending)
+        if append == False: 
+            removed_keys = set(existing_keys).difference(requested_keys)
+            for existing_key_id in keydict:
+                if keydict[existing_key_id] in removed_keys:
+                    try:
+                        if peer:
+                            self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'key', existing_key_id, peer['shortname'])
+                        self.api.plshell.DeleteKey(self.api.plauth, existing_key_id)
+                    except:
+                        pass   
+
+    def verify_slice_attributes(self, slice, requested_slice_attributes):
+        # get list of attributes users ar able to manage
+        slice_attributes = self.api.plshell.GetTagTypes(self.api.plauth, {'category': '*slice*', '|roles': ['user']})
+        valid_slice_attribute_names = [attribute['tagname'] for attribute in slice_attributes]
+
+        # get sliver attributes
+        added_slice_attributes = []
+        removed_slice_attributes = []
+        ignored_slice_attribute_names = []
+        existing_slice_attributes = self.api.plshell.GetSliceTags(self.api.plauth, {'slice_id': slice['slice_id']})
+
+        # get attributes that should be removed
+        for slice_tag in existing_slice_attributes:
+            if slice_tag['tagname'] in ignored_slice_attribute_names:
+                # If a slice already has a admin only role it was probably given to them by an
+                # admin, so we should ignore it.
+                ignored_slice_attribute_names.append(slice_tag['tagname'])
+            else:
+                # If an existing slice attribute was not found in the request it should
+                # be removed
+                attribute_found=False
+                for requested_attribute in requested_slice_attributes:
+                    if requested_attribute['name'] == slice_tag['tagname'] and \
+                       requested_attribute['value'] == slice_tag['value']:
+                        attribute_found=True
+                        break
+
+            if not attribute_found:
+                removed_slice_attributes.append(slice_tag)
+        
+        # get attributes that should be added:
+        for requested_attribute in requested_slice_attributes:
+            # if the requested attribute wasn't found  we should add it
+            if requested_attribute['name'] in valid_slice_attribute_names:
+                attribute_found = False
+                for existing_attribute in existing_slice_attributes:
+                    if requested_attribute['name'] == existing_attribute['tagname'] and \
+                       requested_attribute['value'] == existing_attribute['value']:
+                        attribute_found=True
+                        break
+                if not attribute_found:
+                    added_slice_attributes.append(requested_attribute)
+
+
+        # remove stale attributes
+        for attribute in removed_slice_attributes:
+            try:
+                self.api.plshell.DeleteSliceTag(self.api.plauth, attribute['slice_tag_id'])
+            except Exception, e:
+                self.api.logger.warn('Failed to remove sliver attribute. name: %s, value: %s, node_id: %s\nCause:%s'\
+                                % (name, value,  node_id, str(e)))
+
+        # add requested_attributes
+        for attribute in added_slice_attributes:
+            try:
+                name, value, node_id = attribute['name'], attribute['value'], attribute.get('node_id', None)
+                self.api.plshell.AddSliceTag(self.api.plauth, slice['name'], name, value, node_id)
+            except Exception, e:
+                self.api.logger.warn('Failed to add sliver attribute. name: %s, value: %s, node_id: %s\nCause:%s'\
+                                % (name, value,  node_id, str(e)))
 
     def create_slice_aggregate(self, xrn, rspec):
         hrn, type = urn_to_hrn(xrn)
@@ -443,27 +627,3 @@ class Slices:
 
         return 1
 
-    def sync_site(self, old_record, new_record, peer):
-        if old_record['max_slices'] != new_record['max_slices'] or old_record['max_slivers'] != new_record['max_slivers']:
-            try:
-                if peer:
-                    self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'site', old_record['site_id'], peer)
-                if old_record['max_slices'] != new_record['max_slices']:
-                    self.api.plshell.UpdateSite(self.api.plauth, old_record['site_id'], {'max_slices' : new_record['max_slices']})
-                if old_record['max_slivers'] != new_record['max_slivers']:
-                    self.api.plshell.UpdateSite(self.api.plauth, old_record['site_id'], {'max_slivers' : new_record['max_slivers']})
-            finally:
-                if peer:
-                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'site', old_record['site_id'], peer, old_record['peer_site_id'])
-       return 1
-
-    def sync_slice(self, old_record, new_record, peer):
-        if old_record['expires'] != new_record['expires']:
-            try:
-                if peer:
-                    self.api.plshell.UnBindObjectFromPeer(self.api.plauth, 'slice', old_record['slice_id'], peer)
-                self.api.plshell.UpdateSlice(self.api.plauth, old_record['slice_id'], {'expires' : new_record['expires']})
-            finally:
-                if peer:
-                    self.api.plshell.BindObjectToPeer(self.api.plauth, 'slice', old_record['slice_id'], peer, old_record['peer_slice_id'])
-       return 1
diff --git a/sfa/plc/vini_aggregate.py b/sfa/plc/vini_aggregate.py
new file mode 100644 (file)
index 0000000..b5663b2
--- /dev/null
@@ -0,0 +1,40 @@
+from sfa.plc.aggregate import Aggregate
+from sfa.managers.vini.topology import PhysicalLinks
+from sfa.rspecs.elements.link import Link
+from sfa.util.xrn import hrn_to_urn
+from sfa.util.plxrn import PlXrn
+
+class ViniAggregate(Aggregate):
+
+    def prepare_links(self, force=False):
+        for (site_id1, site_id2) in PhysicalLinks:
+            link = Link()
+            if not site_id1 in self.sites or site_id2 not in self.sites:
+                continue 
+            site1 = self.sites[site_id1]
+            site2 = self.sites[site_id2]
+            # get hrns
+            site1_hrn = self.api.hrn + '.' + site1['login_base']
+            site2_hrn = self.api.hrn + '.' + site2['login_base']
+            # get the first node
+            node1 = self.nodes[site1['node_id'][0]]
+            node2 = self.nodes[site2['node_id'][0]]
+        
+            # set interfaces
+            # just get first interface of the first node 
+            if1_xrn = PlXrn(auth=self.api.hrn, interface='node%s:eth0' % (node1['node_id']))   
+            if2_xrn = PlXrn(auth=self.api.hrn, interface='node%s:eth0' % (node2['node_id']))
+               
+            if1 = Interface({'component_id': if1_xrn.urn} )  
+            if2 = Interface({'component_id': if2_xrn.urn} )  
+            
+            # set link
+            link = Link({'capacity': '1000000', 'latency': '0', 'packet_loss': '0', 'type': 'ipv4'})
+            link['interface1'] = if1
+            link['interface2'] = if2
+            link['component_name'] = "%s:%s" % (site1['login_base'], site2['login_base'])
+            link['component_id'] = PlXrn(auth=self.api.hrn, link=link['component_name'])
+            link['component_manager_id'] =  hrn_to_urn(self.api.hrn, 'authority+am')
+            self.links[link['component_name']] = link
+        
+        
diff --git a/sfa/plc/vlink.py b/sfa/plc/vlink.py
new file mode 100644 (file)
index 0000000..8aeee49
--- /dev/null
@@ -0,0 +1,112 @@
+
+from sfa.util.plxrn import PlXrn
+# Taken from bwlimit.py
+#
+# See tc_util.c and http://physics.nist.gov/cuu/Units/binary.html. Be
+# warned that older versions of tc interpret "kbps", "mbps", "mbit",
+# and "kbit" to mean (in this system) "kibps", "mibps", "mibit", and
+# "kibit" and that if an older version is installed, all rates will
+# be off by a small fraction.
+suffixes = {
+    "":         1,
+    "bit":  1,
+    "kibit":    1024,
+    "kbit": 1000,
+    "mibit":    1024*1024,
+    "mbit": 1000000,
+    "gibit":    1024*1024*1024,
+    "gbit": 1000000000,
+    "tibit":    1024*1024*1024*1024,
+    "tbit": 1000000000000,
+    "bps":  8,
+    "kibps":    8*1024,
+    "kbps": 8000,
+    "mibps":    8*1024*1024,
+    "mbps": 8000000,
+    "gibps":    8*1024*1024*1024,
+    "gbps": 8000000000,
+    "tibps":    8*1024*1024*1024*1024,
+    "tbps": 8000000000000
+}
+
+
+def get_tc_rate(s):
+    """
+    Parses an integer or a tc rate string (e.g., 1.5mbit) into bits/second
+    """
+
+    if type(s) == int:
+        return s
+    m = re.match(r"([0-9.]+)(\D*)", s)
+    if m is None:
+        return -1
+    suffix = m.group(2).lower()
+    if suffixes.has_key(suffix):
+        return int(float(m.group(1)) * suffixes[suffix])
+    else:
+        return -1
+
+def format_tc_rate(rate):
+    """
+    Formats a bits/second rate into a tc rate string
+    """
+
+    if rate >= 1000000000 and (rate % 1000000000) == 0:
+        return "%.0fgbit" % (rate / 1000000000.)
+    elif rate >= 1000000 and (rate % 1000000) == 0:
+        return "%.0fmbit" % (rate / 1000000.)
+    elif rate >= 1000:
+        return "%.0fkbit" % (rate / 1000.)
+    else:
+        return "%.0fbit" % rate
+
+class VLink:
+    @staticmethod
+    def get_link_id(if1, if2):
+        if if1['id'] < if2['id']:
+            link = (if1['id']<<7) + if2['id']
+        else:
+            link = (if2['id']<<7) + if1['id']
+        return link
+
+    @staticmethod
+    def get_iface_id(if1, if2):
+        if if1['id'] < if2['id']:
+            iface = 1
+        else:
+            iface = 2
+        return iface
+
+    @staticmethod
+    def get_virt_ip(if1, if2):
+        link_id = get_link_id(if1, if2)
+        iface_id = get_iface_id(if1, if2)
+        first = link_id >> 6
+        second = ((link_id & 0x3f)<<2) + iface_id
+        return "192.168.%d.%s" % (frist, second)
+
+    @staticmethod
+    def get_virt_net(link):
+        link_id = self.get_link_id(link)
+        first = link_id >> 6
+        second = (link_id & 0x3f)<<2
+        return "192.168.%d.%d/30" % (first, second)
+
+    @staticmethod
+    def get_interface_id(interface):
+        if_name = PlXrn(interface=interface['component_id']).interface_name()
+        node, dev = if_name.split(":")
+        node_id = int(node.replace("pc", ""))
+        return node_id
+
+        
+    @staticmethod
+    def get_topo_rspec(link):
+        link['interface1']['id'] = VLink.get_interface_id(link['interface1'])
+        link['interface2']['id'] = VLink.get_interface_id(link['interface2'])
+        my_ip = VLink.get_virt_ip(link['interface1'], link['interface2'])
+        remote_ip = VLink.get_virt_ip(link['interface2'], link['interface1'])
+        net = VLink.get_virt_net(link)
+        bw = format_tc_rate(long(link['capacity']))
+        ipaddr = remote.get_primary_iface().ipv4
+        return (link['interface2']['id'], ipaddr, bw, my_ip, remote_ip, net) 
diff --git a/sfa/rspecs/elements/__init__.py b/sfa/rspecs/elements/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/sfa/rspecs/elements/element.py b/sfa/rspecs/elements/element.py
new file mode 100644 (file)
index 0000000..8217c11
--- /dev/null
@@ -0,0 +1,85 @@
+from lxml import etree
+
+class Element:
+    def __init__(self, root_node, namespaces = None):
+        self.root_node = root_node
+        self.namespaces = namespaces
+
+    def xpath(self, xpath):
+        return this.root_node.xpath(xpath, namespaces=self.namespaces) 
+
+    def add_element(self, name, attrs={}, parent=None, text=""):
+        """
+        Generic wrapper around etree.SubElement(). Adds an element to
+        specified parent node. Adds element to root node is parent is
+        not specified.
+        """
+        if parent == None:
+            parent = self.root_node
+        element = etree.SubElement(parent, name)
+        if text:
+            element.text = text
+        if isinstance(attrs, dict):
+            for attr in attrs:
+                element.set(attr, attrs[attr])
+        return element
+
+    def remove_element(self, element_name, root_node = None):
+        """
+        Removes all occurences of an element from the tree. Start at
+        specified root_node if specified, otherwise start at tree's root.
+        """
+        if not root_node:
+            root_node = self.root_node
+
+        if not element_name.startswith('//'):
+            element_name = '//' + element_name
+
+        elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
+        for element in elements:
+            parent = element.getparent()
+            parent.remove(element)
+
+    
+    def add_attribute(self, elem, name, value):
+        """
+        Add attribute to specified etree element
+        """
+        opt = etree.SubElement(elem, name)
+        opt.text = value
+
+    def remove_attribute(self, elem, name, value):
+        """
+        Removes an attribute from an element
+        """
+        if not elem == None:
+            opts = elem.iterfind(name)
+            if opts is not None:
+                for opt in opts:
+                    if opt.text == value:
+                        elem.remove(opt)
+
+    def get_attributes(self, elem=None, depth=None):
+        if elem == None:
+            elem = self.root_node
+        attrs = dict(elem.attrib)
+        attrs['text'] = str(elem.text).strip()
+        if depth is None or isinstance(depth, int) and depth > 0: 
+            for child_elem in list(elem):
+                key = str(child_elem.tag)
+                if key not in attrs:
+                    attrs[key] = [self.get_attributes(child_elem, recursive)]
+                else:
+                    attrs[key].append(self.get_attributes(child_elem, recursive))
+        return attrs
+    
+    def attributes_list(self, elem):
+        # convert a list of attribute tags into list of tuples
+        # (tagnme, text_value)
+        opts = []
+        if not elem == None:
+            for e in elem:
+                opts.append((e.tag, e.text))
+        return opts
+
+    
diff --git a/sfa/rspecs/elements/interface.py b/sfa/rspecs/elements/interface.py
new file mode 100644 (file)
index 0000000..d2022d8
--- /dev/null
@@ -0,0 +1,11 @@
+class Interface(dict):
+    fields = {'component_id': None,
+              'role': None,
+              'client_id': None,
+              'ipv4': None 
+    }    
+    def __init__(self, fields={}):
+        dict.__init__(self, Interface.fields)
+        self.update(fields)
+        
+    
diff --git a/sfa/rspecs/elements/link.py b/sfa/rspecs/elements/link.py
new file mode 100644 (file)
index 0000000..4722cf8
--- /dev/null
@@ -0,0 +1,22 @@
+from sfa.rspecs.elements.interface import Interface
+
+class Link(dict):
+    
+    fields = {
+        'client_id': None, 
+        'component_id': None,
+        'component_name': None,
+        'component_manager': None,
+        'type': None,
+        'interface1': None,
+        'interface2': None,
+        'capacity': None,
+        'latency': None,
+        'packet_loss': None,
+        'description': None,
+    }
+    
+    def __init__(self, fields={}):
+        dict.__init__(self, Link.fields)
+        self.update(fields)
+
diff --git a/sfa/rspecs/elements/network.py b/sfa/rspecs/elements/network.py
new file mode 100644 (file)
index 0000000..6a358a4
--- /dev/null
@@ -0,0 +1,11 @@
+from sfa.rspecs.elements.element import Element
+from sfa.util.sfalogging import logger
+
+class Network(Element):
+
+    def get_networks(*args, **kwds):
+        logger.info("sfa.rspecs.networks: get_networks not implemented")
+
+    def add_networks(*args, **kwds):
+        logger.info("sfa.rspecs.networks: add_network not implemented")
+        
diff --git a/sfa/rspecs/elements/node.py b/sfa/rspecs/elements/node.py
new file mode 100644 (file)
index 0000000..db6e119
--- /dev/null
@@ -0,0 +1,13 @@
+from sfa.rspecs.elements.element import Element
+from sfa.util.faults import SfaNotImplemented 
+from sfa.util.sfalogging import logger
+class Node(Element):
+
+    def get_nodes(*args):
+        logger.info("sfa.rspecs.nodes: get_nodes not implemented") 
+    
+    def add_nodes(*args):
+        logger.info("sfa.rspecs.nodes: add_nodes not implemented") 
+                
+      
diff --git a/sfa/rspecs/elements/sliver.py b/sfa/rspecs/elements/sliver.py
new file mode 100644 (file)
index 0000000..67105dc
--- /dev/null
@@ -0,0 +1,29 @@
+from sfa.rspecs.elements.element import Element
+from sfa.util.sfalogging import logger
+
+class Slivers(Element):
+
+    def get_slivers(*args, **kwds):
+        logger.debug("sfa.rspecs.slivers: get_slivers not implemented")
+
+    def add_slivers(*args, **kwds):
+        logger.debug("sfa.rspecs.slivers: add_slivers not implemented")
+
+    def remove_slivers(*args, **kwds):
+        logger.debug("sfa.rspecs.slivers: remove_slivers not implemented")
+
+    def get_sliver_defaults(*args, **kwds):    
+        logger.debug("sfa.rspecs.slivers: get_sliver_defaults not implemented")
+    
+    def add_default_sliver_attribute(*args, **kwds):
+        logger.debug("sfa.rspecs.slivers: add_default_sliver_attributes not implemented")
+
+    def add_sliver_attribute(*args, **kwds):
+        logger.debug("sfa.rspecs.slivers: add_sliver_attribute not implemented")
+
+    def remove_default_sliver_attribute(*args, **kwds):
+        logger.debug("sfa.rspecs.slivers: remove_default_sliver_attributes not implemented")
+
+    def remove_sliver_attribute(*args, **kwds):
+        logger.debuv("sfa.rspecs.slivers: remove_sliver_attribute not implemented")
+        
diff --git a/sfa/rspecs/elements/versions/__init__.py b/sfa/rspecs/elements/versions/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/sfa/rspecs/elements/versions/pgv2Link.py b/sfa/rspecs/elements/versions/pgv2Link.py
new file mode 100644 (file)
index 0000000..1f19a7a
--- /dev/null
@@ -0,0 +1,75 @@
+from lxml import etree
+from sfa.rspecs.elements.link import Link
+from sfa.rspecs.elements.interface import Interface
+from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements
+
+class PGv2Link:
+
+    elements = {
+        'link': RSpecElement(RSpecElements.LINK, '//default:link | //link'),
+        'component_manager': RSpecElement(RSpecElements.COMPONENT_MANAGER, './default:component_manager | ./component_manager')
+    }
+    
+    @staticmethod
+    def add_links(xml, links):
+        for link in links:
+            link_elem = etree.SubElement(xml, 'link')
+            for attrib in ['component_name', 'component_id', 'client_id']:
+                if attrib in link and link[attrib]:
+                    link_elem.set(attrib, link[attrib])
+            if 'component_manager' in link and link['component_manager']:
+                cm_element = etree.SubElement(xml, 'component_manager', name=link['component_manager'])
+            for if_ref in [link['interface1'], link['interface2']]:
+                if_ref_elem = etree.SubElement(xml, 'interface_ref')
+                for attrib in Interface.fields:
+                    if attrib in if_ref and if_ref[attrib]:
+                        if_ref_elem.attrib[attrib] = if_ref[attrib]  
+            prop1 = etree.SubElement(xml, 'property', source_id = link['interface1']['component_id'],
+                dest_id = link['interface2']['component_id'], capacity=link['capacity'], 
+                latency=link['latency'], packet_loss=link['packet_loss'])
+            prop2 = etree.SubElement(xml, 'property', source_id = link['interface2']['component_id'],
+                dest_id = link['interface1']['component_id'], capacity=link['capacity'], 
+                latency=link['latency'], packet_loss=link['packet_loss'])
+            if 'type' in link and link['type']:
+                type_elem = etree.SubElement(xml, 'link_type', name=link['type'])             
+   
+    @staticmethod 
+    def get_links(xml, namespaces=None):
+        links = []
+        link_elems = xml.xpath('//default:link', namespaces=namespaces)
+        for link_elem in link_elems:
+            # set client_id, component_id, component_name
+            link = Link(link_elem.attrib)
+            # set component manager
+            cm = link_elem.xpath('./default:component_manager', namespaces=namespaces)
+            if len(cm) >  0:
+                cm = cm[0]
+                if  'name' in cm.attrib:
+                    link['component_manager'] = cm.attrib['name'] 
+            # set link type
+            link_types = link_elem.xpath('./default:link_type', namespaces=namespaces)
+            if len(link_types) > 0:
+                link_type = link_types[0]
+                if 'name' in link_type.attrib:
+                    link['type'] = link_type.attrib['name']
+          
+            # get capacity, latency and packet_loss from first property  
+            props = link_elem.xpath('./default:property', namespaces=namespaces)
+            if len(props) > 0:
+                prop = props[0]
+                for attrib in ['capacity', 'latency', 'packet_loss']:
+                    if attrib in prop.attrib:
+                        link[attrib] = prop.attrib[attrib]
+                             
+            # get interfaces 
+            if_elems = link_elem.xpath('./default:interface_ref', namespaces=namespaces)
+            ifs = []
+            for if_elem in if_elems:
+                if_ref = Interface(if_elem.attrib)                 
+                ifs.append(if_ref)
+            if len(ifs) > 1:
+                link['interface1'] = ifs[0]
+                link['interface2'] = ifs[1] 
+            links.append(link)
+        return links 
+
diff --git a/sfa/rspecs/pg_rspec.py b/sfa/rspecs/pg_rspec.py
deleted file mode 100755 (executable)
index 3d2ae4f..0000000
+++ /dev/null
@@ -1,179 +0,0 @@
-#!/usr/bin/python 
-from lxml import etree
-from StringIO import StringIO
-from sfa.rspecs.rspec import RSpec 
-from sfa.util.xrn import *
-from sfa.util.plxrn import hostname_to_urn
-from sfa.util.config import Config 
-from sfa.rspecs.rspec_version import RSpecVersion 
-
-_ad_version = {'type':  'ProtoGENI',
-            'version': '2',
-            'schema': 'http://www.protogeni.net/resources/rspec/2/ad.xsd',
-            'namespace': 'http://www.protogeni.net/resources/rspec/2',
-            'extensions':  [
-                'http://www.protogeni.net/resources/rspec/ext/gre-tunnel/1',
-                'http://www.protogeni.net/resources/rspec/ext/other-ext/3'
-            ]
-}
-
-_request_version = {'type':  'ProtoGENI',
-            'version': '2',
-            'schema': 'http://www.protogeni.net/resources/rspec/2/request.xsd',
-            'namespace': 'http://www.protogeni.net/resources/rspec/2',
-            'extensions':  [
-                'http://www.protogeni.net/resources/rspec/ext/gre-tunnel/1',
-                'http://www.protogeni.net/resources/rspec/ext/other-ext/3'
-            ]
-}
-pg_rspec_ad_version = RSpecVersion(_ad_version)
-pg_rspec_request_version = RSpecVersion(_request_version)
-
-class PGRSpec(RSpec):
-    xml = None
-    header = '<?xml version="1.0"?>\n'
-    template = """<rspec xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://www.protogeni.net/resources/rspec/2" xsi:schemaLocation="http://www.protogeni.net/resources/rspec/2 http://www.protogeni.net/resources/rspec/2/%(rspec_type)s.xsd"></rspec>"""
-
-    def __init__(self, rspec="", namespaces={}, type=None):
-        if not type:
-            type = 'advertisement'
-        self.type = type
-
-        if type == 'advertisement':
-            self.version = pg_rspec_ad_version
-            rspec_type = 'ad'
-        else:
-            self.version = pg_rspec_request_version
-            rspec_type = type
-        
-        self.template = self.template % {'rspec_type': rspec_type}
-
-        if not namespaces:
-            self.namespaces = {'rspecv2': self.version['namespace']}
-        else:
-            self.namespaces = namespaces 
-
-        if rspec:
-            self.parse_rspec(rspec, self.namespaces)
-        else: 
-            self.create()
-
-    def create(self):
-        RSpec.create(self)
-        if self.type:
-            self.xml.set('type', self.type) 
-       
-    def get_network(self):
-        network = None 
-        nodes = self.xml.xpath('//rspecv2:node[@component_manager_uuid][1]', namespaces=self.namespaces)
-        if nodes:
-            network  = nodes[0].get('component_manager_uuid')
-        return network
-
-    def get_networks(self):
-        networks = self.xml.xpath('//rspecv2:node[@component_manager_uuid]/@component_manager_uuid', namespaces=self.namespaces)
-        return set(networks)
-
-    def get_node_elements(self):
-        nodes = self.xml.xpath('//rspecv2:node | //node', namespaces=self.namespaces)
-        return nodes
-
-    def get_nodes(self, network=None):
-        xpath = '//rspecv2:node[@component_name]/@component_name | //node[@component_name]/@component_name'
-        return self.xml.xpath(xpath, namespaces=self.namespaces) 
-
-    def get_nodes_with_slivers(self, network=None):
-        if network:
-            return self.xml.xpath('//rspecv2:node[@component_manager_id="%s"][sliver_type]/@component_name' % network, namespaces=self.namespaces)
-        else:
-            return self.xml.xpath('//rspecv2:node[rspecv2:sliver_type]/@component_name', namespaces=self.namespaces)
-
-    def get_nodes_without_slivers(self, network=None):
-        pass
-
-    def add_nodes(self, nodes, check_for_dupes=False):
-        if not isinstance(nodes, list):
-            nodes = [nodes]
-        for node in nodes:
-            urn = ""
-            if check_for_dupes and \
-              self.xml.xpath('//rspecv2:node[@component_uuid="%s"]' % urn, namespaces=self.namespaces):
-                # node already exists
-                continue
-                
-            node_tag = etree.SubElement(self.xml, 'node', exclusive='false')
-            if 'network_urn' in node:
-                node_tag.set('component_manager_id', node['network_urn'])
-            if 'urn' in node:
-                node_tag.set('component_id', node['urn'])
-            if 'hostname' in node:
-                node_tag.set('component_name', node['hostname'])
-            # TODO: should replace plab-pc with pc model 
-            node_type_tag = etree.SubElement(node_tag, 'hardware_type', name='plab-pc')
-            node_type_tag = etree.SubElement(node_tag, 'hardware_type', name='pc')
-            available_tag = etree.SubElement(node_tag, 'available', now='true')
-            location_tag = etree.SubElement(node_tag, 'location', country="us")
-            if 'site' in node:
-                if 'longitude' in node['site']:
-                    location_tag.set('longitude', str(node['site']['longitude']))
-                if 'latitude' in node['site']:
-                    location_tag.set('latitude', str(node['site']['latitude']))
-            #if 'interfaces' in node:
-            
-
-    def add_slivers(self, slivers, sliver_urn=None, no_dupes=False): 
-        slivers = self._process_slivers(slivers)
-        nodes_with_slivers = self.get_nodes_with_slivers()
-        for sliver in slivers:
-            hostname = sliver['hostname']
-            if hostname in nodes_with_slivers:
-                continue
-            nodes = self.xml.xpath('//rspecv2:node[@component_name="%s"] | //node[@component_name="%s"]' % (hostname, hostname), namespaces=self.namespaces)
-            if nodes:
-                node = nodes[0]
-                node.set('client_id', hostname)
-                if sliver_urn:
-                    node.set('sliver_id', sliver_urn)
-                etree.SubElement(node, 'sliver_type', name='plab-vnode')
-
-    def add_interfaces(self, interfaces, no_dupes=False):
-        pass
-
-    def add_links(self, links, no_dupes=False):
-        pass
-
-
-    def merge(self, in_rspec):
-        """
-        Merge contents for specified rspec with current rspec
-        """
-        
-        # just copy over all the child elements under the root element
-        tree = etree.parse(StringIO(in_rspec))
-        root = tree.getroot()
-        for child in root.getchildren():
-            self.xml.append(child)
-                  
-    def cleanup(self):
-        # remove unncecessary elements, attributes
-        if self.type in ['request', 'manifest']:
-            # remove nodes without slivers
-            nodes = self.get_node_elements()
-            for node in nodes:
-                delete = True
-                hostname = node.get('component_name')
-                parent = node.getparent()
-                children = node.getchildren()
-                for child in children:
-                    if child.tag.endswith('sliver_type'):
-                        delete = False
-                if delete:
-                    parent.remove(node)
-
-            # remove 'available' element from remaining node elements
-            self.remove_element('//rspecv2:available | //available')
-
-if __name__ == '__main__':
-    rspec = PGRSpec()
-    rspec.add_nodes([1])
-    print rspec
index e9a34eb..42e7ccd 100755 (executable)
@@ -2,8 +2,8 @@
 from lxml import etree
 from StringIO import StringIO
 from sfa.util.xrn import *
-from sfa.rspecs.pg_rspec import PGRSpec 
-from sfa.rspecs.sfa_rspec import SfaRSpec
+from sfa.rspecs.rspec import RSpec
+from sfa.rspecs.version_manager import VersionManager
 
 xslt='''<xsl:stylesheet version="1.0" xmlns:xsl="http://www.w3.org/1999/XSL/Transform">
 <xsl:output method="xml" indent="no"/>
@@ -34,33 +34,39 @@ transform=etree.XSLT(xslt_doc)
 class PGRSpecConverter:
 
     @staticmethod
-    def to_sfa_rspec(rspec):
-        if isinstance(rspec, PGRSpec):
+    def to_sfa_rspec(rspec, content_type = None):
+        if not isinstance(rspec, RSpec):
+            pg_rspec = RSpec(rspec)
+        else:
             pg_rspec = rspec
-        else:        
-            pg_rspec = PGRSpec(rspec=rspec)
-        sfa_rspec = SfaRSpec()
+        
+        version_manager = VersionManager()
+        sfa_version = version_manager._get_version('sfa', '1')    
+        sfa_rspec = RSpec(version=sfa_version)
 
         # get network
-        network_urn = pg_rspec.get_network()
+        network_urn = pg_rspec.version.get_network()
         network,  _ = urn_to_hrn(network_urn)
-        network_element = sfa_rspec.add_element('network', {'name': network, 'id': network})
+        network_element = sfa_rspec.xml.add_element('network', {'name': network, 'id': network})
         
         # get nodes
-        pg_nodes_elements = pg_rspec.get_node_elements()
-        nodes_with_slivers = pg_rspec.get_nodes_with_slivers()
+        pg_nodes_elements = pg_rspec.version.get_node_elements()
+        nodes_with_slivers = pg_rspec.version.get_nodes_with_slivers()
         i = 1
         for pg_node_element in pg_nodes_elements:
-            node_element = sfa_rspec.add_element('node', {'id': 'n'+str(i)}, parent=network_element)
-            urn = pg_node_element.xpath('@component_uuid | @component_id')
+            attribs = dict(pg_node_element.attrib.items()) 
+            attribs['id'] = 'n'+str(i)
+            
+            node_element = sfa_rspec.xml.add_element('node', attribs, parent=network_element)
+            urn = pg_node_element.xpath('@component_id', namespaces=pg_rspec.namespaces)
             if urn:
                 urn = urn[0]
                 hostname = Xrn.urn_split(urn)[-1]
-                hostname_element = sfa_rspec.add_element('hostname', parent=node_element, text=hostname)
+                hostname_element = sfa_rspec.xml.add_element('hostname', parent=node_element, text=hostname)
                 if hostname in nodes_with_slivers:
-                    sfa_rspec.add_element('sliver', parent=node_element)
+                    sfa_rspec.xml.add_element('sliver', parent=node_element)
                      
-            urn_element = sfa_rspec.add_element('urn', parent=node_element, text=urn)
+            urn_element = sfa_rspec.xml.add_element('urn', parent=node_element, text=urn)
 
 
             # just copy over remaining child elements  
diff --git a/sfa/rspecs/resources/__init__.py b/sfa/rspecs/resources/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/sfa/rspecs/resources/ext/__init__.py b/sfa/rspecs/resources/ext/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/sfa/rspecs/resources/ext/planetlab.rnc b/sfa/rspecs/resources/ext/planetlab.rnc
new file mode 100644 (file)
index 0000000..f1ff971
--- /dev/null
@@ -0,0 +1,13 @@
+#
+## Extension for the "initscript" type for RSpecV2 on PlanetLab
+## Version 1
+##
+
+default namespace = "http://www.planet-lab.org/resources/ext/initscript/1"
+
+Node = element initscript {
+   attribute name { text }
+}
+
+start = Node
+
diff --git a/sfa/rspecs/resources/ext/planetlab.xsd b/sfa/rspecs/resources/ext/planetlab.xsd
new file mode 100644 (file)
index 0000000..e862877
--- /dev/null
@@ -0,0 +1,17 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+
+   Extension for the "initscript" type for RSpecV2 on PlanetLab
+   Version 1
+
+-->
+<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema" 
+elementFormDefault="qualified" 
+targetNamespace="http://www.planet-lab.org/resources/sfa/ext/planetlab/1" 
+xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1">
+   <xs:element name="initscript">
+     <xs:complexType>
+       <xs:attribute name="name" use="required"/>
+     </xs:complexType>
+   </xs:element>
+</xs:schema>
index c127ae1..b86f996 100755 (executable)
 #!/usr/bin/python 
-from lxml import etree
-from StringIO import StringIO
 from datetime import datetime, timedelta
+from sfa.util.xml import XML, XpathFilter
+from sfa.rspecs.version_manager import VersionManager
 from sfa.util.xrn import *
 from sfa.util.plxrn import hostname_to_urn
-from sfa.util.faults import SfaNotImplemented, InvalidRSpec
+from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements 
+from sfa.util.faults import SfaNotImplemented, InvalidRSpec, InvalidRSpecElement
 
 class RSpec:
-    header = '<?xml version="1.0"?>\n'
-    template = """<RSpec></RSpec>"""
-    xml = None
-    type = None
-    version = None
-    namespaces = None
-    user_options = {}
-  
-    def __init__(self, rspec="", namespaces={}, type=None, user_options={}):
-        self.type = type
+    def __init__(self, rspec="", version=None, user_options={}):
+        self.header = '<?xml version="1.0"?>\n'
+        self.template = """<RSpec></RSpec>"""
+        self.version = None
+        self.xml = XML()
+        self.version_manager = VersionManager()
         self.user_options = user_options
+        self.elements = {}
         if rspec:
-            self.parse_rspec(rspec, namespaces)
+            self.parse_xml(rspec)
         else:
-            self.create()
+            self.create(version)
 
-    def create(self):
+    def create(self, version=None):
         """
         Create root element
         """
+        self.version = self.version_manager.get_version(version)
+        self.namespaces = self.version.namespaces
+        self.parse_xml(self.version.template) 
         # eg. 2011-03-23T19:53:28Z 
         date_format = '%Y-%m-%dT%H:%M:%SZ'
         now = datetime.utcnow()
         generated_ts = now.strftime(date_format)
         expires_ts = (now + timedelta(hours=1)).strftime(date_format) 
-        self.parse_rspec(self.template, self.namespaces)
         self.xml.set('expires', expires_ts)
         self.xml.set('generated', generated_ts)
-    
-    def parse_rspec(self, rspec, namespaces={}):
-        """
-        parse rspec into etree
-        """
-        parser = etree.XMLParser(remove_blank_text=True)
-        try:
-            tree = etree.parse(rspec, parser)
-        except IOError:
-            # 'rspec' file doesnt exist. 'rspec' is proably an xml string
-            try:
-                tree = etree.parse(StringIO(rspec), parser)
-            except:
-                raise InvalidRSpec('Must specify a xml file or xml string. Received: ' + rspec )
-        self.xml = tree.getroot()  
-        if namespaces:
-           self.namespaces = namespaces
 
-    def xpath(self, xpath):
-        return this.xml.xpath(xpath, namespaces=self.namespaces)
 
-    def add_attribute(self, elem, name, value):
-        """
-        Add attribute to specified etree element    
-        """
-        opt = etree.SubElement(elem, name)
-        opt.text = value
+    def parse_xml(self, xml):
+        self.xml.parse_xml(xml)
+        self.version = None
+        if self.xml.schema:
+            self.version = self.version_manager.get_version_by_schema(self.xml.schema)
+        else:
+            #raise InvalidRSpec('unknown rspec schema: %s' % schema)
+            # TODO: Should start raising an exception once SFA defines a schema.
+            # for now we just use the default  
+            self.version = self.version_manager.get_version()
+        self.version.xml = self.xml    
+        self.namespaces = self.xml.namespaces
+    
+    def load_rspec_elements(self, rspec_elements):
+        self.elements = {}
+        for rspec_element in rspec_elements:
+            if isinstance(rspec_element, RSpecElement):
+                self.elements[rspec_element.type] = rspec_element
 
-    def add_element(self, name, attrs={}, parent=None, text=""):
-        """
-        Generic wrapper around etree.SubElement(). Adds an element to 
-        specified parent node. Adds element to root node is parent is 
-        not specified. 
-        """
-        if parent == None:
-            parent = self.xml
-        element = etree.SubElement(parent, name)
-        if text:
-            element.text = text
-        if isinstance(attrs, dict):
-            for attr in attrs:
-                element.set(attr, attrs[attr])  
-        return element
+    def register_rspec_element(self, element_type, element_name, element_path):
+        if element_type not in RSpecElements:
+            raise InvalidRSpecElement(element_type, extra="no such element type: %s. Must specify a valid RSpecElement" % element_type)
+        self.elements[element_type] = RSpecElement(element_type, element_name, element_path)
 
-    def remove_attribute(self, elem, name, value):
-        """
-        Removes an attribute from an element
-        """
-        if elem is not None:
-            opts = elem.iterfind(name)
-            if opts is not None:
-                for opt in opts:
-                    if opt.text == value:
-                        elem.remove(opt)
+    def get_rspec_element(self, element_type):
+        if element_type not in self.elements:
+            msg = "ElementType %s not registerd for this rspec" % element_type
+            raise InvalidRSpecElement(element_type, extra=msg)
+        return self.elements[element_type]
 
-    def remove_element(self, element_name, root_node = None):
+    def get(self, element_type, filter={}, depth=0):
+        elements = self.get_elements(element_type, filter)
+        elements = [self.get_element_attributes(element, depth=depth) for element in elements]
+        return elements
+
+    def get_elements(self, element_type, filter={}):
         """
-        Removes all occurences of an element from the tree. Start at 
-        specified root_node if specified, otherwise start at tree's root.   
+        search for a registered element
         """
-        if not root_node:
-            root_node = self.xml
-
-        if not element_name.startswith('//'):
-            element_name = '//' + element_name
-
-        elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
-        for element in elements:
-            parent = element.getparent()
-            parent.remove(element)
-         
+        if element_type not in self.elements:
+            msg = "Unable to search for element %s in rspec, expath expression not found." % \
+                   element_type
+            raise InvalidRSpecElement(element_type, extra=msg)
+        rspec_element = self.get_rspec_element(element_type)
+        xpath = rspec_element.path + XpathFilter.xpath(filter)
+        return self.xpath(xpath)
 
     def merge(self, in_rspec):
-        pass
+        self.version.merge(in_rspec)
 
-    def validate(self, schema):
-        """
-        Validate against rng schema
-        """
+    def filter(self, filter):
+        if 'component_manager_id' in filter:    
+            nodes = self.version.get_node_elements()
+            for node in nodes:
+                if 'component_manager_id' not in node.attrib or \
+                  node.attrib['component_manager_id'] != filter['component_manager_id']:
+                    parent = node.getparent()
+                    parent.remove(node) 
         
-        relaxng_doc = etree.parse(schema)
-        relaxng = etree.RelaxNG(relaxng_doc)
-        if not relaxng(self.xml):
-            error = relaxng.error_log.last_error
-            message = "%s (line %s)" % (error.message, error.line)
-            raise InvalidRSpec(message)
-        return True
-        
-    def cleanup(self):
-        """
-        Optional method which inheriting classes can choose to implent. 
-        """
-        pass 
 
-    def _process_slivers(self, slivers):
-        """
-        Creates a dict of sliver details for each sliver host
-        
-        @param slivers a single hostname, list of hostanmes or list of dicts keys on hostname,
-        Returns a list of dicts 
-        """
-        if not isinstance(slivers, list):
-            slivers = [slivers]
-        dicts = []
-        for sliver in slivers:
-            if isinstance(sliver, dict):
-                dicts.append(sliver)
-            elif isinstance(sliver, basestring):
-                dicts.append({'hostname': sliver}) 
-        return dicts
-
-    def __str__(self):
-        return self.toxml()
+    def toxml(self, header=True):
+        if header:
+            return self.header + self.xml.toxml()
+        else:
+            return self.xml.toxml()
+    
 
-    def toxml(self, cleanup=False):
-        if cleanup:
-            self.cleanup()
-        return self.header + etree.tostring(self.xml, pretty_print=True)  
-        
     def save(self, filename):
-        f = open(filename, 'w')
-        f.write(self.toxml())
-        f.close()
+        return self.xml.save(filename)
+         
 if __name__ == '__main__':
-    rspec = RSpec()
+    rspec = RSpec('/tmp/resources.rspec')
     print rspec
+    rspec.register_rspec_element(RSpecElements.NETWORK, 'network', '//network')
+    rspec.register_rspec_element(RSpecElements.NODE, 'node', '//node')
+    print rspec.get(RSpecElements.NODE)[0]
+    print rspec.get(RSpecElements.NODE, depth=1)[0]
+
index 89f03a4..7dff2f0 100755 (executable)
@@ -2,30 +2,35 @@
 
 from sfa.rspecs.pg_rspec_converter import PGRSpecConverter
 from sfa.rspecs.sfa_rspec_converter import SfaRSpecConverter
-from sfa.rspecs.sfa_rspec import sfa_rspec_version
-from sfa.rspecs.pg_rspec import pg_rspec_ad_version, pg_rspec_request_version
-from sfa.rspecs.rspec_parser import parse_rspec
-
+from sfa.rspecs.rspec import RSpec
+from sfa.rspecs.version_manager import VersionManager
 
 class RSpecConverter:
 
     @staticmethod
-    def to_sfa_rspec(in_rspec):
-        rspec = parse_rspec(in_rspec)
-        if rspec.version['type'] == sfa_rspec_version['type']: 
+    def to_sfa_rspec(in_rspec, content_type=None):
+        rspec = RSpec(in_rspec)
+        version_manager = VersionManager()
+        sfa_version = version_manager._get_version('sfa', '1')
+        pg_version = version_manager._get_version('protogeni', '2')
+        if rspec.version.type.lower() == sfa_version.type.lower(): 
           return in_rspec
-        elif rspec.version['type'] == pg_rspec_ad_version['type']:
-            return PGRSpecConverter.to_sfa_rspec(in_rspec)
+        elif rspec.version.type.lower() == pg_version.type.lower(): 
+            return PGRSpecConverter.to_sfa_rspec(in_rspec, content_type)
         else:
-             return in_rspec 
+            return in_rspec 
 
     @staticmethod 
-    def to_pg_rspec(in_rspec):
-        rspec = parse_rspec(in_rspec)
-        if rspec.version['type'] == pg_rspec_ad_version['type']:
+    def to_pg_rspec(in_rspec, content_type=None):
+        rspec = RSpec(in_rspec)
+        version_manager = VersionManager()
+        sfa_version = version_manager._get_version('sfa', '1')
+        pg_version = version_manager._get_version('protogeni', '2')
+
+        if rspec.version.type.lower() == pg_version.type.lower(): 
             return in_rspec
-        elif rspec.version['type'] == sfa_rspec_version['type']:
-            return SfaRSpecConverter.to_pg_rspec(in_rspec)
+        elif rspec.version.type.lower() == sfa_version.type.lower(): 
+            return SfaRSpecConverter.to_pg_rspec(in_rspec, content_type)
         else:
             return in_rspec 
 
diff --git a/sfa/rspecs/rspec_elements.py b/sfa/rspecs/rspec_elements.py
new file mode 100644 (file)
index 0000000..4209139
--- /dev/null
@@ -0,0 +1,18 @@
+from sfa.util.enumeration import Enum
+
+# recognized top level rspec elements
+RSpecElements = Enum(NETWORK='NETWORK', 
+                     COMPONENT_MANAGER='COMPONENT_MANAGER', 
+                     SLIVER='SLIVER', 
+                     NODE='NODE', 
+                     INTERFACE='INTERFACE', 
+                     LINK='LINK', 
+                     SERVICE='SERVICE'
+                )
+
+class RSpecElement:
+    def __init__(self, element_type, path):
+        if not element_type in RSpecElements:
+            raise InvalidRSpecElement(element_type)
+        self.type = element_type
+        self.path = path
diff --git a/sfa/rspecs/rspec_parser.py b/sfa/rspecs/rspec_parser.py
deleted file mode 100755 (executable)
index 8e3bced..0000000
+++ /dev/null
@@ -1,33 +0,0 @@
-#!/usr/bin/python
-from sfa.rspecs.sfa_rspec import SfaRSpec
-from sfa.rspecs.pg_rspec import PGRSpec
-from sfa.rspecs.rspec import RSpec
-from lxml import etree 
-
-def parse_rspec(in_rspec):
-    rspec = RSpec(rspec=in_rspec)
-    # really simple check
-    # TODO: check against schema instead
-    out_rspec = None 
-    if rspec.xml.xpath('//network'):
-        #out_rspec = SfaRSpec(in_rspec)
-        out_rspec = SfaRSpec()
-        out_rspec.xml = rspec.xml
-    else:
-        #out_rspec = PGRSpec(in_rspec)
-        # TODO: determine if this is an ad or request
-        out_rspec = PGRSpec()
-        out_rspec.xml = rspec.xml
-    return out_rspec
-
-if __name__ == '__main__':
-    
-    print "Parsing SFA RSpec:", 
-    rspec = parse_rspec('nodes.rspec')
-    print rspec.version
-    rspec = parse_rspec('protogeni.rspec')
-    print "Parsing ProtoGENI RSpec:", 
-    print rspec.version
-    
-    
-
old mode 100755 (executable)
new mode 100644 (file)
index b96a765..8e311b7
@@ -1,42 +1,30 @@
 #!/usr/bin/python
-from sfa.util.sfalogging import _SfaLogger
+from sfa.util.sfalogging import logger
 
-class RSpecVersion(dict):
+class BaseVersion:
+    type = None
+    content_type = None
+    version = None
+    schema = None
+    namespace = None
+    extensions = {}
+    namespaces = dict(extensions.items() + [('default', namespace)])
+    elements = []
+    enabled = False
 
-    fields = {'type': None,
-              'version': None,
-              'schema': None,
-              'namespace': None,
-              'extensions': []
-        }
-    def __init__(self, version={}):
-        
-        self.logger = _SfaLogger('/var/log/sfa.log')
-        dict.__init__(self, self.fields)
-
-        if not version:
-            from sfa.rspecs.sfa_rspec import sfa_rspec_version
-            self.update(sfa_rspec_version)          
-        elif isinstance(version, dict):
-            self.update(version)
-        elif isinstance(version, basestring):
-            version_parts = version.split(' ')
-            num_parts = len(version_parts)
-            self['type'] = version_parts[0]
-            if num_parts > 1:
-                self['version'] = version_parts[1]
-        else:
-            logger.info("Unable to parse rspec version, using default")
+    def __init__(self, xml=None):
+        self.xml = xml
 
-    def get_version_name(self):
-        return "%s %s" % (str(self['type']), str(self['version']))
-
-if __name__ == '__main__':
+    def to_dict(self):
+        return {
+            'type': self.type,
+            'version': self.version,
+            'schema': self.schema,
+            'namespace': self.namespace,
+            'extensions': self.extensions
+        }
 
-    from sfa.rspecs.pl_rspec_version import ad_rspec_versions
-    for version in [RSpecVersion(), 
-                    RSpecVersion("SFA"), 
-                    RSpecVersion("SFA 1"),
-                    RSpecVersion(ad_rspec_versions[0])]: 
-        print version.get_version_name() + ": " + str(version)
+    def to_string(self):
+        return "%s %s" % (self.type, self.version)
+    
 
index 096f422..6ba56c1 100755 (executable)
@@ -3,56 +3,82 @@
 from lxml import etree
 from StringIO import StringIO
 from sfa.util.xrn import *
-from sfa.rspecs.sfa_rspec import SfaRSpec
-from sfa.rspecs.pg_rspec import PGRSpec
+from sfa.rspecs.rspec import RSpec
+from sfa.rspecs.version_manager import VersionManager
 
 class SfaRSpecConverter:
 
     @staticmethod
-    def to_pg_rspec(rspec):
-        if isinstance(rspec, SfaRSpec):
-            sfa_rspec = rspec
+    def to_pg_rspec(rspec, content_type = None):
+        if not isinstance(rspec, RSpec):
+            sfa_rspec = RSpec(rspec)
         else:
-            sfa_rspec = SfaRSpec(rspec=rspec)
-        pg_rspec = PGRSpec()
-    
+            sfa_rspec = rspec
+  
+        if not content_type or content_type not in \
+          ['ad', 'request', 'manifest']:
+            content_type = sfa_rspec.version.content_type
+     
+        version_manager = VersionManager()
+        pg_version = version_manager._get_version('protogeni', '2', 'request')
+        pg_rspec = RSpec(version=pg_version)
         # get networks
-        networks = sfa_rspec.get_networks()
+        networks = sfa_rspec.version.get_networks()
         
         for network in networks:
             # get nodes
-            sfa_node_elements = sfa_rspec.get_node_elements(network=network)
+            sfa_node_elements = sfa_rspec.version.get_node_elements(network=network)
             for sfa_node_element in sfa_node_elements:
                 # create node element
                 node_attrs = {}
                 node_attrs['exclusive'] = 'false'
-                node_attrs['component_manager_id'] = network
-                if sfa_node_element.find('hostname') != None:
-                    node_attrs['component_name'] = sfa_node_element.find('hostname').text
-                if sfa_node_element.find('urn') != None:    
-                    node_attrs['component_id'] = sfa_node_element.find('urn').text
-                node_element = pg_rspec.add_element('node', node_attrs)
+                if 'component_manager_id' in sfa_node_element.attrib:
+                    node_attrs['component_manager_id'] = sfa_node_element.attrib['component_manager_id']
+                else:
+                    node_attrs['component_manager_id'] = hrn_to_urn(network, 'authority+cm')
 
-                # create node_type element
-                for hw_type in ['plab-pc', 'pc']:
-                    hdware_type_element = pg_rspec.add_element('hardware_type', {'name': hw_type}, parent=node_element)
-                # create available element
-                pg_rspec.add_element('available', {'now': 'true'}, parent=node_element)
-                # create locaiton element
-                # We don't actually associate nodes with a country. 
-                # Set country to "unknown" until we figure out how to make
-                # sure this value is always accurate.
-                location = sfa_node_element.find('location')
-                if location != None:
-                    location_attrs = {}      
-                    location_attrs['country'] = locatiton.get('country', 'unknown')
-                    location_attrs['latitude'] = location.get('latitiue', 'None')
-                    location_attrs['longitude'] = location.get('longitude', 'None')
-                    pg_rspec.add_element('location', location_attrs, parent=node_element)
+                if 'component_id' in sfa_node_element.attrib:
+                    node_attrs['compoenent_id'] = sfa_node_element.attrib['component_id']
 
-                sliver_element = sfa_node_element.find('sliver')
-                if sliver_element != None:
-                    pg_rspec.add_element('sliver_type', {'name': 'planetlab-vnode'}, parent=node_element)
+                if sfa_node_element.find('hostname') != None:
+                    hostname = sfa_node_element.find('hostname').text
+                    node_attrs['component_name'] = hostname
+                    node_attrs['client_id'] = hostname
+                node_element = pg_rspec.xml.add_element('node', node_attrs)    
+            
+                if content_type == 'request':
+                    sliver_element = sfa_node_element.find('sliver')
+                    sliver_type_elements = sfa_node_element.xpath('./sliver_type', namespaces=sfa_rspec.namespaces)
+                    available_sliver_types = [element.attrib['name'] for element in sliver_type_elements]
+                    valid_sliver_types = ['emulab-openvz', 'raw-pc']
+                   
+                    # determine sliver type 
+                    requested_sliver_type = 'emulab-openvz'
+                    for available_sliver_type in available_sliver_types:
+                        if available_sliver_type in valid_sliver_types:
+                            requested_sliver_type = available_sliver_type
+                                
+                    if sliver_element != None:
+                        pg_rspec.xml.add_element('sliver_type', {'name': requested_sliver_type}, parent=node_element) 
+                else:
+                    # create node_type element
+                    for hw_type in ['plab-pc', 'pc']:
+                        hdware_type_element = pg_rspec.xml.add_element('hardware_type', {'name': hw_type}, parent=node_element)
+                    # create available element
+                    pg_rspec.xml.add_element('available', {'now': 'true'}, parent=node_element)
+                    # create locaiton element
+                    # We don't actually associate nodes with a country. 
+                    # Set country to "unknown" until we figure out how to make
+                    # sure this value is always accurate.
+                    location = sfa_node_element.find('location')
+                    if location != None:
+                        location_attrs = {}      
+                        location_attrs['country'] =  location.get('country', 'unknown')
+                        location_attrs['latitude'] = location.get('latitude', 'None')
+                        location_attrs['longitude'] = location.get('longitude', 'None')
+                        pg_rspec.xml.add_element('location', location_attrs, parent=node_element)
 
         return pg_rspec.toxml()
 
diff --git a/sfa/rspecs/version_manager.py b/sfa/rspecs/version_manager.py
new file mode 100644 (file)
index 0000000..f53ec6f
--- /dev/null
@@ -0,0 +1,80 @@
+import os
+from sfa.util.faults import InvalidRSpec
+from sfa.rspecs.rspec_version import BaseVersion 
+from sfa.util.sfalogging import logger    
+
+class VersionManager:
+    default_type = 'SFA'
+    default_version_num = '1'     
+        
+    def __init__(self):
+        self.versions = []
+        self.load_versions()
+
+    def load_versions(self):
+        path = os.path.dirname(os.path.abspath( __file__ ))
+        versions_path = path + os.sep + 'versions'
+        versions_module_path = 'sfa.rspecs.versions'
+        valid_module = lambda x: os.path.isfile(os.sep.join([versions_path, x])) \
+                        and not x.endswith('.pyc') and x not in ['__init__.py']
+        files = [f for f in os.listdir(versions_path) if valid_module(f)]
+        for filename in files:
+            basename = filename.split('.')[0]
+            module_path = versions_module_path +'.'+basename
+            module = __import__(module_path, fromlist=module_path)
+            for attr_name in dir(module):
+                attr = getattr(module, attr_name)
+                if hasattr(attr, 'version') and hasattr(attr, 'enabled') and attr.enabled == True:
+                    self.versions.append(attr())
+
+    def _get_version(self, type, version_num=None, content_type=None):
+        retval = None
+        for version in self.versions:
+            if type is None or type.lower() == version.type.lower():
+                if version_num is None or version_num == version.version:
+                    if content_type is None or content_type.lower() == version.content_type.lower() \
+                      or version.content_type == '*':
+                        retval = version
+        if not retval:
+            raise InvalidRSpec("No such version format: %s version: %s type:%s "% (type, version_num, content_type))
+        return retval
+
+    def get_version(self, version=None):
+        retval = None
+        if isinstance(version, dict):
+            retval =  self._get_version(version.get('type'), version.get('version'), version.get('content_type'))
+        elif isinstance(version, basestring):
+            version_parts = version.split(' ')     
+            num_parts = len(version_parts)
+            type = version_parts[0]
+            version_num = None
+            content_type = None
+            if num_parts > 1:
+                version_num = version_parts[1]
+            if num_parts > 2:
+                content_type = version_parts[2]
+            retval = self._get_version(type, version_num, content_type) 
+        elif isinstance(version, BaseVersion):
+            retval = version
+        else:
+            retval = self._get_version(self.default_type, self.default_version_num)   
+        return retval
+
+    def get_version_by_schema(self, schema):
+        retval = None
+        for version in self.versions:
+            if schema == version.schema:
+                retval = version
+        if not retval:
+            raise InvalidRSpec("Unkwnown RSpec schema: %s" % schema)
+        return retval
+
+if __name__ == '__main__':
+    v = VersionManager()
+    print v.versions
+    print v.get_version('sfa 1') 
+    print v.get_version('protogeni 2') 
+    print v.get_version('protogeni 2 advertisement') 
+    print v.get_version_by_schema('http://www.protogeni.net/resources/rspec/2/ad.xsd') 
+
diff --git a/sfa/rspecs/versions/__init__.py b/sfa/rspecs/versions/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/sfa/rspecs/versions/pgv2.py b/sfa/rspecs/versions/pgv2.py
new file mode 100644 (file)
index 0000000..3e995ad
--- /dev/null
@@ -0,0 +1,271 @@
+from lxml import etree
+from copy import deepcopy
+from StringIO import StringIO
+from sfa.util.xrn import *
+from sfa.util.plxrn import hostname_to_urn, xrn_to_hostname 
+from sfa.rspecs.rspec_version import BaseVersion
+from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements
+from sfa.rspecs.elements.versions.pgv2Link import PGv2Link
+class PGv2(BaseVersion):
+    type = 'ProtoGENI'
+    content_type = 'ad'
+    version = '2'
+    schema = 'http://www.protogeni.net/resources/rspec/2/ad.xsd'
+    namespace = 'http://www.protogeni.net/resources/rspec/2'
+    extensions = {
+        'flack': "http://www.protogeni.net/resources/rspec/ext/flack/1",
+        'planetlab': "http://www.planet-lab.org/resources/sfa/ext/planetlab/1",
+    }
+    namespaces = dict(extensions.items() + [('default', namespace)])
+    elements = []
+
+    def get_network(self):
+        network = None
+        nodes = self.xml.xpath('//default:node[@component_manager_id][1]', namespaces=self.namespaces)
+        if nodes:
+            network  = nodes[0].get('component_manager_id')
+        return network
+
+    def get_networks(self):
+        networks = self.xml.xpath('//default:node[@component_manager_id]/@component_manager_id', namespaces=self.namespaces)
+        return set(networks)
+
+    def get_node_element(self, hostname, network=None):
+        nodes = self.xml.xpath('//default:node[@component_id[contains(., "%s")]] | node[@component_id[contains(., "%s")]]' % (hostname, hostname), namespaces=self.namespaces)
+        if isinstance(nodes,list) and nodes:
+            return nodes[0]
+        else:
+            return None
+
+    def get_node_elements(self, network=None):
+        nodes = self.xml.xpath('//default:node | //node', namespaces=self.namespaces)
+        return nodes
+
+
+    def get_nodes(self, network=None):
+        xpath = '//default:node[@component_name]/@component_id | //node[@component_name]/@component_id'
+        nodes = self.xml.xpath(xpath, namespaces=self.namespaces)
+        nodes = [xrn_to_hostname(node) for node in nodes]
+        return nodes
+
+    def get_nodes_with_slivers(self, network=None):
+        if network:
+            nodes = self.xml.xpath('//default:node[@component_manager_id="%s"][sliver_type]/@component_id' % network, namespaces=self.namespaces)
+        else:
+            nodes = self.xml.xpath('//default:node[default:sliver_type]/@component_id', namespaces=self.namespaces)
+        nodes = [xrn_to_hostname(node) for node in nodes]
+        return nodes
+
+    def get_nodes_without_slivers(self, network=None):
+        return []
+
+    def get_sliver_attributes(self, hostname, network=None):
+        node = self.get_node_element(hostname, network)
+        sliver = node.xpath('./default:sliver_type', namespaces=self.namespaces)
+        if sliver is not None and isinstance(sliver, list):
+            sliver = sliver[0]
+        return self.attributes_list(sliver)
+
+    def get_slice_attributes(self, network=None):
+        slice_attributes = []
+        nodes_with_slivers = self.get_nodes_with_slivers(network)
+        # TODO: default sliver attributes in the PG rspec?
+        default_ns_prefix = self.namespaces['default']
+        for node in nodes_with_slivers:
+            sliver_attributes = self.get_sliver_attributes(node, network)
+            for sliver_attribute in sliver_attributes:
+                name=str(sliver_attribute[0])
+                text =str(sliver_attribute[1])
+                attribs = sliver_attribute[2]
+                # we currently only suppor the <initscript> and <flack> attributes
+                if  'info' in name:
+                    attribute = {'name': 'flack_info', 'value': str(attribs), 'node_id': node}
+                    slice_attributes.append(attribute)
+                elif 'initscript' in name:
+                    if attribs is not None and 'name' in attribs:
+                        value = attribs['name']
+                    else:
+                        value = text
+                    attribute = {'name': 'initscript', 'value': value, 'node_id': node}
+                    slice_attributes.append(attribute)
+
+        return slice_attributes
+
+    def get_links(self, network=None):
+        links = PGv2Link.get_links(self.xml.root, self.namespaces)
+        return links
+
+    def add_links(self, links):
+        PGv2Link.add_links(self.xml.root, links)
+
+    def attributes_list(self, elem):
+        opts = []
+        if elem is not None:
+            for e in elem:
+                opts.append((e.tag, str(e.text).strip(), e.attrib))
+        return opts
+
+    def get_default_sliver_attributes(self, network=None):
+        return []
+
+    def add_default_sliver_attribute(self, name, value, network=None):
+        pass
+
+    def add_nodes(self, nodes, check_for_dupes=False):
+        if not isinstance(nodes, list):
+            nodes = [nodes]
+        for node in nodes:
+            urn = ""
+            if check_for_dupes and \
+              self.xml.xpath('//default:node[@component_uuid="%s"]' % urn, namespaces=self.namespaces):
+                # node already exists
+                continue
+
+            node_tag = etree.SubElement(self.xml.root, 'node', exclusive='false')
+            if 'network_urn' in node:
+                node_tag.set('component_manager_id', node['network_urn'])
+            if 'urn' in node:
+                node_tag.set('component_id', node['urn'])
+            if 'hostname' in node:
+                node_tag.set('component_name', node['hostname'])
+            # TODO: should replace plab-pc with pc model
+            node_type_tag = etree.SubElement(node_tag, 'hardware_type', name='plab-pc')
+            node_type_tag = etree.SubElement(node_tag, 'hardware_type', name='pc')
+            available_tag = etree.SubElement(node_tag, 'available', now='true')
+            sliver_type_tag = etree.SubElement(node_tag, 'sliver_type', name='plab-vserver')
+
+            pl_initscripts = node.get('pl_initscripts', {})
+            for pl_initscript in pl_initscripts.values():
+                etree.SubElement(sliver_type_tag, '{%s}initscript' % self.namespaces['planetlab'], name=pl_initscript['name'])
+
+            # protogeni uses the <sliver_type> tag to identify the types of
+            # vms available at the node.
+            # only add location tag if longitude and latitude are not null
+            if 'site' in node:
+                longitude = node['site'].get('longitude', None)
+                latitude = node['site'].get('latitude', None)
+                if longitude and latitude:
+                    location_tag = etree.SubElement(node_tag, 'location', country="us", \
+                                                    longitude=str(longitude), latitude=str(latitude))
+
+    def merge_node(self, source_node_tag):
+        # this is untested
+        self.xml.root.append(deepcopy(source_node_tag))
+
+    def add_slivers(self, slivers, sliver_urn=None, no_dupes=False, append=False):
+
+        # all nodes hould already be present in the rspec. Remove all
+        # nodes that done have slivers
+        slivers_dict = {}
+        for sliver in slivers:
+            if isinstance(sliver, basestring):
+                slivers_dict[sliver] = {'hostname': sliver}
+            elif isinstance(sliver, dict):
+                slivers_dict[sliver['hostname']] = sliver        
+
+        nodes = self.get_node_elements()
+        for node in nodes:
+            urn = node.get('component_id')
+            hostname = xrn_to_hostname(urn)
+            if hostname not in slivers_dict and not append:
+                parent = node.getparent()
+                parent.remove(node)
+            else:
+                sliver_info = slivers_dict[hostname]
+                sliver_type_elements = node.xpath('./default:sliver_type', namespaces=self.namespaces)
+                available_sliver_types = [element.attrib['name'] for element in sliver_type_elements]
+                valid_sliver_types = ['emulab-openvz', 'raw-pc', 'plab-vserver', 'plab-vnode']
+                requested_sliver_type = None
+                for valid_sliver_type in valid_sliver_types:
+                    if valid_sliver_type in available_sliver_types:
+                        requested_sliver_type = valid_sliver_type
+                if requested_sliver_type:
+                    # remove existing sliver_type tags,it needs to be recreated
+                    sliver_elem = node.xpath('./default:sliver_type | ./sliver_type', namespaces=self.namespaces)
+                    if sliver_elem and isinstance(sliver_elem, list):
+                        sliver_elem = sliver_elem[0]
+                        node.remove(sliver_elem)
+                    # set the client id
+                    node.set('client_id', hostname)
+                    if sliver_urn:
+                        # set the sliver id
+                        slice_id = sliver_info.get('slice_id', -1)
+                        node_id = sliver_info.get('node_id', -1)
+                        sliver_id = urn_to_sliver_id(sliver_urn, slice_id, node_id)
+                        node.set('sliver_id', sliver_id)
+
+                    # add the sliver element
+                    sliver_elem = etree.SubElement(node, 'sliver_type', name=requested_sliver_type)
+                    for tag in sliver_info.get('tags', []):
+                        if tag['tagname'] == 'flack_info':
+                            e = etree.SubElement(sliver_elem, '{%s}info' % self.namespaces['flack'], attrib=eval(tag['value']))
+                        elif tag['tagname'] == 'initscript':
+                            e = etree.SubElement(sliver_elem, '{%s}initscript' % self.namespaces['planetlab'], attrib={'name': tag['value']})                
+                else:
+                    # node isn't usable. just remove it from the request     
+                    parent = node.getparent()
+                    parent.remove(node)
+
+    
+
+    def remove_slivers(self, slivers, network=None, no_dupes=False):
+        for sliver in slivers:
+            node_elem = self.get_node_element(sliver['hostname'])
+            sliver_elem = node_elem.xpath('./default:sliver_type', self.namespaces)
+            if sliver_elem != None and sliver_elem != []:
+                node_elem.remove(sliver_elem[0])
+
+    def add_default_sliver_attribute(self, name, value, network=None):
+        pass
+
+    def add_interfaces(self, interfaces, no_dupes=False):
+        pass
+
+    def merge(self, in_rspec):
+        """
+        Merge contents for specified rspec with current rspec
+        """
+        from sfa.rspecs.rspec import RSpec
+        # just copy over all the child elements under the root element
+        if isinstance(in_rspec, RSpec):
+            in_rspec = in_rspec.toxml()
+        tree = etree.parse(StringIO(in_rspec))
+        root = tree.getroot()
+        for child in root.getchildren():
+            self.xml.root.append(child)
+
+    def cleanup(self):
+        # remove unncecessary elements, attributes
+        if self.type in ['request', 'manifest']:
+            # remove 'available' element from remaining node elements
+            self.xml.remove_element('//default:available | //available')
+
+class PGv2Ad(PGv2):
+    enabled = True
+    content_type = 'ad'
+    schema = 'http://www.protogeni.net/resources/rspec/2/ad.xsd'
+    template = '<rspec type="advertisement" xmlns="http://www.protogeni.net/resources/rspec/2" xmlns:flack="http://www.protogeni.net/resources/rspec/ext/flack/1" xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.protogeni.net/resources/rspec/2 http://www.protogeni.net/resources/rspec/2/ad.xsd http://www.planet-lab.org/resources/sfa/ext/planetlab/1 http://www.planet-lab.org/resources/sfa/ext/planetlab/1/planetlab.xsd"/>'
+
+class PGv2Request(PGv2):
+    enabled = True
+    content_type = 'request'
+    schema = 'http://www.protogeni.net/resources/rspec/2/request.xsd'
+    template = '<rspec type="request" xmlns="http://www.protogeni.net/resources/rspec/2" xmlns:flack="http://www.protogeni.net/resources/rspec/ext/flack/1" xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.protogeni.net/resources/rspec/2 http://www.protogeni.net/resources/rspec/2/request.xsd http://www.planet-lab.org/resources/sfa/ext/planetlab/1 http://www.planet-lab.org/resources/sfa/ext/planetlab/1/planetlab.xsd"/>'
+
+class PGv2Manifest(PGv2):
+    enabled = True
+    content_type = 'manifest'
+    schema = 'http://www.protogeni.net/resources/rspec/2/manifest.xsd'
+    template = '<rspec type="manifest" xmlns="http://www.protogeni.net/resources/rspec/2" xmlns:flack="http://www.protogeni.net/resources/rspec/ext/flack/1" xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.protogeni.net/resources/rspec/2 http://www.protogeni.net/resources/rspec/2/manifest.xsd http://www.planet-lab.org/resources/sfa/ext/planetlab/1 http://www.planet-lab.org/resources/sfa/ext/planetlab/1/planetlab.xsd"/>'
+
+     
+
+
+if __name__ == '__main__':
+    from sfa.rspecs.rspec import RSpec
+    from sfa.rspecs.rspec_elements import *
+    r = RSpec('/tmp/pg.rspec')
+    r.load_rspec_elements(PGv2.elements)
+    r.namespaces = PGv2.namespaces
+    print r.get(RSpecElements.NODE)
diff --git a/sfa/rspecs/versions/pgv3.py b/sfa/rspecs/versions/pgv3.py
new file mode 100644 (file)
index 0000000..3fe60e5
--- /dev/null
@@ -0,0 +1,34 @@
+from sfa.rspecs.versions.pgv2 import PGv2
+
+class PGv3(PGv2):
+    type = 'GENI'
+    content_type = 'ad'
+    version = '3'
+    schema = 'http://www.geni.net/resources/rspec/3/ad.xsd'
+    namespace = 'http://www.geni.net/resources/rspec/3'
+    extensions = {
+        'flack': "http://www.protogeni.net/resources/rspec/ext/flack/1",
+        'planetlab': "http://www.planet-lab.org/resources/sfa/ext/planetlab/1",
+    }
+    namespaces = dict(extensions.items() + [('default', namespace)])
+    elements = []
+
+
+class PGv3Ad(PGv3):
+    enabled = True
+    content_type = 'ad'
+    schema = 'http://www.geni.net/resources/rspec/3/ad.xsd'
+    template = '<rspec type="advertisement" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://www.geni.net/resources/rspec/3" xmlns:flack="http://www.protogeni.net/resources/rspec/ext/flack/1" xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1" xsi:schemaLocation="http://www.geni.net/resources/rspec/3 http://www.geni.net/resources/rspec/3/ad.xsd http://www.planet-lab.org/resources/sfa/ext/planetlab/1 http://www.planet-lab.org/resources/sfa/ext/planetlab/1/planetlab.xsd"/>'
+
+class PGv3Request(PGv3):
+    enabled = True
+    content_type = 'request'
+    schema = 'http://www.geni.net/resources/rspec/3/request.xsd'
+    template = '<rspec type="request" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://www.geni.net/resources/rspec/3" xmlns:flack="http://www.protogeni.net/resources/rspec/ext/flack/1" xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1" xsi:schemaLocation="http://www.geni.net/resources/rspec/3 http://www.geni.net/resources/rspec/3/request.xsd http://www.planet-lab.org/resources/sfa/ext/planetlab/1 http://www.planet-lab.org/resources/sfa/ext/planetlab/1/planetlab.xsd"/>'
+
+class PGv2Manifest(PGv3):
+    enabled = True
+    content_type = 'manifest'
+    schema = 'http://www.geni.net/resources/rspec/3/manifest.xsd'
+    template = '<rspec type="manifest" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://www.geni.net/resources/rspec/3" xmlns:flack="http://www.protogeni.net/resources/rspec/ext/flack/1" xmlns:planetlab="http://www.planet-lab.org/resources/sfa/ext/planetlab/1" xsi:schemaLocation="http://www.geni.net/resources/rspec/3 http://www.geni.net/resources/rspec/3/ad.xsd http://www.planet-lab.org/resources/sfa/ext/planetlab/1 http://www.planet-lab.org/resources/sfa/ext/planetlab/1/planetlab.xsd"/>'
+     
old mode 100755 (executable)
new mode 100644 (file)
similarity index 58%
rename from sfa/rspecs/sfa_rspec.py
rename to sfa/rspecs/versions/sfav1.py
index 85d5fcc..06f277a
@@ -1,31 +1,21 @@
-#!/usr/bin/python 
 from lxml import etree
-from StringIO import StringIO
-from sfa.rspecs.rspec import RSpec 
-from sfa.util.xrn import *
-from sfa.util.plxrn import hostname_to_urn
-from sfa.util.config import Config
-from sfa.rspecs.rspec_version import RSpecVersion  
+from sfa.util.xrn import hrn_to_urn, urn_to_hrn
+from sfa.rspecs.rspec_version import BaseVersion
+from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements
+from sfa.rspecs.elements.versions.pgv2Link import PGv2Link
+
+class SFAv1(BaseVersion):
+    enabled = True
+    type = 'SFA'
+    content_type = '*'
+    version = '1'
+    schema = None
+    namespace = None
+    extensions = {}
+    namespaces = None
+    elements = [] 
+    template = '<RSpec type="%s"></RSpec>' % type
 
-
-_version = { 'type': 'SFA', 
-             'version': '1' 
-}
-
-sfa_rspec_version = RSpecVersion(_version)
-
-class SfaRSpec(RSpec):
-    xml = None
-    header = '<?xml version="1.0"?>\n'
-    version = sfa_rspec_version
-
-    def create(self):
-        RSpec.create(self)
-        self.xml.set('type', 'SFA')
-
-    ###################
-    # Parser
-    ###################
     def get_network_elements(self):
         return self.xml.xpath('//network')
 
@@ -38,10 +28,10 @@ class SfaRSpec(RSpec):
         else:
             names = self.xml.xpath('//node/hostname')
         for name in names:
-            if name.text == hostname:
+            if str(name.text).strip() == hostname:
                 return name.getparent()
         return None
+
     def get_node_elements(self, network=None):
         if network:
             return self.xml.xpath('//network[@name="%s"]//node' % network)
@@ -53,43 +43,66 @@ class SfaRSpec(RSpec):
             nodes = self.xml.xpath('//node/hostname/text()')
         else:
             nodes = self.xml.xpath('//network[@name="%s"]//node/hostname/text()' % network)
+
+        nodes = [node.strip() for node in nodes]
         return nodes
 
     def get_nodes_with_slivers(self, network = None):
         if network:
-            return self.xml.xpath('//network[@name="%s"]//node[sliver]/hostname/text()' % network)   
+            nodes =  self.xml.xpath('//network[@name="%s"]//node[sliver]/hostname/text()' % network)  
         else:
-            return self.xml.xpath('//node[sliver]/hostname/text()')
+            nodes = self.xml.xpath('//node[sliver]/hostname/text()')
 
-    def get_nodes_without_slivers(self, network=None): 
+        nodes = [node.strip() for node in nodes]
+        return nodes
+
+    def get_nodes_without_slivers(self, network=None):
         xpath_nodes_without_slivers = '//node[not(sliver)]/hostname/text()'
-        xpath_nodes_without_slivers_in_network = '//network[@name="%s"]//node[not(sliver)]/hostname/text()' 
+        xpath_nodes_without_slivers_in_network = '//network[@name="%s"]//node[not(sliver)]/hostname/text()'
         if network:
             return self.xml.xpath('//network[@name="%s"]//node[not(sliver)]/hostname/text()' % network)
         else:
-            return self.xml.xpath('//node[not(sliver)]/hostname/text()')      
-
+            return self.xml.xpath('//node[not(sliver)]/hostname/text()')
 
     def attributes_list(self, elem):
         # convert a list of attribute tags into list of tuples
-        # (tagnme, text_value) 
+        # (tagnme, text_value)
         opts = []
         if elem is not None:
             for e in elem:
-                opts.append((e.tag, e.text))
+                opts.append((e.tag, str(e.text).strip()))
         return opts
 
     def get_default_sliver_attributes(self, network=None):
         if network:
-            defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)        
+            defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)
         else:
-            defaults = self.xml.xpath("//network/sliver_defaults" % network)
+            defaults = self.xml.xpath("//sliver_defaults")
+        if isinstance(defaults, list) and defaults:
+            defaults = defaults[0]
         return self.attributes_list(defaults)
 
     def get_sliver_attributes(self, hostname, network=None):
+        attributes = []
         node = self.get_node_element(hostname, network)
-        sliver = node.find("sliver")
-        return self.attributes_list(sliver)
+        #sliver = node.find("sliver")
+        slivers = node.xpath('./sliver')
+        if isinstance(slivers, list) and slivers:
+            attributes = self.attributes_list(slivers[0])
+        return attributes
+
+    def get_slice_attributes(self, network=None):
+        slice_attributes = []
+        nodes_with_slivers = self.get_nodes_with_slivers(network)
+        for default_attribute in self.get_default_sliver_attributes(network):
+            attribute = {'name': str(default_attribute[0]), 'value': str(default_attribute[1]), 'node_id': None}
+            slice_attributes.append(attribute)
+        for node in nodes_with_slivers:
+            sliver_attributes = self.get_sliver_attributes(node, network)
+            for sliver_attribute in sliver_attributes:
+                attribute = {'name': str(sliver_attribute[0]), 'value': str(sliver_attribute[1]), 'node_id': node}
+                slice_attributes.append(attribute)
+        return slice_attributes
 
     def get_site_nodes(self, siteid, network=None):
         if network:
@@ -98,20 +111,10 @@ class SfaRSpec(RSpec):
         else:
             nodes = self.xml.xpath('//site[@id="%s"]/node/hostname/text()' % siteid)
         return nodes
-        
+
     def get_links(self, network=None):
-        if network: 
-            links = self.xml.xpath('//network[@name="%s"]/link' % network)
-        else:
-            links = self.xml.xpath('//link')    
-        linklist = []
-        for link in links:
-            (end1, end2) = link.get("endpoints").split()
-            name = link.find("description")
-            linklist.append((name.text,
-                             self.get_site_nodes(end1, network),
-                             self.get_site_nodes(end2, network)))
-        return linklist
+        links = PGv2Link.get_links(self.xml, self.namespaces)
+        return links
 
     def get_link(self, fromnode, tonode, network=None):
         fromsite = fromnode.getparent()
@@ -133,19 +136,19 @@ class SfaRSpec(RSpec):
 
     def get_vlinks(self, network=None):
         vlinklist = []
-        if network: 
+        if network:
             vlinks = self.xml.xpath("//network[@name='%s']//vlink" % network)
         else:
-            vlinks = self.xml.xpath("//vlink") 
+            vlinks = self.xml.xpath("//vlink")
         for vlink in vlinks:
             endpoints = vlink.get("endpoints")
             (end1, end2) = endpoints.split()
-            if network: 
+            if network:
                 node1 = self.xml.xpath('//network[@name="%s"]//node[@id="%s"]/hostname/text()' % \
                                        (network, end1))[0]
                 node2 = self.xml.xpath('//network[@name="%s"]//node[@id="%s"]/hostname/text()' % \
                                        (network, end2))[0]
-            else: 
+            else:
                 node1 = self.xml.xpath('//node[@id="%s"]/hostname/text()' % end1)[0]
                 node2 = self.xml.xpath('//node[@id="%s"]/hostname/text()' % end2)[0]
             desc = "%s <--> %s" % (node1, node2)
@@ -160,16 +163,22 @@ class SfaRSpec(RSpec):
             query = "//vlink[@endpoints = '%s']" % (network, endpoints)
         results = self.rspec.xpath(query)
         return results
-        
+
     def query_vlinks(self, endpoints, network=None):
         return get_vlink(endpoints,network)
 
+    
     ##################
     # Builder
     ##################
 
     def add_network(self, network):
-        network_tag = etree.SubElement(self.xml, 'network', id=network)     
+        network_tags = self.xml.xpath('//network[@name="%s"]' % network)
+        if not network_tags:
+            network_tag = etree.SubElement(self.xml.root, 'network', name=network)
+        else:
+            network_tag = network_tags[0]
+        return network_tag
 
     def add_nodes(self, nodes, network = None, no_dupes=False):
         if not isinstance(nodes, list):
@@ -180,96 +189,125 @@ class SfaRSpec(RSpec):
                 # node already exists
                 continue
 
-            network_tag = self.xml
+            network_tag = self.xml.root
             if 'network' in node:
                 network = node['network']
-                network_tags = self.xml.xpath('//network[@name="%s"]' % network)
-                if not network_tags:
-                    network_tag = etree.SubElement(self.xml, 'network', name=network)
-                else:
-                    network_tag = network_tags[0]
-                     
+                network_tag = self.add_network(network)
+
             node_tag = etree.SubElement(network_tag, 'node')
             if 'network' in node:
-                node_tag.set('component_manager_id', network)
+                node_tag.set('component_manager_id', hrn_to_urn(network, 'authority+sa'))
             if 'urn' in node:
-                node_tag.set('component_id', node['urn']) 
+                node_tag.set('component_id', node['urn'])
             if 'site_urn' in node:
                 node_tag.set('site_id', node['site_urn'])
-            if 'node_id' in node: 
+            if 'node_id' in node:
                 node_tag.set('node_id', 'n'+str(node['node_id']))
+            if 'boot_state' in node:
+                node_tag.set('boot_state', node['boot_state'])
             if 'hostname' in node:
+                node_tag.set('component_name', node['hostname']) 
                 hostname_tag = etree.SubElement(node_tag, 'hostname').text = node['hostname']
             if 'interfaces' in node:
+                i = 0
                 for interface in node['interfaces']:
                     if 'bwlimit' in interface and interface['bwlimit']:
                         bwlimit = etree.SubElement(node_tag, 'bw_limit', units='kbps').text = str(interface['bwlimit']/1000)
+                    comp_id = hrn_to_urn(network, 'pc%s:eth%s' % (node['node_id'], i)) 
+                    interface_tag = etree.SubElement(node_tag, 'interface', component_id=comp_id)
+                    i+=1
+            if 'bw_unallocated' in node:
+                bw_unallocated = etree.SubElement(node_tag, 'bw_unallocated', units='kbps').text = str(node['bw_unallocated']/1000) 
             if 'tags' in node:
                 for tag in node['tags']:
-                   # expose this hard wired list of tags, plus the ones that are marked 'sfa' in their category 
+                   # expose this hard wired list of tags, plus the ones that are marked 'sfa' in their category
                    if tag['tagname'] in ['fcdistro', 'arch'] or 'sfa' in tag['category'].split('/'):
-                        tag_element = etree.SubElement(node_tag, tag['tagname'], value=tag['value'])
+                        tag_element = etree.SubElement(node_tag, tag['tagname']).text=tag['value']
 
             if 'site' in node:
                 longitude = str(node['site']['longitude'])
                 latitude = str(node['site']['latitude'])
                 location = etree.SubElement(node_tag, 'location', country='unknown', \
-                                            longitude=longitude, latitude=latitude)                
+                                            longitude=longitude, latitude=latitude)
+
+    def merge_node(self, source_node_tag, network, no_dupes=False):
+        if no_dupes and self.get_node_element(node['hostname']):
+            # node already exists
+            return
+
+        network_tag = self.add_network(network)
+        network_tag.append(deepcopy(source_node_tag))
 
     def add_interfaces(self, interfaces):
-        pass     
+        pass
 
     def add_links(self, links):
-        pass
-    
-    def add_slivers(self, slivers, network=None, sliver_urn=None, no_dupes=False):
-        slivers = self._process_slivers(slivers)
-        nodes_with_slivers = self.get_nodes_with_slivers(network)
+        PGv2Link.add_links(self.xml, links)
+
+    def add_slivers(self, slivers, network=None, sliver_urn=None, no_dupes=False, append=False):
+        # add slice name to network tag
+        network_tags = self.xml.xpath('//network')
+        if network_tags:
+            network_tag = network_tags[0]
+            network_tag.set('slice', urn_to_hrn(sliver_urn)[0])
+        
+        all_nodes = self.get_nodes()
+        nodes_with_slivers = [sliver['hostname'] for sliver in slivers]
+        nodes_without_slivers = set(all_nodes).difference(nodes_with_slivers)
+        
+        # add slivers
         for sliver in slivers:
-            if sliver['hostname'] in nodes_with_slivers:
-                continue
             node_elem = self.get_node_element(sliver['hostname'], network)
+            if not node_elem: continue
             sliver_elem = etree.SubElement(node_elem, 'sliver')
             if 'tags' in sliver:
                 for tag in sliver['tags']:
-                    etree.SubElement(sliver_elem, tag['tagname'], value=tag['value'])
+                    etree.SubElement(sliver_elem, tag['tagname']).text = value=tag['value']
+            
+        # remove all nodes without slivers
+        if not append:
+            for node in nodes_without_slivers:
+                node_elem = self.get_node_element(node)
+                parent = node_elem.getparent()
+                parent.remove(node_elem)
 
     def remove_slivers(self, slivers, network=None, no_dupes=False):
-        if not isinstance(slivers, list):
-            slivers = [slivers]
         for sliver in slivers:
             node_elem = self.get_node_element(sliver['hostname'], network)
-            sliver_elem = node.find('sliver')
+            sliver_elem = node_elem.find('sliver')
             if sliver_elem != None:
-                node_elem.remove(sliver)                 
-    
+                node_elem.remove(sliver_elem)
+
     def add_default_sliver_attribute(self, name, value, network=None):
         if network:
             defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)
         else:
             defaults = self.xml.xpath("//sliver_defaults" % network)
-        if defaults is None:
-            defaults = etree.Element("sliver_defaults")
-            network = self.xml.xpath("//network[@name='%s']" % network)
-            network.insert(0, defaults)
-        self.add_attribute(defaults, name, value)
+        if not defaults :
+            network_tag = self.xml.xpath("//network[@name='%s']" % network)
+            if isinstance(network_tag, list):
+                network_tag = network_tag[0]
+            defaults = self.xml.add_element('sliver_defaults', attrs={}, parent=network_tag)
+        elif isinstance(defaults, list):
+            defaults = defaults[0]
+        self.xml.add_attribute(defaults, name, value)
 
     def add_sliver_attribute(self, hostname, name, value, network=None):
         node = self.get_node_element(hostname, network)
         sliver = node.find("sliver")
-        self.add_attribute(sliver, name, value)
+        self.xml.add_attribute(sliver, name, value)
 
     def remove_default_sliver_attribute(self, name, value, network=None):
         if network:
             defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)
         else:
             defaults = self.xml.xpath("//sliver_defaults" % network)
-        self.remove_attribute(defaults, name, value)
+        self.xml.remove_attribute(defaults, name, value)
 
     def remove_sliver_attribute(self, hostname, name, value, network=None):
         node = self.get_node_element(hostname, network)
         sliver = node.find("sliver")
-        self.remove_attribute(sliver, name, value)
+        self.xml.remove_attribute(sliver, name, value)
 
     def add_vlink(self, fromhost, tohost, kbps, network=None):
         fromnode = self.get_node_element(fromhost, network)
@@ -281,7 +319,7 @@ class SfaRSpec(RSpec):
             fromid = fromnode.get("id")
             toid = tonode.get("id")
             vlink.set("endpoints", "%s %s" % (fromid, toid))
-            self.add_attribute(vlink, "kbps", kbps)
+            self.xml.add_attribute(vlink, "kbps", kbps)
 
 
     def remove_vlink(self, endpoints, network=None):
@@ -292,39 +330,31 @@ class SfaRSpec(RSpec):
 
     def merge(self, in_rspec):
         """
-        Merge contents for specified rspec with current rspec 
+        Merge contents for specified rspec with current rspec
         """
 
+        from sfa.rspecs.rspec import RSpec
+        if isinstance(in_rspec, RSpec):
+            rspec = in_rspec
+        else:
+            rspec = RSpec(in_rspec)
+        if rspec.version.type.lower() == 'protogeni':
+            from sfa.rspecs.rspec_converter import RSpecConverter
+            in_rspec = RSpecConverter.to_sfa_rspec(rspec.toxml())
+            rspec = RSpec(in_rspec)
+
         # just copy over all networks
         current_networks = self.get_networks()
-        rspec = SfaRSpec(rspec=in_rspec)
-        networks = rspec.get_network_elements()
+        networks = rspec.version.get_network_elements()
         for network in networks:
             current_network = network.get('name')
-            if not current_network in current_networks:
-                self.xml.append(network)
+            if current_network and current_network not in current_networks:
+                self.xml.root.append(network)
                 current_networks.append(current_network)
-        
-         
 
 if __name__ == '__main__':
-    rspec = SfaRSpec()
-    nodes = [
-    {'network': 'plc',
-     'hostname': 'node1.planet-lab.org',
-     'site_urn': 'urn:publicid:IDN+plc+authority+cm',
-      'node_id': 1,
-    },
-    {'network': 'plc',
-     'hostname': 'node2.planet-lab.org',
-     'site_urn': 'urn:publicid:IDN+plc+authority+cm',
-      'node_id': 1,
-    },
-    {'network': 'ple',
-     'hostname': 'node1.planet-lab.eu',
-     'site_urn': 'urn:publicid:IDN+plc+authority+cm',
-      'node_id': 1,
-    },
-    ]
-    rspec.add_nodes(nodes)
-    print rspec
+    from sfa.rspecs.rspec import RSpec
+    from sfa.rspecs.rspec_elements import *
+    r = RSpec('/tmp/resources.rspec')
+    r.load_rspec_elements(SFAv1.elements)
+    print r.get(RSpecElements.NODE)
index 1a96e15..59a3e6b 100644 (file)
@@ -1,7 +1,8 @@
 from sfa.util.faults import *
 from sfa.util.server import SfaServer
 from sfa.util.xrn import hrn_to_urn
-from sfa.server.interface import Interfaces
+from sfa.server.interface import Interfaces, Interface
+from sfa.util.config import Config     
 
 class Aggregate(SfaServer):
 
@@ -22,19 +23,13 @@ class Aggregates(Interfaces):
 
     default_dict = {'aggregates': {'aggregate': [Interfaces.default_fields]}}
  
-    def __init__(self, api, conf_file = "/etc/sfa/aggregates.xml"):
-        Interfaces.__init__(self, api, conf_file)
+    def __init__(self, conf_file = "/etc/sfa/aggregates.xml"):
+        Interfaces.__init__(self, conf_file)
+        sfa_config = Config() 
         # set up a connection to the local aggregate
-        if self.api.config.SFA_AGGREGATE_ENABLED:
-            address = self.api.config.SFA_AGGREGATE_HOST
-            port = self.api.config.SFA_AGGREGATE_PORT
-            url = 'http://%(address)s:%(port)s' % locals()
-            local_aggregate = {'hrn': self.api.hrn,
-                               'urn': hrn_to_urn(self.api.hrn, 'authority'),
-                               'addr': address,
-                               'port': port,
-                               'url': url}
-            self.interfaces[self.api.hrn] = local_aggregate
-
-        # get connections
-        self.update(self.get_connections())
+        if sfa_config.SFA_AGGREGATE_ENABLED:
+            addr = sfa_config.SFA_AGGREGATE_HOST
+            port = sfa_config.SFA_AGGREGATE_PORT
+            hrn = sfa_config.SFA_INTERFACE_HRN
+            interface = Interface(hrn, addr, port)
+            self[hrn] = interface
index c83ac4a..1dc6652 100644 (file)
@@ -1,10 +1,6 @@
 #
 # Component is a SfaServer that implements the Component interface
 #
-### $Id: 
-### $URL: 
-#
-
 import tempfile
 import os
 import time
index a3b06ef..dbc8ef2 100644 (file)
@@ -1,7 +1,6 @@
 import traceback
 import os.path
 
-from sfa.util.sfalogging import sfa_logger
 from sfa.util.faults import *
 from sfa.util.storage import XmlStorage
 from sfa.util.xrn import get_authority, hrn_to_urn
@@ -17,6 +16,29 @@ except ImportError:
     GeniClientLight = None            
 
 
+
+class Interface:
+    
+    def __init__(self, hrn, addr, port, client_type='sfa'):
+        self.hrn = hrn
+        self.addr = addr
+        self.port = port
+        self.client_type = client_type
+  
+    def get_url(self):
+        address_parts = self.addr.split('/')
+        address_parts[0] = address_parts[0] + ":" + str(self.port)
+        url =  "http://%s" %  "/".join(address_parts)
+        return url
+
+    def get_server(self, key_file, cert_file, timeout=30):
+        server = None 
+        if  self.client_type ==  'geniclientlight' and GeniClientLight:
+            server = GeniClientLight(url, self.api.key_file, self.api.cert_file)
+        else:
+            server = xmlrpcprotocol.get_server(self.get_url(), key_file, cert_file, timeout) 
+        return server       
 ##
 # In is a dictionary of registry connections keyed on the registry
 # hrn
@@ -24,12 +46,7 @@ except ImportError:
 class Interfaces(dict):
     """
     Interfaces is a base class for managing information on the
-    peers we are federated with. It is responsible for the following:
-
-    1) Makes sure a record exist in the local registry for the each 
-       fedeated peer   
-    2) Attempts to fetch and install trusted gids   
-    3) Provides connections (xmlrpc or soap) to federated peers
+    peers we are federated with. Provides connections (xmlrpc or soap) to federated peers
     """
 
     # fields that must be specified in the config file
@@ -42,165 +59,24 @@ class Interfaces(dict):
     # defined by the class 
     default_dict = {}
 
-    types = ['authority']
-
-    def __init__(self, api, conf_file, type='authority'):
-        if type not in self.types:
-            raise SfaInfaildArgument('Invalid type %s: must be in %s' % (type, self.types))    
+    def __init__(self, conf_file):
         dict.__init__(self, {})
-        self.api = api
-        self.type = type  
         # load config file
         self.interface_info = XmlStorage(conf_file, self.default_dict)
         self.interface_info.load()
-        interfaces = self.interface_info.values()[0].values()[0]
-        if not isinstance(interfaces, list):
-            interfaces = [self.interfaces]
-        # set the url and urn 
-        for interface in interfaces:
-            # port is appended onto the domain, before the path. Should look like:
-            # http://domain:port/path
-            hrn, address, port = interface['hrn'], interface['addr'], interface['port']
-            address_parts = address.split('/')
-            address_parts[0] = address_parts[0] + ":" + str(port)
-            url =  "http://%s" %  "/".join(address_parts)
-            interface['url'] = url
-            interface['urn'] = hrn_to_urn(hrn, 'authority')
-    
-        self.interfaces = {}
-        required_fields = self.default_fields.keys()
-        for interface in interfaces:
-            valid = True
-            # skp any interface definition that has a null hrn, 
-            # address or port
-            for field in required_fields:
-                if field not in interface or not interface[field]:
-                    valid = False
-                    break
-            if valid:     
-                self.interfaces[interface['hrn']] = interface
-
-
-    def sync_interfaces(self):
-        """
-        Install missing trusted gids and db records for our federated
-        interfaces
-        """     
-        # Attempt to get any missing peer gids
-        # There should be a gid file in /etc/sfa/trusted_roots for every
-        # peer registry found in in the registries.xml config file. If there
-        # are any missing gids, request a new one from the peer registry.
-        gids_current = self.api.auth.trusted_cert_list
-        hrns_current = [gid.get_hrn() for gid in gids_current] 
-        hrns_expected = self.interfaces.keys() 
-        new_hrns = set(hrns_expected).difference(hrns_current)
-        gids = self.get_peer_gids(new_hrns) + gids_current
-        # make sure there is a record for every gid
-        self.update_db_records(self.type, gids)
-        
-    def get_peer_gids(self, new_hrns):
-        """
-        Install trusted gids from the specified interfaces.  
-        """
-        peer_gids = []
-        if not new_hrns:
-            return peer_gids
-        trusted_certs_dir = self.api.config.get_trustedroots_dir()
-        for new_hrn in new_hrns:
-            if not new_hrn:
-                continue
-            # the gid for this interface should already be installed  
-            if new_hrn == self.api.config.SFA_INTERFACE_HRN:
-                continue
-            try:
-                # get gid from the registry
-                interface_info =  self.interfaces[new_hrn]
-                interface = self[new_hrn]
-                trusted_gids = interface.get_trusted_certs()
-                if trusted_gids:
-                    # the gid we want shoudl be the first one in the list, 
-                    # but lets make sure
-                    for trusted_gid in trusted_gids:
-                        # default message
-                        message = "interface: %s\t" % (self.api.interface)
-                        message += "unable to install trusted gid for %s" % \
-                                   (new_hrn) 
-                        gid = GID(string=trusted_gids[0])
-                        peer_gids.append(gid) 
-                        if gid.get_hrn() == new_hrn:
-                            gid_filename = os.path.join(trusted_certs_dir, '%s.gid' % new_hrn)
-                            gid.save_to_file(gid_filename, save_parents=True)
-                            message = "interface: %s\tinstalled trusted gid for %s" % \
-                                (self.api.interface, new_hrn)
-                        # log the message
-                        self.api.logger.info(message)
-            except:
-                message = "interface: %s\tunable to install trusted gid for %s" % \
-                            (self.api.interface, new_hrn) 
-                self.api.logger.log_exc(message)
-        
-        # reload the trusted certs list
-        self.api.auth.load_trusted_certs()
-        return peer_gids
-
-    def update_db_records(self, type, gids):
-        """
-        Make sure there is a record in the local db for allowed registries
-        defined in the config file (registries.xml). Removes old records from
-        the db.         
-        """
-        # import SfaTable here so this module can be loaded by ComponentAPI 
-        from sfa.util.table import SfaTable
-        if not gids: 
-            return
+        records = self.interface_info.values()[0]
+        if not isinstance(records, list):
+            records = [records]
         
-        # hrns that should have a record
-        hrns_expected = [gid.get_hrn() for gid in gids]
-
-        # get hrns that actually exist in the db
-        table = SfaTable()
-        records = table.find({'type': type, 'pointer': -1})
-        hrns_found = [record['hrn'] for record in records]
-      
-        # remove old records
-        for record in records:
-            if record['hrn'] not in hrns_expected and \
-                record['hrn'] != self.api.config.SFA_INTERFACE_HRN:
-                table.remove(record)
-
-        # add new records
-        for gid in gids:
-            hrn = gid.get_hrn()
-            if hrn not in hrns_found:
-                record = {
-                    'hrn': hrn,
-                    'type': type,
-                    'pointer': -1, 
-                    'authority': get_authority(hrn),
-                    'gid': gid.save_to_string(save_parents=True),
-                }
-                record = SfaRecord(dict=record)
-                table.insert(record)
-                        
-    def get_connections(self):
-        """
-        read connection details for the trusted peer registries from file return 
-        a dictionary of connections keyed on interface hrn. 
-        """
-        connections = {}
         required_fields = self.default_fields.keys()
-        for interface in self.interfaces.values():
-            url = interface['url']
-#            sfa_logger().debug("Interfaces.get_connections - looping on neighbour %s"%url)
-            # check which client we should use
-            # sfa.util.xmlrpcprotocol is default
-            client_type = 'xmlrpcprotocol'
-            if interface.has_key('client') and \
-               interface['client'] in ['geniclientlight'] and \
-               GeniClientLight:
-                client_type = 'geniclientlight'
-                connections[hrn] = GeniClientLight(url, self.api.key_file, self.api.cert_file) 
-            else:
-                connections[interface['hrn']] = xmlrpcprotocol.get_server(url, self.api.key_file, self.api.cert_file)
+        for record in records:
+            if not record or not set(required_fields).issubset(record.keys()):
+                continue
+            # port is appended onto the domain, before the path. Should look like:
+            # http://domain:port/path
+            hrn, address, port = record['hrn'], record['addr'], record['port']
+            interface = Interface(hrn, address, port) 
+            self[hrn] = interface
 
-        return connections 
+    def get_server(self, hrn, key_file, cert_file, timeout=30):
+        return self[hrn].get_server(key_file, cert_file, timeout)
index bb61bb7..deaf89f 100755 (executable)
@@ -13,7 +13,7 @@ import xmlrpclib
 from mod_python import apache
 
 from sfa.plc.api import SfaAPI
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
 
 api = SfaAPI(interface='aggregate')
 
@@ -53,5 +53,5 @@ def handler(req):
 
     except Exception, err:
         # Log error in /var/log/httpd/(ssl_)?error_log
-        sfa_logger().log_exc('%r'%err)
+        logger.log_exc('%r'%err)
         return apache.HTTP_INTERNAL_SERVER_ERROR
index 0c46084..8879813 100755 (executable)
@@ -13,7 +13,7 @@ import xmlrpclib
 from mod_python import apache
 
 from sfa.plc.api import SfaAPI
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
 
 api = SfaAPI(interface='registry')
 
@@ -53,5 +53,5 @@ def handler(req):
 
     except Exception, err:
         # Log error in /var/log/httpd/(ssl_)?error_log
-        sfa_logger().log_exc('%r'%err)
+        logger.log_exc('%r'%err)
         return apache.HTTP_INTERNAL_SERVER_ERROR
index a28b002..e0f2b92 100755 (executable)
@@ -13,7 +13,7 @@ import xmlrpclib
 from mod_python import apache
 
 from sfa.plc.api import SfaAPI
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
 
 api = SfaAPI(interface='slicemgr')
 
@@ -53,5 +53,5 @@ def handler(req):
 
     except Exception, err:
         # Log error in /var/log/httpd/(ssl_)?error_log
-        sfa_logger().log_exc('%r'%err)
+        logger.log_exc('%r'%err)
         return apache.HTTP_INTERNAL_SERVER_ERROR
index b7bfdd8..b254811 100644 (file)
@@ -1,17 +1,11 @@
 #
 # Registry is a SfaServer that implements the Registry interface
 #
-### $Id$
-### $URL$
-#
-
 from sfa.util.server import SfaServer
 from sfa.util.faults import *
 from sfa.util.xrn import hrn_to_urn
-from sfa.server.interface import Interfaces
-import sfa.util.xmlrpcprotocol as xmlrpcprotocol
-import sfa.util.soapprotocol as soapprotocol
+from sfa.server.interface import Interfaces, Interface
+from sfa.util.config import Config 
 
 ##
 # Registry is a SfaServer that serves registry and slice operations at PLC.
@@ -35,17 +29,12 @@ class Registries(Interfaces):
     
     default_dict = {'registries': {'registry': [Interfaces.default_fields]}}
 
-    def __init__(self, api, conf_file = "/etc/sfa/registries.xml"):
-        Interfaces.__init__(self, api, conf_file) 
-        address = self.api.config.SFA_REGISTRY_HOST
-        port = self.api.config.SFA_REGISTRY_PORT
-        url = 'http://%(address)s:%(port)s' % locals()
-        local_registry = {'hrn': self.api.hrn,
-                           'urn': hrn_to_urn(self.api.hrn, 'authority'),
-                           'addr': address,
-                           'port': port,
-                           'url': url}
-        self.interfaces[self.api.hrn] = local_registry
-       
-        # get connections
-        self.update(self.get_connections()) 
+    def __init__(self, conf_file = "/etc/sfa/registries.xml"):
+        Interfaces.__init__(self, conf_file) 
+        sfa_config = Config() 
+        if sfa_config.SFA_REGISTRY_ENABLED:
+            addr = sfa_config.SFA_REGISTRY_HOST
+            port = sfa_config.SFA_REGISTRY_PORT
+            hrn = sfa_config.SFA_INTERFACE_HRN
+            interface = Interface(hrn, addr, port)
+            self[hrn] = interface
index 984b41c..fadb1d3 100755 (executable)
@@ -1,13 +1,12 @@
 #!/usr/bin/python
 #
-# SFA PLC Wrapper
+# PlanetLab SFA implementation
 #
-# This wrapper implements the SFA Registry and Slice Interfaces on PLC.
+# This implements the SFA Registry and Slice Interfaces on PLC.
 # Depending on command line options, it starts some combination of a
 # Registry, an Aggregate Manager, and a Slice Manager.
 #
-# There are several items that need to be done before starting the wrapper
-# server.
+# There are several items that need to be done before starting the servers.
 #
 # NOTE:  Many configuration settings, including the PLC maintenance account
 # credentials, URI of the PLCAPI, and PLC DB URI and admin credentials are initialized
@@ -35,10 +34,10 @@ component_port=12346
 import os, os.path
 import traceback
 import sys
+import sfa.util.xmlrpcprotocol as xmlrpcprotocol
 from optparse import OptionParser
 
-from sfa.util.sfalogging import sfa_logger
-from sfa.trust.trustedroot import TrustedRootList
+from sfa.util.sfalogging import logger
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.hierarchy import Hierarchy
 from sfa.trust.gid import GID
@@ -46,6 +45,10 @@ from sfa.util.config import Config
 from sfa.plc.api import SfaAPI
 from sfa.server.registry import Registries
 from sfa.server.aggregate import Aggregates
+from sfa.util.xrn import get_authority, hrn_to_urn
+from sfa.util.sfalogging import logger
+
+from sfa.managers.import_manager import import_manager
 
 # after http://www.erlenstar.demon.co.uk/unix/faq_2.html
 def daemon():
@@ -83,8 +86,9 @@ def init_server_key(server_key_file, server_cert_file, config, hierarchy):
         if not os.path.exists(key_file):
             # if it doesnt exist then this is probably a fresh interface
             # with no records. Generate a random keypair for now
-            sfa_logger().debug("server's public key not found in %s" % key_file)
-            sfa_logger().debug("generating a random server key pair")
+            logger.debug("server's public key not found in %s" % key_file)
+
+            logger.debug("generating a random server key pair")
             key = Keypair(create=True)
             key.save_to_file(server_key_file)
             init_server_cert(hrn, key, server_cert_file, self_signed=True)    
@@ -113,18 +117,18 @@ def init_server_cert(hrn, key, server_cert_file, self_signed=False):
     else:
         try:
             # look for gid file
-            sfa_logger().debug("generating server cert from gid: %s"% hrn)
+            logger.debug("generating server cert from gid: %s"% hrn)
             hierarchy = Hierarchy()
             auth_info = hierarchy.get_auth_info(hrn)
             gid = GID(filename=auth_info.gid_filename)
             gid.save_to_file(filename=server_cert_file)
         except:
             # fall back to self signed cert
-            sfa_logger().debug("gid for %s not found" % hrn)
+            logger.debug("gid for %s not found" % hrn)
             init_self_signed_cert(hrn, key, server_cert_file)        
         
 def init_self_signed_cert(hrn, key, server_cert_file):
-    sfa_logger().debug("generating self signed cert")
+    logger.debug("generating self signed cert")
     # generate self signed certificate
     cert = Certificate(subject=hrn)
     cert.set_issuer(key=key, subject=hrn)
@@ -134,43 +138,122 @@ def init_self_signed_cert(hrn, key, server_cert_file):
 
 def init_server(options, config):
     """
-    Execute the init method defined in the manager file 
+    Locate the manager based on config.*TYPE
+    Execute the init_server method (well in fact function, sigh) if defined in that module
+    In order to migrate to a more generic approach:
+    * search for <>_manager_<type>.py
+    * if not found, try <>_manager.py (and issue a warning if <type>!='pl')
     """
-    def init_manager(manager_module, manager_base):
-        try: manager = __import__(manager_module, fromlist=[manager_base])
-        except: manager = None
-        if manager and hasattr(manager, 'init_server'):
-            manager.init_server()
-    
-    manager_base = 'sfa.managers'
     if options.registry:
-        mgr_type = config.SFA_REGISTRY_TYPE
-        manager_module = manager_base + ".registry_manager_%s" % mgr_type
-        init_manager(manager_module, manager_base)    
-    if options.am:
-        mgr_type = config.SFA_AGGREGATE_TYPE
-        manager_module = manager_base + ".aggregate_manager_%s" % mgr_type
-        init_manager(manager_module, manager_base)    
+        manager=import_manager ("registry",       config.SFA_REGISTRY_TYPE)
+        if manager and hasattr(manager, 'init_server'): manager.init_server()
+    if options.am:      
+        manager=import_manager ("aggregate",      config.SFA_AGGREGATE_TYPE)
+        if manager and hasattr(manager, 'init_server'): manager.init_server()
     if options.sm:
-        mgr_type = config.SFA_SM_TYPE
-        manager_module = manager_base + ".slice_manager_%s" % mgr_type
-        init_manager(manager_module, manager_base)    
+        manager=import_manager ("slice",          config.SFA_SM_TYPE)
+        if manager and hasattr(manager, 'init_server'): manager.init_server()
     if options.cm:
-        mgr_type = config.SFA_CM_TYPE
-        manager_module = manager_base + ".component_manager_%s" % mgr_type
-        init_manager(manager_module, manager_base)    
+        manager=import_manager ("component",      config.SFA_CM_TYPE)
+        if manager and hasattr(manager, 'init_server'): manager.init_server()
+
 
-def sync_interfaces(server_key_file, server_cert_file):
+def install_peer_certs(server_key_file, server_cert_file):
     """
     Attempt to install missing trusted gids and db records for 
     our federated interfaces
     """
+    # Attempt to get any missing peer gids
+    # There should be a gid file in /etc/sfa/trusted_roots for every
+    # peer registry found in in the registries.xml config file. If there
+    # are any missing gids, request a new one from the peer registry.
     api = SfaAPI(key_file = server_key_file, cert_file = server_cert_file)
-    registries = Registries(api)
-    aggregates = Aggregates(api)
-    registries.sync_interfaces()
-    aggregates.sync_interfaces()
+    registries = Registries()
+    aggregates = Aggregates()
+    interfaces = dict(registries.items() + aggregates.items())
+    gids_current = api.auth.trusted_cert_list
+    hrns_current = [gid.get_hrn() for gid in gids_current]
+    hrns_expected = set([hrn for hrn in interfaces])
+    new_hrns = set(hrns_expected).difference(hrns_current)
+    #gids = self.get_peer_gids(new_hrns) + gids_current
+    peer_gids = []
+    if not new_hrns:
+        return 
+
+    trusted_certs_dir = api.config.get_trustedroots_dir()
+    for new_hrn in new_hrns:
+        if not new_hrn: continue
+        # the gid for this interface should already be installed
+        if new_hrn == api.config.SFA_INTERFACE_HRN: continue
+        try:
+            # get gid from the registry
+            url = interfaces[new_hrn].get_url()
+            interface = interfaces[new_hrn].get_server(server_key_file, server_cert_file, timeout=30)
+            # skip non sfa aggregates
+            server_version = api.get_cached_server_version(interface)
+            if 'sfa' not in server_version:
+                logger.info("get_trusted_certs: skipping non sfa aggregate: %s" % new_hrn)
+                continue
+      
+            trusted_gids = interface.get_trusted_certs()
+            if trusted_gids:
+                # the gid we want should be the first one in the list,
+                # but lets make sure
+                for trusted_gid in trusted_gids:
+                    # default message
+                    message = "interface: %s\t" % (api.interface)
+                    message += "unable to install trusted gid for %s" % \
+                               (new_hrn)
+                    gid = GID(string=trusted_gids[0])
+                    peer_gids.append(gid)
+                    if gid.get_hrn() == new_hrn:
+                        gid_filename = os.path.join(trusted_certs_dir, '%s.gid' % new_hrn)
+                        gid.save_to_file(gid_filename, save_parents=True)
+                        message = "installed trusted cert for %s" % new_hrn
+                    # log the message
+                    api.logger.info(message)
+        except:
+            message = "interface: %s\tunable to install trusted gid for %s" % \
+                        (api.interface, new_hrn)
+            api.logger.log_exc(message)
+    # doesnt matter witch one
+    update_cert_records(peer_gids)
+
+def update_cert_records(gids):
+    """
+    Make sure there is a record in the registry for the specified gids. 
+    Removes old records from the db.
+    """
+    # import SfaTable here so this module can be loaded by ComponentAPI
+    from sfa.util.table import SfaTable
+    from sfa.util.record import SfaRecord
+    if not gids:
+        return
+    table = SfaTable()
+    # get records that actually exist in the db
+    gid_urns = [gid.get_urn() for gid in gids]
+    hrns_expected = [gid.get_hrn() for gid in gids]
+    records_found = table.find({'hrn': hrns_expected, 'pointer': -1}) 
 
+    # remove old records
+    for record in records_found:
+        if record['hrn'] not in hrns_expected and \
+            record['hrn'] != self.api.config.SFA_INTERFACE_HRN:
+            table.remove(record)
+
+    # TODO: store urn in the db so we do this in 1 query 
+    for gid in gids:
+        hrn, type = gid.get_hrn(), gid.get_type()
+        record = table.find({'hrn': hrn, 'type': type, 'pointer': -1})
+        if not record:
+            record = {
+                'hrn': hrn, 'type': type, 'pointer': -1,
+                'authority': get_authority(hrn),
+                'gid': gid.save_to_string(save_parents=True),
+            }
+            record = SfaRecord(dict=record)
+            table.insert(record)
+        
 def main():
     # Generate command line parser
     parser = OptionParser(usage="sfa-server [options]")
@@ -182,24 +265,28 @@ def main():
          help="run aggregate manager", default=False)
     parser.add_option("-c", "--component", dest="cm", action="store_true",
          help="run component server", default=False)
+    parser.add_option("-t", "--trusted-certs", dest="trusted_certs", action="store_true",
+         help="refresh trusted certs", default=False)
     parser.add_option("-v", "--verbose", action="count", dest="verbose", default=0,
          help="verbose mode - cumulative")
     parser.add_option("-d", "--daemon", dest="daemon", action="store_true",
          help="Run as daemon.", default=False)
     (options, args) = parser.parse_args()
-    sfa_logger().setLevelFromOptVerbose(options.verbose)
-
+    
     config = Config()
-    if config.SFA_API_DEBUG: sfa_logger().setLevelDebug()
+    if config.SFA_API_DEBUG: pass
     hierarchy = Hierarchy()
     server_key_file = os.path.join(hierarchy.basedir, "server.key")
     server_cert_file = os.path.join(hierarchy.basedir, "server.cert")
 
     init_server_key(server_key_file, server_cert_file, config, hierarchy)
     init_server(options, config)
-    sync_interfaces(server_key_file, server_cert_file)   
  
     if (options.daemon):  daemon()
+    
+    if options.trusted_certs:
+        install_peer_certs(server_key_file, server_cert_file)   
+    
     # start registry server
     if (options.registry):
         from sfa.server.registry import Registry
@@ -227,4 +314,4 @@ if __name__ == "__main__":
     try:
         main()
     except:
-        sfa_logger().log_exc_critical("SFA server is exiting")
+        logger.log_exc_critical("SFA server is exiting")
index 8842eae..c0fbd6a 100644 (file)
@@ -1,6 +1,3 @@
-### $Id$
-### $URL$
-
 import os
 import sys
 import datetime
index 218783e..41c71cf 100644 (file)
@@ -5,14 +5,14 @@ import sys
 
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.credential import Credential
-from sfa.trust.trustedroot import TrustedRootList
+from sfa.trust.trustedroots import TrustedRoots
 from sfa.util.faults import *
 from sfa.trust.hierarchy import Hierarchy
 from sfa.util.config import *
 from sfa.util.xrn import get_authority
 from sfa.util.sfaticket import *
 
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
 
 class Auth:
     """
@@ -27,8 +27,8 @@ class Auth:
         self.load_trusted_certs()
 
     def load_trusted_certs(self):
-        self.trusted_cert_list = TrustedRootList(self.config.get_trustedroots_dir()).get_list()
-        self.trusted_cert_file_list = TrustedRootList(self.config.get_trustedroots_dir()).get_file_list()
+        self.trusted_cert_list = TrustedRoots(self.config.get_trustedroots_dir()).get_list()
+        self.trusted_cert_file_list = TrustedRoots(self.config.get_trustedroots_dir()).get_file_list()
 
         
         
@@ -36,14 +36,14 @@ class Auth:
         valid = []
         if not isinstance(creds, list):
             creds = [creds]
-        sfa_logger().debug("Auth.checkCredentials with %d creds"%len(creds))
+        logger.debug("Auth.checkCredentials with %d creds"%len(creds))
         for cred in creds:
             try:
                 self.check(cred, operation, hrn)
                 valid.append(cred)
             except:
                 cred_obj=Credential(string=cred)
-                sfa_logger().debug("failed to validate credential - dump=%s"%cred_obj.dump_string(dump_parents=True))
+                logger.debug("failed to validate credential - dump=%s"%cred_obj.dump_string(dump_parents=True))
                 error = sys.exc_info()[:2]
                 continue
             
@@ -303,7 +303,7 @@ class Auth:
     def get_authority(self, hrn):
         return get_authority(hrn)
 
-    def filter_creds_by_caller(self, creds, caller_hrn):
+    def filter_creds_by_caller(self, creds, caller_hrn_list):
         """
         Returns a list of creds who's gid caller matches the 
         specified caller hrn
@@ -311,10 +311,12 @@ class Auth:
         if not isinstance(creds, list):
             creds = [creds]
         creds = []
+        if not isinistance(caller_hrn_list, list):
+            caller_hrn_list = [caller_hrn_list]
         for cred in creds:
             try:
                 tmp_cred = Credential(string=cred)
-                if tmp_cred.get_gid_caller().get_hrn() == caller_hrn:
+                if tmp_cred.get_gid_caller().get_hrn() in [caller_hrn_list]:
                     creds.append(cred)
             except: pass
         return creds
index 9a2e862..bcec9d6 100644 (file)
-#----------------------------------------------------------------------
-# Copyright (c) 2008 Board of Trustees, Princeton University
-#
-# Permission is hereby granted, free of charge, to any person obtaining
-# a copy of this software and/or hardware specification (the "Work") to
-# deal in the Work without restriction, including without limitation the
-# rights to use, copy, modify, merge, publish, distribute, sublicense,
-# and/or sell copies of the Work, and to permit persons to whom the Work
-# is furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be
-# included in all copies or substantial portions of the Work.
-#
-# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
-# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 
-# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
-# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
-# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
-# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
-# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS 
-# IN THE WORK.
-#----------------------------------------------------------------------
-
-##
-# SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement
-# the necessary crypto functionality. Ideally just one of these libraries
-# would be used, but unfortunately each of these libraries is independently
-# lacking. The pyOpenSSL library is missing many necessary functions, and
-# the M2Crypto library has crashed inside of some of the functions. The
-# design decision is to use pyOpenSSL whenever possible as it seems more
-# stable, and only use M2Crypto for those functions that are not possible
-# in pyOpenSSL.
-#
-# This module exports two classes: Keypair and Certificate.
-##
-#
-
-import functools
-import os
-import tempfile
-import base64
-import traceback
-from tempfile import mkstemp
-
-from OpenSSL import crypto
-import M2Crypto
-from M2Crypto import X509
-
-from sfa.util.sfalogging import sfa_logger
-from sfa.util.xrn import urn_to_hrn
-from sfa.util.faults import *
-
-glo_passphrase_callback = None
-
-##
-# A global callback msy be implemented for requesting passphrases from the
-# user. The function will be called with three arguments:
-#
-#    keypair_obj: the keypair object that is calling the passphrase
-#    string: the string containing the private key that's being loaded
-#    x: unknown, appears to be 0, comes from pyOpenSSL and/or m2crypto
-#
-# The callback should return a string containing the passphrase.
-
-def set_passphrase_callback(callback_func):
-    global glo_passphrase_callback
-
-    glo_passphrase_callback = callback_func
-
-##
-# Sets a fixed passphrase.
-
-def set_passphrase(passphrase):
-    set_passphrase_callback( lambda k,s,x: passphrase )
-
-##
-# Check to see if a passphrase works for a particular private key string.
-# Intended to be used by passphrase callbacks for input validation.
-
-def test_passphrase(string, passphrase):
-    try:
-        crypto.load_privatekey(crypto.FILETYPE_PEM, string, (lambda x: passphrase))
-        return True
-    except:
-        return False
-
-def convert_public_key(key):
-    keyconvert_path = "/usr/bin/keyconvert.py"
-    if not os.path.isfile(keyconvert_path):
-        raise IOError, "Could not find keyconvert in %s" % keyconvert_path
-
-    # we can only convert rsa keys
-    if "ssh-dss" in key:
-        return None
-
-    (ssh_f, ssh_fn) = tempfile.mkstemp()
-    ssl_fn = tempfile.mktemp()
-    os.write(ssh_f, key)
-    os.close(ssh_f)
-
-    cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn
-    os.system(cmd)
-
-    # this check leaves the temporary file containing the public key so
-    # that it can be expected to see why it failed.
-    # TODO: for production, cleanup the temporary files
-    if not os.path.exists(ssl_fn):
-        return None
-
-    k = Keypair()
-    try:
-        k.load_pubkey_from_file(ssl_fn)
-    except:
-        sfa_logger().log_exc("convert_public_key caught exception")
-        k = None
-
-    # remove the temporary files
-    os.remove(ssh_fn)
-    os.remove(ssl_fn)
-
-    return k
-
-##
-# Public-private key pairs are implemented by the Keypair class.
-# A Keypair object may represent both a public and private key pair, or it
-# may represent only a public key (this usage is consistent with OpenSSL).
-
-class Keypair:
-    key = None       # public/private keypair
-    m2key = None     # public key (m2crypto format)
-
-    ##
-    # Creates a Keypair object
-    # @param create If create==True, creates a new public/private key and
-    #     stores it in the object
-    # @param string If string!=None, load the keypair from the string (PEM)
-    # @param filename If filename!=None, load the keypair from the file
-
-    def __init__(self, create=False, string=None, filename=None):
-        if create:
-            self.create()
-        if string:
-            self.load_from_string(string)
-        if filename:
-            self.load_from_file(filename)
-
-    ##
-    # Create a RSA public/private key pair and store it inside the keypair object
-
-    def create(self):
-        self.key = crypto.PKey()
-        self.key.generate_key(crypto.TYPE_RSA, 1024)
-
-    ##
-    # Save the private key to a file
-    # @param filename name of file to store the keypair in
-
-    def save_to_file(self, filename):
-        open(filename, 'w').write(self.as_pem())
-        self.filename=filename
-
-    ##
-    # Load the private key from a file. Implicity the private key includes the public key.
-
-    def load_from_file(self, filename):
-        self.filename=filename
-        buffer = open(filename, 'r').read()
-        self.load_from_string(buffer)
-
-    ##
-    # Load the private key from a string. Implicitly the private key includes the public key.
-
-    def load_from_string(self, string):
-        if glo_passphrase_callback:
-            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string) )
-            self.m2key = M2Crypto.EVP.load_key_string(string, functools.partial(glo_passphrase_callback, self, string) )
-        else:
-            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)
-            self.m2key = M2Crypto.EVP.load_key_string(string)
-
-    ##
-    #  Load the public key from a string. No private key is loaded.
-
-    def load_pubkey_from_file(self, filename):
-        # load the m2 public key
-        m2rsakey = M2Crypto.RSA.load_pub_key(filename)
-        self.m2key = M2Crypto.EVP.PKey()
-        self.m2key.assign_rsa(m2rsakey)
-
-        # create an m2 x509 cert
-        m2name = M2Crypto.X509.X509_Name()
-        m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)
-        m2x509 = M2Crypto.X509.X509()
-        m2x509.set_pubkey(self.m2key)
-        m2x509.set_serial_number(0)
-        m2x509.set_issuer_name(m2name)
-        m2x509.set_subject_name(m2name)
-        ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()
-        ASN1.set_time(500)
-        m2x509.set_not_before(ASN1)
-        m2x509.set_not_after(ASN1)
-        # x509v3 so it can have extensions
-        # prob not necc since this cert itself is junk but still...
-        m2x509.set_version(2)
-        junk_key = Keypair(create=True)
-        m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")
-
-        # convert the m2 x509 cert to a pyopenssl x509
-        m2pem = m2x509.as_pem()
-        pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)
-
-        # get the pyopenssl pkey from the pyopenssl x509
-        self.key = pyx509.get_pubkey()
-        self.filename=filename
-
-    ##
-    # Load the public key from a string. No private key is loaded.
-
-    def load_pubkey_from_string(self, string):
-        (f, fn) = tempfile.mkstemp()
-        os.write(f, string)
-        os.close(f)
-        self.load_pubkey_from_file(fn)
-        os.remove(fn)
-
-    ##
-    # Return the private key in PEM format.
-
-    def as_pem(self):
-        return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)
-
-    ##
-    # Return an M2Crypto key object
-
-    def get_m2_pkey(self):
-        if not self.m2key:
-            self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())
-        return self.m2key
-
-    ##
-    # Returns a string containing the public key represented by this object.
-
-    def get_pubkey_string(self):
-        m2pkey = self.get_m2_pkey()
-        return base64.b64encode(m2pkey.as_der())
-
-    ##
-    # Return an OpenSSL pkey object
-
-    def get_openssl_pkey(self):
-        return self.key
-
-    ##
-    # Given another Keypair object, return TRUE if the two keys are the same.
-
-    def is_same(self, pkey):
-        return self.as_pem() == pkey.as_pem()
-
-    def sign_string(self, data):
-        k = self.get_m2_pkey()
-        k.sign_init()
-        k.sign_update(data)
-        return base64.b64encode(k.sign_final())
-
-    def verify_string(self, data, sig):
-        k = self.get_m2_pkey()
-        k.verify_init()
-        k.verify_update(data)
-        return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)
-
-    def compute_hash(self, value):
-        return self.sign_string(str(value))
-
-    # only informative
-    def get_filename(self):
-        return getattr(self,'filename',None)
-
-    def dump (self, *args, **kwargs):
-        print self.dump_string(*args, **kwargs)
-
-    def dump_string (self):
-        result=""
-        result += "KEYPAIR: pubkey=%40s..."%self.get_pubkey_string()
-        filename=self.get_filename()
-        if filename: result += "Filename %s\n"%filename
-        return result
-    
-##
-# The certificate class implements a general purpose X509 certificate, making
-# use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds
-# several addition features, such as the ability to maintain a chain of
-# parent certificates, and storage of application-specific data.
-#
-# Certificates include the ability to maintain a chain of parents. Each
-# certificate includes a pointer to it's parent certificate. When loaded
-# from a file or a string, the parent chain will be automatically loaded.
-# When saving a certificate to a file or a string, the caller can choose
-# whether to save the parent certificates as well.
-
-class Certificate:
-    digest = "md5"
-
-    cert = None
-    issuerKey = None
-    issuerSubject = None
-    parent = None
-
-    separator="-----parent-----"
-
-    ##
-    # Create a certificate object.
-    #
-    # @param create If create==True, then also create a blank X509 certificate.
-    # @param subject If subject!=None, then create a blank certificate and set
-    #     it's subject name.
-    # @param string If string!=None, load the certficate from the string.
-    # @param filename If filename!=None, load the certficiate from the file.
-
-    def __init__(self, create=False, subject=None, string=None, filename=None, intermediate=None):
-        self.data = {}
-        if create or subject:
-            self.create()
-        if subject:
-            self.set_subject(subject)
-        if string:
-            self.load_from_string(string)
-        if filename:
-            self.load_from_file(filename)
-
-        if intermediate:
-            self.set_intermediate_ca(intermediate)
-
-    ##
-    # Create a blank X509 certificate and store it in this object.
-
-    def create(self):
-        self.cert = crypto.X509()
-        self.cert.set_serial_number(3)
-        self.cert.gmtime_adj_notBefore(0)
-        self.cert.gmtime_adj_notAfter(60*60*24*365*5) # five years
-        self.cert.set_version(2) # x509v3 so it can have extensions        
-
-
-    ##
-    # Given a pyOpenSSL X509 object, store that object inside of this
-    # certificate object.
-
-    def load_from_pyopenssl_x509(self, x509):
-        self.cert = x509
-
-    ##
-    # Load the certificate from a string
-
-    def load_from_string(self, string):
-        # if it is a chain of multiple certs, then split off the first one and
-        # load it (support for the ---parent--- tag as well as normal chained certs)
-
-        string = string.strip()
-        
-        # If it's not in proper PEM format, wrap it
-        if string.count('-----BEGIN CERTIFICATE') == 0:
-            string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string
-
-        # If there is a PEM cert in there, but there is some other text first
-        # such as the text of the certificate, skip the text
-        beg = string.find('-----BEGIN CERTIFICATE')
-        if beg > 0:
-            # skipping over non cert beginning                                                                                                              
-            string = string[beg:]
-
-        parts = []
-
-        if string.count('-----BEGIN CERTIFICATE-----') > 1 and \
-               string.count(Certificate.separator) == 0:
-            parts = string.split('-----END CERTIFICATE-----',1)
-            parts[0] += '-----END CERTIFICATE-----'
-        else:
-            parts = string.split(Certificate.separator, 1)
-
-        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])
-
-        # if there are more certs, then create a parent and let the parent load
-        # itself from the remainder of the string
-        if len(parts) > 1 and parts[1] != '':
-            self.parent = self.__class__()
-            self.parent.load_from_string(parts[1])
-
-    ##
-    # Load the certificate from a file
-
-    def load_from_file(self, filename):
-        file = open(filename)
-        string = file.read()
-        self.load_from_string(string)
-        self.filename=filename
-
-    ##
-    # Save the certificate to a string.
-    #
-    # @param save_parents If save_parents==True, then also save the parent certificates.
-
-    def save_to_string(self, save_parents=True):
-        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)
-        if save_parents and self.parent:
-            string = string + self.parent.save_to_string(save_parents)
-        return string
-
-    ##
-    # Save the certificate to a file.
-    # @param save_parents If save_parents==True, then also save the parent certificates.
-
-    def save_to_file(self, filename, save_parents=True, filep=None):
-        string = self.save_to_string(save_parents=save_parents)
-        if filep:
-            f = filep
-        else:
-            f = open(filename, 'w')
-        f.write(string)
-        f.close()
-        self.filename=filename
-
-    ##
-    # Save the certificate to a random file in /tmp/
-    # @param save_parents If save_parents==True, then also save the parent certificates.
-    def save_to_random_tmp_file(self, save_parents=True):
-        fp, filename = mkstemp(suffix='cert', text=True)
-        fp = os.fdopen(fp, "w")
-        self.save_to_file(filename, save_parents=True, filep=fp)
-        return filename
-
-    ##
-    # Sets the issuer private key and name
-    # @param key Keypair object containing the private key of the issuer
-    # @param subject String containing the name of the issuer
-    # @param cert (optional) Certificate object containing the name of the issuer
-
-    def set_issuer(self, key, subject=None, cert=None):
-        self.issuerKey = key
-        if subject:
-            # it's a mistake to use subject and cert params at the same time
-            assert(not cert)
-            if isinstance(subject, dict) or isinstance(subject, str):
-                req = crypto.X509Req()
-                reqSubject = req.get_subject()
-                if (isinstance(subject, dict)):
-                    for key in reqSubject.keys():
-                        setattr(reqSubject, key, subject[key])
-                else:
-                    setattr(reqSubject, "CN", subject)
-                subject = reqSubject
-                # subject is not valid once req is out of scope, so save req
-                self.issuerReq = req
-        if cert:
-            # if a cert was supplied, then get the subject from the cert
-            subject = cert.cert.get_subject()
-        assert(subject)
-        self.issuerSubject = subject
-
-    ##
-    # Get the issuer name
-
-    def get_issuer(self, which="CN"):
-        x = self.cert.get_issuer()
-        return getattr(x, which)
-
-    ##
-    # Set the subject name of the certificate
-
-    def set_subject(self, name):
-        req = crypto.X509Req()
-        subj = req.get_subject()
-        if (isinstance(name, dict)):
-            for key in name.keys():
-                setattr(subj, key, name[key])
-        else:
-            setattr(subj, "CN", name)
-        self.cert.set_subject(subj)
-    ##
-    # Get the subject name of the certificate
-
-    def get_subject(self, which="CN"):
-        x = self.cert.get_subject()
-        return getattr(x, which)
-
-    ##
-    # Get the public key of the certificate.
-    #
-    # @param key Keypair object containing the public key
-
-    def set_pubkey(self, key):
-        assert(isinstance(key, Keypair))
-        self.cert.set_pubkey(key.get_openssl_pkey())
-
-    ##
-    # Get the public key of the certificate.
-    # It is returned in the form of a Keypair object.
-
-    def get_pubkey(self):
-        m2x509 = X509.load_cert_string(self.save_to_string())
-        pkey = Keypair()
-        pkey.key = self.cert.get_pubkey()
-        pkey.m2key = m2x509.get_pubkey()
-        return pkey
-
-    def set_intermediate_ca(self, val):
-        self.intermediate = val
-        if val:
-            self.add_extension('basicConstraints', 1, 'CA:TRUE')
-
-
-
-    ##
-    # Add an X509 extension to the certificate. Add_extension can only be called
-    # once for a particular extension name, due to limitations in the underlying
-    # library.
-    #
-    # @param name string containing name of extension
-    # @param value string containing value of the extension
-
-    def add_extension(self, name, critical, value):
-        ext = crypto.X509Extension (name, critical, value)
-        self.cert.add_extensions([ext])
-
-    ##
-    # Get an X509 extension from the certificate
-
-    def get_extension(self, name):
-
-        # pyOpenSSL does not have a way to get extensions
-        m2x509 = X509.load_cert_string(self.save_to_string())
-        value = m2x509.get_ext(name).get_value()
-        
-        return value
-
-    ##
-    # Set_data is a wrapper around add_extension. It stores the parameter str in
-    # the X509 subject_alt_name extension. Set_data can only be called once, due
-    # to limitations in the underlying library.
-
-    def set_data(self, str, field='subjectAltName'):
-        # pyOpenSSL only allows us to add extensions, so if we try to set the
-        # same extension more than once, it will not work
-        if self.data.has_key(field):
-            raise "Cannot set ", field, " more than once"
-        self.data[field] = str
-        self.add_extension(field, 0, str)
-
-    ##
-    # Return the data string that was previously set with set_data
-
-    def get_data(self, field='subjectAltName'):
-        if self.data.has_key(field):
-            return self.data[field]
-
-        try:
-            uri = self.get_extension(field)
-            self.data[field] = uri
-        except LookupError:
-            return None
-
-        return self.data[field]
-
-    ##
-    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().
-
-    def sign(self):
-        sfa_logger().debug('certificate.sign')
-        assert self.cert != None
-        assert self.issuerSubject != None
-        assert self.issuerKey != None
-        self.cert.set_issuer(self.issuerSubject)
-        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)
-
-    ##
-    # Verify the authenticity of a certificate.
-    # @param pkey is a Keypair object representing a public key. If Pkey
-    #     did not sign the certificate, then an exception will be thrown.
-
-    def verify(self, pkey):
-        # pyOpenSSL does not have a way to verify signatures
-        m2x509 = X509.load_cert_string(self.save_to_string())
-        m2pkey = pkey.get_m2_pkey()
-        # verify it
-        return m2x509.verify(m2pkey)
-
-        # XXX alternatively, if openssl has been patched, do the much simpler:
-        # try:
-        #   self.cert.verify(pkey.get_openssl_key())
-        #   return 1
-        # except:
-        #   return 0
-
-    ##
-    # Return True if pkey is identical to the public key that is contained in the certificate.
-    # @param pkey Keypair object
-
-    def is_pubkey(self, pkey):
-        return self.get_pubkey().is_same(pkey)
-
-    ##
-    # Given a certificate cert, verify that this certificate was signed by the
-    # public key contained in cert. Throw an exception otherwise.
-    #
-    # @param cert certificate object
-
-    def is_signed_by_cert(self, cert):
-        k = cert.get_pubkey()
-        result = self.verify(k)
-        return result
-
-    ##
-    # Set the parent certficiate.
-    #
-    # @param p certificate object.
-
-    def set_parent(self, p):
-        self.parent = p
-
-    ##
-    # Return the certificate object of the parent of this certificate.
-
-    def get_parent(self):
-        return self.parent
-
-    ##
-    # Verification examines a chain of certificates to ensure that each parent
-    # signs the child, and that some certificate in the chain is signed by a
-    # trusted certificate.
-    #
-    # Verification is a basic recursion: <pre>
-    #     if this_certificate was signed by trusted_certs:
-    #         return
-    #     else
-    #         return verify_chain(parent, trusted_certs)
-    # </pre>
-    #
-    # At each recursion, the parent is tested to ensure that it did sign the
-    # child. If a parent did not sign a child, then an exception is thrown. If
-    # the bottom of the recursion is reached and the certificate does not match
-    # a trusted root, then an exception is thrown.
-    #
-    # @param Trusted_certs is a list of certificates that are trusted.
-    #
-
-    def verify_chain(self, trusted_certs = None):
-        # Verify a chain of certificates. Each certificate must be signed by
-        # the public key contained in it's parent. The chain is recursed
-        # until a certificate is found that is signed by a trusted root.
-
-        # verify expiration time
-        if self.cert.has_expired():
-            sfa_logger().debug("verify_chain: NO our certificate has expired")
-            raise CertExpired(self.get_subject(), "client cert")   
-        
-        # if this cert is signed by a trusted_cert, then we are set
-        for trusted_cert in trusted_certs:
-            if self.is_signed_by_cert(trusted_cert):
-                # verify expiration of trusted_cert ?
-                if not trusted_cert.cert.has_expired():
-                    sfa_logger().debug("verify_chain: YES cert %s signed by trusted cert %s"%(
-                            self.get_subject(), trusted_cert.get_subject()))
-                    return trusted_cert
-                else:
-                    sfa_logger().debug("verify_chain: NO cert %s is signed by trusted_cert %s, but this is expired..."%(
-                            self.get_subject(),trusted_cert.get_subject()))
-                    raise CertExpired(self.get_subject(),"trusted_cert %s"%trusted_cert.get_subject())
-
-        # if there is no parent, then no way to verify the chain
-        if not self.parent:
-            sfa_logger().debug("verify_chain: NO %s has no parent and is not in trusted roots"%self.get_subject())
-            raise CertMissingParent(self.get_subject())
-
-        # if it wasn't signed by the parent...
-        if not self.is_signed_by_cert(self.parent):
-            sfa_logger().debug("verify_chain: NO %s is not signed by parent"%self.get_subject())
-            return CertNotSignedByParent(self.get_subject())
-
-        # if the parent isn't verified...
-        sfa_logger().debug("verify_chain: .. %s, -> verifying parent %s"%(self.get_subject(),self.parent.get_subject()))
-        self.parent.verify_chain(trusted_certs)
-
-        return
-
-    ### more introspection
-    def get_extensions(self):
-        # pyOpenSSL does not have a way to get extensions
-        triples=[]
-        m2x509 = X509.load_cert_string(self.save_to_string())
-        nb_extensions=m2x509.get_ext_count()
-        sfa_logger().debug("X509 had %d extensions"%nb_extensions)
-        for i in range(nb_extensions):
-            ext=m2x509.get_ext_at(i)
-            triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) )
-        return triples
-
-    def get_data_names(self):
-        return self.data.keys()
-
-    def get_all_datas (self):
-        triples=self.get_extensions()
-        for name in self.get_data_names(): 
-            triples.append( (name,self.get_data(name),'data',) )
-        return triples
-
-    # only informative
-    def get_filename(self):
-        return getattr(self,'filename',None)
-
-    def dump (self, *args, **kwargs):
-        print self.dump_string(*args, **kwargs)
-
-    def dump_string (self,show_extensions=False):
-        result = ""
-        result += "CERTIFICATE for %s\n"%self.get_subject()
-        result += "Issued by %s\n"%self.get_issuer()
-        filename=self.get_filename()
-        if filename: result += "Filename %s\n"%filename
-        if show_extensions:
-            all_datas=self.get_all_datas()
-            result += " has %d extensions/data attached"%len(all_datas)
-            for (n,v,c) in all_datas:
-                if c=='data':
-                    result += "   data: %s=%s\n"%(n,v)
-                else:
-                    result += "    ext: %s (crit=%s)=<<<%s>>>\n"%(n,c,v)
-        return result
+#----------------------------------------------------------------------\r
+# Copyright (c) 2008 Board of Trustees, Princeton University\r
+#\r
+# Permission is hereby granted, free of charge, to any person obtaining\r
+# a copy of this software and/or hardware specification (the "Work") to\r
+# deal in the Work without restriction, including without limitation the\r
+# rights to use, copy, modify, merge, publish, distribute, sublicense,\r
+# and/or sell copies of the Work, and to permit persons to whom the Work\r
+# is furnished to do so, subject to the following conditions:\r
+#\r
+# The above copyright notice and this permission notice shall be\r
+# included in all copies or substantial portions of the Work.\r
+#\r
+# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS \r
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF \r
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND \r
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT \r
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, \r
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, \r
+# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS \r
+# IN THE WORK.\r
+#----------------------------------------------------------------------\r
+\r
+##\r
+# SFA uses two crypto libraries: pyOpenSSL and M2Crypto to implement\r
+# the necessary crypto functionality. Ideally just one of these libraries\r
+# would be used, but unfortunately each of these libraries is independently\r
+# lacking. The pyOpenSSL library is missing many necessary functions, and\r
+# the M2Crypto library has crashed inside of some of the functions. The\r
+# design decision is to use pyOpenSSL whenever possible as it seems more\r
+# stable, and only use M2Crypto for those functions that are not possible\r
+# in pyOpenSSL.\r
+#\r
+# This module exports two classes: Keypair and Certificate.\r
+##\r
+#\r
+\r
+import functools\r
+import os\r
+import tempfile\r
+import base64\r
+import traceback\r
+from tempfile import mkstemp\r
+\r
+from OpenSSL import crypto\r
+import M2Crypto\r
+from M2Crypto import X509\r
+\r
+from sfa.util.sfalogging import logger\r
+from sfa.util.xrn import urn_to_hrn\r
+from sfa.util.faults import *\r
+from sfa.util.sfalogging import logger\r
+\r
+glo_passphrase_callback = None\r
+\r
+##\r
+# A global callback msy be implemented for requesting passphrases from the\r
+# user. The function will be called with three arguments:\r
+#\r
+#    keypair_obj: the keypair object that is calling the passphrase\r
+#    string: the string containing the private key that's being loaded\r
+#    x: unknown, appears to be 0, comes from pyOpenSSL and/or m2crypto\r
+#\r
+# The callback should return a string containing the passphrase.\r
+\r
+def set_passphrase_callback(callback_func):\r
+    global glo_passphrase_callback\r
+\r
+    glo_passphrase_callback = callback_func\r
+\r
+##\r
+# Sets a fixed passphrase.\r
+\r
+def set_passphrase(passphrase):\r
+    set_passphrase_callback( lambda k,s,x: passphrase )\r
+\r
+##\r
+# Check to see if a passphrase works for a particular private key string.\r
+# Intended to be used by passphrase callbacks for input validation.\r
+\r
+def test_passphrase(string, passphrase):\r
+    try:\r
+        crypto.load_privatekey(crypto.FILETYPE_PEM, string, (lambda x: passphrase))\r
+        return True\r
+    except:\r
+        return False\r
+\r
+def convert_public_key(key):\r
+    keyconvert_path = "/usr/bin/keyconvert.py"\r
+    if not os.path.isfile(keyconvert_path):\r
+        raise IOError, "Could not find keyconvert in %s" % keyconvert_path\r
+\r
+    # we can only convert rsa keys\r
+    if "ssh-dss" in key:\r
+        return None\r
+\r
+    (ssh_f, ssh_fn) = tempfile.mkstemp()\r
+    ssl_fn = tempfile.mktemp()\r
+    os.write(ssh_f, key)\r
+    os.close(ssh_f)\r
+\r
+    cmd = keyconvert_path + " " + ssh_fn + " " + ssl_fn\r
+    os.system(cmd)\r
+\r
+    # this check leaves the temporary file containing the public key so\r
+    # that it can be expected to see why it failed.\r
+    # TODO: for production, cleanup the temporary files\r
+    if not os.path.exists(ssl_fn):\r
+        return None\r
+\r
+    k = Keypair()\r
+    try:\r
+        k.load_pubkey_from_file(ssl_fn)\r
+    except:\r
+        logger.log_exc("convert_public_key caught exception")\r
+        k = None\r
+\r
+    # remove the temporary files\r
+    os.remove(ssh_fn)\r
+    os.remove(ssl_fn)\r
+\r
+    return k\r
+\r
+##\r
+# Public-private key pairs are implemented by the Keypair class.\r
+# A Keypair object may represent both a public and private key pair, or it\r
+# may represent only a public key (this usage is consistent with OpenSSL).\r
+\r
+class Keypair:\r
+    key = None       # public/private keypair\r
+    m2key = None     # public key (m2crypto format)\r
+\r
+    ##\r
+    # Creates a Keypair object\r
+    # @param create If create==True, creates a new public/private key and\r
+    #     stores it in the object\r
+    # @param string If string!=None, load the keypair from the string (PEM)\r
+    # @param filename If filename!=None, load the keypair from the file\r
+\r
+    def __init__(self, create=False, string=None, filename=None):\r
+        if create:\r
+            self.create()\r
+        if string:\r
+            self.load_from_string(string)\r
+        if filename:\r
+            self.load_from_file(filename)\r
+\r
+    ##\r
+    # Create a RSA public/private key pair and store it inside the keypair object\r
+\r
+    def create(self):\r
+        self.key = crypto.PKey()\r
+        self.key.generate_key(crypto.TYPE_RSA, 1024)\r
+\r
+    ##\r
+    # Save the private key to a file\r
+    # @param filename name of file to store the keypair in\r
+\r
+    def save_to_file(self, filename):\r
+        open(filename, 'w').write(self.as_pem())\r
+        self.filename=filename\r
+\r
+    ##\r
+    # Load the private key from a file. Implicity the private key includes the public key.\r
+\r
+    def load_from_file(self, filename):\r
+        self.filename=filename\r
+        buffer = open(filename, 'r').read()\r
+        self.load_from_string(buffer)\r
+\r
+    ##\r
+    # Load the private key from a string. Implicitly the private key includes the public key.\r
+\r
+    def load_from_string(self, string):\r
+        if glo_passphrase_callback:\r
+            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string, functools.partial(glo_passphrase_callback, self, string) )\r
+            self.m2key = M2Crypto.EVP.load_key_string(string, functools.partial(glo_passphrase_callback, self, string) )\r
+        else:\r
+            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, string)\r
+            self.m2key = M2Crypto.EVP.load_key_string(string)\r
+\r
+    ##\r
+    #  Load the public key from a string. No private key is loaded.\r
+\r
+    def load_pubkey_from_file(self, filename):\r
+        # load the m2 public key\r
+        m2rsakey = M2Crypto.RSA.load_pub_key(filename)\r
+        self.m2key = M2Crypto.EVP.PKey()\r
+        self.m2key.assign_rsa(m2rsakey)\r
+\r
+        # create an m2 x509 cert\r
+        m2name = M2Crypto.X509.X509_Name()\r
+        m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0)\r
+        m2x509 = M2Crypto.X509.X509()\r
+        m2x509.set_pubkey(self.m2key)\r
+        m2x509.set_serial_number(0)\r
+        m2x509.set_issuer_name(m2name)\r
+        m2x509.set_subject_name(m2name)\r
+        ASN1 = M2Crypto.ASN1.ASN1_UTCTIME()\r
+        ASN1.set_time(500)\r
+        m2x509.set_not_before(ASN1)\r
+        m2x509.set_not_after(ASN1)\r
+        # x509v3 so it can have extensions\r
+        # prob not necc since this cert itself is junk but still...\r
+        m2x509.set_version(2)\r
+        junk_key = Keypair(create=True)\r
+        m2x509.sign(pkey=junk_key.get_m2_pkey(), md="sha1")\r
+\r
+        # convert the m2 x509 cert to a pyopenssl x509\r
+        m2pem = m2x509.as_pem()\r
+        pyx509 = crypto.load_certificate(crypto.FILETYPE_PEM, m2pem)\r
+\r
+        # get the pyopenssl pkey from the pyopenssl x509\r
+        self.key = pyx509.get_pubkey()\r
+        self.filename=filename\r
+\r
+    ##\r
+    # Load the public key from a string. No private key is loaded.\r
+\r
+    def load_pubkey_from_string(self, string):\r
+        (f, fn) = tempfile.mkstemp()\r
+        os.write(f, string)\r
+        os.close(f)\r
+        self.load_pubkey_from_file(fn)\r
+        os.remove(fn)\r
+\r
+    ##\r
+    # Return the private key in PEM format.\r
+\r
+    def as_pem(self):\r
+        return crypto.dump_privatekey(crypto.FILETYPE_PEM, self.key)\r
+\r
+    ##\r
+    # Return an M2Crypto key object\r
+\r
+    def get_m2_pkey(self):\r
+        if not self.m2key:\r
+            self.m2key = M2Crypto.EVP.load_key_string(self.as_pem())\r
+        return self.m2key\r
+\r
+    ##\r
+    # Returns a string containing the public key represented by this object.\r
+\r
+    def get_pubkey_string(self):\r
+        m2pkey = self.get_m2_pkey()\r
+        return base64.b64encode(m2pkey.as_der())\r
+\r
+    ##\r
+    # Return an OpenSSL pkey object\r
+\r
+    def get_openssl_pkey(self):\r
+        return self.key\r
+\r
+    ##\r
+    # Given another Keypair object, return TRUE if the two keys are the same.\r
+\r
+    def is_same(self, pkey):\r
+        return self.as_pem() == pkey.as_pem()\r
+\r
+    def sign_string(self, data):\r
+        k = self.get_m2_pkey()\r
+        k.sign_init()\r
+        k.sign_update(data)\r
+        return base64.b64encode(k.sign_final())\r
+\r
+    def verify_string(self, data, sig):\r
+        k = self.get_m2_pkey()\r
+        k.verify_init()\r
+        k.verify_update(data)\r
+        return M2Crypto.m2.verify_final(k.ctx, base64.b64decode(sig), k.pkey)\r
+\r
+    def compute_hash(self, value):\r
+        return self.sign_string(str(value))\r
+\r
+    # only informative\r
+    def get_filename(self):\r
+        return getattr(self,'filename',None)\r
+\r
+    def dump (self, *args, **kwargs):\r
+        print self.dump_string(*args, **kwargs)\r
+\r
+    def dump_string (self):\r
+        result=""\r
+        result += "KEYPAIR: pubkey=%40s..."%self.get_pubkey_string()\r
+        filename=self.get_filename()\r
+        if filename: result += "Filename %s\n"%filename\r
+        return result\r
+\r
+##\r
+# The certificate class implements a general purpose X509 certificate, making\r
+# use of the appropriate pyOpenSSL or M2Crypto abstractions. It also adds\r
+# several addition features, such as the ability to maintain a chain of\r
+# parent certificates, and storage of application-specific data.\r
+#\r
+# Certificates include the ability to maintain a chain of parents. Each\r
+# certificate includes a pointer to it's parent certificate. When loaded\r
+# from a file or a string, the parent chain will be automatically loaded.\r
+# When saving a certificate to a file or a string, the caller can choose\r
+# whether to save the parent certificates as well.\r
+\r
+class Certificate:\r
+    digest = "md5"\r
+\r
+    cert = None\r
+    issuerKey = None\r
+    issuerSubject = None\r
+    parent = None\r
+    isCA = None # will be a boolean once set\r
+\r
+    separator="-----parent-----"\r
+\r
+    ##\r
+    # Create a certificate object.\r
+    #\r
+    # @param lifeDays life of cert in days - default is 1825==5 years\r
+    # @param create If create==True, then also create a blank X509 certificate.\r
+    # @param subject If subject!=None, then create a blank certificate and set\r
+    #     it's subject name.\r
+    # @param string If string!=None, load the certficate from the string.\r
+    # @param filename If filename!=None, load the certficiate from the file.\r
+    # @param isCA If !=None, set whether this cert is for a CA\r
+\r
+    def __init__(self, lifeDays=1825, create=False, subject=None, string=None, filename=None, isCA=None):\r
+        self.data = {}\r
+        if create or subject:\r
+            self.create(lifeDays)\r
+        if subject:\r
+            self.set_subject(subject)\r
+        if string:\r
+            self.load_from_string(string)\r
+        if filename:\r
+            self.load_from_file(filename)\r
+\r
+        # Set the CA bit if a value was supplied\r
+        if isCA != None:\r
+            self.set_is_ca(isCA)\r
+\r
+    # Create a blank X509 certificate and store it in this object.\r
+\r
+    def create(self, lifeDays=1825):\r
+        self.cert = crypto.X509()\r
+        # FIXME: Use different serial #s\r
+        self.cert.set_serial_number(3)\r
+        self.cert.gmtime_adj_notBefore(0) # 0 means now\r
+        self.cert.gmtime_adj_notAfter(lifeDays*60*60*24) # five years is default\r
+        self.cert.set_version(2) # x509v3 so it can have extensions\r
+\r
+\r
+    ##\r
+    # Given a pyOpenSSL X509 object, store that object inside of this\r
+    # certificate object.\r
+\r
+    def load_from_pyopenssl_x509(self, x509):\r
+        self.cert = x509\r
+\r
+    ##\r
+    # Load the certificate from a string\r
+\r
+    def load_from_string(self, string):\r
+        # if it is a chain of multiple certs, then split off the first one and\r
+        # load it (support for the ---parent--- tag as well as normal chained certs)\r
+\r
+        string = string.strip()\r
+        \r
+        # If it's not in proper PEM format, wrap it\r
+        if string.count('-----BEGIN CERTIFICATE') == 0:\r
+            string = '-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----' % string\r
+\r
+        # If there is a PEM cert in there, but there is some other text first\r
+        # such as the text of the certificate, skip the text\r
+        beg = string.find('-----BEGIN CERTIFICATE')\r
+        if beg > 0:\r
+            # skipping over non cert beginning                                                                                                              \r
+            string = string[beg:]\r
+\r
+        parts = []\r
+\r
+        if string.count('-----BEGIN CERTIFICATE-----') > 1 and \\r
+               string.count(Certificate.separator) == 0:\r
+            parts = string.split('-----END CERTIFICATE-----',1)\r
+            parts[0] += '-----END CERTIFICATE-----'\r
+        else:\r
+            parts = string.split(Certificate.separator, 1)\r
+\r
+        self.cert = crypto.load_certificate(crypto.FILETYPE_PEM, parts[0])\r
+\r
+        # if there are more certs, then create a parent and let the parent load\r
+        # itself from the remainder of the string\r
+        if len(parts) > 1 and parts[1] != '':\r
+            self.parent = self.__class__()\r
+            self.parent.load_from_string(parts[1])\r
+\r
+    ##\r
+    # Load the certificate from a file\r
+\r
+    def load_from_file(self, filename):\r
+        file = open(filename)\r
+        string = file.read()\r
+        self.load_from_string(string)\r
+        self.filename=filename\r
+\r
+    ##\r
+    # Save the certificate to a string.\r
+    #\r
+    # @param save_parents If save_parents==True, then also save the parent certificates.\r
+\r
+    def save_to_string(self, save_parents=True):\r
+        string = crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert)\r
+        if save_parents and self.parent:\r
+            string = string + self.parent.save_to_string(save_parents)\r
+        return string\r
+\r
+    ##\r
+    # Save the certificate to a file.\r
+    # @param save_parents If save_parents==True, then also save the parent certificates.\r
+\r
+    def save_to_file(self, filename, save_parents=True, filep=None):\r
+        string = self.save_to_string(save_parents=save_parents)\r
+        if filep:\r
+            f = filep\r
+        else:\r
+            f = open(filename, 'w')\r
+        f.write(string)\r
+        f.close()\r
+        self.filename=filename\r
+\r
+    ##\r
+    # Save the certificate to a random file in /tmp/\r
+    # @param save_parents If save_parents==True, then also save the parent certificates.\r
+    def save_to_random_tmp_file(self, save_parents=True):\r
+        fp, filename = mkstemp(suffix='cert', text=True)\r
+        fp = os.fdopen(fp, "w")\r
+        self.save_to_file(filename, save_parents=True, filep=fp)\r
+        return filename\r
+\r
+    ##\r
+    # Sets the issuer private key and name\r
+    # @param key Keypair object containing the private key of the issuer\r
+    # @param subject String containing the name of the issuer\r
+    # @param cert (optional) Certificate object containing the name of the issuer\r
+\r
+    def set_issuer(self, key, subject=None, cert=None):\r
+        self.issuerKey = key\r
+        if subject:\r
+            # it's a mistake to use subject and cert params at the same time\r
+            assert(not cert)\r
+            if isinstance(subject, dict) or isinstance(subject, str):\r
+                req = crypto.X509Req()\r
+                reqSubject = req.get_subject()\r
+                if (isinstance(subject, dict)):\r
+                    for key in reqSubject.keys():\r
+                        setattr(reqSubject, key, subject[key])\r
+                else:\r
+                    setattr(reqSubject, "CN", subject)\r
+                subject = reqSubject\r
+                # subject is not valid once req is out of scope, so save req\r
+                self.issuerReq = req\r
+        if cert:\r
+            # if a cert was supplied, then get the subject from the cert\r
+            subject = cert.cert.get_subject()\r
+        assert(subject)\r
+        self.issuerSubject = subject\r
+\r
+    ##\r
+    # Get the issuer name\r
+\r
+    def get_issuer(self, which="CN"):\r
+        x = self.cert.get_issuer()\r
+        return getattr(x, which)\r
+\r
+    ##\r
+    # Set the subject name of the certificate\r
+\r
+    def set_subject(self, name):\r
+        req = crypto.X509Req()\r
+        subj = req.get_subject()\r
+        if (isinstance(name, dict)):\r
+            for key in name.keys():\r
+                setattr(subj, key, name[key])\r
+        else:\r
+            setattr(subj, "CN", name)\r
+        self.cert.set_subject(subj)\r
+\r
+    ##\r
+    # Get the subject name of the certificate\r
+\r
+    def get_subject(self, which="CN"):\r
+        x = self.cert.get_subject()\r
+        return getattr(x, which)\r
+\r
+    ##\r
+    # Get a pretty-print subject name of the certificate\r
+\r
+    def get_printable_subject(self):\r
+        x = self.cert.get_subject()\r
+        return "[ OU: %s, CN: %s, SubjectAltName: %s ]" % (getattr(x, "OU"), getattr(x, "CN"), self.get_data())\r
+\r
+    ##\r
+    # Get the public key of the certificate.\r
+    #\r
+    # @param key Keypair object containing the public key\r
+\r
+    def set_pubkey(self, key):\r
+        assert(isinstance(key, Keypair))\r
+        self.cert.set_pubkey(key.get_openssl_pkey())\r
+\r
+    ##\r
+    # Get the public key of the certificate.\r
+    # It is returned in the form of a Keypair object.\r
+\r
+    def get_pubkey(self):\r
+        m2x509 = X509.load_cert_string(self.save_to_string())\r
+        pkey = Keypair()\r
+        pkey.key = self.cert.get_pubkey()\r
+        pkey.m2key = m2x509.get_pubkey()\r
+        return pkey\r
+\r
+    def set_intermediate_ca(self, val):\r
+        return self.set_is_ca(val)\r
+\r
+    # Set whether this cert is for a CA. All signers and only signers should be CAs.\r
+    # The local member starts unset, letting us check that you only set it once\r
+    # @param val Boolean indicating whether this cert is for a CA\r
+    def set_is_ca(self, val):\r
+        if val is None:\r
+            return\r
+\r
+        if self.isCA != None:\r
+            # Can't double set properties\r
+            raise "Cannot set basicConstraints CA:?? more than once. Was %s, trying to set as %s" % (self.isCA, val)\r
+\r
+        self.isCA = val\r
+        if val:\r
+            self.add_extension('basicConstraints', 1, 'CA:TRUE')\r
+        else:\r
+            self.add_extension('basicConstraints', 1, 'CA:FALSE')\r
+\r
+\r
+\r
+    ##\r
+    # Add an X509 extension to the certificate. Add_extension can only be called\r
+    # once for a particular extension name, due to limitations in the underlying\r
+    # library.\r
+    #\r
+    # @param name string containing name of extension\r
+    # @param value string containing value of the extension\r
+\r
+    def add_extension(self, name, critical, value):\r
+        oldExtVal = None\r
+        try:\r
+            oldExtVal = self.get_extension(name)\r
+        except:\r
+            # M2Crypto LookupError when the extension isn't there (yet)\r
+            pass\r
+\r
+        # This code limits you from adding the extension with the same value\r
+        # The method comment says you shouldn't do this with the same name\r
+        # But actually it (m2crypto) appears to allow you to do this.\r
+        if oldExtVal and oldExtVal == value:\r
+            # don't add this extension again\r
+            # just do nothing as here\r
+            return\r
+        # FIXME: What if they are trying to set with a different value?\r
+        # Is this ever OK? Or should we raise an exception?\r
+#        elif oldExtVal:\r
+#            raise "Cannot add extension %s which had val %s with new val %s" % (name, oldExtVal, value)\r
+\r
+        ext = crypto.X509Extension (name, critical, value)\r
+        self.cert.add_extensions([ext])\r
+\r
+    ##\r
+    # Get an X509 extension from the certificate\r
+\r
+    def get_extension(self, name):\r
+\r
+        # pyOpenSSL does not have a way to get extensions\r
+        m2x509 = X509.load_cert_string(self.save_to_string())\r
+        value = m2x509.get_ext(name).get_value()\r
+\r
+        return value\r
+\r
+    ##\r
+    # Set_data is a wrapper around add_extension. It stores the parameter str in\r
+    # the X509 subject_alt_name extension. Set_data can only be called once, due\r
+    # to limitations in the underlying library.\r
+\r
+    def set_data(self, str, field='subjectAltName'):\r
+        # pyOpenSSL only allows us to add extensions, so if we try to set the\r
+        # same extension more than once, it will not work\r
+        if self.data.has_key(field):\r
+            raise "Cannot set ", field, " more than once"\r
+        self.data[field] = str\r
+        self.add_extension(field, 0, str)\r
+\r
+    ##\r
+    # Return the data string that was previously set with set_data\r
+\r
+    def get_data(self, field='subjectAltName'):\r
+        if self.data.has_key(field):\r
+            return self.data[field]\r
+\r
+        try:\r
+            uri = self.get_extension(field)\r
+            self.data[field] = uri\r
+        except LookupError:\r
+            return None\r
+\r
+        return self.data[field]\r
+\r
+    ##\r
+    # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer().\r
+\r
+    def sign(self):\r
+        logger.debug('certificate.sign')\r
+        assert self.cert != None\r
+        assert self.issuerSubject != None\r
+        assert self.issuerKey != None\r
+        self.cert.set_issuer(self.issuerSubject)\r
+        self.cert.sign(self.issuerKey.get_openssl_pkey(), self.digest)\r
+\r
+    ##\r
+    # Verify the authenticity of a certificate.\r
+    # @param pkey is a Keypair object representing a public key. If Pkey\r
+    #     did not sign the certificate, then an exception will be thrown.\r
+\r
+    def verify(self, pkey):\r
+        # pyOpenSSL does not have a way to verify signatures\r
+        m2x509 = X509.load_cert_string(self.save_to_string())\r
+        m2pkey = pkey.get_m2_pkey()\r
+        # verify it\r
+        return m2x509.verify(m2pkey)\r
+\r
+        # XXX alternatively, if openssl has been patched, do the much simpler:\r
+        # try:\r
+        #   self.cert.verify(pkey.get_openssl_key())\r
+        #   return 1\r
+        # except:\r
+        #   return 0\r
+\r
+    ##\r
+    # Return True if pkey is identical to the public key that is contained in the certificate.\r
+    # @param pkey Keypair object\r
+\r
+    def is_pubkey(self, pkey):\r
+        return self.get_pubkey().is_same(pkey)\r
+\r
+    ##\r
+    # Given a certificate cert, verify that this certificate was signed by the\r
+    # public key contained in cert. Throw an exception otherwise.\r
+    #\r
+    # @param cert certificate object\r
+\r
+    def is_signed_by_cert(self, cert):\r
+        k = cert.get_pubkey()\r
+        result = self.verify(k)\r
+        return result\r
+\r
+    ##\r
+    # Set the parent certficiate.\r
+    #\r
+    # @param p certificate object.\r
+\r
+    def set_parent(self, p):\r
+        self.parent = p\r
+\r
+    ##\r
+    # Return the certificate object of the parent of this certificate.\r
+\r
+    def get_parent(self):\r
+        return self.parent\r
+\r
+    ##\r
+    # Verification examines a chain of certificates to ensure that each parent\r
+    # signs the child, and that some certificate in the chain is signed by a\r
+    # trusted certificate.\r
+    #\r
+    # Verification is a basic recursion: <pre>\r
+    #     if this_certificate was signed by trusted_certs:\r
+    #         return\r
+    #     else\r
+    #         return verify_chain(parent, trusted_certs)\r
+    # </pre>\r
+    #\r
+    # At each recursion, the parent is tested to ensure that it did sign the\r
+    # child. If a parent did not sign a child, then an exception is thrown. If\r
+    # the bottom of the recursion is reached and the certificate does not match\r
+    # a trusted root, then an exception is thrown.\r
+    # Also require that parents are CAs.\r
+    #\r
+    # @param Trusted_certs is a list of certificates that are trusted.\r
+    #\r
+\r
+    def verify_chain(self, trusted_certs = None):\r
+        # Verify a chain of certificates. Each certificate must be signed by\r
+        # the public key contained in it's parent. The chain is recursed\r
+        # until a certificate is found that is signed by a trusted root.\r
+\r
+        # verify expiration time\r
+        if self.cert.has_expired():\r
+            logger.debug("verify_chain: NO, Certificate %s has expired" % self.get_printable_subject())\r
+            raise CertExpired(self.get_printable_subject(), "client cert")\r
+\r
+        # if this cert is signed by a trusted_cert, then we are set\r
+        for trusted_cert in trusted_certs:\r
+            if self.is_signed_by_cert(trusted_cert):\r
+                # verify expiration of trusted_cert ?\r
+                if not trusted_cert.cert.has_expired():\r
+                    logger.debug("verify_chain: YES. Cert %s signed by trusted cert %s"%(\r
+                            self.get_printable_subject(), trusted_cert.get_printable_subject()))\r
+                    return trusted_cert\r
+                else:\r
+                    logger.debug("verify_chain: NO. Cert %s is signed by trusted_cert %s, but that signer is expired..."%(\r
+                            self.get_printable_subject(),trusted_cert.get_printable_subject()))\r
+                    raise CertExpired(self.get_printable_subject()," signer trusted_cert %s"%trusted_cert.get_printable_subject())\r
+\r
+        # if there is no parent, then no way to verify the chain\r
+        if not self.parent:\r
+            logger.debug("verify_chain: NO. %s has no parent and issuer %s is not in %d trusted roots"%(self.get_printable_subject(), self.get_issuer(), len(trusted_certs)))\r
+            raise CertMissingParent(self.get_printable_subject() + ": Issuer %s not trusted by any of %d trusted roots, and cert has no parent." % (self.get_issuer(), len(trusted_certs)))\r
+\r
+        # if it wasn't signed by the parent...\r
+        if not self.is_signed_by_cert(self.parent):\r
+            logger.debug("verify_chain: NO. %s is not signed by parent %s, but by %s"%self.get_printable_subject(), self.parent.get_printable_subject(), self.get_issuer())\r
+            raise CertNotSignedByParent(self.get_printable_subject() + ": Parent %s, issuer %s" % (self.parent.get_printable_subject(), self.get_issuer()))\r
+\r
+        # Confirm that the parent is a CA. Only CAs can be trusted as\r
+        # signers.\r
+        # Note that trusted roots are not parents, so don't need to be\r
+        # CAs.\r
+        # Ugly - cert objects aren't parsed so we need to read the\r
+        # extension and hope there are no other basicConstraints\r
+        if not self.parent.isCA and not (self.parent.get_extension('basicConstraints') == 'CA:TRUE'):\r
+            logger.warn("verify_chain: cert %s's parent %s is not a CA" % (self.get_printable_subject(), self.parent.get_printable_subject()))\r
+            raise CertNotSignedByParent(self.get_printable_subject() + ": Parent %s not a CA" % self.parent.get_printable_subject())\r
+\r
+        # if the parent isn't verified...\r
+        logger.debug("verify_chain: .. %s, -> verifying parent %s"%(self.get_printable_subject(),self.parent.get_printable_subject()))\r
+        self.parent.verify_chain(trusted_certs)\r
+\r
+        return\r
+\r
+    ### more introspection\r
+    def get_extensions(self):\r
+        # pyOpenSSL does not have a way to get extensions\r
+        triples=[]\r
+        m2x509 = X509.load_cert_string(self.save_to_string())\r
+        nb_extensions=m2x509.get_ext_count()\r
+        logger.debug("X509 had %d extensions"%nb_extensions)\r
+        for i in range(nb_extensions):\r
+            ext=m2x509.get_ext_at(i)\r
+            triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) )\r
+        return triples\r
+\r
+    def get_data_names(self):\r
+        return self.data.keys()\r
+\r
+    def get_all_datas (self):\r
+        triples=self.get_extensions()\r
+        for name in self.get_data_names():\r
+            triples.append( (name,self.get_data(name),'data',) )\r
+        return triples\r
+\r
+    # only informative\r
+    def get_filename(self):\r
+        return getattr(self,'filename',None)\r
+\r
+    def dump (self, *args, **kwargs):\r
+        print self.dump_string(*args, **kwargs)\r
+\r
+    def dump_string (self,show_extensions=False):\r
+        result = ""\r
+        result += "CERTIFICATE for %s\n"%self.get_printable_subject()\r
+        result += "Issued by %s\n"%self.get_issuer()\r
+        filename=self.get_filename()\r
+        if filename: result += "Filename %s\n"%filename\r
+        if show_extensions:\r
+            all_datas=self.get_all_datas()\r
+            result += " has %d extensions/data attached"%len(all_datas)\r
+            for (n,v,c) in all_datas:\r
+                if c=='data':\r
+                    result += "   data: %s=%s\n"%(n,v)\r
+                else:\r
+                    result += "    ext: %s (crit=%s)=<<<%s>>>\n"%(n,c,v)\r
+        return result\r
index 5ac987a..a18019d 100644 (file)
-#----------------------------------------------------------------------
-# Copyright (c) 2008 Board of Trustees, Princeton University
-#
-# Permission is hereby granted, free of charge, to any person obtaining
-# a copy of this software and/or hardware specification (the "Work") to
-# deal in the Work without restriction, including without limitation the
-# rights to use, copy, modify, merge, publish, distribute, sublicense,
-# and/or sell copies of the Work, and to permit persons to whom the Work
-# is furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be
-# included in all copies or substantial portions of the Work.
-#
-# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
-# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 
-# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
-# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
-# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
-# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
-# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS 
-# IN THE WORK.
-#----------------------------------------------------------------------
-##
-# Implements SFA Credentials
-#
-# Credentials are signed XML files that assign a subject gid privileges to an object gid
-##
-
-import os
-import datetime
-from tempfile import mkstemp
-import dateutil.parser
-from StringIO import StringIO 
-from xml.dom.minidom import Document, parseString
-from lxml import etree
-
-from sfa.util.faults import *
-from sfa.util.sfalogging import sfa_logger
-from sfa.trust.certificate import Keypair
-from sfa.trust.credential_legacy import CredentialLegacy
-from sfa.trust.rights import Right, Rights
-from sfa.trust.gid import GID
-from sfa.util.xrn import urn_to_hrn
-
-# 2 weeks, in seconds 
-DEFAULT_CREDENTIAL_LIFETIME = 86400 * 14
-
-
-# TODO:
-# . make privs match between PG and PL
-# . Need to add support for other types of credentials, e.g. tickets
-
-
-signature_template = \
-'''
-<Signature xml:id="Sig_%s" xmlns="http://www.w3.org/2000/09/xmldsig#">
-    <SignedInfo>
-      <CanonicalizationMethod Algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315"/>
-      <SignatureMethod Algorithm="http://www.w3.org/2000/09/xmldsig#rsa-sha1"/>
-      <Reference URI="#%s">
-      <Transforms>
-        <Transform Algorithm="http://www.w3.org/2000/09/xmldsig#enveloped-signature" />
-      </Transforms>
-      <DigestMethod Algorithm="http://www.w3.org/2000/09/xmldsig#sha1"/>
-      <DigestValue></DigestValue>
-      </Reference>
-    </SignedInfo>
-    <SignatureValue />
-      <KeyInfo>
-        <X509Data>
-          <X509SubjectName/>
-          <X509IssuerSerial/>
-          <X509Certificate/>
-        </X509Data>
-      <KeyValue />
-      </KeyInfo>
-    </Signature>
-'''
-
-##
-# Convert a string into a bool
-
-def str2bool(str):
-    if str.lower() in ['yes','true','1']:
-        return True
-    return False
-
-
-##
-# Utility function to get the text of an XML element
-
-def getTextNode(element, subele):
-    sub = element.getElementsByTagName(subele)[0]
-    if len(sub.childNodes) > 0:            
-        return sub.childNodes[0].nodeValue
-    else:
-        return None
-        
-##
-# Utility function to set the text of an XML element
-# It creates the element, adds the text to it,
-# and then appends it to the parent.
-
-def append_sub(doc, parent, element, text):
-    ele = doc.createElement(element)
-    ele.appendChild(doc.createTextNode(text))
-    parent.appendChild(ele)
-
-##
-# Signature contains information about an xmlsec1 signature
-# for a signed-credential
-#
-
-class Signature(object):
-   
-    def __init__(self, string=None):
-        self.refid = None
-        self.issuer_gid = None
-        self.xml = None
-        if string:
-            self.xml = string
-            self.decode()
-
-
-    def get_refid(self):
-        if not self.refid:
-            self.decode()
-        return self.refid
-
-    def get_xml(self):
-        if not self.xml:
-            self.encode()
-        return self.xml
-
-    def set_refid(self, id):
-        self.refid = id
-
-    def get_issuer_gid(self):
-        if not self.gid:
-            self.decode()
-        return self.gid        
-
-    def set_issuer_gid(self, gid):
-        self.gid = gid
-
-    def decode(self):
-        doc = parseString(self.xml)
-        sig = doc.getElementsByTagName("Signature")[0]
-        self.set_refid(sig.getAttribute("xml:id").strip("Sig_"))
-        keyinfo = sig.getElementsByTagName("X509Data")[0]
-        szgid = getTextNode(keyinfo, "X509Certificate")
-        szgid = "-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----" % szgid
-        self.set_issuer_gid(GID(string=szgid))        
-        
-    def encode(self):
-        self.xml = signature_template % (self.get_refid(), self.get_refid())
-
-
-##
-# A credential provides a caller gid with privileges to an object gid.
-# A signed credential is signed by the object's authority.
-#
-# Credentials are encoded in one of two ways.  The legacy style places
-# it in the subjectAltName of an X509 certificate.  The new credentials
-# are placed in signed XML.
-#
-# WARNING:
-# In general, a signed credential obtained externally should
-# not be changed else the signature is no longer valid.  So, once
-# you have loaded an existing signed credential, do not call encode() or sign() on it.
-
-def filter_creds_by_caller(creds, caller_hrn):
-        """
-        Returns a list of creds who's gid caller matches the
-        specified caller hrn
-        """
-        if not isinstance(creds, list): creds = [creds]
-        caller_creds = []
-        for cred in creds:
-            try:
-                tmp_cred = Credential(string=cred)
-                if tmp_cred.get_gid_caller().get_hrn() == caller_hrn:
-                    caller_creds.append(cred)
-            except: pass
-        return caller_creds
-
-class Credential(object):
-
-    ##
-    # Create a Credential object
-    #
-    # @param create If true, create a blank x509 certificate
-    # @param subject If subject!=None, create an x509 cert with the subject name
-    # @param string If string!=None, load the credential from the string
-    # @param filename If filename!=None, load the credential from the file
-    # FIXME: create and subject are ignored!
-    def __init__(self, create=False, subject=None, string=None, filename=None):
-        self.gidCaller = None
-        self.gidObject = None
-        self.expiration = None
-        self.privileges = None
-        self.issuer_privkey = None
-        self.issuer_gid = None
-        self.issuer_pubkey = None
-        self.parent = None
-        self.signature = None
-        self.xml = None
-        self.refid = None
-        self.legacy = None
-
-        # Check if this is a legacy credential, translate it if so
-        if string or filename:
-            if string:                
-                str = string
-            elif filename:
-                str = file(filename).read()
-                self.filename=filename
-                
-            if str.strip().startswith("-----"):
-                self.legacy = CredentialLegacy(False,string=str)
-                self.translate_legacy(str)
-            else:
-                self.xml = str
-                self.decode()
-
-        # Find an xmlsec1 path
-        self.xmlsec_path = ''
-        paths = ['/usr/bin','/usr/local/bin','/bin','/opt/bin','/opt/local/bin']
-        for path in paths:
-            if os.path.isfile(path + '/' + 'xmlsec1'):
-                self.xmlsec_path = path + '/' + 'xmlsec1'
-                break
-
-    def get_subject(self):
-        if not self.gidObject:
-            self.decode()
-        return self.gidObject.get_subject()   
-
-    def get_signature(self):
-        if not self.signature:
-            self.decode()
-        return self.signature
-
-    def set_signature(self, sig):
-        self.signature = sig
-
-        
-    ##
-    # Translate a legacy credential into a new one
-    #
-    # @param String of the legacy credential
-
-    def translate_legacy(self, str):
-        legacy = CredentialLegacy(False,string=str)
-        self.gidCaller = legacy.get_gid_caller()
-        self.gidObject = legacy.get_gid_object()
-        lifetime = legacy.get_lifetime()
-        if not lifetime:
-            self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME))
-        else:
-            self.set_expiration(int(lifetime))
-        self.lifeTime = legacy.get_lifetime()
-        self.set_privileges(legacy.get_privileges())
-        self.get_privileges().delegate_all_privileges(legacy.get_delegate())
-
-    ##
-    # Need the issuer's private key and name
-    # @param key Keypair object containing the private key of the issuer
-    # @param gid GID of the issuing authority
-
-    def set_issuer_keys(self, privkey, gid):
-        self.issuer_privkey = privkey
-        self.issuer_gid = gid
-
-
-    ##
-    # Set this credential's parent
-    def set_parent(self, cred):
-        self.parent = cred
-        self.updateRefID()
-
-    ##
-    # set the GID of the caller
-    #
-    # @param gid GID object of the caller
-
-    def set_gid_caller(self, gid):
-        self.gidCaller = gid
-        # gid origin caller is the caller's gid by default
-        self.gidOriginCaller = gid
-
-    ##
-    # get the GID of the object
-
-    def get_gid_caller(self):
-        if not self.gidCaller:
-            self.decode()
-        return self.gidCaller
-
-    ##
-    # set the GID of the object
-    #
-    # @param gid GID object of the object
-
-    def set_gid_object(self, gid):
-        self.gidObject = gid
-
-    ##
-    # get the GID of the object
-
-    def get_gid_object(self):
-        if not self.gidObject:
-            self.decode()
-        return self.gidObject
-
-
-            
-    ##
-    # Expiration: an absolute UTC time of expiration (as either an int or datetime)
-    # 
-    def set_expiration(self, expiration):
-        if isinstance(expiration, int):
-            self.expiration = datetime.datetime.fromtimestamp(expiration)
-        else:
-            self.expiration = expiration
-            
-
-    ##
-    # get the lifetime of the credential (in datetime format)
-
-    def get_expiration(self):
-        if not self.expiration:
-            self.decode()
-        return self.expiration
-
-    ##
-    # For legacy sake
-    def get_lifetime(self):
-        return self.get_expiration()
-    ##
-    # set the privileges
-    #
-    # @param privs either a comma-separated list of privileges of a Rights object
-
-    def set_privileges(self, privs):
-        if isinstance(privs, str):
-            self.privileges = Rights(string = privs)
-        else:
-            self.privileges = privs
-        
-
-    ##
-    # return the privileges as a Rights object
-
-    def get_privileges(self):
-        if not self.privileges:
-            self.decode()
-        return self.privileges
-
-    ##
-    # determine whether the credential allows a particular operation to be
-    # performed
-    #
-    # @param op_name string specifying name of operation ("lookup", "update", etc)
-
-    def can_perform(self, op_name):
-        rights = self.get_privileges()
-        
-        if not rights:
-            return False
-
-        return rights.can_perform(op_name)
-
-
-    ##
-    # Encode the attributes of the credential into an XML string    
-    # This should be done immediately before signing the credential.    
-    # WARNING:
-    # In general, a signed credential obtained externally should
-    # not be changed else the signature is no longer valid.  So, once
-    # you have loaded an existing signed credential, do not call encode() or sign() on it.
-
-    def encode(self):
-        # Create the XML document
-        doc = Document()
-        signed_cred = doc.createElement("signed-credential")
-        doc.appendChild(signed_cred)  
-        
-        # Fill in the <credential> bit        
-        cred = doc.createElement("credential")
-        cred.setAttribute("xml:id", self.get_refid())
-        signed_cred.appendChild(cred)
-        append_sub(doc, cred, "type", "privilege")
-        append_sub(doc, cred, "serial", "8")
-        append_sub(doc, cred, "owner_gid", self.gidCaller.save_to_string())
-        append_sub(doc, cred, "owner_urn", self.gidCaller.get_urn())
-        append_sub(doc, cred, "target_gid", self.gidObject.save_to_string())
-        append_sub(doc, cred, "target_urn", self.gidObject.get_urn())
-        append_sub(doc, cred, "uuid", "")
-        if not self.expiration:
-            self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME))
-        self.expiration = self.expiration.replace(microsecond=0)
-        append_sub(doc, cred, "expires", self.expiration.isoformat())
-        privileges = doc.createElement("privileges")
-        cred.appendChild(privileges)
-
-        if self.privileges:
-            rights = self.get_privileges()
-            for right in rights.rights:
-                priv = doc.createElement("privilege")
-                append_sub(doc, priv, "name", right.kind)
-                append_sub(doc, priv, "can_delegate", str(right.delegate).lower())
-                privileges.appendChild(priv)
-
-        # Add the parent credential if it exists
-        if self.parent:
-            sdoc = parseString(self.parent.get_xml())
-            p_cred = doc.importNode(sdoc.getElementsByTagName("credential")[0], True)
-            p = doc.createElement("parent")
-            p.appendChild(p_cred)
-            cred.appendChild(p)
-
-
-        # Create the <signatures> tag
-        signatures = doc.createElement("signatures")
-        signed_cred.appendChild(signatures)
-
-        # Add any parent signatures
-        if self.parent:
-            for cur_cred in self.get_credential_list()[1:]:
-                sdoc = parseString(cur_cred.get_signature().get_xml())
-                ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True)
-                signatures.appendChild(ele)
-                
-        # Get the finished product
-        self.xml = doc.toxml()
-
-
-    def save_to_random_tmp_file(self):       
-        fp, filename = mkstemp(suffix='cred', text=True)
-        fp = os.fdopen(fp, "w")
-        self.save_to_file(filename, save_parents=True, filep=fp)
-        return filename
-    
-    def save_to_file(self, filename, save_parents=True, filep=None):
-        if not self.xml:
-            self.encode()
-        if filep:
-            f = filep 
-        else:
-            f = open(filename, "w")
-        f.write(self.xml)
-        f.close()
-        self.filename=filename
-
-    def save_to_string(self, save_parents=True):
-        if not self.xml:
-            self.encode()
-        return self.xml
-
-    def get_refid(self):
-        if not self.refid:
-            self.refid = 'ref0'
-        return self.refid
-
-    def set_refid(self, rid):
-        self.refid = rid
-
-    ##
-    # Figure out what refids exist, and update this credential's id
-    # so that it doesn't clobber the others.  Returns the refids of
-    # the parents.
-    
-    def updateRefID(self):
-        if not self.parent:
-            self.set_refid('ref0')
-            return []
-        
-        refs = []
-
-        next_cred = self.parent
-        while next_cred:
-            refs.append(next_cred.get_refid())
-            if next_cred.parent:
-                next_cred = next_cred.parent
-            else:
-                next_cred = None
-
-        
-        # Find a unique refid for this credential
-        rid = self.get_refid()
-        while rid in refs:
-            val = int(rid[3:])
-            rid = "ref%d" % (val + 1)
-
-        # Set the new refid
-        self.set_refid(rid)
-
-        # Return the set of parent credential ref ids
-        return refs
-
-    def get_xml(self):
-        if not self.xml:
-            self.encode()
-        return self.xml
-
-    ##
-    # Sign the XML file created by encode()
-    #
-    # WARNING:
-    # In general, a signed credential obtained externally should
-    # not be changed else the signature is no longer valid.  So, once
-    # you have loaded an existing signed credential, do not call encode() or sign() on it.
-
-    def sign(self):
-        if not self.issuer_privkey or not self.issuer_gid:
-            return
-        doc = parseString(self.get_xml())
-        sigs = doc.getElementsByTagName("signatures")[0]
-
-        # Create the signature template to be signed
-        signature = Signature()
-        signature.set_refid(self.get_refid())
-        sdoc = parseString(signature.get_xml())        
-        sig_ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True)
-        sigs.appendChild(sig_ele)
-
-        self.xml = doc.toxml()
-
-
-        # Split the issuer GID into multiple certificates if it's a chain
-        chain = GID(filename=self.issuer_gid)
-        gid_files = []
-        while chain:
-            gid_files.append(chain.save_to_random_tmp_file(False))
-            if chain.get_parent():
-                chain = chain.get_parent()
-            else:
-                chain = None
-
-
-        # Call out to xmlsec1 to sign it
-        ref = 'Sig_%s' % self.get_refid()
-        filename = self.save_to_random_tmp_file()
-        signed = os.popen('%s --sign --node-id "%s" --privkey-pem %s,%s %s' \
-                 % (self.xmlsec_path, ref, self.issuer_privkey, ",".join(gid_files), filename)).read()
-        os.remove(filename)
-
-        for gid_file in gid_files:
-            os.remove(gid_file)
-
-        self.xml = signed
-
-        # This is no longer a legacy credential
-        if self.legacy:
-            self.legacy = None
-
-        # Update signatures
-        self.decode()       
-
-        
-    ##
-    # Retrieve the attributes of the credential from the XML.
-    # This is automatically called by the various get_* methods of
-    # this class and should not need to be called explicitly.
-
-    def decode(self):
-        if not self.xml:
-            return
-        doc = parseString(self.xml)
-        sigs = []
-        signed_cred = doc.getElementsByTagName("signed-credential")
-
-        # Is this a signed-cred or just a cred?
-        if len(signed_cred) > 0:
-            cred = signed_cred[0].getElementsByTagName("credential")[0]
-            signatures = signed_cred[0].getElementsByTagName("signatures")
-            if len(signatures) > 0:
-                sigs = signatures[0].getElementsByTagName("Signature")
-        else:
-            cred = doc.getElementsByTagName("credential")[0]
-        
-
-        self.set_refid(cred.getAttribute("xml:id"))
-        self.set_expiration(dateutil.parser.parse(getTextNode(cred, "expires")))
-        self.gidCaller = GID(string=getTextNode(cred, "owner_gid"))
-        self.gidObject = GID(string=getTextNode(cred, "target_gid"))   
-
-
-        # Process privileges
-        privs = cred.getElementsByTagName("privileges")[0]
-        rlist = Rights()
-        for priv in privs.getElementsByTagName("privilege"):
-            kind = getTextNode(priv, "name")
-            deleg = str2bool(getTextNode(priv, "can_delegate"))
-            if kind == '*':
-                # Convert * into the default privileges for the credential's type                
-                _ , type = urn_to_hrn(self.gidObject.get_urn())
-                rl = rlist.determine_rights(type, self.gidObject.get_urn())
-                for r in rl.rights:
-                    rlist.add(r)
-            else:
-                rlist.add(Right(kind.strip(), deleg))
-        self.set_privileges(rlist)
-
-
-        # Is there a parent?
-        parent = cred.getElementsByTagName("parent")
-        if len(parent) > 0:
-            parent_doc = parent[0].getElementsByTagName("credential")[0]
-            parent_xml = parent_doc.toxml()
-            self.parent = Credential(string=parent_xml)
-            self.updateRefID()
-
-        # Assign the signatures to the credentials
-        for sig in sigs:
-            Sig = Signature(string=sig.toxml())
-
-            for cur_cred in self.get_credential_list():
-                if cur_cred.get_refid() == Sig.get_refid():
-                    cur_cred.set_signature(Sig)
-                                    
-            
-    ##
-    # Verify
-    #   trusted_certs: A list of trusted GID filenames (not GID objects!) 
-    #                  Chaining is not supported within the GIDs by xmlsec1.
-    #    
-    # Verify that:
-    # . All of the signatures are valid and that the issuers trace back
-    #   to trusted roots (performed by xmlsec1)
-    # . The XML matches the credential schema
-    # . That the issuer of the credential is the authority in the target's urn
-    #    . In the case of a delegated credential, this must be true of the root
-    # . That all of the gids presented in the credential are valid
-    # . The credential is not expired
-    #
-    # -- For Delegates (credentials with parents)
-    # . The privileges must be a subset of the parent credentials
-    # . The privileges must have "can_delegate" set for each delegated privilege
-    # . The target gid must be the same between child and parents
-    # . The expiry time on the child must be no later than the parent
-    # . The signer of the child must be the owner of the parent
-    #
-    # -- Verify does *NOT*
-    # . ensure that an xmlrpc client's gid matches a credential gid, that
-    #   must be done elsewhere
-    #
-    # @param trusted_certs: The certificates of trusted CA certificates
-    # @param schema: The RelaxNG schema to validate the credential against 
-    def verify(self, trusted_certs, schema=None):
-        if not self.xml:
-            self.decode()        
-        
-        # validate against RelaxNG schema
-        if not self.legacy:
-            if schema and os.path.exists(schema):
-                tree = etree.parse(StringIO(self.xml))
-                schema_doc = etree.parse(schema)
-                xmlschema = etree.XMLSchema(schema_doc)
-                if not xmlschema.validate(tree):
-                    error = xmlschema.error_log.last_error
-                    message = "%s (line %s)" % (error.message, error.line)
-                    raise CredentialNotVerifiable(message) 
-            
-
-#       trusted_cert_objects = [GID(filename=f) for f in trusted_certs]
-        trusted_cert_objects = []
-        ok_trusted_certs = []
-        for f in trusted_certs:
-            try:
-                # Failures here include unreadable files
-                # or non PEM files
-                trusted_cert_objects.append(GID(filename=f))
-                ok_trusted_certs.append(f)
-            except Exception, exc:
-                sfa_logger().error("Failed to load trusted cert from %s: %r"%( f, exc))
-        trusted_certs = ok_trusted_certs
-
-        # Use legacy verification if this is a legacy credential
-        if self.legacy:
-            self.legacy.verify_chain(trusted_cert_objects)
-            if self.legacy.client_gid:
-                self.legacy.client_gid.verify_chain(trusted_cert_objects)
-            if self.legacy.object_gid:
-                self.legacy.object_gid.verify_chain(trusted_cert_objects)
-            return True
-
-        
-        # make sure it is not expired
-        if self.get_expiration() < datetime.datetime.utcnow():
-            raise CredentialNotVerifiable("Credential expired at %s" % self.expiration.isoformat())
-
-        # Verify the signatures
-        filename = self.save_to_random_tmp_file()
-        cert_args = " ".join(['--trusted-pem %s' % x for x in trusted_certs])
-
-        # Verify the gids of this cred and of its parents
-        for cur_cred in self.get_credential_list():
-            cur_cred.get_gid_object().verify_chain(trusted_cert_objects)
-            cur_cred.get_gid_caller().verify_chain(trusted_cert_objects) 
-
-        refs = []
-        refs.append("Sig_%s" % self.get_refid())
-
-        parentRefs = self.updateRefID()
-        for ref in parentRefs:
-            refs.append("Sig_%s" % ref)
-
-        for ref in refs:
-            verified = os.popen('%s --verify --node-id "%s" %s %s 2>&1' \
-                            % (self.xmlsec_path, ref, cert_args, filename)).read()
-            if not verified.strip().startswith("OK"):
-                raise CredentialNotVerifiable("xmlsec1 error verifying cert: " + verified)
-        os.remove(filename)
-
-        # Verify the parents (delegation)
-        if self.parent:
-            self.verify_parent(self.parent)
-
-        # Make sure the issuer is the target's authority
-        self.verify_issuer()
-        return True
-
-    ##
-    # Creates a list of the credential and its parents, with the root 
-    # (original delegated credential) as the last item in the list
-    def get_credential_list(self):    
-        cur_cred = self
-        list = []
-        while cur_cred:
-            list.append(cur_cred)
-            if cur_cred.parent:
-                cur_cred = cur_cred.parent
-            else:
-                cur_cred = None
-        return list
-    
-    ##
-    # Make sure the credential's target gid was signed by (or is the same) the entity that signed
-    # the original credential or an authority over that namespace.
-    def verify_issuer(self):                
-        root_cred = self.get_credential_list()[-1]
-        root_target_gid = root_cred.get_gid_object()
-        root_cred_signer = root_cred.get_signature().get_issuer_gid()
-
-        if root_target_gid.is_signed_by_cert(root_cred_signer):
-            # cred signer matches target signer, return success
-            return
-
-        root_target_gid_str = root_target_gid.save_to_string()
-        root_cred_signer_str = root_cred_signer.save_to_string()
-        if root_target_gid_str == root_cred_signer_str:
-            # cred signer is target, return success
-            return
-
-        # See if it the signer is an authority over the domain of the target
-        # Maybe should be (hrn, type) = urn_to_hrn(root_cred_signer.get_urn())
-        root_cred_signer_type = root_cred_signer.get_type()
-        if (root_cred_signer_type == 'authority'):
-            #sfa_logger().debug('Cred signer is an authority')
-            # signer is an authority, see if target is in authority's domain
-            hrn = root_cred_signer.get_hrn()
-            if root_target_gid.get_hrn().startswith(hrn):
-                return
-
-        # We've required that the credential be signed by an authority
-        # for that domain. Reasonable and probably correct.
-        # A looser model would also allow the signer to be an authority
-        # in my control framework - eg My CA or CH. Even if it is not
-        # the CH that issued these, eg, user credentials.
-
-        # Give up, credential does not pass issuer verification
-
-        raise CredentialNotVerifiable("Could not verify credential owned by %s for object %s. Cred signer %s not the trusted authority for Cred target %s" % (self.gidCaller.get_urn(), self.gidObject.get_urn(), root_cred_signer.get_hrn(), root_target_gid.get_hrn()))
-
-
-    ##
-    # -- For Delegates (credentials with parents) verify that:
-    # . The privileges must be a subset of the parent credentials
-    # . The privileges must have "can_delegate" set for each delegated privilege
-    # . The target gid must be the same between child and parents
-    # . The expiry time on the child must be no later than the parent
-    # . The signer of the child must be the owner of the parent        
-    def verify_parent(self, parent_cred):
-        # make sure the rights given to the child are a subset of the
-        # parents rights (and check delegate bits)
-        if not parent_cred.get_privileges().is_superset(self.get_privileges()):
-            raise ChildRightsNotSubsetOfParent(
-                self.parent.get_privileges().save_to_string() + " " +
-                self.get_privileges().save_to_string())
-
-        # make sure my target gid is the same as the parent's
-        if not parent_cred.get_gid_object().save_to_string() == \
-           self.get_gid_object().save_to_string():
-            raise CredentialNotVerifiable("Target gid not equal between parent and child")
-
-        # make sure my expiry time is <= my parent's
-        if not parent_cred.get_expiration() >= self.get_expiration():
-            raise CredentialNotVerifiable("Delegated credential expires after parent")
-
-        # make sure my signer is the parent's caller
-        if not parent_cred.get_gid_caller().save_to_string(False) == \
-           self.get_signature().get_issuer_gid().save_to_string(False):
-            raise CredentialNotVerifiable("Delegated credential not signed by parent caller")
-                
-        # Recurse
-        if parent_cred.parent:
-            parent_cred.verify_parent(parent_cred.parent)
-
-
-    def delegate(self, delegee_gidfile, caller_keyfile, caller_gidfile):
-        """
-        Return a delegated copy of this credential, delegated to the 
-        specified gid's user.    
-        """
-        # get the gid of the object we are delegating
-        object_gid = self.get_gid_object()
-        object_hrn = object_gid.get_hrn()        
-        # the hrn of the user who will be delegated to
-        delegee_gid = GID(filename=delegee_gidfile)
-        delegee_hrn = delegee_gid.get_hrn()
-  
-        #user_key = Keypair(filename=keyfile)
-        #user_hrn = self.get_gid_caller().get_hrn()
-        subject_string = "%s delegated to %s" % (object_hrn, delegee_hrn)
-        dcred = Credential(subject=subject_string)
-        dcred.set_gid_caller(delegee_gid)
-        dcred.set_gid_object(object_gid)
-        dcred.set_parent(self)
-        dcred.set_expiration(self.get_expiration())
-        dcred.set_privileges(self.get_privileges())
-        dcred.get_privileges().delegate_all_privileges(True)
-        #dcred.set_issuer_keys(keyfile, delegee_gidfile)
-        dcred.set_issuer_keys(caller_keyfile, caller_gidfile)
-        dcred.encode()
-        dcred.sign()
-
-        return dcred 
-
-    # only informative
-    def get_filename(self):
-        return getattr(self,'filename',None)
-
-    # @param dump_parents If true, also dump the parent certificates
-    def dump (self, *args, **kwargs):
-        print self.dump_string(*args, **kwargs)
-
-    def dump_string(self, dump_parents=False):
-        result=""
-        result += "CREDENTIAL %s\n" % self.get_subject() 
-        filename=self.get_filename()
-        if filename: result += "Filename %s\n"%filename
-        result += "      privs: %s\n" % self.get_privileges().save_to_string()
-        gidCaller = self.get_gid_caller()
-        if gidCaller:
-            result += "  gidCaller:\n"
-            result += gidCaller.dump_string(8, dump_parents)
-
-        gidObject = self.get_gid_object()
-        if gidObject:
-            result += "  gidObject:\n"
-            result += gidObject.dump_string(8, dump_parents)
-
-        if self.parent and dump_parents:
-            result += "PARENT"
-            result += self.parent.dump_string(dump_parents)
-        return result
-
+#----------------------------------------------------------------------\r
+# Copyright (c) 2008 Board of Trustees, Princeton University\r
+#\r
+# Permission is hereby granted, free of charge, to any person obtaining\r
+# a copy of this software and/or hardware specification (the "Work") to\r
+# deal in the Work without restriction, including without limitation the\r
+# rights to use, copy, modify, merge, publish, distribute, sublicense,\r
+# and/or sell copies of the Work, and to permit persons to whom the Work\r
+# is furnished to do so, subject to the following conditions:\r
+#\r
+# The above copyright notice and this permission notice shall be\r
+# included in all copies or substantial portions of the Work.\r
+#\r
+# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS \r
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF \r
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND \r
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT \r
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, \r
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, \r
+# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS \r
+# IN THE WORK.\r
+#----------------------------------------------------------------------\r
+##\r
+# Implements SFA Credentials\r
+#\r
+# Credentials are signed XML files that assign a subject gid privileges to an object gid\r
+##\r
+\r
+import os\r
+from types import StringTypes\r
+import datetime\r
+from StringIO import StringIO\r
+from tempfile import mkstemp\r
+from xml.dom.minidom import Document, parseString\r
+\r
+HAVELXML = False\r
+try:\r
+    from lxml import etree\r
+    HAVELXML = True\r
+except:\r
+    pass\r
+\r
+from sfa.util.faults import *\r
+from sfa.util.sfalogging import logger\r
+from sfa.util.sfatime import utcparse\r
+from sfa.trust.certificate import Keypair\r
+from sfa.trust.credential_legacy import CredentialLegacy\r
+from sfa.trust.rights import Right, Rights, determine_rights\r
+from sfa.trust.gid import GID\r
+from sfa.util.xrn import urn_to_hrn, hrn_authfor_hrn\r
+\r
+# 2 weeks, in seconds \r
+DEFAULT_CREDENTIAL_LIFETIME = 86400 * 14\r
+\r
+\r
+# TODO:\r
+# . make privs match between PG and PL\r
+# . Need to add support for other types of credentials, e.g. tickets\r
+# . add namespaces to signed-credential element?\r
+\r
+signature_template = \\r
+'''\r
+<Signature xml:id="Sig_%s" xmlns="http://www.w3.org/2000/09/xmldsig#">\r
+  <SignedInfo>\r
+    <CanonicalizationMethod Algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315"/>\r
+    <SignatureMethod Algorithm="http://www.w3.org/2000/09/xmldsig#rsa-sha1"/>\r
+    <Reference URI="#%s">\r
+      <Transforms>\r
+        <Transform Algorithm="http://www.w3.org/2000/09/xmldsig#enveloped-signature" />\r
+      </Transforms>\r
+      <DigestMethod Algorithm="http://www.w3.org/2000/09/xmldsig#sha1"/>\r
+      <DigestValue></DigestValue>\r
+    </Reference>\r
+  </SignedInfo>\r
+  <SignatureValue />\r
+  <KeyInfo>\r
+    <X509Data>\r
+      <X509SubjectName/>\r
+      <X509IssuerSerial/>\r
+      <X509Certificate/>\r
+    </X509Data>\r
+    <KeyValue />\r
+  </KeyInfo>\r
+</Signature>\r
+'''\r
+\r
+# PG formats the template (whitespace) slightly differently.\r
+# Note that they don't include the xmlns in the template, but add it later.\r
+# Otherwise the two are equivalent.\r
+#signature_template_as_in_pg = \\r
+#'''\r
+#<Signature xml:id="Sig_%s" >\r
+# <SignedInfo>\r
+#  <CanonicalizationMethod      Algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315"/>\r
+#  <SignatureMethod      Algorithm="http://www.w3.org/2000/09/xmldsig#rsa-sha1"/>\r
+#  <Reference URI="#%s">\r
+#    <Transforms>\r
+#      <Transform         Algorithm="http://www.w3.org/2000/09/xmldsig#enveloped-signature" />\r
+#    </Transforms>\r
+#    <DigestMethod        Algorithm="http://www.w3.org/2000/09/xmldsig#sha1"/>\r
+#    <DigestValue></DigestValue>\r
+#    </Reference>\r
+# </SignedInfo>\r
+# <SignatureValue />\r
+# <KeyInfo>\r
+#  <X509Data >\r
+#   <X509SubjectName/>\r
+#   <X509IssuerSerial/>\r
+#   <X509Certificate/>\r
+#  </X509Data>\r
+#  <KeyValue />\r
+# </KeyInfo>\r
+#</Signature>\r
+#'''\r
+\r
+##\r
+# Convert a string into a bool\r
+# used to convert an xsd:boolean to a Python boolean\r
+def str2bool(str):\r
+    if str.lower() in ['true','1']:\r
+        return True\r
+    return False\r
+\r
+\r
+##\r
+# Utility function to get the text of an XML element\r
+\r
+def getTextNode(element, subele):\r
+    sub = element.getElementsByTagName(subele)[0]\r
+    if len(sub.childNodes) > 0:            \r
+        return sub.childNodes[0].nodeValue\r
+    else:\r
+        return None\r
+        \r
+##\r
+# Utility function to set the text of an XML element\r
+# It creates the element, adds the text to it,\r
+# and then appends it to the parent.\r
+\r
+def append_sub(doc, parent, element, text):\r
+    ele = doc.createElement(element)\r
+    ele.appendChild(doc.createTextNode(text))\r
+    parent.appendChild(ele)\r
+\r
+##\r
+# Signature contains information about an xmlsec1 signature\r
+# for a signed-credential\r
+#\r
+\r
+class Signature(object):\r
+   \r
+    def __init__(self, string=None):\r
+        self.refid = None\r
+        self.issuer_gid = None\r
+        self.xml = None\r
+        if string:\r
+            self.xml = string\r
+            self.decode()\r
+\r
+\r
+    def get_refid(self):\r
+        if not self.refid:\r
+            self.decode()\r
+        return self.refid\r
+\r
+    def get_xml(self):\r
+        if not self.xml:\r
+            self.encode()\r
+        return self.xml\r
+\r
+    def set_refid(self, id):\r
+        self.refid = id\r
+\r
+    def get_issuer_gid(self):\r
+        if not self.gid:\r
+            self.decode()\r
+        return self.gid        \r
+\r
+    def set_issuer_gid(self, gid):\r
+        self.gid = gid\r
+\r
+    def decode(self):\r
+        try:\r
+            doc = parseString(self.xml)\r
+        except ExpatError,e:\r
+            logger.log_exc ("Failed to parse credential, %s"%self.xml)\r
+            raise\r
+        sig = doc.getElementsByTagName("Signature")[0]\r
+        self.set_refid(sig.getAttribute("xml:id").strip("Sig_"))\r
+        keyinfo = sig.getElementsByTagName("X509Data")[0]\r
+        szgid = getTextNode(keyinfo, "X509Certificate")\r
+        szgid = "-----BEGIN CERTIFICATE-----\n%s\n-----END CERTIFICATE-----" % szgid\r
+        self.set_issuer_gid(GID(string=szgid))        \r
+        \r
+    def encode(self):\r
+        self.xml = signature_template % (self.get_refid(), self.get_refid())\r
+\r
+\r
+##\r
+# A credential provides a caller gid with privileges to an object gid.\r
+# A signed credential is signed by the object's authority.\r
+#\r
+# Credentials are encoded in one of two ways.  The legacy style places\r
+# it in the subjectAltName of an X509 certificate.  The new credentials\r
+# are placed in signed XML.\r
+#\r
+# WARNING:\r
+# In general, a signed credential obtained externally should\r
+# not be changed else the signature is no longer valid.  So, once\r
+# you have loaded an existing signed credential, do not call encode() or sign() on it.\r
+\r
+def filter_creds_by_caller(creds, caller_hrn_list):\r
+        """\r
+        Returns a list of creds who's gid caller matches the\r
+        specified caller hrn\r
+        """\r
+        if not isinstance(creds, list): creds = [creds]\r
+        if not isinstance(caller_hrn_list, list): \r
+            caller_hrn_list = [caller_hrn_list]\r
+        caller_creds = []\r
+        for cred in creds:\r
+            try:\r
+                tmp_cred = Credential(string=cred)\r
+                if tmp_cred.get_gid_caller().get_hrn() in caller_hrn_list:\r
+                    caller_creds.append(cred)\r
+            except: pass\r
+        return caller_creds\r
+\r
+class Credential(object):\r
+\r
+    ##\r
+    # Create a Credential object\r
+    #\r
+    # @param create If true, create a blank x509 certificate\r
+    # @param subject If subject!=None, create an x509 cert with the subject name\r
+    # @param string If string!=None, load the credential from the string\r
+    # @param filename If filename!=None, load the credential from the file\r
+    # FIXME: create and subject are ignored!\r
+    def __init__(self, create=False, subject=None, string=None, filename=None):\r
+        self.gidCaller = None\r
+        self.gidObject = None\r
+        self.expiration = None\r
+        self.privileges = None\r
+        self.issuer_privkey = None\r
+        self.issuer_gid = None\r
+        self.issuer_pubkey = None\r
+        self.parent = None\r
+        self.signature = None\r
+        self.xml = None\r
+        self.refid = None\r
+        self.legacy = None\r
+\r
+        # Check if this is a legacy credential, translate it if so\r
+        if string or filename:\r
+            if string:                \r
+                str = string\r
+            elif filename:\r
+                str = file(filename).read()\r
+                \r
+            if str.strip().startswith("-----"):\r
+                self.legacy = CredentialLegacy(False,string=str)\r
+                self.translate_legacy(str)\r
+            else:\r
+                self.xml = str\r
+                self.decode()\r
+\r
+        # Find an xmlsec1 path\r
+        self.xmlsec_path = ''\r
+        paths = ['/usr/bin','/usr/local/bin','/bin','/opt/bin','/opt/local/bin']\r
+        for path in paths:\r
+            if os.path.isfile(path + '/' + 'xmlsec1'):\r
+                self.xmlsec_path = path + '/' + 'xmlsec1'\r
+                break\r
+\r
+    def get_subject(self):\r
+        if not self.gidObject:\r
+            self.decode()\r
+        return self.gidObject.get_printable_subject()\r
+\r
+    def get_summary_tostring(self):\r
+        if not self.gidObject:\r
+            self.decode()\r
+        obj = self.gidObject.get_printable_subject()\r
+        caller = self.gidCaller.get_printable_subject()\r
+        exp = self.get_expiration()\r
+        # Summarize the rights too? The issuer?\r
+        return "[ Grant %s rights on %s until %s ]" % (caller, obj, exp)\r
+\r
+    def get_signature(self):\r
+        if not self.signature:\r
+            self.decode()\r
+        return self.signature\r
+\r
+    def set_signature(self, sig):\r
+        self.signature = sig\r
+\r
+        \r
+    ##\r
+    # Translate a legacy credential into a new one\r
+    #\r
+    # @param String of the legacy credential\r
+\r
+    def translate_legacy(self, str):\r
+        legacy = CredentialLegacy(False,string=str)\r
+        self.gidCaller = legacy.get_gid_caller()\r
+        self.gidObject = legacy.get_gid_object()\r
+        lifetime = legacy.get_lifetime()\r
+        if not lifetime:\r
+            self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME))\r
+        else:\r
+            self.set_expiration(int(lifetime))\r
+        self.lifeTime = legacy.get_lifetime()\r
+        self.set_privileges(legacy.get_privileges())\r
+        self.get_privileges().delegate_all_privileges(legacy.get_delegate())\r
+\r
+    ##\r
+    # Need the issuer's private key and name\r
+    # @param key Keypair object containing the private key of the issuer\r
+    # @param gid GID of the issuing authority\r
+\r
+    def set_issuer_keys(self, privkey, gid):\r
+        self.issuer_privkey = privkey\r
+        self.issuer_gid = gid\r
+\r
+\r
+    ##\r
+    # Set this credential's parent\r
+    def set_parent(self, cred):\r
+        self.parent = cred\r
+        self.updateRefID()\r
+\r
+    ##\r
+    # set the GID of the caller\r
+    #\r
+    # @param gid GID object of the caller\r
+\r
+    def set_gid_caller(self, gid):\r
+        self.gidCaller = gid\r
+        # gid origin caller is the caller's gid by default\r
+        self.gidOriginCaller = gid\r
+\r
+    ##\r
+    # get the GID of the object\r
+\r
+    def get_gid_caller(self):\r
+        if not self.gidCaller:\r
+            self.decode()\r
+        return self.gidCaller\r
+\r
+    ##\r
+    # set the GID of the object\r
+    #\r
+    # @param gid GID object of the object\r
+\r
+    def set_gid_object(self, gid):\r
+        self.gidObject = gid\r
+\r
+    ##\r
+    # get the GID of the object\r
+\r
+    def get_gid_object(self):\r
+        if not self.gidObject:\r
+            self.decode()\r
+        return self.gidObject\r
+\r
+\r
+            \r
+    ##\r
+    # Expiration: an absolute UTC time of expiration (as either an int or string or datetime)\r
+    # \r
+    def set_expiration(self, expiration):\r
+        if isinstance(expiration, (int, float)):\r
+            self.expiration = datetime.datetime.fromtimestamp(expiration)\r
+        elif isinstance (expiration, datetime.datetime):\r
+            self.expiration = expiration\r
+        elif isinstance (expiration, StringTypes):\r
+            self.expiration = utcparse (expiration)\r
+        else:\r
+            logger.error ("unexpected input type in Credential.set_expiration")\r
+\r
+\r
+    ##\r
+    # get the lifetime of the credential (always in datetime format)\r
+\r
+    def get_expiration(self):\r
+        if not self.expiration:\r
+            self.decode()\r
+        # at this point self.expiration is normalized as a datetime - DON'T call utcparse again\r
+        return self.expiration\r
+\r
+    ##\r
+    # For legacy sake\r
+    def get_lifetime(self):\r
+        return self.get_expiration()\r
\r
+    ##\r
+    # set the privileges\r
+    #\r
+    # @param privs either a comma-separated list of privileges of a Rights object\r
+\r
+    def set_privileges(self, privs):\r
+        if isinstance(privs, str):\r
+            self.privileges = Rights(string = privs)\r
+        else:\r
+            self.privileges = privs\r
+        \r
+\r
+    ##\r
+    # return the privileges as a Rights object\r
+\r
+    def get_privileges(self):\r
+        if not self.privileges:\r
+            self.decode()\r
+        return self.privileges\r
+\r
+    ##\r
+    # determine whether the credential allows a particular operation to be\r
+    # performed\r
+    #\r
+    # @param op_name string specifying name of operation ("lookup", "update", etc)\r
+\r
+    def can_perform(self, op_name):\r
+        rights = self.get_privileges()\r
+        \r
+        if not rights:\r
+            return False\r
+\r
+        return rights.can_perform(op_name)\r
+\r
+\r
+    ##\r
+    # Encode the attributes of the credential into an XML string    \r
+    # This should be done immediately before signing the credential.    \r
+    # WARNING:\r
+    # In general, a signed credential obtained externally should\r
+    # not be changed else the signature is no longer valid.  So, once\r
+    # you have loaded an existing signed credential, do not call encode() or sign() on it.\r
+\r
+    def encode(self):\r
+        # Create the XML document\r
+        doc = Document()\r
+        signed_cred = doc.createElement("signed-credential")\r
+\r
+# Declare namespaces\r
+# Note that credential/policy.xsd are really the PG schemas\r
+# in a PL namespace.\r
+# Note that delegation of credentials between the 2 only really works\r
+# cause those schemas are identical.\r
+# Also note these PG schemas talk about PG tickets and CM policies.\r
+        signed_cred.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance")\r
+        signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.planet-lab.org/resources/sfa/credential.xsd")\r
+        signed_cred.setAttribute("xsi:schemaLocation", "http://www.planet-lab.org/resources/sfa/ext/policy/1 http://www.planet-lab.org/resources/sfa/ext/policy/1/policy.xsd")\r
+\r
+# PG says for those last 2:\r
+#        signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd")\r
+#        signed_cred.setAttribute("xsi:schemaLocation", "http://www.protogeni.net/resources/credential/ext/policy/1 http://www.protogeni.net/resources/credential/ext/policy/1/policy.xsd")\r
+\r
+        doc.appendChild(signed_cred)  \r
+        \r
+        # Fill in the <credential> bit        \r
+        cred = doc.createElement("credential")\r
+        cred.setAttribute("xml:id", self.get_refid())\r
+        signed_cred.appendChild(cred)\r
+        append_sub(doc, cred, "type", "privilege")\r
+        append_sub(doc, cred, "serial", "8")\r
+        append_sub(doc, cred, "owner_gid", self.gidCaller.save_to_string())\r
+        append_sub(doc, cred, "owner_urn", self.gidCaller.get_urn())\r
+        append_sub(doc, cred, "target_gid", self.gidObject.save_to_string())\r
+        append_sub(doc, cred, "target_urn", self.gidObject.get_urn())\r
+        append_sub(doc, cred, "uuid", "")\r
+        if not self.expiration:\r
+            self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME))\r
+        self.expiration = self.expiration.replace(microsecond=0)\r
+        append_sub(doc, cred, "expires", self.expiration.isoformat())\r
+        privileges = doc.createElement("privileges")\r
+        cred.appendChild(privileges)\r
+\r
+        if self.privileges:\r
+            rights = self.get_privileges()\r
+            for right in rights.rights:\r
+                priv = doc.createElement("privilege")\r
+                append_sub(doc, priv, "name", right.kind)\r
+                append_sub(doc, priv, "can_delegate", str(right.delegate).lower())\r
+                privileges.appendChild(priv)\r
+\r
+        # Add the parent credential if it exists\r
+        if self.parent:\r
+            sdoc = parseString(self.parent.get_xml())\r
+            # If the root node is a signed-credential (it should be), then\r
+            # get all its attributes and attach those to our signed_cred\r
+            # node.\r
+            # Specifically, PG and PLadd attributes for namespaces (which is reasonable),\r
+            # and we need to include those again here or else their signature\r
+            # no longer matches on the credential.\r
+            # We expect three of these, but here we copy them all:\r
+#        signed_cred.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance")\r
+# and from PG (PL is equivalent, as shown above):\r
+#        signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd")\r
+#        signed_cred.setAttribute("xsi:schemaLocation", "http://www.protogeni.net/resources/credential/ext/policy/1 http://www.protogeni.net/resources/credential/ext/policy/1/policy.xsd")\r
+\r
+            # HOWEVER!\r
+            # PL now also declares these, with different URLs, so\r
+            # the code notices those attributes already existed with\r
+            # different values, and complains.\r
+            # This happens regularly on delegation now that PG and\r
+            # PL both declare the namespace with different URLs.\r
+            # If the content ever differs this is a problem,\r
+            # but for now it works - different URLs (values in the attributes)\r
+            # but the same actual schema, so using the PG schema\r
+            # on delegated-to-PL credentials works fine.\r
+\r
+            # Note: you could also not copy attributes\r
+            # which already exist. It appears that both PG and PL\r
+            # will actually validate a slicecred with a parent\r
+            # signed using PG namespaces and a child signed with PL\r
+            # namespaces over the whole thing. But I don't know\r
+            # if that is a bug in xmlsec1, an accident since\r
+            # the contents of the schemas are the same,\r
+            # or something else, but it seems odd. And this works.\r
+            parentRoot = sdoc.documentElement\r
+            if parentRoot.tagName == "signed-credential" and parentRoot.hasAttributes():\r
+                for attrIx in range(0, parentRoot.attributes.length):\r
+                    attr = parentRoot.attributes.item(attrIx)\r
+                    # returns the old attribute of same name that was\r
+                    # on the credential\r
+                    # Below throws InUse exception if we forgot to clone the attribute first\r
+                    oldAttr = signed_cred.setAttributeNode(attr.cloneNode(True))\r
+                    if oldAttr and oldAttr.value != attr.value:\r
+                        msg = "Delegating cred from owner %s to %s over %s replaced attribute %s value '%s' with '%s'" % (self.parent.gidCaller.get_urn(), self.gidCaller.get_urn(), self.gidObject.get_urn(), oldAttr.name, oldAttr.value, attr.value)\r
+                        logger.warn(msg)\r
+                        #raise CredentialNotVerifiable("Can't encode new valid delegated credential: %s" % msg)\r
+\r
+            p_cred = doc.importNode(sdoc.getElementsByTagName("credential")[0], True)\r
+            p = doc.createElement("parent")\r
+            p.appendChild(p_cred)\r
+            cred.appendChild(p)\r
+        # done handling parent credential\r
+\r
+        # Create the <signatures> tag\r
+        signatures = doc.createElement("signatures")\r
+        signed_cred.appendChild(signatures)\r
+\r
+        # Add any parent signatures\r
+        if self.parent:\r
+            for cur_cred in self.get_credential_list()[1:]:\r
+                sdoc = parseString(cur_cred.get_signature().get_xml())\r
+                ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True)\r
+                signatures.appendChild(ele)\r
+                \r
+        # Get the finished product\r
+        self.xml = doc.toxml()\r
+\r
+\r
+    def save_to_random_tmp_file(self):       \r
+        fp, filename = mkstemp(suffix='cred', text=True)\r
+        fp = os.fdopen(fp, "w")\r
+        self.save_to_file(filename, save_parents=True, filep=fp)\r
+        return filename\r
+    \r
+    def save_to_file(self, filename, save_parents=True, filep=None):\r
+        if not self.xml:\r
+            self.encode()\r
+        if filep:\r
+            f = filep \r
+        else:\r
+            f = open(filename, "w")\r
+        f.write(self.xml)\r
+        f.close()\r
+\r
+    def save_to_string(self, save_parents=True):\r
+        if not self.xml:\r
+            self.encode()\r
+        return self.xml\r
+\r
+    def get_refid(self):\r
+        if not self.refid:\r
+            self.refid = 'ref0'\r
+        return self.refid\r
+\r
+    def set_refid(self, rid):\r
+        self.refid = rid\r
+\r
+    ##\r
+    # Figure out what refids exist, and update this credential's id\r
+    # so that it doesn't clobber the others.  Returns the refids of\r
+    # the parents.\r
+    \r
+    def updateRefID(self):\r
+        if not self.parent:\r
+            self.set_refid('ref0')\r
+            return []\r
+        \r
+        refs = []\r
+\r
+        next_cred = self.parent\r
+        while next_cred:\r
+            refs.append(next_cred.get_refid())\r
+            if next_cred.parent:\r
+                next_cred = next_cred.parent\r
+            else:\r
+                next_cred = None\r
+\r
+        \r
+        # Find a unique refid for this credential\r
+        rid = self.get_refid()\r
+        while rid in refs:\r
+            val = int(rid[3:])\r
+            rid = "ref%d" % (val + 1)\r
+\r
+        # Set the new refid\r
+        self.set_refid(rid)\r
+\r
+        # Return the set of parent credential ref ids\r
+        return refs\r
+\r
+    def get_xml(self):\r
+        if not self.xml:\r
+            self.encode()\r
+        return self.xml\r
+\r
+    ##\r
+    # Sign the XML file created by encode()\r
+    #\r
+    # WARNING:\r
+    # In general, a signed credential obtained externally should\r
+    # not be changed else the signature is no longer valid.  So, once\r
+    # you have loaded an existing signed credential, do not call encode() or sign() on it.\r
+\r
+    def sign(self):\r
+        if not self.issuer_privkey or not self.issuer_gid:\r
+            return\r
+        doc = parseString(self.get_xml())\r
+        sigs = doc.getElementsByTagName("signatures")[0]\r
+\r
+        # Create the signature template to be signed\r
+        signature = Signature()\r
+        signature.set_refid(self.get_refid())\r
+        sdoc = parseString(signature.get_xml())        \r
+        sig_ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True)\r
+        sigs.appendChild(sig_ele)\r
+\r
+        self.xml = doc.toxml()\r
+\r
+\r
+        # Split the issuer GID into multiple certificates if it's a chain\r
+        chain = GID(filename=self.issuer_gid)\r
+        gid_files = []\r
+        while chain:\r
+            gid_files.append(chain.save_to_random_tmp_file(False))\r
+            if chain.get_parent():\r
+                chain = chain.get_parent()\r
+            else:\r
+                chain = None\r
+\r
+\r
+        # Call out to xmlsec1 to sign it\r
+        ref = 'Sig_%s' % self.get_refid()\r
+        filename = self.save_to_random_tmp_file()\r
+        signed = os.popen('%s --sign --node-id "%s" --privkey-pem %s,%s %s' \\r
+                 % (self.xmlsec_path, ref, self.issuer_privkey, ",".join(gid_files), filename)).read()\r
+        os.remove(filename)\r
+\r
+        for gid_file in gid_files:\r
+            os.remove(gid_file)\r
+\r
+        self.xml = signed\r
+\r
+        # This is no longer a legacy credential\r
+        if self.legacy:\r
+            self.legacy = None\r
+\r
+        # Update signatures\r
+        self.decode()       \r
+\r
+        \r
+    ##\r
+    # Retrieve the attributes of the credential from the XML.\r
+    # This is automatically called by the various get_* methods of\r
+    # this class and should not need to be called explicitly.\r
+\r
+    def decode(self):\r
+        if not self.xml:\r
+            return\r
+        doc = parseString(self.xml)\r
+        sigs = []\r
+        signed_cred = doc.getElementsByTagName("signed-credential")\r
+\r
+        # Is this a signed-cred or just a cred?\r
+        if len(signed_cred) > 0:\r
+            creds = signed_cred[0].getElementsByTagName("credential")\r
+            signatures = signed_cred[0].getElementsByTagName("signatures")\r
+            if len(signatures) > 0:\r
+                sigs = signatures[0].getElementsByTagName("Signature")\r
+        else:\r
+            creds = doc.getElementsByTagName("credential")\r
+        \r
+        if creds is None or len(creds) == 0:\r
+            # malformed cred file\r
+            raise CredentialNotVerifiable("Malformed XML: No credential tag found")\r
+\r
+        # Just take the first cred if there are more than one\r
+        cred = creds[0]\r
+\r
+        self.set_refid(cred.getAttribute("xml:id"))\r
+        self.set_expiration(utcparse(getTextNode(cred, "expires")))\r
+        self.gidCaller = GID(string=getTextNode(cred, "owner_gid"))\r
+        self.gidObject = GID(string=getTextNode(cred, "target_gid"))   \r
+\r
+\r
+        # Process privileges\r
+        privs = cred.getElementsByTagName("privileges")[0]\r
+        rlist = Rights()\r
+        for priv in privs.getElementsByTagName("privilege"):\r
+            kind = getTextNode(priv, "name")\r
+            deleg = str2bool(getTextNode(priv, "can_delegate"))\r
+            if kind == '*':\r
+                # Convert * into the default privileges for the credential's type\r
+                # Each inherits the delegatability from the * above\r
+                _ , type = urn_to_hrn(self.gidObject.get_urn())\r
+                rl = determine_rights(type, self.gidObject.get_urn())\r
+                for r in rl.rights:\r
+                    r.delegate = deleg\r
+                    rlist.add(r)\r
+            else:\r
+                rlist.add(Right(kind.strip(), deleg))\r
+        self.set_privileges(rlist)\r
+\r
+\r
+        # Is there a parent?\r
+        parent = cred.getElementsByTagName("parent")\r
+        if len(parent) > 0:\r
+            parent_doc = parent[0].getElementsByTagName("credential")[0]\r
+            parent_xml = parent_doc.toxml()\r
+            self.parent = Credential(string=parent_xml)\r
+            self.updateRefID()\r
+\r
+        # Assign the signatures to the credentials\r
+        for sig in sigs:\r
+            Sig = Signature(string=sig.toxml())\r
+\r
+            for cur_cred in self.get_credential_list():\r
+                if cur_cred.get_refid() == Sig.get_refid():\r
+                    cur_cred.set_signature(Sig)\r
+                                    \r
+            \r
+    ##\r
+    # Verify\r
+    #   trusted_certs: A list of trusted GID filenames (not GID objects!) \r
+    #                  Chaining is not supported within the GIDs by xmlsec1.\r
+    #\r
+    #   trusted_certs_required: Should usually be true. Set False means an\r
+    #                 empty list of trusted_certs would still let this method pass.\r
+    #                 It just skips xmlsec1 verification et al. Only used by some utils\r
+    #    \r
+    # Verify that:\r
+    # . All of the signatures are valid and that the issuers trace back\r
+    #   to trusted roots (performed by xmlsec1)\r
+    # . The XML matches the credential schema\r
+    # . That the issuer of the credential is the authority in the target's urn\r
+    #    . In the case of a delegated credential, this must be true of the root\r
+    # . That all of the gids presented in the credential are valid\r
+    #    . Including verifying GID chains, and includ the issuer\r
+    # . The credential is not expired\r
+    #\r
+    # -- For Delegates (credentials with parents)\r
+    # . The privileges must be a subset of the parent credentials\r
+    # . The privileges must have "can_delegate" set for each delegated privilege\r
+    # . The target gid must be the same between child and parents\r
+    # . The expiry time on the child must be no later than the parent\r
+    # . The signer of the child must be the owner of the parent\r
+    #\r
+    # -- Verify does *NOT*\r
+    # . ensure that an xmlrpc client's gid matches a credential gid, that\r
+    #   must be done elsewhere\r
+    #\r
+    # @param trusted_certs: The certificates of trusted CA certificates\r
+    def verify(self, trusted_certs=None, schema=None, trusted_certs_required=True):\r
+        if not self.xml:\r
+            self.decode()\r
+\r
+        # validate against RelaxNG schema\r
+        if HAVELXML and not self.legacy:\r
+            if schema and os.path.exists(schema):\r
+                tree = etree.parse(StringIO(self.xml))\r
+                schema_doc = etree.parse(schema)\r
+                xmlschema = etree.XMLSchema(schema_doc)\r
+                if not xmlschema.validate(tree):\r
+                    error = xmlschema.error_log.last_error\r
+                    message = "%s: %s (line %s)" % (self.get_summary_tostring(), error.message, error.line)\r
+                    raise CredentialNotVerifiable(message)\r
+\r
+        if trusted_certs_required and trusted_certs is None:\r
+            trusted_certs = []\r
+\r
+#        trusted_cert_objects = [GID(filename=f) for f in trusted_certs]\r
+        trusted_cert_objects = []\r
+        ok_trusted_certs = []\r
+        # If caller explicitly passed in None that means skip cert chain validation.\r
+        # Strange and not typical\r
+        if trusted_certs is not None:\r
+            for f in trusted_certs:\r
+                try:\r
+                    # Failures here include unreadable files\r
+                    # or non PEM files\r
+                    trusted_cert_objects.append(GID(filename=f))\r
+                    ok_trusted_certs.append(f)\r
+                except Exception, exc:\r
+                    logger.error("Failed to load trusted cert from %s: %r", f, exc)\r
+            trusted_certs = ok_trusted_certs\r
+\r
+        # Use legacy verification if this is a legacy credential\r
+        if self.legacy:\r
+            self.legacy.verify_chain(trusted_cert_objects)\r
+            if self.legacy.client_gid:\r
+                self.legacy.client_gid.verify_chain(trusted_cert_objects)\r
+            if self.legacy.object_gid:\r
+                self.legacy.object_gid.verify_chain(trusted_cert_objects)\r
+            return True\r
+        \r
+        # make sure it is not expired\r
+        if self.get_expiration() < datetime.datetime.utcnow():\r
+            raise CredentialNotVerifiable("Credential %s expired at %s" % (self.get_summary_tostring(), self.expiration.isoformat()))\r
+\r
+        # Verify the signatures\r
+        filename = self.save_to_random_tmp_file()\r
+        if trusted_certs is not None:\r
+            cert_args = " ".join(['--trusted-pem %s' % x for x in trusted_certs])\r
+\r
+        # If caller explicitly passed in None that means skip cert chain validation.\r
+        # - Strange and not typical\r
+        if trusted_certs is not None:\r
+            # Verify the gids of this cred and of its parents\r
+            for cur_cred in self.get_credential_list():\r
+                cur_cred.get_gid_object().verify_chain(trusted_cert_objects)\r
+                cur_cred.get_gid_caller().verify_chain(trusted_cert_objects)\r
+\r
+        refs = []\r
+        refs.append("Sig_%s" % self.get_refid())\r
+\r
+        parentRefs = self.updateRefID()\r
+        for ref in parentRefs:\r
+            refs.append("Sig_%s" % ref)\r
+\r
+        for ref in refs:\r
+            # If caller explicitly passed in None that means skip xmlsec1 validation.\r
+            # Strange and not typical\r
+            if trusted_certs is None:\r
+                break\r
+\r
+#            print "Doing %s --verify --node-id '%s' %s %s 2>&1" % \\r
+#                (self.xmlsec_path, ref, cert_args, filename)\r
+            verified = os.popen('%s --verify --node-id "%s" %s %s 2>&1' \\r
+                            % (self.xmlsec_path, ref, cert_args, filename)).read()\r
+            if not verified.strip().startswith("OK"):\r
+                # xmlsec errors have a msg= which is the interesting bit.\r
+                mstart = verified.find("msg=")\r
+                msg = ""\r
+                if mstart > -1 and len(verified) > 4:\r
+                    mstart = mstart + 4\r
+                    mend = verified.find('\\', mstart)\r
+                    msg = verified[mstart:mend]\r
+                raise CredentialNotVerifiable("xmlsec1 error verifying cred %s using Signature ID %s: %s %s" % (self.get_summary_tostring(), ref, msg, verified.strip()))\r
+        os.remove(filename)\r
+\r
+        # Verify the parents (delegation)\r
+        if self.parent:\r
+            self.verify_parent(self.parent)\r
+\r
+        # Make sure the issuer is the target's authority, and is\r
+        # itself a valid GID\r
+        self.verify_issuer(trusted_cert_objects)\r
+        return True\r
+\r
+    ##\r
+    # Creates a list of the credential and its parents, with the root \r
+    # (original delegated credential) as the last item in the list\r
+    def get_credential_list(self):    \r
+        cur_cred = self\r
+        list = []\r
+        while cur_cred:\r
+            list.append(cur_cred)\r
+            if cur_cred.parent:\r
+                cur_cred = cur_cred.parent\r
+            else:\r
+                cur_cred = None\r
+        return list\r
+    \r
+    ##\r
+    # Make sure the credential's target gid (a) was signed by or (b)\r
+    # is the same as the entity that signed the original credential,\r
+    # or (c) is an authority over the target's namespace.\r
+    # Also ensure that the credential issuer / signer itself has a valid\r
+    # GID signature chain (signed by an authority with namespace rights).\r
+    def verify_issuer(self, trusted_gids):\r
+        root_cred = self.get_credential_list()[-1]\r
+        root_target_gid = root_cred.get_gid_object()\r
+        root_cred_signer = root_cred.get_signature().get_issuer_gid()\r
+\r
+        # Case 1:\r
+        # Allow non authority to sign target and cred about target.\r
+        #\r
+        # Why do we need to allow non authorities to sign?\r
+        # If in the target gid validation step we correctly\r
+        # checked that the target is only signed by an authority,\r
+        # then this is just a special case of case 3.\r
+        # This short-circuit is the common case currently -\r
+        # and cause GID validation doesn't check 'authority',\r
+        # this allows users to generate valid slice credentials.\r
+        if root_target_gid.is_signed_by_cert(root_cred_signer):\r
+            # cred signer matches target signer, return success\r
+            return\r
+\r
+        # Case 2:\r
+        # Allow someone to sign credential about themeselves. Used?\r
+        # If not, remove this.\r
+        #root_target_gid_str = root_target_gid.save_to_string()\r
+        #root_cred_signer_str = root_cred_signer.save_to_string()\r
+        #if root_target_gid_str == root_cred_signer_str:\r
+        #    # cred signer is target, return success\r
+        #    return\r
+\r
+        # Case 3:\r
+\r
+        # root_cred_signer is not the target_gid\r
+        # So this is a different gid that we have not verified.\r
+        # xmlsec1 verified the cert chain on this already, but\r
+        # it hasn't verified that the gid meets the HRN namespace\r
+        # requirements.\r
+        # Below we'll ensure that it is an authority.\r
+        # But we haven't verified that it is _signed by_ an authority\r
+        # We also don't know if xmlsec1 requires that cert signers\r
+        # are marked as CAs.\r
+\r
+        # Note that if verify() gave us no trusted_gids then this\r
+        # call will fail. So skip it if we have no trusted_gids\r
+        if trusted_gids and len(trusted_gids) > 0:\r
+            root_cred_signer.verify_chain(trusted_gids)\r
+        else:\r
+            logger.debug("No trusted gids. Cannot verify that cred signer is signed by a trusted authority. Skipping that check.")\r
+\r
+        # See if the signer is an authority over the domain of the target.\r
+        # There are multiple types of authority - accept them all here\r
+        # Maybe should be (hrn, type) = urn_to_hrn(root_cred_signer.get_urn())\r
+        root_cred_signer_type = root_cred_signer.get_type()\r
+        if (root_cred_signer_type.find('authority') == 0):\r
+            #logger.debug('Cred signer is an authority')\r
+            # signer is an authority, see if target is in authority's domain\r
+            signerhrn = root_cred_signer.get_hrn()\r
+            if hrn_authfor_hrn(signerhrn, root_target_gid.get_hrn()):\r
+                return\r
+\r
+        # We've required that the credential be signed by an authority\r
+        # for that domain. Reasonable and probably correct.\r
+        # A looser model would also allow the signer to be an authority\r
+        # in my control framework - eg My CA or CH. Even if it is not\r
+        # the CH that issued these, eg, user credentials.\r
+\r
+        # Give up, credential does not pass issuer verification\r
+\r
+        raise CredentialNotVerifiable("Could not verify credential owned by %s for object %s. Cred signer %s not the trusted authority for Cred target %s" % (self.gidCaller.get_urn(), self.gidObject.get_urn(), root_cred_signer.get_hrn(), root_target_gid.get_hrn()))\r
+\r
+\r
+    ##\r
+    # -- For Delegates (credentials with parents) verify that:\r
+    # . The privileges must be a subset of the parent credentials\r
+    # . The privileges must have "can_delegate" set for each delegated privilege\r
+    # . The target gid must be the same between child and parents\r
+    # . The expiry time on the child must be no later than the parent\r
+    # . The signer of the child must be the owner of the parent        \r
+    def verify_parent(self, parent_cred):\r
+        # make sure the rights given to the child are a subset of the\r
+        # parents rights (and check delegate bits)\r
+        if not parent_cred.get_privileges().is_superset(self.get_privileges()):\r
+            raise ChildRightsNotSubsetOfParent(("Parent cred ref %s rights " % parent_cred.get_refid()) +\r
+                self.parent.get_privileges().save_to_string() + (" not superset of delegated cred %s ref %s rights " % (self.get_summary_tostring(), self.get_refid())) +\r
+                self.get_privileges().save_to_string())\r
+\r
+        # make sure my target gid is the same as the parent's\r
+        if not parent_cred.get_gid_object().save_to_string() == \\r
+           self.get_gid_object().save_to_string():\r
+            raise CredentialNotVerifiable("Delegated cred %s: Target gid not equal between parent and child. Parent %s" % (self.get_summary_tostring(), parent_cred.get_summary_tostring()))\r
+\r
+        # make sure my expiry time is <= my parent's\r
+        if not parent_cred.get_expiration() >= self.get_expiration():\r
+            raise CredentialNotVerifiable("Delegated credential %s expires after parent %s" % (self.get_summary_tostring(), parent_cred.get_summary_tostring()))\r
+\r
+        # make sure my signer is the parent's caller\r
+        if not parent_cred.get_gid_caller().save_to_string(False) == \\r
+           self.get_signature().get_issuer_gid().save_to_string(False):\r
+            raise CredentialNotVerifiable("Delegated credential %s not signed by parent %s's caller" % (self.get_summary_tostring(), parent_cred.get_summary_tostring()))\r
+                \r
+        # Recurse\r
+        if parent_cred.parent:\r
+            parent_cred.verify_parent(parent_cred.parent)\r
+\r
+\r
+    def delegate(self, delegee_gidfile, caller_keyfile, caller_gidfile):\r
+        """\r
+        Return a delegated copy of this credential, delegated to the \r
+        specified gid's user.    \r
+        """\r
+        # get the gid of the object we are delegating\r
+        object_gid = self.get_gid_object()\r
+        object_hrn = object_gid.get_hrn()        \r
\r
+        # the hrn of the user who will be delegated to\r
+        delegee_gid = GID(filename=delegee_gidfile)\r
+        delegee_hrn = delegee_gid.get_hrn()\r
+  \r
+        #user_key = Keypair(filename=keyfile)\r
+        #user_hrn = self.get_gid_caller().get_hrn()\r
+        subject_string = "%s delegated to %s" % (object_hrn, delegee_hrn)\r
+        dcred = Credential(subject=subject_string)\r
+        dcred.set_gid_caller(delegee_gid)\r
+        dcred.set_gid_object(object_gid)\r
+        dcred.set_parent(self)\r
+        dcred.set_expiration(self.get_expiration())\r
+        dcred.set_privileges(self.get_privileges())\r
+        dcred.get_privileges().delegate_all_privileges(True)\r
+        #dcred.set_issuer_keys(keyfile, delegee_gidfile)\r
+        dcred.set_issuer_keys(caller_keyfile, caller_gidfile)\r
+        dcred.encode()\r
+        dcred.sign()\r
+\r
+        return dcred\r
+\r
+    # only informative\r
+    def get_filename(self):\r
+        return getattr(self,'filename',None)\r
+\r
+    ##\r
+    # Dump the contents of a credential to stdout in human-readable format\r
+    #\r
+    # @param dump_parents If true, also dump the parent certificates\r
+    def dump (self, *args, **kwargs):\r
+        print self.dump_string(*args, **kwargs)\r
+\r
+\r
+    def dump_string(self, dump_parents=False):\r
+        result=""\r
+        result += "CREDENTIAL %s\n" % self.get_subject()\r
+        filename=self.get_filename()\r
+        if filename: result += "Filename %s\n"%filename\r
+        result += "      privs: %s\n" % self.get_privileges().save_to_string()\r
+        gidCaller = self.get_gid_caller()\r
+        if gidCaller:\r
+            result += "  gidCaller:\n"\r
+            result += gidCaller.dump_string(8, dump_parents)\r
+\r
+        if self.get_signature():\r
+            print "  gidIssuer:"\r
+            self.get_signature().get_issuer_gid().dump(8, dump_parents)\r
+\r
+        gidObject = self.get_gid_object()\r
+        if gidObject:\r
+            result += "  gidObject:\n"\r
+            result += gidObject.dump_string(8, dump_parents)\r
+\r
+        if self.parent and dump_parents:\r
+            result += "\nPARENT"\r
+            result += self.parent.dump_string(True)\r
+\r
+        return result\r
index c115211..a57b94c 100644 (file)
@@ -7,7 +7,7 @@
   
 -->
 <!--
-  ProtoGENI credential and privilege specification. The key points:
+  PlanetLab credential specification. The key points:
   
   * A credential is a set of privileges or a Ticket, each with a flag
     to indicate delegation is permitted.
@@ -17,7 +17,7 @@
     blob will be signed. So, there will be multiple signatures in the
     document, each with a reference to the credential it signs.
   
-  default namespace = "http://www.protogeni.net/resources/credential/0.1"
+  default namespace = "http://www.planet-lab.org/resources/ext/credential/1"
 -->
 <xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema" elementFormDefault="qualified" xmlns:sig="http://www.w3.org/2000/09/xmldsig#">
   <xs:include schemaLocation="protogeni-rspec-common.xsd"/>
index 8ba90b2..dda7096 100644 (file)
@@ -5,9 +5,6 @@
 # certificate that stores a tuple of parameters.
 ##
 
-### $Id: credential.py 17477 2010-03-25 16:49:34Z jkarlin $
-### $URL: svn+ssh://svn.planet-lab.org/svn/sfa/branches/geni-api/sfa/trust/credential.py $
-
 import xmlrpclib
 
 from sfa.util.faults import *
index 94240cd..15ad6bf 100644 (file)
 import xmlrpclib
 import uuid
 
-from sfa.util.sfalogging import sfa_logger
 from sfa.trust.certificate import Certificate
-from sfa.util.xrn import hrn_to_urn, urn_to_hrn
+
+from sfa.util.faults import *
+from sfa.util.sfalogging import logger
+from sfa.util.xrn import hrn_to_urn, urn_to_hrn, hrn_authfor_hrn
 
 ##
 # Create a new uuid. Returns the UUID as a string.
@@ -75,12 +77,13 @@ class GID(Certificate):
     # @param subject If subject!=None, create the X509 cert and set the subject name
     # @param string If string!=None, load the GID from a string
     # @param filename If filename!=None, load the GID from a file
+    # @param lifeDays life of GID in days - default is 1825==5 years
 
-    def __init__(self, create=False, subject=None, string=None, filename=None, uuid=None, hrn=None, urn=None):
+    def __init__(self, create=False, subject=None, string=None, filename=None, uuid=None, hrn=None, urn=None, lifeDays=1825):
         
-        Certificate.__init__(self, create, subject, string, filename)
+        Certificate.__init__(self, lifeDays, create, subject, string, filename)
         if subject:
-            sfa_logger().debug("Creating GID for subject: %s" % subject)
+            logger.debug("Creating GID for subject: %s" % subject)
         if uuid:
             self.uuid = int(uuid)
         if hrn:
@@ -180,7 +183,7 @@ class GID(Certificate):
         print self.dump_string(*args,**kwargs)
 
     def dump_string(self, indent=0, dump_parents=False):
-        result="GID\n"
+        result=" "*(indent-2) + "GID\n"
         result += " "*indent + "hrn:" + str(self.get_hrn()) +"\n"
         result += " "*indent + "urn:" + str(self.get_urn()) +"\n"
         result += " "*indent + "uuid:" + str(self.get_uuid()) + "\n"
@@ -196,7 +199,7 @@ class GID(Certificate):
     # Verify the chain of authenticity of the GID. First perform the checks
     # of the certificate class (verifying that each parent signs the child,
     # etc). In addition, GIDs also confirm that the parent's HRN is a prefix
-    # of the child's HRN.
+    # of the child's HRN, and the parent is of type 'authority'.
     #
     # Verifying these prefixes prevents a rogue authority from signing a GID
     # for a principal that is not a member of that authority. For example,
@@ -208,8 +211,17 @@ class GID(Certificate):
        
         if self.parent:
             # make sure the parent's hrn is a prefix of the child's hrn
-            if not self.get_hrn().startswith(self.parent.get_hrn()):
-                raise GidParentHrn("This cert HRN %s doesnt start with parent HRN %s" % (self.get_hrn(), self.parent.get_hrn()))
+            if not hrn_authfor_hrn(self.parent.get_hrn(), self.get_hrn()):
+                raise GidParentHrn("This cert HRN %s isn't in the namespace for parent HRN %s" % (self.get_hrn(), self.parent.get_hrn()))
+
+            # Parent must also be an authority (of some type) to sign a GID
+            # There are multiple types of authority - accept them all here
+            if not self.parent.get_type().find('authority') == 0:
+                raise GidInvalidParentHrn("This cert %s's parent %s is not an authority (is a %s)" % (self.get_hrn(), self.parent.get_hrn(), self.parent.get_type()))
+
+            # Then recurse up the chain - ensure the parent is a trusted
+            # root or is in the namespace of a trusted root
+            self.parent.verify_chain(trusted_certs)
         else:
             # make sure that the trusted root's hrn is a prefix of the child's
             trusted_gid = GID(string=trusted_root.save_to_string())
@@ -218,7 +230,11 @@ class GID(Certificate):
             #if trusted_type == 'authority':
             #    trusted_hrn = trusted_hrn[:trusted_hrn.rindex('.')]
             cur_hrn = self.get_hrn()
-            if not self.get_hrn().startswith(trusted_hrn):
-                raise GidParentHrn("Trusted roots HRN %s isnt start of this cert %s" % (trusted_hrn, cur_hrn))
+            if not hrn_authfor_hrn(trusted_hrn, cur_hrn):
+                raise GidParentHrn("Trusted root with HRN %s isn't a namespace authority for this cert %s" % (trusted_hrn, cur_hrn))
+
+            # There are multiple types of authority - accept them all here
+            if not trusted_type.find('authority') == 0:
+                raise GidInvalidParentHrn("This cert %s's trusted root signer %s is not an authority (is a %s)" % (self.get_hrn(), trusted_hrn, trusted_type))
 
         return
index 5d7db6f..6323436 100644 (file)
@@ -15,7 +15,7 @@
 import os
 
 from sfa.util.faults import *
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
 from sfa.util.xrn import get_leaf, get_authority, hrn_to_urn, urn_to_hrn
 from sfa.trust.certificate import Keypair
 from sfa.trust.credential import Credential
@@ -33,7 +33,6 @@ class AuthInfo:
     gid_filename = None
     privkey_filename = None
     dbinfo_filename = None
-
     ##
     # Initialize and authority object.
     #
@@ -159,7 +158,7 @@ class Hierarchy:
 
     def create_auth(self, xrn, create_parents=False):
         hrn, type = urn_to_hrn(xrn)
-        sfa_logger().debug("Hierarchy: creating authority: %s"% hrn)
+        logger.debug("Hierarchy: creating authority: %s"% hrn)
 
         # create the parent authority if necessary
         parent_hrn = get_authority(hrn)
@@ -179,7 +178,7 @@ class Hierarchy:
                 pass
 
         if os.path.exists(privkey_filename):
-            sfa_logger().debug("using existing key %r for authority %r"%(privkey_filename,hrn))
+            logger.debug("using existing key %r for authority %r"%(privkey_filename,hrn))
             pkey = Keypair(filename = privkey_filename)
         else:
             pkey = Keypair(create = True)
@@ -205,7 +204,7 @@ class Hierarchy:
     def get_auth_info(self, xrn):
         hrn, type = urn_to_hrn(xrn)
         if not self.auth_exists(hrn):
-            sfa_logger().warning("Hierarchy: mising authority - xrn=%s, hrn=%s"%(xrn,hrn))
+            logger.warning("Hierarchy: mising authority - xrn=%s, hrn=%s"%(xrn,hrn))
             raise MissingAuthority(hrn)
 
         (directory, gid_filename, privkey_filename, dbinfo_filename) = \
@@ -230,15 +229,28 @@ class Hierarchy:
     # @param uuid the unique identifier to store in the GID
     # @param pkey the public key to store in the GID
 
-    def create_gid(self, xrn, uuid, pkey):
+    def create_gid(self, xrn, uuid, pkey, CA=False):
         hrn, type = urn_to_hrn(xrn)
+        parent_hrn = get_authority(hrn)
         # Using hrn_to_urn() here to make sure the urn is in the right format
         # If xrn was a hrn instead of a urn, then the gid's urn will be
         # of type None 
         urn = hrn_to_urn(hrn, type)
         gid = GID(subject=hrn, uuid=uuid, hrn=hrn, urn=urn)
 
-        parent_hrn = get_authority(hrn)
+        # is this a CA cert
+        if hrn == self.config.SFA_INTERFACE_HRN or not parent_hrn:
+            # root or sub authority  
+            gid.set_intermediate_ca(True)
+        elif type and 'authority' in type:
+            # authority type
+            gid.set_intermediate_ca(True)
+        elif CA:
+            gid.set_intermediate_ca(True)
+        else:
+            gid.set_intermediate_ca(False)
+
+        # set issuer
         if not parent_hrn or hrn == self.config.SFA_INTERFACE_HRN:
             # if there is no parent hrn, then it must be self-signed. this
             # is where we terminate the recursion
@@ -248,7 +260,6 @@ class Hierarchy:
             parent_auth_info = self.get_auth_info(parent_hrn)
             gid.set_issuer(parent_auth_info.get_pkey_object(), parent_auth_info.hrn)
             gid.set_parent(parent_auth_info.get_gid_object())
-            gid.set_intermediate_ca(True)
 
         gid.set_pubkey(pkey)
         gid.encode()
index ff1ac2d..db88123 100644 (file)
@@ -60,7 +60,7 @@ def determine_rights(type, name):
     elif type in ["sa", "authority+sa"]:
         rl.add("authority")
         rl.add("sa")
-    elif type in ["ma", "authority+ma", "cm", "authority+cm"]:
+    elif type in ["ma", "authority+ma", "cm", "authority+cm", "sm", "authority+sm"]:
         rl.add("authority")
         rl.add("ma")
     elif type == "authority":
@@ -220,6 +220,7 @@ class Rights:
             for my_right in self.rights:
                 if my_right.is_superset(child_right):
                     allowed = True
+                    break
             if not allowed:
                 return False
         return True
@@ -245,47 +246,3 @@ class Rights:
                 return False
         return True
 
-
-
-    ##
-    # Determine the rights that an object should have. The rights are entirely
-    # dependent on the type of the object. For example, users automatically
-    # get "refresh", "resolve", and "info".
-    #
-    # @param type the type of the object (user | sa | ma | slice | node)
-    # @param name human readable name of the object (not used at this time)
-    #
-    # @return Rights object containing rights
-
-    def determine_rights(self, type, name):
-        rl = Rights()
-
-        # rights seem to be somewhat redundant with the type of the credential.
-        # For example, a "sa" credential implies the authority right, because
-        # a sa credential cannot be issued to a user who is not an owner of
-        # the authority
-
-        if type == "user":
-            rl.add("refresh")
-            rl.add("resolve")
-            rl.add("info")
-        elif type in ["sa", "authority+sa"]:
-            rl.add("authority")
-            rl.add("sa")
-        elif type in ["ma", "authority+ma", "cm", "authority+cm"]:
-            rl.add("authority")
-            rl.add("ma")
-        elif type == "authority":
-            rl.add("authority")
-            rl.add("sa")
-            rl.add("ma")
-        elif type == "slice":
-            rl.add("refresh")
-            rl.add("embed")
-            rl.add("bind")
-            rl.add("control")
-            rl.add("info")
-        elif type == "component":
-            rl.add("operator")
-
-        return rl
diff --git a/sfa/trust/trustedroot.py b/sfa/trust/trustedroot.py
deleted file mode 100644 (file)
index ec8d2f0..0000000
+++ /dev/null
@@ -1,44 +0,0 @@
-### $Id$
-### $URL$
-
-import os
-
-from sfa.trust.gid import *
-
-class TrustedRootList:
-    def __init__(self, dir):
-        self.basedir = dir
-        
-        # create the directory to hold the files
-        try:
-            os.makedirs(self.basedir)
-        # if the path already exists then pass
-        except OSError, (errno, strerr):
-            if errno == 17:
-                pass
-
-    def add_gid(self, gid):
-        fn = os.path.join(self.basedir, gid.get_hrn() + ".gid")
-
-        gid.save_to_file(fn)
-
-    def get_list(self):
-        gid_list = []
-        file_list = os.listdir(self.basedir)
-        for gid_file in file_list:
-            fn = os.path.join(self.basedir, gid_file)
-            if os.path.isfile(fn):
-                gid = GID(filename = fn)
-                gid_list.append(gid)
-        return gid_list
-
-    def get_file_list(self):
-        gid_file_list = []
-        
-        file_list = os.listdir(self.basedir)
-        for gid_file in file_list:
-            fn = os.path.join(self.basedir, gid_file)
-            if os.path.isfile(fn):
-                gid_file_list.append(fn)
-
-        return gid_file_list
diff --git a/sfa/trust/trustedroots.py b/sfa/trust/trustedroots.py
new file mode 100644 (file)
index 0000000..a505aea
--- /dev/null
@@ -0,0 +1,43 @@
+import os.path
+import glob
+
+from sfa.trust.gid import GID
+from sfa.util.sfalogging import logger
+
+class TrustedRoots:
+    
+    # we want to avoid reading all files in the directory
+    # this is because it's common to have backups of all kinds
+    # e.g. *~, *.hide, *-00, *.bak and the like
+    supported_extensions= [ 'gid', 'cert', 'pem' ]
+
+    def __init__(self, dir):
+        self.basedir = dir
+        # create the directory to hold the files, if not existing
+        if not os.path.isdir (self.basedir):
+            os.makedirs(self.basedir)
+
+    def add_gid(self, gid):
+        fn = os.path.join(self.basedir, gid.get_hrn() + ".gid")
+        gid.save_to_file(fn)
+
+    def get_list(self):
+        gid_list = [GID(filename=cert_file) for cert_file in self.get_file_list()]
+        return gid_list
+
+    def get_file_list(self):
+        file_list  = []
+        pattern=os.path.join(self.basedir,"*")
+        for cert_file in glob.glob(pattern):
+            if os.path.isfile(cert_file):
+                if self.has_supported_extension(cert_file):
+                    file_list.append(cert_file) 
+                else:
+                    logger.warning("File %s ignored - supported extensions are %r"%\
+                                       (cert_file,TrustedRoots.supported_extensions))
+        return file_list
+
+    def has_supported_extension (self,path):
+        (_,ext)=os.path.splitext(path)
+        ext=ext.replace('.','').lower()
+        return ext in TrustedRoots.supported_extensions
index 39cde57..19cd4d0 100644 (file)
@@ -22,7 +22,7 @@ try: import pgdb
 except: print >> sys.stderr, "WARNING, could not import pgdb"
 
 from sfa.util.faults import *
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
 
 if not psycopg2:
     is8bit = re.compile("[\x80-\xff]").search
@@ -183,21 +183,21 @@ class PostgreSQL:
 
             if not params:
                 if self.debug:
-                    sfa_logger().debug('execute0 %r'%query)
+                    logger.debug('execute0 %r'%query)
                 cursor.execute(query)
             elif isinstance(params,dict):
                 if self.debug:
-                    sfa_logger().debug('execute-dict: params=[%r] query=[%r]'%(params,query%params))
+                    logger.debug('execute-dict: params=[%r] query=[%r]'%(params,query%params))
                 cursor.execute(query,params)
             elif isinstance(params,tuple) and len(params)==1:
                 if self.debug:
-                    sfa_logger().debug('execute-tuple %r'%(query%params[0]))
+                    logger.debug('execute-tuple %r'%(query%params[0]))
                 cursor.execute(query,params[0])
             else:
                 param_seq=(params,)
                 if self.debug:
                     for params in param_seq:
-                        sfa_logger().debug('executemany %r'%(query%params))
+                        logger.debug('executemany %r'%(query%params))
                 cursor.executemany(query, param_seq)
             (self.rowcount, self.description, self.lastrowid) = \
                             (cursor.rowcount, cursor.description, cursor.lastrowid)
@@ -207,11 +207,11 @@ class PostgreSQL:
             except:
                 pass
             uuid = commands.getoutput("uuidgen")
-            sfa_logger().error("Database error %s:" % uuid)
-            sfa_logger().error("Exception=%r"%e)
-            sfa_logger().error("Query=%r"%query)
-            sfa_logger().error("Params=%r"%pformat(params))
-            sfa_logger().log_exc("PostgreSQL.execute caught exception")
+            logger.error("Database error %s:" % uuid)
+            logger.error("Exception=%r"%e)
+            logger.error("Query=%r"%query)
+            logger.error("Params=%r"%pformat(params))
+            logger.log_exc("PostgreSQL.execute caught exception")
             raise SfaDBError("Please contact support: %s" % str(e))
 
         return cursor
index f9aed1c..39d6538 100644 (file)
@@ -8,13 +8,18 @@ import traceback
 import string
 import xmlrpclib
 
-from sfa.util.sfalogging import sfa_logger
-from sfa.trust.auth import Auth
-from sfa.util.config import *
 from sfa.util.faults import *
+from sfa.util.config import *
+import sfa.util.xmlrpcprotocol as xmlrpcprotocol
+from sfa.util.sfalogging import logger
+from sfa.trust.auth import Auth
+from sfa.util.cache import Cache
 from sfa.trust.credential import *
 from sfa.trust.certificate import *
 
+# this is wrong all right, but temporary 
+from sfa.managers.import_manager import import_manager
+
 # See "2.2 Characters" in the XML specification:
 #
 # #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD]
@@ -94,7 +99,7 @@ class ManagerWrapper:
     """
     This class acts as a wrapper around an SFA interface manager module, but
     can be used with any python module. The purpose of this class is raise a 
-    SfaNotImplemented exception if the a someone attepmts to use an attribute 
+    SfaNotImplemented exception if someone attempts to use an attribute 
     (could be a callable) thats not available in the library by checking the
     library using hasattr. This helps to communicate better errors messages 
     to the users and developers in the event that a specifiec operation 
@@ -106,19 +111,17 @@ class ManagerWrapper:
         self.interface = interface
         
     def __getattr__(self, method):
-        
         if not hasattr(self.manager, method):
             raise SfaNotImplemented(method, self.interface)
         return getattr(self.manager, method)
         
 class BaseAPI:
 
-    cache = None
     protocol = None
   
     def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", 
                  methods='sfa.methods', peer_cert = None, interface = None, 
-                 key_file = None, cert_file = None, cache = cache):
+                 key_file = None, cert_file = None, cache = None):
 
         self.encoding = encoding
         
@@ -129,7 +132,6 @@ class BaseAPI:
         # Better just be documenting the API
         if config is None:
             return
-        
         # Load configuration
         self.config = Config(config)
         self.auth = Auth(peer_cert)
@@ -140,18 +142,20 @@ class BaseAPI:
         self.cert_file = cert_file
         self.cert = Certificate(filename=self.cert_file)
         self.cache = cache
+        if self.cache is None:
+            self.cache = Cache()
         self.credential = None
         self.source = None 
         self.time_format = "%Y-%m-%d %H:%M:%S"
-        self.logger=sfa_logger
-        
+        self.logger = logger
         # load registries
         from sfa.server.registry import Registries
-        self.registries = Registries(self
+        self.registries = Registries() 
 
         # load aggregates
         from sfa.server.aggregate import Aggregates
-        self.aggregates = Aggregates(self)
+        self.aggregates = Aggregates()
 
 
     def get_interface_manager(self, manager_base = 'sfa.managers'):
@@ -159,23 +163,19 @@ class BaseAPI:
         Returns the appropriate manager module for this interface.
         Modules are usually found in sfa/managers/
         """
-        
+        manager=None
         if self.interface in ['registry']:
-            mgr_type = self.config.SFA_REGISTRY_TYPE
-            manager_module = manager_base + ".registry_manager_%s" % mgr_type
+            manager=import_manager ("registry",  self.config.SFA_REGISTRY_TYPE)
         elif self.interface in ['aggregate']:
-            mgr_type = self.config.SFA_AGGREGATE_TYPE
-            manager_module = manager_base + ".aggregate_manager_%s" % mgr_type 
+            manager=import_manager ("aggregate", self.config.SFA_AGGREGATE_TYPE)
         elif self.interface in ['slicemgr', 'sm']:
-            mgr_type = self.config.SFA_SM_TYPE
-            manager_module = manager_base + ".slice_manager_%s" % mgr_type
+            manager=import_manager ("slice",     self.config.SFA_SM_TYPE)
         elif self.interface in ['component', 'cm']:
-            mgr_type = self.config.SFA_CM_TYPE
-            manager_module = manager_base + ".component_manager_%s" % mgr_type
-        else:
+            manager=import_manager ("component", self.config.SFA_CM_TYPE)
+        if not manager:
             raise SfaAPIError("No manager for interface: %s" % self.interface)  
-        manager = __import__(manager_module, fromlist=[manager_base])
-        # this isnt necessary but will hlep to produce better error messages
+            
+        # this isnt necessary but will help to produce better error messages
         # if someone tries to access an operation this manager doesn't implement  
         manager = ManagerWrapper(manager, self.interface)
 
@@ -238,7 +238,7 @@ class BaseAPI:
         except SfaFault, fault:
             result = fault 
         except Exception, fault:
-            sfa_logger().log_exc("BaseAPI.handle has caught Exception")
+            logger.log_exc("BaseAPI.handle has caught Exception")
             result = SfaAPIError(fault)
 
 
@@ -266,3 +266,14 @@ class BaseAPI:
                 raise result 
             
         return response
+
+    def get_cached_server_version(self, server):
+        cache_key = server.url + "-version"
+        server_version = None
+        if self.cache:
+            server_version = self.cache.get(cache_key)
+        if not server_version:
+            server_version = server.GetVersion()
+            # cache version for 24 hours
+            self.cache.add(cache_key, server_version, ttl= 60*60*24)
+        return server_version
index 904c42b..ead60bb 100644 (file)
@@ -3,7 +3,7 @@
 import threading
 import time
 
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
 
 """
 Callids: a simple mechanism to remember the call ids served so fas
@@ -33,15 +33,15 @@ class _call_ids_impl (dict):
         if not call_id: return False
         has_lock=False
         for attempt in range(_call_ids_impl.retries):
-            if debug: sfa_logger().debug("Waiting for lock (%d)"%attempt)
+            if debug: logger.debug("Waiting for lock (%d)"%attempt)
             if self._lock.acquire(False): 
                 has_lock=True
-                if debug: sfa_logger().debug("got lock (%d)"%attempt)
+                if debug: logger.debug("got lock (%d)"%attempt)
                 break
             time.sleep(float(_call_ids_impl.wait_ms)/1000)
         # in the unlikely event where we can't get the lock
         if not has_lock:
-            sfa_logger().warning("_call_ids_impl.should_handle_call_id: could not acquire lock")
+            logger.warning("_call_ids_impl.should_handle_call_id: could not acquire lock")
             return False
         # we're good to go
         if self.has_key(call_id):
@@ -51,7 +51,7 @@ class _call_ids_impl (dict):
         self[call_id]=time.time()
         self._purge()
         self._lock.release()
-        if debug: sfa_logger().debug("released lock")
+        if debug: logger.debug("released lock")
         return False
         
     def _purge(self):
@@ -60,11 +60,11 @@ class _call_ids_impl (dict):
         for (k,v) in self.iteritems():
             if (now-v) >= _call_ids_impl.purge_timeout: o_keys.append(k)
         for k in o_keys: 
-            if debug: sfa_logger().debug("Purging call_id %r (%s)"%(k,time.strftime("%H:%M:%S",time.localtime(self[k]))))
+            if debug: logger.debug("Purging call_id %r (%s)"%(k,time.strftime("%H:%M:%S",time.localtime(self[k]))))
             del self[k]
         if debug:
-            sfa_logger().debug("AFTER PURGE")
-            for (k,v) in self.iteritems(): sfa_logger().debug("%s -> %s"%(k,time.strftime("%H:%M:%S",time.localtime(v))))
+            logger.debug("AFTER PURGE")
+            for (k,v) in self.iteritems(): logger.debug("%s -> %s"%(k,time.strftime("%H:%M:%S",time.localtime(v))))
         
 def Callids ():
     if not _call_ids_impl._instance:
index c12104b..98373ec 100644 (file)
@@ -6,9 +6,6 @@
 # TODO: investigate ways to combine this with existing PLC server?
 ##
 
-### $Id$
-### $URL$
-
 import sys
 import traceback
 import threading
@@ -19,7 +16,7 @@ import SimpleHTTPServer
 import SimpleXMLRPCServer
 from OpenSSL import SSL
 
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.credential import *
 from sfa.util.faults import *
@@ -74,7 +71,7 @@ class SecureXMLRpcRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
             # internal error, report as HTTP server error
             self.send_response(500)
             self.end_headers()
-            sfa_logger().log_exc("componentserver.SecureXMLRpcRequestHandler.do_POST")
+            logger.log_exc("componentserver.SecureXMLRpcRequestHandler.do_POST")
         else:
             # got a valid XML RPC response
             self.send_response(200)
index 0f513ec..cf2cca9 100644 (file)
@@ -14,9 +14,6 @@
 # Note that SFA does not access any of the PLC databases directly via
 # a mysql connection; All PLC databases are accessed via PLCAPI.
 
-### $Id$
-### $URL$
-
 import os.path
 import traceback
 
@@ -69,6 +66,9 @@ class Config:
             if not hasattr(self, 'SFA_CM_TYPE'):
                 self.SFA_COMPONENT_TYPE='pl'
 
+            if not hasattr(self, 'SFA_MAX_SLICE_RENEW'):
+                self.SFA_MAX_SLICE_RENEW=60
+
             # create the data directory if it doesnt exist
             if not os.path.isdir(self.SFA_DATA_DIR):
                 try:
diff --git a/sfa/util/enumeration.py b/sfa/util/enumeration.py
new file mode 100644 (file)
index 0000000..4e508bd
--- /dev/null
@@ -0,0 +1,13 @@
+
+class Enum(set):
+    def __init__(self, *args, **kwds):
+        set.__init__(self)
+        enums = dict(zip(args, [object() for i in range(len(args))]), **kwds)
+        for (key, value) in enums.items():
+            setattr(self, key, value)
+            self.add(eval('self.%s' % key))
+
+
+#def Enum2(*args, **kwds):
+#    enums = dict(zip(sequential, range(len(sequential))), **named)
+#    return type('Enum', (), enums)
index 07335b7..91e5300 100644 (file)
@@ -270,6 +270,30 @@ class InvalidRSpec(SfaFault):
     def __str__(self):
         return repr(self.value)
 
+class InvalidRSpecElement(SfaFault):
+    def __init__(self, value, extra = None):
+        self.value = value
+        faultString = "Invalid RSpec Element: %(value)s" % locals()
+        SfaFault.__init__(self, 108, faultString, extra)
+    def __str__(self):
+        return repr(self.value)
+
+class InvalidXML(SfaFault):
+    def __init__(self, value, extra = None):
+        self.value = value
+        faultString = "Invalid XML Document: %(value)s" % locals()
+        SfaFault.__init__(self, 108, faultString, extra)
+    def __str__(self):
+        return repr(self.value)
+
+class InvalidXMLElement(SfaFault):
+    def __init__(self, value, extra = None):
+        self.value = value
+        faultString = "Invalid XML Element: %(value)s" % locals()
+        SfaFault.__init__(self, 108, faultString, extra)
+    def __str__(self):
+        return repr(self.value)
+
 class AccountNotEnabled(SfaFault):
     def __init__(self,  extra = None):
         faultString = "Account Disabled"
diff --git a/sfa/util/httpsProtocol.py b/sfa/util/httpsProtocol.py
new file mode 100644 (file)
index 0000000..e6c6be1
--- /dev/null
@@ -0,0 +1,51 @@
+import httplib
+import socket
+import sys
+
+
+def is_python26():
+    return False
+    #return sys.version_info[0] == 2 and sys.version_info[1] == 6
+
+# wrapper around standartd https modules. Properly supports timeouts.  
+
+class HTTPSConnection(httplib.HTTPSConnection):
+    def __init__(self, host, port=None, key_file=None, cert_file=None,
+                 strict=None, timeout = None):
+        httplib.HTTPSConnection.__init__(self, host, port, key_file, cert_file, strict)
+        if timeout:
+            timeout = float(timeout)
+        self.timeout = timeout
+
+    def connect(self):
+        """Connect to a host on a given (SSL) port."""
+        if is_python26():
+            from sfa.util.ssl_socket import SSLSocket
+            sock = socket.create_connection((self.host, self.port), self.timeout)
+            if self._tunnel_host:
+                self.sock = sock
+                self._tunnel()
+            self.sock = SSLSocket(sock, self.key_file, self.cert_file)
+        else:
+            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+            sock.settimeout(self.timeout)
+            sock.connect((self.host, self.port))
+            ssl = socket.ssl(sock, self.key_file, self.cert_file)
+            self.sock = httplib.FakeSocket(sock, ssl)
+
+class HTTPS(httplib.HTTPS):
+    def __init__(self, host='', port=None, key_file=None, cert_file=None,
+                     strict=None, timeout = None):
+        # urf. compensate for bad input.
+        if port == 0:
+            port = None
+        self._setup(HTTPSConnection(host, port, key_file, cert_file, strict, timeout))
+
+        # we never actually use these for anything, but we keep them
+        # here for compatibility with post-1.5.2 CVS.
+        self.key_file = key_file
+        self.cert_file = cert_file
+    
+    def set_timeout(self, timeout):
+        if is_python26():
+            self._conn.timeout = timeout
index 43b589c..4c37c67 100644 (file)
@@ -11,7 +11,7 @@ import textwrap
 import xmlrpclib
 
 
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
 from sfa.util.faults import * 
 from sfa.util.parameter import Parameter, Mixed, python_type, xmlrpc_type
 from sfa.trust.auth import Auth
@@ -80,12 +80,12 @@ class Method:
                 self.type_check(name, value, expected, args)
 
             if self.api.config.SFA_API_DEBUG:
-                sfa_logger().debug("method.__call__ [%s] : BEG %s"%(self.api.interface,methodname))
+                logger.debug("method.__call__ [%s] : BEG %s"%(self.api.interface,methodname))
             result = self.call(*args, **kwds)
 
             runtime = time.time() - start
             if self.api.config.SFA_API_DEBUG or hasattr(self, 'message'):
-                sfa_logger().debug("method.__call__ [%s] : END %s in %02f s (%s)"%\
+                logger.debug("method.__call__ [%s] : END %s in %02f s (%s)"%\
                                        (self.api.interface,methodname,runtime,getattr(self,'message',"[no-msg]")))
 
             return result
@@ -97,7 +97,7 @@ class Method:
             # Prepend caller and method name to expected faults
             fault.faultString = caller + ": " +  self.name + ": " + fault.faultString
             runtime = time.time() - start
-            sfa_logger().log_exc("Method %s raised an exception"%self.name) 
+            logger.log_exc("Method %s raised an exception"%self.name) 
             raise fault
 
 
index 00d9319..7e38419 100644 (file)
@@ -4,11 +4,6 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id$
-#
-
-### $Id$
-### $URL$
 
 from types import *
 from sfa.util.faults import *
index 5580c44..f292823 100644 (file)
@@ -17,7 +17,8 @@ def hrn_to_pl_login_base (hrn):
     return PlXrn(xrn=hrn,type='slice').pl_login_base()
 def hrn_to_pl_authname (hrn):
     return PlXrn(xrn=hrn,type='any').pl_authname()
-
+def xrn_to_hostname(hrn):
+    return Xrn.unescape(PlXrn(xrn=hrn, type='node').get_leaf())
 
 class PlXrn (Xrn):
 
@@ -25,7 +26,7 @@ class PlXrn (Xrn):
     def site_hrn (auth, login_base):
         return '.'.join([auth,login_base])
 
-    def __init__ (self, auth=None, hostname=None, slicename=None, email=None, **kwargs):
+    def __init__ (self, auth=None, hostname=None, slicename=None, email=None, interface=None, **kwargs):
         #def hostname_to_hrn(auth_hrn, login_base, hostname):
         if hostname is not None:
             self.type='node'
@@ -47,6 +48,10 @@ class PlXrn (Xrn):
             # keep only the part before '@' and replace special chars into _
             self.hrn='.'.join([auth,email.split('@')[0].replace(".", "_").replace("+", "_")])
             self.hrn_to_urn()
+        elif interface is not None:
+            self.type = 'interface'
+            self.hrn = auth + '.' + interface
+            self.hrn_to_urn()
         else:
             Xrn.__init__ (self,**kwargs)
 
@@ -62,6 +67,10 @@ class PlXrn (Xrn):
         self._normalize()
         return self.authority[-1]
 
+    def interface_name(self):
+        self._normalize()
+        return self.leaf
+
     #def hrn_to_pl_login_base(hrn):
     def pl_login_base (self):
         self._normalize()
index aa68f43..196c71e 100644 (file)
@@ -1,6 +1,3 @@
-### $Id$
-### $URL$
-
 import os
 
 from sfa.util.storage import *
index a539f62..7cfc752 100644 (file)
@@ -4,17 +4,14 @@
 # TODO: Use existing PLC database methods? or keep this separate?
 ##
 
-### $Id$
-### $URL$
-
 from types import StringTypes
 
 from sfa.trust.gid import *
 
-from sfa.util.rspec import *
 from sfa.util.parameter import *
 from sfa.util.xrn import get_authority
 from sfa.util.row import Row
+from sfa.util.xml import XML 
 
 class SfaRecord(Row):
     """ 
@@ -207,6 +204,22 @@ class SfaRecord(Row):
         """
         return GID(string=self.gid)
 
+    ##
+    # Returns the value of a field
+
+    def get_field(self, fieldname, default=None):
+        # sometimes records act like classes, and sometimes they act like dicts
+        try:
+            return getattr(self, fieldname)
+        except AttributeError:
+            try:
+                 return self[fieldname]
+            except KeyError:
+                 if default != None:
+                     return default
+                 else:
+                     raise
+
     ##
     # Returns a list of field names in this record. 
 
@@ -288,10 +301,9 @@ class SfaRecord(Row):
         """
         recorddict = self.as_dict()
         filteredDict = dict([(key, val) for (key, val) in recorddict.iteritems() if key in self.fields.keys()])
-        record = RecordSpec()
-        record.parseDict(filteredDict)
+        record = XML('<record/>')
+        record.root.attrib.update(filteredDict)
         str = record.toxml()
-        #str = xmlrpclib.dumps((dict,), allow_none=True)
         return str
 
     ##
@@ -305,11 +317,8 @@ class SfaRecord(Row):
         """
         #dict = xmlrpclib.loads(str)[0][0]
         
-        record = RecordSpec()
-        record.parseString(str)
-        record_dict = record.toDict()
-        sfa_dict = record_dict['record']
-        self.load_from_dict(sfa_dict)
+        record = XML(str)
+        self.load_from_dict(record.todict())
 
     ##
     # Dump the record to stdout
diff --git a/sfa/util/rspec.py b/sfa/util/rspec.py
deleted file mode 100644 (file)
index ffc816c..0000000
+++ /dev/null
@@ -1,417 +0,0 @@
-import sys
-import pprint
-import os
-from StringIO import StringIO
-from types import StringTypes, ListType
-import httplib
-from xml.dom import minidom
-from lxml import etree
-
-from sfa.util.sfalogging import sfa_logger
-
-class RSpec:
-
-    def __init__(self, xml = None, xsd = None, NSURL = None):
-        '''
-        Class to manipulate RSpecs.  Reads and parses rspec xml into python dicts
-        and reads python dicts and writes rspec xml
-
-        self.xsd = # Schema.  Can be local or remote file.
-        self.NSURL = # If schema is remote, Name Space URL to query (full path minus filename)
-        self.rootNode = # root of the DOM
-        self.dict = # dict of the RSpec.
-        self.schemaDict = {} # dict of the Schema
-        '''
-        self.xsd = xsd
-        self.rootNode = None
-        self.dict = {}
-        self.schemaDict = {}
-        self.NSURL = NSURL 
-        if xml:
-            if type(xml) == file:
-                self.parseFile(xml)
-            if type(xml) in StringTypes:
-                self.parseString(xml)
-            self.dict = self.toDict() 
-        if xsd:
-            self._parseXSD(self.NSURL + self.xsd)
-
-
-    def _getText(self, nodelist):
-        rc = ""
-        for node in nodelist:
-            if node.nodeType == node.TEXT_NODE:
-                rc = rc + node.data
-        return rc
-  
-    # The rspec is comprised of 2 parts, and 1 reference:
-    # attributes/elements describe individual resources
-    # complexTypes are used to describe a set of attributes/elements
-    # complexTypes can include a reference to other complexTypes.
-  
-  
-    def _getName(self, node):
-        '''Gets name of node. If tag has no name, then return tag's localName'''
-        name = None
-        if not node.nodeName.startswith("#"):
-            if node.localName:
-                name = node.localName
-            elif node.attributes.has_key("name"):
-                name = node.attributes.get("name").value
-        return name     
-    # Attribute.  {name : nameofattribute, {items: values})
-    def _attributeDict(self, attributeDom):
-        '''Traverse single attribute node.  Create a dict {attributename : {name: value,}]}'''
-        node = {} # parsed dict
-        for attr in attributeDom.attributes.keys():
-            node[attr] = attributeDom.attributes.get(attr).value
-        return node
-  
-    def appendToDictOrCreate(self, dict, key, value):
-        if (dict.has_key(key)):
-            dict[key].append(value)
-        else:
-            dict[key]=[value]
-        return dict
-
-    def toGenDict(self, nodeDom=None, parentdict=None, siblingdict={}, parent=None):
-        """
-        convert an XML to a nested dict:
-          * Non-terminal nodes (elements with string children and attributes) are simple dictionaries
-          * Terminal nodes (the rest) are nested dictionaries
-        """
-
-        if (not nodeDom):
-            nodeDom=self.rootNode
-
-        curNodeName = nodeDom.localName
-
-        if (nodeDom.hasChildNodes()):
-            childdict={}
-            for attribute in nodeDom.attributes.keys():
-                childdict = self.appendToDictOrCreate(childdict, attribute, nodeDom.getAttribute(attribute))
-            for child in nodeDom.childNodes[:-1]:
-                if (child.nodeValue):
-                    siblingdict = self.appendToDictOrCreate(siblingdict, curNodeName, child.nodeValue)
-                else:
-                    childdict = self.toGenDict(child, None, childdict, curNodeName)
-
-            child = nodeDom.childNodes[-1]
-            if (child.nodeValue):
-                siblingdict = self.appendToDictOrCreate(siblingdict, curNodeName, child.nodeValue)
-                if (childdict):
-                    siblingdict = self.appendToDictOrCreate(siblingdict, curNodeName, childdict)
-            else:
-                siblingdict = self.toGenDict(child, siblingdict, childdict, curNodeName)
-        else:
-            childdict={}
-            for attribute in nodeDom.attributes.keys():
-                childdict = self.appendToDictOrCreate(childdict, attribute, nodeDom.getAttribute(attribute))
-
-            self.appendToDictOrCreate(siblingdict, curNodeName, childdict)
-            
-        if (parentdict is not None):
-            parentdict = self.appendToDictOrCreate(parentdict, parent, siblingdict)
-            return parentdict
-        else:
-            return siblingdict
-
-
-
-    def toDict(self, nodeDom = None):
-        """
-        convert this rspec to a dict and return it.
-        """
-        node = {}
-        if not nodeDom:
-             nodeDom = self.rootNode
-  
-        elementName = nodeDom.nodeName
-        if elementName and not elementName.startswith("#"):
-            # attributes have tags and values.  get {tag: value}, else {type: value}
-            node[elementName] = self._attributeDict(nodeDom)
-            # resolve the child nodes.
-            if nodeDom.hasChildNodes():
-                for child in nodeDom.childNodes:
-                    childName = self._getName(child)
-                    
-                    # skip null children
-                    if not childName: continue
-
-                    # initialize the possible array of children
-                    if not node[elementName].has_key(childName): node[elementName][childName] = []
-
-                    if isinstance(child, minidom.Text):
-                        # add if data is not empty
-                        if child.data.strip():
-                            node[elementName][childName].append(nextchild.data)
-                    elif child.hasChildNodes() and isinstance(child.childNodes[0], minidom.Text):
-                        for nextchild in child.childNodes:  
-                            node[elementName][childName].append(nextchild.data)
-                    else:
-                        childdict = self.toDict(child)
-                        for value in childdict.values():
-                            node[elementName][childName].append(value)
-
-        return node
-
-  
-    def toxml(self):
-        """
-        convert this rspec to an xml string and return it.
-        """
-        return self.rootNode.toxml()
-
-  
-    def toprettyxml(self):
-        """
-        print this rspec in xml in a pretty format.
-        """
-        return self.rootNode.toprettyxml()
-
-  
-    def __removeWhitespaceNodes(self, parent):
-        for child in list(parent.childNodes):
-            if child.nodeType == minidom.Node.TEXT_NODE and child.data.strip() == '':
-                parent.removeChild(child)
-            else:
-                self.__removeWhitespaceNodes(child)
-
-    def parseFile(self, filename):
-        """
-        read a local xml file and store it as a dom object.
-        """
-        dom = minidom.parse(filename)
-        self.__removeWhitespaceNodes(dom)
-        self.rootNode = dom.childNodes[0]
-
-
-    def parseString(self, xml):
-        """
-        read an xml string and store it as a dom object.
-        """
-        dom = minidom.parseString(xml)
-        self.__removeWhitespaceNodes(dom)
-        self.rootNode = dom.childNodes[0]
-
-    def _httpGetXSD(self, xsdURI):
-        # split the URI into relevant parts
-        host = xsdURI.split("/")[2]
-        if xsdURI.startswith("https"):
-            conn = httplib.HTTPSConnection(host,
-                httplib.HTTPSConnection.default_port)
-        elif xsdURI.startswith("http"):
-            conn = httplib.HTTPConnection(host,
-                httplib.HTTPConnection.default_port)
-        conn.request("GET", xsdURI)
-        # If we can't download the schema, raise an exception
-        r1 = conn.getresponse()
-        if r1.status != 200: 
-            raise Exception
-        return r1.read().replace('\n', '').replace('\t', '').strip() 
-
-
-    def _parseXSD(self, xsdURI):
-        """
-        Download XSD from URL, or if file, read local xsd file and set
-        schemaDict.
-        
-        Since the schema definiton is a global namespace shared by and
-        agreed upon by others, this should probably be a URL.  Check
-        for URL, download xsd, parse, or if local file, use that.
-        """
-        schemaDom = None
-        if xsdURI.startswith("http"):
-            try: 
-                schemaDom = minidom.parseString(self._httpGetXSD(xsdURI))
-            except Exception, e:
-                # logging.debug("%s: web file not found" % xsdURI)
-                # logging.debug("Using local file %s" % self.xsd")
-                sfa_logger().log_exc("rspec.parseXSD: can't find %s on the web. Continuing." % xsdURI)
-        if not schemaDom:
-            if os.path.exists(xsdURI):
-                # logging.debug("using local copy.")
-                sfa_logger().debug("rspec.parseXSD: Using local %s" % xsdURI)
-                schemaDom = minidom.parse(xsdURI)
-            else:
-                raise Exception("rspec.parseXSD: can't find xsd locally")
-        self.schemaDict = self.toDict(schemaDom.childNodes[0])
-
-
-    def dict2dom(self, rdict, include_doc = False):
-        """
-        convert a dict object into a dom object.
-        """
-     
-        def elementNode(tagname, rd):
-            element = minidom.Element(tagname)
-            for key in rd.keys():
-                if isinstance(rd[key], StringTypes) or isinstance(rd[key], int):
-                    element.setAttribute(key, str(rd[key]))
-                elif isinstance(rd[key], dict):
-                    child = elementNode(key, rd[key])
-                    element.appendChild(child)
-                elif isinstance(rd[key], list):
-                    for item in rd[key]:
-                        if isinstance(item, dict):
-                            child = elementNode(key, item)
-                            element.appendChild(child)
-                        elif isinstance(item, StringTypes) or isinstance(item, int):
-                            child = minidom.Element(key)
-                            text = minidom.Text()
-                            text.data = item
-                            child.appendChild(text)
-                            element.appendChild(child) 
-            return element
-        
-        # Minidom does not allow documents to have more then one
-        # child, but elements may have many children. Because of
-        # this, the document's root node will be the first key/value
-        # pair in the dictionary.  
-        node = elementNode(rdict.keys()[0], rdict.values()[0])
-        if include_doc:
-            rootNode = minidom.Document()
-            rootNode.appendChild(node)
-        else:
-            rootNode = node
-        return rootNode
-
-    def parseDict(self, rdict, include_doc = True):
-        """
-        Convert a dictionary into a dom object and store it.
-        """
-        self.rootNode = self.dict2dom(rdict, include_doc).childNodes[0]
-    def getDictsByTagName(self, tagname, dom = None):
-        """
-        Search the dom for all elements with the specified tagname
-        and return them as a list of dicts
-        """
-        if not dom:
-            dom = self.rootNode
-        dicts = []
-        doms = dom.getElementsByTagName(tagname)
-        dictlist = [self.toDict(d) for d in doms]
-        for item in dictlist:
-            for value in item.values():
-                dicts.append(value)
-        return dicts
-
-    def getDictByTagNameValue(self, tagname, value, dom = None):
-        """
-        Search the dom for the first element with the specified tagname
-        and value and return it as a dict.
-        """
-        tempdict = {}
-        if not dom:
-            dom = self.rootNode
-        dicts = self.getDictsByTagName(tagname, dom)
-        
-        for rdict in dicts:
-            if rdict.has_key('name') and rdict['name'] in [value]:
-                return rdict
-              
-        return tempdict
-
-
-    def filter(self, tagname, attribute, blacklist = [], whitelist = [], dom = None):
-        """
-        Removes all elements where:
-        1. tagname matches the element tag
-        2. attribute matches the element attribte
-        3. attribute value is in valuelist  
-        """
-
-        tempdict = {}
-        if not dom:
-            dom = self.rootNode
-       
-        if dom.localName in [tagname] and dom.attributes.has_key(attribute):
-            if whitelist and dom.attributes.get(attribute).value not in whitelist:
-                dom.parentNode.removeChild(dom)
-            if blacklist and dom.attributes.get(attribute).value in blacklist:
-                dom.parentNode.removeChild(dom)
-           
-        if dom.hasChildNodes():
-            for child in dom.childNodes:
-                self.filter(tagname, attribute, blacklist, whitelist, child) 
-
-
-    def merge(self, rspecs, tagname, dom=None):
-        """
-        Merge this rspec with the requested rspec based on the specified 
-        starting tag name. The start tag (and all of its children) will be merged  
-        """
-        tempdict = {}
-        if not dom:
-            dom = self.rootNode
-
-        whitelist = []
-        blacklist = []
-            
-        if dom.localName in [tagname] and dom.attributes.has_key(attribute):
-            if whitelist and dom.attributes.get(attribute).value not in whitelist:
-                dom.parentNode.removeChild(dom)
-            if blacklist and dom.attributes.get(attribute).value in blacklist:
-                dom.parentNode.removeChild(dom)
-
-        if dom.hasChildNodes():
-            for child in dom.childNodes:
-                self.filter(tagname, attribute, blacklist, whitelist, child) 
-
-    def validateDicts(self):
-        types = {
-            'EInt' : int,
-            'EString' : str,
-            'EByteArray' : list,
-            'EBoolean' : bool,
-            'EFloat' : float,
-            'EDate' : date}
-
-
-    def pprint(self, r = None, depth = 0):
-        """
-        Pretty print the dict
-        """
-        line = ""
-        if r == None: r = self.dict
-        # Set the dept
-        for tab in range(0,depth): line += "    "
-        # check if it's nested
-        if type(r) == dict:
-            for i in r.keys():
-                print line + "%s:" % i
-                self.pprint(r[i], depth + 1)
-        elif type(r) in (tuple, list):
-            for j in r: self.pprint(j, depth + 1)
-        # not nested so just print.
-        else:
-            print line + "%s" %  r
-    
-
-
-class RecordSpec(RSpec):
-
-    root_tag = 'record'
-    def parseDict(self, rdict, include_doc = False):
-        """
-        Convert a dictionary into a dom object and store it.
-        """
-        self.rootNode = self.dict2dom(rdict, include_doc)
-
-    def dict2dom(self, rdict, include_doc = False):
-        record_dict = rdict
-        if not len(rdict.keys()) == 1:
-            record_dict = {self.root_tag : rdict}
-        return RSpec.dict2dom(self, record_dict, include_doc)
-
-        
-# vim:ts=4:expandtab
-    
index 1ccc984..89f15af 100755 (executable)
@@ -8,7 +8,7 @@ from StringIO import StringIO
 from optparse import OptionParser
 
 from sfa.util.faults import *
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
 
 def merge_rspecs(rspecs):
     """
@@ -24,13 +24,13 @@ def merge_rspecs(rspecs):
         try:
             known_networks[network.get('name')]=True
         except:
-            sfa_logger().error("merge_rspecs: cannot register network with no name in rspec")
+            logger.error("merge_rspecs: cannot register network with no name in rspec")
             pass
     def is_registered_network (network):
         try:
             return network.get('name') in known_networks
         except:
-            sfa_logger().error("merge_rspecs: cannot retrieve network with no name in rspec")
+            logger.error("merge_rspecs: cannot retrieve network with no name in rspec")
             return False
 
     # the resulting tree
@@ -42,13 +42,13 @@ def merge_rspecs(rspecs):
             tree = etree.parse(StringIO(input_rspec))
         except etree.XMLSyntaxError:
             # consider failing silently here
-            sfa_logger().log_exc("merge_rspecs, parse error")
+            logger.log_exc("merge_rspecs, parse error")
             message = str(sys.exc_info()[1]) + ' with ' + input_rspec
             raise InvalidRSpec(message)
 
         root = tree.getroot()
         if not root.get("type") in ["SFA"]:
-            sfa_logger().error("merge_rspecs: unexpected type for rspec root, %s"%root.get('type'))
+            logger.error("merge_rspecs: unexpected type for rspec root, %s"%root.get('type'))
             continue
         if rspec == None:
             # we scan the first input, register all networks
index b4fd2ff..c3ae718 100644 (file)
@@ -18,13 +18,13 @@ import SimpleXMLRPCServer
 from OpenSSL import SSL
 
 from sfa.trust.certificate import Keypair, Certificate
-from sfa.trust.trustedroot import TrustedRootList
+from sfa.trust.trustedroots import TrustedRoots
 from sfa.util.config import Config
 from sfa.trust.credential import *
 from sfa.util.faults import *
 from sfa.plc.api import SfaAPI
 from sfa.util.cache import Cache 
-from sfa.util.sfalogging import sfa_logger
+from sfa.util.sfalogging import logger
 
 ##
 # Verification callback for pyOpenSSL. We do our own checking of keys because
@@ -110,7 +110,7 @@ class SecureXMLRpcRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
         except Exception, fault:
             # This should only happen if the module is buggy
             # internal error, report as HTTP server error
-            sfa_logger().log_exc("server.do_POST")
+            logger.log_exc("server.do_POST")
             response = self.api.prepare_response(fault)
             #self.send_response(500)
             #self.end_headers()
@@ -134,7 +134,7 @@ class SecureXMLRPCServer(BaseHTTPServer.HTTPServer,SimpleXMLRPCServer.SimpleXMLR
 
         It it very similar to SimpleXMLRPCServer but it uses HTTPS for transporting XML data.
         """
-        sfa_logger().debug("SecureXMLRPCServer.__init__, server_address=%s, cert_file=%s"%(server_address,cert_file))
+        logger.debug("SecureXMLRPCServer.__init__, server_address=%s, cert_file=%s"%(server_address,cert_file))
         self.logRequests = logRequests
         self.interface = None
         self.key_file = key_file
@@ -154,7 +154,7 @@ class SecureXMLRPCServer(BaseHTTPServer.HTTPServer,SimpleXMLRPCServer.SimpleXMLR
         # If you wanted to verify certs against known CAs.. this is how you would do it
         #ctx.load_verify_locations('/etc/sfa/trusted_roots/plc.gpo.gid')
         config = Config()
-        trusted_cert_files = TrustedRootList(config.get_trustedroots_dir()).get_file_list()
+        trusted_cert_files = TrustedRoots(config.get_trustedroots_dir()).get_file_list()
         for cert_file in trusted_cert_files:
             ctx.load_verify_locations(cert_file)
         ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, verify_callback)
@@ -171,7 +171,7 @@ class SecureXMLRPCServer(BaseHTTPServer.HTTPServer,SimpleXMLRPCServer.SimpleXMLR
     # the client.
 
     def _dispatch(self, method, params):
-        sfa_logger().debug("SecureXMLRPCServer._dispatch, method=%s"%method)
+        logger.debug("SecureXMLRPCServer._dispatch, method=%s"%method)
         try:
             return SimpleXMLRPCServer.SimpleXMLRPCDispatcher._dispatch(self, method, params)
         except:
@@ -287,7 +287,7 @@ class SfaServer(threading.Thread):
         self.server.interface=interface
         self.trusted_cert_list = None
         self.register_functions()
-        sfa_logger().info("Starting SfaServer, interface=%s"%interface)
+        logger.info("Starting SfaServer, interface=%s"%interface)
 
     ##
     # Register functions that will be served by the XMLRPC server. This
old mode 100755 (executable)
new mode 100644 (file)
index 42e2e67..75229b3
@@ -15,24 +15,42 @@ class _SfaLogger:
     def __init__ (self,logfile=None,loggername=None,level=logging.INFO):
         # default is to locate loggername from the logfile if avail.
         if not logfile:
-            loggername='console'
-            handler=logging.StreamHandler()
-            handler.setFormatter(logging.Formatter("%(levelname)s %(message)s"))
-        else:
-            if not loggername:
-                loggername=os.path.basename(logfile)
-            try:
-                handler=logging.handlers.RotatingFileHandler(logfile,maxBytes=1000000, backupCount=5) 
-            except IOError:
-                # This is usually a permissions error becaue the file is
-                # owned by root, but httpd is trying to access it.
-                tmplogfile=os.getenv("TMPDIR", "/tmp") + os.path.sep + os.path.basename(logfile)
+            #loggername='console'
+            #handler=logging.StreamHandler()
+            #handler.setFormatter(logging.Formatter("%(levelname)s %(message)s"))
+            logfile = "/var/log/sfa.log"
+
+        if not loggername:
+            loggername=os.path.basename(logfile)
+        try:
+            handler=logging.handlers.RotatingFileHandler(logfile,maxBytes=1000000, backupCount=5) 
+        except IOError:
+            # This is usually a permissions error becaue the file is
+            # owned by root, but httpd is trying to access it.
+            tmplogfile=os.getenv("TMPDIR", "/tmp") + os.path.sep + os.path.basename(logfile)
+            # In strange uses, 2 users on same machine might use same code,
+            # meaning they would clobber each others files
+            # We could (a) rename the tmplogfile, or (b)
+            # just log to the console in that case.
+            # Here we default to the console.
+            if os.path.exists(tmplogfile) and not os.access(tmplogfile,os.W_OK):
+                loggername = loggername + "-console"
+                handler = logging.StreamHandler()
+            else:
                 handler=logging.handlers.RotatingFileHandler(tmplogfile,maxBytes=1000000, backupCount=5) 
-            handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
-
+        handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
         self.logger=logging.getLogger(loggername)
         self.logger.setLevel(level)
-        self.logger.addHandler(handler)
+        # check if logger already has the handler we're about to add
+        handler_exists = False
+        for l_handler in self.logger.handlers:
+            if l_handler.baseFilename == handler.baseFilename and \
+               l_handler.level == handler.level:
+                handler_exists = True 
+
+        if not handler_exists:
+            self.logger.addHandler(handler)
+
         self.loggername=loggername
 
     def setLevel(self,level):
@@ -54,24 +72,25 @@ class _SfaLogger:
             self.logger.setLevel(logging.DEBUG)
 
     ####################
-    def wrap(fun):
-        def wrapped(self,msg,*args,**kwds):
-            native=getattr(self.logger,fun.__name__)
-            return native(msg,*args,**kwds)
-        #wrapped.__doc__=native.__doc__
-        return wrapped
-
-    @wrap
-    def critical(): pass
-    @wrap
-    def error(): pass
-    @wrap
-    def warning(): pass
-    @wrap
-    def info(): pass
-    @wrap
-    def debug(): pass
-    
+    def info(self, msg):
+        self.logger.info(msg)
+
+    def debug(self, msg):
+        self.logger.debug(msg)
+        
+    def warn(self, msg):
+        self.logger.warn(msg)
+
+    # some code is using logger.warn(), some is using logger.warning()
+    def warning(self, msg):
+        self.logger.warning(msg)
+   
+    def error(self, msg):
+        self.logger.error(msg)    
+    def critical(self, msg):
+        self.logger.critical(msg)
+
     # logs an exception - use in an except statement
     def log_exc(self,message):
         self.error("%s BEG TRACEBACK"%message+"\n"+traceback.format_exc().strip("\n"))
@@ -84,34 +103,23 @@ class _SfaLogger:
     # for investigation purposes, can be placed anywhere
     def log_stack(self,message):
         to_log="".join(traceback.format_stack())
-        self.debug("%s BEG STACK"%message+"\n"+to_log)
-        self.debug("%s END STACK"%message)
-
-####################
-# import-related operations go in this file
-_import_logger=_SfaLogger(logfile='/var/log/sfa_import.log')
-# servers log into /var/log/sfa.log
-_server_logger=_SfaLogger(logfile='/var/log/sfa.log')
-# clients use the console
-_console_logger=_SfaLogger()
-
-# default is to use the server-side logger
-_the_logger=_server_logger
-
-# clients would change the default by issuing one of these call
-def sfa_logger_goes_to_console():
-    current_module=sys.modules[globals()['__name__']]
-    current_module._the_logger=_console_logger
-
-# clients would change the default by issuing one of these call
-def sfa_logger_goes_to_import():
-    current_module=sys.modules[globals()['__name__']]
-    current_module._the_logger=_import_logger
-
-# this is how to retrieve the 'right' logger
-def sfa_logger():
-    return _the_logger
+        self.info("%s BEG STACK"%message+"\n"+to_log)
+        self.info("%s END STACK"%message)
+
+    def enable_console(self, stream=sys.stdout):
+        formatter = logging.Formatter("%(message)s")
+        handler = logging.StreamHandler(stream)
+        handler.setFormatter(formatter)
+        self.logger.addHandler(handler)
 
+
+info_logger = _SfaLogger(loggername='info', level=logging.INFO)
+debug_logger = _SfaLogger(loggername='debug', level=logging.DEBUG)
+warn_logger = _SfaLogger(loggername='warning', level=logging.WARNING)
+error_logger = _SfaLogger(loggername='error', level=logging.ERROR)
+critical_logger = _SfaLogger(loggername='critical', level=logging.CRITICAL)
+logger = info_logger
+sfi_logger = _SfaLogger(logfile=os.path.expanduser("~/.sfi/")+'sfi.log',loggername='sfilog', level=logging.DEBUG)
 ########################################
 import time
 
@@ -139,25 +147,29 @@ def profile(logger):
 
 if __name__ == '__main__': 
     print 'testing sfalogging into logger.log'
-    logger=_SfaLogger('logger.log')
-    logger.critical("logger.critical")
-    logger.error("logger.error")
-    logger.warning("logger.warning")
-    logger.info("logger.info")
-    logger.debug("logger.debug")
-    logger.setLevel(logging.DEBUG)
-    logger.debug("logger.debug again")
+    logger1=_SfaLogger('logger.log', loggername='std(info)')
+    logger2=_SfaLogger('logger.log', loggername='error', level=logging.ERROR)
+    logger3=_SfaLogger('logger.log', loggername='debug', level=logging.DEBUG)
+    
+    for (logger,msg) in [ (logger1,"std(info)"),(logger2,"error"),(logger3,"debug")]:
+        
+        print "====================",msg, logger.logger.handlers
+   
+        logger.enable_console()
+        logger.critical("logger.critical")
+        logger.error("logger.error")
+        logger.warn("logger.warning")
+        logger.info("logger.info")
+        logger.debug("logger.debug")
+        logger.setLevel(logging.DEBUG)
+        logger.debug("logger.debug again")
     
-    sfa_logger_goes_to_console()
-    my_logger=sfa_logger()
-    my_logger.info("redirected to console")
-
-    @profile(my_logger)
-    def sleep(seconds = 1):
-        time.sleep(seconds)
-
-    my_logger.info('console.info')
-    sleep(0.5)
-    my_logger.setLevel(logging.DEBUG)
-    sleep(0.25)
+        @profile(logger)
+        def sleep(seconds = 1):
+            time.sleep(seconds)
+
+        logger.info('console.info')
+        sleep(0.5)
+        logger.setLevel(logging.DEBUG)
+        sleep(0.25)
 
index 901b4e0..11cc566 100644 (file)
@@ -1,10 +1,26 @@
+from types import StringTypes
 import dateutil.parser
+import datetime
 
-def utcparse(str):
+from sfa.util.sfalogging import logger
+
+def utcparse(input):
     """ Translate a string into a time using dateutil.parser.parse but make sure it's in UTC time and strip
-    the timezone, so that it's compatible with normal datetime.datetime objects"""
+the timezone, so that it's compatible with normal datetime.datetime objects.
+
+For safety this can also handle inputs that are either timestamps, or datetimes
+"""
     
-    t = dateutil.parser.parse(str)
-    if not t.utcoffset() is None:
-        t = t.utcoffset() + t.replace(tzinfo=None)
-    return t
+    if isinstance (input, datetime.datetime):
+        logger.warn ("argument to utcparse already a datetime - doing nothing")
+        return input
+    elif isinstance (input, StringTypes):
+        t = dateutil.parser.parse(input)
+        if t.utcoffset() is not None:
+            t = t.utcoffset() + t.replace(tzinfo=None)
+        return t
+    elif isinstance (input, (int,float)):
+        return datetime.datetime.fromtimestamp(input)
+    else:
+        logger.error("Unexpected type in utcparse [%s]"%type(input))
+
index 7a6ff97..2138e87 100644 (file)
@@ -10,9 +10,6 @@
 # SpecDict.plc_fields defines a one to one mapping of plc attribute to rspec 
 # attribute
 
-### $Id$
-### $URL$
-
 from types import StringTypes, ListType
 
 class SpecDict(dict):
diff --git a/sfa/util/ssl_socket.py b/sfa/util/ssl_socket.py
new file mode 100644 (file)
index 0000000..d221da3
--- /dev/null
@@ -0,0 +1,76 @@
+from ssl import SSLSocket
+
+import textwrap
+
+import _ssl             # if we can't import it, let the error propagate
+
+from _ssl import SSLError
+from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
+from _ssl import PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1
+from _ssl import RAND_status, RAND_egd, RAND_add
+from _ssl import \
+     SSL_ERROR_ZERO_RETURN, \
+     SSL_ERROR_WANT_READ, \
+     SSL_ERROR_WANT_WRITE, \
+     SSL_ERROR_WANT_X509_LOOKUP, \
+     SSL_ERROR_SYSCALL, \
+     SSL_ERROR_SSL, \
+     SSL_ERROR_WANT_CONNECT, \
+     SSL_ERROR_EOF, \
+     SSL_ERROR_INVALID_ERROR_CODE
+
+from socket import socket, _fileobject
+from socket import getnameinfo as _getnameinfo
+import base64        # for DER-to-PEM translation
+
+class SSLSocket(SSLSocket, socket):
+
+    """This class implements a subtype of socket.socket that wraps
+    the underlying OS socket in an SSL context when necessary, and
+    provides read and write methods over that channel."""
+
+    def __init__(self, sock, keyfile=None, certfile=None,
+                 server_side=False, cert_reqs=CERT_NONE,
+                 ssl_version=PROTOCOL_SSLv23, ca_certs=None,
+                 do_handshake_on_connect=True,
+                 suppress_ragged_eofs=True):
+        socket.__init__(self, _sock=sock._sock)
+        # the initializer for socket trashes the methods (tsk, tsk), so...
+        self.send = lambda data, flags=0: SSLSocket.send(self, data, flags)
+        self.sendto = lambda data, addr, flags=0: SSLSocket.sendto(self, data, addr, flags)
+        self.recv = lambda buflen=1024, flags=0: SSLSocket.recv(self, buflen, flags)
+        self.recvfrom = lambda addr, buflen=1024, flags=0: SSLSocket.recvfrom(self, addr, buflen, flags)
+        self.recv_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recv_into(self, buffer, nbytes, flags)
+        self.recvfrom_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recvfrom_into(self, buffer, nbytes, flags)
+
+        if certfile and not keyfile:
+            keyfile = certfile
+        # see if it's connected
+        try:
+            socket.getpeername(self)
+        except:
+            # no, no connection yet
+            self._sslobj = None
+        else:
+            # yes, create the SSL object
+            self._sslobj = _ssl.sslwrap(self._sock, server_side,
+                                        keyfile, certfile,
+                                        cert_reqs, ssl_version, ca_certs)
+            if do_handshake_on_connect:
+                timeout = self.gettimeout()
+                try:
+                    if timeout == 0:
+                        self.settimeout(None)
+                    self.do_handshake()
+                finally:
+                    self.settimeout(timeout)
+        self.keyfile = keyfile
+        self.certfile = certfile
+        self.cert_reqs = cert_reqs
+        self.ssl_version = ssl_version
+        self.ca_certs = ca_certs
+        self.do_handshake_on_connect = do_handshake_on_connect
+        self.suppress_ragged_eofs = suppress_ragged_eofs
+        self._makefile_refs = 0
+
+
index 5d91539..ee2e41b 100644 (file)
@@ -1,7 +1,5 @@
 import os
-
-from sfa.util.rspec import RecordSpec
-
+from sfa.util.xml import XML
 class SimpleStorage(dict):
     """
     Handles storing and loading python dictionaries. The storage file created
@@ -46,10 +44,9 @@ class XmlStorage(SimpleStorage):
         """
         Parse an xml file and store it as a dict
         """ 
-        data = RecordSpec()
         if os.path.exists(self.db_filename) and os.path.isfile(self.db_filename):
-            data.parseFile(self.db_filename)
-            dict.__init__(self, data.toDict())
+            xml = XML(self.db_filename)
+            dict.__init__(self, xml.todict())
         elif os.path.exists(self.db_filename) and not os.path.isfile(self.db_filename):
             raise IOError, '%s exists but is not a file. please remove it and try again' \
                            % self.db_filename
@@ -58,8 +55,8 @@ class XmlStorage(SimpleStorage):
             self.load()
 
     def write(self):
-        data = RecordSpec()
-        data.parseDict(self)
+        xml = XML()
+        xml.parseDict(self)
         db_file = open(self.db_filename, 'w')
         db_file.write(data.toprettyxml())
         db_file.close()
old mode 100755 (executable)
new mode 100644 (file)
index 331f847..b47b818
@@ -2,6 +2,7 @@ import threading
 import traceback
 import time
 from Queue import Queue
+from sfa.util.sfalogging import logger
 
 def ThreadedMethod(callable, results, errors):
     """
@@ -15,6 +16,7 @@ def ThreadedMethod(callable, results, errors):
                 try:
                     results.put(callable(*args, **kwds))
                 except Exception, e:
+                    logger.log_exc('ThreadManager: Error in thread: ')
                     errors.put(traceback.format_exc())
                     
         thread = ThreadInstance()
@@ -29,9 +31,11 @@ class ThreadManager:
     ThreadManager executes a callable in a thread and stores the result
     in a thread safe queue. 
     """
-    results = Queue()
-    errors = Queue()
-    threads = []
+
+    def __init__(self):
+        self.results = Queue()
+        self.errors = Queue()
+        self.threads = []
 
     def run (self, method, *args, **kwds):
         """
diff --git a/sfa/util/xml.py b/sfa/util/xml.py
new file mode 100755 (executable)
index 0000000..91a1d95
--- /dev/null
@@ -0,0 +1,225 @@
+#!/usr/bin/python 
+from lxml import etree
+from StringIO import StringIO
+from datetime import datetime, timedelta
+from sfa.util.xrn import *
+from sfa.util.plxrn import hostname_to_urn
+from sfa.util.faults import SfaNotImplemented, InvalidXML
+
+class XpathFilter:
+    @staticmethod
+    def xpath(filter={}):
+        xpath = ""
+        if filter:
+            filter_list = []
+            for (key, value) in filter.items():
+                if key == 'text':
+                    key = 'text()'
+                else:
+                    key = '@'+key
+                if isinstance(value, str):
+                    filter_list.append('%s="%s"' % (key, value))
+                elif isinstance(value, list):
+                    filter_list.append('contains("%s", %s)' % (' '.join(map(str, value)), key))
+            if filter_list:
+                xpath = ' and '.join(filter_list)
+                xpath = '[' + xpath + ']'
+        return xpath
+
+class XML:
+    def __init__(self, xml=None):
+        self.root = None
+        self.namespaces = None
+        self.default_namespace = None
+        self.schema = None
+        if isinstance(xml, basestring):
+            self.parse_xml(xml)
+        elif isinstance(xml, etree._ElementTree):
+            self.root = xml.getroot()
+        elif isinstance(xml, etree._Element):
+            self.root = xml 
+
+    def parse_xml(self, xml):
+        """
+        parse rspec into etree
+        """
+        parser = etree.XMLParser(remove_blank_text=True)
+        try:
+            tree = etree.parse(xml, parser)
+        except IOError:
+            # 'rspec' file doesnt exist. 'rspec' is proably an xml string
+            try:
+                tree = etree.parse(StringIO(xml), parser)
+            except Exception, e:
+                raise InvalidXML(str(e))
+        self.root = tree.getroot()
+        # set namespaces map
+        self.namespaces = dict(self.root.nsmap)
+        # If the 'None' exist, then it's pointing to the default namespace. This makes 
+        # it hard for us to write xpath queries for the default naemspace because lxml 
+        # wont understand a None prefix. We will just associate the default namespeace 
+        # with a key named 'default'.     
+        if None in self.namespaces:
+            default_namespace = self.namespaces.pop(None)
+            self.namespaces['default'] = default_namespace
+
+        # set schema 
+        for key in self.root.attrib.keys():
+            if key.endswith('schemaLocation'):
+                # schema location should be at the end of the list
+                schema_parts  = self.root.attrib[key].split(' ')
+                self.schema = schema_parts[1]    
+                namespace, schema  = schema_parts[0], schema_parts[1]
+                break
+
+    def parse_dict(self, d, root_tag_name='xml', element = None):
+        if element is None: 
+            self.parse_xml('<%s/>' % root_tag_name)
+            element = self.root
+
+        if 'text' in d:
+            text = d.pop('text')
+            element.text = text
+
+        # handle repeating fields
+        for (key, value) in d.items():
+            if isinstance(value, list):
+                value = d.pop(key)
+                for val in value:
+                    if isinstance(val, dict):
+                        child_element = etree.SubElement(element, key)
+                        self.parse_dict(val, key, child_element) 
+        
+        element.attrib.update(d)
+
+    def validate(self, schema):
+        """
+        Validate against rng schema
+        """
+        relaxng_doc = etree.parse(schema)
+        relaxng = etree.RelaxNG(relaxng_doc)
+        if not relaxng(self.root):
+            error = relaxng.error_log.last_error
+            message = "%s (line %s)" % (error.message, error.line)
+            raise InvalidXML(message)
+        return True
+
+    def xpath(self, xpath, namespaces=None):
+        if not namespaces:
+            namespaces = self.namespaces
+        return self.root.xpath(xpath, namespaces=namespaces)
+
+    def set(self, key, value):
+        return self.root.set(key, value)
+
+    def add_attribute(self, elem, name, value):
+        """
+        Add attribute to specified etree element    
+        """
+        opt = etree.SubElement(elem, name)
+        opt.text = value
+
+    def add_element(self, name, attrs={}, parent=None, text=""):
+        """
+        Generic wrapper around etree.SubElement(). Adds an element to 
+        specified parent node. Adds element to root node is parent is 
+        not specified. 
+        """
+        if parent == None:
+            parent = self.root
+        element = etree.SubElement(parent, name)
+        if text:
+            element.text = text
+        if isinstance(attrs, dict):
+            for attr in attrs:
+                element.set(attr, attrs[attr])  
+        return element
+
+    def remove_attribute(self, elem, name, value):
+        """
+        Removes an attribute from an element
+        """
+        if elem is not None:
+            opts = elem.iterfind(name)
+            if opts is not None:
+                for opt in opts:
+                    if opt.text == value:
+                        elem.remove(opt)
+
+    def remove_element(self, element_name, root_node = None):
+        """
+        Removes all occurences of an element from the tree. Start at 
+        specified root_node if specified, otherwise start at tree's root.   
+        """
+        if not root_node:
+            root_node = self.root
+
+        if not element_name.startswith('//'):
+            element_name = '//' + element_name
+
+        elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
+        for element in elements:
+            parent = element.getparent()
+            parent.remove(element)
+
+    def attributes_list(self, elem):
+        # convert a list of attribute tags into list of tuples
+        # (tagnme, text_value)
+        opts = []
+        if elem is not None:
+            for e in elem:
+                opts.append((e.tag, str(e.text).strip()))
+        return opts
+
+    def get_element_attributes(self, elem=None, depth=0):
+        if elem == None:
+            elem = self.root_node
+        if not hasattr(elem, 'attrib'):
+            # this is probably not an element node with attribute. could be just and an
+            # attribute, return it
+            return elem
+        attrs = dict(elem.attrib)
+        attrs['text'] = str(elem.text).strip()
+        attrs['parent'] = elem.getparent()
+        if isinstance(depth, int) and depth > 0:
+            for child_elem in list(elem):
+                key = str(child_elem.tag)
+                if key not in attrs:
+                    attrs[key] = [self.get_element_attributes(child_elem, depth-1)]
+                else:
+                    attrs[key].append(self.get_element_attributes(child_elem, depth-1))
+        else:
+            attrs['child_nodes'] = list(elem)
+        return attrs
+
+    def merge(self, in_xml):
+        pass
+
+    def __str__(self):
+        return self.toxml()
+
+    def toxml(self):
+        return etree.tostring(self.root, pretty_print=True)  
+    
+    def todict(self, elem=None):
+        if elem is None:
+            elem = self.root
+        d = {}
+        d.update(elem.attrib)
+        d['text'] = elem.text
+        for child in elem.iterchildren():
+            if child.tag not in d:
+                d[child.tag] = []
+            d[child.tag].append(self.todict(child))
+        return d            
+        
+    def save(self, filename):
+        f = open(filename, 'w')
+        f.write(self.toxml())
+        f.close()
+if __name__ == '__main__':
+    rspec = RSpec('/tmp/resources.rspec')
+    print rspec
+
index 61e16fe..25e7b76 100644 (file)
@@ -1,10 +1,9 @@
 # XMLRPC-specific code for SFA Client
 
-import httplib
 import xmlrpclib
-
-from sfa.util.sfalogging import sfa_logger
-
+#from sfa.util.httpsProtocol import HTTPS, HTTPSConnection
+from httplib import HTTPS, HTTPSConnection
+from sfa.util.sfalogging import logger
 ##
 # ServerException, ExceptionUnmarshaller
 #
@@ -35,16 +34,42 @@ class ExceptionUnmarshaller(xmlrpclib.Unmarshaller):
 need_HTTPSConnection=hasattr(xmlrpclib.Transport().make_connection('localhost'),'getresponse')
 
 class XMLRPCTransport(xmlrpclib.Transport):
-    key_file = None
-    cert_file = None
+    
+    def __init__(self, key_file=None, cert_file=None, timeout=None):
+        xmlrpclib.Transport.__init__(self)
+        self.timeout=timeout
+        self.key_file = key_file
+        self.cert_file = cert_file
+        
     def make_connection(self, host):
         # create a HTTPS connection object from a host descriptor
         # host may be a string, or a (host, x509-dict) tuple
         host, extra_headers, x509 = self.get_host_info(host)
         if need_HTTPSConnection:
-            return httplib.HTTPSConnection(host, None, key_file=self.key_file, cert_file=self.cert_file) #**(x509 or {}))
+            #conn = HTTPSConnection(host, None, key_file=self.key_file, cert_file=self.cert_file, timeout=self.timeout) #**(x509 or {}))
+            conn = HTTPSConnection(host, None, key_file=self.key_file, cert_file=self.cert_file) #**(x509 or {}))
         else:
-            return httplib.HTTPS(host, None, key_file=self.key_file, cert_file=self.cert_file) #**(x509 or {}))
+            #conn = HTTPS(host, None, key_file=self.key_file, cert_file=self.cert_file, timeout=self.timeout) #**(x509 or {}))
+            conn = HTTPS(host, None, key_file=self.key_file, cert_file=self.cert_file) #**(x509 or {}))
+
+        if hasattr(conn, 'set_timeout'):
+            conn.set_timeout(self.timeout)
+
+        # Some logic to deal with timeouts. It appears that some (or all) versions
+        # of python don't set the timeout after the socket is created. We'll do it
+        # ourselves by forcing the connection to connect, finding the socket, and
+        # calling settimeout() on it. (tested with python 2.6)
+        if self.timeout:
+            if hasattr(conn, "_conn"):
+                # HTTPS is a wrapper around HTTPSConnection
+                real_conn = conn._conn
+            else:
+                real_conn = conn
+            conn.connect()
+            if hasattr(real_conn, "sock") and hasattr(real_conn.sock, "settimeout"):
+                real_conn.sock.settimeout(float(self.timeout))
+
+        return conn
 
     def getparser(self):
         unmarshaller = ExceptionUnmarshaller()
@@ -52,24 +77,16 @@ class XMLRPCTransport(xmlrpclib.Transport):
         return parser, unmarshaller
 
 class XMLRPCServerProxy(xmlrpclib.ServerProxy):
-    def __init__(self, url, transport, allow_none=True, options=None):
+    def __init__(self, url, transport, allow_none=True, verbose=False):
         # remember url for GetVersion
         self.url=url
-        verbose = False
-        if options and options.debug:
-            verbose = True
-#        sfa_logger().debug ("xmlrpcprotocol.XMLRPCServerProxy.__init__ %s (with verbose=%s)"%(url,verbose))
         xmlrpclib.ServerProxy.__init__(self, url, transport, allow_none=allow_none, verbose=verbose)
 
     def __getattr__(self, attr):
-        sfa_logger().debug ("xml-rpc %s method:%s"%(self.url,attr))
+        logger.debug ("xml-rpc %s method:%s"%(self.url,attr))
         return xmlrpclib.ServerProxy.__getattr__(self, attr)
 
-
-def get_server(url, key_file, cert_file, options=None):
-    transport = XMLRPCTransport()
-    transport.key_file = key_file
-    transport.cert_file = cert_file
-
-    return XMLRPCServerProxy(url, transport, allow_none=True, options=options)
+def get_server(url, key_file, cert_file, timeout=None, verbose=False):
+    transport = XMLRPCTransport(key_file, cert_file, timeout)
+    return XMLRPCServerProxy(url, transport, allow_none=True, verbose=verbose)
 
index d5ad2d4..3dc87b6 100644 (file)
@@ -1,13 +1,39 @@
+#----------------------------------------------------------------------
+# Copyright (c) 2008 Board of Trustees, Princeton University
+#
+# Permission is hereby granted, free of charge, to any person obtaining
+# a copy of this software and/or hardware specification (the "Work") to
+# deal in the Work without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Work, and to permit persons to whom the Work
+# is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Work.
+#
+# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
+# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS 
+# IN THE WORK.
+#----------------------------------------------------------------------
+
 import re
 
 from sfa.util.faults import *
-from sfa.util.sfalogging import sfa_logger
 
 # for convenience and smoother translation - we should get rid of these functions eventually 
 def get_leaf(hrn): return Xrn(hrn).get_leaf()
 def get_authority(hrn): return Xrn(hrn).get_authority_hrn()
 def urn_to_hrn(urn): xrn=Xrn(urn); return (xrn.hrn, xrn.type)
 def hrn_to_urn(hrn,type): return Xrn(hrn, type=type).urn
+def hrn_authfor_hrn(parenthrn, hrn): return Xrn.hrn_is_auth_for_hrn(parenthrn, hrn)
+
+def urn_to_sliver_id(urn, slice_id, node_id, index=0):
+    return ":".join(map(str, [urn, slice_id, node_id, index]))
 
 class Xrn:
 
@@ -34,10 +60,38 @@ class Xrn:
     # e.g. escape ('a.b') -> 'a\.b'
     @staticmethod
     def escape(token): return re.sub(r'([^\\])\.', r'\1\.', token)
+
     # e.g. unescape ('a\.b') -> 'a.b'
     @staticmethod
     def unescape(token): return token.replace('\\.','.')
-        
+
+    # Return the HRN authority chain from top to bottom.
+    # e.g. hrn_auth_chain('a\.b.c.d') -> ['a\.b', 'a\.b.c']
+    @staticmethod
+    def hrn_auth_chain(hrn):
+        parts = Xrn.hrn_auth_list(hrn)
+        chain = []
+        for i in range(len(parts)):
+            chain.append('.'.join(parts[:i+1]))
+        # Include the HRN itself?
+        #chain.append(hrn)
+        return chain
+
+    # Is the given HRN a true authority over the namespace of the other
+    # child HRN?
+    # A better alternative than childHRN.startswith(parentHRN)
+    # e.g. hrn_is_auth_for_hrn('a\.b', 'a\.b.c.d') -> True,
+    # but hrn_is_auth_for_hrn('a', 'a\.b.c.d') -> False
+    # Also hrn_is_uauth_for_hrn('a\.b.c.d', 'a\.b.c.d') -> True
+    @staticmethod
+    def hrn_is_auth_for_hrn(parenthrn, hrn):
+        if parenthrn == hrn:
+            return True
+        for auth in Xrn.hrn_auth_chain(hrn):
+            if parenthrn == auth:
+                return True
+        return False
+
     URN_PREFIX = "urn:publicid:IDN"
 
     ########## basic tools on URNs
@@ -74,7 +128,7 @@ class Xrn:
             self.hrn_to_urn()
 # happens all the time ..
 #        if not type:
-#            sfa_logger().debug("type-less Xrn's are not safe")
+#            debug_logger.debug("type-less Xrn's are not safe")
 
     def get_urn(self): return self.urn
     def get_hrn(self): return self.hrn
diff --git a/tests/client/README b/tests/client/README
deleted file mode 100644 (file)
index 6d4ae3d..0000000
+++ /dev/null
@@ -1 +0,0 @@
-these files used to be in geniwrapper/cmdline
index 4d49d08..d25484c 100755 (executable)
@@ -13,7 +13,6 @@ from sfa.util.config import *
 from sfa.trust.certificate import *
 from sfa.trust.credential import *
 from sfa.util.sfaticket import *
-from sfa.util.rspec import *
 from sfa.client import sfi
 
 def random_string(size):
index 931cb31..bac4b12 100644 (file)
@@ -1,43 +1,43 @@
-#!/usr/bin/env python\r
-#-------------------------------------------------------------------------------\r
-import os\r
-import sys\r
-import glob\r
-import os.path\r
-from setuptools import setup\r
-#from distutils.core import setup\r
-#-------------------------------------------------------------------------------\r
-if 'upload' in sys.argv:\r
-    # for .pypirc file\r
-    try:\r
-        os.environ['HOME']\r
-    except KeyError:\r
-        os.environ['HOME'] = '..\\'\r
-#-------------------------------------------------------------------------------\r
-fpath = lambda x : os.path.join(*x.split('/'))\r
-#-------------------------------------------------------------------------------\r
-PYPI_URL = 'http://pypi.python.org/pypi/xmlbuilder'\r
-ld = open(fpath('xmlbuilder/docs/long_descr.rst')).read()\r
-ld = ld.replace('&','&amp;').replace('<','&lt;').replace('>','&gt;')\r
-setup(\r
-    name = "xmlbuilder",\r
-    fullname = "xmlbuilder",\r
-    version = "0.9",\r
-    packages = ["xmlbuilder"],\r
-    package_dir = {'xmlbuilder':'xmlbuilder'},\r
-    author = "koder",\r
-    author_email = "koder_dot_mail@gmail_dot_com",\r
-    maintainer = 'koder',\r
-    maintainer_email = "koder_dot_mail@gmail_dot_com",\r
-    description = "Pythonic way to create xml files",\r
-    license = "MIT",\r
-    keywords = "xml",\r
-    test_suite = "xml_buider.tests",\r
-    url = PYPI_URL,\r
-    download_url = PYPI_URL,\r
-    long_description = ld,\r
-    #include_package_data = True,\r
-    #package_data = {'xmlbuilder':["docs/*.rst"]},\r
-    #data_files = [('', ['xmlbuilder/docs/long_descr.rst'])]\r
-)\r
-#-------------------------------------------------------------------------------\r
+#!/usr/bin/env python
+#-------------------------------------------------------------------------------
+import os
+import sys
+import glob
+import os.path
+from setuptools import setup
+#from distutils.core import setup
+#-------------------------------------------------------------------------------
+if 'upload' in sys.argv:
+    # for .pypirc file
+    try:
+        os.environ['HOME']
+    except KeyError:
+        os.environ['HOME'] = '..\\'
+#-------------------------------------------------------------------------------
+fpath = lambda x : os.path.join(*x.split('/'))
+#-------------------------------------------------------------------------------
+PYPI_URL = 'http://pypi.python.org/pypi/xmlbuilder'
+ld = open(fpath('xmlbuilder/docs/long_descr.rst')).read()
+ld = ld.replace('&','&amp;').replace('<','&lt;').replace('>','&gt;')
+setup(
+    name = "xmlbuilder",
+    fullname = "xmlbuilder",
+    version = "0.9",
+    packages = ["xmlbuilder"],
+    package_dir = {'xmlbuilder':'xmlbuilder'},
+    author = "koder",
+    author_email = "koder_dot_mail@gmail_dot_com",
+    maintainer = 'koder',
+    maintainer_email = "koder_dot_mail@gmail_dot_com",
+    description = "Pythonic way to create xml files",
+    license = "MIT",
+    keywords = "xml",
+    test_suite = "xml_buider.tests",
+    url = PYPI_URL,
+    download_url = PYPI_URL,
+    long_description = ld,
+    #include_package_data = True,
+    #package_data = {'xmlbuilder':["docs/*.rst"]},
+    #data_files = [('', ['xmlbuilder/docs/long_descr.rst'])]
+)
+#-------------------------------------------------------------------------------
index 1be17b0..24ce7a5 100644 (file)
-#!/usr/bin/env python\r
-#-------------------------------------------------------------------------------\r
-from __future__ import with_statement\r
-#-------------------------------------------------------------------------------\r
-from xml.etree.ElementTree import TreeBuilder,tostring\r
-#-------------------------------------------------------------------------------\r
-__all__ = ["XMLBuilder"]\r
-__doc__ = """\r
-XMLBuilder is simple library build on top of ElementTree.TreeBuilder to\r
-simplify xml files creation as much as possible. Althow it can produce\r
-structured result with identated child tags. `XMLBuilder` use python `with`\r
-statement to define xml tag levels and `<<` operator for simple cases -\r
-text and tag without childs.\r
-\r
-from __future__ import with_statement\r
-from xmlbuilder import XMLBuilder\r
-x = XMLBuilder(format=True)\r
-with x.root(a = 1):\r
-    with x.data:\r
-        [x << ('node',{'val':i}) for i in range(10)]\r
-\r
-etree_node = ~x\r
-print str(x)\r
-"""\r
-#-------------------------------------------------------------------------------\r
-class _XMLNode(object):\r
-    """Class for internal usage"""\r
-    def __init__(self,parent,name,builder):\r
-        self.builder = builder\r
-        self.name = name\r
-        self.text = []\r
-        self.attrs = {}\r
-        self.entered = False\r
-        self.parent = parent\r
-    def __call__(self,*dt,**mp):\r
-        text = "".join(dt)\r
-        if self.entered:\r
-            self.builder.data(text)\r
-        else:\r
-            self.text.append(text)\r
-        if self.entered:\r
-            raise ValueError("Can't add attributes to already opened element")\r
-        smp = dict((k,str(v)) for k,v in mp.items())\r
-        self.attrs.update(smp)\r
-        return self\r
-    def __enter__(self):\r
-        self.parent += 1\r
-        self.builder.start(self.name,self.attrs)\r
-        self.builder.data("".join(self.text))\r
-        self.entered = True\r
-        return self\r
-    def __exit__(self,x,y,z):\r
-        self.parent -= 1\r
-        self.builder.end(self.name)\r
-        return False\r
-#-------------------------------------------------------------------------------\r
-class XMLBuilder(object):\r
-    """XmlBuilder(encoding = 'utf-8', # result xml file encoding\r
-            builder = None, #etree.TreeBuilder or compatible class\r
-            tab_level = None, #current tabulation level - string\r
-            format = False,   # make formatted output\r
-            tab_step = " " * 4) # tabulation step\r
-    use str(builder) or unicode(builder) to get xml text or\r
-    ~builder to obtaine etree.ElementTree\r
-    """\r
-    def __init__(self,encoding = 'utf-8',\r
-                      builder = None,\r
-                      tab_level = None,\r
-                      format = False,\r
-                      tab_step = " " * 4):\r
-        self.__builder = builder or TreeBuilder()\r
-        self.__encoding = encoding \r
-        if format :\r
-            if tab_level is None:\r
-                tab_level = ""\r
-        if tab_level is not None:\r
-            if not format:\r
-                raise ValueError("format is False, but tab_level not None")\r
-        self.__tab_level = tab_level # current format level\r
-        self.__tab_step = tab_step   # format step\r
-        self.__has_sub_tag = False   # True, if current tag had childrens\r
-        self.__node = None\r
-    # called from _XMLNode when tag opened\r
-    def __iadd__(self,val):\r
-        self.__has_sub_tag = False\r
-        if self.__tab_level is not None:\r
-            self.__builder.data("\n" + self.__tab_level)\r
-            self.__tab_level += self.__tab_step\r
-        return self\r
-    # called from XMLNode when tag closed\r
-    def __isub__(self,val):\r
-        if self.__tab_level is not None:\r
-            self.__tab_level = self.__tab_level[:-len(self.__tab_step)]\r
-            if self.__has_sub_tag:\r
-                self.__builder.data("\n" + self.__tab_level)\r
-        self.__has_sub_tag = True\r
-        return self\r
-    def __getattr__(self,name):\r
-        return _XMLNode(self,name,self.__builder)\r
-    def __call__(self,name,*dt,**mp):\r
-        x = _XMLNode(self,name,self.__builder)\r
-        x(*dt,**mp)\r
-        return x\r
-    #create new tag or add text\r
-    #possible shift values\r
-    #string - text\r
-    #tuple(string1,string2,dict) - new tag with name string1,attrs = dict,and text string2\r
-    #dict and string2 are optional\r
-    def __lshift__(self,val):\r
-        if isinstance(val,basestring):\r
-            self.__builder.data(val)\r
-        else:\r
-            self.__has_sub_tag = True\r
-            assert hasattr(val,'__len__'),\\r
-                'Shifted value should be tuple or list like object not %r' % val\r
-            assert hasattr(val,'__getitem__'),\\r
-                'Shifted value should be tuple or list like object not %r' % val\r
-            name = val[0]\r
-            if len(val) == 3:\r
-                text = val[1]\r
-                attrs = val[2]\r
-            elif len(val) == 1:\r
-                text = ""\r
-                attrs = {}\r
-            elif len(val) == 2:\r
-                if isinstance(val[1],basestring):\r
-                    text = val[1]\r
-                    attrs = {}\r
-                else:\r
-                    text = ""\r
-                    attrs = val[1]\r
-            if self.__tab_level is not None:\r
-                self.__builder.data("\n" + self.__tab_level)\r
-            self.__builder.start(name,\r
-                                 dict((k,str(v)) for k,v in attrs.items()))\r
-            if text:\r
-                self.__builder.data(text)\r
-            self.__builder.end(name)\r
-        return self # to allow xml << some1 << some2 << some3\r
-    #close builder\r
-    def __invert__(self):\r
-        if self.__node is not None:\r
-            return self.__node\r
-        self.__node = self.__builder.close()\r
-        return self.__node\r
-    def __str__(self):\r
-        """return generated xml"""\r
-        return tostring(~self,self.__encoding)\r
-    def __unicode__(self):\r
-        """return generated xml"""\r
-        res = tostring(~self,self.__encoding)\r
-        return res.decode(self.__encoding)\r
-#-------------------------------------------------------------------------------\r
+#!/usr/bin/env python
+#-------------------------------------------------------------------------------
+from __future__ import with_statement
+#-------------------------------------------------------------------------------
+from xml.etree.ElementTree import TreeBuilder,tostring
+#-------------------------------------------------------------------------------
+__all__ = ["XMLBuilder"]
+__doc__ = """
+XMLBuilder is simple library build on top of ElementTree.TreeBuilder to
+simplify xml files creation as much as possible. Althow it can produce
+structured result with identated child tags. `XMLBuilder` use python `with`
+statement to define xml tag levels and `<<` operator for simple cases -
+text and tag without childs.
+
+from __future__ import with_statement
+from xmlbuilder import XMLBuilder
+x = XMLBuilder(format=True)
+with x.root(a = 1):
+    with x.data:
+        [x << ('node',{'val':i}) for i in range(10)]
+
+etree_node = ~x
+print str(x)
+"""
+#-------------------------------------------------------------------------------
+class _XMLNode(object):
+    """Class for internal usage"""
+    def __init__(self,parent,name,builder):
+        self.builder = builder
+        self.name = name
+        self.text = []
+        self.attrs = {}
+        self.entered = False
+        self.parent = parent
+    def __call__(self,*dt,**mp):
+        text = "".join(dt)
+        if self.entered:
+            self.builder.data(text)
+        else:
+            self.text.append(text)
+        if self.entered:
+            raise ValueError("Can't add attributes to already opened element")
+        smp = dict((k,str(v)) for k,v in mp.items())
+        self.attrs.update(smp)
+        return self
+    def __enter__(self):
+        self.parent += 1
+        self.builder.start(self.name,self.attrs)
+        self.builder.data("".join(self.text))
+        self.entered = True
+        return self
+    def __exit__(self,x,y,z):
+        self.parent -= 1
+        self.builder.end(self.name)
+        return False
+#-------------------------------------------------------------------------------
+class XMLBuilder(object):
+    """XmlBuilder(encoding = 'utf-8', # result xml file encoding
+            builder = None, #etree.TreeBuilder or compatible class
+            tab_level = None, #current tabulation level - string
+            format = False,   # make formatted output
+            tab_step = " " * 4) # tabulation step
+    use str(builder) or unicode(builder) to get xml text or
+    ~builder to obtaine etree.ElementTree
+    """
+    def __init__(self,encoding = 'utf-8',
+                      builder = None,
+                      tab_level = None,
+                      format = False,
+                      tab_step = " " * 4):
+        self.__builder = builder or TreeBuilder()
+        self.__encoding = encoding 
+        if format :
+            if tab_level is None:
+                tab_level = ""
+        if tab_level is not None:
+            if not format:
+                raise ValueError("format is False, but tab_level not None")
+        self.__tab_level = tab_level # current format level
+        self.__tab_step = tab_step   # format step
+        self.__has_sub_tag = False   # True, if current tag had childrens
+        self.__node = None
+    # called from _XMLNode when tag opened
+    def __iadd__(self,val):
+        self.__has_sub_tag = False
+        if self.__tab_level is not None:
+            self.__builder.data("\n" + self.__tab_level)
+            self.__tab_level += self.__tab_step
+        return self
+    # called from XMLNode when tag closed
+    def __isub__(self,val):
+        if self.__tab_level is not None:
+            self.__tab_level = self.__tab_level[:-len(self.__tab_step)]
+            if self.__has_sub_tag:
+                self.__builder.data("\n" + self.__tab_level)
+        self.__has_sub_tag = True
+        return self
+    def __getattr__(self,name):
+        return _XMLNode(self,name,self.__builder)
+    def __call__(self,name,*dt,**mp):
+        x = _XMLNode(self,name,self.__builder)
+        x(*dt,**mp)
+        return x
+    #create new tag or add text
+    #possible shift values
+    #string - text
+    #tuple(string1,string2,dict) - new tag with name string1,attrs = dict,and text string2
+    #dict and string2 are optional
+    def __lshift__(self,val):
+        if isinstance(val,basestring):
+            self.__builder.data(val)
+        else:
+            self.__has_sub_tag = True
+            assert hasattr(val,'__len__'),\
+                'Shifted value should be tuple or list like object not %r' % val
+            assert hasattr(val,'__getitem__'),\
+                'Shifted value should be tuple or list like object not %r' % val
+            name = val[0]
+            if len(val) == 3:
+                text = val[1]
+                attrs = val[2]
+            elif len(val) == 1:
+                text = ""
+                attrs = {}
+            elif len(val) == 2:
+                if isinstance(val[1],basestring):
+                    text = val[1]
+                    attrs = {}
+                else:
+                    text = ""
+                    attrs = val[1]
+            if self.__tab_level is not None:
+                self.__builder.data("\n" + self.__tab_level)
+            self.__builder.start(name,
+                                 dict((k,str(v)) for k,v in attrs.items()))
+            if text:
+                self.__builder.data(text)
+            self.__builder.end(name)
+        return self # to allow xml << some1 << some2 << some3
+    #close builder
+    def __invert__(self):
+        if self.__node is not None:
+            return self.__node
+        self.__node = self.__builder.close()
+        return self.__node
+    def __str__(self):
+        """return generated xml"""
+        return tostring(~self,self.__encoding)
+    def __unicode__(self):
+        """return generated xml"""
+        res = tostring(~self,self.__encoding)
+        return res.decode(self.__encoding)
+#-------------------------------------------------------------------------------
index 67eaa67..43a67b1 100644 (file)
@@ -1,99 +1,99 @@
-#!/usr/bin/env python\r
-from __future__ import with_statement\r
-#-------------------------------------------------------------------------------\r
-import unittest\r
-from xml.etree.ElementTree import fromstring\r
-#-------------------------------------------------------------------------------\r
-from xmlbuilder import XMLBuilder\r
-#-------------------------------------------------------------------------------\r
-def xmlStructureEqual(xml1,xml2):\r
-    tree1 = fromstring(xml1)\r
-    tree2 = fromstring(xml2)\r
-    return _xmlStructureEqual(tree1,tree2)\r
-#-------------------------------------------------------------------------------\r
-def _xmlStructureEqual(tree1,tree2):\r
-    if tree1.tag != tree2.tag:\r
-        return False\r
-    attr1 = list(tree1.attrib.items())\r
-    attr1.sort()\r
-    attr2 = list(tree2.attrib.items())\r
-    attr2.sort()\r
-    if attr1 != attr2:\r
-        return False\r
-    return tree1.getchildren() == tree2.getchildren()\r
-#-------------------------------------------------------------------------------\r
-result1 = \\r
-"""\r
-<root>\r
-    <array />\r
-    <array len="10">\r
-        <el val="0" />\r
-        <el val="1">xyz</el>\r
-        <el val="2">abc</el>\r
-        <el val="3" />\r
-        <el val="4" />\r
-        <el val="5" />\r
-        <sup-el val="23">test  </sup-el>\r
-    </array>\r
-</root>\r
-""".strip()\r
-#-------------------------------------------------------------------------------\r
-class TestXMLBuilder(unittest.TestCase):\r
-    def testShift(self):\r
-        xml = (XMLBuilder() << ('root',))\r
-        self.assertEqual(str(xml),"<root />")\r
-        \r
-        xml = XMLBuilder()\r
-        xml << ('root',"some text")\r
-        self.assertEqual(str(xml),"<root>some text</root>")\r
-        \r
-        xml = XMLBuilder()\r
-        xml << ('root',{'x':1,'y':'2'})\r
-        self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'>some text</root>"))\r
-        \r
-        xml = XMLBuilder()\r
-        xml << ('root',{'x':1,'y':'2'})\r
-        self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'></root>"))\r
-\r
-        xml = XMLBuilder()\r
-        xml << ('root',{'x':1,'y':'2'})\r
-        self.assert_(not xmlStructureEqual(str(xml),"<root x='2' y='2'></root>"))\r
-\r
-        \r
-        xml = XMLBuilder()\r
-        xml << ('root',"gonduras.ua",{'x':1,'y':'2'})\r
-        self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'>gonduras.ua</root>"))\r
-        \r
-        xml = XMLBuilder()\r
-        xml << ('root',"gonduras.ua",{'x':1,'y':'2'})\r
-        self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'>gonduras.com</root>"))\r
-    #---------------------------------------------------------------------------\r
-    def testWith(self):\r
-        xml = XMLBuilder()\r
-        with xml.root(lenght = 12):\r
-            pass\r
-        self.assertEqual(str(xml),'<root lenght="12" />')\r
-        \r
-        xml = XMLBuilder()\r
-        with xml.root():\r
-            xml << "text1" << "text2" << ('some_node',)\r
-        self.assertEqual(str(xml),"<root>text1text2<some_node /></root>")\r
-    #---------------------------------------------------------------------------\r
-    def testFormat(self):\r
-        x = XMLBuilder('utf-8',format = True)\r
-        with x.root():\r
-            x << ('array',)\r
-            with x.array(len = 10):\r
-                with x.el(val = 0):\r
-                    pass\r
-                with x.el('xyz',val = 1):\r
-                    pass\r
-                x << ("el","abc",{'val':2}) << ('el',dict(val=3))\r
-                x << ('el',dict(val=4)) << ('el',dict(val='5'))\r
-                with x('sup-el',val = 23):\r
-                    x << "test  "\r
-        self.assertEqual(str(x),result1)\r
-#-------------------------------------------------------------------------------\r
-if __name__ == '__main__':\r
-    unittest.main()\r
-#-------------------------------------------------------------------------------\r
+#!/usr/bin/env python
+from __future__ import with_statement
+#-------------------------------------------------------------------------------
+import unittest
+from xml.etree.ElementTree import fromstring
+#-------------------------------------------------------------------------------
+from xmlbuilder import XMLBuilder
+#-------------------------------------------------------------------------------
+def xmlStructureEqual(xml1,xml2):
+    tree1 = fromstring(xml1)
+    tree2 = fromstring(xml2)
+    return _xmlStructureEqual(tree1,tree2)
+#-------------------------------------------------------------------------------
+def _xmlStructureEqual(tree1,tree2):
+    if tree1.tag != tree2.tag:
+        return False
+    attr1 = list(tree1.attrib.items())
+    attr1.sort()
+    attr2 = list(tree2.attrib.items())
+    attr2.sort()
+    if attr1 != attr2:
+        return False
+    return tree1.getchildren() == tree2.getchildren()
+#-------------------------------------------------------------------------------
+result1 = \
+"""
+<root>
+    <array />
+    <array len="10">
+        <el val="0" />
+        <el val="1">xyz</el>
+        <el val="2">abc</el>
+        <el val="3" />
+        <el val="4" />
+        <el val="5" />
+        <sup-el val="23">test  </sup-el>
+    </array>
+</root>
+""".strip()
+#-------------------------------------------------------------------------------
+class TestXMLBuilder(unittest.TestCase):
+    def testShift(self):
+        xml = (XMLBuilder() << ('root',))
+        self.assertEqual(str(xml),"<root />")
+        
+        xml = XMLBuilder()
+        xml << ('root',"some text")
+        self.assertEqual(str(xml),"<root>some text</root>")
+        
+        xml = XMLBuilder()
+        xml << ('root',{'x':1,'y':'2'})
+        self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'>some text</root>"))
+        
+        xml = XMLBuilder()
+        xml << ('root',{'x':1,'y':'2'})
+        self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'></root>"))
+
+        xml = XMLBuilder()
+        xml << ('root',{'x':1,'y':'2'})
+        self.assert_(not xmlStructureEqual(str(xml),"<root x='2' y='2'></root>"))
+
+        
+        xml = XMLBuilder()
+        xml << ('root',"gonduras.ua",{'x':1,'y':'2'})
+        self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'>gonduras.ua</root>"))
+        
+        xml = XMLBuilder()
+        xml << ('root',"gonduras.ua",{'x':1,'y':'2'})
+        self.assert_(xmlStructureEqual(str(xml),"<root x='1' y='2'>gonduras.com</root>"))
+    #---------------------------------------------------------------------------
+    def testWith(self):
+        xml = XMLBuilder()
+        with xml.root(lenght = 12):
+            pass
+        self.assertEqual(str(xml),'<root lenght="12" />')
+        
+        xml = XMLBuilder()
+        with xml.root():
+            xml << "text1" << "text2" << ('some_node',)
+        self.assertEqual(str(xml),"<root>text1text2<some_node /></root>")
+    #---------------------------------------------------------------------------
+    def testFormat(self):
+        x = XMLBuilder('utf-8',format = True)
+        with x.root():
+            x << ('array',)
+            with x.array(len = 10):
+                with x.el(val = 0):
+                    pass
+                with x.el('xyz',val = 1):
+                    pass
+                x << ("el","abc",{'val':2}) << ('el',dict(val=3))
+                x << ('el',dict(val=4)) << ('el',dict(val='5'))
+                with x('sup-el',val = 23):
+                    x << "test  "
+        self.assertEqual(str(x),result1)
+#-------------------------------------------------------------------------------
+if __name__ == '__main__':
+    unittest.main()
+#-------------------------------------------------------------------------------