Ported code from create_topo_attributes.py
[sfa.git] / sfa / rspecs / aggregates / rspec_manager_vini.py
1 from sfa.util.faults import *
2 from sfa.util.misc import *
3 from sfa.util.rspec import Rspec
4 from sfa.server.registry import Registries
5 from sfa.plc.nodes import *
6 import sys
7 import socket
8
9 SFA_VINI_DEFAULT_RSPEC = '/etc/sfa/vini.rspec'
10
11 class Node:
12     def __init__(self, node):
13         self.id = node['node_id']
14         self.hostname = node['hostname']
15         self.shortname = self.hostname.replace('.vini-veritas.net', '')
16         self.site_id = node['site_id']
17         self.ipaddr = socket.gethostbyname(self.hostname)
18
19     def get_link_id(self, remote):
20         if self.id < remote.id:
21             link = (self.id<<7) + remote.id
22         else:
23             link = (remote.id<<7) + self.id
24         return link
25         
26     def get_iface_id(self, remote):
27         if self.id < remote.id:
28             iface = 1
29         else:
30             iface = 2
31         return iface
32     
33     def get_virt_ip(self, remote):
34         link = self.get_link_id(remote)
35         iface = self.get_iface_id(remote)
36         first = link >> 6
37         second = ((link & 0x3f)<<2) + iface
38         return "192.168.%d.%d" % (first, second)
39
40     def get_virt_net(self, remote):
41         link = self.get_link_id(remote)
42         first = link >> 6
43         second = (link & 0x3f)<<2
44         return "192.168.%d.%d/30" % (first, second)
45         
46     def get_site(self, sites):
47         return sites[self.site_id]
48             
49     def adjacent_nodes(self, sites, nodes, node_ids):
50         mysite = self.get_site(sites)
51         adj_ids = mysite.adj_node_ids.intersection(node_ids)
52         adj_nodes = []
53         for id in adj_ids:
54             adj_nodes.append(nodes[id])
55         return adj_nodes
56     
57     def init_rspecs(self):
58         self.rspecs = []
59         
60     def add_rspec(self, remote):
61         my_ip = self.get_virt_ip(remote)
62         remote_ip = remote.get_virt_ip(self)
63         net = self.get_virt_net(remote)
64         rspec = remote.id, remote.ipaddr, "1Mbit", my_ip, remote_ip, net
65         self.rspecs.append(rspec)
66
67         
68 class Site:
69     def __init__(self, site):
70         self.id = site['site_id']
71         self.node_ids = site['node_ids']
72         self.adj_site_ids = set()
73         self.adj_node_ids = set()
74
75     def get_sitenodes(self, nodes):
76         n = []
77         for i in self.node_ids:
78             n.append(nodes[i])
79         return n
80     
81     def add_adjacency(self, site):
82         self.adj_site_ids.add(site.id)
83         for n in site.node_ids:
84             self.adj_node_ids.add(n)
85         
86     
87 class Slice:
88     def __init__(self, slice):
89         self.id = slice['slice_id']
90         self.name = slice['name']
91         self.node_ids = set(slice['node_ids'])
92         self.slice_tag_ids = slice['slice_tag_ids']
93     
94     def get_tag(self, tagname, slicetags, node = None):
95         for i in self.slice_tag_ids:
96             tag = slicetags[i]
97             if tag.tagname == tagname:
98                 if (not node) or (node.id == tag.node_id):
99                     return tag
100         else:
101             return None
102         
103     def get_nodes(self, nodes):
104         n = []
105         for id in self.node_ids:
106             n.append(nodes[id])
107         return n
108              
109     
110     # Add a new slice tag   
111     def add_tag(self, tagname, value, slicetags, node = None):
112         record = {'slice_tag_id':None, 'slice_id':self.id, 'tagname':tagname, 'value':value}
113         if node:
114             record['node_id'] = node.id
115         else:
116             record['node_id'] = None
117         tag = Slicetag(record)
118         slicetags[tag.id] = tag
119         self.slice_tag_ids.append(tag.id)
120         tag.changed = True       
121         tag.updated = True
122         return tag
123     
124     # Update a slice tag if it exists, else add it             
125     def update_tag(self, tagname, value, slicetags, node = None):
126         tag = self.get_tag(tagname, slicetags, node)
127         if tag and tag.value == value:
128             value = "no change"
129         elif tag:
130             tag.value = value
131             tag.changed = True
132         else:
133             tag = self.add_tag(tagname, value, slicetags, node)
134         tag.updated = True
135             
136     def assign_egre_key(self, slicetags):
137         if not self.get_tag('egre_key', slicetags):
138             try:
139                 key = free_egre_key(slicetags)
140                 self.update_tag('egre_key', key, slicetags)
141             except:
142                 # Should handle this case...
143                 pass
144         return
145             
146     def turn_on_netns(self, slicetags):
147         tag = self.get_tag('netns', slicetags)
148         if (not tag) or (tag.value != '1'):
149             self.update_tag('netns', '1', slicetags)
150         return
151    
152     def turn_off_netns(self, slicetags):
153         tag = self.get_tag('netns', slicetags)
154         if tag and (tag.value != '0'):
155             tag.delete()
156         return
157     
158     def add_cap_net_admin(self, slicetags):
159         tag = self.get_tag('capabilities', slicetags)
160         if tag:
161             caps = tag.value.split(',')
162             for cap in caps:
163                 if cap == "CAP_NET_ADMIN":
164                     return
165             else:
166                 newcaps = "CAP_NET_ADMIN," + tag.value
167                 self.update_tag('capabilities', newcaps, slicetags)
168         else:
169             self.add_tag('capabilities', 'CAP_NET_ADMIN', slicetags)
170         return
171     
172     def remove_cap_net_admin(self, slicetags):
173         tag = self.get_tag('capabilities', slicetags)
174         if tag:
175             caps = tag.value.split(',')
176             newcaps = []
177             for cap in caps:
178                 if cap != "CAP_NET_ADMIN":
179                     newcaps.append(cap)
180             if newcaps:
181                 value = ','.join(newcaps)
182                 self.update_tag('capabilities', value, slicetags)
183             else:
184                 tag.delete()
185         return
186
187     # Update the vsys/setup-link and vsys/setup-nat slice tags.
188     def add_vsys_tags(self, slicetags):
189         link = nat = False
190         for i in self.slice_tag_ids:
191             tag = slicetags[i]
192             if tag.tagname == 'vsys':
193                 if tag.value == 'setup-link':
194                     link = True
195                 elif tag.value == 'setup-nat':
196                     nat = True
197         if not link:
198             self.add_tag('vsys', 'setup-link', slicetags)
199         if not nat:
200             self.add_tag('vsys', 'setup-nat', slicetags)
201         return
202
203
204 class Slicetag:
205     newid = -1 
206     def __init__(self, tag):
207         self.id = tag['slice_tag_id']
208         if not self.id:
209             # Make one up for the time being...
210             self.id = Slicetag.newid
211             Slicetag.newid -= 1
212         self.slice_id = tag['slice_id']
213         self.tagname = tag['tagname']
214         self.value = tag['value']
215         self.node_id = tag['node_id']
216         self.updated = False
217         self.changed = False
218         self.deleted = False
219     
220     # Mark a tag as deleted
221     def delete(self):
222         self.deleted = True
223         self.updated = True
224     
225     def write(self, slices, nodes, dryrun):
226         if not dryrun:
227             if self.changed:
228                 if int(self.id) > 0:
229                     UpdateSliceTag(self.id, self.value)
230                 else:
231                     AddSliceTag(self.slice_id, self.tagname, self.value, self.node_id)
232             elif self.deleted and int(self.id) > 0:
233                 DeleteSliceTag(self.id)
234         else:
235             slice = slices[self.slice_id].name
236             if self.node_id:
237                 node = nodes[tag.node_id].hostname
238             if self.updated:
239                 if self.deleted:
240                     self.value = "deleted"
241                 elif not self.changed:
242                     self.value = "no change"
243                 if int(self.id) < 0:
244                     self.id = "new"
245                 if self.node_id:
246                     print "[%s] %s: %s (%s, %s)" % (self.id, self.tagname, self.value, slice, node)
247                 else:
248                     print "[%s] %s: %s (%s)" % (self.id, self.tagname, self.value, slice)
249
250
251 """
252 Create a dictionary of site objects keyed by site ID
253 """
254 def get_sites():
255     tmp = []
256     for site in GetSites():
257         t = site['site_id'], Site(site)
258         tmp.append(t)
259     return dict(tmp)
260
261
262 """
263 Create a dictionary of node objects keyed by node ID
264 """
265 def get_nodes(api):
266     tmp = []
267     for node in api.plshell.GetNodes(api.plauth):
268         t = node['node_id'], Node(node)
269         tmp.append(t)
270     return dict(tmp)
271
272 """
273 Create a dictionary of slice objects keyed by slice ID
274 """
275 def get_slice(api, slicename):
276     slice = api.plshell.GetSlices(api.plauth, [slicename])
277     if slice:
278         return Slice(slice[0])
279     else:
280         return None
281
282 """
283 Create a dictionary of slicetag objects keyed by slice tag ID
284 """
285 def get_slice_tags(api):
286     tmp = []
287     for tag in api.plshell.GetSliceTags(api.plauth):
288         t = tag['slice_tag_id'], Slicetag(tag)
289         tmp.append(t)
290     return dict(tmp)
291     
292 """
293 Find a free EGRE key
294 """
295 def free_egre_key(slicetags):
296     used = set()
297     for i in slicetags:
298         tag = slicetags[i]
299         if tag.tagname == 'egre_key':
300             used.add(int(tag.value))
301                 
302     for i in range(1, 256):
303         if i not in used:
304             key = i
305             break
306     else:
307         raise KeyError("No more EGRE keys available")
308         
309     return "%s" % key
310    
311
312 """
313 Copied from create_slice_aggregate() in sfa.plc.slices
314 """
315 def create_slice_vini_aggregate(api, hrn, nodes):
316     # Get the slice record from geni
317     slice = {}
318     registries = Registries(api)
319     registry = registries[api.hrn]
320     credential = api.getCredential()
321     records = registry.resolve(credential, hrn)
322     for record in records:
323         if record.get_type() in ['slice']:
324             slice = record.as_dict()
325     if not slice:
326         raise RecordNotFound(hrn)   
327
328     # Make sure slice exists at plc, if it doesnt add it
329     slicename = hrn_to_pl_slicename(hrn)
330     slices = api.plshell.GetSlices(api.plauth, [slicename], ['node_ids'])
331     if not slices:
332         parts = slicename.split("_")
333         login_base = parts[0]
334         # if site doesnt exist add it
335         sites = api.plshell.GetSites(api.plauth, [login_base])
336         if not sites:
337             authority = get_authority(hrn)
338             site_records = registry.resolve(credential, authority)
339             site_record = {}
340             if not site_records:
341                 raise RecordNotFound(authority)
342             site_record = site_records[0]
343             site = site_record.as_dict()
344                 
345             # add the site
346             site.pop('site_id')
347             site_id = api.plshell.AddSite(api.plauth, site)
348         else:
349             site = sites[0]
350             
351         slice_fields = {}
352         slice_keys = ['name', 'url', 'description']
353         for key in slice_keys:
354             if key in slice and slice[key]:
355                 slice_fields[key] = slice[key]  
356         api.plshell.AddSlice(api.plauth, slice_fields)
357         slice = slice_fields
358         slice['node_ids'] = 0
359     else:
360         slice = slices[0]    
361
362     # get the list of valid slice users from the registry and make 
363     # they are added to the slice 
364     researchers = record.get('researcher', [])
365     for researcher in researchers:
366         person_record = {}
367         person_records = registry.resolve(credential, researcher)
368         for record in person_records:
369             if record.get_type() in ['user']:
370                 person_record = record
371         if not person_record:
372             pass
373         person_dict = person_record.as_dict()
374         persons = api.plshell.GetPersons(api.plauth, [person_dict['email']],
375                                          ['person_id', 'key_ids'])
376
377         # Create the person record 
378         if not persons:
379             person_id=api.plshell.AddPerson(api.plauth, person_dict)
380
381             # The line below enables the user account on the remote aggregate
382             # soon after it is created.
383             # without this the user key is not transfered to the slice
384             # (as GetSlivers returns key of only enabled users),
385             # which prevents the user from login to the slice.
386             # We may do additional checks before enabling the user.
387
388             api.plshell.UpdatePerson(api.plauth, person_id, {'enabled' : True})
389             key_ids = []
390         else:
391             key_ids = persons[0]['key_ids']
392
393         api.plshell.AddPersonToSlice(api.plauth, person_dict['email'],
394                                      slicename)        
395
396         # Get this users local keys
397         keylist = api.plshell.GetKeys(api.plauth, key_ids, ['key'])
398         keys = [key['key'] for key in keylist]
399
400         # add keys that arent already there 
401         for personkey in person_dict['keys']:
402             if personkey not in keys:
403                 key = {'key_type': 'ssh', 'key': personkey}
404                 api.plshell.AddPersonKey(api.plauth, person_dict['email'], key)
405
406     # find out where this slice is currently running
407     nodelist = api.plshell.GetNodes(api.plauth, slice['node_ids'],
408                                     ['hostname'])
409     hostnames = [node['hostname'] for node in nodelist]
410
411     # remove nodes not in rspec
412     deleted_nodes = list(set(hostnames).difference(nodes))
413     # add nodes from rspec
414     added_nodes = list(set(nodes).difference(hostnames))
415
416     api.plshell.AddSliceToNodes(api.plauth, slicename, added_nodes) 
417     api.plshell.DeleteSliceFromNodes(api.plauth, slicename, deleted_nodes)
418
419     return 1
420
421 def get_rspec(api, hrn):
422     # Get default rspec
423     default = Rspec()
424     default.parseFile(SFA_VINI_DEFAULT_RSPEC)
425     
426     if (hrn):
427         slicename = hrn_to_pl_slicename(hrn)
428         defaultrspec = default.toGenDict()
429         nodedict = get_nodedict(defaultrspec)
430
431         # call the default sfa.plc.nodes.get_rspec() method
432         nodes = Nodes(api)      
433         rspec = nodes.get_rspec(hrn)     
434
435         # Grab all the PLC info we'll need at once
436         slice = get_slice(api, slicename)
437         if slice:
438             nodes = get_nodes(api)
439             tags = get_slice_tags(api)
440
441             # Add the node tags from the Capacity statement to Node objects
442             for (k, v) in nodedict.iteritems():
443                 for id in nodes:
444                     if v == nodes[id].hostname:
445                         nodes[id].tag = k
446
447             endpoints = []
448             for node in slice.get_nodes(nodes):
449                 linktag = slice.get_tag('topo_rspec', tags, node)
450                 if linktag:
451                     l = eval(linktag.value)
452                     for (id, realip, bw, lvip, rvip, vnet) in l:
453                         endpoints.append((node.id, id, bw))
454             
455             if endpoints:
456                 linkspecs = []
457                 for (l, r, bw) in endpoints:
458                     if (r, l, bw) in endpoints:
459                         if l < r:
460                             edict = {}
461                             edict['endpoint'] = [nodes[l].tag, nodes[r].tag]
462                             edict['bw'] = [bw]
463                             linkspecs.append(edict)
464
465                 d = default.toDict()
466                 d['Rspec']['Request'][0]['NetSpec'][0]['LinkSpec'] = linkspecs
467                 d['Rspec']['Request'][0]['NetSpec'][0]['name'] = hrn
468                 new = Rspec()
469                 new.parseDict(d)
470                 rspec = new.toxml()
471     else:
472         # Return canned response for now...
473         rspec = default.toxml()
474
475     return rspec
476
477
478 def create_slice(api, hrn, xml):
479     r = Rspec()
480     r.parseString(xml)
481     rspec = r.toGenDict()
482
483     # Check request against current allocations
484     # Request OK
485
486     nodes = rspec_to_nodeset(rspec)
487     create_slice_vini_aggregate(api, hrn, nodes)
488
489     # Add VINI-specific topology attributes to slice here
490
491     return True
492
493 def get_nodedict(rspec):
494     nodedict = {}
495     try:    
496         sitespecs = rspec['Rspec'][0]['Capacity'][0]['NetSpec'][0]['SiteSpec']
497         for s in sitespecs:
498             for node in s['NodeSpec']:
499                 nodedict[node['name'][0]] = node['hostname'][0]
500     except KeyError:
501         pass
502
503     return nodedict
504
505         
506 def rspec_to_nodeset(rspec):
507     nodes = set()
508     try:
509         nodedict = get_nodedict(rspec)
510         linkspecs = rspec['Rspec'][0]['Request'][0]['NetSpec'][0]['LinkSpec']
511         for l in linkspecs:
512             for e in l['endpoint']:
513                 nodes.add(nodedict[e])
514     except KeyError:
515         # Bad Rspec
516         pass
517     
518     return nodes
519
520 def main():
521     r = Rspec()
522     r.parseFile(sys.argv[1])
523     rspec = r.toGenDict()
524     create_slice(None,'plc',rspec)
525     
526 if __name__ == "__main__":
527     main()