fix bug in GetVersion(). Use sfa.plc.api.get_server() to establish a connection the...
[sfa.git] / sfa / managers / aggregate_manager_eucalyptus.py
index 26e5742..ea8f2af 100644 (file)
@@ -1,7 +1,7 @@
 from __future__ import with_statement 
 
 import sys
 from __future__ import with_statement 
 
 import sys
-import os
+import os, errno
 import logging
 import datetime
 
 import logging
 import datetime
 
@@ -14,15 +14,20 @@ from lxml import etree as ET
 from sqlobject import *
 
 from sfa.util.faults import *
 from sqlobject import *
 
 from sfa.util.faults import *
-from sfa.util.xrn import urn_to_hrn
+from sfa.util.xrn import urn_to_hrn, Xrn
 from sfa.util.rspec import RSpec
 from sfa.server.registry import Registries
 from sfa.trust.credential import Credential
 from sfa.plc.api import SfaAPI
 from sfa.util.rspec import RSpec
 from sfa.server.registry import Registries
 from sfa.trust.credential import Credential
 from sfa.plc.api import SfaAPI
+from sfa.plc.aggregate import Aggregate
+from sfa.plc.slices import *
 from sfa.util.plxrn import hrn_to_pl_slicename, slicename_to_hrn
 from sfa.util.callids import Callids
 from sfa.util.plxrn import hrn_to_pl_slicename, slicename_to_hrn
 from sfa.util.callids import Callids
+from sfa.util.sfalogging import logger
+from sfa.rspecs.sfa_rspec import sfa_rspec_version
+from sfa.util.version import version_core
 
 
-from threading import Thread
+from multiprocessing import Process
 from time import sleep
 
 ##
 from time import sleep
 
 ##
@@ -113,6 +118,7 @@ def init_server():
     fileHandler = logging.FileHandler('/var/log/euca.log')
     fileHandler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
     logger.addHandler(fileHandler)
     fileHandler = logging.FileHandler('/var/log/euca.log')
     fileHandler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
     logger.addHandler(fileHandler)
+    fileHandler.setLevel(logging.DEBUG)
     logger.setLevel(logging.DEBUG)
 
     configParser = ConfigParser()
     logger.setLevel(logging.DEBUG)
 
     configParser = ConfigParser()
@@ -156,11 +162,11 @@ def init_server():
     sqlhub.processConnection = conn
     Slice.createTable(ifNotExists=True)
     EucaInstance.createTable(ifNotExists=True)
     sqlhub.processConnection = conn
     Slice.createTable(ifNotExists=True)
     EucaInstance.createTable(ifNotExists=True)
-    IP.createTable(ifNotExists=True)
+    Meta.createTable(ifNotExists=True)
 
 
-    # Start the update thread to keep track of the meta data
+    # Start the update process to keep track of the meta data
     # about Eucalyptus instance.
     # about Eucalyptus instance.
-    Thread(target=updateMeta).start()
+    Process(target=updateMeta).start()
 
     # Make sure the schema exists.
     if not os.path.exists(EUCALYPTUS_RSPEC_SCHEMA):
 
     # Make sure the schema exists.
     if not os.path.exists(EUCALYPTUS_RSPEC_SCHEMA):
@@ -215,37 +221,31 @@ def getEucaConnection():
 # @param sliceHRN The hunman readable name of the slice.
 # @return sting()
 #
 # @param sliceHRN The hunman readable name of the slice.
 # @return sting()
 #
-def getKeysForSlice(sliceHRN):
-    logger = logging.getLogger('EucaAggregate')
-    try:
-        # convert hrn to slice name
-        plSliceName = hrn_to_pl_slicename(sliceHRN)
-    except IndexError, e:
-        logger.error('Invalid slice name (%s)' % sliceHRN)
-        return []
-
-    # Get the slice's information
-    sliceData = api.plshell.GetSlices(api.plauth, {'name':plSliceName})
-    if not sliceData:
-        logger.warn('Cannot get any data for slice %s' % plSliceName)
+# This method is no longer needed because the user keys are passed into
+# CreateSliver
+#
+def getKeysForSlice(api, sliceHRN):
+    logger   = logging.getLogger('EucaAggregate')
+    cred     = api.getCredential()
+    registry = api.registries[api.hrn]
+    keys     = []
+
+    # Get the slice record
+    records = registry.Resolve(sliceHRN, cred)
+    if not records:
+        logging.warn('Cannot find any record for slice %s' % sliceHRN)
         return []
 
         return []
 
-    # It should only return a list with len = 1
-    sliceData = sliceData[0]
+    # Find who can log into this slice
+    persons = records[0]['persons']
 
 
-    keys = []
-    person_ids = sliceData['person_ids']
-    if not person_ids: 
-        logger.warn('No users in slice %s' % sliceHRN)
-        return []
+    # Extract the keys from persons records
+    for p in persons:
+        sliceUser = registry.Resolve(p, cred)
+        userKeys = sliceUser[0]['keys']
+        keys += userKeys
 
 
-    persons = api.plshell.GetPersons(api.plauth, person_ids)
-    for person in persons:
-        pkeys = api.plshell.GetKeys(api.plauth, person['key_ids'])
-        for key in pkeys:
-            keys.append(key['key'])
-    return ''.join(keys)
+    return '\n'.join(keys)
 
 ##
 # A class that builds the RSpec for Eucalyptus.
 
 ##
 # A class that builds the RSpec for Eucalyptus.
@@ -374,7 +374,7 @@ class EucaRSpecBuilder(object):
         xml = self.eucaRSpec
         cloud = self.cloudInfo
         with xml.RSpec(type='eucalyptus'):
         xml = self.eucaRSpec
         cloud = self.cloudInfo
         with xml.RSpec(type='eucalyptus'):
-            with xml.cloud(id=cloud['name']):
+            with xml.network(name=cloud['name']):
                 with xml.ipv4:
                     xml << cloud['ip']
                 #self.__keyPairsXML(cloud['keypairs'])
                 with xml.ipv4:
                     xml << cloud['ip']
                 #self.__keyPairsXML(cloud['keypairs'])
@@ -520,12 +520,21 @@ def ListResources(api, creds, options, call_id):
 """
 Hook called via 'sfi.py create'
 """
 """
 Hook called via 'sfi.py create'
 """
-def CreateSliver(api, xrn, creds, xml, users, call_id):
+def CreateSliver(api, slice_xrn, creds, xml, users, call_id):
     if Callids().already_handled(call_id): return ""
 
     global cloud
     if Callids().already_handled(call_id): return ""
 
     global cloud
-    hrn = urn_to_hrn(xrn)[0]
     logger = logging.getLogger('EucaAggregate')
     logger = logging.getLogger('EucaAggregate')
+    logger.debug("In CreateSliver")
+
+    aggregate = Aggregate(api)
+    slices = Slices(api)
+    (hrn, type) = urn_to_hrn(slice_xrn)
+    peer = slices.get_peer(hrn)
+    sfa_peer = slices.get_sfa_peer(hrn)
+    slice_record=None
+    if users:
+        slice_record = users[0].get('slice_record', {})
 
     conn = getEucaConnection()
     if not conn:
 
     conn = getEucaConnection()
     if not conn:
@@ -536,12 +545,27 @@ def CreateSliver(api, xrn, creds, xml, users, call_id):
     schemaXML = ET.parse(EUCALYPTUS_RSPEC_SCHEMA)
     rspecValidator = ET.RelaxNG(schemaXML)
     rspecXML = ET.XML(xml)
     schemaXML = ET.parse(EUCALYPTUS_RSPEC_SCHEMA)
     rspecValidator = ET.RelaxNG(schemaXML)
     rspecXML = ET.XML(xml)
+    for network in rspecXML.iterfind("./network"):
+        if network.get('name') != cloud['name']:
+            # Throw away everything except my own RSpec
+            # sfa_logger().error("CreateSliver: deleting %s from rspec"%network.get('id'))
+            network.getparent().remove(network)
     if not rspecValidator(rspecXML):
         error = rspecValidator.error_log.last_error
         message = '%s (line %s)' % (error.message, error.line) 
     if not rspecValidator(rspecXML):
         error = rspecValidator.error_log.last_error
         message = '%s (line %s)' % (error.message, error.line) 
-        # XXX: InvalidRSpec is new. Currently, I am not working with Trunk code.
-        #raise InvalidRSpec(message)
-        raise Exception(message)
+        raise InvalidRSpec(message)
+
+    """
+    Create the sliver[s] (slice) at this aggregate.
+    Verify HRN and initialize the slice record in PLC if necessary.
+    """
+
+    # ensure site record exists
+    site = slices.verify_site(hrn, slice_record, peer, sfa_peer)
+    # ensure slice record exists
+    slice = slices.verify_slice(hrn, slice_record, peer, sfa_peer)
+    # ensure person records exists
+    persons = slices.verify_persons(hrn, slice, users, peer, sfa_peer)
 
     # Get the slice from db or create one.
     s = Slice.select(Slice.q.slice_hrn == hrn).getOne(None)
 
     # Get the slice from db or create one.
     s = Slice.select(Slice.q.slice_hrn == hrn).getOne(None)
@@ -552,24 +576,30 @@ def CreateSliver(api, xrn, creds, xml, users, call_id):
     pendingRmInst = []
     for sliceInst in s.instances:
         pendingRmInst.append(sliceInst.instance_id)
     pendingRmInst = []
     for sliceInst in s.instances:
         pendingRmInst.append(sliceInst.instance_id)
-    existingInstGroup = rspecXML.findall('.//euca_instances')
+    existingInstGroup = rspecXML.findall(".//euca_instances")
     for instGroup in existingInstGroup:
         for existingInst in instGroup:
             if existingInst.get('id') in pendingRmInst:
                 pendingRmInst.remove(existingInst.get('id'))
     for inst in pendingRmInst:
     for instGroup in existingInstGroup:
         for existingInst in instGroup:
             if existingInst.get('id') in pendingRmInst:
                 pendingRmInst.remove(existingInst.get('id'))
     for inst in pendingRmInst:
-        logger.debug('Instance %s will be terminated' % inst)
         dbInst = EucaInstance.select(EucaInstance.q.instance_id == inst).getOne(None)
         dbInst = EucaInstance.select(EucaInstance.q.instance_id == inst).getOne(None)
-        # Only change the state but do not remove the entry from the DB.
-        dbInst.meta.state = 'deleted'
-        #dbInst.destroySelf()
-    conn.terminate_instances(pendingRmInst)
+        if dbInst.meta.state != 'deleted':
+            logger.debug('Instance %s will be terminated' % inst)
+            # Terminate instances one at a time for robustness
+            conn.terminate_instances([inst])
+            # Only change the state but do not remove the entry from the DB.
+            dbInst.meta.state = 'deleted'
+            #dbInst.destroySelf()
 
     # Process new instance requests
 
     # Process new instance requests
-    requests = rspecXML.findall('.//request')
+    requests = rspecXML.findall(".//request")
     if requests:
         # Get all the public keys associate with slice.
     if requests:
         # Get all the public keys associate with slice.
-        pubKeys = getKeysForSlice(s.slice_hrn)
+        keys = []
+        for user in users:
+            keys += user['keys']
+            logger.debug("Keys: %s" % user['keys'])
+        pubKeys = '\n'.join(keys)
         logger.debug('Passing the following keys to the instance:\n%s' % pubKeys)
     for req in requests:
         vmTypeElement = req.getparent()
         logger.debug('Passing the following keys to the instance:\n%s' % pubKeys)
     for req in requests:
         vmTypeElement = req.getparent()
@@ -601,12 +631,46 @@ def CreateSliver(api, xrn, creds, xml, users, call_id):
     return xml
 
 ##
     return xml
 
 ##
-# A thread that will update the meta data.
+# Return information on the IP addresses bound to each slice's instances
+#
+def dumpInstanceInfo():
+    logger = logging.getLogger('EucaMeta')
+    outdir = "/var/www/html/euca/"
+    outfile = outdir + "instances.txt"
+
+    try:
+        os.makedirs(outdir)
+    except OSError, e:
+        if e.errno != errno.EEXIST:
+            raise
+
+    dbResults = Meta.select(
+        AND(Meta.q.pri_addr != None,
+            Meta.q.state    == 'running')
+        )
+    dbResults = list(dbResults)
+    f = open(outfile, "w")
+    for r in dbResults:
+        instId = r.instance.instance_id
+        ipaddr = r.pri_addr
+        hrn = r.instance.slice.slice_hrn
+        logger.debug('[dumpInstanceInfo] %s %s %s' % (instId, ipaddr, hrn))
+        f.write("%s %s %s\n" % (instId, ipaddr, hrn))
+    f.close()
+
+##
+# A separate process that will update the meta data.
 #
 def updateMeta():
 #
 def updateMeta():
-    logger = logging.getLogger('EucaAggregate')
+    logger = logging.getLogger('EucaMeta')
+    fileHandler = logging.FileHandler('/var/log/euca_meta.log')
+    fileHandler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
+    logger.addHandler(fileHandler)
+    fileHandler.setLevel(logging.DEBUG)
+    logger.setLevel(logging.DEBUG)
+
     while True:
     while True:
-        sleep(120)
+        sleep(30)
 
         # Get IDs of the instances that don't have IPs yet.
         dbResults = Meta.select(
 
         # Get IDs of the instances that don't have IPs yet.
         dbResults = Meta.select(
@@ -614,11 +678,13 @@ def updateMeta():
                           Meta.q.state    != 'deleted')
                     )
         dbResults = list(dbResults)
                           Meta.q.state    != 'deleted')
                     )
         dbResults = list(dbResults)
-        logger.debug('[update thread] dbResults: %s' % dbResults)
+        logger.debug('[update process] dbResults: %s' % dbResults)
         instids = []
         for r in dbResults:
         instids = []
         for r in dbResults:
+            if not r.instance:
+                continue
             instids.append(r.instance.instance_id)
             instids.append(r.instance.instance_id)
-        logger.debug('[update thread] Instance Id: %s' % ', '.join(instids))
+        logger.debug('[update process] Instance Id: %s' % ', '.join(instids))
 
         # Get instance information from Eucalyptus
         conn = getEucaConnection()
 
         # Get instance information from Eucalyptus
         conn = getEucaConnection()
@@ -630,18 +696,33 @@ def updateMeta():
         # Check the IPs
         instIPs = [ {'id':i.id, 'pri_addr':i.private_dns_name, 'pub_addr':i.public_dns_name}
                     for i in vmInstances if i.private_dns_name != '0.0.0.0' ]
         # Check the IPs
         instIPs = [ {'id':i.id, 'pri_addr':i.private_dns_name, 'pub_addr':i.public_dns_name}
                     for i in vmInstances if i.private_dns_name != '0.0.0.0' ]
-        logger.debug('[update thread] IP dict: %s' % str(instIPs))
+        logger.debug('[update process] IP dict: %s' % str(instIPs))
 
         # Update the local DB
         for ipData in instIPs:
             dbInst = EucaInstance.select(EucaInstance.q.instance_id == ipData['id']).getOne(None)
             if not dbInst:
 
         # Update the local DB
         for ipData in instIPs:
             dbInst = EucaInstance.select(EucaInstance.q.instance_id == ipData['id']).getOne(None)
             if not dbInst:
-                logger.info('[update thread] Could not find %s in DB' % ipData['id'])
+                logger.info('[update process] Could not find %s in DB' % ipData['id'])
                 continue
             dbInst.meta.pri_addr = ipData['pri_addr']
             dbInst.meta.pub_addr = ipData['pub_addr']
             dbInst.meta.state    = 'running'
 
                 continue
             dbInst.meta.pri_addr = ipData['pri_addr']
             dbInst.meta.pub_addr = ipData['pub_addr']
             dbInst.meta.state    = 'running'
 
+        dumpInstanceInfo()
+
+def GetVersion(api):
+    xrn=Xrn(api.hrn)
+    request_rspec_versions = [dict(sfa_rspec_version)]
+    ad_rspec_versions = [dict(sfa_rspec_version)]
+    version_more = {'interface':'aggregate',
+                    'testbed':'myplc',
+                    'hrn':xrn.get_hrn(),
+                    'request_rspec_versions': request_rspec_versions,
+                    'ad_rspec_versions': ad_rspec_versions,
+                    'default_ad_rspec': dict(sfa_rspec_version)
+                    }
+    return version_core(version_more)
+
 def main():
     init_server()
 
 def main():
     init_server()
 
@@ -652,7 +733,11 @@ def main():
 
     #rspec = ListResources('euca', 'planetcloud.pc.test', 'planetcloud.pc.marcoy', 'test_euca')
     #print rspec
 
     #rspec = ListResources('euca', 'planetcloud.pc.test', 'planetcloud.pc.marcoy', 'test_euca')
     #print rspec
-    print getKeysForSlice('gc.gc.test1')
+
+    server_key_file = '/var/lib/sfa/authorities/server.key'
+    server_cert_file = '/var/lib/sfa/authorities/server.cert'
+    api = SfaAPI(key_file = server_key_file, cert_file = server_cert_file, interface='aggregate')
+    print getKeysForSlice(api, 'gc.gc.test1')
 
 if __name__ == "__main__":
     main()
 
 if __name__ == "__main__":
     main()