fix some logic flaws
[util-vserver.git] / src / rebootmgr.c
1 // $Id: rebootmgr.c 923 2004-02-17 19:55:54Z ensc $
2
3 // Copyright (C) 2003 Enrico Scholz <enrico.scholz@informatik.tu-chemnitz.de>
4 // based on rebootmgr.cc by Jacques Gelinas
5 //  
6 // This program is free software; you can redistribute it and/or modify
7 // it under the terms of the GNU General Public License as published by
8 // the Free Software Foundation; either version 2, or (at your option)
9 // any later version.
10 //  
11 // This program is distributed in the hope that it will be useful,
12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 // GNU General Public License for more details.
15 //  
16 // You should have received a copy of the GNU General Public License
17 // along with this program; if not, write to the Free Software
18 // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
19
20 /*
21         The reboot manager allow a virtual server administrator to request
22         a complete restart of his vserver. This means that all services
23         are terminated, all remaining processes are killed and then
24         all services are started.
25
26         This is done by issuing
27
28                 /usr/sbin/vserver vserver restart
29
30
31         The rebootmgr installs a unix domain socket in each vservers
32         and listen for the reboot messages. All other message are discarded.
33
34         The unix domain socket is placed in /vservers/N/dev/reboot and is
35         turned immutable.
36
37         The vreboot utility is used to send the signal from the vserver
38         environment.
39 */
40 #ifdef HAVE_CONFIG_H
41 #  include <config.h>
42 #endif
43 #include "pathconfig.h"
44
45 #include <stdio.h>
46 #include <unistd.h>
47 #include <stdlib.h>
48 #include <sys/types.h>
49 #include <limits.h>
50 #include <errno.h>
51 #include <syslog.h>
52 #include <sys/stat.h>
53 #include <sys/time.h>
54 #include <sys/socket.h>
55 #include <sys/un.h>
56 #include <alloca.h>
57
58 static void usage()
59 {
60         fprintf (stderr,"rebootmgr version %s\n",VERSION);
61         fprintf (stderr,"\n");
62         fprintf (stderr,"rebootmgr [--pidfile file ] vserver-name [ vserver-name ...]\n");
63 }
64
65 static int rebootmgr_opensocket (const char *vname)
66 {
67         int ret = -1;
68         char sockn[PATH_MAX];
69         int fd =  socket (AF_UNIX,SOCK_STREAM,0);
70         sprintf (sockn, DEFAULT_VSERVERDIR "/%s/dev/reboot",vname);
71         unlink (sockn);
72         if (fd == -1){
73                 fprintf (stderr,"Can't create a unix domain socket (%s)\n"
74                                 ,strerror(errno));
75         }else{
76                 struct sockaddr_un un;
77                 un.sun_family = AF_UNIX;
78                 strcpy (un.sun_path,sockn);
79                 if (bind(fd,(struct sockaddr*)&un,sizeof(un))==-1){
80                         fprintf (stderr,"Can't bind to file %s (%s)\n",sockn
81                                 ,strerror(errno));
82                 }else{
83                         int code;
84                         chmod (sockn,0600);
85                         code = listen (fd,10);
86                         if (code == -1){
87                                 fprintf (stderr,"Can't listen to file %s (%s)\n",sockn
88                                         ,strerror(errno));
89                         }else{
90                                 ret = fd;
91                         }       
92                 }
93         }
94         return ret;
95 }
96
97 static int rebootmgr_process (int fd, const char *vname)
98 {
99         int ret = -1;
100         char buf[100];
101         int len = read (fd,buf,sizeof(buf)-1);
102         // fprintf (stderr,"process %d %s len %d\n",fd,vname,len);
103         if (len > 0){
104                 buf[len] = '\0';
105                 if (strcmp(buf,"reboot\n")==0){
106                         char cmd[1000];
107                         syslog (LOG_NOTICE,"reboot vserver %s\n",vname);
108                         snprintf (cmd,sizeof(cmd)-1, SBINDIR "/vserver %s restart >>/var/log/boot.log 2>&1", vname);
109                         system (cmd);
110                         ret = 0;
111                 }else if (strcmp(buf,"halt\n")==0){
112                         char cmd[1000];
113                         syslog (LOG_NOTICE,"halt vserver %s\n",vname);
114                         snprintf (cmd,sizeof(cmd)-1, SBINDIR "/vserver %s stop >>/var/log/boot.log 2>&1", vname);
115                         system (cmd);
116                         ret = 0;
117                 }else{
118                         syslog (LOG_ERR,"Invalid request from vserver %s",vname);
119                 }
120         }
121         return ret;
122 }
123
124
125 int main (int argc, char *argv[])
126 {
127         int ret = -1;
128         if (argc < 2){
129                 usage();
130         }else{
131                 int error = 0;
132                 int start = 1;
133                 int i;
134                 int *sockets = alloca(argc * sizeof(int));
135
136                 openlog ("rebootmgr",LOG_PID,LOG_DAEMON);
137                 for (i=0; i<argc; i++){
138                         const char *arg = argv[i];
139                         if (strcmp(arg,"--pidfile")==0){
140                                 const char *pidfile = argv[i+1];
141                                 FILE *fout = fopen (pidfile,"w");
142                                 if (fout == NULL){
143                                         fprintf (stderr,"Can't open pidfile %s (%s)\n"
144                                                 ,pidfile,strerror(errno));
145
146                                         __extension__
147                                         syslog (LOG_ERR,"Can't open pidfile %s (%m)"
148                                                 ,pidfile);
149                                 }else{
150                                         fprintf (fout,"%d\n",getpid());
151                                         fclose (fout);
152                                 }
153                                 start = i+2;
154                                 i++;
155                         }else if (strcmp(arg,"--")==0){
156                                 start = i+1;
157                                 break;
158                         }else if (arg[0] == '-'){
159                                 fprintf (stderr,"Invalid argument %s\n",arg);
160                                 syslog (LOG_ERR,"Invalid argument %s",arg);
161                         }
162                 }
163                 for (i=start; i<argc; i++){
164                         int fd = rebootmgr_opensocket (argv[i]);
165                         if (fd == -1){
166                                 error = 1;
167                         }else{
168                                 sockets[i] = fd;
169                         }
170                 }
171                 if (!error){
172                         int maxhandles = argc*2;
173                         struct {
174                                 int handle;
175                                 const char *vname;
176                         } handles[maxhandles];
177                         int nbhandles=0;
178                         while (1){
179                                 int maxfd = 0;
180                                 int i;
181                                 int ok;
182                                 
183                                 fd_set fdin;
184                                 FD_ZERO (&fdin);
185                                 for (i=start; i<argc; i++){
186                                         int fd = sockets[i];
187                                         if (fd > maxfd) maxfd = fd;
188                                         FD_SET (fd,&fdin);
189                                 }
190                                 for (i=0; i<nbhandles; i++){
191                                         int fd = handles[i].handle;
192                                         if (fd > maxfd) maxfd = fd;
193                                         FD_SET (fd,&fdin);
194                                 }
195                                 ok = select (maxfd+1,&fdin,NULL,NULL,NULL);
196                                 if (ok <= 0){
197                                         break;
198                                 }else{
199                                         int i;
200                                         int dst = 0;
201
202                                         for (i=start; i<argc; i++){
203                                                 int fd = sockets[i];
204                                                 if (FD_ISSET(fd,&fdin)){
205                                                         struct sockaddr_un unc;
206                                                         socklen_t len = sizeof(unc);
207                                                         unc.sun_family = AF_UNIX;
208                                                         fd = accept (fd,(struct sockaddr*)&unc,&len);
209                                                         if (fd != -1){
210                                                                 if (nbhandles == maxhandles){
211                                                                         int j;
212                                                                         // Overloaded, we close every handle
213                                                                         syslog (LOG_ERR,"%d sockets opened: Overloaded\n",nbhandles);
214                                                                         for (j=0; j<nbhandles; j++){
215                                                                                 close (handles[j].handle);
216                                                                         }
217                                                                         nbhandles = 0;
218                                                                 }
219                                                                 handles[nbhandles].handle = fd;
220                                                                 handles[nbhandles].vname = argv[i];
221                                                                 nbhandles++;
222                                                                 // fprintf (stderr,"accept %d\n",nbhandles);
223                                                         }
224                                                 }
225                                         }
226                                         for (i=0; i<nbhandles; i++){
227                                                 int fd = handles[i].handle;
228                                                 if (FD_ISSET(fd,&fdin)){
229                                                         if (rebootmgr_process (fd,handles[i].vname)==-1){
230                                                                 close (fd);
231                                                         }else{
232                                                                 handles[dst++] = handles[i];
233                                                         }
234                                                 }else{
235                                                         handles[dst++] = handles[i];
236                                                 }
237                                         }
238                                         nbhandles = dst;
239                                 }
240                         }
241                 }
242         }
243         return ret;
244 }
245
246