python: Workaround UNIX socket path length limits
[sliver-openvswitch.git] / python / ovs / socket_util.py
1 # Copyright (c) 2010, 2012 Nicira, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import errno
16 import os
17 import select
18 import socket
19 import sys
20
21 import ovs.fatal_signal
22 import ovs.poller
23 import ovs.vlog
24
25 vlog = ovs.vlog.Vlog("socket_util")
26
27
28 def make_unix_socket(style, nonblock, bind_path, connect_path):
29     """Creates a Unix domain socket in the given 'style' (either
30     socket.SOCK_DGRAM or socket.SOCK_STREAM) that is bound to 'bind_path' (if
31     'bind_path' is not None) and connected to 'connect_path' (if 'connect_path'
32     is not None).  If 'nonblock' is true, the socket is made non-blocking.
33
34     Returns (error, socket): on success 'error' is 0 and 'socket' is a new
35     socket object, on failure 'error' is a positive errno value and 'socket' is
36     None."""
37
38     try:
39         sock = socket.socket(socket.AF_UNIX, style)
40     except socket.error, e:
41         return get_exception_errno(e), None
42
43     try:
44         if nonblock:
45             set_nonblocking(sock)
46         if bind_path is not None:
47             # Delete bind_path but ignore ENOENT.
48             try:
49                 os.unlink(bind_path)
50             except OSError, e:
51                 if e.errno != errno.ENOENT:
52                     return e.errno, None
53
54             ovs.fatal_signal.add_file_to_unlink(bind_path)
55             sock.bind(bind_path)
56
57             try:
58                 if sys.hexversion >= 0x02060000:
59                     os.fchmod(sock.fileno(), 0700)
60                 else:
61                     os.chmod("/dev/fd/%d" % sock.fileno(), 0700)
62             except OSError, e:
63                 pass
64         if connect_path is not None:
65             try:
66                 sock.connect(connect_path)
67             except socket.error, e:
68                 if get_exception_errno(e) != errno.EINPROGRESS:
69                     raise
70         return 0, sock
71     except socket.error, e:
72         sock.close()
73         if (bind_path is not None and
74             os.path.exists(bind_path)):
75             ovs.fatal_signal.unlink_file_now(bind_path)
76         eno = ovs.socket_util.get_exception_errno(e)
77         if (eno == "AF_UNIX path too long" and
78             os.uname()[0] == "Linux"):
79             short_connect_path = None
80             short_bind_path = None
81             connect_dirfd = None
82             bind_dirfd = None
83             # Try workaround using /proc/self/fd
84             if connect_path is not None:
85                 dirname = os.path.dirname(connect_path)
86                 basename = os.path.basename(connect_path)
87                 try:
88                     connect_dirfd = os.open(dirname, os.O_DIRECTORY | os.O_RDONLY)
89                 except OSError, err:
90                     return get_exception_errno(e), None
91                 short_connect_path = "/proc/self/fd/%d/%s" % (connect_dirfd, basename)
92
93             if bind_path is not None:
94                 dirname = os.path.dirname(bind_path)
95                 basename = os.path.basename(bind_path)
96                 try:
97                     bind_dirfd = os.open(dirname, os.O_DIRECTORY | os.O_RDONLY)
98                 except OSError, err:
99                     return get_exception_errno(e), None
100                 short_bind_path = "/proc/self/fd/%d/%s" % (bind_dirfd, basename)
101
102             try:
103                 return make_unix_socket(style, nonblock, short_bind_path, short_connect_path)
104             finally:
105                 if connect_dirfd is not None:
106                     os.close(connect_dirfd)
107                 if bind_dirfd is not None:
108                     os.close(bind_dirfd)
109         else:
110             return get_exception_errno(e), None
111
112
113 def check_connection_completion(sock):
114     p = ovs.poller.SelectPoll()
115     p.register(sock, ovs.poller.POLLOUT)
116     pfds = p.poll(0)
117     if len(pfds) == 1:
118         revents = pfds[0][1]
119         if revents & ovs.poller.POLLERR:
120             try:
121                 # The following should raise an exception.
122                 socket.send("\0", socket.MSG_DONTWAIT)
123
124                 # (Here's where we end up if it didn't.)
125                 # XXX rate-limit
126                 vlog.err("poll return POLLERR but send succeeded")
127                 return errno.EPROTO
128             except socket.error, e:
129                 return get_exception_errno(e)
130         else:
131             return 0
132     else:
133         return errno.EAGAIN
134
135
136 def inet_parse_active(target, default_port):
137     address = target.split(":")
138     host_name = address[0]
139     if not host_name:
140         raise ValueError("%s: bad peer name format" % target)
141     if len(address) >= 2:
142         port = int(address[1])
143     elif default_port:
144         port = default_port
145     else:
146         raise ValueError("%s: port number must be specified" % target)
147     return (host_name, port)
148
149
150 def inet_open_active(style, target, default_port, dscp):
151     address = inet_parse_active(target, default_port)
152     try:
153         sock = socket.socket(socket.AF_INET, style, 0)
154     except socket.error, e:
155         return get_exception_errno(e), None
156
157     try:
158         set_nonblocking(sock)
159         set_dscp(sock, dscp)
160         try:
161             sock.connect(address)
162         except socket.error, e:
163             if get_exception_errno(e) != errno.EINPROGRESS:
164                 raise
165         return 0, sock
166     except socket.error, e:
167         sock.close()
168         return get_exception_errno(e), None
169
170
171 def get_exception_errno(e):
172     """A lot of methods on Python socket objects raise socket.error, but that
173     exception is documented as having two completely different forms of
174     arguments: either a string or a (errno, string) tuple.  We only want the
175     errno."""
176     if type(e.args) == tuple:
177         return e.args[0]
178     else:
179         return errno.EPROTO
180
181
182 null_fd = -1
183
184
185 def get_null_fd():
186     """Returns a readable and writable fd for /dev/null, if successful,
187     otherwise a negative errno value.  The caller must not close the returned
188     fd (because the same fd will be handed out to subsequent callers)."""
189     global null_fd
190     if null_fd < 0:
191         try:
192             null_fd = os.open("/dev/null", os.O_RDWR)
193         except OSError, e:
194             vlog.err("could not open /dev/null: %s" % os.strerror(e.errno))
195             return -e.errno
196     return null_fd
197
198
199 def write_fully(fd, buf):
200     """Returns an (error, bytes_written) tuple where 'error' is 0 on success,
201     otherwise a positive errno value, and 'bytes_written' is the number of
202     bytes that were written before the error occurred.  'error' is 0 if and
203     only if 'bytes_written' is len(buf)."""
204     bytes_written = 0
205     if len(buf) == 0:
206         return 0, 0
207     while True:
208         try:
209             retval = os.write(fd, buf)
210             assert retval >= 0
211             if retval == len(buf):
212                 return 0, bytes_written + len(buf)
213             elif retval == 0:
214                 vlog.warn("write returned 0")
215                 return errno.EPROTO, bytes_written
216             else:
217                 bytes_written += retval
218                 buf = buf[:retval]
219         except OSError, e:
220             return e.errno, bytes_written
221
222
223 def set_nonblocking(sock):
224     try:
225         sock.setblocking(0)
226     except socket.error, e:
227         vlog.err("could not set nonblocking mode on socket: %s"
228                  % os.strerror(get_exception_errno(e)))
229
230
231 def set_dscp(sock, dscp):
232     if dscp > 63:
233         raise ValueError("Invalid dscp %d" % dscp)
234     val = dscp << 2
235     sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, val)