Adding working UdpTunnel for Planetlab and Linux
[nepi.git] / src / nepi / resources / linux / udptunnel.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19
20 from nepi.execution.attribute import Attribute, Flags, Types
21 from nepi.execution.resource import ResourceManager, clsinit_copy, ResourceState, \
22         reschedule_delay
23 from nepi.resources.linux.application import LinuxApplication
24 from nepi.util.timefuncs import tnow, tdiffsec
25
26 import os
27 import socket
28 import time
29
30 @clsinit_copy
31 class UdpTunnel(LinuxApplication):
32     _rtype = "UdpTunnel"
33
34     def __init__(self, ec, guid):
35         super(UdpTunnel, self).__init__(ec, guid)
36         self._home = "udp-tunnel-%s" % self.guid
37         self._pid1 = None
38         self._ppid1 = None
39         self._pid2 = None
40         self._ppid2 = None
41
42     def log_message(self, msg):
43         return " guid %d - tunnel %s - %s - %s " % (self.guid, 
44                 self.endpoint1.node.get("hostname"), 
45                 self.endpoint2.node.get("hostname"), 
46                 msg)
47
48     def get_endpoints(self):
49         """ Returns the list of RM that are endpoints to the tunnel 
50         """
51         connected = []
52         for guid in self.connections:
53             rm = self.ec.get_resource(guid)
54             if hasattr(rm, "udp_connect_command"):
55                 connected.append(rm)
56         return connected
57
58     @property
59     def endpoint1(self):
60         endpoints = self.get_endpoints()
61         if endpoints: return endpoints[0]
62         return None
63
64     @property
65     def endpoint2(self):
66         endpoints = self.get_endpoints()
67         if endpoints and len(endpoints) > 1: return endpoints[1]
68         return None
69
70     def app_home(self, endpoint):
71         return os.path.join(endpoint.node.exp_home, self._home)
72
73     def run_home(self, endpoint):
74         return os.path.join(self.app_home(endpoint), self.ec.run_id)
75
76     def udp_connect(self, endpoint, remote_ip):
77         # Get udp connect command
78         local_port_file = os.path.join(self.run_home(endpoint), 
79                 "local_port")
80         remote_port_file = os.path.join(self.run_home(endpoint), 
81                 "remote_port")
82         ret_file = os.path.join(self.run_home(endpoint), 
83                 "ret_file")
84         udp_connect_command = endpoint.udp_connect_command(
85                 remote_ip, local_port_file, remote_port_file,
86                 ret_file)
87
88         # upload command to connect.sh script
89         shfile = os.path.join(self.app_home(endpoint), "udp-connect.sh")
90         endpoint.node.upload(udp_connect_command,
91                 shfile,
92                 text = True, 
93                 overwrite = False)
94
95         # invoke connect script
96         cmd = "bash %s" % shfile
97         (out, err), proc = endpoint.node.run(cmd, self.run_home(endpoint)) 
98              
99         # check if execution errors occurred
100         msg = " Failed to connect endpoints "
101         
102         if proc.poll():
103             self.fail()
104             self.error(msg, out, err)
105             raise RuntimeError, msg
106     
107         # Wait for pid file to be generated
108         pid, ppid = endpoint.node.wait_pid(self.run_home(endpoint))
109         
110         # If the process is not running, check for error information
111         # on the remote machine
112         if not pid or not ppid:
113             (out, err), proc = endpoint.node.check_errors(self.run_home(endpoint))
114             # Out is what was written in the stderr file
115             if err:
116                 self.fail()
117                 msg = " Failed to start command '%s' " % command
118                 self.error(msg, out, err)
119                 raise RuntimeError, msg
120
121         # wait until port is written to file
122         port = self.wait_local_port(endpoint)
123         return (port, pid, ppid)
124
125     def provision(self):
126         # create run dir for tunnel on each node 
127         self.endpoint1.node.mkdir(self.run_home(self.endpoint1))
128         self.endpoint2.node.mkdir(self.run_home(self.endpoint2))
129
130         # Invoke connect script in endpoint 1
131         remote_ip1 = socket.gethostbyname(self.endpoint2.node.get("hostname"))
132         (port1, self._pid1, self._ppid1) = self.udp_connect(self.endpoint1,
133                 remote_ip1)
134
135         # Invoke connect script in endpoint 2
136         remote_ip2 = socket.gethostbyname(self.endpoint1.node.get("hostname"))
137         (port2, self._pid2, self._ppid2) = self.udp_connect(self.endpoint2,
138                 remote_ip2)
139
140         # upload file with port 2 to endpoint 1
141         self.upload_remote_port(self.endpoint1, port2)
142         
143         # upload file with port 1 to endpoint 2
144         self.upload_remote_port(self.endpoint2, port1)
145
146         # check if connection was successful on both sides
147         self.wait_result(self.endpoint1)
148         self.wait_result(self.endpoint2)
149        
150         self.info("Provisioning finished")
151  
152         self.debug("----- READY ---- ")
153         self._provision_time = tnow()
154         self._state = ResourceState.PROVISIONED
155
156     def deploy(self):
157         if (not self.endpoint1 or self.endpoint1.state < ResourceState.READY) or \
158             (not self.endpoint2 or self.endpoint2.state < ResourceState.READY):
159             self.ec.schedule(reschedule_delay, self.deploy)
160         else:
161             try:
162                 self.discover()
163                 self.provision()
164             except:
165                 self.fail()
166                 raise
167  
168             self.debug("----- READY ---- ")
169             self._ready_time = tnow()
170             self._state = ResourceState.READY
171
172     def start(self):
173         if self._state == ResourceState.READY:
174             command = self.get("command")
175             self.info("Starting command '%s'" % command)
176
177             self._start_time = tnow()
178             self._state = ResourceState.STARTED
179         else:
180             msg = " Failed to execute command '%s'" % command
181             self.error(msg, out, err)
182             self._state = ResourceState.FAILED
183             raise RuntimeError, msg
184
185     def stop(self):
186         command = self.get('command') or ''
187         state = self.state
188         
189         if state == ResourceState.STARTED:
190             self.info("Stopping command '%s'" % command)
191
192             command = "bash %s" % os.path.join(self.app_home, "stop.sh")
193             (out, err), proc = self.execute_command(command,
194                     blocking = True)
195
196             self._stop_time = tnow()
197             self._state = ResourceState.STOPPED
198
199     def stop(self):
200         """ Stops application execution
201         """
202         if self.state == ResourceState.STARTED:
203             stopped = True
204             self.info("Stopping tunnel")
205     
206             # Only try to kill the process if the pid and ppid
207             # were retrieved
208             if self._pid1 and self._ppid1 and self._pid2 and self._ppid2:
209                 (out1, err1), proc1 = self.endpoint1.node.kill(self._pid1,
210                         self._ppid1, sudo = True) 
211                 (out2, err2), proc2 = self.endpoint2.node.kill(self._pid2, 
212                         self._ppid2, sudo = True) 
213
214                 if err1 or err2 or pro1.poll() or proc2.poll():
215                     # check if execution errors occurred
216                     msg = " Failed to STOP tunnel"
217                     self.error(msg, out, err)
218                     self.fail()
219                     stopped = False
220
221             if stopped:
222                 self._stop_time = tnow()
223                 self._state = ResourceState.STOPPED
224
225     @property
226     def state(self):
227         """ Returns the state of the application
228         """
229         if self._state == ResourceState.STARTED:
230             # In order to avoid overwhelming the remote host and
231             # the local processor with too many ssh queries, the state is only
232             # requested every 'state_check_delay' seconds.
233             state_check_delay = 0.5
234             if tdiffsec(tnow(), self._last_state_check) > state_check_delay:
235                 # check if execution errors occurred
236                 (out1, err1), proc1 = self.endpoint1.node.check_errors(
237                         self.run_home(self.endpoint1))
238
239                 (out2, err2), proc2 = self.endpoint2.node.check_errors(
240                         self.run_home(self.endpoint2))
241
242                 if err1 or err2:
243                     msg = " Failed to connect endpoints "
244                     self.error(msg, err1, err2)
245                     self.fail()
246
247                 elif self._pid1 and self._ppid1 and self._pid2 and self._ppid2:
248                     # No execution errors occurred. Make sure the background
249                     # process with the recorded pid is still running.
250                     status1 = self.node.status(self._pid1, self._ppid1)
251                     status2 = self.node.status(self._pid2, self._ppid2)
252
253                     if status1 == ProcStatus.FINISHED and \
254                             satus2 == ProcStatus.FINISHED:
255                         self._state = ResourceState.FINISHED
256
257                 self._last_state_check = tnow()
258
259         return self._state
260
261     def wait_local_port(self, endpoint):
262         """ Waits until the local_port file for the endpoint is generated, 
263             and returns the port number """
264         return self.wait_file(endpoint, "local_port")
265
266     def wait_result(self, endpoint):
267         """ Waits until the return code file for the endpoint is generated """ 
268         return self.wait_file(endpoint, "ret_file")
269  
270     def wait_file(self, endpoint, filename):
271         """ Waits until file on endpoint is generated """
272         result = None
273         delay = 1.0
274
275         for i in xrange(4):
276             (out, err), proc = endpoint.node.check_output(
277                     self.run_home(endpoint), filename)
278
279             if out:
280                 result = out.strip()
281                 break
282             else:
283                 time.sleep(delay)
284                 delay = delay * 1.5
285         else:
286             msg = "Couldn't retrieve %s" % filename
287             self.error(msg, out, err)
288             self.fail()
289             raise RuntimeError, msg
290
291         return result
292
293     def upload_remote_port(self, endpoint, port):
294         # upload remote port number to file
295         port = "%s\n" % port
296         endpoint.node.upload(port,
297                 os.path.join(self.run_home(endpoint), "remote_port"),
298                 text = True, 
299                 overwrite = False)
300
301     def valid_connection(self, guid):
302         # TODO: Validate!
303         return True
304