X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=net%2Fatm%2Fcommon.c;h=ae002220fa99096aa00edc4ce0fd8df3e5c8694b;hb=43bc926fffd92024b46cafaf7350d669ba9ca884;hp=859b727395d7022f770da0356f2d25651115cb9d;hpb=5273a3df6485dc2ad6aa7ddd441b9a21970f003b;p=linux-2.6.git diff --git a/net/atm/common.c b/net/atm/common.c index 859b72739..ae002220f 100644 --- a/net/atm/common.c +++ b/net/atm/common.c @@ -12,7 +12,7 @@ #include /* SOL_SOCKET */ #include /* error codes */ #include -#include /* verify_area */ +#include #include #include /* struct timeval */ #include @@ -39,14 +39,14 @@ #endif struct hlist_head vcc_hash[VCC_HTABLE_SIZE]; -rwlock_t vcc_sklist_lock = RW_LOCK_UNLOCKED; +DEFINE_RWLOCK(vcc_sklist_lock); -void __vcc_insert_socket(struct sock *sk) +static void __vcc_insert_socket(struct sock *sk) { struct atm_vcc *vcc = atm_sk(sk); struct hlist_head *head = &vcc_hash[vcc->vci & (VCC_HTABLE_SIZE - 1)]; - sk->sk_hashent = vcc->vci & (VCC_HTABLE_SIZE - 1); + sk->sk_hash = vcc->vci & (VCC_HTABLE_SIZE - 1); sk_add_node(sk, head); } @@ -57,7 +57,7 @@ void vcc_insert_socket(struct sock *sk) write_unlock_irq(&vcc_sklist_lock); } -void vcc_remove_socket(struct sock *sk) +static void vcc_remove_socket(struct sock *sk) { write_lock_irq(&vcc_sklist_lock); sk_del_node_init(sk); @@ -68,17 +68,18 @@ void vcc_remove_socket(struct sock *sk) static struct sk_buff *alloc_tx(struct atm_vcc *vcc,unsigned int size) { struct sk_buff *skb; + struct sock *sk = sk_atm(vcc); - if (atomic_read(&vcc->sk->sk_wmem_alloc) && !atm_may_send(vcc, size)) { + if (atomic_read(&sk->sk_wmem_alloc) && !atm_may_send(vcc, size)) { DPRINTK("Sorry: wmem_alloc = %d, size = %d, sndbuf = %d\n", - atomic_read(&vcc->sk->sk_wmem_alloc), size, - vcc->sk->sk_sndbuf); + atomic_read(&sk->sk_wmem_alloc), size, + sk->sk_sndbuf); return NULL; } while (!(skb = alloc_skb(size,GFP_KERNEL))) schedule(); - DPRINTK("AlTx %d += %d\n", atomic_read(&vcc->sk->sk_wmem_alloc), + DPRINTK("AlTx %d += %d\n", atomic_read(&sk->sk_wmem_alloc), skb->truesize); - atomic_add(skb->truesize, &vcc->sk->sk_wmem_alloc); + atomic_add(skb->truesize, &sk->sk_wmem_alloc); return skb; } @@ -86,19 +87,14 @@ static struct sk_buff *alloc_tx(struct atm_vcc *vcc,unsigned int size) EXPORT_SYMBOL(vcc_hash); EXPORT_SYMBOL(vcc_sklist_lock); EXPORT_SYMBOL(vcc_insert_socket); -EXPORT_SYMBOL(vcc_remove_socket); static void vcc_sock_destruct(struct sock *sk) { - struct atm_vcc *vcc = atm_sk(sk); - - if (atomic_read(&vcc->sk->sk_rmem_alloc)) + if (atomic_read(&sk->sk_rmem_alloc)) printk(KERN_DEBUG "vcc_sock_destruct: rmem leakage (%d bytes) detected.\n", atomic_read(&sk->sk_rmem_alloc)); - if (atomic_read(&vcc->sk->sk_wmem_alloc)) + if (atomic_read(&sk->sk_wmem_alloc)) printk(KERN_DEBUG "vcc_sock_destruct: wmem leakage (%d bytes) detected.\n", atomic_read(&sk->sk_wmem_alloc)); - - kfree(sk->sk_protinfo); } static void vcc_def_wakeup(struct sock *sk) @@ -131,6 +127,11 @@ static void vcc_write_space(struct sock *sk) read_unlock(&sk->sk_callback_lock); } +static struct proto vcc_proto = { + .name = "VCC", + .owner = THIS_MODULE, + .obj_size = sizeof(struct atm_vcc), +}; int vcc_create(struct socket *sock, int protocol, int family) { @@ -140,35 +141,26 @@ int vcc_create(struct socket *sock, int protocol, int family) sock->sk = NULL; if (sock->type == SOCK_STREAM) return -EINVAL; - sk = sk_alloc(family, GFP_KERNEL, 1, NULL); + sk = sk_alloc(family, GFP_KERNEL, &vcc_proto, 1); if (!sk) return -ENOMEM; sock_init_data(sock, sk); - sk_set_owner(sk, THIS_MODULE); sk->sk_state_change = vcc_def_wakeup; sk->sk_write_space = vcc_write_space; - vcc = sk->sk_protinfo = kmalloc(sizeof(*vcc), GFP_KERNEL); - if (!vcc) { - sk_free(sk); - return -ENOMEM; - } - - memset(vcc, 0, sizeof(*vcc)); - vcc->sk = sk; + vcc = atm_sk(sk); vcc->dev = NULL; memset(&vcc->local,0,sizeof(struct sockaddr_atmsvc)); memset(&vcc->remote,0,sizeof(struct sockaddr_atmsvc)); vcc->qos.txtp.max_sdu = 1 << 16; /* for meta VCs */ - atomic_set(&vcc->sk->sk_wmem_alloc, 0); - atomic_set(&vcc->sk->sk_rmem_alloc, 0); + atomic_set(&sk->sk_wmem_alloc, 0); + atomic_set(&sk->sk_rmem_alloc, 0); vcc->push = NULL; vcc->pop = NULL; vcc->push_oam = NULL; vcc->vpi = vcc->vci = 0; /* no VCI/VPI yet */ vcc->atm_options = vcc->aal_options = 0; sk->sk_destruct = vcc_sock_destruct; - sock->sk = sk; return 0; } @@ -178,6 +170,7 @@ static void vcc_destroy_socket(struct sock *sk) struct atm_vcc *vcc = atm_sk(sk); struct sk_buff *skb; + set_bit(ATM_VF_CLOSE, &vcc->flags); clear_bit(ATM_VF_READY, &vcc->flags); if (vcc->dev) { if (vcc->dev->ops->close) @@ -185,9 +178,7 @@ static void vcc_destroy_socket(struct sock *sk) if (vcc->push) vcc->push(vcc, NULL); /* atmarpd has no push */ - vcc_remove_socket(sk); /* no more receive */ - - while ((skb = skb_dequeue(&vcc->sk->sk_receive_queue))) { + while ((skb = skb_dequeue(&sk->sk_receive_queue)) != NULL) { atm_return(vcc,skb->truesize); kfree_skb(skb); } @@ -195,6 +186,8 @@ static void vcc_destroy_socket(struct sock *sk) module_put(vcc->dev->ops->owner); atm_dev_put(vcc->dev); } + + vcc_remove_socket(sk); } @@ -215,16 +208,42 @@ int vcc_release(struct socket *sock) void vcc_release_async(struct atm_vcc *vcc, int reply) { + struct sock *sk = sk_atm(vcc); + set_bit(ATM_VF_CLOSE, &vcc->flags); - vcc->sk->sk_err = -reply; + sk->sk_shutdown |= RCV_SHUTDOWN; + sk->sk_err = -reply; clear_bit(ATM_VF_WAITING, &vcc->flags); - vcc->sk->sk_state_change(vcc->sk); + sk->sk_state_change(sk); } EXPORT_SYMBOL(vcc_release_async); +void atm_dev_release_vccs(struct atm_dev *dev) +{ + int i; + + write_lock_irq(&vcc_sklist_lock); + for (i = 0; i < VCC_HTABLE_SIZE; i++) { + struct hlist_head *head = &vcc_hash[i]; + struct hlist_node *node, *tmp; + struct sock *s; + struct atm_vcc *vcc; + + sk_for_each_safe(s, node, tmp, head) { + vcc = atm_sk(s); + if (vcc->dev == dev) { + vcc_release_async(vcc, -EPIPE); + sk_del_node_init(s); + } + } + } + write_unlock_irq(&vcc_sklist_lock); +} + + static int adjust_tp(struct atm_trafprm *tp,unsigned char aal) { int max_sdu; @@ -327,6 +346,7 @@ static int find_ci(struct atm_vcc *vcc, short *vpi, int *vci) static int __vcc_connect(struct atm_vcc *vcc, struct atm_dev *dev, short vpi, int vci) { + struct sock *sk = sk_atm(vcc); int error; if ((vpi != ATM_VPI_UNSPEC && vpi != ATM_VPI_ANY && @@ -335,18 +355,19 @@ static int __vcc_connect(struct atm_vcc *vcc, struct atm_dev *dev, short vpi, return -EINVAL; if (vci > 0 && vci < ATM_NOT_RSV_VCI && !capable(CAP_NET_BIND_SERVICE)) return -EPERM; - error = 0; + error = -ENODEV; if (!try_module_get(dev->ops->owner)) - return -ENODEV; + return error; vcc->dev = dev; write_lock_irq(&vcc_sklist_lock); - if ((error = find_ci(vcc, &vpi, &vci))) { + if (test_bit(ATM_DF_REMOVED, &dev->flags) || + (error = find_ci(vcc, &vpi, &vci))) { write_unlock_irq(&vcc_sklist_lock); goto fail_module_put; } vcc->vpi = vpi; vcc->vci = vci; - __vcc_insert_socket(vcc->sk); + __vcc_insert_socket(sk); write_unlock_irq(&vcc_sklist_lock); switch (vcc->qos.aal) { case ATM_AAL0: @@ -385,7 +406,7 @@ static int __vcc_connect(struct atm_vcc *vcc, struct atm_dev *dev, short vpi, return 0; fail: - vcc_remove_socket(vcc->sk); + vcc_remove_socket(sk); fail_module_put: module_put(dev->ops->owner); /* ensure we get dev module ref count correct */ @@ -426,33 +447,23 @@ int vcc_connect(struct socket *sock, int itf, short vpi, int vci) if (vcc->qos.txtp.traffic_class == ATM_ANYCLASS || vcc->qos.rxtp.traffic_class == ATM_ANYCLASS) return -EINVAL; - if (itf != ATM_ITF_ANY) { - dev = atm_dev_lookup(itf); - if (!dev) - return -ENODEV; - error = __vcc_connect(vcc, dev, vpi, vci); - if (error) { - atm_dev_put(dev); - return error; - } + if (likely(itf != ATM_ITF_ANY)) { + dev = try_then_request_module(atm_dev_lookup(itf), "atm-device-%d", itf); } else { - struct list_head *p, *next; - dev = NULL; - spin_lock(&atm_dev_lock); - list_for_each_safe(p, next, &atm_devs) { - dev = list_entry(p, struct atm_dev, dev_list); + mutex_lock(&atm_dev_mutex); + if (!list_empty(&atm_devs)) { + dev = list_entry(atm_devs.next, struct atm_dev, dev_list); atm_dev_hold(dev); - spin_unlock(&atm_dev_lock); - if (!__vcc_connect(vcc, dev, vpi, vci)) - break; - atm_dev_put(dev); - dev = NULL; - spin_lock(&atm_dev_lock); } - spin_unlock(&atm_dev_lock); - if (!dev) - return -ENODEV; + mutex_unlock(&atm_dev_mutex); + } + if (!dev) + return -ENODEV; + error = __vcc_connect(vcc, dev, vpi, vci); + if (error) { + atm_dev_put(dev); + return error; } if (vpi == ATM_VPI_UNSPEC || vci == ATM_VCI_UNSPEC) set_bit(ATM_VF_PARTIAL,&vcc->flags); @@ -494,7 +505,7 @@ int vcc_recvmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg, if (error) return error; sock_recv_timestamp(msg, sk, skb); - DPRINTK("RcvM %d -= %d\n", atomic_read(&vcc->sk->rmem_alloc), skb->truesize); + DPRINTK("RcvM %d -= %d\n", atomic_read(&sk->rmem_alloc), skb->truesize); atm_return(vcc, skb->truesize); skb_free_datagram(sk, skb); return copied; @@ -509,7 +520,7 @@ int vcc_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *m, struct atm_vcc *vcc; struct sk_buff *skb; int eff,error; - const void *buff; + const void __user *buff; int size; lock_sock(sk); @@ -543,7 +554,7 @@ int vcc_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *m, error = -EMSGSIZE; goto out; } - /* verify_area is done by net/socket.c */ + eff = (size+3) & ~3; /* align to word boundary */ prepare_to_wait(sk->sk_sleep, &wait, TASK_INTERRUPTIBLE); error = 0; @@ -614,7 +625,7 @@ unsigned int vcc_poll(struct file *file, struct socket *sock, poll_table *wait) return mask; if (vcc->qos.txtp.traffic_class != ATM_NONE && - vcc_writable(vcc->sk)) + vcc_writable(sk)) mask |= POLLOUT | POLLWRNORM | POLLWRBAND; return mask; @@ -637,7 +648,7 @@ static int atm_change_qos(struct atm_vcc *vcc,struct atm_qos *qos) if (!error) error = adjust_tp(&qos->rxtp,qos->aal); if (error) return error; if (!vcc->dev->ops->change_qos) return -EOPNOTSUPP; - if (vcc->sk->sk_family == AF_ATMPVC) + if (sk_atm(vcc)->sk_family == AF_ATMPVC) return vcc->dev->ops->change_qos(vcc,qos,ATM_MF_SET); return svc_change_qos(vcc,qos); } @@ -676,7 +687,7 @@ static int check_qos(struct atm_qos *qos) } int vcc_setsockopt(struct socket *sock, int level, int optname, - char *optval, int optlen) + char __user *optval, int optlen) { struct atm_vcc *vcc; unsigned long value; @@ -704,7 +715,7 @@ int vcc_setsockopt(struct socket *sock, int level, int optname, return 0; } case SO_SETCLP: - if (get_user(value,(unsigned long *) optval)) + if (get_user(value,(unsigned long __user *)optval)) return -EFAULT; if (value) vcc->atm_options |= ATM_ATMOPT_CLP; else vcc->atm_options &= ~ATM_ATMOPT_CLP; @@ -719,7 +730,7 @@ int vcc_setsockopt(struct socket *sock, int level, int optname, int vcc_getsockopt(struct socket *sock, int level, int optname, - char *optval, int *optlen) + char __user *optval, int __user *optlen) { struct atm_vcc *vcc; int len; @@ -738,7 +749,7 @@ int vcc_getsockopt(struct socket *sock, int level, int optname, -EFAULT : 0; case SO_SETCLP: return put_user(vcc->atm_options & ATM_ATMOPT_CLP ? 1 : - 0,(unsigned long *) optval) ? -EFAULT : 0; + 0,(unsigned long __user *)optval) ? -EFAULT : 0; case SO_ATMPVC: { struct sockaddr_atmpvc pvc; @@ -761,43 +772,34 @@ int vcc_getsockopt(struct socket *sock, int level, int optname, return vcc->dev->ops->getsockopt(vcc, level, optname, optval, len); } - -#if defined(CONFIG_ATM_LANE) || defined(CONFIG_ATM_LANE_MODULE) -#if defined(CONFIG_BRIDGE) || defined(CONFIG_BRIDGE_MODULE) -struct net_bridge; -struct net_bridge_fdb_entry *(*br_fdb_get_hook)(struct net_bridge *br, - unsigned char *addr) = NULL; -void (*br_fdb_put_hook)(struct net_bridge_fdb_entry *ent) = NULL; -#if defined(CONFIG_ATM_LANE_MODULE) || defined(CONFIG_BRIDGE_MODULE) -EXPORT_SYMBOL(br_fdb_get_hook); -EXPORT_SYMBOL(br_fdb_put_hook); -#endif /* defined(CONFIG_ATM_LANE_MODULE) || defined(CONFIG_BRIDGE_MODULE) */ -#endif /* defined(CONFIG_BRIDGE) || defined(CONFIG_BRIDGE_MODULE) */ -#endif /* defined(CONFIG_ATM_LANE) || defined(CONFIG_ATM_LANE_MODULE) */ - - static int __init atm_init(void) { int error; + if ((error = proto_register(&vcc_proto, 0)) < 0) + goto out; + if ((error = atmpvc_init()) < 0) { printk(KERN_ERR "atmpvc_init() failed with %d\n", error); - goto failure; + goto out_unregister_vcc_proto; } if ((error = atmsvc_init()) < 0) { printk(KERN_ERR "atmsvc_init() failed with %d\n", error); - goto failure; + goto out_atmpvc_exit; } if ((error = atm_proc_init()) < 0) { printk(KERN_ERR "atm_proc_init() failed with %d\n",error); - goto failure; + goto out_atmsvc_exit; } - return 0; - -failure: - atmsvc_exit(); - atmpvc_exit(); +out: return error; +out_atmsvc_exit: + atmsvc_exit(); +out_atmpvc_exit: + atmsvc_exit(); +out_unregister_vcc_proto: + proto_unregister(&vcc_proto); + goto out; } static void __exit atm_exit(void) @@ -805,6 +807,7 @@ static void __exit atm_exit(void) atm_proc_exit(); atmsvc_exit(); atmpvc_exit(); + proto_unregister(&vcc_proto); } module_init(atm_init);