From: Thierry Parmentelat Date: Wed, 16 Nov 2011 13:44:14 +0000 (+0100) Subject: Merge branch 'upstreammaster' X-Git-Tag: sfa-2.1-24~34 X-Git-Url: http://git.onelab.eu/?p=sfa.git;a=commitdiff_plain;h=69fb221c274eb0b6e9f6ff6f895e5e6f90b17230;hp=7691d2227f76cac186cb98167ac425bd9b5c0d5e Merge branch 'upstreammaster' --- diff --git a/Makefile b/Makefile index 72fa41ac..701c3295 100644 --- a/Makefile +++ b/Makefile @@ -146,5 +146,10 @@ else $(SSHCOMMAND) exec service sfa restart endif +# 99% of the time this is enough +fastsync: + +$(RSYNC) ./sfa/ $(SSHURL)/usr/lib\*/python2.\*/site-packages/sfa/ + $(SSHCOMMAND) exec service sfa restart + .PHONY: sync ########## diff --git a/sfa/client/sfiAddAttribute.py b/sfa/client/sfiAddAttribute.py index f22e63e8..6fa51b40 100755 --- a/sfa/client/sfiAddAttribute.py +++ b/sfa/client/sfiAddAttribute.py @@ -1,9 +1,12 @@ #! /usr/bin/env python import sys + +from sfa.util.sfalogging import logger from sfa.client.sfi_commands import Commands from sfa.rspecs.rspec import RSpec +logger.enable_console() command = Commands(usage="%prog [options] [node1 node2...]", description="Add sliver attributes to the RSpec. " + "This command reads in an RSpec and outputs a modified " + @@ -33,12 +36,12 @@ if command.opts.infile: try: rspec.version.add_default_sliver_attribute(name, value) except: - print >> sys.stderr, "FAILED: on all nodes: %s=%s" % (name, value) + logger.log_exc("sfiAddAttribute FAILED on all nodes: %s=%s" % (name, value)) else: for node in nodes: try: rspec.version.add_sliver_attribute(node, name, value) except: - print >> sys.stderr, "FAILED: on node %s: %s=%s" % (node, name, value) + logger.log_exc ("sfiAddAttribute FAILED on node %s: %s=%s" % (node, name, value)) print rspec.toxml() diff --git a/sfa/client/sfiAddLinks.py b/sfa/client/sfiAddLinks.py index f5b28888..2e667b1b 100755 --- a/sfa/client/sfiAddLinks.py +++ b/sfa/client/sfiAddLinks.py @@ -1,10 +1,13 @@ #! /usr/bin/env python import sys + +from sfa.util.sfalogging import logger from sfa.client.sfi_commands import Commands from sfa.rspecs.rspec import RSpec from sfa.rspecs.version_manager import VersionManager +logger.enable_console() command = Commands(usage="%prog [options] node1 node2...", description="Add links to the RSpec. " + "This command reads in an RSpec and outputs a modified " + @@ -38,8 +41,7 @@ try: request_rspec.version.merge(ad_rspec) request_rspec.version.add_link_requests(link_tuples) except: - print >> sys.stderr, "FAILED: %s" % links - raise + logger.log_exc("sfiAddLinks FAILED with links %s" % links) sys.exit(1) print >>outfile, request_rspec.toxml() sys.exit(0) diff --git a/sfa/client/sfiAddSliver.py b/sfa/client/sfiAddSliver.py index c72dee34..84ffa8b0 100755 --- a/sfa/client/sfiAddSliver.py +++ b/sfa/client/sfiAddSliver.py @@ -1,10 +1,13 @@ #! /usr/bin/env python import sys + +from sfa.util.sfalogging import logger from sfa.client.sfi_commands import Commands from sfa.rspecs.rspec import RSpec from sfa.rspecs.version_manager import VersionManager +logger.enable_console() command = Commands(usage="%prog [options] node1 node2...", description="Add slivers to the RSpec. " + "This command reads in an RSpec and outputs a modified " + @@ -33,12 +36,10 @@ try: version_num = ad_rspec.version.version request_version = version_manager._get_version(type, version_num, 'request') request_rspec = RSpec(version=request_version) - slivers = [{'hostname': node} for node in nodes] request_rspec.version.merge(ad_rspec) - request_rspec.version.add_slivers(slivers) + request_rspec.version.add_slivers(nodes) except: - print >> sys.stderr, "FAILED: %s" % nodes - raise + logger.log_exc("sfiAddSliver failed with nodes %s" % nodes) sys.exit(1) print >>outfile, request_rspec.toxml() sys.exit(0) diff --git a/sfa/client/sfiDeleteAttribute.py b/sfa/client/sfiDeleteAttribute.py index 53b2542d..7e6a5aeb 100755 --- a/sfa/client/sfiDeleteAttribute.py +++ b/sfa/client/sfiDeleteAttribute.py @@ -1,9 +1,12 @@ #! /usr/bin/env python import sys + +from sfa.util.sfalogging import logger from sfa.client.sfi_commands import Commands from sfa.rspecs.rspec import RSpec +logger.enable_console() command = Commands(usage="%prog [options] [node1 node2...]", description="Delete sliver attributes from the RSpec. " + "This command reads in an RSpec and outputs a modified " + @@ -33,12 +36,12 @@ if command.opts.infile: try: rspec.version.remove_default_sliver_attribute(name, value) except: - print >> sys.stderr, "FAILED: on all nodes: %s=%s" % (name, value) + logger.log_exc("sfiDeleteAttribute FAILED on all nodes: %s=%s" % (name, value)) else: for node in nodes: try: rspec.version.remove_sliver_attribute(node, name, value) except: - print >> sys.stderr, "FAILED: on node %s: %s=%s" % (node, name, value) + logger.log_exc("sfiDeleteAttribute FAILED on node %s: %s=%s" % (node, name, value)) print rspec.toxml() diff --git a/sfa/client/sfiDeleteSliver.py b/sfa/client/sfiDeleteSliver.py index be10f0b6..3dc50e65 100755 --- a/sfa/client/sfiDeleteSliver.py +++ b/sfa/client/sfiDeleteSliver.py @@ -1,9 +1,12 @@ #! /usr/bin/env python import sys + +from sfa.util.sfalogging import logger from sfa.client.sfi_commands import Commands from sfa.rspecs.rspec import RSpec +logger.enable_console() command = Commands(usage="%prog [options] node1 node2...", description="Delete slivers from the RSpec. " + "This command reads in an RSpec and outputs a modified " + @@ -24,7 +27,7 @@ if command.opts.infile: rspec.version.remove_slivers(slivers) print rspec.toxml() except: - print >> sys.stderr, "FAILED: %s" % nodes + logger.log_exc("sfiDeleteSliver FAILED with nodes %s" % nodes) diff --git a/sfa/client/sfiListNodes.py b/sfa/client/sfiListNodes.py index f9e794b8..3de1657f 100755 --- a/sfa/client/sfiListNodes.py +++ b/sfa/client/sfiListNodes.py @@ -2,7 +2,8 @@ import sys from sfa.client.sfi_commands import Commands -from sfa.rspecs.rspec import RSpec +from sfa.rspecs.rspec import RSpec +from sfa.util.plxrn import xrn_to_hostname command = Commands(usage="%prog [options]", description="List all nodes in the RSpec. " + @@ -17,7 +18,11 @@ if command.opts.infile: sys.stdout = open(command.opts.outfile, 'w') for node in nodes: - print node + hostname = None + if node.get('component_id'): + hostname = xrn_to_hostname(node['component_id']) + if hostname: + print hostname diff --git a/sfa/client/sfiListSlivers.py b/sfa/client/sfiListSlivers.py index cf76823f..c869f614 100755 --- a/sfa/client/sfiListSlivers.py +++ b/sfa/client/sfiListSlivers.py @@ -3,6 +3,7 @@ import sys from sfa.client.sfi_commands import Commands from sfa.rspecs.rspec import RSpec +from sfa.util.plxrn import xrn_to_hostname command = Commands(usage="%prog [options]", description="List all slivers in the RSpec. " + @@ -23,10 +24,14 @@ if command.opts.infile: print " %s: %s" % (name, value) for node in nodes: - print node - if command.opts.showatt: - atts = rspec.version.get_sliver_attributes(node) - for (name, value) in atts: - print " %s: %s" % (name, value) + hostname = None + if node.get('component_id'): + hostname = xrn_to_hostname(node['component_id']) + if hostname: + print hostname + if command.opts.showatt: + atts = rspec.version.get_sliver_attributes(hostname) + for (name, value) in atts: + print " %s: %s" % (name, value) diff --git a/sfa/generic/__init__.py b/sfa/generic/__init__.py index ea6ce059..0f9043b2 100644 --- a/sfa/generic/__init__.py +++ b/sfa/generic/__init__.py @@ -83,7 +83,7 @@ class Generic: classname = "%s_manager_class"%interface try: module = getattr(self,classname)() - logger.info("%s : %s"%(message,module)) + logger.debug("%s : %s"%(message,module)) return module except: logger.log_exc_critical(message) @@ -99,7 +99,7 @@ class Generic: classname = "driver_class" try: class_obj = getattr(self,classname)() - logger.info("%s : %s"%(message,class_obj)) + logger.debug("%s : %s"%(message,class_obj)) return class_obj(config) except: logger.log_exc_critical(message) diff --git a/sfa/generic/pl.py b/sfa/generic/pl.py index 098a27a3..2c036199 100644 --- a/sfa/generic/pl.py +++ b/sfa/generic/pl.py @@ -14,7 +14,7 @@ class pl (Generic): # the manager classes for the server-side services def registry_manager_class (self) : - return sfa.managers.registry_manager + return sfa.managers.registry_manager.RegistryManager def slicemgr_manager_class (self) : return sfa.managers.slice_manager.SliceManager def aggregate_manager_class (self) : diff --git a/sfa/managers/aggregate_manager.py b/sfa/managers/aggregate_manager.py index 14ec7d0f..30ebd925 100644 --- a/sfa/managers/aggregate_manager.py +++ b/sfa/managers/aggregate_manager.py @@ -2,6 +2,7 @@ import datetime import time import sys +from sfa.util.sfalogging import logger from sfa.util.faults import RecordNotFound, SliverDoesNotExist from sfa.util.xrn import get_authority, hrn_to_urn, urn_to_hrn, Xrn, urn_to_sliver_id from sfa.util.plxrn import slicename_to_hrn, hrn_to_pl_slicename @@ -187,7 +188,7 @@ class AggregateManager: slices.verify_slice_attributes(slice, requested_attributes) # add/remove slice from nodes - requested_slivers = [node['component_name'] for node in rspec.version.get_nodes_with_slivers()] + requested_slivers = [node.get('component_name') for node in rspec.version.get_nodes_with_slivers()] slices.verify_slice_nodes(slice, requested_slivers, peer) # add/remove links links diff --git a/sfa/managers/registry_manager.py b/sfa/managers/registry_manager.py index 5888b2b9..56062b9f 100644 --- a/sfa/managers/registry_manager.py +++ b/sfa/managers/registry_manager.py @@ -15,425 +15,429 @@ from sfa.trust.credential import Credential from sfa.trust.certificate import Certificate, Keypair, convert_public_key from sfa.trust.gid import create_uuid -# The GENI GetVersion call -def GetVersion(api): - peers =dict ([ (peername,v._ServerProxy__host) for (peername,v) in api.registries.iteritems() - if peername != api.hrn]) - xrn=Xrn(api.hrn) - return version_core({'interface':'registry', - 'hrn':xrn.get_hrn(), - 'urn':xrn.get_urn(), - 'peers':peers}) - -def get_credential(api, xrn, type, is_self=False): - # convert xrn to hrn - if type: - hrn = urn_to_hrn(xrn)[0] - else: - hrn, type = urn_to_hrn(xrn) +class RegistryManager: + + def __init__ (self): pass + + # The GENI GetVersion call + def GetVersion(self, api): + peers = dict ( [ (hrn,interface._ServerProxy__host) for (hrn,interface) in api.registries.iteritems() + if hrn != api.hrn]) + xrn=Xrn(api.hrn) + return version_core({'interface':'registry', + 'hrn':xrn.get_hrn(), + 'urn':xrn.get_urn(), + 'peers':peers}) + + def get_credential(self, api, xrn, type, is_self=False): + # convert xrn to hrn + if type: + hrn = urn_to_hrn(xrn)[0] + else: + hrn, type = urn_to_hrn(xrn) + + # Is this a root or sub authority + auth_hrn = api.auth.get_authority(hrn) + if not auth_hrn or hrn == api.config.SFA_INTERFACE_HRN: + auth_hrn = hrn + # get record info + auth_info = api.auth.get_auth_info(auth_hrn) + table = SfaTable() + records = table.findObjects({'type': type, 'hrn': hrn}) + if not records: + raise RecordNotFound(hrn) + record = records[0] + + # verify_cancreate_credential requires that the member lists + # (researchers, pis, etc) be filled in + api.driver.fill_record_info(record, api.aggregates) + if record['type']=='user': + if not record['enabled']: + raise AccountNotEnabled(": PlanetLab account %s is not enabled. Please contact your site PI" %(record['email'])) + + # get the callers gid + # if this is a self cred the record's gid is the caller's gid + if is_self: + caller_hrn = hrn + caller_gid = record.get_gid_object() + else: + caller_gid = api.auth.client_cred.get_gid_caller() + caller_hrn = caller_gid.get_hrn() - # Is this a root or sub authority - auth_hrn = api.auth.get_authority(hrn) - if not auth_hrn or hrn == api.config.SFA_INTERFACE_HRN: - auth_hrn = hrn - # get record info - auth_info = api.auth.get_auth_info(auth_hrn) - table = SfaTable() - records = table.findObjects({'type': type, 'hrn': hrn}) - if not records: - raise RecordNotFound(hrn) - record = records[0] - - # verify_cancreate_credential requires that the member lists - # (researchers, pis, etc) be filled in - api.driver.fill_record_info(record, api.aggregates) - if record['type']=='user': - if not record['enabled']: - raise AccountNotEnabled(": PlanetLab account %s is not enabled. Please contact your site PI" %(record['email'])) - - # get the callers gid - # if this is a self cred the record's gid is the caller's gid - if is_self: - caller_hrn = hrn - caller_gid = record.get_gid_object() - else: - caller_gid = api.auth.client_cred.get_gid_caller() - caller_hrn = caller_gid.get_hrn() - - object_hrn = record.get_gid_object().get_hrn() - rights = api.auth.determine_user_rights(caller_hrn, record) - # make sure caller has rights to this object - if rights.is_empty(): - raise PermissionError(caller_hrn + " has no rights to " + record['name']) - - object_gid = GID(string=record['gid']) - new_cred = Credential(subject = object_gid.get_subject()) - new_cred.set_gid_caller(caller_gid) - new_cred.set_gid_object(object_gid) - new_cred.set_issuer_keys(auth_info.get_privkey_filename(), auth_info.get_gid_filename()) - #new_cred.set_pubkey(object_gid.get_pubkey()) - new_cred.set_privileges(rights) - new_cred.get_privileges().delegate_all_privileges(True) - if 'expires' in record: - new_cred.set_expiration(int(record['expires'])) - auth_kind = "authority,ma,sa" - # Parent not necessary, verify with certs - #new_cred.set_parent(api.auth.hierarchy.get_auth_cred(auth_hrn, kind=auth_kind)) - new_cred.encode() - new_cred.sign() - - return new_cred.save_to_string(save_parents=True) - - -def resolve(api, xrns, type=None, full=True): - - # load all known registry names into a prefix tree and attempt to find - # the longest matching prefix - if not isinstance(xrns, types.ListType): - if not type: - type = Xrn(xrns).get_type() - xrns = [xrns] - hrns = [urn_to_hrn(xrn)[0] for xrn in xrns] - # create a dict where key is a registry hrn and its value is a - # hrns at that registry (determined by the known prefix tree). - xrn_dict = {} - registries = api.registries - tree = prefixTree() - registry_hrns = registries.keys() - tree.load(registry_hrns) - for xrn in xrns: - registry_hrn = tree.best_match(urn_to_hrn(xrn)[0]) - if registry_hrn not in xrn_dict: - xrn_dict[registry_hrn] = [] - xrn_dict[registry_hrn].append(xrn) + object_hrn = record.get_gid_object().get_hrn() + rights = api.auth.determine_user_rights(caller_hrn, record) + # make sure caller has rights to this object + if rights.is_empty(): + raise PermissionError(caller_hrn + " has no rights to " + record['name']) + + object_gid = GID(string=record['gid']) + new_cred = Credential(subject = object_gid.get_subject()) + new_cred.set_gid_caller(caller_gid) + new_cred.set_gid_object(object_gid) + new_cred.set_issuer_keys(auth_info.get_privkey_filename(), auth_info.get_gid_filename()) + #new_cred.set_pubkey(object_gid.get_pubkey()) + new_cred.set_privileges(rights) + new_cred.get_privileges().delegate_all_privileges(True) + if 'expires' in record: + new_cred.set_expiration(int(record['expires'])) + auth_kind = "authority,ma,sa" + # Parent not necessary, verify with certs + #new_cred.set_parent(api.auth.hierarchy.get_auth_cred(auth_hrn, kind=auth_kind)) + new_cred.encode() + new_cred.sign() + + return new_cred.save_to_string(save_parents=True) + + + def resolve(self, api, xrns, type=None, full=True): + + # load all known registry names into a prefix tree and attempt to find + # the longest matching prefix + if not isinstance(xrns, types.ListType): + if not type: + type = Xrn(xrns).get_type() + xrns = [xrns] + hrns = [urn_to_hrn(xrn)[0] for xrn in xrns] + # create a dict where key is a registry hrn and its value is a + # hrns at that registry (determined by the known prefix tree). + xrn_dict = {} + registries = api.registries + tree = prefixTree() + registry_hrns = registries.keys() + tree.load(registry_hrns) + for xrn in xrns: + registry_hrn = tree.best_match(urn_to_hrn(xrn)[0]) + if registry_hrn not in xrn_dict: + xrn_dict[registry_hrn] = [] + xrn_dict[registry_hrn].append(xrn) + + records = [] + for registry_hrn in xrn_dict: + # skip the hrn without a registry hrn + # XX should we let the user know the authority is unknown? + if not registry_hrn: + continue + + # if the best match (longest matching hrn) is not the local registry, + # forward the request + xrns = xrn_dict[registry_hrn] + if registry_hrn != api.hrn: + credential = api.getCredential() + interface = api.registries[registry_hrn] + server = api.server_proxy(interface, credential) + peer_records = server.Resolve(xrns, credential) + records.extend([SfaRecord(dict=record).as_dict() for record in peer_records]) + + # try resolving the remaining unfound records at the local registry + remaining_hrns = set(hrns).difference([record['hrn'] for record in records]) + # convert set to list + remaining_hrns = [hrn for hrn in remaining_hrns] + table = SfaTable() + local_records = table.findObjects({'hrn': remaining_hrns}) + if full: + api.driver.fill_record_info(local_records, api.aggregates) - records = [] - for registry_hrn in xrn_dict: - # skip the hrn without a registry hrn - # XX should we let the user know the authority is unknown? + # convert local record objects to dicts + records.extend([dict(record) for record in local_records]) + if not records: + raise RecordNotFound(str(hrns)) + + if type: + records = filter(lambda rec: rec['type'] in [type], records) + + return records + + def list(self, api, xrn, origin_hrn=None): + hrn, type = urn_to_hrn(xrn) + # load all know registry names into a prefix tree and attempt to find + # the longest matching prefix + records = [] + registries = api.registries + registry_hrns = registries.keys() + tree = prefixTree() + tree.load(registry_hrns) + registry_hrn = tree.best_match(hrn) + + #if there was no match then this record belongs to an unknow registry if not registry_hrn: - continue - + raise MissingAuthority(xrn) # if the best match (longest matching hrn) is not the local registry, # forward the request - xrns = xrn_dict[registry_hrn] + records = [] if registry_hrn != api.hrn: credential = api.getCredential() interface = api.registries[registry_hrn] server = api.server_proxy(interface, credential) - peer_records = server.Resolve(xrns, credential) - records.extend([SfaRecord(dict=record).as_dict() for record in peer_records]) - - # try resolving the remaining unfound records at the local registry - remaining_hrns = set(hrns).difference([record['hrn'] for record in records]) - # convert set to list - remaining_hrns = [hrn for hrn in remaining_hrns] - table = SfaTable() - local_records = table.findObjects({'hrn': remaining_hrns}) - if full: - api.driver.fill_record_info(local_records, api.aggregates) - - # convert local record objects to dicts - records.extend([dict(record) for record in local_records]) - if not records: - raise RecordNotFound(str(hrns)) - - if type: - records = filter(lambda rec: rec['type'] in [type], records) - - return records - -def list(api, xrn, origin_hrn=None): - hrn, type = urn_to_hrn(xrn) - # load all know registry names into a prefix tree and attempt to find - # the longest matching prefix - records = [] - registries = api.registries - registry_hrns = registries.keys() - tree = prefixTree() - tree.load(registry_hrns) - registry_hrn = tree.best_match(hrn) - - #if there was no match then this record belongs to an unknow registry - if not registry_hrn: - raise MissingAuthority(xrn) - # if the best match (longest matching hrn) is not the local registry, - # forward the request - records = [] - if registry_hrn != api.hrn: - credential = api.getCredential() - interface = api.registries[registry_hrn] - server = api.server_proxy(interface, credential) - record_list = server.List(xrn, credential) - records = [SfaRecord(dict=record).as_dict() for record in record_list] - - # if we still have not found the record yet, try the local registry - if not records: - if not api.auth.hierarchy.auth_exists(hrn): - raise MissingAuthority(hrn) - + record_list = server.List(xrn, credential) + records = [SfaRecord(dict=record).as_dict() for record in record_list] + + # if we still have not found the record yet, try the local registry + if not records: + if not api.auth.hierarchy.auth_exists(hrn): + raise MissingAuthority(hrn) + + table = SfaTable() + records = table.find({'authority': hrn}) + + return records + + + def create_gid(self, api, xrn, cert): + # get the authority + authority = Xrn(xrn=xrn).get_authority_hrn() + auth_info = api.auth.get_auth_info(authority) + if not cert: + pkey = Keypair(create=True) + else: + certificate = Certificate(string=cert) + pkey = certificate.get_pubkey() + gid = api.auth.hierarchy.create_gid(xrn, create_uuid(), pkey) + return gid.save_to_string(save_parents=True) + + def register(self, api, record): + + hrn, type = record['hrn'], record['type'] + urn = hrn_to_urn(hrn,type) + # validate the type + if type not in ['authority', 'slice', 'node', 'user']: + raise UnknownSfaType(type) + + # check if record already exists table = SfaTable() - records = table.find({'authority': hrn}) - - return records - - -def create_gid(api, xrn, cert): - # get the authority - authority = Xrn(xrn=xrn).get_authority_hrn() - auth_info = api.auth.get_auth_info(authority) - if not cert: - pkey = Keypair(create=True) - else: - certificate = Certificate(string=cert) - pkey = certificate.get_pubkey() - gid = api.auth.hierarchy.create_gid(xrn, create_uuid(), pkey) - return gid.save_to_string(save_parents=True) - -def register(api, record): - - hrn, type = record['hrn'], record['type'] - urn = hrn_to_urn(hrn,type) - # validate the type - if type not in ['authority', 'slice', 'node', 'user']: - raise UnknownSfaType(type) - - # check if record already exists - table = SfaTable() - existing_records = table.find({'type': type, 'hrn': hrn}) - if existing_records: - raise ExistingRecord(hrn) - - record = SfaRecord(dict = record) - record['authority'] = get_authority(record['hrn']) - type = record['type'] - hrn = record['hrn'] - auth_info = api.auth.get_auth_info(record['authority']) - pub_key = None - # make sure record has a gid - if 'gid' not in record: - uuid = create_uuid() - pkey = Keypair(create=True) - if 'key' in record and record['key']: - if isinstance(record['key'], types.ListType): - pub_key = record['key'][0] + existing_records = table.find({'type': type, 'hrn': hrn}) + if existing_records: + raise ExistingRecord(hrn) + + record = SfaRecord(dict = record) + record['authority'] = get_authority(record['hrn']) + type = record['type'] + hrn = record['hrn'] + auth_info = api.auth.get_auth_info(record['authority']) + pub_key = None + # make sure record has a gid + if 'gid' not in record: + uuid = create_uuid() + pkey = Keypair(create=True) + if 'key' in record and record['key']: + if isinstance(record['key'], types.ListType): + pub_key = record['key'][0] + else: + pub_key = record['key'] + pkey = convert_public_key(pub_key) + + gid_object = api.auth.hierarchy.create_gid(urn, uuid, pkey) + gid = gid_object.save_to_string(save_parents=True) + record['gid'] = gid + record.set_gid(gid) + + if type in ["authority"]: + # update the tree + if not api.auth.hierarchy.auth_exists(hrn): + api.auth.hierarchy.create_auth(hrn_to_urn(hrn,'authority')) + + # get the GID from the newly created authority + gid = auth_info.get_gid_object() + record.set_gid(gid.save_to_string(save_parents=True)) + pl_record = api.driver.sfa_fields_to_pl_fields(type, hrn, record) + sites = api.driver.GetSites([pl_record['login_base']]) + if not sites: + pointer = api.driver.AddSite(pl_record) else: - pub_key = record['key'] - pkey = convert_public_key(pub_key) - - gid_object = api.auth.hierarchy.create_gid(urn, uuid, pkey) - gid = gid_object.save_to_string(save_parents=True) - record['gid'] = gid - record.set_gid(gid) - - if type in ["authority"]: - # update the tree - if not api.auth.hierarchy.auth_exists(hrn): - api.auth.hierarchy.create_auth(hrn_to_urn(hrn,'authority')) - - # get the GID from the newly created authority - gid = auth_info.get_gid_object() - record.set_gid(gid.save_to_string(save_parents=True)) - pl_record = api.driver.sfa_fields_to_pl_fields(type, hrn, record) - sites = api.driver.GetSites([pl_record['login_base']]) - if not sites: - pointer = api.driver.AddSite(pl_record) - else: - pointer = sites[0]['site_id'] - - record.set_pointer(pointer) + pointer = sites[0]['site_id'] + + record.set_pointer(pointer) + record['pointer'] = pointer + + elif (type == "slice"): + acceptable_fields=['url', 'instantiation', 'name', 'description'] + pl_record = api.driver.sfa_fields_to_pl_fields(type, hrn, record) + for key in pl_record.keys(): + if key not in acceptable_fields: + pl_record.pop(key) + slices = api.driver.GetSlices([pl_record['name']]) + if not slices: + pointer = api.driver.AddSlice(pl_record) + else: + pointer = slices[0]['slice_id'] + record.set_pointer(pointer) + record['pointer'] = pointer + + elif (type == "user"): + persons = api.driver.GetPersons([record['email']]) + if not persons: + pointer = api.driver.AddPerson(dict(record)) + else: + pointer = persons[0]['person_id'] + + if 'enabled' in record and record['enabled']: + api.driver.UpdatePerson(pointer, {'enabled': record['enabled']}) + # add this persons to the site only if he is being added for the first + # time by sfa and doesont already exist in plc + if not persons or not persons[0]['site_ids']: + login_base = get_leaf(record['authority']) + api.driver.AddPersonToSite(pointer, login_base) + + # What roles should this user have? + api.driver.AddRoleToPerson('user', pointer) + # Add the user's key + if pub_key: + api.driver.AddPersonKey(pointer, {'key_type' : 'ssh', 'key' : pub_key}) + + elif (type == "node"): + pl_record = api.driver.sfa_fields_to_pl_fields(type, hrn, record) + login_base = hrn_to_pl_login_base(record['authority']) + nodes = api.driver.GetNodes([pl_record['hostname']]) + if not nodes: + pointer = api.driver.AddNode(login_base, pl_record) + else: + pointer = nodes[0]['node_id'] + record['pointer'] = pointer - - elif (type == "slice"): - acceptable_fields=['url', 'instantiation', 'name', 'description'] - pl_record = api.driver.sfa_fields_to_pl_fields(type, hrn, record) - for key in pl_record.keys(): - if key not in acceptable_fields: - pl_record.pop(key) - slices = api.driver.GetSlices([pl_record['name']]) - if not slices: - pointer = api.driver.AddSlice(pl_record) - else: - pointer = slices[0]['slice_id'] record.set_pointer(pointer) - record['pointer'] = pointer - - elif (type == "user"): - persons = api.driver.GetPersons([record['email']]) - if not persons: - pointer = api.driver.AddPerson(dict(record)) + record_id = table.insert(record) + record['record_id'] = record_id + + # update membership for researchers, pis, owners, operators + api.driver.update_membership(None, record) + + return record.get_gid_object().save_to_string(save_parents=True) + + def update(self, api, record_dict): + new_record = SfaRecord(dict = record_dict) + type = new_record['type'] + hrn = new_record['hrn'] + urn = hrn_to_urn(hrn,type) + table = SfaTable() + # make sure the record exists + records = table.findObjects({'type': type, 'hrn': hrn}) + if not records: + raise RecordNotFound(hrn) + record = records[0] + record['last_updated'] = time.gmtime() + + # Update_membership needs the membership lists in the existing record + # filled in, so it can see if members were added or removed + api.driver.fill_record_info(record, api.aggregates) + + # Use the pointer from the existing record, not the one that the user + # gave us. This prevents the user from inserting a forged pointer + pointer = record['pointer'] + # update the PLC information that was specified with the record + + if (type == "authority"): + api.driver.UpdateSite(pointer, new_record) + + elif type == "slice": + pl_record=api.driver.sfa_fields_to_pl_fields(type, hrn, new_record) + if 'name' in pl_record: + pl_record.pop('name') + api.driver.UpdateSlice(pointer, pl_record) + + elif type == "user": + # SMBAKER: UpdatePerson only allows a limited set of fields to be + # updated. Ideally we should have a more generic way of doing + # this. I copied the field names from UpdatePerson.py... + update_fields = {} + all_fields = new_record + for key in all_fields.keys(): + if key in ['first_name', 'last_name', 'title', 'email', + 'password', 'phone', 'url', 'bio', 'accepted_aup', + 'enabled']: + update_fields[key] = all_fields[key] + api.driver.UpdatePerson(pointer, update_fields) + + if 'key' in new_record and new_record['key']: + # must check this key against the previous one if it exists + persons = api.driver.GetPersons([pointer], ['key_ids']) + person = persons[0] + keys = person['key_ids'] + keys = api.driver.GetKeys(person['key_ids']) + key_exists = False + if isinstance(new_record['key'], types.ListType): + new_key = new_record['key'][0] + else: + new_key = new_record['key'] + + # Delete all stale keys + for key in keys: + if new_record['key'] != key['key']: + api.driver.DeleteKey(key['key_id']) + else: + key_exists = True + if not key_exists: + api.driver.AddPersonKey(pointer, {'key_type': 'ssh', 'key': new_key}) + + # update the openssl key and gid + pkey = convert_public_key(new_key) + uuid = create_uuid() + gid_object = api.auth.hierarchy.create_gid(urn, uuid, pkey) + gid = gid_object.save_to_string(save_parents=True) + record['gid'] = gid + record = SfaRecord(dict=record) + table.update(record) + + elif type == "node": + api.driver.UpdateNode(pointer, new_record) + else: - pointer = persons[0]['person_id'] - - if 'enabled' in record and record['enabled']: - api.driver.UpdatePerson(pointer, {'enabled': record['enabled']}) - # add this persons to the site only if he is being added for the first - # time by sfa and doesont already exist in plc - if not persons or not persons[0]['site_ids']: - login_base = get_leaf(record['authority']) - api.driver.AddPersonToSite(pointer, login_base) - - # What roles should this user have? - api.driver.AddRoleToPerson('user', pointer) - # Add the user's key - if pub_key: - api.driver.AddPersonKey(pointer, {'key_type' : 'ssh', 'key' : pub_key}) - - elif (type == "node"): - pl_record = api.driver.sfa_fields_to_pl_fields(type, hrn, record) - login_base = hrn_to_pl_login_base(record['authority']) - nodes = api.driver.GetNodes([pl_record['hostname']]) - if not nodes: - pointer = api.driver.AddNode(login_base, pl_record) + raise UnknownSfaType(type) + + # update membership for researchers, pis, owners, operators + api.driver.update_membership(record, new_record) + + return 1 + + # expecting an Xrn instance + def remove(self, api, xrn, origin_hrn=None): + + table = SfaTable() + filter = {'hrn': xrn.get_hrn()} + hrn=xrn.get_hrn() + type=xrn.get_type() + if type and type not in ['all', '*']: + filter['type'] = type + + records = table.find(filter) + if not records: raise RecordNotFound(hrn) + record = records[0] + type = record['type'] + + credential = api.getCredential() + registries = api.registries + + # Try to remove the object from the PLCDB of federated agg. + # This is attempted before removing the object from the local agg's PLCDB and sfa table + if hrn.startswith(api.hrn) and type in ['user', 'slice', 'authority']: + for registry in registries: + if registry not in [api.hrn]: + try: + result=registries[registry].remove_peer_object(credential, record, origin_hrn) + except: + pass + if type == "user": + persons = api.driver.GetPersons(record['pointer']) + # only delete this person if he has site ids. if he doesnt, it probably means + # he was just removed from a site, not actually deleted + if persons and persons[0]['site_ids']: + api.driver.DeletePerson(record['pointer']) + elif type == "slice": + if api.driver.GetSlices(record['pointer']): + api.driver.DeleteSlice(record['pointer']) + elif type == "node": + if api.driver.GetNodes(record['pointer']): + api.driver.DeleteNode(record['pointer']) + elif type == "authority": + if api.driver.GetSites(record['pointer']): + api.driver.DeleteSite(record['pointer']) else: - pointer = nodes[0]['node_id'] - - record['pointer'] = pointer - record.set_pointer(pointer) - record_id = table.insert(record) - record['record_id'] = record_id - - # update membership for researchers, pis, owners, operators - api.driver.update_membership(None, record) - - return record.get_gid_object().save_to_string(save_parents=True) - -def update(api, record_dict): - new_record = SfaRecord(dict = record_dict) - type = new_record['type'] - hrn = new_record['hrn'] - urn = hrn_to_urn(hrn,type) - table = SfaTable() - # make sure the record exists - records = table.findObjects({'type': type, 'hrn': hrn}) - if not records: - raise RecordNotFound(hrn) - record = records[0] - record['last_updated'] = time.gmtime() - - # Update_membership needs the membership lists in the existing record - # filled in, so it can see if members were added or removed - api.driver.fill_record_info(record, api.aggregates) - - # Use the pointer from the existing record, not the one that the user - # gave us. This prevents the user from inserting a forged pointer - pointer = record['pointer'] - # update the PLC information that was specified with the record - - if (type == "authority"): - api.driver.UpdateSite(pointer, new_record) - - elif type == "slice": - pl_record=api.driver.sfa_fields_to_pl_fields(type, hrn, new_record) - if 'name' in pl_record: - pl_record.pop('name') - api.driver.UpdateSlice(pointer, pl_record) - - elif type == "user": - # SMBAKER: UpdatePerson only allows a limited set of fields to be - # updated. Ideally we should have a more generic way of doing - # this. I copied the field names from UpdatePerson.py... - update_fields = {} - all_fields = new_record - for key in all_fields.keys(): - if key in ['first_name', 'last_name', 'title', 'email', - 'password', 'phone', 'url', 'bio', 'accepted_aup', - 'enabled']: - update_fields[key] = all_fields[key] - api.driver.UpdatePerson(pointer, update_fields) - - if 'key' in new_record and new_record['key']: - # must check this key against the previous one if it exists - persons = api.driver.GetPersons([pointer], ['key_ids']) - person = persons[0] - keys = person['key_ids'] - keys = api.driver.GetKeys(person['key_ids']) - key_exists = False - if isinstance(new_record['key'], types.ListType): - new_key = new_record['key'][0] - else: - new_key = new_record['key'] - - # Delete all stale keys - for key in keys: - if new_record['key'] != key['key']: - api.driver.DeleteKey(key['key_id']) - else: - key_exists = True - if not key_exists: - api.driver.AddPersonKey(pointer, {'key_type': 'ssh', 'key': new_key}) - - # update the openssl key and gid - pkey = convert_public_key(new_key) - uuid = create_uuid() - gid_object = api.auth.hierarchy.create_gid(urn, uuid, pkey) - gid = gid_object.save_to_string(save_parents=True) - record['gid'] = gid - record = SfaRecord(dict=record) - table.update(record) - - elif type == "node": - api.driver.UpdateNode(pointer, new_record) - - else: - raise UnknownSfaType(type) - - # update membership for researchers, pis, owners, operators - api.driver.update_membership(record, new_record) + raise UnknownSfaType(type) - return 1 - -# expecting an Xrn instance -def remove(api, xrn, origin_hrn=None): - - table = SfaTable() - filter = {'hrn': xrn.get_hrn()} - hrn=xrn.get_hrn() - type=xrn.get_type() - if type and type not in ['all', '*']: - filter['type'] = type - - records = table.find(filter) - if not records: raise RecordNotFound(hrn) - record = records[0] - type = record['type'] - - credential = api.getCredential() - registries = api.registries - - # Try to remove the object from the PLCDB of federated agg. - # This is attempted before removing the object from the local agg's PLCDB and sfa table - if hrn.startswith(api.hrn) and type in ['user', 'slice', 'authority']: - for registry in registries: - if registry not in [api.hrn]: - try: - result=registries[registry].remove_peer_object(credential, record, origin_hrn) - except: - pass - if type == "user": - persons = api.driver.GetPersons(record['pointer']) - # only delete this person if he has site ids. if he doesnt, it probably means - # he was just removed from a site, not actually deleted - if persons and persons[0]['site_ids']: - api.driver.DeletePerson(record['pointer']) - elif type == "slice": - if api.driver.GetSlices(record['pointer']): - api.driver.DeleteSlice(record['pointer']) - elif type == "node": - if api.driver.GetNodes(record['pointer']): - api.driver.DeleteNode(record['pointer']) - elif type == "authority": - if api.driver.GetSites(record['pointer']): - api.driver.DeleteSite(record['pointer']) - else: - raise UnknownSfaType(type) - - table.remove(record) - - return 1 - -def remove_peer_object(api, record, origin_hrn=None): - pass - -def register_peer_object(api, record, origin_hrn=None): - pass + table.remove(record) + + return 1 + + def remove_peer_object(self, api, record, origin_hrn=None): + pass + + def register_peer_object(self, api, record, origin_hrn=None): + pass diff --git a/sfa/managers/slice_manager.py b/sfa/managers/slice_manager.py index 3d6c0a67..95655777 100644 --- a/sfa/managers/slice_manager.py +++ b/sfa/managers/slice_manager.py @@ -96,12 +96,12 @@ class SliceManager: if stats_tags: stats_tag = stats_tags[0] else: - stats_tag = etree.SubElement(rspec.xml.root, "statistics", call=callname) + stats_tag = rspec.xml.root.add_element("statistics", call=callname) - stat_tag = etree.SubElement(stats_tag, "aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status)) + stat_tag = stats_tag.add_element("aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status)) if exc_info: - exc_tag = etree.SubElement(stat_tag, "exc_info", name=str(exc_info[1])) + exc_tag = stat_tag.add_element("exc_info", name=str(exc_info[1])) # formats the traceback as one big text blob #exc_tag.text = "\n".join(traceback.format_exception(exc_info[0], exc_info[1], exc_info[2])) @@ -109,7 +109,7 @@ class SliceManager: # formats the traceback as a set of xml elements tb = traceback.extract_tb(exc_info[2]) for item in tb: - exc_frame = etree.SubElement(exc_tag, "tb_frame", filename=str(item[0]), line=str(item[1]), func=str(item[2]), code=str(item[3])) + exc_frame = exc_tag.add_element("tb_frame", filename=str(item[0]), line=str(item[1]), func=str(item[2]), code=str(item[3])) except Exception, e: logger.warn("add_slicemgr_stat failed on %s: %s" %(aggname, str(e))) diff --git a/sfa/methods/ResolveGENI.py b/sfa/methods/ResolveGENI.py index d9781971..7c9cc144 100644 --- a/sfa/methods/ResolveGENI.py +++ b/sfa/methods/ResolveGENI.py @@ -15,13 +15,4 @@ class ResolveGENI(Method): returns = Parameter(bool, "Success or Failure") def call(self, xrn): - - manager_base = 'sfa.managers' - - if self.api.interface in ['registry']: - mgr_type = self.api.config.SFA_REGISTRY_TYPE - manager_module = manager_base + ".registry_manager_%s" % mgr_type - manager = __import__(manager_module, fromlist=[manager_base]) - return manager.Resolve(self.api, xrn, '') - - return {} + return self.api.manager.Resolve(self.api, xrn, '') diff --git a/sfa/plc/aggregate.py b/sfa/plc/aggregate.py index bda3bbcc..c1f008fc 100644 --- a/sfa/plc/aggregate.py +++ b/sfa/plc/aggregate.py @@ -6,6 +6,7 @@ from sfa.rspecs.rspec import RSpec from sfa.rspecs.elements.hardware_type import HardwareType from sfa.rspecs.elements.node import Node from sfa.rspecs.elements.link import Link +from sfa.rspecs.elements.sliver import Sliver from sfa.rspecs.elements.login import Login from sfa.rspecs.elements.location import Location from sfa.rspecs.elements.interface import Interface @@ -14,6 +15,7 @@ from sfa.rspecs.elements.pltag import PLTag from sfa.util.topology import Topology from sfa.rspecs.version_manager import VersionManager from sfa.plc.vlink import get_tc_rate +from sfa.util.sfatime import epochparse class Aggregate: @@ -38,7 +40,8 @@ class Aggregate: iface['interface_id'] = interface['interface_id'] iface['node_id'] = interface['node_id'] iface['ipv4'] = interface['ip'] - iface['bwlimit'] = interface['bwlimit'] + if interface['bwlimit']: + iface['bwlimit'] = str(int(interface['bwlimit'])/1000) interfaces[iface['interface_id']] = iface return interfaces @@ -58,28 +61,27 @@ class Aggregate: # get hrns site1_hrn = self.api.hrn + '.' + site1['login_base'] site2_hrn = self.api.hrn + '.' + site2['login_base'] - # get the first node - node1 = self.nodes[site1['node_ids'][0]] - node2 = self.nodes[site2['node_ids'][0]] - - # set interfaces - # just get first interface of the first node - if1_xrn = PlXrn(auth=self.api.hrn, interface='node%s:eth0' % (node1['node_id'])) - if1_ipv4 = self.interfaces[node1['interface_ids'][0]]['ip'] - if2_xrn = PlXrn(auth=self.api.hrn, interface='node%s:eth0' % (node2['node_id'])) - if2_ipv4 = self.interfaces[node2['interface_ids'][0]]['ip'] - - if1 = Interface({'component_id': if1_xrn.urn, 'ipv4': if1_ipv4} ) - if2 = Interface({'component_id': if2_xrn.urn, 'ipv4': if2_ipv4} ) - - # set link - link = Link({'capacity': '1000000', 'latency': '0', 'packet_loss': '0', 'type': 'ipv4'}) - link['interface1'] = if1 - link['interface2'] = if2 - link['component_name'] = "%s:%s" % (site1['login_base'], site2['login_base']) - link['component_id'] = PlXrn(auth=self.api.hrn, interface=link['component_name']).get_urn() - link['component_manager_id'] = hrn_to_urn(self.api.hrn, 'authority+am') - links[link['component_name']] = link + + for s1_node in self.nodes[site1['node_ids']]: + for s2_node in self.nodes[site2['node_ids']]: + # set interfaces + # just get first interface of the first node + if1_xrn = PlXrn(auth=self.api.hrn, interface='node%s:eth0' % (s1_node['node_id'])) + if1_ipv4 = self.interfaces[node1['interface_ids'][0]]['ip'] + if2_xrn = PlXrn(auth=self.api.hrn, interface='node%s:eth0' % (s2_node['node_id'])) + if2_ipv4 = self.interfaces[node2['interface_ids'][0]]['ip'] + + if1 = Interface({'component_id': if1_xrn.urn, 'ipv4': if1_ipv4} ) + if2 = Interface({'component_id': if2_xrn.urn, 'ipv4': if2_ipv4} ) + + # set link + link = Link({'capacity': '1000000', 'latency': '0', 'packet_loss': '0', 'type': 'ipv4'}) + link['interface1'] = if1 + link['interface2'] = if2 + link['component_name'] = "%s:%s" % (site1['login_base'], site2['login_base']) + link['component_id'] = PlXrn(auth=self.api.hrn, interface=link['component_name']).get_urn() + link['component_manager_id'] = hrn_to_urn(self.api.hrn, 'authority+am') + links[link['component_name']] = link return links @@ -105,7 +107,7 @@ class Aggregate: slice = None if not slice_xrn: return (slice, slivers) - slice_urn = hrn_to_urn(slice_xrn) + slice_urn = hrn_to_urn(slice_xrn, 'slice') slice_hrn, _ = urn_to_hrn(slice_xrn) slice_name = hrn_to_pl_slicename(slice_hrn) slices = self.api.driver.GetSlices(slice_name) @@ -116,7 +118,8 @@ class Aggregate: # sort slivers by node id for node_id in slice['node_ids']: sliver = Sliver({'sliver_id': urn_to_sliver_id(slice_urn, slice['slice_id'], node_id), - 'name': 'plab-vserver', + 'name': slice['name'], + 'type': 'plab-vserver', 'tags': []}) slivers[node_id]= sliver @@ -155,13 +158,6 @@ class Aggregate: sites_dict = self.get_sites({'site_id': site_ids}) # get interfaces interfaces = self.get_interfaces({'interface_id':interface_ids}) - # get slivers - # - # thierry: no get_slivers, we have slivers as a result of - # get_slice_and_slivers passed as an argument - # -# slivers = self.get_slivers(slice) - # get tags node_tags = self.get_node_tags(tags_filter) # get initscripts @@ -183,20 +179,26 @@ class Aggregate: rspec_node['authority_id'] = hrn_to_urn(PlXrn.site_hrn(self.api.hrn, site['login_base']), 'authority+sa') rspec_node['boot_state'] = node['boot_state'] rspec_node['exclusive'] = 'False' - rspec_node['hardware_types'].append(HardwareType({'name': 'plab-vserver'})) + rspec_node['hardware_types']= [HardwareType({'name': 'plab-pc'}), + HardwareType({'name': 'pc'})] # only doing this because protogeni rspec needs # to advertise available initscripts - rspec_node['pl_initscripts'] = pl_initscripts + rspec_node['pl_initscripts'] = pl_initscripts.values() # add site/interface info to nodes. # assumes that sites, interfaces and tags have already been prepared. site = sites_dict[node['site_id']] - location = Location({'longitude': site['longitude'], 'latitude': site['latitude']}) - rspec_node['location'] = location + if site['longitude'] and site['latitude']: + location = Location({'longitude': site['longitude'], 'latitude': site['latitude']}) + rspec_node['location'] = location rspec_node['interfaces'] = [] + if_count=0 for if_id in node['interface_ids']: interface = Interface(interfaces[if_id]) interface['ipv4'] = interface['ipv4'] + interface['component_id'] = PlXrn(auth=self.api.hrn, interface='node%s:eth%s' % (node['node_id'], if_count)).get_urn() rspec_node['interfaces'].append(interface) + if_count+=1 + tags = [PLTag(node_tags[tag_id]) for tag_id in node['node_tag_ids']] rspec_node['tags'] = tags if node['node_id'] in slivers: @@ -204,12 +206,12 @@ class Aggregate: sliver = slivers[node['node_id']] rspec_node['sliver_id'] = sliver['sliver_id'] rspec_node['client_id'] = node['hostname'] - rspec_node['slivers'] = [slivers[node['node_id']]] + rspec_node['slivers'] = [sliver] # slivers always provide the ssh service - login = Login({'authentication': 'ssh-keys', 'hostname': node['hostname'], port:'22'}) + login = Login({'authentication': 'ssh-keys', 'hostname': node['hostname'], 'port':'22'}) service = Services({'login': login}) - rspec_node['services'].append(service) + rspec_node['services'] = [service] rspec_nodes.append(rspec_node) return rspec_nodes @@ -225,7 +227,9 @@ class Aggregate: slice, slivers = self.get_slice_and_slivers(slice_xrn) rspec = RSpec(version=rspec_version, user_options=self.user_options) - rspec.version.add_nodes(self.get_nodes(slice), slivers) + if slice and 'expires' in slice: + rspec.xml.set('expires', epochparse(slice['expires'])) + rspec.version.add_nodes(self.get_nodes(slice, slivers)) rspec.version.add_links(self.get_links(slice)) # add sliver defaults diff --git a/sfa/plc/plshell.py b/sfa/plc/plshell.py index 863472fe..972a97ea 100644 --- a/sfa/plc/plshell.py +++ b/sfa/plc/plshell.py @@ -30,9 +30,9 @@ class PlShell: 'AuthString': config.SFA_PLC_PASSWORD} self.url = config.SFA_PLC_URL - self.plauth = {'Username': 'root@test.onelab.eu', - 'AuthMethod': 'password', - 'AuthString': 'test++'} + #self.plauth = {'Username': 'root@test.onelab.eu', + # 'AuthMethod': 'password', + # 'AuthString': 'test++'} self.proxy_server = xmlrpclib.Server(self.url, verbose = 0, allow_none = True) def __getattr__(self, name): diff --git a/sfa/rspecs/rspec_version.py b/sfa/rspecs/baseversion.py similarity index 100% rename from sfa/rspecs/rspec_version.py rename to sfa/rspecs/baseversion.py diff --git a/sfa/rspecs/elements/bwlimit.py b/sfa/rspecs/elements/bwlimit.py index 027bb5b3..6f75161c 100644 --- a/sfa/rspecs/elements/bwlimit.py +++ b/sfa/rspecs/elements/bwlimit.py @@ -1,8 +1,8 @@ from sfa.rspecs.elements.element import Element class BWlimit(Element): - fields = { - 'units': None, - 'value': None, - } + fields = [ + 'units', + 'value', + ] diff --git a/sfa/rspecs/elements/component_manager.py b/sfa/rspecs/elements/component_manager.py deleted file mode 100644 index ec9d85c9..00000000 --- a/sfa/rspecs/elements/component_manager.py +++ /dev/null @@ -1,7 +0,0 @@ -from sfa.rspecs.elements.element import Element - -class ComponentManager(Element): - fields = { - 'name': None, - } - diff --git a/sfa/rspecs/elements/element.py b/sfa/rspecs/elements/element.py index 6757f8a8..5789a9cc 100644 --- a/sfa/rspecs/elements/element.py +++ b/sfa/rspecs/elements/element.py @@ -2,8 +2,56 @@ class Element(dict): fields = {} - def __init__(self, fields={}, element=None): + def __init__(self, fields={}, element=None, keys=None): self.element = element - dict.__init__(self, self.fields) - self.update(fields) + dict.__init__(self, dict.fromkeys(self.fields)) + if not keys: + keys = fields.keys() + for key in keys: + if key in fields: + self[key] = fields[key] + @staticmethod + def get_elements(xml, xpath, element_class=None, fields=None): + """ + Search the specifed xml node for elements that match the + specified xpath query. + Returns a list of objects instanced by the specified element_class. + """ + if not element_class: + element_class = Element + if not fields and hasattr(element_class, 'fields'): + fields = element_class.fields + elems = xml.xpath(xpath) + objs = [] + for elem in elems: + if not fields: + obj = element_class(elem.attrib, elem) + else: + obj = element_class({}, elem) + for field in fields: + if field in elem.attrib: + obj[field] = elem.attrib[field] + objs.append(obj) + return objs + + @staticmethod + def add_elements(xml, name, objs, fields=None): + """ + Adds a child node to the specified xml node based on + the specified name , element class and object. + """ + if not isinstance(objs, list): + objs = [objs] + elems = [] + for obj in objs: + if not obj: + continue + if not fields: + fields = obj.keys() + elem = xml.add_element(name) + for field in fields: + if field in obj and obj[field]: + elem.set(field, unicode(obj[field])) + elems.append(elem) + return elems diff --git a/sfa/rspecs/elements/execute.py b/sfa/rspecs/elements/execute.py index 43e6e626..e7ee7067 100644 --- a/sfa/rspecs/elements/execute.py +++ b/sfa/rspecs/elements/execute.py @@ -1,7 +1,7 @@ from sfa.rspecs.elements.element import Element class Execute(Element): - fields = { - 'shell': None, - 'command': None, - } + fields = [ + 'shell', + 'command', + ] diff --git a/sfa/rspecs/elements/hardware_type.py b/sfa/rspecs/elements/hardware_type.py index 8dd959c2..5f20c9bb 100644 --- a/sfa/rspecs/elements/hardware_type.py +++ b/sfa/rspecs/elements/hardware_type.py @@ -2,6 +2,6 @@ from sfa.rspecs.elements.element import Element class HardwareType(Element): - fields = { - 'name': None, - } + fields = [ + 'name' + ] diff --git a/sfa/rspecs/elements/install.py b/sfa/rspecs/elements/install.py index 1df60b68..227a7972 100644 --- a/sfa/rspecs/elements/install.py +++ b/sfa/rspecs/elements/install.py @@ -1,8 +1,8 @@ from sfa.rspecs.elements.element import Element class Install(Element): - fields = { - 'file_type': None, - 'url': None, - 'install_path': None, - } + fields = [ + 'file_type', + 'url', + 'install_path', + ] diff --git a/sfa/rspecs/elements/interface.py b/sfa/rspecs/elements/interface.py index 7617ade0..11045df8 100644 --- a/sfa/rspecs/elements/interface.py +++ b/sfa/rspecs/elements/interface.py @@ -1,12 +1,12 @@ from sfa.rspecs.elements.element import Element class Interface(Element): - fields = {'component_id': None, - 'role': None, - 'client_id': None, - 'ipv4': None, - 'bwlimit': None, - 'node_id': None, - 'interface_id': None - - } + fields = ['component_id', + 'role', + 'client_id', + 'ipv4', + 'bwlimit', + 'node_id', + 'interface_id', + 'mac_address', + ] diff --git a/sfa/rspecs/elements/link.py b/sfa/rspecs/elements/link.py index 02a8d102..3bbfe2bb 100644 --- a/sfa/rspecs/elements/link.py +++ b/sfa/rspecs/elements/link.py @@ -1,16 +1,16 @@ from sfa.rspecs.elements.element import Element class Link(Element): - fields = { - 'client_id': None, - 'component_id': None, - 'component_name': None, - 'component_manager': None, - 'type': None, - 'interface1': None, - 'interface2': None, - 'capacity': None, - 'latency': None, - 'packet_loss': None, - 'description': None, - } + fields = [ + 'client_id', + 'component_id', + 'component_name', + 'component_manager', + 'type', + 'interface1', + 'interface2', + 'capacity', + 'latency', + 'packet_loss', + 'description', + ] diff --git a/sfa/rspecs/elements/link_type.py b/sfa/rspecs/elements/link_type.py deleted file mode 100644 index 882903d1..00000000 --- a/sfa/rspecs/elements/link_type.py +++ /dev/null @@ -1,6 +0,0 @@ -from sfa.rspecs.elements.element import Element - -class LinkType(Element): - fields = { - 'name': None, - } diff --git a/sfa/rspecs/elements/location.py b/sfa/rspecs/elements/location.py index a5a92603..57bfe0c1 100644 --- a/sfa/rspecs/elements/location.py +++ b/sfa/rspecs/elements/location.py @@ -2,8 +2,8 @@ from sfa.rspecs.elements.element import Element class Location(Element): - fields = { - 'country': None, - 'longitude': None, - 'latitude': None, - } + fields = [ + 'country', + 'longitude', + 'latitude', + ] diff --git a/sfa/rspecs/elements/login.py b/sfa/rspecs/elements/login.py index a64c7598..ae42641b 100644 --- a/sfa/rspecs/elements/login.py +++ b/sfa/rspecs/elements/login.py @@ -1,8 +1,8 @@ from sfa.rspecs.elements.element import Element class Login(Element): - fields = { - 'authentication': None, - 'hostname': None, - 'port': None - } + fields = [ + 'authentication', + 'hostname', + 'port' + ] diff --git a/sfa/rspecs/elements/network.py b/sfa/rspecs/elements/network.py deleted file mode 100644 index 362b9ffb..00000000 --- a/sfa/rspecs/elements/network.py +++ /dev/null @@ -1,9 +0,0 @@ -from sfa.rspecs.elements.element import Element - -class Network(Element): - - fields = { - 'name': None, - } - - diff --git a/sfa/rspecs/elements/node.py b/sfa/rspecs/elements/node.py index f90fff1b..7358ee03 100644 --- a/sfa/rspecs/elements/node.py +++ b/sfa/rspecs/elements/node.py @@ -2,25 +2,25 @@ from sfa.rspecs.elements.element import Element class Node(Element): - fields = { - 'component_id': None, - 'component_name': None, - 'component_manager_id': None, - 'client_id': None, - 'sliver_id': None, - 'authority_id': None, - 'exclusive': None, - 'location': None, - 'bw_unallocated': None, - 'bw_limit': None, - 'boot_state': None, - 'slivers': [], - 'hardware_types': [], - 'disk_images': [], - 'interfaces': [], - 'services': [], - 'tags': [], - 'pl_initscripts': [], - } + fields = [ + 'component_id', + 'component_name', + 'component_manager_id', + 'client_id', + 'sliver_id', + 'authority_id', + 'exclusive', + 'location', + 'bw_unallocated', + 'bw_limit', + 'boot_state', + 'slivers', + 'hardware_types', + 'disk_images', + 'interfaces', + 'services', + 'tags', + 'pl_initscripts', + ] diff --git a/sfa/rspecs/elements/pltag.py b/sfa/rspecs/elements/pltag.py index 51b1e765..0868a941 100644 --- a/sfa/rspecs/elements/pltag.py +++ b/sfa/rspecs/elements/pltag.py @@ -2,8 +2,8 @@ from sfa.rspecs.elements.element import Element class PLTag(Element): - fields = { - 'name': None, - 'value': None, - } + fields = [ + 'tagname', + 'value', + ] diff --git a/sfa/rspecs/elements/property.py b/sfa/rspecs/elements/property.py index 97a1ffcb..472dedeb 100644 --- a/sfa/rspecs/elements/property.py +++ b/sfa/rspecs/elements/property.py @@ -2,11 +2,11 @@ from sfa.rspecs.elements.element import Element class Property(Element): - fields = { - 'source_id': None, - 'dest_id': None, - 'capacity': None, - 'latency': None, - 'packet_loss': None, - } + fields = [ + 'source_id', + 'dest_id', + 'capacity', + 'latency', + 'packet_loss', + ] diff --git a/sfa/rspecs/elements/services.py b/sfa/rspecs/elements/services.py index a48be27a..df0546d4 100644 --- a/sfa/rspecs/elements/services.py +++ b/sfa/rspecs/elements/services.py @@ -2,9 +2,9 @@ from sfa.rspecs.elements.element import Element class Services(Element): - fields = { - 'install': [], - 'execute': [], - 'login': [], - } + fields = [ + 'install', + 'execute', + 'login', + ] diff --git a/sfa/rspecs/elements/sliver.py b/sfa/rspecs/elements/sliver.py index bf2cc1f7..8dd65425 100644 --- a/sfa/rspecs/elements/sliver.py +++ b/sfa/rspecs/elements/sliver.py @@ -1,9 +1,11 @@ from sfa.rspecs.elements.element import Element class Sliver(Element): - fields = { - 'sliver_id': None, - 'client_id': None, - 'name': None, - 'tags': [], - } + fields = [ + 'sliver_id', + 'component_id', + 'client_id', + 'name', + 'type', + 'tags', + ] diff --git a/sfa/rspecs/elements/tag.py b/sfa/rspecs/elements/tag.py deleted file mode 100644 index 8b137891..00000000 --- a/sfa/rspecs/elements/tag.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/sfa/rspecs/elements/versions/element_version.py b/sfa/rspecs/elements/versions/element_version.py deleted file mode 100644 index e69de29b..00000000 diff --git a/sfa/rspecs/elements/versions/pgv2Link.py b/sfa/rspecs/elements/versions/pgv2Link.py index db28f6ce..e496d166 100644 --- a/sfa/rspecs/elements/versions/pgv2Link.py +++ b/sfa/rspecs/elements/versions/pgv2Link.py @@ -1,81 +1,64 @@ -from lxml import etree from sfa.util.plxrn import PlXrn from sfa.util.xrn import Xrn +from sfa.rspecs.elements.element import Element from sfa.rspecs.elements.link import Link from sfa.rspecs.elements.interface import Interface -from sfa.rspecs.elements.link_type import LinkType -from sfa.rspecs.elements.component_manager import ComponentManager from sfa.rspecs.elements.property import Property -from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements class PGv2Link: - elements = { - 'link': RSpecElement(RSpecElements.LINK, '//default:link | //link'), - 'component_manager': RSpecElement(RSpecElements.COMPONENT_MANAGER, './default:component_manager | ./component_manager'), - 'link_type': RSpecElement(RSpecElements.LINK_TYPE, './default:link_type | ./link_type'), - 'property': RSpecElement(RSpecElements.PROPERTY, './default:property | ./property'), - 'interface_ref': RSpecElement(RSpecElements.INTERFACE_REF, './default:interface_ref | ./interface_ref'), - } - @staticmethod def add_links(xml, links): for link in links: - link_elem = etree.SubElement(xml, 'link') - for attrib in ['component_name', 'component_id', 'client_id']: - if attrib in link and link[attrib] is not None: - link_elem.set(attrib, link[attrib]) + link_elems = Element.add(xml, 'link', link, ['component_name', 'component_id', 'client_id']) + link_elem = link_elems[0] + # set component manager element if 'component_manager' in link and link['component_manager']: - cm_element = etree.SubElement(link_elem, 'component_manager', name=link['component_manager']) + cm_element = link_elem.add_element('component_manager', name=link['component_manager']) + # set interface_ref elements for if_ref in [link['interface1'], link['interface2']]: - if_ref_elem = etree.SubElement(link_elem, 'interface_ref') - for attrib in Interface.fields: - if attrib in if_ref and if_ref[attrib]: - if_ref_elem.attrib[attrib] = if_ref[attrib] - prop1 = etree.SubElement(link_elem, 'property', source_id = link['interface1']['component_id'], + Element.add(link_elem, 'interface_ref', if_ref, Interface.fields) + # set property elements + prop1 = link_elem.add_element('property', source_id = link['interface1']['component_id'], dest_id = link['interface2']['component_id'], capacity=link['capacity'], latency=link['latency'], packet_loss=link['packet_loss']) - prop2 = etree.SubElement(link_elem, 'property', source_id = link['interface2']['component_id'], + prop2 = link_elem.add_element('property', source_id = link['interface2']['component_id'], dest_id = link['interface1']['component_id'], capacity=link['capacity'], latency=link['latency'], packet_loss=link['packet_loss']) - if 'type' in link and link['type']: - type_elem = etree.SubElement(link_elem, 'link_type', name=link['type']) + if link.get('type'): + type_elem = link_elem.add_element('link_type', name=link['type']) + @staticmethod def get_links(xml): links = [] - link_elems = xml.xpath(PGv2Link.elements['link'].path, namespaces=xml.namespaces) + link_elems = xml.xpath('//default:link | //link') for link_elem in link_elems: # set client_id, component_id, component_name link = Link(link_elem.attrib, link_elem) + # set component manager - cm = link_elem.xpath('./default:component_manager', namespaces=xml.namespaces) - if len(cm) > 0: - cm = cm[0] - if 'name' in cm.attrib: - link['component_manager'] = cm.attrib['name'] + component_managers = link_elem.xpath('./default:component_manager | ./component_manager') + if len(component_managers) > 0 and 'name' in component_managers[0].attrib: + link['component_manager'] = component_managers[0].attrib['name'] + # set link type - link_types = link_elem.xpath(PGv2Link.elements['link_type'].path, namespaces=xml.namespaces) - if len(link_types) > 0: - link_type = link_types[0] - if 'name' in link_type.attrib: - link['type'] = link_type.attrib['name'] + link_types = link_elem.xpath('./default:link_type | ./link_type') + if len(link_types) > 0 and 'name' in link_types[0].attrib: + link['type'] = link_types[0].attrib['name'] # get capacity, latency and packet_loss from first property - props = link_elem.xpath(PGv2Link.elements['property'].path, namespaces=xml.namespaces) - if len(props) > 0: - prop = props[0] + property_fields = ['capacity', 'latency', 'packet_loss'] + property_elems = link_elem.xpath('./default:property | ./property') + if len(propery_elems) > 0: + prop = property_elems[0] for attrib in ['capacity', 'latency', 'packet_loss']: - if attrib in prop.attrib: - link[attrib] = prop.attrib[attrib] + if attrib in prop: + link[attrib] = prop[attrib] - # get interfaces - if_elems = link_elem.xpath(PGv2Link.elements['interface_ref'].path, namespaces=xml.namespaces) - ifs = [] - for if_elem in if_elems: - if_ref = Interface(if_elem.attrib, if_elem) - ifs.append(if_ref) - if len(ifs) > 1: - link['interface1'] = ifs[0] - link['interface2'] = ifs[1] + # get interfaces + interfaces = Element.get(Interface, link_elem, './default:interface_ref | ./interface_ref') + if len(interfaces) > 1: + link['interface1'] = interfaces[0] + link['interface2'] = interfaces[1] links.append(link) return links diff --git a/sfa/rspecs/elements/versions/pgv2Node.py b/sfa/rspecs/elements/versions/pgv2Node.py index b293b049..e5ec58b4 100644 --- a/sfa/rspecs/elements/versions/pgv2Node.py +++ b/sfa/rspecs/elements/versions/pgv2Node.py @@ -1,140 +1,120 @@ - -from lxml import etree -from sfa.util.plxrn import PlXrn +from sfa.util.plxrn import PlXrn, xrn_to_hostname from sfa.util.xrn import Xrn +from sfa.util.xml import XpathFilter +from sfa.rspecs.elements.element import Element from sfa.rspecs.elements.node import Node from sfa.rspecs.elements.sliver import Sliver -from sfa.rspecs.elements.network import Network from sfa.rspecs.elements.location import Location from sfa.rspecs.elements.hardware_type import HardwareType from sfa.rspecs.elements.disk_image import DiskImage from sfa.rspecs.elements.interface import Interface from sfa.rspecs.elements.bwlimit import BWlimit from sfa.rspecs.elements.pltag import PLTag -from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements -from sfa.rspecs.elements.versions.pgv2Service import PGv2Service +from sfa.rspecs.elements.versions.pgv2Services import PGv2Services +from sfa.rspecs.elements.versions.pgv2SliverType import PGv2SliverType class PGv2Node: - elements = { - 'node': RSpecElement(RSpecElements.NODE, '//default:node | //node'), - 'sliver': RSpecElement(RSpecElements.SLIVER, './default:sliver_type | ./sliver_type'), - 'interface': RSpecElement(RSpecElements.INTERFACE, './default:interface | ./interface'), - 'location': RSpecElement(RSpecElements.LOCATION, './default:location | ./location'), - 'hardware_type': RSpecElement(RSpecElements.HARDWARE_TYPE, './default:hardware_type | ./hardware_type'), - 'available': RSpecElement(RSpecElements.AVAILABLE, './default:available | ./available'), - } - @staticmethod def add_nodes(xml, nodes): node_elems = [] for node in nodes: - node_elem = etree.SubElement(xml, 'node') + node_fields = ['component_manager_id', 'component_id', 'client_id', 'sliver_id', 'exclusive'] + elems = Element.add_elements(xml, 'node', node, node_fields) + node_elem = elems[0] node_elems.append(node_elem) - if node.get('component_manager_id'): - node_elem.set('component_manager_id', node['component_manager_id']) + # set component name if node.get('component_id'): - node_elem.set('component_id', node['component_id']) - component_name = Xrn(node['component_id']).get_leaf() - node_elem.set('component_nama', component_name) - if node.get('client_id'): - node_elem.set('client_id', node['client_id']) - if node.get('sliver_id'): - node_elem.set('sliver_id', node['sliver_id']) - if node.get('exclusive'): - node_elem.set('exclusive', node['exclusive']) - hardware_types = node.get('hardware_type', []) - for hardware_type in hardware_types: - hw_type_elem = etree.SubElement(node_elem, 'hardware_type') - if hardware_type.get('name'): - hw_type_elem.set('name', hardware_type['name']) + component_name = xrn_to_hostname(node['component_id']) + node_elem.set('component_name', component_name) + # set hardware types + Element.add_elements(node_elem, 'hardware_type', node.get('hardware_types', []), HardwareType.fields) + # set location + location_elems = Element.add_elements(node_elem, 'location', node.get('location', []), Location.fields) + # set interfaces + interface_elems = Element.add_elements(node_elem, 'interface', node.get('interfaces', []), ['component_id', 'client_id', 'ipv4']) + # set available element if node.get('boot_state', '').lower() == 'boot': - available_elem = etree.SubElement(node_elem, 'available', now='True') + available_elem = node_elem.add_element('available', now='True') else: - available_elem = etree.SubElement(node_elem, 'available', now='False') - - if node.get('services'): - PGv2Services.add_services(node_elem, node.get('services')) - + available_elem = node_elem.add_element('available', now='False') + # add services + PGv2Services.add_services(node_elem, node.get('services', [])) + # add slivers slivers = node.get('slivers', []) - pl_initscripts = node.get('pl_initscripts', {}) - for sliver in slivers: - sliver_elem = etree.SubElement(node_elem, 'sliver_type') - if sliver.get('name'): - sliver_elem.set('name', sliver['name']) - if sliver.get('client_id'): - sliver_elem.set('client_id', sliver['client_id']) - for pl_initscript in pl_initscripts.values(): - etree.SubElement(sliver_elem, '{%s}initscript' % xml.namespaces['planetlab'], \ - name=pl_initscript['name']) - location = node.get('location') - #only add locaiton if long and lat are not null - if location.get('longitute') and location.get('latitude'): - location_elem = etree.SubElement(node_elem, country=location['country'], - latitude=location['latitude'], longitude=location['longiutde']) + if not slivers: + # we must still advertise the available sliver types + slivers = Sliver({'type': 'plab-vserver'}) + # we must also advertise the available initscripts + slivers['tags'] = [] + for initscript in node.get('pl_initscripts', []): + slivers['tags'].append({'name': 'initscript', 'value': initscript['name']}) + PGv2SliverType.add_slivers(node_elem, slivers) + return node_elems @staticmethod - def get_nodes(xml): + def get_nodes(xml, filter={}): + xpath = '//node%s | //default:node%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) + node_elems = xml.xpath(xpath) + return PGv2Node.get_node_objs(node_elems) + + @staticmethod + def get_nodes_with_slivers(xml, filter={}): + xpath = '//node/sliver_type | //default:node/default:sliver_type' + node_elems = xml.xpath(xpath) + return PGv2Node.get_node_objs(node_elems) + + @staticmethod + def get_node_objs(node_elems): nodes = [] - node_elems = xml.xpath(PGv2Node.elements['node'].path) for node_elem in node_elems: node = Node(node_elem.attrib, node_elem) nodes.append(node) if 'component_id' in node_elem.attrib: node['authority_id'] = Xrn(node_elem.attrib['component_id']).get_authority_urn() - # set hardware type - node['hardware_types'] = [] - hardware_type_elems = node_elem.xpath(PGv2Node.elements['hardware_type'].path, xml.namespaces) - for hardware_type_elem in hardware_type_elems: - node['hardware_types'].append(HardwareType(hardware_type_elem.attrib, hardware_type_elem)) - - # set location - location_elems = node_elem.xpath(PGv2Node.elements['location'].path, xml.namespaces) + node['hardware_types'] = Element.get_elements(node_elem, './default:hardwate_type | ./hardware_type', HardwareType) + location_elems = Element.get_elements(node_elem, './default:location | ./location', Location) if len(location_elems) > 0: - node['location'] = Location(location_elems[0].attrib, location_elems[0]) - - # set services - services_elems = node_elem.xpath(PGv2Service.elements['services'].path, xml.namespaces) - node['services'] = [] - for services_elem in services_elems: - # services element has no useful info, but the child elements do - for child in services_elem.iterchildren(): - pass - - # set interfaces - interface_elems = node_elem.xpath(PGv2Node.elements['interface'].path, xml.namespaces) - node['interfaces'] = [] - for interface_elem in interface_elems: - node['interfaces'].append(Interface(interface_elem.attrib, interface_elem)) - - # set available - available = node_elem.xpath(PGv2Node.elements['available'].path, xml.namespaces) - if len(available) > 0: - if available[0].attrib.get('now', '').lower() == 'true': + node['location'] = location_elems[0] + node['interfaces'] = Element.get_elements(node_elem, './default:interface | ./interface', Interface) + node['services'] = PGv2Services.get_services(node_elem) + node['slivers'] = PGv2SliverType.get_slivers(node_elem) + available_elem = Element.get_elements(node_elem, './default:available | ./available', fields=['now']) + if len(available_elem) > 0 and 'name' in available_elem[0]: + if available_elem[0].get('now', '').lower() == 'true': node['boot_state'] = 'boot' else: node['boot_state'] = 'disabled' - - # set the slivers - sliver_elems = node_elem.xpath(PGv2Node.elements['sliver'].path, xml.namespaces) - node['slivers'] = [] - for sliver_elem in sliver_elems: - node['slivers'].append(Sliver(sliver_elem.attrib, sliver_elem)) - return nodes @staticmethod def add_slivers(xml, slivers): - pass - - @staticmethod - def get_nodes_with_slivers(xml): - nodes = PGv2Node.get_nodes(xml) - nodes_with_slivers = [node for node in nodes if node['slivers']] - return nodes_with_slivers + component_ids = [] + for sliver in slivers: + filter = {} + if isinstance(sliver, str): + filter['component_id'] = '*%s*' % sliver + sliver = {} + elif 'component_id' in sliver and sliver['component_id']: + filter['component_id'] = '*%s*' % sliver['component_id'] + if not filter: + continue + nodes = PGv2Node.get_nodes(xml, filter) + if not nodes: + continue + node = nodes[0] + PGv2SliverType.add_slivers(node, sliver) + @staticmethod + def remove_slivers(xml, hostnames): + for hostname in hostnames: + nodes = PGv2Node.get_nodes(xml, {'component_id': '*%s*' % hostname}) + for node in nodes: + slivers = PGv2SliverType.get_slivers(node.element) + for sliver in slivers: + node.element.remove(sliver.element) if __name__ == '__main__': from sfa.rspecs.rspec import RSpec import pdb diff --git a/sfa/rspecs/elements/versions/pgv2Services.py b/sfa/rspecs/elements/versions/pgv2Services.py index fe946b3a..9eb1ed9a 100644 --- a/sfa/rspecs/elements/versions/pgv2Services.py +++ b/sfa/rspecs/elements/versions/pgv2Services.py @@ -1,65 +1,28 @@ -from lxml import etree -from sfa.util.plxrn import PlXrn -from sfa.util.xrn import Xrn +from sfa.rspecs.elements.element import Element from sfa.rspecs.elements.execute import Execute from sfa.rspecs.elements.install import Install from sfa.rspecs.elements.login import Login -from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements class PGv2Services: - elements = { - 'services': RSpecElement(RSpecElements.SERVICES, '//default:services | //services'), - 'install': RSpecElement(RSpecElements.INSTALL, './default:install | ./install'), - 'execute': RSpecElement(RSpecElements.EXECUTE, './default:execute | ./execute'), - 'login': RSpecElement(RSpecElements.LOGIN, './default:login | ./login'), - } - @staticmethod def add_services(xml, services): - for service in services: - service_elem = etree.SubElement(xml, 'service') - for install in service.get('install', []): - install_elem = etree.SubElement(service_elem, 'install') - for field in Install.fields: - if field in install: - install_elem.set(field, install[field]) - for execute in service.get('execute', []): - execute_elem = etree.SubElement(service_elem, 'execute') - for field in Execute.fields: - if field in execute: - execute_elem.set(field, execute[field]) - for login in service.get('login', []): - login_elem = etree.SubElement(service_elem, 'login') - for field in Login.fields: - if field in login: - login_elem.set(field, login[field]) + if not services: + return + for service in services: + service_elem = xml.add_element('services') + Element.add_elements(service_elem, 'install', service.get('install', []), Install.fields) + Element.add_elements(service_elem, 'execute', service.get('execute', []), Execute.fields) + Element.add_elements(service_elem, 'login', service.get('login', []), Login.fields) @staticmethod def get_services(xml): services = [] - for services_elem in xml.xpath(PGv2Services.elements['services'].path): + for services_elem in xml.xpath('./default:services | ./services'): service = Services(services_elem.attrib, services_elem) - - # get install elements - service['install'] = [] - for install_elem in xml.xpath(PGv2Services.elements['install'].path): - install = Install(install_elem.attrib, install_elem) - service['install'].append(install) - - # get execute elements - service['execute'] = [] - for execute_elem in xml.xpath(PGv2Services.elements['execute'].path): - execute = Execute(execute_elem.attrib, execute_elem) - service['execute'].append(execute) - - # get login elements - service['login'] = [] - for login_elem in xml.xpath(PGv2Services.elements['login'].path): - login = Login(login_elem.attrib, login_elem) - service['login'].append(login) - + service['install'] = Element.get_elements(service_elem, './default:install | ./install', Install) + service['execute'] = Element.get_elements(service_elem, './default:execute | ./execute', Execute) + service['login'] = Element.get_elements(service_elem, './default:login | ./login', Login) services.append(service) - return services diff --git a/sfa/rspecs/elements/versions/pgv2SliverType.py b/sfa/rspecs/elements/versions/pgv2SliverType.py index e69de29b..c0715321 100644 --- a/sfa/rspecs/elements/versions/pgv2SliverType.py +++ b/sfa/rspecs/elements/versions/pgv2SliverType.py @@ -0,0 +1,40 @@ +from sfa.rspecs.elements.element import Element +from sfa.rspecs.elements.sliver import Sliver + +class PGv2SliverType: + + @staticmethod + def add_slivers(xml, slivers): + if not slivers: + return + if not isinstance(slivers, list): + slivers = [slivers] + for sliver in slivers: + sliver_elem = Element.add_elements(xml, 'sliver_type', sliver, ['type', 'client_id']) + PGv2SliverType.add_sliver_attributes(sliver_elem, sliver.get('pl_tags', [])) + + @staticmethod + def add_sliver_attributes(xml, attributes): + for attribute in attributes: + if attribute['name'] == 'initscript': + xml.add_element('{%s}initscript' % xml.namespaces['planetlab'], name=attribute['value']) + elif tag['tagname'] == 'flack_info': + attrib_elem = xml.add_element('{%s}info' % self.namespaces['flack']) + attrib_dict = eval(tag['value']) + for (key, value) in attrib_dict.items(): + attrib_elem.set(key, value) + @staticmethod + def get_slivers(xml, filter={}): + xpath = './default:sliver_type | ./sliver_type' + sliver_elems = xml.xpath(xpath) + slivers = [] + for sliver_elem in sliver_elems: + sliver = Sliver(sliver_elem.attrib,sliver_elem) + if 'component_id' in xml.attrib: + sliver['component_id'] = xml.attrib['component_id'] + slivers.append(sliver) + return slivers + + @staticmethod + def get_sliver_attributes(xml, filter={}): + return [] diff --git a/sfa/rspecs/elements/versions/sfav1Network.py b/sfa/rspecs/elements/versions/sfav1Network.py deleted file mode 100644 index b529ad57..00000000 --- a/sfa/rspecs/elements/versions/sfav1Network.py +++ /dev/null @@ -1,32 +0,0 @@ - - -from lxml import etree -from sfa.util.plxrn import PlXrn -from sfa.util.xrn import Xrn -from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements - -class SFAv1Network: - elements = { - 'network': RSpecElement(RSpecElements.NETWORK, '//network'), - } - - @staticmethod - def add_network(xml, network): - found = False - network_objs = SFAv1Network.get_networks(xml) - for network_obj in network_objs: - if network_obj['name'] == network['name']: - found = True - network_elem = network_obj.element - if not found: - network_elem = etree.SubElement(xml, 'network', name = network['name']) - return network_elem - - @staticmethod - def get_networks(xml): - networks = [] - network_elems = xml.xpath(SFAv1Network.elements['network'].path) - for network_elem in network_elems: - network = Network({'name': network_elem.attrib.get('name', None)}, network_elem) - networks.append(network) - return networks diff --git a/sfa/rspecs/elements/versions/sfav1Node.py b/sfa/rspecs/elements/versions/sfav1Node.py index 0b67c91d..f62cc718 100644 --- a/sfa/rspecs/elements/versions/sfav1Node.py +++ b/sfa/rspecs/elements/versions/sfav1Node.py @@ -1,142 +1,140 @@ - -from lxml import etree -from sfa.util.plxrn import PlXrn +from sfa.util.sfalogging import logger +from sfa.util.xml import XpathFilter +from sfa.util.plxrn import PlXrn, xrn_to_hostname from sfa.util.xrn import Xrn +from sfa.rspecs.elements.element import Element from sfa.rspecs.elements.node import Node from sfa.rspecs.elements.sliver import Sliver -from sfa.rspecs.elements.network import Network from sfa.rspecs.elements.location import Location from sfa.rspecs.elements.hardware_type import HardwareType from sfa.rspecs.elements.disk_image import DiskImage from sfa.rspecs.elements.interface import Interface from sfa.rspecs.elements.bwlimit import BWlimit from sfa.rspecs.elements.pltag import PLTag -from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements -from sfa.rspecs.elements.versions.sfav1Network import SFAv1Network +from sfa.rspecs.elements.versions.sfav1Sliver import SFAv1Sliver +from sfa.rspecs.elements.versions.sfav1PLTag import SFAv1PLTag from sfa.rspecs.elements.versions.pgv2Services import PGv2Services class SFAv1Node: - elements = { - 'node': RSpecElement(RSpecElements.NODE, '//default:node | //node'), - 'sliver': RSpecElement(RSpecElements.SLIVER, './default:sliver | ./sliver'), - 'interface': RSpecElement(RSpecElements.INTERFACE, './default:interface | ./interface'), - 'location': RSpecElement(RSpecElements.LOCATION, './default:location | ./location'), - 'bw_limit': RSpecElement(RSpecElements.BWLIMIT, './default:bw_limit | ./bw_limit'), - } - @staticmethod def add_nodes(xml, nodes): - network_elems = SFAv1Network.get_networks(xml) + network_elems = Element.get_elements(xml, '//network', fields=['name']) if len(network_elems) > 0: network_elem = network_elems[0] elif len(nodes) > 0 and nodes[0].get('component_manager_id'): - network_elem = SFAv1Network.add_network(xml.root, {'name': nodes[0]['component_manager_id']}) - + network_urn = nodes[0]['component_manager_id'] + network_elems = Element.add_elements(xml, 'network', {'name': Xrn(network_urn).get_hrn()}) + network_elem = network_elems[0] node_elems = [] for node in nodes: - node_elem = etree.SubElement(network_elem, 'node') + node_fields = ['component_manager_id', 'component_id', 'boot_state'] + elems = Element.add_elements(network_elem, 'node', node, node_fields) + node_elem = elems[0] node_elems.append(node_elem) - network = None + + # determine network hrn + network_hrn = None if 'component_manager_id' in node and node['component_manager_id']: - node_elem.set('component_manager_id', node['component_manager_id']) - network = Xrn(node['component_manager_id']).get_hrn() + network_hrn = Xrn(node['component_manager_id']).get_hrn() + + # set component_name attribute and hostname element if 'component_id' in node and node['component_id']: - node_elem.set('component_id', node['component_id']) - xrn = Xrn(node['component_id']) - node_elem.set('component_name', xrn.get_leaf()) - hostname_tag = etree.SubElement(node_elem, 'hostname').text = xrn.get_leaf() + component_name = xrn_to_hostname(node['component_id']) + node_elem.set('component_name', component_name) + hostname_tag = node_elem.add_element('hostname') + hostname_tag.set_text(component_name) + + # set site id if 'authority_id' in node and node['authority_id']: node_elem.set('site_id', node['authority_id']) - if 'boot_state' in node and node['boot_state']: - node_elem.set('boot_state', node['boot_state']) - if 'location' in node and node['location']: - location_elem = etree.SubElement(node_elem, 'location') - for field in Location.fields: - if field in node['location'] and node['location'][field]: - location_elem.set(field, node['location'][field]) - if 'interfaces' in node and node['interfaces']: - i = 0 - for interface in node['interfaces']: - if 'bwlimit' in interface and interface['bwlimit']: - bwlimit = etree.SubElement(node_elem, 'bw_limit', units='kbps').text = str(interface['bwlimit']/1000) - comp_id = PlXrn(auth=network, interface='node%s:eth%s' % (interface['node_id'], i)).get_urn() - ipaddr = interface['ipv4'] - interface_elem = etree.SubElement(node_elem, 'interface', component_id=comp_id, ipv4=ipaddr) - i+=1 - if 'bw_unallocated' in node and node['bw_unallocated']: - bw_unallocated = etree.SubElement(node_elem, 'bw_unallocated', units='kbps').text = str(int(node['bw_unallocated'])/1000) - - if node.get('services'): - PGv2Services.add_services(node_elem, node.get('services')) - if 'tags' in node: - for tag in node['tags']: - # expose this hard wired list of tags, plus the ones that are marked 'sfa' in their category - if tag['name'] in ['fcdistro', 'arch']: - tag_element = etree.SubElement(node_elem, tag['name']).text=tag['value'] + location_elems = Element.add_elements(node_elem, 'location', + node.get('location', []), Location.fields) + interface_elems = Element.add_elements(node_elem, 'interface', + node.get('interfaces', []), ['component_id', 'client_id', 'ipv4']) + + #if 'bw_unallocated' in node and node['bw_unallocated']: + # bw_unallocated = etree.SubElement(node_elem, 'bw_unallocated', units='kbps').text = str(int(node['bw_unallocated'])/1000) - if node.get('slivers'): - for sliver in node['slivers']: - sliver_elem = etree.SubElement(node_elem, 'sliver') - if sliver.get('sliver_id'): - sliver_id_leaf = Xrn(sliver.get('sliver_id')).get_leaf() - sliver_id_parts = sliver_id_leaf.split(':') - name = sliver_id_parts[0] - sliver_elem.set('name', name) + PGv2Services.add_services(node_elem, node.get('services', [])) + SFAv1PLTag.add_pl_tags(node_elem, node.get('tags', [])) + SFAv1Sliver.add_slivers(node_elem, node.get('slivers', [])) @staticmethod def add_slivers(xml, slivers): - pass + component_ids = [] + for sliver in slivers: + filter = {} + if isinstance(sliver, str): + filter['component_id'] = '*%s*' % sliver + sliver = {} + elif 'component_id' in sliver and sliver['component_id']: + filter['component_id'] = '*%s*' % sliver['component_id'] + if not fliter: + continue + nodes = SFAv1Node.get_nodes(xml, filter) + if not nodes: + continue + node = nodes[0] + SFAv1Sliver.add_slivers(node, sliver) @staticmethod - def get_nodes(xml): - nodes = [] - node_elems = xml.xpath(SFAv1Node.elements['node'].path) + def remove_slivers(xml, hostnames): + for hostname in hostnames: + nodes = SFAv1Node.get_nodes(xml, {'component_id': '*%s*' % hostname}) + for node in nodes: + slivers = SFAv1Slivers.get_slivers(node.element) + for sliver in slivers: + node.element.remove(sliver.element) + + @staticmethod + def get_nodes(xml, filter={}): + xpath = '//node%s | //default:node%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) + node_elems = xml.xpath(xpath) + return SFAv1Node.get_node_objs(node_elems) + + # xxx Thierry : an ugly hack to get the tests to pass again + # probably this needs to be trashed + # the original code returned the tag, + # but we prefer the father node instead as it already holds data + # initially this was to preserve the nodename... + # xxx I don't get the ' | //default:node/default:sliver' ... + @staticmethod + def get_nodes_with_slivers_thierry(xml): + # dropping the '' + xpath = '//node[count (sliver)>0]' + node_elems = xml.xpath(xpath) + # we need to check/recompute the node data + + return node_elems + + @staticmethod + def get_nodes_with_slivers(xml): + xpath = '//node/sliver | //default:node/default:sliver' + node_elems = xml.xpath(xpath) + return SFAv1Node.get_node_objs(node_elems) + + + @staticmethod + def get_node_objs(node_elems): + nodes = [] for node_elem in node_elems: node = Node(node_elem.attrib, node_elem) if 'site_id' in node_elem.attrib: node['authority_id'] = node_elem.attrib['site_id'] - if 'authority_id' in node_elem.attrib: - node['authority_id'] = node_elem.attrib['authority_id'] - - # set the location - location_elems = node_elem.xpath(SFAv1Node.elements['location'].path, xml.namespaces) - if len(location_elems) > 0: - node['location'] = Location(location_elems[0].attrib, location_elems[0]) - - # set the bwlimit - bwlimit_elems = node_elem.xpath(SFAv1Node.elements['bw_limit'].path, xml.namespaces) - if len(bwlimit_elems) > 0: - bwlimit = BWlimit(bwlimit_elems[0].attrib, bwlimit_elems[0]) - node['bwlimit'] = bwlimit - - # set the interfaces - interface_elems = node_elem.xpath(SFAv1Node.elements['interface'].path, xml.namespaces) - node['interfaces'] = [] - for interface_elem in interface_elems: - node['interfaces'].append(Interface(interface_elem.attrib, interface_elem)) - - # set the slivers - sliver_elems = node_elem.xpath(SFAv1Node.elements['sliver'].path, xml.namespaces) - node['slivers'] = [] - for sliver_elem in sliver_elems: - node['slivers'].append(Sliver(sliver_elem.attrib, sliver_elem)) - - # set tags - node['tags'] = [] - for child in node_elem.iterchildren(): - if child.tag not in SFAv1Node.elements: - tag = PLTag({'name': child.tag, 'value': child.text}, child) - node['tags'].append(tag) + location_objs = Element.get_elements(node_elem, './default:location | ./location', Location) + if len(location_objs) > 0: + node['location'] = location_objs[0] + bwlimit_objs = Element.get_elements(node_elem, './default:bw_limit | ./bw_limit', BWlimit) + if len(bwlimit_objs) > 0: + node['bwlimit'] = bwlimit_objs[0] + node['interfaces'] = Element.get_elements(node_elem, './default:interface | ./interface', Interface) + node['services'] = PGv2Services.get_services(node_elem) + node['slivers'] = SFAv1Sliver.get_slivers(node_elem) +#thierry node['tags'] = SFAv1PLTag.get_pl_tags(node_elem, ignore=Node.fields.keys()) + node['tags'] = SFAv1PLTag.get_pl_tags(node_elem, ignore=Node.fields) nodes.append(node) - return nodes - - @staticmethod - def get_nodes_with_slivers(xml): - nodes = SFAv1Node.get_nodes(xml) - nodes_with_slivers = [node for node in nodes if node['slivers']] - return nodes_with_slivers - - + return nodes + diff --git a/sfa/rspecs/elements/versions/sfav1PLTag.py b/sfa/rspecs/elements/versions/sfav1PLTag.py new file mode 100644 index 00000000..8e04d511 --- /dev/null +++ b/sfa/rspecs/elements/versions/sfav1PLTag.py @@ -0,0 +1,19 @@ +from sfa.rspecs.elements.element import Element +from sfa.rspecs.elements.pltag import PLTag + +class SFAv1PLTag: + @staticmethod + def add_pl_tags(xml, pl_tags): + for pl_tag in pl_tags: + pl_tag_elem = xml.add_element(pl_tag['tagname']) + pl_tag_elem.set_text(pl_tag['value']) + + @staticmethod + def get_pl_tags(xml, ignore=[]): + pl_tags = [] + for elem in xml.iterchildren(): + if elem.tag not in ignore: + pl_tag = PLTag({'tagname': elem.tag, 'value': elem.text}) + pl_tags.append(pl_tag) + return pl_tags + diff --git a/sfa/rspecs/elements/versions/sfav1Sliver.py b/sfa/rspecs/elements/versions/sfav1Sliver.py index f12c9776..d0a7592b 100644 --- a/sfa/rspecs/elements/versions/sfav1Sliver.py +++ b/sfa/rspecs/elements/versions/sfav1Sliver.py @@ -1,18 +1,42 @@ - -from lxml import etree - +from sfa.util.xrn import Xrn +from sfa.rspecs.elements.element import Element from sfa.rspecs.elements.sliver import Sliver +from sfa.rspecs.elements.versions.sfav1PLTag import SFAv1PLTag -from sfa.util.xrn import Xrn -from sfa.util.plxrn import PlXrn class SFAv1Sliver: @staticmethod def add_slivers(xml, slivers): + if not slivers: + return + if not isinstance(slivers, list): + slivers = [slivers] for sliver in slivers: - sliver_elem = etree.SubElement(xml, 'sliver') - if sliver.get('component_id'): - name_full = Xrn(sliver.get('component_id')).get_leaf() - name = name_full.split(':') + sliver_elem = Element.add_elements(xml, 'sliver', sliver, ['name'])[0] + SFAv1Sliver.add_sliver_attributes(sliver_elem, sliver.get('tags', [])) + if sliver.get('sliver_id'): + sliver_id_leaf = Xrn(sliver.get('sliver_id')).get_leaf() + sliver_id_parts = sliver_id_leaf.split(':') + name = sliver_id_parts[0] sliver_elem.set('name', name) - + + @staticmethod + def add_sliver_attributes(xml, attributes): + SFAv1PLTag.add_pl_tags(xml, attributes) + + @staticmethod + def get_slivers(xml, filter={}): + xpath = './default:sliver | ./sliver' + sliver_elems = xml.xpath(xpath) + slivers = [] + for sliver_elem in sliver_elems: + sliver = Sliver(sliver_elem.attrib,sliver_elem) + if 'component_id' in xml.attrib: + sliver['component_id'] = xml.attrib['component_id'] + sliver['tags'] = SFAv1Sliver.get_sliver_attributes(sliver_elem) + slivers.append(sliver) + return slivers + + @staticmethod + def get_sliver_attributes(xml, filter={}): + return SFAv1PLTag.get_pl_tags(xml, ignore=Sliver.fields) diff --git a/sfa/rspecs/version_manager.py b/sfa/rspecs/version_manager.py index f53ec6f9..27eba932 100644 --- a/sfa/rspecs/version_manager.py +++ b/sfa/rspecs/version_manager.py @@ -1,6 +1,6 @@ import os from sfa.util.faults import InvalidRSpec -from sfa.rspecs.rspec_version import BaseVersion +from sfa.rspecs.baseversion import BaseVersion from sfa.util.sfalogging import logger class VersionManager: diff --git a/sfa/rspecs/versions/pgv2.py b/sfa/rspecs/versions/pgv2.py index b57cd9bb..ceff971b 100644 --- a/sfa/rspecs/versions/pgv2.py +++ b/sfa/rspecs/versions/pgv2.py @@ -1,11 +1,11 @@ -from lxml import etree from copy import deepcopy from StringIO import StringIO from sfa.util.xrn import urn_to_sliver_id from sfa.util.plxrn import hostname_to_urn, xrn_to_hostname -from sfa.rspecs.rspec_version import BaseVersion -from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements +from sfa.rspecs.baseversion import BaseVersion from sfa.rspecs.elements.versions.pgv2Link import PGv2Link +from sfa.rspecs.elements.versions.pgv2Node import PGv2Node +from sfa.rspecs.elements.versions.pgv2SliverType import PGv2SliverType class PGv2(BaseVersion): type = 'ProtoGENI' @@ -18,8 +18,8 @@ class PGv2(BaseVersion): 'planetlab': "http://www.planet-lab.org/resources/sfa/ext/planetlab/1", } namespaces = dict(extensions.items() + [('default', namespace)]) - elements = [] + # Networks def get_network(self): network = None nodes = self.xml.xpath('//default:node[@component_manager_id][1]', namespaces=self.namespaces) @@ -31,45 +31,38 @@ class PGv2(BaseVersion): networks = self.xml.xpath('//default:node[@component_manager_id]/@component_manager_id', namespaces=self.namespaces) return set(networks) - def get_node_element(self, hostname, network=None): - nodes = self.xml.xpath('//default:node[@component_id[contains(., "%s")]] | node[@component_id[contains(., "%s")]]' % (hostname, hostname), namespaces=self.namespaces) - if isinstance(nodes,list) and nodes: - return nodes[0] - else: - return None - - def get_node_elements(self, network=None): - nodes = self.xml.xpath('//default:node | //node', namespaces=self.namespaces) - return nodes - - - def get_nodes(self, network=None): - xpath = '//default:node[@component_name]/@component_id | //node[@component_name]/@component_id' - nodes = self.xml.xpath(xpath, namespaces=self.namespaces) - nodes = [xrn_to_hostname(node) for node in nodes] - return nodes - - def get_nodes_with_slivers(self, network=None): - if network: - nodes = self.xml.xpath('//default:node[@component_manager_id="%s"][sliver_type]/@component_id' % network, namespaces=self.namespaces) - else: - nodes = self.xml.xpath('//default:node[default:sliver_type]/@component_id', namespaces=self.namespaces) - nodes = [xrn_to_hostname(node) for node in nodes] - return nodes - - def get_nodes_without_slivers(self, network=None): - return [] + + # Nodes + + def get_nodes(self, filter=None): + return PGv2Node.get_nodes(self.xml, filter) + + def get_nodes_with_slivers(self): + return PGv2Node.get_nodes_with_slivers(self.xml) + + def add_nodes(self, nodes, check_for_dupes=False): + return PGv2Node.add_nodes(self.xml, nodes) + + def merge_node(self, source_node_tag): + # this is untested + self.xml.root.append(deepcopy(source_node_tag)) + # Slivers + def get_sliver_attributes(self, hostname, network=None): - node = self.get_node_element(hostname, network) - sliver = node.xpath('./default:sliver_type', namespaces=self.namespaces) - if sliver is not None and isinstance(sliver, list): - sliver = sliver[0] - return self.attributes_list(sliver) + nodes = self.get_nodes({'component_id': '*%s*' %hostname}) + attribs = [] + if nodes is not None and isinstance(nodes, list) and len(nodes) > 0: + node = nodes[0] + sliver = node.xpath('./default:sliver_type', namespaces=self.namespaces) + if sliver is not None and isinstance(sliver, list) and len(sliver) > 0: + sliver = sliver[0] + #attribs = self.attributes_list(sliver) + return attribs def get_slice_attributes(self, network=None): slice_attributes = [] - nodes_with_slivers = self.get_nodes_with_slivers(network) + nodes_with_slivers = self.get_nodes_with_slivers() # TODO: default sliver attributes in the PG rspec? default_ns_prefix = self.namespaces['default'] for node in nodes_with_slivers: @@ -92,18 +85,6 @@ class PGv2(BaseVersion): return slice_attributes - def get_links(self, network=None): - return PGv2Link.get_links(self.xml) - - def get_link_requests(self): - return PGv2Link.get_link_requests(self.xml) - - def add_links(self, links): - PGv2Link.add_links(self.xml.root, links) - - def add_link_requests(self, link_tuples, append=False): - PGv2Link.add_link_requests(self.xml.root, link_tuples, append) - def attributes_list(self, elem): opts = [] if elem is not None: @@ -117,115 +98,70 @@ class PGv2(BaseVersion): def add_default_sliver_attribute(self, name, value, network=None): pass - def add_nodes(self, nodes, check_for_dupes=False): - if not isinstance(nodes, list): - nodes = [nodes] - for node in nodes: - urn = "" - if check_for_dupes and \ - self.xml.xpath('//default:node[@component_uuid="%s"]' % urn, namespaces=self.namespaces): - # node already exists + def add_slivers(self, hostnames, attributes=[], sliver_urn=None, append=False): + # all nodes hould already be present in the rspec. Remove all + # nodes that done have slivers + for hostname in hostnames: + node_elems = self.get_nodes({'component_id': '*%s*' % hostname}) + if not node_elems: continue + node_elem = node_elems[0] + + # determine sliver types for this node + valid_sliver_types = ['emulab-openvz', 'raw-pc', 'plab-vserver', 'plab-vnode'] + requested_sliver_type = None + for sliver_type in node_elem.get('slivers', []): + if sliver_type.get('type') in valid_sliver_types: + requested_sliver_type = sliver_type['type'] + + if not requested_sliver_type: + continue + sliver = {'name': requested_sliver_type, + 'pl_tags': attributes} + + # remove existing sliver_type tags + for sliver_type in node_elem.get('slivers', []): + node_elem.element.remove(sliver_type.element) + + # set the client id + node_elem.element.set('client_id', hostname) + if sliver_urn: + pass + # TODO + # set the sliver id + #slice_id = sliver_info.get('slice_id', -1) + #node_id = sliver_info.get('node_id', -1) + #sliver_id = urn_to_sliver_id(sliver_urn, slice_id, node_id) + #node_elem.set('sliver_id', sliver_id) + + # add the sliver type elemnt + PGv2SliverType.add_slivers(node_elem.element, sliver) + + # remove all nodes without slivers + if not append: + for node_elem in self.get_nodes(): + if not node_elem['client_id']: + parent = node_elem.element.getparent() + parent.remove(node_elem.element) - node_tag = etree.SubElement(self.xml.root, 'node', exclusive='false') - if 'network_urn' in node: - node_tag.set('component_manager_id', node['network_urn']) - if 'urn' in node: - node_tag.set('component_id', node['urn']) - if 'hostname' in node: - node_tag.set('component_name', node['hostname']) - # TODO: should replace plab-pc with pc model - node_type_tag = etree.SubElement(node_tag, 'hardware_type', name='plab-pc') - node_type_tag = etree.SubElement(node_tag, 'hardware_type', name='pc') - available_tag = etree.SubElement(node_tag, 'available', now='true') - sliver_type_tag = etree.SubElement(node_tag, 'sliver_type', name='plab-vserver') - - pl_initscripts = node.get('pl_initscripts', {}) - for pl_initscript in pl_initscripts.values(): - etree.SubElement(sliver_type_tag, '{%s}initscript' % self.namespaces['planetlab'], name=pl_initscript['name']) - - # protogeni uses the tag to identify the types of - # vms available at the node. - # only add location tag if longitude and latitude are not null - if 'site' in node: - longitude = node['site'].get('longitude', None) - latitude = node['site'].get('latitude', None) - if longitude and latitude: - location_tag = etree.SubElement(node_tag, 'location', country="us", \ - longitude=str(longitude), latitude=str(latitude)) - - def merge_node(self, source_node_tag): - # this is untested - self.xml.root.append(deepcopy(source_node_tag)) + def remove_slivers(self, slivers, network=None, no_dupes=False): + PGv2Node.remove_slivers(self.xml, slivers) - def add_slivers(self, slivers, sliver_urn=None, no_dupes=False, append=False): + # Links - # all nodes hould already be present in the rspec. Remove all - # nodes that done have slivers - slivers_dict = {} - for sliver in slivers: - if isinstance(sliver, basestring): - slivers_dict[sliver] = {'hostname': sliver} - elif isinstance(sliver, dict): - slivers_dict[sliver['hostname']] = sliver - - nodes = self.get_node_elements() - for node in nodes: - urn = node.get('component_id') - hostname = xrn_to_hostname(urn) - if hostname not in slivers_dict and not append: - parent = node.getparent() - parent.remove(node) - else: - sliver_info = slivers_dict[hostname] - sliver_type_elements = node.xpath('./default:sliver_type', namespaces=self.namespaces) - available_sliver_types = [element.attrib['name'] for element in sliver_type_elements] - valid_sliver_types = ['emulab-openvz', 'raw-pc', 'plab-vserver', 'plab-vnode'] - requested_sliver_type = None - for valid_sliver_type in valid_sliver_types: - if valid_sliver_type in available_sliver_types: - requested_sliver_type = valid_sliver_type - if requested_sliver_type: - # remove existing sliver_type tags,it needs to be recreated - sliver_elem = node.xpath('./default:sliver_type | ./sliver_type', namespaces=self.namespaces) - if sliver_elem and isinstance(sliver_elem, list): - sliver_elem = sliver_elem[0] - node.remove(sliver_elem) - # set the client id - node.set('client_id', hostname) - if sliver_urn: - # set the sliver id - slice_id = sliver_info.get('slice_id', -1) - node_id = sliver_info.get('node_id', -1) - sliver_id = urn_to_sliver_id(sliver_urn, slice_id, node_id) - node.set('sliver_id', sliver_id) - - # add the sliver element - sliver_elem = etree.SubElement(node, 'sliver_type', name=requested_sliver_type) - for tag in sliver_info.get('tags', []): - if tag['tagname'] == 'flack_info': - e = etree.SubElement(sliver_elem, '{%s}info' % self.namespaces['flack'], attrib=eval(tag['value'])) - elif tag['tagname'] == 'initscript': - e = etree.SubElement(sliver_elem, '{%s}initscript' % self.namespaces['planetlab'], attrib={'name': tag['value']}) - else: - # node isn't usable. just remove it from the request - parent = node.getparent() - parent.remove(node) + def get_links(self, network=None): + return PGv2Link.get_links(self.xml) - + def get_link_requests(self): + return PGv2Link.get_link_requests(self.xml) - def remove_slivers(self, slivers, network=None, no_dupes=False): - for sliver in slivers: - node_elem = self.get_node_element(sliver['hostname']) - sliver_elem = node_elem.xpath('./default:sliver_type', self.namespaces) - if sliver_elem != None and sliver_elem != []: - node_elem.remove(sliver_elem[0]) + def add_links(self, links): + PGv2Link.add_links(self.xml.root, links) - def add_default_sliver_attribute(self, name, value, network=None): - pass + def add_link_requests(self, link_tuples, append=False): + PGv2Link.add_link_requests(self.xml.root, link_tuples, append) - def add_interfaces(self, interfaces, no_dupes=False): - pass + # Utility def merge(self, in_rspec): """ @@ -235,9 +171,9 @@ class PGv2(BaseVersion): # just copy over all the child elements under the root element if isinstance(in_rspec, RSpec): in_rspec = in_rspec.toxml() - tree = etree.parse(StringIO(in_rspec)) - root = tree.getroot() - for child in root.getchildren(): + + rspec = RSpec(in_rspec) + for child in rspec.xml.iterchildren(): self.xml.root.append(child) def cleanup(self): diff --git a/sfa/rspecs/versions/sfav1.py b/sfa/rspecs/versions/sfav1.py index 3917b39b..85aa86e6 100644 --- a/sfa/rspecs/versions/sfav1.py +++ b/sfa/rspecs/versions/sfav1.py @@ -1,10 +1,14 @@ from copy import deepcopy from lxml import etree + +from sfa.util.sfalogging import logger from sfa.util.xrn import hrn_to_urn, urn_to_hrn from sfa.util.plxrn import PlXrn -from sfa.rspecs.rspec_version import BaseVersion -from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements +from sfa.rspecs.baseversion import BaseVersion +from sfa.rspecs.elements.element import Element from sfa.rspecs.elements.versions.pgv2Link import PGv2Link +from sfa.rspecs.elements.versions.sfav1Node import SFAv1Node +from sfa.rspecs.elements.versions.sfav1Sliver import SFAv1Sliver class SFAv1(BaseVersion): enabled = True @@ -15,60 +19,45 @@ class SFAv1(BaseVersion): namespace = None extensions = {} namespaces = None - elements = [] template = '' % type - def get_network_elements(self): - return self.xml.xpath('//network') - + # Network def get_networks(self): - return self.xml.xpath('//network[@name]/@name') + return Element.get_elements(self.xml, '//network', Element) - def get_node_element(self, hostname, network=None): - if network: - names = self.xml.xpath('//network[@name="%s"]//node/hostname' % network) + def add_network(self, network): + network_tags = self.xml.xpath('//network[@name="%s"]' % network) + if not network_tags: + network_tag = etree.SubElement(self.xml.root, 'network', name=network) else: - names = self.xml.xpath('//node/hostname') - for name in names: - if str(name.text).strip() == hostname: - return name.getparent() - return None + network_tag = network_tags[0] + return network_tag - def get_node_elements(self, network=None): - if network: - return self.xml.xpath('//network[@name="%s"]//node' % network) - else: - return self.xml.xpath('//node') - def get_nodes(self, network=None): - if network == None: - nodes = self.xml.xpath('//node/hostname/text()') - else: - nodes = self.xml.xpath('//network[@name="%s"]//node/hostname/text()' % network) + # Nodes + + def get_nodes(self, filter=None): + return SFAv1Node.get_nodes(self.xml, filter) - nodes = [node.strip() for node in nodes] - return nodes + def get_nodes_with_slivers(self): + return SFAv1Node.get_nodes_with_slivers(self.xml) - def get_nodes_with_slivers(self, network = None): - if network: - nodes = self.xml.xpath('//network[@name="%s"]//node[sliver]/hostname/text()' % network) - else: - nodes = self.xml.xpath('//node[sliver]/hostname/text()') + def add_nodes(self, nodes, network = None, no_dupes=False): + SFAv1Node.add_nodes(self.xml, nodes) - nodes = [node.strip() for node in nodes] - return nodes + def merge_node(self, source_node_tag, network, no_dupes=False): + if no_dupes and self.get_node_element(node['hostname']): + # node already exists + return - def get_nodes_without_slivers(self, network=None): - xpath_nodes_without_slivers = '//node[not(sliver)]/hostname/text()' - xpath_nodes_without_slivers_in_network = '//network[@name="%s"]//node[not(sliver)]/hostname/text()' - if network: - return self.xml.xpath('//network[@name="%s"]//node[not(sliver)]/hostname/text()' % network) - else: - return self.xml.xpath('//node[not(sliver)]/hostname/text()') + network_tag = self.add_network(network) + network_tag.append(deepcopy(source_node_tag)) + # Slivers + def attributes_list(self, elem): # convert a list of attribute tags into list of tuples - # (tagnme, text_value) + # (tagname, text_value) opts = [] if elem is not None: for e in elem: @@ -80,216 +69,65 @@ class SFAv1(BaseVersion): defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network) else: defaults = self.xml.xpath("//sliver_defaults") - if isinstance(defaults, list) and defaults: - defaults = defaults[0] - return self.attributes_list(defaults) + if not defaults: return [] + return self.attributes_list_thierry(defaults) def get_sliver_attributes(self, hostname, network=None): - attributes = [] - node = self.get_node_element(hostname, network) - #sliver = node.find("sliver") - slivers = node.xpath('./sliver') - if isinstance(slivers, list) and slivers: - attributes = self.attributes_list(slivers[0]) - return attributes + nodes = self.get_nodes({'component_id': '*%s*' %hostname}) + attribs = [] + if nodes is not None and isinstance(nodes, list) and len(nodes) > 0: + node = nodes[0] + sliver = node.xpath('./default:sliver', namespaces=self.namespaces) + if sliver is not None and isinstance(sliver, list) and len(sliver) > 0: + sliver = sliver[0] + #attribs = self.attributes_list(sliver) + return attribs def get_slice_attributes(self, network=None): slice_attributes = [] - nodes_with_slivers = self.get_nodes_with_slivers(network) + nodes_with_slivers = self.get_nodes_with_slivers() for default_attribute in self.get_default_sliver_attributes(network): - attribute = {'name': str(default_attribute[0]), 'value': str(default_attribute[1]), 'node_id': None} + attribute = {'name': str(default_attribute[0]), + 'value': str(default_attribute[1]), + 'node_id': None} slice_attributes.append(attribute) for node in nodes_with_slivers: - sliver_attributes = self.get_sliver_attributes(node, network) + nodename=node.get('component_name') + sliver_attributes = self.get_sliver_attributes(nodename, network) for sliver_attribute in sliver_attributes: attribute = {'name': str(sliver_attribute[0]), 'value': str(sliver_attribute[1]), 'node_id': node} slice_attributes.append(attribute) return slice_attributes - def get_site_nodes(self, siteid, network=None): - if network: - nodes = self.xml.xpath('//network[@name="%s"]/site[@id="%s"]/node/hostname/text()' % \ - (network, siteid)) - else: - nodes = self.xml.xpath('//site[@id="%s"]/node/hostname/text()' % siteid) - return nodes - - def get_links(self, network=None): - return PGv2Link.get_links(self.xml) - - def get_link_requests(self): - return PGv2Link.get_link_requests(self.xml) - - def get_link(self, fromnode, tonode, network=None): - fromsite = fromnode.getparent() - tosite = tonode.getparent() - fromid = fromsite.get("id") - toid = tosite.get("id") - if network: - query = "//network[@name='%s']" % network + "/link[@endpoints = '%s %s']" - else: - query = "//link[@endpoints = '%s %s']" - - results = self.rspec.xpath(query % (fromid, toid)) - if not results: - results = self.rspec.xpath(query % (toid, fromid)) - return results - - def query_links(self, fromnode, tonode, network=None): - return get_link(fromnode, tonode, network) - - def get_vlinks(self, network=None): - vlinklist = [] - if network: - vlinks = self.xml.xpath("//network[@name='%s']//vlink" % network) - else: - vlinks = self.xml.xpath("//vlink") - for vlink in vlinks: - endpoints = vlink.get("endpoints") - (end1, end2) = endpoints.split() - if network: - node1 = self.xml.xpath('//network[@name="%s"]//node[@id="%s"]/hostname/text()' % \ - (network, end1))[0] - node2 = self.xml.xpath('//network[@name="%s"]//node[@id="%s"]/hostname/text()' % \ - (network, end2))[0] - else: - node1 = self.xml.xpath('//node[@id="%s"]/hostname/text()' % end1)[0] - node2 = self.xml.xpath('//node[@id="%s"]/hostname/text()' % end2)[0] - desc = "%s <--> %s" % (node1, node2) - kbps = vlink.find("kbps") - vlinklist.append((endpoints, desc, kbps.text)) - return vlinklist - - def get_vlink(self, endponts, network=None): - if network: - query = "//network[@name='%s']//vlink[@endpoints = '%s']" % (network, endpoints) - else: - query = "//vlink[@endpoints = '%s']" % (network, endpoints) - results = self.rspec.xpath(query) - return results - - def query_vlinks(self, endpoints, network=None): - return get_vlink(endpoints,network) - - - ################## - # Builder - ################## - - def add_network(self, network): - network_tags = self.xml.xpath('//network[@name="%s"]' % network) - if not network_tags: - network_tag = etree.SubElement(self.xml.root, 'network', name=network) - else: - network_tag = network_tags[0] - return network_tag - - def add_nodes(self, nodes, network = None, no_dupes=False): - if not isinstance(nodes, list): - nodes = [nodes] - for node in nodes: - if no_dupes and \ - self.get_node_element(node['hostname']): - # node already exists - continue - - network_tag = self.xml.root - if 'network' in node: - network = node['network'] - network_tag = self.add_network(network) - - node_tag = etree.SubElement(network_tag, 'node') - if 'network' in node: - node_tag.set('component_manager_id', hrn_to_urn(network, 'authority+sa')) - if 'urn' in node: - node_tag.set('component_id', node['urn']) - if 'site_urn' in node: - node_tag.set('site_id', node['site_urn']) - if 'node_id' in node: - node_tag.set('node_id', 'n'+str(node['node_id'])) - if 'boot_state' in node: - node_tag.set('boot_state', node['boot_state']) - if 'hostname' in node: - node_tag.set('component_name', node['hostname']) - hostname_tag = etree.SubElement(node_tag, 'hostname').text = node['hostname'] - if 'interfaces' in node: - i = 0 - for interface in node['interfaces']: - if 'bwlimit' in interface and interface['bwlimit']: - bwlimit = etree.SubElement(node_tag, 'bw_limit', units='kbps').text = str(interface['bwlimit']/1000) - comp_id = PlXrn(auth=network, interface='node%s:eth%s' % (node['node_id'], i)).get_urn() - ipaddr = interface['ip'] - interface_tag = etree.SubElement(node_tag, 'interface', component_id=comp_id, ipv4=ipaddr) - i+=1 - if 'bw_unallocated' in node: - bw_unallocated = etree.SubElement(node_tag, 'bw_unallocated', units='kbps').text = str(node['bw_unallocated']/1000) - if 'tags' in node: - for tag in node['tags']: - # expose this hard wired list of tags, plus the ones that are marked 'sfa' in their category - if tag['tagname'] in ['fcdistro', 'arch'] or 'sfa' in tag['category'].split('/'): - tag_element = etree.SubElement(node_tag, tag['tagname']).text=tag['value'] - - if 'site' in node: - longitude = str(node['site']['longitude']) - latitude = str(node['site']['latitude']) - location = etree.SubElement(node_tag, 'location', country='unknown', \ - longitude=longitude, latitude=latitude) - - def merge_node(self, source_node_tag, network, no_dupes=False): - if no_dupes and self.get_node_element(node['hostname']): - # node already exists - return - - network_tag = self.add_network(network) - network_tag.append(deepcopy(source_node_tag)) - - def add_interfaces(self, interfaces): - pass - - def add_links(self, links): - networks = self.get_network_elements() - if len(networks) > 0: - xml = networks[0] - else: - xml = self.xml - PGv2Link.add_links(xml, links) - - def add_link_requests(self, links): - PGv2Link.add_link_requests(self.xml, links) - - def add_slivers(self, slivers, network=None, sliver_urn=None, no_dupes=False, append=False): + def add_slivers(self, hostnames, attributes=[], sliver_urn=None, append=False): # add slice name to network tag network_tags = self.xml.xpath('//network') if network_tags: network_tag = network_tags[0] network_tag.set('slice', urn_to_hrn(sliver_urn)[0]) - all_nodes = self.get_nodes() - nodes_with_slivers = [sliver['hostname'] for sliver in slivers] - nodes_without_slivers = set(all_nodes).difference(nodes_with_slivers) - # add slivers - for sliver in slivers: - node_elem = self.get_node_element(sliver['hostname'], network) - if not node_elem: continue - sliver_elem = etree.SubElement(node_elem, 'sliver') - if 'tags' in sliver: - for tag in sliver['tags']: - etree.SubElement(sliver_elem, tag['tagname']).text = value=tag['value'] - + sliver = {'name':sliver_urn, + 'pl_tags': attributes} + for hostname in hostnames: + if sliver_urn: + sliver['name'] = sliver_urn + node_elems = self.get_nodes({'component_id': '*%s*' % hostname}) + if not node_elems: + continue + node_elem = node_elems[0] + SFAv1Sliver.add_slivers(node_elem.element, sliver) + # remove all nodes without slivers if not append: - for node in nodes_without_slivers: - node_elem = self.get_node_element(node) - parent = node_elem.getparent() - parent.remove(node_elem) + for node_elem in self.get_nodes(): + if not node_elem['slivers']: + parent = node_elem.element.getparent() + parent.remove(node_elem.element) + def remove_slivers(self, slivers, network=None, no_dupes=False): - for sliver in slivers: - node_elem = self.get_node_element(sliver['hostname'], network) - sliver_elem = node_elem.find('sliver') - if sliver_elem != None: - node_elem.remove(sliver_elem) + SFAv1Node.remove_slivers(self.xml, slivers) def add_default_sliver_attribute(self, name, value, network=None): if network: @@ -322,24 +160,26 @@ class SFAv1(BaseVersion): sliver = node.find("sliver") self.xml.remove_attribute(sliver, name, value) - def add_vlink(self, fromhost, tohost, kbps, network=None): - fromnode = self.get_node_element(fromhost, network) - tonode = self.get_node_element(tohost, network) - links = self.get_link(fromnode, tonode, network) + # Links - for link in links: - vlink = etree.SubElement(link, "vlink") - fromid = fromnode.get("id") - toid = tonode.get("id") - vlink.set("endpoints", "%s %s" % (fromid, toid)) - self.xml.add_attribute(vlink, "kbps", kbps) + def get_links(self, network=None): + return PGv2Link.get_links(self.xml) + def get_link_requests(self): + return PGv2Link.get_link_requests(self.xml) - def remove_vlink(self, endpoints, network=None): - vlinks = self.query_vlinks(endpoints, network) - for vlink in vlinks: - vlink.getparent().remove(vlink) + def add_links(self, links): + networks = self.get_networks() + if len(networks) > 0: + xml = networks[0] + else: + xml = self.xml + PGv2Link.add_links(xml, links) + + def add_link_requests(self, links): + PGv2Link.add_link_requests(self.xml, links) + # utility def merge(self, in_rspec): """ @@ -358,11 +198,11 @@ class SFAv1(BaseVersion): # just copy over all networks current_networks = self.get_networks() - networks = rspec.version.get_network_elements() + networks = rspec.version.get_networks() for network in networks: current_network = network.get('name') if current_network and current_network not in current_networks: - self.xml.root.append(network) + self.xml.append(network.element) current_networks.append(current_network) if __name__ == '__main__': diff --git a/sfa/util/sfatime.py b/sfa/util/sfatime.py index 11cc566b..c5c6a557 100644 --- a/sfa/util/sfatime.py +++ b/sfa/util/sfatime.py @@ -1,6 +1,7 @@ from types import StringTypes import dateutil.parser import datetime +import time from sfa.util.sfalogging import logger @@ -24,3 +25,5 @@ For safety this can also handle inputs that are either timestamps, or datetimes else: logger.error("Unexpected type in utcparse [%s]"%type(input)) +def epochparse(input): + return time.strftime("%Y-%d-%m-T%H:%M:%SZ", time.localtime(input)) diff --git a/sfa/util/xml.py b/sfa/util/xml.py index b2aea13b..bb298a3f 100755 --- a/sfa/util/xml.py +++ b/sfa/util/xml.py @@ -40,19 +40,31 @@ class XpathFilter: class XmlNode: def __init__(self, node, namespaces): self.node = node + self.text = node.text self.namespaces = namespaces self.attrib = node.attrib + def xpath(self, xpath, namespaces=None): if not namespaces: namespaces = self.namespaces - return self.node.xpath(xpath, namespaces=namespaces) + elems = self.node.xpath(xpath, namespaces=namespaces) + return [XmlNode(elem, namespaces) for elem in elems] - def add_element(name, *args, **kwds): - element = etree.SubElement(name, args, kwds) + def add_element(self, tagname, **kwds): + element = etree.SubElement(self.node, tagname, **kwds) return XmlNode(element, self.namespaces) - def remove_elements(name): + def append(self, elem): + if isinstance(elem, XmlNode): + self.node.append(elem.node) + else: + self.node.append(elem) + + def getparent(self): + return XmlNode(self.node.getparent(), self.namespaces) + + def remove_elements(self, name): """ Removes all occurences of an element from the tree. Start at specified root_node if specified, otherwise start at tree's root. @@ -65,6 +77,17 @@ class XmlNode: parent = element.getparent() parent.remove(element) + def remove(self, element): + if isinstance(element, XmlNode): + self.node.remove(element.node) + else: + self.node.remove(element) + + def get(self, key, *args): + return self.node.get(key, *args) + + def items(self): return self.node.items() + def set(self, key, value): self.node.set(key, value) @@ -73,7 +96,10 @@ class XmlNode: def unset(self, key): del self.node.attrib[key] - + + def iterchildren(self): + return self.node.iterchildren() + def toxml(self): return etree.tostring(self.node, encoding='UTF-8', pretty_print=True) @@ -91,7 +117,7 @@ class XML: self.parse_xml(xml) if isinstance(xml, XmlNode): self.root = xml - self.namespces = xml.namespaces + self.namespaces = xml.namespaces elif isinstance(xml, etree._ElementTree) or isinstance(xml, etree._Element): self.parse_xml(etree.tostring(xml)) @@ -194,22 +220,15 @@ class XML: node = self.root node.remove_attribute(name) - - def add_element(self, name, attrs={}, parent=None, text=""): + def add_element(self, name, **kwds): """ Wrapper around etree.SubElement(). Adds an element to specified parent node. Adds element to root node is parent is not specified. """ - if parent == None: - parent = self.root - element = etree.SubElement(parent, name) - if text: - element.text = text - if isinstance(attrs, dict): - for attr in attrs: - element.set(attr, attrs[attr]) - return XmlNode(element, self.namespaces) + parent = self.root + xmlnode = parent.add_element(name, *kwds) + return xmlnode def remove_elements(self, name, node = None): """ @@ -251,6 +270,12 @@ class XML: attrs['child_nodes'] = list(elem) return attrs + def append(self, elem): + return self.root.append(elem) + + def iterchildren(self): + return self.root.iterchildren() + def merge(self, in_xml): pass @@ -258,7 +283,7 @@ class XML: return self.toxml() def toxml(self): - return etree.tostring(self.root, encoding='UTF-8', pretty_print=True) + return etree.tostring(self.root.node, encoding='UTF-8', pretty_print=True) # XXX smbaker, for record.load_from_string def todict(self, elem=None):