X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=net%2Fnetlink%2Faf_netlink.c;h=87d8074b7757a534bd133bc84282abeefbd0b717;hb=refs%2Fheads%2Fvserver;hp=2cb02877f00c6c3dfc3fd379c82161204f7791ef;hpb=76828883507a47dae78837ab5dec5a5b4513c667;p=linux-2.6.git diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 2cb02877f..87d8074b7 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -21,7 +21,6 @@ * mandatory if CONFIG_NET=y these days */ -#include #include #include @@ -56,6 +55,7 @@ #include #include #include +#include #include #include #include @@ -64,7 +64,6 @@ #include #include -#define Nprintk(a...) #define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8) struct netlink_sock { @@ -109,6 +108,7 @@ struct nl_pid_hash { struct netlink_table { struct nl_pid_hash hash; struct hlist_head mc_list; + unsigned long *listeners; unsigned int nl_nonroot; unsigned int groups; struct module *module; @@ -125,7 +125,7 @@ static void netlink_destroy_callback(struct netlink_callback *cb); static DEFINE_RWLOCK(nl_table_lock); static atomic_t nl_table_users = ATOMIC_INIT(0); -static struct notifier_block *netlink_chain; +static ATOMIC_NOTIFIER_HEAD(netlink_chain); static u32 netlink_group_mask(u32 group) { @@ -159,7 +159,7 @@ static void netlink_sock_destruct(struct sock *sk) static void netlink_table_grab(void) { - write_lock_bh(&nl_table_lock); + write_lock_irq(&nl_table_lock); if (atomic_read(&nl_table_users)) { DECLARE_WAITQUEUE(wait, current); @@ -169,9 +169,9 @@ static void netlink_table_grab(void) set_current_state(TASK_UNINTERRUPTIBLE); if (atomic_read(&nl_table_users) == 0) break; - write_unlock_bh(&nl_table_lock); + write_unlock_irq(&nl_table_lock); schedule(); - write_lock_bh(&nl_table_lock); + write_lock_irq(&nl_table_lock); } __set_current_state(TASK_RUNNING); @@ -181,7 +181,7 @@ static void netlink_table_grab(void) static __inline__ void netlink_table_ungrab(void) { - write_unlock_bh(&nl_table_lock); + write_unlock_irq(&nl_table_lock); wake_up(&nl_table_wait); } @@ -299,6 +299,24 @@ static inline int nl_pid_hash_dilute(struct nl_pid_hash *hash, int len) static const struct proto_ops netlink_ops; +static void +netlink_update_listeners(struct sock *sk) +{ + struct netlink_table *tbl = &nl_table[sk->sk_protocol]; + struct hlist_node *node; + unsigned long mask; + unsigned int i; + + for (i = 0; i < NLGRPSZ(tbl->groups)/sizeof(unsigned long); i++) { + mask = 0; + sk_for_each_bound(sk, node, &tbl->mc_list) + mask |= nlk_sk(sk)->groups[i]; + tbl->listeners[i] = mask; + } + /* this function is only called with the netlink table "grabbed", which + * makes sure updates are visible before bind or setsockopt return. */ +} + static int netlink_insert(struct sock *sk, u32 pid) { struct nl_pid_hash *hash = &nl_table[sk->sk_protocol].hash; @@ -453,18 +471,20 @@ static int netlink_release(struct socket *sock) .protocol = sk->sk_protocol, .pid = nlk->pid, }; - notifier_call_chain(&netlink_chain, NETLINK_URELEASE, &n); + atomic_notifier_call_chain(&netlink_chain, + NETLINK_URELEASE, &n); } - if (nlk->module) - module_put(nlk->module); + module_put(nlk->module); + netlink_table_grab(); if (nlk->flags & NETLINK_KERNEL_SOCKET) { - netlink_table_grab(); + kfree(nl_table[sk->sk_protocol].listeners); nl_table[sk->sk_protocol].module = NULL; nl_table[sk->sk_protocol].registered = 0; - netlink_table_ungrab(); - } + } else if (nlk->subscriptions) + netlink_update_listeners(sk); + netlink_table_ungrab(); kfree(nlk->groups); nlk->groups = NULL; @@ -544,10 +564,9 @@ static int netlink_alloc_groups(struct sock *sk) if (err) return err; - nlk->groups = kmalloc(NLGRPSZ(groups), GFP_KERNEL); + nlk->groups = kzalloc(NLGRPSZ(groups), GFP_KERNEL); if (nlk->groups == NULL) return -ENOMEM; - memset(nlk->groups, 0, NLGRPSZ(groups)); nlk->ngroups = groups; return 0; } @@ -592,6 +611,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len hweight32(nladdr->nl_groups) - hweight32(nlk->groups[0])); nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups; + netlink_update_listeners(sk); netlink_table_ungrab(); return 0; @@ -681,7 +701,7 @@ static struct sock *netlink_getsockbypid(struct sock *ssk, u32 pid) struct sock *netlink_getsockbyfilp(struct file *filp) { - struct inode *inode = filp->f_dentry->d_inode; + struct inode *inode = filp->f_path.dentry->d_inode; struct sock *sock; if (!S_ISSOCK(inode->i_mode)) @@ -810,6 +830,17 @@ retry: return netlink_sendskb(sk, skb, ssk->sk_protocol); } +int netlink_has_listeners(struct sock *sk, unsigned int group) +{ + int res = 0; + + BUG_ON(!(nlk_sk(sk)->flags & NETLINK_KERNEL_SOCKET)); + if (group - 1 < nl_table[sk->sk_protocol].groups) + res = test_bit(group - 1, nl_table[sk->sk_protocol].listeners); + return res; +} +EXPORT_SYMBOL_GPL(netlink_has_listeners); + static __inline__ int netlink_broadcast_deliver(struct sock *sk, struct sk_buff *skb) { struct netlink_sock *nlk = nlk_sk(sk); @@ -1014,6 +1045,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname, else __clear_bit(val - 1, nlk->groups); netlink_update_subscriptions(sk, subscriptions); + netlink_update_listeners(sk); netlink_table_ungrab(); err = 0; break; @@ -1045,8 +1077,9 @@ static int netlink_getsockopt(struct socket *sock, int level, int optname, return -EINVAL; len = sizeof(int); val = nlk->flags & NETLINK_RECV_PKTINFO ? 1 : 0; - put_user(len, optlen); - put_user(val, optval); + if (put_user(len, optlen) || + put_user(val, optval)) + return -EFAULT; err = 0; break; default: @@ -1122,9 +1155,9 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock, goto out; NETLINK_CB(skb).pid = nlk->pid; - NETLINK_CB(skb).dst_pid = dst_pid; NETLINK_CB(skb).dst_group = dst_group; NETLINK_CB(skb).loginuid = audit_get_loginuid(current->audit_context); + selinux_get_task_sid(current, &(NETLINK_CB(skb).sid)); memcpy(NETLINK_CREDS(skb), &siocb->scm->creds, sizeof(struct ucred)); /* What can I do? Netlink is asynchronous, so that @@ -1240,9 +1273,9 @@ netlink_kernel_create(int unit, unsigned int groups, struct socket *sock; struct sock *sk; struct netlink_sock *nlk; + unsigned long *listeners = NULL; - if (!nl_table) - return NULL; + BUG_ON(!nl_table); if (unit<0 || unit>=MAX_LINKS) return NULL; @@ -1253,6 +1286,13 @@ netlink_kernel_create(int unit, unsigned int groups, if (__netlink_create(sock, unit) < 0) goto out_sock_release; + if (groups < 32) + groups = 32; + + listeners = kzalloc(NLGRPSZ(groups), GFP_KERNEL); + if (!listeners) + goto out_sock_release; + sk = sock->sk; sk->sk_data_ready = netlink_data_ready; if (input) @@ -1265,7 +1305,8 @@ netlink_kernel_create(int unit, unsigned int groups, nlk->flags |= NETLINK_KERNEL_SOCKET; netlink_table_grab(); - nl_table[unit].groups = groups < 32 ? 32 : groups; + nl_table[unit].groups = groups; + nl_table[unit].listeners = listeners; nl_table[unit].module = module; nl_table[unit].registered = 1; netlink_table_ungrab(); @@ -1273,6 +1314,7 @@ netlink_kernel_create(int unit, unsigned int groups, return sk; out_sock_release: + kfree(listeners); sock_release(sock); return NULL; } @@ -1301,19 +1343,18 @@ static int netlink_dump(struct sock *sk) struct netlink_callback *cb; struct sk_buff *skb; struct nlmsghdr *nlh; - int len; + int len, err = -ENOBUFS; skb = sock_rmalloc(sk, NLMSG_GOODSIZE, 0, GFP_KERNEL); if (!skb) - return -ENOBUFS; + goto errout; spin_lock(&nlk->cb_lock); cb = nlk->cb; if (cb == NULL) { - spin_unlock(&nlk->cb_lock); - kfree_skb(skb); - return -EINVAL; + err = -EINVAL; + goto errout_skb; } len = cb->dump(skb, cb); @@ -1325,8 +1366,12 @@ static int netlink_dump(struct sock *sk) return 0; } - nlh = NLMSG_NEW_ANSWER(skb, cb, NLMSG_DONE, sizeof(len), NLM_F_MULTI); - memcpy(NLMSG_DATA(nlh), &len, sizeof(len)); + nlh = nlmsg_put_answer(skb, cb, NLMSG_DONE, sizeof(len), NLM_F_MULTI); + if (!nlh) + goto errout_skb; + + memcpy(nlmsg_data(nlh), &len, sizeof(len)); + skb_queue_tail(&sk->sk_receive_queue, skb); sk->sk_data_ready(sk, skb->len); @@ -1338,8 +1383,11 @@ static int netlink_dump(struct sock *sk) netlink_destroy_callback(cb); return 0; -nlmsg_failure: - return -ENOBUFS; +errout_skb: + spin_unlock(&nlk->cb_lock); + kfree_skb(skb); +errout: + return err; } int netlink_dump_start(struct sock *ssk, struct sk_buff *skb, @@ -1351,11 +1399,10 @@ int netlink_dump_start(struct sock *ssk, struct sk_buff *skb, struct sock *sk; struct netlink_sock *nlk; - cb = kmalloc(sizeof(*cb), GFP_KERNEL); + cb = kzalloc(sizeof(*cb), GFP_KERNEL); if (cb == NULL) return -ENOBUFS; - memset(cb, 0, sizeof(*cb)); cb->dump = dump; cb->done = done; cb->nlh = nlh; @@ -1389,14 +1436,13 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err) struct sk_buff *skb; struct nlmsghdr *rep; struct nlmsgerr *errmsg; - int size; + size_t payload = sizeof(*errmsg); - if (err == 0) - size = NLMSG_SPACE(sizeof(struct nlmsgerr)); - else - size = NLMSG_SPACE(4 + NLMSG_ALIGN(nlh->nlmsg_len)); + /* error messages get the original request appened */ + if (err) + payload += nlmsg_len(nlh); - skb = alloc_skb(size, GFP_KERNEL); + skb = nlmsg_new(payload, GFP_KERNEL); if (!skb) { struct sock *sk; @@ -1412,16 +1458,15 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err) rep = __nlmsg_put(skb, NETLINK_CB(in_skb).pid, nlh->nlmsg_seq, NLMSG_ERROR, sizeof(struct nlmsgerr), 0); - errmsg = NLMSG_DATA(rep); + errmsg = nlmsg_data(rep); errmsg->error = err; - memcpy(&errmsg->msg, nlh, err ? nlh->nlmsg_len : sizeof(struct nlmsghdr)); + memcpy(&errmsg->msg, nlh, err ? nlh->nlmsg_len : sizeof(*nlh)); netlink_unicast(in_skb->sk, skb, NETLINK_CB(in_skb).pid, MSG_DONTWAIT); } static int netlink_rcv_skb(struct sk_buff *skb, int (*cb)(struct sk_buff *, struct nlmsghdr *, int *)) { - unsigned int total_len; struct nlmsghdr *nlh; int err; @@ -1431,8 +1476,6 @@ static int netlink_rcv_skb(struct sk_buff *skb, int (*cb)(struct sk_buff *, if (nlh->nlmsg_len < NLMSG_HDRLEN || skb->len < nlh->nlmsg_len) return 0; - total_len = min(NLMSG_ALIGN(nlh->nlmsg_len), skb->len); - if (cb(skb, nlh, &err) < 0) { /* Not an error, but we have to interrupt processing * here. Note: that in this case we do not pull @@ -1444,7 +1487,7 @@ static int netlink_rcv_skb(struct sk_buff *skb, int (*cb)(struct sk_buff *, } else if (nlh->nlmsg_flags & NLM_F_ACK) netlink_ack(skb, nlh, 0); - skb_pull(skb, total_len); + netlink_queue_skip(nlh, skb); } return 0; @@ -1507,6 +1550,38 @@ void netlink_queue_skip(struct nlmsghdr *nlh, struct sk_buff *skb) skb_pull(skb, msglen); } +/** + * nlmsg_notify - send a notification netlink message + * @sk: netlink socket to use + * @skb: notification message + * @pid: destination netlink pid for reports or 0 + * @group: destination multicast group or 0 + * @report: 1 to report back, 0 to disable + * @flags: allocation flags + */ +int nlmsg_notify(struct sock *sk, struct sk_buff *skb, u32 pid, + unsigned int group, int report, gfp_t flags) +{ + int err = 0; + + if (group) { + int exclude_pid = 0; + + if (report) { + atomic_inc(&skb->users); + exclude_pid = pid; + } + + /* errors reported via destination sk->sk_err */ + nlmsg_multicast(sk, skb, exclude_pid, group, flags); + } + + if (report) + err = nlmsg_unicast(sk, skb, pid); + + return err; +} + #ifdef CONFIG_PROC_FS struct nl_seq_iter { int link; @@ -1626,7 +1701,7 @@ static int netlink_seq_open(struct inode *inode, struct file *file) struct nl_seq_iter *iter; int err; - iter = kmalloc(sizeof(*iter), GFP_KERNEL); + iter = kzalloc(sizeof(*iter), GFP_KERNEL); if (!iter) return -ENOMEM; @@ -1636,7 +1711,6 @@ static int netlink_seq_open(struct inode *inode, struct file *file) return err; } - memset(iter, 0, sizeof(*iter)); seq = file->private_data; seq->private = iter; return 0; @@ -1654,12 +1728,12 @@ static struct file_operations netlink_seq_fops = { int netlink_register_notifier(struct notifier_block *nb) { - return notifier_chain_register(&netlink_chain, nb); + return atomic_notifier_chain_register(&netlink_chain, nb); } int netlink_unregister_notifier(struct notifier_block *nb) { - return notifier_chain_unregister(&netlink_chain, nb); + return atomic_notifier_chain_unregister(&netlink_chain, nb); } static const struct proto_ops netlink_ops = { @@ -1689,8 +1763,6 @@ static struct net_proto_family netlink_family_ops = { .owner = THIS_MODULE, /* for consistency 8) */ }; -extern void netlink_skb_parms_too_large(void); - static int __init netlink_proto_init(void) { struct sk_buff *dummy_skb; @@ -1702,17 +1774,11 @@ static int __init netlink_proto_init(void) if (err != 0) goto out; - if (sizeof(struct netlink_skb_parms) > sizeof(dummy_skb->cb)) - netlink_skb_parms_too_large(); - - nl_table = kmalloc(sizeof(*nl_table) * MAX_LINKS, GFP_KERNEL); - if (!nl_table) { -enomem: - printk(KERN_CRIT "netlink_init: Cannot allocate nl_table\n"); - return -ENOMEM; - } + BUILD_BUG_ON(sizeof(struct netlink_skb_parms) > sizeof(dummy_skb->cb)); - memset(nl_table, 0, sizeof(*nl_table) * MAX_LINKS); + nl_table = kcalloc(MAX_LINKS, sizeof(*nl_table), GFP_KERNEL); + if (!nl_table) + goto panic; if (num_physpages >= (128 * 1024)) max = num_physpages >> (21 - PAGE_SHIFT); @@ -1732,7 +1798,7 @@ enomem: nl_pid_hash_free(nl_table[i].hash.table, 1 * sizeof(*hash->table)); kfree(nl_table); - goto enomem; + goto panic; } memset(hash->table, 0, 1 * sizeof(*hash->table)); hash->max_shift = order; @@ -1749,6 +1815,8 @@ enomem: rtnetlink_init(); out: return err; +panic: + panic("netlink_init: Cannot allocate nl_table\n"); } core_initcall(netlink_proto_init); @@ -1764,4 +1832,4 @@ EXPORT_SYMBOL(netlink_set_err); EXPORT_SYMBOL(netlink_set_nonroot); EXPORT_SYMBOL(netlink_unicast); EXPORT_SYMBOL(netlink_unregister_notifier); - +EXPORT_SYMBOL(nlmsg_notify);