c89f1d88d514df2f456ba1835bd24255da0e7e88
[nepi.git] / src / nepi / util / proxy.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import nepi.core.execute
6 from nepi.core.attributes import AttributesMap, Attribute
7 from nepi.util import server, validation
8 from nepi.util.constants import TIME_NOW, ATTR_NEPI_TESTBED_ENVIRONMENT_SETUP, DeploymentConfiguration as DC
9 import getpass
10 import cPickle
11 import sys
12 import time
13 import tempfile
14 import shutil
15 import functools
16
17 # PROTOCOL REPLIES
18 OK = 0
19 ERROR = 1
20
21 # PROTOCOL INSTRUCTION MESSAGES
22 XML = 2 
23 TRACE   = 4
24 FINISHED    = 5
25 START   = 6
26 STOP    = 7
27 SHUTDOWN    = 8
28 CONFIGURE   = 9
29 CREATE      = 10
30 CREATE_SET  = 11
31 FACTORY_SET = 12
32 CONNECT     = 13
33 CROSS_CONNECT   = 14
34 ADD_TRACE   = 15
35 ADD_ADDRESS = 16
36 ADD_ROUTE   = 17
37 DO_SETUP    = 18
38 DO_CREATE   = 19
39 DO_CONNECT_INIT = 20
40 DO_CONFIGURE    = 21
41 DO_CROSS_CONNECT_INIT   = 22
42 GET = 23
43 SET = 24
44 ACTION  = 25
45 STATUS  = 26
46 GUIDS  = 27
47 GET_ROUTE = 28
48 GET_ADDRESS = 29
49 RECOVER = 30
50 DO_PRECONFIGURE     = 31
51 GET_ATTRIBUTE_LIST  = 32
52 DO_CONNECT_COMPL    = 33
53 DO_CROSS_CONNECT_COMPL  = 34
54 TESTBED_ID  = 35
55 TESTBED_VERSION  = 36
56 DO_PRESTART = 37
57 GET_FACTORY_ID = 38
58 GET_TESTBED_ID = 39
59 GET_TESTBED_VERSION = 40
60 TRACES_INFO = 41
61 EXEC_XML = 42
62
63 instruction_text = dict({
64     OK:     "OK",
65     ERROR:  "ERROR",
66     XML:    "XML",
67     EXEC_XML:    "EXEC_XML",
68     TRACE:  "TRACE",
69     FINISHED:   "FINISHED",
70     START:  "START",
71     STOP:   "STOP",
72     RECOVER: "RECOVER",
73     SHUTDOWN:   "SHUTDOWN",
74     CONFIGURE:  "CONFIGURE",
75     CREATE: "CREATE",
76     CREATE_SET: "CREATE_SET",
77     FACTORY_SET:    "FACTORY_SET",
78     CONNECT:    "CONNECT",
79     CROSS_CONNECT: "CROSS_CONNECT",
80     ADD_TRACE:  "ADD_TRACE",
81     ADD_ADDRESS:    "ADD_ADDRESS",
82     ADD_ROUTE:  "ADD_ROUTE",
83     DO_SETUP:   "DO_SETUP",
84     DO_CREATE:  "DO_CREATE",
85     DO_CONNECT_INIT: "DO_CONNECT_INIT",
86     DO_CONNECT_COMPL: "DO_CONNECT_COMPL",
87     DO_CONFIGURE:   "DO_CONFIGURE",
88     DO_PRECONFIGURE:   "DO_PRECONFIGURE",
89     DO_CROSS_CONNECT_INIT:  "DO_CROSS_CONNECT_INIT",
90     DO_CROSS_CONNECT_COMPL: "DO_CROSS_CONNECT_COMPL",
91     GET:    "GET",
92     SET:    "SET",
93     GET_ROUTE: "GET_ROUTE",
94     GET_ADDRESS: "GET_ADDRESS",
95     GET_ATTRIBUTE_LIST: "GET_ATTRIBUTE_LIST",
96     GET_FACTORY_ID: "GET_FACTORY_ID",
97     GET_TESTBED_ID: "GET_TESTBED_ID",
98     GET_TESTBED_VERSION: "GET_TESTBED_VERSION",
99     ACTION: "ACTION",
100     STATUS: "STATUS",
101     GUIDS:  "GUIDS",
102     TESTBED_ID: "TESTBED_ID",
103     TESTBED_VERSION: "TESTBED_VERSION",
104     TRACES_INFO: "TRACES_INFO",
105     })
106
107 def log_msg(server, params):
108     try:
109         instr = int(params[0])
110         instr_txt = instruction_text[instr]
111         server.log_debug("%s - msg: %s [%s]" % (server.__class__.__name__, 
112             instr_txt, ", ".join(map(str, params[1:]))))
113     except:
114         # don't die for logging
115         pass
116
117 def log_reply(server, reply):
118     try:
119         res = reply.split("|")
120         code = int(res[0])
121         code_txt = instruction_text[code]
122         try:
123             txt = base64.b64decode(res[1])
124         except:
125             txt = res[1]
126         server.log_debug("%s - reply: %s %s" % (server.__class__.__name__, 
127                 code_txt, txt))
128     except:
129         # don't die for logging
130         server.log_debug("%s - reply: %s" % (server.__class__.__name__, 
131                 reply))
132         pass
133
134 def to_server_log_level(log_level):
135     return (
136         server.DEBUG_LEVEL
137             if log_level == DC.DEBUG_LEVEL 
138         else server.ERROR_LEVEL
139     )
140
141 def get_access_config_params(access_config):
142     root_dir = access_config.get_attribute_value(DC.ROOT_DIRECTORY)
143     log_level = access_config.get_attribute_value(DC.LOG_LEVEL)
144     log_level = to_server_log_level(log_level)
145     user = host = port = agent = key = None
146     communication = access_config.get_attribute_value(DC.DEPLOYMENT_COMMUNICATION)
147     environment_setup = (
148         access_config.get_attribute_value(DC.DEPLOYMENT_ENVIRONMENT_SETUP)
149         if access_config.has_attribute(DC.DEPLOYMENT_ENVIRONMENT_SETUP)
150         else None
151     )
152     if communication == DC.ACCESS_SSH:
153         user = access_config.get_attribute_value(DC.DEPLOYMENT_USER)
154         host = access_config.get_attribute_value(DC.DEPLOYMENT_HOST)
155         port = access_config.get_attribute_value(DC.DEPLOYMENT_PORT)
156         agent = access_config.get_attribute_value(DC.USE_AGENT)
157         key = access_config.get_attribute_value(DC.DEPLOYMENT_KEY)
158     return (root_dir, log_level, user, host, port, key, agent, environment_setup)
159
160 class AccessConfiguration(AttributesMap):
161     def __init__(self, params = None):
162         super(AccessConfiguration, self).__init__()
163         
164         from nepi.core.metadata import Metadata
165         
166         for _,attr_info in Metadata.DEPLOYMENT_ATTRIBUTES.iteritems():
167             self.add_attribute(**attr_info)
168         
169         if params:
170             for attr_name, attr_value in params.iteritems():
171                 parser = Attribute.type_parsers[self.get_attribute_type(attr_name)]
172                 attr_value = parser(attr_value)
173                 self.set_attribute_value(attr_name, attr_value)
174
175 class TempDir(object):
176     def __init__(self):
177         self.path = tempfile.mkdtemp()
178     
179     def __del__(self):
180         shutil.rmtree(self.path)
181
182 class PermDir(object):
183     def __init__(self, path):
184         self.path = path
185
186 def create_experiment_controller(xml, access_config = None):
187     mode = None if not access_config \
188             else access_config.get_attribute_value(DC.DEPLOYMENT_MODE)
189     launch = True if not access_config \
190             else not access_config.get_attribute_value(DC.RECOVER)
191     if not mode or mode == DC.MODE_SINGLE_PROCESS:
192         if not launch:
193             raise ValueError, "Unsupported instantiation mode: %s with lanch=False" % (mode,)
194         
195         from nepi.core.execute import ExperimentController
196         
197         if not access_config or not access_config.has_attribute(DC.ROOT_DIRECTORY):
198             root_dir = TempDir()
199         else:
200             root_dir = PermDir(access_config.get_attribute_value(DC.ROOT_DIRECTORY))
201         controller = ExperimentController(xml, root_dir.path)
202         
203         # inject reference to temporary dir, so that it gets cleaned
204         # up at destruction time.
205         controller._tempdir = root_dir
206         
207         return controller
208     elif mode == DC.MODE_DAEMON:
209         (root_dir, log_level, user, host, port, key, agent, environment_setup) = \
210                 get_access_config_params(access_config)
211         return ExperimentControllerProxy(root_dir, log_level,
212                 experiment_xml = xml, host = host, port = port, user = user, ident_key = key,
213                 agent = agent, launch = launch,
214                 environment_setup = environment_setup)
215     raise RuntimeError("Unsupported access configuration '%s'" % mode)
216
217 def create_testbed_controller(testbed_id, testbed_version, access_config):
218     mode = None if not access_config \
219             else access_config.get_attribute_value(DC.DEPLOYMENT_MODE)
220     launch = True if not access_config \
221             else not access_config.get_attribute_value(DC.RECOVER)
222     if not mode or mode == DC.MODE_SINGLE_PROCESS:
223         if not launch:
224             raise ValueError, "Unsupported instantiation mode: %s with lanch=False" % (mode,)
225         return  _build_testbed_controller(testbed_id, testbed_version)
226     elif mode == DC.MODE_DAEMON:
227         (root_dir, log_level, user, host, port, key, agent, environment_setup) = \
228                 get_access_config_params(access_config)
229         return TestbedControllerProxy(root_dir, log_level, testbed_id = testbed_id, 
230                 testbed_version = testbed_version, host = host, port = port, ident_key = key,
231                 user = user, agent = agent, launch = launch,
232                 environment_setup = environment_setup)
233     raise RuntimeError("Unsupported access configuration '%s'" % mode)
234
235 def _build_testbed_controller(testbed_id, testbed_version):
236     mod_name = "nepi.testbeds.%s" % (testbed_id.lower())
237     if not mod_name in sys.modules:
238         __import__(mod_name)
239     module = sys.modules[mod_name]
240     return module.TestbedController(testbed_version)
241
242 # Just a namespace class
243 class Marshalling:
244     class Decoders:
245         @staticmethod
246         def pickled_data(sdata):
247             return cPickle.loads(base64.b64decode(sdata))
248         
249         @staticmethod
250         def base64_data(sdata):
251             return base64.b64decode(sdata)
252         
253         @staticmethod
254         def nullint(sdata):
255             return None if sdata == "None" else int(sdata)
256         
257         @staticmethod
258         def bool(sdata):
259             return sdata == 'True'
260         
261     class Encoders:
262         @staticmethod
263         def pickled_data(data):
264             return base64.b64encode(cPickle.dumps(data))
265         
266         @staticmethod
267         def base64_data(data):
268             return base64.b64encode(data)
269         
270         @staticmethod
271         def nullint(data):
272             return "None" if data is None else int(data)
273         
274         @staticmethod
275         def bool(data):
276             return str(bool(data))
277            
278     # import into Marshalling all the decoders
279     # they act as types
280     locals().update([
281         (typname, typ)
282         for typname, typ in vars(Decoders).iteritems()
283         if not typname.startswith('_')
284     ])
285
286     _TYPE_ENCODERS = dict([
287         # id(type) -> (<encoding_function>, <formatting_string>)
288         (typname, (getattr(Encoders,typname),"%s"))
289         for typname in vars(Decoders)
290         if not typname.startswith('_')
291            and hasattr(Encoders,typname)
292     ])
293
294     # Builtins
295     _TYPE_ENCODERS["float"] = (float, "%r")
296     _TYPE_ENCODERS["int"] = (int, "%d")
297     _TYPE_ENCODERS["long"] = (int, "%d")
298     _TYPE_ENCODERS["str"] = (str, "%s")
299     _TYPE_ENCODERS["unicode"] = (str, "%s")
300     
301     # Generic encoder
302     _TYPE_ENCODERS[None] = (str, "%s")
303     
304     @staticmethod
305     def args(*types):
306         """
307         Decorator that converts the given function into one that takes
308         a single "params" list, with each parameter marshalled according
309         to the given factory callable (type constructors are accepted).
310         
311         The first argument (self) is left untouched.
312         
313         eg:
314         
315         @Marshalling.args(int,int,str,base64_data)
316         def somefunc(self, someint, otherint, somestr, someb64):
317            return someretval
318         """
319         def decor(f):
320             @functools.wraps(f)
321             def rv(self, params):
322                 return f(self, *[ ctor(val)
323                                   for ctor,val in zip(types, params[1:]) ])
324             
325             rv._argtypes = types
326             
327             # Derive type encoders by looking up types in _TYPE_ENCODERS
328             # make_proxy will use it to encode arguments in command strings
329             argencoders = []
330             TYPE_ENCODERS = Marshalling._TYPE_ENCODERS
331             for typ in types:
332                 if typ.__name__ in TYPE_ENCODERS:
333                     argencoders.append(TYPE_ENCODERS[typ.__name__])
334                 else:
335                     # generic encoder
336                     argencoders.append(TYPE_ENCODERS[None])
337             
338             rv._argencoders = tuple(argencoders)
339             
340             rv._retval = getattr(f, '_retval', None)
341             return rv
342         return decor
343
344     @staticmethod
345     def retval(typ=Decoders.base64_data):
346         """
347         Decorator that converts the given function into one that 
348         returns a properly encoded return string, given that the undecorated
349         function returns suitable input for the encoding function.
350         
351         The optional typ argument specifies a type.
352         For the default of base64_data, return values should be strings.
353         The return value of the encoding method should be a string always.
354         
355         eg:
356         
357         @Marshalling.args(int,int,str,base64_data)
358         @Marshalling.retval(str)
359         def somefunc(self, someint, otherint, somestr, someb64):
360            return someint
361         """
362         encode, fmt = Marshalling._TYPE_ENCODERS.get(
363             typ.__name__,
364             Marshalling._TYPE_ENCODERS[None])
365         fmt = "%d|"+fmt
366         
367         def decor(f):
368             @functools.wraps(f)
369             def rv(self, *p, **kw):
370                 data = f(self, *p, **kw)
371                 return fmt % (
372                     OK,
373                     encode(data)
374                 )
375             rv._retval = typ
376             rv._argtypes = getattr(f, '_argtypes', None)
377             rv._argencoders = getattr(f, '_argencoders', None)
378             return rv
379         return decor
380     
381     @staticmethod
382     def retvoid(f):
383         """
384         Decorator that converts the given function into one that 
385         always return an encoded empty string.
386         
387         Useful for null-returning functions.
388         """
389         OKRV = "%d|" % (OK,)
390         
391         @functools.wraps(f)
392         def rv(self, *p, **kw):
393             f(self, *p, **kw)
394             return OKRV
395         
396         rv._retval = None
397         rv._argtypes = getattr(f, '_argtypes', None)
398         rv._argencoders = getattr(f, '_argencoders', None)
399         return rv
400     
401     @staticmethod
402     def handles(whichcommand):
403         """
404         Associates the method with a given command code for servers.
405         It should always be the topmost decorator.
406         """
407         def decor(f):
408             f._handles_command = whichcommand
409             return f
410         return decor
411
412 class BaseServer(server.Server):
413     def reply_action(self, msg):
414         if not msg:
415             result = base64.b64encode("Invalid command line")
416             reply = "%d|%s" % (ERROR, result)
417         else:
418             params = msg.split("|")
419             instruction = int(params[0])
420             log_msg(self, params)
421             try:
422                 for mname,meth in vars(self.__class__).iteritems():
423                     if not mname.startswith('_'):
424                         cmd = getattr(meth, '_handles_command', None)
425                         if cmd == instruction:
426                             meth = getattr(self, mname)
427                             reply = meth(params)
428                             break
429                 else:
430                     error = "Invalid instruction %s" % instruction
431                     self.log_error(error)
432                     result = base64.b64encode(error)
433                     reply = "%d|%s" % (ERROR, result)
434             except:
435                 error = self.log_error()
436                 result = base64.b64encode(error)
437                 reply = "%d|%s" % (ERROR, result)
438         log_reply(self, reply)
439         return reply
440
441 class TestbedControllerServer(BaseServer):
442     def __init__(self, root_dir, log_level, testbed_id, testbed_version):
443         super(TestbedControllerServer, self).__init__(root_dir, log_level)
444         self._testbed_id = testbed_id
445         self._testbed_version = testbed_version
446         self._testbed = None
447
448     def post_daemonize(self):
449         self._testbed = _build_testbed_controller(self._testbed_id, 
450                 self._testbed_version)
451
452     @Marshalling.handles(GUIDS)
453     @Marshalling.args()
454     @Marshalling.retval( Marshalling.pickled_data )
455     def guids(self):
456         return self._testbed.guids
457
458     @Marshalling.handles(TESTBED_ID)
459     @Marshalling.args()
460     @Marshalling.retval()
461     def testbed_id(self):
462         return str(self._testbed.testbed_id)
463
464     @Marshalling.handles(TESTBED_VERSION)
465     @Marshalling.args()
466     @Marshalling.retval()
467     def testbed_version(self):
468         return str(self._testbed.testbed_version)
469
470     @Marshalling.handles(CREATE)
471     @Marshalling.args(int, str)
472     @Marshalling.retvoid
473     def defer_create(self, guid, factory_id):
474         self._testbed.defer_create(guid, factory_id)
475
476     @Marshalling.handles(TRACE)
477     @Marshalling.args(int, str, Marshalling.base64_data)
478     @Marshalling.retval()
479     def trace(self, guid, trace_id, attribute):
480         return self._testbed.trace(guid, trace_id, attribute)
481
482     @Marshalling.handles(TRACES_INFO)
483     @Marshalling.args()
484     @Marshalling.retval( Marshalling.pickled_data )
485     def traces_info(self):
486         return self._testbed.traces_info()
487
488     @Marshalling.handles(START)
489     @Marshalling.args()
490     @Marshalling.retvoid
491     def start(self):
492         self._testbed.start()
493
494     @Marshalling.handles(STOP)
495     @Marshalling.args()
496     @Marshalling.retvoid
497     def stop(self):
498         self._testbed.stop()
499
500     @Marshalling.handles(SHUTDOWN)
501     @Marshalling.args()
502     @Marshalling.retvoid
503     def shutdown(self):
504         self._testbed.shutdown()
505
506     @Marshalling.handles(CONFIGURE)
507     @Marshalling.args(Marshalling.base64_data, Marshalling.pickled_data)
508     @Marshalling.retvoid
509     def defer_configure(self, name, value):
510         self._testbed.defer_configure(name, value)
511
512     @Marshalling.handles(CREATE_SET)
513     @Marshalling.args(int, Marshalling.base64_data, Marshalling.pickled_data)
514     @Marshalling.retvoid
515     def defer_create_set(self, guid, name, value):
516         self._testbed.defer_create_set(guid, name, value)
517
518     @Marshalling.handles(FACTORY_SET)
519     @Marshalling.args(Marshalling.base64_data, Marshalling.pickled_data)
520     @Marshalling.retvoid
521     def defer_factory_set(self, name, value):
522         self._testbed.defer_factory_set(name, value)
523
524     @Marshalling.handles(CONNECT)
525     @Marshalling.args(int, str, int, str)
526     @Marshalling.retvoid
527     def defer_connect(self, guid1, connector_type_name1, guid2, connector_type_name2):
528         self._testbed.defer_connect(guid1, connector_type_name1, guid2, 
529             connector_type_name2)
530
531     @Marshalling.handles(CROSS_CONNECT)
532     @Marshalling.args(int, str, int, int, str, str, str)
533     @Marshalling.retvoid
534     def defer_cross_connect(self, 
535             guid, connector_type_name,
536             cross_guid, cross_testbed_guid,
537             cross_testbed_id, cross_factory_id,
538             cross_connector_type_name):
539         self._testbed.defer_cross_connect(guid, connector_type_name, cross_guid, 
540             cross_testbed_guid, cross_testbed_id, cross_factory_id, 
541             cross_connector_type_name)
542
543     @Marshalling.handles(ADD_TRACE)
544     @Marshalling.args(int, str)
545     @Marshalling.retvoid
546     def defer_add_trace(self, guid, trace_id):
547         self._testbed.defer_add_trace(guid, trace_id)
548
549     @Marshalling.handles(ADD_ADDRESS)
550     @Marshalling.args(int, str, int, str)
551     @Marshalling.retvoid
552     def defer_add_address(self, guid, address, netprefix, broadcast):
553         self._testbed.defer_add_address(guid, address, netprefix,
554                 broadcast)
555
556     @Marshalling.handles(ADD_ROUTE)
557     @Marshalling.args(int, str, int, str)
558     @Marshalling.retvoid
559     def defer_add_route(self, guid, destination, netprefix, nexthop):
560         self._testbed.defer_add_route(guid, destination, netprefix, nexthop)
561
562     @Marshalling.handles(DO_SETUP)
563     @Marshalling.args()
564     @Marshalling.retvoid
565     def do_setup(self):
566         self._testbed.do_setup()
567
568     @Marshalling.handles(DO_CREATE)
569     @Marshalling.args()
570     @Marshalling.retvoid
571     def do_create(self):
572         self._testbed.do_create()
573
574     @Marshalling.handles(DO_CONNECT_INIT)
575     @Marshalling.args()
576     @Marshalling.retvoid
577     def do_connect_init(self):
578         self._testbed.do_connect_init()
579
580     @Marshalling.handles(DO_CONNECT_COMPL)
581     @Marshalling.args()
582     @Marshalling.retvoid
583     def do_connect_compl(self):
584         self._testbed.do_connect_compl()
585
586     @Marshalling.handles(DO_CONFIGURE)
587     @Marshalling.args()
588     @Marshalling.retvoid
589     def do_configure(self):
590         self._testbed.do_configure()
591
592     @Marshalling.handles(DO_PRECONFIGURE)
593     @Marshalling.args()
594     @Marshalling.retvoid
595     def do_preconfigure(self):
596         self._testbed.do_preconfigure()
597
598     @Marshalling.handles(DO_PRESTART)
599     @Marshalling.args()
600     @Marshalling.retvoid
601     def do_prestart(self):
602         self._testbed.do_prestart()
603
604     @Marshalling.handles(DO_CROSS_CONNECT_INIT)
605     @Marshalling.args( Marshalling.Decoders.pickled_data )
606     @Marshalling.retvoid
607     def do_cross_connect_init(self, cross_data):
608         self._testbed.do_cross_connect_init(cross_data)
609
610     @Marshalling.handles(DO_CROSS_CONNECT_COMPL)
611     @Marshalling.args( Marshalling.Decoders.pickled_data )
612     @Marshalling.retvoid
613     def do_cross_connect_compl(self, cross_data):
614         self._testbed.do_cross_connect_compl(cross_data)
615
616     @Marshalling.handles(GET)
617     @Marshalling.args(int, Marshalling.base64_data, str)
618     @Marshalling.retval( Marshalling.pickled_data )
619     def get(self, guid, name, time):
620         return self._testbed.get(guid, name, time)
621
622     @Marshalling.handles(SET)
623     @Marshalling.args(int, Marshalling.base64_data, Marshalling.pickled_data, str)
624     @Marshalling.retvoid
625     def set(self, guid, name, value, time):
626         self._testbed.set(guid, name, value, time)
627
628     @Marshalling.handles(GET_ADDRESS)
629     @Marshalling.args(int, int, Marshalling.base64_data)
630     @Marshalling.retval()
631     def get_address(self, guid, index, attribute):
632         return str(self._testbed.get_address(guid, index, attribute))
633
634     @Marshalling.handles(GET_ROUTE)
635     @Marshalling.args(int, int, Marshalling.base64_data)
636     @Marshalling.retval()
637     def get_route(self, guid, index, attribute):
638         return str(self._testbed.get_route(guid, index, attribute))
639
640     @Marshalling.handles(ACTION)
641     @Marshalling.args(str, int, Marshalling.base64_data)
642     @Marshalling.retvoid
643     def action(self, time, guid, command):
644         self._testbed.action(time, guid, command)
645
646     @Marshalling.handles(STATUS)
647     @Marshalling.args(Marshalling.nullint)
648     @Marshalling.retval(int)
649     def status(self, guid):
650         return self._testbed.status(guid)
651
652     @Marshalling.handles(GET_ATTRIBUTE_LIST)
653     @Marshalling.args(int, int, Marshalling.bool)
654     @Marshalling.retval( Marshalling.pickled_data )
655     def get_attribute_list(self, guid, filter_flags = None, exclude = False):
656         return self._testbed.get_attribute_list(guid, filter_flags, exclude)
657
658     @Marshalling.handles(GET_FACTORY_ID)
659     @Marshalling.args(int)
660     @Marshalling.retval()
661     def get_factory_id(self, guid):
662         return self._testbed.get_factory_id(guid)
663
664 class ExperimentControllerServer(BaseServer):
665     def __init__(self, root_dir, log_level, experiment_xml):
666         super(ExperimentControllerServer, self).__init__(root_dir, log_level)
667         self._experiment_xml = experiment_xml
668         self._experiment = None
669
670     def post_daemonize(self):
671         from nepi.core.execute import ExperimentController
672         self._experiment = ExperimentController(self._experiment_xml, 
673             root_dir = self._root_dir)
674
675     @Marshalling.handles(GUIDS)
676     @Marshalling.args()
677     @Marshalling.retval( Marshalling.pickled_data )
678     def guids(self):
679         return self._experiment.guids
680
681     @Marshalling.handles(XML)
682     @Marshalling.args()
683     @Marshalling.retval()
684     def experiment_design_xml(self):
685         return self._experiment.experiment_design_xml
686         
687     @Marshalling.handles(EXEC_XML)
688     @Marshalling.args()
689     @Marshalling.retval()
690     def experiment_execute_xml(self):
691         return self._experiment.experiment_execute_xml
692         
693     @Marshalling.handles(TRACE)
694     @Marshalling.args(int, str, Marshalling.base64_data)
695     @Marshalling.retval()
696     def trace(self, guid, trace_id, attribute):
697         return str(self._experiment.trace(guid, trace_id, attribute))
698
699     @Marshalling.handles(TRACES_INFO)
700     @Marshalling.args()
701     @Marshalling.retval( Marshalling.pickled_data )
702     def traces_info(self):
703         return self._experiment.traces_info()
704
705     @Marshalling.handles(FINISHED)
706     @Marshalling.args(int)
707     @Marshalling.retval(Marshalling.bool)
708     def is_finished(self, guid):
709         return self._experiment.is_finished(guid)
710
711     @Marshalling.handles(GET)
712     @Marshalling.args(int, Marshalling.base64_data, str)
713     @Marshalling.retval( Marshalling.pickled_data )
714     def get(self, guid, name, time):
715         return self._experiment.get(guid, name, time)
716
717     @Marshalling.handles(SET)
718     @Marshalling.args(int, Marshalling.base64_data, Marshalling.pickled_data, str)
719     @Marshalling.retvoid
720     def set(self, guid, name, value, time):
721         self._experiment.set(guid, name, value, time)
722
723     @Marshalling.handles(START)
724     @Marshalling.args()
725     @Marshalling.retvoid
726     def start(self):
727         self._experiment.start()
728
729     @Marshalling.handles(STOP)
730     @Marshalling.args()
731     @Marshalling.retvoid
732     def stop(self):
733         self._experiment.stop()
734
735     @Marshalling.handles(RECOVER)
736     @Marshalling.args()
737     @Marshalling.retvoid
738     def recover(self):
739         self._experiment.recover()
740
741     @Marshalling.handles(SHUTDOWN)
742     @Marshalling.args()
743     @Marshalling.retvoid
744     def shutdown(self):
745         self._experiment.shutdown()
746
747     @Marshalling.handles(GET_TESTBED_ID)
748     @Marshalling.args(int)
749     @Marshalling.retval()
750     def get_testbed_id(self, guid):
751         return self._experiment.get_testbed_id(guid)
752
753     @Marshalling.handles(GET_FACTORY_ID)
754     @Marshalling.args(int)
755     @Marshalling.retval()
756     def get_factory_id(self, guid):
757         return self._experiment.get_factory_id(guid)
758
759     @Marshalling.handles(GET_TESTBED_VERSION)
760     @Marshalling.args(int)
761     @Marshalling.retval()
762     def get_testbed_version(self, guid):
763         return self._experiment.get_testbed_version(guid)
764
765 class BaseProxy(object):
766     _ServerClass = None
767     _ServerClassModule = "nepi.util.proxy"
768     
769     def __init__(self, 
770             ctor_args, root_dir, 
771             launch = True, host = None, 
772             port = None, user = None, ident_key = None, agent = None,
773             environment_setup = ""):
774         if launch:
775             # ssh
776             if host != None:
777                 python_code = (
778                     "from %(classmodule)s import %(classname)s;"
779                     "s = %(classname)s%(ctor_args)r;"
780                     "s.run()" 
781                 % dict(
782                     classname = self._ServerClass.__name__,
783                     classmodule = self._ServerClassModule,
784                     ctor_args = ctor_args
785                 ) )
786                 proc = server.popen_ssh_subprocess(python_code, host = host,
787                     port = port, user = user, agent = agent,
788                     ident_key = ident_key,
789                     environment_setup = environment_setup,
790                     waitcommand = True)
791                 if proc.poll():
792                     err = proc.stderr.read()
793                     raise RuntimeError, "Server could not be executed: %s" % (err,)
794             else:
795                 # launch daemon
796                 s = self._ServerClass(*ctor_args)
797                 s.run()
798
799         # connect client to server
800         self._client = server.Client(root_dir, host = host, port = port, 
801                 user = user, agent = agent, 
802                 environment_setup = environment_setup)
803     
804     @staticmethod
805     def _make_message(argtypes, argencoders, command, methname, classname, *args):
806         if len(argtypes) != len(argencoders):
807             raise ValueError, "Invalid arguments for _make_message: "\
808                 "in stub method %s of class %s "\
809                 "argtypes and argencoders must match in size" % (
810                     methname, classname )
811         if len(argtypes) != len(args):
812             raise ValueError, "Invalid arguments for _make_message: "\
813                 "in stub method %s of class %s "\
814                 "expected %d arguments, got %d" % (
815                     methname, classname,
816                     len(argtypes), len(args))
817         
818         buf = []
819         for argnum, (typ, (encode, fmt), val) in enumerate(zip(argtypes, argencoders, args)):
820             try:
821                 buf.append(fmt % encode(val))
822             except:
823                 import traceback
824                 raise TypeError, "Argument %d of stub method %s of class %s "\
825                     "requires a value of type %s, but got %s - nested error: %s" % (
826                         argnum, methname, classname,
827                         getattr(typ, '__name__', typ), type(val),
828                         traceback.format_exc()
829                 )
830         
831         return "%d|%s" % (command, '|'.join(buf))
832     
833     @staticmethod
834     def _parse_reply(rvtype, methname, classname, reply):
835         if not reply:
836             raise RuntimeError, "Invalid reply: %r "\
837                 "for stub method %s of class %s" % (
838                     reply,
839                     methname,
840                     classname)
841         
842         try:
843             result = reply.split("|")
844             code = int(result[0])
845             text = result[1]
846         except:
847             import traceback
848             raise TypeError, "Return value of stub method %s of class %s "\
849                 "cannot be parsed: must be of type %s, got %r - nested error: %s" % (
850                     methname, classname,
851                     getattr(rvtype, '__name__', rvtype), reply,
852                     traceback.format_exc()
853             )
854         if code == ERROR:
855             text = base64.b64decode(text)
856             raise RuntimeError(text)
857         elif code == OK:
858             try:
859                 if rvtype is None:
860                     return
861                 else:
862                     return rvtype(text)
863             except:
864                 import traceback
865                 raise TypeError, "Return value of stub method %s of class %s "\
866                     "cannot be parsed: must be of type %s - nested error: %s" % (
867                         methname, classname,
868                         getattr(rvtype, '__name__', rvtype),
869                         traceback.format_exc()
870                 )
871         else:
872             raise RuntimeError, "Invalid reply: %r "\
873                 "for stub method %s of class %s - unknown code" % (
874                     reply,
875                     methname,
876                     classname)
877     
878     @staticmethod
879     def _make_stubs(server_class, template_class):
880         """
881         Returns a dictionary method_name -> method
882         with stub methods.
883         
884         Usage:
885         
886             class SomeProxy(BaseProxy):
887                ...
888                
889                locals().update( BaseProxy._make_stubs(
890                     ServerClass,
891                     TemplateClass
892                ) )
893         
894         ServerClass is the corresponding Server class, as
895         specified in the _ServerClass class method (_make_stubs
896         is static and can't access the method), and TemplateClass
897         is the ultimate implementation class behind the server,
898         from which argument names and defaults are taken, to
899         maintain meaningful interfaces.
900         """
901         rv = {}
902         
903         class NONE: pass
904         
905         import os.path
906         func_template_path = os.path.join(
907             os.path.dirname(__file__),
908             'proxy_stub.tpl')
909         func_template_file = open(func_template_path, "r")
910         func_template = func_template_file.read()
911         func_template_file.close()
912         
913         for methname in vars(template_class).copy():
914             if methname.endswith('_deferred'):
915                 # cannot wrap deferreds...
916                 continue
917             dmethname = methname+'_deferred'
918             if hasattr(server_class, methname) and not methname.startswith('_'):
919                 template_meth = getattr(template_class, methname)
920                 server_meth = getattr(server_class, methname)
921                 
922                 command = getattr(server_meth, '_handles_command', None)
923                 argtypes = getattr(server_meth, '_argtypes', None)
924                 argencoders = getattr(server_meth, '_argencoders', None)
925                 rvtype = getattr(server_meth, '_retval', None)
926                 doprop = False
927                 
928                 if hasattr(template_meth, 'fget'):
929                     # property getter
930                     template_meth = template_meth.fget
931                     doprop = True
932                 
933                 if command is not None and argtypes is not None and argencoders is not None:
934                     # We have an interface method...
935                     code = template_meth.func_code
936                     argnames = code.co_varnames[:code.co_argcount]
937                     argdefaults = ( (NONE,) * (len(argnames) - len(template_meth.func_defaults or ()))
938                                   + (template_meth.func_defaults or ()) )
939                     
940                     func_globals = dict(
941                         BaseProxy = BaseProxy,
942                         argtypes = argtypes,
943                         argencoders = argencoders,
944                         rvtype = rvtype,
945                         functools = functools,
946                     )
947                     context = dict()
948                     
949                     func_text = func_template % dict(
950                         self = argnames[0],
951                         args = '%s' % (','.join(argnames[1:])),
952                         argdefs = ','.join([
953                             argname if argdef is NONE
954                             else "%s=%r" % (argname, argdef)
955                             for argname, argdef in zip(argnames[1:], argdefaults[1:])
956                         ]),
957                         command = command,
958                         methname = methname,
959                         classname = server_class.__name__
960                     )
961                     
962                     func_text = compile(
963                         func_text,
964                         func_template_path,
965                         'exec')
966                     
967                     exec func_text in func_globals, context
968                     
969                     if doprop:
970                         rv[methname] = property(context[methname])
971                         rv[dmethname] = property(context[dmethname])
972                     else:
973                         rv[methname] = context[methname]
974                         rv[dmethname] = context[dmethname]
975                     
976                     # inject _deferred into core classes
977                     if hasattr(template_class, methname) and not hasattr(template_class, dmethname):
978                         def freezename(methname, dmethname):
979                             def dmeth(self, *p, **kw): 
980                                 return getattr(self, methname)(*p, **kw)
981                             dmeth.__name__ = dmethname
982                             return dmeth
983                         dmeth = freezename(methname, dmethname)
984                         setattr(template_class, dmethname, dmeth)
985         
986         return rv
987                         
988 class TestbedControllerProxy(BaseProxy):
989     
990     _ServerClass = TestbedControllerServer
991     
992     def __init__(self, root_dir, log_level, testbed_id = None, 
993             testbed_version = None, launch = True, host = None, 
994             port = None, user = None, ident_key = None, agent = None,
995             environment_setup = ""):
996         if launch and (testbed_id == None or testbed_version == None):
997             raise RuntimeError("To launch a TesbedControllerServer a "
998                     "testbed_id and testbed_version are required")
999         super(TestbedControllerProxy,self).__init__(
1000             ctor_args = (root_dir, log_level, testbed_id, testbed_version),
1001             root_dir = root_dir,
1002             launch = launch, host = host, port = port, user = user,
1003             ident_key = ident_key, agent = agent, 
1004             environment_setup = environment_setup)
1005
1006     locals().update( BaseProxy._make_stubs(
1007         server_class = TestbedControllerServer,
1008         template_class = nepi.core.execute.TestbedController,
1009     ) )
1010     
1011     # Shutdown stops the serverside...
1012     def shutdown(self, _stub = shutdown):
1013         rv = _stub(self)
1014         self._client.send_stop()
1015         self._client.read_reply() # wait for it
1016         return rv
1017     
1018
1019 class ExperimentControllerProxy(BaseProxy):
1020     _ServerClass = ExperimentControllerServer
1021     
1022     def __init__(self, root_dir, log_level, experiment_xml = None, 
1023             launch = True, host = None, port = None, user = None, 
1024             ident_key = None, agent = None, environment_setup = ""):
1025         if launch and experiment_xml is None:
1026             raise RuntimeError("To launch a ExperimentControllerServer a \
1027                     xml description of the experiment is required")
1028         super(ExperimentControllerProxy,self).__init__(
1029             ctor_args = (root_dir, log_level, experiment_xml),
1030             root_dir = root_dir,
1031             launch = launch, host = host, port = port, user = user,
1032             ident_key = ident_key, agent = agent, 
1033             environment_setup = environment_setup)
1034
1035     locals().update( BaseProxy._make_stubs(
1036         server_class = ExperimentControllerServer,
1037         template_class = nepi.core.execute.ExperimentController,
1038     ) )
1039
1040     
1041     # Shutdown stops the serverside...
1042     def shutdown(self, _stub = shutdown):
1043         rv = _stub(self)
1044         self._client.send_stop()
1045         self._client.read_reply() # wait for it
1046         return rv
1047