move clean_policy.py into monitor package
[monitor.git] / monitor / model.py
diff --git a/monitor/model.py b/monitor/model.py
new file mode 100755 (executable)
index 0000000..ce941f2
--- /dev/null
@@ -0,0 +1,768 @@
+#!/usr/bin/python
+
+from monitor import database
+
+from monitor.wrapper import plc, plccache
+from monitor.wrapper import mailer
+import time
+
+from monitor.const import *
+from monitor import util
+from monitor import config
+
+import time
+from datetime import datetime, timedelta
+import re
+
+def gethostlist(hostlist_file):
+       return util.file.getListFromFile(hostlist_file)
+
+def array_to_priority_map(array):
+       """ Create a mapping where each entry of array is given a priority equal
+       to its position in the array.  This is useful for subsequent use in the
+       cmpMap() function."""
+       map = {}
+       count = 0
+       for i in array:
+               map[i] = count
+               count += 1
+       return map
+
+def cmpValMap(v1, v2, map):
+       if v1 in map and v2 in map and map[v1] < map[v2]:
+               return 1
+       elif v1 in map and v2 in map and map[v1] > map[v2]:
+               return -1
+       elif v1 in map and v2 in map:
+               return 0
+       else:
+               raise Exception("No index %s or %s in map" % (v1, v2))
+
+def cmpCategoryVal(v1, v2):
+       # Terrible hack to manage migration to no more 'ALPHA' states.
+       if v1 == 'ALPHA': v1 = "PROD"
+       if v2 == 'ALPHA': v2 = "PROD"
+       #map = array_to_priority_map([ None, 'PROD', 'ALPHA', 'OLDBOOTCD', 'UNKNOWN', 'FORCED', 'ERROR', ])
+       map = array_to_priority_map([ None, 'ALPHA', 'PROD', 'OLDBOOTCD', 'UNKNOWN', 'FORCED', 'ERROR', ])
+       return cmpValMap(v1,v2,map)
+
+
+class PCU:
+       def __init__(self, hostname):
+               self.hostname = hostname
+
+       def reboot(self):
+               return True
+       def available(self):
+               return True
+       def previous_attempt(self):
+               return True
+       def setValidMapping(self):
+               pass
+
+class Penalty:
+       def __init__(self, key, valuepattern, action):
+               pass
+
+class PenaltyMap:
+       def __init__(self):
+               pass
+
+       # connect one penalty to another, in a FSM diagram.  After one
+       #       condition/penalty is applied, move to the next phase.
+
+
+class RT(object):
+       def __init__(self, ticket_id = None):
+               self.ticket_id = ticket_id
+               if self.ticket_id:
+                       print "getting ticket status",
+                       self.status = mailer.getTicketStatus(self.ticket_id)
+                       print self.status
+
+       def setTicketStatus(self, status):
+               mailer.setTicketStatus(self.ticket_id, status)
+               self.status = mailer.getTicketStatus(self.ticket_id)
+               return True
+       
+       def getTicketStatus(self):
+               if not self.status:
+                       self.status = mailer.getTicketStatus(self.ticket_id)
+               return self.status
+
+       def closeTicket(self):
+               mailer.closeTicketViaRT(self.ticket_id, "Ticket CLOSED automatically by SiteAssist.") 
+
+       def email(self, subject, body, to):
+               self.ticket_id = mailer.emailViaRT(subject, body, to, self.ticket_id)
+               return self.ticket_id
+
+class Message(object):
+       def __init__(self, subject, message, via_rt=True, ticket_id=None, **kwargs):
+               self.via_rt = via_rt
+               self.subject = subject
+               self.message = message
+               self.rt = RT(ticket_id)
+
+       def send(self, to):
+               if self.via_rt:
+                       return self.rt.email(self.subject, self.message, to)
+               else:
+                       return mailer.email(self.subject, self.message, to)
+
+class Recent(object):
+       def __init__(self, withintime):
+               self.withintime = withintime
+
+               try:
+                       self.time = self.__getattribute__('time')
+               except:
+                       self.time = time.time()- 7*24*60*60
+
+               #self.time = time.time()
+               #self.action_taken = False
+
+       def isRecent(self):
+               if self.time + self.withintime < time.time():
+                       self.action_taken = False
+
+               if self.time + self.withintime > time.time() and self.action_taken:
+                       return True
+               else:
+                       return False
+
+       def unsetRecent(self):
+               self.action_taken = False
+               self.time = time.time()
+               return True
+
+       def setRecent(self):
+               self.action_taken = True
+               self.time = time.time()
+               return True
+               
+class PersistFlags(Recent):
+       def __new__(typ, id, *args, **kwargs):
+               if 'db' in kwargs:
+                       db = kwargs['db']
+                       del kwargs['db']
+               else:
+                       db = "persistflags"
+
+               try:
+                       pm = database.dbLoad(db)
+               except:
+                       database.dbDump(db, {})
+                       pm = database.dbLoad(db)
+               #print pm
+               if id in pm:
+                       obj = pm[id]
+               else:
+                       obj = super(PersistFlags, typ).__new__(typ, *args, **kwargs)
+                       for key in kwargs.keys():
+                               obj.__setattr__(key, kwargs[key])
+                       obj.time = time.time()
+                       obj.action_taken = False
+
+               obj.db = db
+               return obj
+
+       def __init__(self, id, withintime, **kwargs):
+               self.id = id
+               Recent.__init__(self, withintime)
+
+       def save(self):
+               pm = database.dbLoad(self.db)
+               pm[self.id] = self
+               database.dbDump(self.db, pm)
+
+       def resetFlag(self, name):
+               self.__setattr__(name, False)
+
+       def setFlag(self, name):
+               self.__setattr__(name, True)
+               
+       def getFlag(self, name):
+               try:
+                       return self.__getattribute__(name)
+               except:
+                       self.__setattr__(name, False)
+                       return False
+
+       def resetRecentFlag(self, name):
+               self.resetFlag(name)
+               self.unsetRecent()
+
+       def setRecentFlag(self, name):
+               self.setFlag(name)
+               self.setRecent()
+
+       def getRecentFlag(self, name):
+               # if recent and flag set -> true
+               # else false
+               try:
+                       return self.isRecent() & self.__getattribute__(name)
+               except:
+                       self.__setattr__(name, False)
+                       return False
+
+       def checkattr(self, name):
+               try:
+                       x = self.__getattribute__(name)
+                       return True
+               except:
+                       return False
+               
+
+class PersistMessage(Message):
+       def __new__(typ, id, subject, message, via_rt, **kwargs):
+               if 'db' in kwargs:
+                       db = kwargs['db']
+               else:
+                       db = "persistmessages"
+
+               try:
+                       pm = database.dbLoad(db)
+               except:
+                       database.dbDump(db, {})
+                       pm = database.dbLoad(db)
+
+               #print pm
+               if id in pm:
+                       #print "Using existing object"
+                       obj = pm[id]
+               else:
+                       #print "creating new object"
+                       obj = super(PersistMessage, typ).__new__(typ, [id, subject, message, via_rt], **kwargs)
+                       obj.id = id
+                       obj.actiontracker = Recent(1*60*60*24)
+                       obj.ticket_id = None
+
+               if 'ticket_id' in kwargs and kwargs['ticket_id'] is not None:
+                       obj.ticket_id = kwargs['ticket_id']
+
+               obj.db = db
+               return obj
+
+       def __init__(self, id, subject, message, via_rt=True, **kwargs):
+               print "initializing object: %s" % self.ticket_id
+               self.id = id
+               Message.__init__(self, subject, message, via_rt, self.ticket_id)
+
+       def reset(self):
+               self.actiontracker.unsetRecent()
+
+       def save(self):
+               pm = database.dbLoad(self.db)
+               pm[self.id] = self
+               database.dbDump(self.db, pm)
+
+       def send(self, to):
+               if not self.actiontracker.isRecent():
+                       self.ticket_id = Message.send(self, to)
+                       self.actiontracker.setRecent()
+                       self.save()
+               else:
+                       # NOTE: only send a new message every week, regardless.
+                       # NOTE: can cause thank-you messages to be lost, for instance when node comes back online within window.
+                       print "Not sending to host b/c not within window of %s days" % (self.actiontracker.withintime // (60*60*24))
+
+class MonitorMessage(object):
+       def __new__(typ, id, *args, **kwargs):
+               if 'db' in kwargs:
+                       db = kwargs['db']
+               else:
+                       db = "monitormessages"
+
+               try:
+                       if 'reset' in kwargs and kwargs['reset'] == True:
+                               database.dbDump(db, {})
+                       pm = database.dbLoad(db)
+               except:
+                       database.dbDump(db, {})
+                       pm = database.dbLoad(db)
+
+               #print pm
+               if id in pm:
+                       print "Using existing object"
+                       obj = pm[id]
+               else:
+                       print "creating new object"
+                       obj = super(object, typ).__new__(typ, id, *args, **kwargs)
+                       obj.id = id
+                       obj.sp = PersistSitePenalty(id, 0)
+
+               obj.db = db
+               return obj
+
+       def __init__(self, id, message):
+               pass
+               
+
+class SitePenalty(object):
+       penalty_map = [] 
+       penalty_map.append( { 'name': 'noop',                   'enable'   : lambda host: None,
+                                                                                                       'disable'  : lambda host: None } )
+       penalty_map.append( { 'name': 'nocreate',               'enable'   : lambda host: plc.removeSliceCreation(host),
+                                                                                                       'disable'  : lambda host: plc.enableSliceCreation(host) } )
+       penalty_map.append( { 'name': 'suspendslices',  'enable'   : lambda host: plc.suspendSlices(host),
+                                                                                                       'disable'  : lambda host: plc.enableSlices(host) } )
+
+       #def __init__(self, index=0, **kwargs):
+       #       self.index = index
+
+       def get_penalties(self):
+               # TODO: get penalties actually applied to a node from PLC DB.
+               return [ n['name'] for n in SitePenalty.penalty_map ] 
+
+       def increase(self):
+               self.index = self.index + 1
+               if self.index > len(SitePenalty.penalty_map)-1: self.index = len(SitePenalty.penalty_map)-1
+               return True
+
+       def decrease(self):
+               self.index = self.index - 1
+               if self.index < 0: self.index = 0
+               return True
+
+       def apply(self, host):
+
+               for i in range(len(SitePenalty.penalty_map)-1,self.index,-1):
+                       print "\tdisabling %s on %s" % (SitePenalty.penalty_map[i]['name'], host)
+                       SitePenalty.penalty_map[i]['disable'](host)
+
+               for i in range(0,self.index+1):
+                       print "\tapplying %s on %s" % (SitePenalty.penalty_map[i]['name'], host)
+                       SitePenalty.penalty_map[i]['enable'](host)
+
+               return
+
+
+
+class PersistSitePenalty(SitePenalty):
+       def __new__(typ, id, index, **kwargs):
+               if 'db' in kwargs:
+                       db = kwargs['db']
+               else:
+                       db = "persistpenalties"
+
+               try:
+                       if 'reset' in kwargs and kwargs['reset'] == True:
+                               database.dbDump(db, {})
+                       pm = database.dbLoad(db)
+               except:
+                       database.dbDump(db, {})
+                       pm = database.dbLoad(db)
+
+               #print pm
+               if id in pm:
+                       print "Using existing object"
+                       obj = pm[id]
+               else:
+                       print "creating new object"
+                       obj = super(PersistSitePenalty, typ).__new__(typ, [index], **kwargs)
+                       obj.id = id
+                       obj.index = index
+
+               obj.db = db
+               return obj
+
+       def __init__(self, id, index, **kwargs):
+               self.id = id
+
+       def save(self):
+               pm = database.dbLoad(self.db)
+               pm[self.id] = self
+               database.dbDump(self.db, pm)
+
+
+class Target:
+       """
+               Each host has a target set of attributes.  Some may be set manually,
+               or others are set globally for the preferred target.
+
+               For instance:
+                       All nodes in the Alpha or Beta group would have constraints like:
+                               [ { 'state' : 'BOOT', 'kernel' : '2.6.22' } ]
+       """
+       def __init__(self, constraints):
+               self.constraints = constraints
+
+       def verify(self, data):
+               """
+                       self.constraints is a list of key, value pairs.
+                       # [ {... : ...}==AND , ... , ... , ] == OR
+               """
+               con_or_true = False
+               for con in self.constraints:
+                       #print "con: %s" % con
+                       con_and_true = True
+                       for key in con.keys():
+                               #print "looking at key: %s" % key
+                               if key in data: 
+                                       #print "%s %s" % (con[key], data[key])
+                                       con_and_true = con_and_true & (con[key] in data[key])
+                               elif key not in data:
+                                       print "missing key %s" % key
+                                       con_and_true = False
+
+                       con_or_true = con_or_true | con_and_true
+
+               return con_or_true
+
+class Record(object):
+
+       def __init__(self, hostname, data):
+               self.hostname = hostname
+               self.data = data
+               self.plcdb_hn2lb = plccache.plcdb_hn2lb
+               self.loginbase = self.plcdb_hn2lb[self.hostname]
+               return
+
+
+       def stageIswaitforever(self):
+               if 'waitforever' in self.data['stage']:
+                       return True
+               else:
+                       return False
+
+       def severity(self):
+               category = self.data['category']
+               prev_category = self.data['prev_category']
+               #print "SEVERITY: ", category, prev_category
+               val = cmpCategoryVal(category, prev_category)
+               return val 
+
+       def improved(self):
+               return self.severity() > 0
+       
+       def end_record(self):
+               return node_end_record(self.hostname)
+
+       def reset_stage(self):
+               self.data['stage'] = 'findbad'
+               return True
+       
+       def getCategory(self):
+               return self.data['category'].lower()
+
+       def getState(self):
+               return self.data['state'].lower()
+
+       def getDaysDown(cls, diag_record):
+               daysdown = -1
+               if diag_record['comonstats']['uptime'] != "null" and diag_record['comonstats']['uptime'] != "-1":
+                       daysdown = - int(float(diag_record['comonstats']['uptime'])) // (60*60*24)
+               #elif diag_record['comonstats']['sshstatus'] != "null":
+               #       daysdown = int(diag_record['comonstats']['sshstatus']) // (60*60*24)
+               #elif diag_record['comonstats']['lastcotop'] != "null":
+               #       daysdown = int(diag_record['comonstats']['lastcotop']) // (60*60*24)
+               else:
+                       now = time.time()
+                       last_contact = diag_record['plcnode']['last_contact']
+                       if last_contact == None:
+                               # the node has never been up, so give it a break
+                               daysdown = -1
+                       else:
+                               diff = now - last_contact
+                               daysdown = diff // (60*60*24)
+               return daysdown
+       getDaysDown = classmethod(getDaysDown)
+
+       def getStrDaysDown(cls, diag_record):
+               daysdown = "unknown"
+               last_contact = diag_record['plcnode']['last_contact']
+               date_created = diag_record['plcnode']['date_created']
+
+               if      diag_record['comonstats']['uptime'] != "null" and \
+                       diag_record['comonstats']['uptime'] != "-1":
+                       daysdown = int(float(diag_record['comonstats']['uptime'])) // (60*60*24)
+                       daysdown = "%d days up" % daysdown
+
+               elif last_contact is None:
+                       if date_created is not None:
+                               now = time.time()
+                               diff = now - date_created
+                               daysdown = diff // (60*60*24)
+                               daysdown = "Never contacted PLC, created %s days ago" % daysdown
+                       else:
+                               daysdown = "Never contacted PLC"
+               else:
+                       now = time.time()
+                       diff = now - last_contact
+                       daysdown = diff // (60*60*24)
+                       daysdown = "%s days down" % daysdown
+               return daysdown
+       getStrDaysDown = classmethod(getStrDaysDown)
+
+       def getSendEmailFlag(self):
+               if not config.mail:
+                       return False
+
+               # resend if open & created longer than 30 days ago.
+               if  'rt' in self.data and \
+                       'Status' in self.data['rt'] and \
+                       "open" in self.data['rt']['Status'] and \
+                       self.data['rt']['Created'] > int(time.time() - 60*60*24*30):
+                       # if created-time is greater than the thirty days ago from the current time
+                       return False
+
+               return True
+
+       def getMostRecentStage(self):
+               lastact = self.data['last_action_record']
+               return lastact.stage
+
+       def getMostRecentTime(self):
+               lastact = self.data['last_action_record']
+               return lastact.date_action_taken
+
+       def takeAction(self, index=0):
+               pp = PersistSitePenalty(self.hostname, 0, db='persistpenalty_hostnames')
+               if 'improvement' in self.data['stage'] or self.improved() or \
+                       'monitor-end-record' in self.data['stage']:
+                       print "takeAction: decreasing penalty for %s"%self.hostname
+                       pp.decrease()
+                       pp.decrease()
+               else:
+                       print "takeAction: increasing penalty for %s"%self.hostname
+                       pp.increase()
+               pp.index = index
+               pp.apply(self.hostname)
+               pp.save()
+
+       def _format_diaginfo(self):
+               info = self.data['info']
+               print "FORMAT : STAGE: ", self.data['stage']
+               if self.data['stage'] == 'monitor-end-record':
+                       if info[2] == "ALPHA": info = (info[0], info[1], "PROD")
+                       hlist = "    %s went from '%s' to '%s'\n" % (info[0], info[1], info[2]) 
+               else:
+                       hlist = "    %s %s - %s\n" % (info[0], info[2], info[1]) #(node,ver,daysdn)
+               return hlist
+       def saveAction(self):
+               if 'save_act_all' in self.data and self.data['save_act_all'] == True:
+                       return True
+               else:
+                       return False
+
+       def getMessage(self, ticket_id=None):
+               self.data['args']['hostname'] = self.hostname
+               self.data['args']['loginbase'] = self.loginbase
+               self.data['args']['hostname_list'] = self._format_diaginfo()
+               #print self.data['message']
+               if self.data['message']:
+                       message = PersistMessage(self.hostname, 
+                                                                self.data['message'][0] % self.data['args'],
+                                                                self.data['message'][1] % self.data['args'],
+                                                                True, db='monitor_persistmessages',
+                                                                ticket_id=ticket_id)
+                       if self.data['stage'] == "improvement":
+                               message.reset()
+                       return message
+               else:
+                       return None
+       
+       def getContacts(self):
+               roles = self.data['email']
+
+               if not config.mail and not config.debug and config.bcc:
+                       roles = ADMIN
+               if config.mail and config.debug:
+                       roles = ADMIN
+
+               # build targets
+               contacts = []
+               if ADMIN & roles:
+                       contacts += [config.email]
+               if TECH & roles:
+                       #contacts += [TECHEMAIL % self.loginbase]
+                       contacts += plc.getTechEmails(self.loginbase)
+               if PI & roles:
+                       #contacts += [PIEMAIL % self.loginbase]
+                       contacts += plc.getSliceUserEmails(self.loginbase)
+               if USER & roles:
+                       contacts += plc.getSliceUserEmails(self.loginbase)
+                       slices = plc.slices(self.loginbase)
+                       if len(slices) >= 1:
+                               #for slice in slices:
+                               #       contacts += [SLICEMAIL % slice]
+                               print "SLIC: %20s : %d slices" % (self.loginbase, len(slices))
+                       else:
+                               print "SLIC: %20s : 0 slices" % self.loginbase
+
+               return contacts
+
+
+class NodeRecord:
+       def __init__(self, hostname, target):
+               self.hostname = hostname
+               self.ticket = None
+               self.target = target
+
+class Action(MonRecord):
+       def __init__(self, host, data):
+               self.host = host
+               MonRecord.__init__(self, data)
+               return
+
+       def deltaDays(self, delta):
+               t = datetime.fromtimestamp(self.__dict__['time'])
+               d = t + timedelta(delta)
+               self.__dict__['time'] = time.mktime(d.timetuple())
+               
+def node_end_record(node):
+       act_all = database.dbLoad("act_all")
+       if node not in act_all:
+               del act_all
+               return False
+
+       if len(act_all[node]) == 0:
+               del act_all
+               return False
+
+       pm = database.dbLoad("monitor_persistmessages")
+       if node not in pm:
+               del pm
+               return False
+       else:
+               print "deleting node record"
+               del pm[node]
+               database.dbDump("monitor_persistmessages", pm)
+
+       a = Action(node, act_all[node][0])
+       a.delField('rt')
+       a.delField('found_rt_ticket')
+       a.delField('second-mail-at-oneweek')
+       a.delField('second-mail-at-twoweeks')
+       a.delField('first-found')
+       rec = a.get()
+       rec['action'] = ["close_rt"]
+       rec['category'] = "ALPHA"       # assume that it's up...
+       rec['stage'] = "monitor-end-record"
+       rec['ticket_id'] = None
+       rec['time'] = time.time() - 7*60*60*24
+       act_all[node].insert(0,rec)
+       database.dbDump("act_all", act_all)
+       del act_all
+       return True
+
+class MonRecord(object):
+       def __init__(self, data):
+               self.keys = data.keys()
+               self.keys.sort()
+               self.__dict__.update(data)
+               return
+
+       def get(self):
+               ret= {}
+               for k in self.keys:
+                       ret[k] = self.__dict__[k]
+               return ret
+
+       def __repr__(self):
+               str = ""
+               str += self.host + "\n"
+               for k in self.keys:
+                       if "message" in k or "msg" in k:
+                               continue
+                       if 'time' in k:
+                               s_time=time.strftime("%Y/%m/%d %H:%M:%S", 
+                                                       time.gmtime(self.__dict__[k]))
+                               str += "\t'%s' : %s\n" % (k, s_time)
+                       else:
+                               str += "\t'%s' : %s\n" % (k, self.__dict__[k])
+               str += "\t--"
+               return str
+
+       def delField(self, field):
+               if field in self.__dict__:
+                       del self.__dict__[field]
+               
+               if field in self.keys:
+                       for i in range(0,len(self.keys)):
+                               if self.keys[i] == field:
+                                       del self.keys[i]
+                                       break
+
+class LogRoll:
+       def __init__(self, list=None):
+               self.list = list
+               if self.list == None:
+                       self.list = {}
+
+       def find(self, host, filter, timerange):
+               if host not in self.list:
+                       return None
+
+               host_log_list = self.list[host]
+               for log in host_log_list:
+                       for key in filter.keys():
+                               #print "searching key %s in log keys" % key
+                               if key in log.keys:
+                                       #print "%s in log.keys" % key
+                                       cmp = re.compile(filter[key])
+                                       res = cmp.search(log.__getattribute__(key))
+                                       if res != None:
+                                               #print "found match in log: %s  %s ~=~ %s" % (log, key, filter[key])
+                                               if log.time > time.time() - timerange:
+                                                       print "returning log b/c it occured within time."
+                                                       return log
+               return None
+               
+
+       def get(self, host):
+               if host in self.list:
+                       return self.list[host][0]
+               else:
+                       return None
+
+       def add(self, log):
+               if log.host not in self.list:
+                       self.list[log.host] = []
+
+               self.list[log.host].insert(0,log)
+
+class Log(MonRecord):
+       def __init__(self, host, data):
+               self.host = host
+               MonRecord.__init__(self, data)
+               return
+
+       def __repr__(self):
+               str = " "
+               str += self.host + " : { "
+               for k in self.keys:
+                       if "message" in k or "msg" in k:
+                               continue
+                       if 'time' in k:
+                               s_time=time.strftime("%Y/%m/%d %H:%M:%S", 
+                                                       time.gmtime(self.__dict__[k]))
+                               #str += " '%s' : %s, " % (k, s_time)
+                       elif 'action' in k:
+                               str += "'%s' : %s, " % (k, self.__dict__[k])
+               str += "}"
+               return str
+       
+
+class Diagnose(MonRecord):
+       def __init__(self, host):
+               self.host = host
+               MonRecord.__init__(self, data)
+               return
+
+
+
+if __name__ == "__main__":
+       #r = RT()
+       #r.email("test", "body of test message", ['database@cs.princeton.edu'])
+       #from emailTxt import mailtxt
+       print "loaded"
+       #database.dbDump("persistmessages", {});
+       #args = {'url_list': 'http://www.planet-lab.org/bootcds/planet1.usb\n','hostname': 'planet1','hostname_list': ' blahblah -  days down\n'}
+       #m = PersistMessage("blue", "test 1", mailtxt.newdown_one[1] % args, True)
+       #m.send(['soltesz@cs.utk.edu'])
+       #m = PersistMessage("blue", "test 1 - part 2", mailtxt.newalphacd_one[1] % args, True)
+       # TRICK timer to thinking some time has passed.
+       #m.actiontracker.time = time.time() - 6*60*60*24
+       #m.send(['soltesz@cs.utk.edu'])