|
@@ -1599,12 +1599,21 @@ static void flush_stack(struct sock **stack, unsigned int count,
|
|
|
kfree_skb(skb1);
|
|
|
}
|
|
|
|
|
|
-static void udp_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb)
|
|
|
+/* For TCP sockets, sk_rx_dst is protected by socket lock
|
|
|
+ * For UDP, we use sk_dst_lock to guard against concurrent changes.
|
|
|
+ */
|
|
|
+static void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst)
|
|
|
{
|
|
|
- struct dst_entry *dst = skb_dst(skb);
|
|
|
+ struct dst_entry *old;
|
|
|
|
|
|
- dst_hold(dst);
|
|
|
- sk->sk_rx_dst = dst;
|
|
|
+ spin_lock(&sk->sk_dst_lock);
|
|
|
+ old = sk->sk_rx_dst;
|
|
|
+ if (likely(old != dst)) {
|
|
|
+ dst_hold(dst);
|
|
|
+ sk->sk_rx_dst = dst;
|
|
|
+ dst_release(old);
|
|
|
+ }
|
|
|
+ spin_unlock(&sk->sk_dst_lock);
|
|
|
}
|
|
|
|
|
|
/*
|
|
@@ -1737,10 +1746,11 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
|
|
|
|
|
|
sk = skb_steal_sock(skb);
|
|
|
if (sk) {
|
|
|
+ struct dst_entry *dst = skb_dst(skb);
|
|
|
int ret;
|
|
|
|
|
|
- if (unlikely(sk->sk_rx_dst == NULL))
|
|
|
- udp_sk_rx_dst_set(sk, skb);
|
|
|
+ if (unlikely(sk->sk_rx_dst != dst))
|
|
|
+ udp_sk_rx_dst_set(sk, dst);
|
|
|
|
|
|
ret = udp_queue_rcv_skb(sk, skb);
|
|
|
sock_put(sk);
|