foreign_xrefs should be a list and not a dict anymore
[plcapi.git] / PLC / Sites.py
index 733c7c1..a30c49b 100644 (file)
@@ -3,12 +3,14 @@ import string
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
+from PLC.Filter import Filter
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.Slices import Slice, Slices
 from PLC.PCUs import PCU, PCUs
 from PLC.Nodes import Node, Nodes
 from PLC.NodeGroups import NodeGroup, NodeGroups
+from PLC.Addresses import Address, Addresses
 import PLC.Persons
 
 class Site(Row):
@@ -18,37 +20,54 @@ class Site(Row):
     dict. Commit to the database with sync().
     """
 
+    table_name = 'sites'
+    primary_key = 'site_id'
     fields = {
         'site_id': Parameter(int, "Site identifier"),
         'name': Parameter(str, "Full site name", max = 254),
         'abbreviated_name': Parameter(str, "Abbreviated site name", max = 50),
         'login_base': Parameter(str, "Site slice prefix", max = 20),
         'is_public': Parameter(bool, "Publicly viewable site"),
-        'latitude': Parameter(float, "Decimal latitude of the site", min = -90.0, max = 90.0),
-        'longitude': Parameter(float, "Decimal longitude of the site", min = -180.0, max = 180.0),
-        'url': Parameter(str, "URL of a page that describes the site", max = 254),
-        'date_created': Parameter(str, "Date and time when site entry was created"),        
-        'last_updated': Parameter(str, "Date and time when site entry was last updated"),        
-        'deleted': Parameter(bool, "Has been deleted"),
+        'latitude': Parameter(float, "Decimal latitude of the site", min = -90.0, max = 90.0, nullok = True),
+        'longitude': Parameter(float, "Decimal longitude of the site", min = -180.0, max = 180.0, nullok = True),
+        'url': Parameter(str, "URL of a page that describes the site", max = 254, nullok = True),
+        'date_created': Parameter(int, "Date and time when site entry was created, in seconds since UNIX epoch", ro = True),
+        'last_updated': Parameter(int, "Date and time when site entry was last updated, in seconds since UNIX epoch", ro = True),
         'max_slices': Parameter(int, "Maximum number of slices that the site is able to create"),
+        'max_slivers': Parameter(int, "Maximum number of slivers that the site is able to create"),
         'person_ids': Parameter([int], "List of account identifiers"),
-        # 'slice_ids': Parameter([int], "List of slice identifiers"),
-        # 'pcu_ids': Parameter([int], "List of PCU identifiers"),
+        'slice_ids': Parameter([int], "List of slice identifiers"),
+        'address_ids': Parameter([int], "List of address identifiers"),
+        'pcu_ids': Parameter([int], "List of PCU identifiers"),
         'node_ids': Parameter([int], "List of site node identifiers"),
+        'peer_id': Parameter(int, "Peer at which this slice was created", nullok = True),
         }
 
-    def __init__(self, api, fields):
-        Row.__init__(self, fields)
-        self.api = api
+    # for Cache
+    class_key = 'login_base'
+    foreign_fields = ['abbreviated_name', 'name', 'is_public', 'latitude', 'longitude',
+                     'url', 'date_created', 'last_updated', 'max_slices', 'max_slivers',
+                     ]
+    foreign_xrefs = []
+
+    def validate_name(self, name):
+        if not len(name):
+            raise PLCInvalidArgument, "Name must be specified"
+
+        return name
+
+    validate_abbreviated_name = validate_name
 
     def validate_login_base(self, login_base):
-        if not set(login_base).issubset(string.ascii_letters):
-            raise PLCInvalidArgument, "Login base must consist only of ASCII letters"
+        if not len(login_base):
+            raise PLCInvalidArgument, "Login base must be specified"
+
+        if not set(login_base).issubset(string.ascii_letters.lower()):
+            raise PLCInvalidArgument, "Login base must consist only of lowercase ASCII letters"
 
-        login_base = login_base.lower()
         conflicts = Sites(self.api, [login_base])
-        for site_id, site in conflicts.iteritems():
-            if not site['deleted'] and ('site_id' not in self or self['site_id'] != site_id):
+        for site in conflicts:
+            if 'site_id' not in self or self['site_id'] != site['site_id']:
                 raise PLCInvalidArgument, "login_base already in use"
 
         return login_base
@@ -67,6 +86,12 @@ class Site(Row):
 
         return longitude
 
+    # timestamps
+    def validate_date_created (self, timestamp):
+       return self.validate_timestamp (timestamp)
+    def validate_last_updated (self, timestamp):
+       return self.validate_timestamp (timestamp)
+
     def add_person(self, person, commit = True):
         """
         Add person to existing site.
@@ -78,17 +103,18 @@ class Site(Row):
 
         site_id = self['site_id']
         person_id = person['person_id']
-        self.api.db.do("INSERT INTO person_site (person_id, site_id)" \
-                       " VALUES(%(person_id)d, %(site_id)d)",
-                       locals())
 
-        if commit:
-            self.api.db.commit()
+        if person_id not in self['person_ids']:
+            assert site_id not in person['site_ids']
 
-        if 'person_ids' in self and person_id not in self['person_ids']:
-            self['person_ids'].append(person_id)
+            self.api.db.do("INSERT INTO person_site (person_id, site_id)" \
+                           " VALUES(%(person_id)d, %(site_id)d)",
+                           locals())
+
+            if commit:
+                self.api.db.commit()
 
-        if 'site_ids' in person and site_id not in person['site_ids']:
+            self['person_ids'].append(person_id)
             person['site_ids'].append(site_id)
 
     def remove_person(self, person, commit = True):
@@ -102,69 +128,65 @@ class Site(Row):
 
         site_id = self['site_id']
         person_id = person['person_id']
-        self.api.db.do("DELETE FROM person_site" \
-                       " WHERE person_id = %(person_id)d" \
-                       " AND site_id = %(site_id)d",
-                       locals())
 
-        if commit:
-            self.api.db.commit()
+        if person_id in self['person_ids']:
+            assert site_id in person['site_ids']
 
-        if 'person_ids' in self and person_id in self['person_ids']:
-            self['person_ids'].remove(person_id)
+            self.api.db.do("DELETE FROM person_site" \
+                           " WHERE person_id = %(person_id)d" \
+                           " AND site_id = %(site_id)d",
+                           locals())
 
-        if 'site_ids' in person and site_id in person['site_ids']:
+            if commit:
+                self.api.db.commit()
+
+            self['person_ids'].remove(person_id)
             person['site_ids'].remove(site_id)
 
-    def sync(self, commit = True):
+    def add_address(self, address, commit = True):
+        """
+        Add address to existing site.
+        """
+
+        assert 'site_id' in self
+        assert isinstance(address, Address)
+        assert 'address_id' in address
+
+        site_id = self['site_id']
+        address_id = address['address_id']
+
+        if address_id not in self['address_ids']:
+            self.api.db.do("INSERT INTO site_address (address_id, site_id)" \
+                           " VALUES(%(address_id)d, %(site_id)d)",
+                           locals())
+
+            if commit:
+                self.api.db.commit()
+
+            self['address_ids'].append(address_id)
+
+    def remove_address(self, address, commit = True):
         """
-        Flush changes back to the database.
+        Remove address from existing site.
         """
 
-        self.validate()
-
-        try:
-            if not self['name'] or \
-               not self['abbreviated_name'] or \
-               not self['login_base']:
-                raise KeyError
-        except KeyError:
-            raise PLCInvalidArgument, "name, abbreviated_name, and login_base must all be specified"
-
-        # Fetch a new site_id if necessary
-        if 'site_id' not in self:
-            rows = self.api.db.selectall("SELECT NEXTVAL('sites_site_id_seq') AS site_id")
-            if not rows:
-                raise PLCDBError, "Unable to fetch new site_id"
-            self['site_id'] = rows[0]['site_id']
-            insert = True
-        else:
-            insert = False
-
-        # Filter out fields that cannot be set or updated directly
-        sites_fields = self.api.db.fields('sites')
-        fields = dict(filter(lambda (key, value): key in sites_fields,
-                             self.items()))
-
-        # Parameterize for safety
-        keys = fields.keys()
-        values = [self.api.db.param(key, value) for (key, value) in fields.items()]
-
-        if insert:
-            # Insert new row in sites table
-            sql = "INSERT INTO sites (%s) VALUES (%s)" % \
-                  (", ".join(keys), ", ".join(values))
-        else:
-            # Update existing row in sites table
-            columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
-            sql = "UPDATE sites SET " + \
-                  ", ".join(columns) + \
-                  " WHERE site_id = %(site_id)d"
-
-        self.api.db.do(sql, fields)
-
-        if commit:
-            self.api.db.commit()
+        assert 'site_id' in self
+        assert isinstance(address, Address)
+        assert 'address_id' in address
+
+        site_id = self['site_id']
+        address_id = address['address_id']
+
+        if address_id in self['address_ids']:
+            self.api.db.do("DELETE FROM site_address" \
+                           " WHERE address_id = %(address_id)d" \
+                           " AND site_id = %(site_id)d",
+                           locals())
+
+            if commit:
+                self.api.db.commit()
+
+            self['address_ids'].remove(address_id)
 
     def delete(self, commit = True):
         """
@@ -176,32 +198,36 @@ class Site(Row):
         # Delete accounts of all people at the site who are not
         # members of at least one other non-deleted site.
         persons = PLC.Persons.Persons(self.api, self['person_ids'])
-        for person_id, person in persons.iteritems():
+        for person in persons:
             delete = True
 
             person_sites = Sites(self.api, person['site_ids'])
-            for person_site_id, person_site in person_sites.iteritems():
-                if person_site_id != self['site_id'] and \
-                   not person_site['deleted']:
+            for person_site in person_sites:
+                if person_site['site_id'] != self['site_id']:
                     delete = False
                     break
 
             if delete:
                 person.delete(commit = False)
 
+        # Delete all site addresses
+        addresses = Addresses(self.api, self['address_ids'])
+        for address in addresses:
+            address.delete(commit = False)
+
         # Delete all site slices
-        slices = Slices(self.api, self['slice_ids'])
-        # for slice in slices.values():
-        #    slice.delete(commit = False)
+        slices = Slices(self.api, self['slice_ids'])
+        for slice in slices:
+            slice.delete(commit = False)
 
         # Delete all site PCUs
-        pcus = PCUs(self.api, self['pcu_ids'])
-        # for pcu in pcus.values():
-        #    pcu.delete(commit = False)
+        pcus = PCUs(self.api, self['pcu_ids'])
+        for pcu in pcus:
+            pcu.delete(commit = False)
 
         # Delete all site nodes
         nodes = Nodes(self.api, self['node_ids'])
-        for node in nodes.values():
+        for node in nodes:
             node.delete(commit = False)
 
         # Clean up miscellaneous join tables
@@ -217,36 +243,24 @@ class Site(Row):
 class Sites(Table):
     """
     Representation of row(s) from the sites table in the
-    database. Specify extra_fields to be able to view and modify extra
-    fields.
+    database.
     """
 
-    def __init__(self, api, site_id_or_login_base_list = None, fields = Site.fields):
-        self.api = api
+    def __init__(self, api, site_filter = None, columns = None):
+        Table.__init__(self, api, Site, columns)
 
         sql = "SELECT %s FROM view_sites WHERE deleted IS False" % \
-              ", ".join(fields)
-
-        if site_id_or_login_base_list:
-            # Separate the list into integers and strings
-            site_ids = filter(lambda site_id: isinstance(site_id, (int, long)),
-                              site_id_or_login_base_list)
-            login_bases = filter(lambda login_base: isinstance(login_base, StringTypes),
-                                 site_id_or_login_base_list)
-            sql += " AND (False"
-            if site_ids:
-                sql += " OR site_id IN (%s)" % ", ".join(map(str, site_ids))
-            if login_bases:
-                sql += " OR login_base IN (%s)" % ", ".join(api.db.quote(login_bases))
-            sql += ")"
-
-        rows = self.api.db.selectall(sql)
-
-        for row in rows:
-            self[row['site_id']] = site = Site(api, row)
-            for aggregate in ['person_ids', 'slice_ids',
-                              'defaultattribute_ids', 'pcu_ids', 'node_ids']:
-                if not site.has_key(aggregate) or site[aggregate] is None:
-                    site[aggregate] = []
-                else:
-                    site[aggregate] = map(int, site[aggregate].split(','))
+              ", ".join(self.columns)
+
+        if site_filter is not None:
+            if isinstance(site_filter, (list, tuple, set)):
+                # Separate the list into integers and strings
+                ints = filter(lambda x: isinstance(x, (int, long)), site_filter)
+                strs = filter(lambda x: isinstance(x, StringTypes), site_filter)
+                site_filter = Filter(Site.fields, {'site_id': ints, 'login_base': strs})
+                sql += " AND (%s)" % site_filter.sql(api, "OR")
+            elif isinstance(site_filter, dict):
+                site_filter = Filter(Site.fields, site_filter)
+                sql += " AND (%s)" % site_filter.sql(api, "AND")
+
+        self.selectall(sql)