python: Upgrade daemon module to argparse.
[sliver-openvswitch.git] / debian / ovs-monitor-ipsec
old mode 100755 (executable)
new mode 100644 (file)
index 019da23..0c1d6a8
 #    adding an interface to racoon.conf.
 
 
-import getopt
+import argparse
 import glob
-import logging, logging.handlers
+import logging
+import logging.handlers
 import os
 import socket
 import subprocess
 import sys
 
+import ovs.dirs
 from ovs.db import error
 from ovs.db import types
 import ovs.util
@@ -51,9 +53,12 @@ try:
 except socket.error, e:
     logging.basicConfig()
     s_log.warn("failed to connect to syslog (%s)" % e)
+s_log.addHandler(logging.StreamHandler())
 
+root_prefix = ''                # Prefix for absolute file names, for testing.
 setkey = "/usr/sbin/setkey"
 
+
 # Class to configure the racoon daemon, which handles IKE negotiation
 class Racoon:
     # Default locations for files
@@ -120,11 +125,12 @@ path certificate "%s";
         self.psk_hosts = {}
         self.cert_hosts = {}
 
-        if not os.path.isdir(self.cert_dir):
+        if not os.path.isdir(root_prefix + self.cert_dir):
             os.mkdir(self.cert_dir)
 
         # Clean out stale peer certs from previous runs
-        for ovs_cert in glob.glob("%s/ovs-*.pem" % self.cert_dir):
+        for ovs_cert in glob.glob("%s%s/ovs-*.pem"
+                                  % (root_prefix, self.cert_dir)):
             try:
                 os.remove(ovs_cert)
             except OSError:
@@ -134,20 +140,22 @@ path certificate "%s";
         self.commit()
 
     def reload(self):
-        exitcode = subprocess.call(["/etc/init.d/racoon", "reload"])
+        exitcode = subprocess.call([root_prefix + "/etc/init.d/racoon",
+                                    "reload"])
         if exitcode != 0:
             # Racoon is finicky about its configuration file and will
             # refuse to start if it sees something it doesn't like
             # (e.g., a certificate file doesn't exist).  Try restarting
             # the process before giving up.
             s_log.warning("attempting to restart racoon")
-            exitcode = subprocess.call(["/etc/init.d/racoon", "restart"])
+            exitcode = subprocess.call([root_prefix + "/etc/init.d/racoon",
+                                        "restart"])
             if exitcode != 0:
                 s_log.warning("couldn't reload racoon")
 
     def commit(self):
         # Rewrite the Racoon configuration file
-        conf_file = open(self.conf_file, 'w')
+        conf_file = open(root_prefix + self.conf_file, 'w')
         conf_file.write(Racoon.conf_header % (self.psk_file, self.cert_dir))
 
         for host, vals in self.cert_hosts.iteritems():
@@ -162,7 +170,7 @@ path certificate "%s";
 
         # Rewrite the pre-shared keys file; it must only be readable by root.
         orig_umask = os.umask(0077)
-        psk_file = open(Racoon.psk_file, 'w')
+        psk_file = open(root_prefix + Racoon.psk_file, 'w')
         os.umask(orig_umask)
 
         psk_file.write("# Generated by Open vSwitch...do not modify by hand!")
@@ -183,10 +191,10 @@ path certificate "%s";
     def _verify_certs(self, vals):
         # Racoon will refuse to start if the certificate files don't
         # exist, so verify that they're there.
-        if not os.path.isfile(vals["certificate"]):
+        if not os.path.isfile(root_prefix + vals["certificate"]):
             raise error.Error("'certificate' file does not exist: %s"
                     % vals["certificate"])
-        elif not os.path.isfile(vals["private_key"]):
+        elif not os.path.isfile(root_prefix + vals["private_key"]):
             raise error.Error("'private_key' file does not exist: %s"
                     % vals["private_key"])
 
@@ -196,14 +204,13 @@ path certificate "%s";
         if vals["peer_cert"].find("-----BEGIN CERTIFICATE-----") == -1:
             raise error.Error("'peer_cert' is not in valid PEM format")
 
-        cert = open(vals["certificate"]).read()
+        cert = open(root_prefix + vals["certificate"]).read()
         if cert.find("-----BEGIN CERTIFICATE-----") == -1:
             raise error.Error("'certificate' is not in valid PEM format")
 
-        cert = open(vals["private_key"]).read()
+        cert = open(root_prefix + vals["private_key"]).read()
         if cert.find("-----BEGIN RSA PRIVATE KEY-----") == -1:
             raise error.Error("'private_key' is not in valid PEM format")
-            
 
     def _add_cert(self, host, vals):
         if host in self.psk_hosts:
@@ -212,7 +219,7 @@ path certificate "%s";
         if vals["certificate"] == None:
             raise error.Error("'certificate' not defined for %s" % host)
         elif vals["private_key"] == None:
-            # Assume the private key is stored in the same PEM file as 
+            # Assume the private key is stored in the same PEM file as
             # the certificate.  We make a copy of "vals" so that we don't
             # modify the original "vals", which would cause the script
             # to constantly think that the configuration has changed
@@ -225,7 +232,7 @@ path certificate "%s";
         # The peer's certificate comes to us in PEM format as a string.
         # Write that string to a file for Racoon to use.
         peer_cert_file = "%s/ovs-%s.pem" % (self.cert_dir, host)
-        f = open(peer_cert_file, "w")
+        f = open(root_prefix + peer_cert_file, "w")
         f.write(vals["peer_cert"])
         f.close()
 
@@ -239,7 +246,7 @@ path certificate "%s";
         del self.cert_hosts[host]
         self.commit()
         try:
-            os.remove(peer_cert_file)
+            os.remove(root_prefix + peer_cert_file)
         except OSError:
             pass
 
@@ -269,10 +276,11 @@ class IPsec:
 
     def call_setkey(self, cmds):
         try:
-            p = subprocess.Popen([setkey, "-c"], stdin=subprocess.PIPE, 
-                    stdout=subprocess.PIPE)
+            p = subprocess.Popen([root_prefix + setkey, "-c"],
+                                 stdin=subprocess.PIPE,
+                                 stdout=subprocess.PIPE)
         except:
-            s_log.error("could not call setkey")
+            s_log.error("could not call %s%s" % (root_prefix, setkey))
             sys.exit(1)
 
         # xxx It is safer to pass the string into the communicate()
@@ -288,18 +296,18 @@ class IPsec:
         # older entry could be in a "dying" state.
         spi_list = []
         host_line = "%s %s" % (local_ip, remote_ip)
-        results = self.call_setkey("dump ;").split("\n")
+        results = self.call_setkey("dump ;\n").split("\n")
         for i in range(len(results)):
             if results[i].strip() == host_line:
                 # The SPI is in the line following the host pair
-                spi_line = results[i+1]
+                spi_line = results[i + 1]
                 if (spi_line[1:4] == proto):
                     spi = spi_line.split()[2]
                     spi_list.append(spi.split('(')[1].rstrip(')'))
         return spi_list
 
     def sad_flush(self):
-        self.call_setkey("flush;")
+        self.call_setkey("flush;\n")
 
     def sad_del(self, local_ip, remote_ip):
         # To delete all SAD entries, we should be able to use setkey's
@@ -321,18 +329,18 @@ class IPsec:
             self.call_setkey(cmds)
 
     def spd_flush(self):
-        self.call_setkey("spdflush;")
+        self.call_setkey("spdflush;\n")
 
     def spd_add(self, local_ip, remote_ip):
         cmds = ("spdadd %s %s gre -P out ipsec esp/transport//require;\n" %
                     (local_ip, remote_ip))
-        cmds += ("spdadd %s %s gre -P in ipsec esp/transport//require;" %
+        cmds += ("spdadd %s %s gre -P in ipsec esp/transport//require;\n" %
                     (remote_ip, local_ip))
         self.call_setkey(cmds)
 
     def spd_del(self, local_ip, remote_ip):
         cmds = "spddelete %s %s gre -P out;\n" % (local_ip, remote_ip)
-        cmds += "spddelete %s %s gre -P in;" % (remote_ip, local_ip)
+        cmds += "spddelete %s %s gre -P in;\n" % (remote_ip, local_ip)
         self.call_setkey(cmds)
 
     def add_entry(self, local_ip, remote_ip, vals):
@@ -345,7 +353,6 @@ class IPsec:
 
         self.entries.append(remote_ip)
 
-
     def del_entry(self, local_ip, remote_ip):
         if remote_ip in self.entries:
             self.racoon.del_entry(remote_ip)
@@ -375,15 +382,16 @@ def keep_table_columns(schema, table_name, column_types):
         new_columns[column_name] = column
     table.columns = new_columns
     return table
-def monitor_uuid_schema_cb(schema):
+
+
+def prune_schema(schema):
     string_type = types.Type(types.BaseType(types.StringType))
     optional_ssl_type = types.Type(types.BaseType(types.UuidType,
-                                                  ref_table='SSL'), None, 0, 1)
+        ref_table_name='SSL'), None, 0, 1)
     string_map_type = types.Type(types.BaseType(types.StringType),
                                  types.BaseType(types.StringType),
                                  0, sys.maxint)
+
     new_tables = {}
     new_tables["Interface"] = keep_table_columns(
         schema, "Interface", {"name": string_type,
@@ -396,14 +404,7 @@ def monitor_uuid_schema_cb(schema):
                         "private_key": string_type})
     schema.tables = new_tables
 
-def usage():
-    print "usage: %s [OPTIONS] DATABASE" % sys.argv[0]
-    print "where DATABASE is a socket on which ovsdb-server is listening."
-    ovs.daemon.usage()
-    print "Other options:"
-    print "  -h, --help               display this help message"
-    sys.exit(0)
+
 def update_ipsec(ipsec, interfaces, new_interfaces):
     for name, vals in interfaces.iteritems():
         if name not in new_interfaces:
@@ -424,38 +425,37 @@ def update_ipsec(ipsec, interfaces, new_interfaces):
         except error.Error, msg:
             s_log.warning("skipping ipsec config for %s: %s" % (name, msg))
 
+
 def get_ssl_cert(data):
-    for ovs_rec in data["Open_vSwitch"].itervalues():
-        if ovs_rec.ssl.as_list():
-            ssl_rec = data["SSL"][ovs_rec.ssl.as_scalar()]
-            return (ssl_rec.certificate.as_scalar(),
-                    ssl_rec.private_key.as_scalar())
+    for ovs_rec in data["Open_vSwitch"].rows.itervalues():
+        ssl = ovs_rec.ssl
+        if ssl and ssl.certificate and ssl.private_key:
+            return (ssl.certificate, ssl.private_key)
 
     return None
 
-def main(argv):
-    try:
-        options, args = getopt.gnu_getopt(
-            argv[1:], 'h', ['help'] + ovs.daemon.LONG_OPTIONS)
-    except getopt.GetoptError, geo:
-        sys.stderr.write("%s: %s\n" % (ovs.util.PROGRAM_NAME, geo.msg))
-        sys.exit(1)
-    for key, value in options:
-        if key in ['-h', '--help']:
-            usage()
-        elif not ovs.daemon.parse_opt(key, value):
-            sys.stderr.write("%s: unhandled option %s\n"
-                             % (ovs.util.PROGRAM_NAME, key))
-            sys.exit(1)
-    if len(args) != 1:
-        sys.stderr.write("%s: exactly one nonoption argument is required "
-                         "(use --help for help)\n" % ovs.util.PROGRAM_NAME)
-        sys.exit(1)
 
-    remote = args[0]
-    idl = ovs.db.idl.Idl(remote, "Open_vSwitch", monitor_uuid_schema_cb)
+def main():
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("database", metavar="DATABASE",
+                        help="A socket on which ovsdb-server is listening.")
+    parser.add_argument("--root-prefix", metavar="DIR",
+                        help="Use DIR as alternate root directory"
+                        " (for testing).")
+
+    ovs.daemon.add_args(parser)
+    args = parser.parse_args()
+    ovs.daemon.handle_args(args)
+
+    global root_prefix
+    root_prefix = args.root_prefix
+
+    remote = args.database
+    schema_file = "%s/vswitch.ovsschema" % ovs.dirs.PKGDATADIR
+    schema = ovs.db.schema.DbSchema.from_json(ovs.json.from_file(schema_file))
+    prune_schema(schema)
+    idl = ovs.db.idl.Idl(remote, schema)
 
     ovs.daemon.daemonize()
 
@@ -469,27 +469,28 @@ def main(argv):
             poller.block()
             continue
 
-        ssl_cert = get_ssl_cert(idl.data)
+        ssl_cert = get_ssl_cert(idl.tables)
+
         new_interfaces = {}
-        for rec in idl.data["Interface"].itervalues():
-            if rec.type.as_scalar() == "ipsec_gre":
-                name = rec.name.as_scalar()
+        for rec in idl.tables["Interface"].rows.itervalues():
+            if rec.type == "ipsec_gre":
+                name = rec.name
+                options = rec.options
                 entry = {
-                    "remote_ip": rec.options.get("remote_ip"),
-                    "local_ip": rec.options.get("local_ip", "0.0.0.0/0"),
-                    "certificate": rec.options.get("certificate"),
-                    "private_key": rec.options.get("private_key"),
-                    "use_ssl_cert": rec.options.get("use_ssl_cert"),
-                    "peer_cert": rec.options.get("peer_cert"),
-                    "psk": rec.options.get("psk") }
+                    "remote_ip": options.get("remote_ip"),
+                    "local_ip": options.get("local_ip", "0.0.0.0/0"),
+                    "certificate": options.get("certificate"),
+                    "private_key": options.get("private_key"),
+                    "use_ssl_cert": options.get("use_ssl_cert"),
+                    "peer_cert": options.get("peer_cert"),
+                    "psk": options.get("psk")}
 
                 if entry["peer_cert"] and entry["psk"]:
-                    s_log.warning("both 'peer_cert' and 'psk' defined for %s" 
+                    s_log.warning("both 'peer_cert' and 'psk' defined for %s"
                             % name)
                     continue
                 elif not entry["peer_cert"] and not entry["psk"]:
-                    s_log.warning("no 'peer_cert' or 'psk' defined for %s" 
+                    s_log.warning("no 'peer_cert' or 'psk' defined for %s"
                             % name)
                     continue
 
@@ -504,14 +505,15 @@ def main(argv):
                     entry["private_key"] = ssl_cert[1]
 
                 new_interfaces[name] = entry
+
         if interfaces != new_interfaces:
             update_ipsec(ipsec, interfaces, new_interfaces)
             interfaces = new_interfaces
+
+
 if __name__ == '__main__':
     try:
-        main(sys.argv)
+        main()
     except SystemExit:
         # Let system.exit() calls complete normally
         raise