Support to add the ssh root pub key to the default administrator. This
[myplc.git] / db-config
1 #!/usr/bin/env /usr/bin/plcsh
2 #
3 # Bootstraps the PLC database with a default administrator account and
4 # a default site, defines default slice attribute types, and
5 # creates/updates default system slices.
6 #
7 # Mark Huang <mlhuang@cs.princeton.edu>
8 # Copyright (C) 2006 The Trustees of Princeton University
9 #
10 # $Id$
11 # $HeadURL$
12
13 from plc_config import PLCConfiguration
14 import sys, os
15 import resource
16
17 g_url = ""
18 def GetMyPLCURL(): return g_url
19 def SetMyPLCURL(url):
20     global g_url
21     g_url = url
22
23 # Get list of existing tag types
24 g_known_tag_types = [tag_type['tagname'] for tag_type in GetTagTypes()]
25 g_known_tag_types.sort()
26
27 def SetTagType(tag_type):
28     global g_known_tag_types
29     # Create/update default slice tag types
30     if tag_type['tagname'] not in g_known_tag_types:
31         AddTagType(tag_type)
32         g_known_tag_types.append(tag_type['tagname'])
33         g_known_tag_types.sort()
34     else:
35         UpdateTagType(tag_type['tagname'], tag_type)
36
37 # Get list of existing (enabled, global) files
38 g_conf_files = GetConfFiles()
39 g_conf_files = filter(lambda conf_file: conf_file['enabled'] and \
40                     not conf_file['node_ids'] and \
41                     not conf_file['nodegroup_ids'],
42                     g_conf_files)
43 g_dests = [conf_file['dest'] for conf_file in g_conf_files]
44 g_conf_files = dict(zip(g_dests, g_conf_files))
45
46 # Get list of existing initscripts
47 g_oldinitscripts = GetInitScripts()
48 g_oldinitscript_names = [script['name'] for script in g_oldinitscripts]
49 g_oldinitscripts = dict(zip(g_oldinitscript_names, g_oldinitscripts))
50
51 def SetInitScript(initscript):
52     global g_oldinitscripts, g_oldinitscript_names
53     if initscript['name'] not in g_oldinitscript_names:
54         initscript_id = AddInitScript(initscript)
55         g_oldinitscript_names.append(initscript['name'])
56         initscript['initscript_id']=initscript_id
57         g_oldinitscripts[initscript['name']]=initscript
58     else:
59         orig_initscript = g_oldinitscripts[initscript['name']]
60         initscript_id = orig_initscript['initscript_id']
61         UpdateInitScript(initscript_id, initscript)
62         
63 def SetConfFile(conf_file):
64     global g_conf_files, g_dests
65     if conf_file['dest'] not in g_dests:
66         AddConfFile(conf_file)
67     else:
68         orig_conf_file = g_conf_files[conf_file['dest']]
69         conf_file_id = orig_conf_file['conf_file_id']
70         UpdateConfFile(conf_file_id, conf_file)
71
72 def SetSlice(slice, tags):
73     # Create or Update slice
74     slice_name = slice['name']
75     slices = GetSlices([slice_name])
76     if len(slices)==1:
77         slice_id = slices[0]['slice_id']
78         if slice.has_key('name'):
79             del slice['name']
80         UpdateSlice(slice_id, slice)
81         slice['name']=slice_name
82     else:
83         expires = None
84         if slice.has_key('expires'):
85             expires = slice['expires']
86             del slice['expires']
87         slice_id = AddSlice(slice)
88         if expires <> None:
89             UpdateSlice(slice_id, {'expires':expires})
90
91     # Get slice structure with all fields
92     slice = GetSlices([slice_name])[0]
93
94     # Create/delete all tags
95     # NOTE: update is not needed, since unspecified tags are deleted, 
96     #       and new tags are added
97     slice_tags = []
98     if slice['slice_tag_ids']:
99         # Delete unknown attributes
100         for slice_tag in GetSliceTags(slice['slice_tag_ids']):
101             # ignore sliver tags, as those are custom/run-time values
102             if slice_tag['node_id'] <> None: continue
103             if (slice_tag['tagname'], slice_tag['value']) not in tags:
104                 DeleteSliceTag(slice_tag['slice_tag_id'])
105             else:
106                 slice_tags.append((slice_tag['tagname'],slice_tag['value']))
107
108     # only add slice tags that are new
109     for (name, value) in tags:
110         if (name,value) not in slice_tags:
111             AddSliceTag(slice_name, name, value)            
112         else:
113             # NOTE: this confirms that the user-specified tag is 
114             #       returned by GetSliceTags
115             pass
116
117 def SetMessage(message):
118     messages = GetMessages([message['message_id']])
119     if len(messages)==0:
120         AddMessage(message)
121     else:
122         UpdateMessage(message['message_id'],message)
123
124 # Get all model names
125 g_pcu_models = [type['model'] for type in GetPCUTypes()]
126
127 def SetPCUType(pcu_type):
128     global g_pcu_models
129     if 'pcu_protocol_types' in pcu_type:
130         protocol_types = pcu_type['pcu_protocol_types']
131         # Take this value out of the struct.
132         del pcu_type['pcu_protocol_types']
133     else:
134         protocol_types = []
135
136     if pcu_type['model'] not in g_pcu_models:
137         # Add the name/model info into DB
138         id = AddPCUType(pcu_type)
139         # for each protocol, also add this.
140         for ptype in protocol_types:
141             AddPCUProtocolType(id, ptype)
142
143 def GetSnippets(directory):
144     filenames = []
145     if os.path.exists(directory):
146         try:
147             filenames = os.listdir(directory)
148         except OSError, e:
149             raise Exception, "Error when opening %s (%s)" % \
150                   (os.path.join(dir, file), e)
151             
152     ignored = (".bak","~",".rpmsave",".rpmnew",".orig")
153     numberedfiles = {}
154     for filename in filenames:
155         shouldIgnore = False
156         for ignore in ignored:
157             if filename.endswith(ignore):
158                 shouldIgnore = True
159                 break
160
161         if not shouldIgnore:
162             parts = filename.split('-')
163             if len(parts)>=2:
164                 name = '-'.join(parts)
165                 try:
166                     number = int(parts[0])
167                     entry = numberedfiles.get(number,[])
168                     entry.append(name)
169                     numberedfiles[number]=entry
170                 except ValueError:
171                     shouldIgnore = True
172             else:
173                 shouldIgnore = True
174
175         if shouldIgnore:
176             print "db-config: ignoring %s snippet" % filename
177
178     filenames = []
179     keys = numberedfiles.keys()
180     keys.sort()
181     for k in keys:
182         for filename in numberedfiles[k]:
183             filenames.append(filename)
184     return filenames
185
186 def main():
187     cfg = PLCConfiguration()
188     cfg.load()
189     variables = cfg.variables()
190
191     # Load variables into dictionaries
192     for category_id, (category, variablelist) in variables.iteritems():
193         globals()[category_id] = dict(zip(variablelist.keys(),
194                                           [variable['value'] for variable in variablelist.values()]))
195
196     directory="/etc/planetlab/db-config.d"
197     snippets = GetSnippets(directory)
198     for snippet in snippets:
199         fullpath = os.path.join(directory, snippet)
200         execfile(fullpath)
201
202 if __name__ == '__main__':
203     main()
204
205 # Local variables:
206 # tab-width: 4
207 # mode: python
208 # End: