Merge branch 'mainstream'
[sliver-openvswitch.git] / python / ovs / jsonrpc.py
1 # Copyright (c) 2010, 2011, 2012, 2013 Nicira, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import errno
16 import os
17
18 import ovs.json
19 import ovs.poller
20 import ovs.reconnect
21 import ovs.stream
22 import ovs.timeval
23 import ovs.util
24 import ovs.vlog
25
26 EOF = ovs.util.EOF
27 vlog = ovs.vlog.Vlog("jsonrpc")
28
29
30 class Message(object):
31     T_REQUEST = 0               # Request.
32     T_NOTIFY = 1                # Notification.
33     T_REPLY = 2                 # Successful reply.
34     T_ERROR = 3                 # Error reply.
35
36     __types = {T_REQUEST: "request",
37                T_NOTIFY: "notification",
38                T_REPLY: "reply",
39                T_ERROR: "error"}
40
41     def __init__(self, type_, method, params, result, error, id):
42         self.type = type_
43         self.method = method
44         self.params = params
45         self.result = result
46         self.error = error
47         self.id = id
48
49     _next_id = 0
50
51     @staticmethod
52     def _create_id():
53         this_id = Message._next_id
54         Message._next_id += 1
55         return this_id
56
57     @staticmethod
58     def create_request(method, params):
59         return Message(Message.T_REQUEST, method, params, None, None,
60                        Message._create_id())
61
62     @staticmethod
63     def create_notify(method, params):
64         return Message(Message.T_NOTIFY, method, params, None, None,
65                        None)
66
67     @staticmethod
68     def create_reply(result, id):
69         return Message(Message.T_REPLY, None, None, result, None, id)
70
71     @staticmethod
72     def create_error(error, id):
73         return Message(Message.T_ERROR, None, None, None, error, id)
74
75     @staticmethod
76     def type_to_string(type_):
77         return Message.__types[type_]
78
79     def __validate_arg(self, value, name, must_have):
80         if (value is not None) == (must_have != 0):
81             return None
82         else:
83             type_name = Message.type_to_string(self.type)
84             if must_have:
85                 verb = "must"
86             else:
87                 verb = "must not"
88             return "%s %s have \"%s\"" % (type_name, verb, name)
89
90     def is_valid(self):
91         if self.params is not None and type(self.params) != list:
92             return "\"params\" must be JSON array"
93
94         pattern = {Message.T_REQUEST: 0x11001,
95                    Message.T_NOTIFY:  0x11000,
96                    Message.T_REPLY:   0x00101,
97                    Message.T_ERROR:   0x00011}.get(self.type)
98         if pattern is None:
99             return "invalid JSON-RPC message type %s" % self.type
100
101         return (
102             self.__validate_arg(self.method, "method", pattern & 0x10000) or
103             self.__validate_arg(self.params, "params", pattern & 0x1000) or
104             self.__validate_arg(self.result, "result", pattern & 0x100) or
105             self.__validate_arg(self.error, "error", pattern & 0x10) or
106             self.__validate_arg(self.id, "id", pattern & 0x1))
107
108     @staticmethod
109     def from_json(json):
110         if type(json) != dict:
111             return "message is not a JSON object"
112
113         # Make a copy to avoid modifying the caller's dict.
114         json = dict(json)
115
116         if "method" in json:
117             method = json.pop("method")
118             if type(method) not in [str, unicode]:
119                 return "method is not a JSON string"
120         else:
121             method = None
122
123         params = json.pop("params", None)
124         result = json.pop("result", None)
125         error = json.pop("error", None)
126         id_ = json.pop("id", None)
127         if len(json):
128             return "message has unexpected member \"%s\"" % json.popitem()[0]
129
130         if result is not None:
131             msg_type = Message.T_REPLY
132         elif error is not None:
133             msg_type = Message.T_ERROR
134         elif id_ is not None:
135             msg_type = Message.T_REQUEST
136         else:
137             msg_type = Message.T_NOTIFY
138
139         msg = Message(msg_type, method, params, result, error, id_)
140         validation_error = msg.is_valid()
141         if validation_error is not None:
142             return validation_error
143         else:
144             return msg
145
146     def to_json(self):
147         json = {}
148
149         if self.method is not None:
150             json["method"] = self.method
151
152         if self.params is not None:
153             json["params"] = self.params
154
155         if self.result is not None or self.type == Message.T_ERROR:
156             json["result"] = self.result
157
158         if self.error is not None or self.type == Message.T_REPLY:
159             json["error"] = self.error
160
161         if self.id is not None or self.type == Message.T_NOTIFY:
162             json["id"] = self.id
163
164         return json
165
166     def __str__(self):
167         s = [Message.type_to_string(self.type)]
168         if self.method is not None:
169             s.append("method=\"%s\"" % self.method)
170         if self.params is not None:
171             s.append("params=" + ovs.json.to_string(self.params))
172         if self.result is not None:
173             s.append("result=" + ovs.json.to_string(self.result))
174         if self.error is not None:
175             s.append("error=" + ovs.json.to_string(self.error))
176         if self.id is not None:
177             s.append("id=" + ovs.json.to_string(self.id))
178         return ", ".join(s)
179
180
181 class Connection(object):
182     def __init__(self, stream):
183         self.name = stream.name
184         self.stream = stream
185         self.status = 0
186         self.input = ""
187         self.output = ""
188         self.parser = None
189         self.received_bytes = 0
190
191     def close(self):
192         self.stream.close()
193         self.stream = None
194
195     def run(self):
196         if self.status:
197             return
198
199         while len(self.output):
200             retval = self.stream.send(self.output)
201             if retval >= 0:
202                 self.output = self.output[retval:]
203             else:
204                 if retval != -errno.EAGAIN:
205                     vlog.warn("%s: send error: %s" %
206                               (self.name, os.strerror(-retval)))
207                     self.error(-retval)
208                 break
209
210     def wait(self, poller):
211         if not self.status:
212             self.stream.run_wait(poller)
213             if len(self.output):
214                 self.stream.send_wait(poller)
215
216     def get_status(self):
217         return self.status
218
219     def get_backlog(self):
220         if self.status != 0:
221             return 0
222         else:
223             return len(self.output)
224
225     def get_received_bytes(self):
226         return self.received_bytes
227
228     def __log_msg(self, title, msg):
229         if vlog.dbg_is_enabled():
230             vlog.dbg("%s: %s %s" % (self.name, title, msg))
231
232     def send(self, msg):
233         if self.status:
234             return self.status
235
236         self.__log_msg("send", msg)
237
238         was_empty = len(self.output) == 0
239         self.output += ovs.json.to_string(msg.to_json())
240         if was_empty:
241             self.run()
242         return self.status
243
244     def send_block(self, msg):
245         error = self.send(msg)
246         if error:
247             return error
248
249         while True:
250             self.run()
251             if not self.get_backlog() or self.get_status():
252                 return self.status
253
254             poller = ovs.poller.Poller()
255             self.wait(poller)
256             poller.block()
257
258     def recv(self):
259         if self.status:
260             return self.status, None
261
262         while True:
263             if not self.input:
264                 error, data = self.stream.recv(4096)
265                 if error:
266                     if error == errno.EAGAIN:
267                         return error, None
268                     else:
269                         # XXX rate-limit
270                         vlog.warn("%s: receive error: %s"
271                                   % (self.name, os.strerror(error)))
272                         self.error(error)
273                         return self.status, None
274                 elif not data:
275                     self.error(EOF)
276                     return EOF, None
277                 else:
278                     self.input += data
279                     self.received_bytes += len(data)
280             else:
281                 if self.parser is None:
282                     self.parser = ovs.json.Parser()
283                 self.input = self.input[self.parser.feed(self.input):]
284                 if self.parser.is_done():
285                     msg = self.__process_msg()
286                     if msg:
287                         return 0, msg
288                     else:
289                         return self.status, None
290
291     def recv_block(self):
292         while True:
293             error, msg = self.recv()
294             if error != errno.EAGAIN:
295                 return error, msg
296
297             self.run()
298
299             poller = ovs.poller.Poller()
300             self.wait(poller)
301             self.recv_wait(poller)
302             poller.block()
303
304     def transact_block(self, request):
305         id_ = request.id
306
307         error = self.send(request)
308         reply = None
309         while not error:
310             error, reply = self.recv_block()
311             if (reply
312                 and (reply.type == Message.T_REPLY
313                      or reply.type == Message.T_ERROR)
314                 and reply.id == id_):
315                 break
316         return error, reply
317
318     def __process_msg(self):
319         json = self.parser.finish()
320         self.parser = None
321         if type(json) in [str, unicode]:
322             # XXX rate-limit
323             vlog.warn("%s: error parsing stream: %s" % (self.name, json))
324             self.error(errno.EPROTO)
325             return
326
327         msg = Message.from_json(json)
328         if not isinstance(msg, Message):
329             # XXX rate-limit
330             vlog.warn("%s: received bad JSON-RPC message: %s"
331                       % (self.name, msg))
332             self.error(errno.EPROTO)
333             return
334
335         self.__log_msg("received", msg)
336         return msg
337
338     def recv_wait(self, poller):
339         if self.status or self.input:
340             poller.immediate_wake()
341         else:
342             self.stream.recv_wait(poller)
343
344     def error(self, error):
345         if self.status == 0:
346             self.status = error
347             self.stream.close()
348             self.output = ""
349
350
351 class Session(object):
352     """A JSON-RPC session with reconnection."""
353
354     def __init__(self, reconnect, rpc):
355         self.reconnect = reconnect
356         self.rpc = rpc
357         self.stream = None
358         self.pstream = None
359         self.seqno = 0
360
361     @staticmethod
362     def open(name):
363         """Creates and returns a Session that maintains a JSON-RPC session to
364         'name', which should be a string acceptable to ovs.stream.Stream or
365         ovs.stream.PassiveStream's initializer.
366
367         If 'name' is an active connection method, e.g. "tcp:127.1.2.3", the new
368         session connects and reconnects, with back-off, to 'name'.
369
370         If 'name' is a passive connection method, e.g. "ptcp:", the new session
371         listens for connections to 'name'.  It maintains at most one connection
372         at any given time.  Any new connection causes the previous one (if any)
373         to be dropped."""
374         reconnect = ovs.reconnect.Reconnect(ovs.timeval.msec())
375         reconnect.set_name(name)
376         reconnect.enable(ovs.timeval.msec())
377
378         if ovs.stream.PassiveStream.is_valid_name(name):
379             reconnect.set_passive(True, ovs.timeval.msec())
380
381         if ovs.stream.stream_or_pstream_needs_probes(name):
382             reconnect.set_probe_interval(0)
383
384         return Session(reconnect, None)
385
386     @staticmethod
387     def open_unreliably(jsonrpc):
388         reconnect = ovs.reconnect.Reconnect(ovs.timeval.msec())
389         reconnect.set_quiet(True)
390         reconnect.set_name(jsonrpc.name)
391         reconnect.set_max_tries(0)
392         reconnect.connected(ovs.timeval.msec())
393         return Session(reconnect, jsonrpc)
394
395     def close(self):
396         if self.rpc is not None:
397             self.rpc.close()
398             self.rpc = None
399         if self.stream is not None:
400             self.stream.close()
401             self.stream = None
402         if self.pstream is not None:
403             self.pstream.close()
404             self.pstream = None
405
406     def __disconnect(self):
407         if self.rpc is not None:
408             self.rpc.error(EOF)
409             self.rpc.close()
410             self.rpc = None
411             self.seqno += 1
412         elif self.stream is not None:
413             self.stream.close()
414             self.stream = None
415             self.seqno += 1
416
417     def __connect(self):
418         self.__disconnect()
419
420         name = self.reconnect.get_name()
421         if not self.reconnect.is_passive():
422             error, self.stream = ovs.stream.Stream.open(name)
423             if not error:
424                 self.reconnect.connecting(ovs.timeval.msec())
425             else:
426                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
427         elif self.pstream is not None:
428             error, self.pstream = ovs.stream.PassiveStream.open(name)
429             if not error:
430                 self.reconnect.listening(ovs.timeval.msec())
431             else:
432                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
433
434         self.seqno += 1
435
436     def run(self):
437         if self.pstream is not None:
438             error, stream = self.pstream.accept()
439             if error == 0:
440                 if self.rpc or self.stream:
441                     # XXX rate-limit
442                     vlog.info("%s: new connection replacing active "
443                               "connection" % self.reconnect.get_name())
444                     self.__disconnect()
445                 self.reconnect.connected(ovs.timeval.msec())
446                 self.rpc = Connection(stream)
447             elif error != errno.EAGAIN:
448                 self.reconnect.listen_error(ovs.timeval.msec(), error)
449                 self.pstream.close()
450                 self.pstream = None
451
452         if self.rpc:
453             backlog = self.rpc.get_backlog()
454             self.rpc.run()
455             if self.rpc.get_backlog() < backlog:
456                 # Data previously caught in a queue was successfully sent (or
457                 # there's an error, which we'll catch below).
458                 #
459                 # We don't count data that is successfully sent immediately as
460                 # activity, because there's a lot of queuing downstream from
461                 # us, which means that we can push a lot of data into a
462                 # connection that has stalled and won't ever recover.
463                 self.reconnect.activity(ovs.timeval.msec())
464
465             error = self.rpc.get_status()
466             if error != 0:
467                 self.reconnect.disconnected(ovs.timeval.msec(), error)
468                 self.__disconnect()
469         elif self.stream is not None:
470             self.stream.run()
471             error = self.stream.connect()
472             if error == 0:
473                 self.reconnect.connected(ovs.timeval.msec())
474                 self.rpc = Connection(self.stream)
475                 self.stream = None
476             elif error != errno.EAGAIN:
477                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
478                 self.stream.close()
479                 self.stream = None
480
481         action = self.reconnect.run(ovs.timeval.msec())
482         if action == ovs.reconnect.CONNECT:
483             self.__connect()
484         elif action == ovs.reconnect.DISCONNECT:
485             self.reconnect.disconnected(ovs.timeval.msec(), 0)
486             self.__disconnect()
487         elif action == ovs.reconnect.PROBE:
488             if self.rpc:
489                 request = Message.create_request("echo", [])
490                 request.id = "echo"
491                 self.rpc.send(request)
492         else:
493             assert action == None
494
495     def wait(self, poller):
496         if self.rpc is not None:
497             self.rpc.wait(poller)
498         elif self.stream is not None:
499             self.stream.run_wait(poller)
500             self.stream.connect_wait(poller)
501         if self.pstream is not None:
502             self.pstream.wait(poller)
503         self.reconnect.wait(poller, ovs.timeval.msec())
504
505     def get_backlog(self):
506         if self.rpc is not None:
507             return self.rpc.get_backlog()
508         else:
509             return 0
510
511     def get_name(self):
512         return self.reconnect.get_name()
513
514     def send(self, msg):
515         if self.rpc is not None:
516             return self.rpc.send(msg)
517         else:
518             return errno.ENOTCONN
519
520     def recv(self):
521         if self.rpc is not None:
522             received_bytes = self.rpc.get_received_bytes()
523             error, msg = self.rpc.recv()
524             if received_bytes != self.rpc.get_received_bytes():
525                 # Data was successfully received.
526                 #
527                 # Previously we only counted receiving a full message as
528                 # activity, but with large messages or a slow connection that
529                 # policy could time out the session mid-message.
530                 self.reconnect.activity(ovs.timeval.msec())
531
532             if not error:
533                 if msg.type == Message.T_REQUEST and msg.method == "echo":
534                     # Echo request.  Send reply.
535                     self.send(Message.create_reply(msg.params, msg.id))
536                 elif msg.type == Message.T_REPLY and msg.id == "echo":
537                     # It's a reply to our echo request.  Suppress it.
538                     pass
539                 else:
540                     return msg
541         return None
542
543     def recv_wait(self, poller):
544         if self.rpc is not None:
545             self.rpc.recv_wait(poller)
546
547     def is_alive(self):
548         if self.rpc is not None or self.stream is not None:
549             return True
550         else:
551             max_tries = self.reconnect.get_max_tries()
552             return max_tries is None or max_tries > 0
553
554     def is_connected(self):
555         return self.rpc is not None
556
557     def get_seqno(self):
558         return self.seqno
559
560     def force_reconnect(self):
561         self.reconnect.force_reconnect(ovs.timeval.msec())