diff --git a/include/net/raw.h b/include/net/raw.h index 719d3556fc0a..d81eeeb8f1e6 100644 --- a/include/net/raw.h +++ b/include/net/raw.h @@ -33,9 +33,18 @@ int raw_rcv(struct sock *, struct sk_buff *); struct raw_hashinfo { rwlock_t lock; - struct hlist_head ht[RAW_HTABLE_SIZE]; + struct hlist_nulls_head ht[RAW_HTABLE_SIZE]; }; +static inline void raw_hashinfo_init(struct raw_hashinfo *hashinfo) +{ + int i; + + rwlock_init(&hashinfo->lock); + for (i = 0; i < RAW_HTABLE_SIZE; i++) + INIT_HLIST_NULLS_HEAD(&hashinfo->ht[i], i); +} + #ifdef CONFIG_PROC_FS int raw_proc_init(void); void raw_proc_exit(void); diff --git a/include/net/rawv6.h b/include/net/rawv6.h index c48c1298699a..bc70909625f6 100644 --- a/include/net/rawv6.h +++ b/include/net/rawv6.h @@ -3,6 +3,7 @@ #define _NET_RAWV6_H #include +#include extern struct raw_hashinfo raw_v6_hashinfo; bool raw_v6_match(struct net *net, struct sock *sk, unsigned short num, diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 30e0e8992085..da81f56fdd1c 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -1929,6 +1929,8 @@ static int __init inet_init(void) sock_skb_cb_check_size(sizeof(struct inet_skb_parm)); + raw_hashinfo_init(&raw_v4_hashinfo); + rc = proto_register(&tcp_prot, 1); if (rc) goto out; diff --git a/net/ipv4/raw.c b/net/ipv4/raw.c index 05e0de4a7c7f..d28bf0b901a2 100644 --- a/net/ipv4/raw.c +++ b/net/ipv4/raw.c @@ -85,20 +85,19 @@ struct raw_frag_vec { int hlen; }; -struct raw_hashinfo raw_v4_hashinfo = { - .lock = __RW_LOCK_UNLOCKED(raw_v4_hashinfo.lock), -}; +struct raw_hashinfo raw_v4_hashinfo; EXPORT_SYMBOL_GPL(raw_v4_hashinfo); int raw_hash_sk(struct sock *sk) { struct raw_hashinfo *h = sk->sk_prot->h.raw_hash; - struct hlist_head *head; + struct hlist_nulls_head *hlist; - head = &h->ht[inet_sk(sk)->inet_num & (RAW_HTABLE_SIZE - 1)]; + hlist = &h->ht[inet_sk(sk)->inet_num & (RAW_HTABLE_SIZE - 1)]; write_lock_bh(&h->lock); - sk_add_node(sk, head); + hlist_nulls_add_head_rcu(&sk->sk_nulls_node, hlist); + sock_set_flag(sk, SOCK_RCU_FREE); write_unlock_bh(&h->lock); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); @@ -111,7 +110,7 @@ void raw_unhash_sk(struct sock *sk) struct raw_hashinfo *h = sk->sk_prot->h.raw_hash; write_lock_bh(&h->lock); - if (sk_del_node_init(sk)) + if (__sk_nulls_del_node_init_rcu(sk)) sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); write_unlock_bh(&h->lock); } @@ -164,17 +163,16 @@ static int icmp_filter(const struct sock *sk, const struct sk_buff *skb) static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash) { struct net *net = dev_net(skb->dev); + struct hlist_nulls_head *hlist; + struct hlist_nulls_node *hnode; int sdif = inet_sdif(skb); int dif = inet_iif(skb); - struct hlist_head *head; int delivered = 0; struct sock *sk; - head = &raw_v4_hashinfo.ht[hash]; - if (hlist_empty(head)) - return 0; - read_lock(&raw_v4_hashinfo.lock); - sk_for_each(sk, head) { + hlist = &raw_v4_hashinfo.ht[hash]; + rcu_read_lock(); + hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) { if (!raw_v4_match(net, sk, iph->protocol, iph->saddr, iph->daddr, dif, sdif)) continue; @@ -189,7 +187,7 @@ static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash) raw_rcv(sk, clone); } } - read_unlock(&raw_v4_hashinfo.lock); + rcu_read_unlock(); return delivered; } @@ -265,25 +263,26 @@ static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info) void raw_icmp_error(struct sk_buff *skb, int protocol, u32 info) { struct net *net = dev_net(skb->dev);; + struct hlist_nulls_head *hlist; + struct hlist_nulls_node *hnode; int dif = skb->dev->ifindex; int sdif = inet_sdif(skb); - struct hlist_head *head; const struct iphdr *iph; struct sock *sk; int hash; hash = protocol & (RAW_HTABLE_SIZE - 1); - head = &raw_v4_hashinfo.ht[hash]; + hlist = &raw_v4_hashinfo.ht[hash]; - read_lock(&raw_v4_hashinfo.lock); - sk_for_each(sk, head) { + rcu_read_lock(); + hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) { iph = (const struct iphdr *)skb->data; if (!raw_v4_match(net, sk, iph->protocol, iph->saddr, iph->daddr, dif, sdif)) continue; raw_err(sk, skb, info); } - read_unlock(&raw_v4_hashinfo.lock); + rcu_read_unlock(); } static int raw_rcv_skb(struct sock *sk, struct sk_buff *skb) @@ -944,44 +943,41 @@ struct proto raw_prot = { }; #ifdef CONFIG_PROC_FS -static struct sock *raw_get_first(struct seq_file *seq) +static struct sock *raw_get_first(struct seq_file *seq, int bucket) { - struct sock *sk; struct raw_hashinfo *h = pde_data(file_inode(seq->file)); struct raw_iter_state *state = raw_seq_private(seq); + struct hlist_nulls_head *hlist; + struct hlist_nulls_node *hnode; + struct sock *sk; - for (state->bucket = 0; state->bucket < RAW_HTABLE_SIZE; + for (state->bucket = bucket; state->bucket < RAW_HTABLE_SIZE; ++state->bucket) { - sk_for_each(sk, &h->ht[state->bucket]) + hlist = &h->ht[state->bucket]; + hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) { if (sock_net(sk) == seq_file_net(seq)) - goto found; + return sk; + } } - sk = NULL; -found: - return sk; + return NULL; } static struct sock *raw_get_next(struct seq_file *seq, struct sock *sk) { - struct raw_hashinfo *h = pde_data(file_inode(seq->file)); struct raw_iter_state *state = raw_seq_private(seq); do { - sk = sk_next(sk); -try_again: - ; + sk = sk_nulls_next(sk); } while (sk && sock_net(sk) != seq_file_net(seq)); - if (!sk && ++state->bucket < RAW_HTABLE_SIZE) { - sk = sk_head(&h->ht[state->bucket]); - goto try_again; - } + if (!sk) + return raw_get_first(seq, state->bucket + 1); return sk; } static struct sock *raw_get_idx(struct seq_file *seq, loff_t pos) { - struct sock *sk = raw_get_first(seq); + struct sock *sk = raw_get_first(seq, 0); if (sk) while (pos && (sk = raw_get_next(seq, sk)) != NULL) @@ -990,11 +986,9 @@ static struct sock *raw_get_idx(struct seq_file *seq, loff_t pos) } void *raw_seq_start(struct seq_file *seq, loff_t *pos) - __acquires(&h->lock) + __acquires(RCU) { - struct raw_hashinfo *h = pde_data(file_inode(seq->file)); - - read_lock(&h->lock); + rcu_read_lock(); return *pos ? raw_get_idx(seq, *pos - 1) : SEQ_START_TOKEN; } EXPORT_SYMBOL_GPL(raw_seq_start); @@ -1004,7 +998,7 @@ void *raw_seq_next(struct seq_file *seq, void *v, loff_t *pos) struct sock *sk; if (v == SEQ_START_TOKEN) - sk = raw_get_first(seq); + sk = raw_get_first(seq, 0); else sk = raw_get_next(seq, v); ++*pos; @@ -1013,11 +1007,9 @@ void *raw_seq_next(struct seq_file *seq, void *v, loff_t *pos) EXPORT_SYMBOL_GPL(raw_seq_next); void raw_seq_stop(struct seq_file *seq, void *v) - __releases(&h->lock) + __releases(RCU) { - struct raw_hashinfo *h = pde_data(file_inode(seq->file)); - - read_unlock(&h->lock); + rcu_read_unlock(); } EXPORT_SYMBOL_GPL(raw_seq_stop); @@ -1079,6 +1071,7 @@ static __net_initdata struct pernet_operations raw_net_ops = { int __init raw_proc_init(void) { + return register_pernet_subsys(&raw_net_ops); } diff --git a/net/ipv4/raw_diag.c b/net/ipv4/raw_diag.c index b6d92dc7b051..5f208e840d85 100644 --- a/net/ipv4/raw_diag.c +++ b/net/ipv4/raw_diag.c @@ -57,31 +57,32 @@ static bool raw_lookup(struct net *net, struct sock *sk, static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2 *r) { struct raw_hashinfo *hashinfo = raw_get_hashinfo(r); + struct hlist_nulls_head *hlist; + struct hlist_nulls_node *hnode; struct sock *sk; int slot; if (IS_ERR(hashinfo)) return ERR_CAST(hashinfo); - read_lock(&hashinfo->lock); + rcu_read_lock(); for (slot = 0; slot < RAW_HTABLE_SIZE; slot++) { - sk_for_each(sk, &hashinfo->ht[slot]) { + hlist = &hashinfo->ht[slot]; + hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) { if (raw_lookup(net, sk, r)) { /* * Grab it and keep until we fill - * diag meaage to be reported, so + * diag message to be reported, so * caller should call sock_put then. - * We can do that because we're keeping - * hashinfo->lock here. */ - sock_hold(sk); - goto out_unlock; + if (refcount_inc_not_zero(&sk->sk_refcnt)) + goto out_unlock; } } } sk = ERR_PTR(-ENOENT); out_unlock: - read_unlock(&hashinfo->lock); + rcu_read_unlock(); return sk; } @@ -141,6 +142,8 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, struct raw_hashinfo *hashinfo = raw_get_hashinfo(r); struct net *net = sock_net(skb->sk); struct inet_diag_dump_data *cb_data; + struct hlist_nulls_head *hlist; + struct hlist_nulls_node *hnode; int num, s_num, slot, s_slot; struct sock *sk = NULL; struct nlattr *bc; @@ -157,7 +160,8 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, for (slot = s_slot; slot < RAW_HTABLE_SIZE; s_num = 0, slot++) { num = 0; - sk_for_each(sk, &hashinfo->ht[slot]) { + hlist = &hashinfo->ht[slot]; + hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) { struct inet_sock *inet = inet_sk(sk); if (!net_eq(sock_net(sk), net)) diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c index 70564ddccc46..658823e91eca 100644 --- a/net/ipv6/af_inet6.c +++ b/net/ipv6/af_inet6.c @@ -63,6 +63,7 @@ #include #include #include +#include #include #include @@ -1073,6 +1074,8 @@ static int __init inet6_init(void) goto out; } + raw_hashinfo_init(&raw_v6_hashinfo); + err = proto_register(&tcpv6_prot, 1); if (err) goto out; diff --git a/net/ipv6/raw.c b/net/ipv6/raw.c index c0f2e3475984..f6119998700e 100644 --- a/net/ipv6/raw.c +++ b/net/ipv6/raw.c @@ -61,9 +61,7 @@ #define ICMPV6_HDRLEN 4 /* ICMPv6 header, RFC 4443 Section 2.1 */ -struct raw_hashinfo raw_v6_hashinfo = { - .lock = __RW_LOCK_UNLOCKED(raw_v6_hashinfo.lock), -}; +struct raw_hashinfo raw_v6_hashinfo; EXPORT_SYMBOL_GPL(raw_v6_hashinfo); bool raw_v6_match(struct net *net, struct sock *sk, unsigned short num, @@ -143,9 +141,10 @@ EXPORT_SYMBOL(rawv6_mh_filter_unregister); static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr) { struct net *net = dev_net(skb->dev); + struct hlist_nulls_head *hlist; + struct hlist_nulls_node *hnode; const struct in6_addr *saddr; const struct in6_addr *daddr; - struct hlist_head *head; struct sock *sk; bool delivered = false; __u8 hash; @@ -154,11 +153,9 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr) daddr = saddr + 1; hash = nexthdr & (RAW_HTABLE_SIZE - 1); - head = &raw_v6_hashinfo.ht[hash]; - if (hlist_empty(head)) - return false; - read_lock(&raw_v6_hashinfo.lock); - sk_for_each(sk, head) { + hlist = &raw_v6_hashinfo.ht[hash]; + rcu_read_lock(); + hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) { int filtered; if (!raw_v6_match(net, sk, nexthdr, daddr, saddr, @@ -203,7 +200,7 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr) } } } - read_unlock(&raw_v6_hashinfo.lock); + rcu_read_unlock(); return delivered; } @@ -337,14 +334,15 @@ void raw6_icmp_error(struct sk_buff *skb, int nexthdr, { const struct in6_addr *saddr, *daddr; struct net *net = dev_net(skb->dev); - struct hlist_head *head; + struct hlist_nulls_head *hlist; + struct hlist_nulls_node *hnode; struct sock *sk; int hash; hash = nexthdr & (RAW_HTABLE_SIZE - 1); - head = &raw_v6_hashinfo.ht[hash]; - read_lock(&raw_v6_hashinfo.lock); - sk_for_each(sk, head) { + hlist = &raw_v6_hashinfo.ht[hash]; + rcu_read_lock(); + hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) { /* Note: ipv6_hdr(skb) != skb->data */ const struct ipv6hdr *ip6h = (const struct ipv6hdr *)skb->data; saddr = &ip6h->saddr; @@ -355,7 +353,7 @@ void raw6_icmp_error(struct sk_buff *skb, int nexthdr, continue; rawv6_err(sk, skb, NULL, type, code, inner_offset, info); } - read_unlock(&raw_v6_hashinfo.lock); + rcu_read_unlock(); } static inline int rawv6_rcv_skb(struct sock *sk, struct sk_buff *skb)