From 3b3de758ec898e47aef609fbf16d78e97dac2000 Mon Sep 17 00:00:00 2001
From: "Jason A. Donenfeld" <Jason@zx2c4.com>
Date: Thu, 7 Jan 2021 17:00:21 +0100
Subject: [PATCH] conn: linux: do not allow ReceiveIPvX to race with Close

If Close is called after ReceiveIPvX, then ReceiveIPvX will block on an
invalid or potentially reused fd.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
---
 conn/conn_linux.go | 49 ++++++++++++++++++++++++++++++----------------
 1 file changed, 32 insertions(+), 17 deletions(-)

diff --git a/conn/conn_linux.go b/conn/conn_linux.go
index ef98100..ef5c0ba 100644
--- a/conn/conn_linux.go
+++ b/conn/conn_linux.go
@@ -18,10 +18,6 @@ import (
 	"golang.org/x/sys/unix"
 )
 
-const (
-	FD_ERR = -1
-)
-
 type IPv4Source struct {
 	Src     [4]byte
 	Ifindex int32
@@ -63,6 +59,7 @@ type nativeBind struct {
 	sock4    int
 	sock6    int
 	lastMark uint32
+	closing  sync.RWMutex
 }
 
 var _ Endpoint = (*NativeEndpoint)(nil)
@@ -129,7 +126,7 @@ func createBind(port uint16) (Bind, uint16, error) {
 		port = newPort
 	}
 
-	if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR {
+	if bind.sock4 == -1 && bind.sock6 == -1 {
 		return nil, 0, errors.New("ipv4 and ipv6 not supported")
 	}
 
@@ -141,6 +138,9 @@ func (bind *nativeBind) LastMark() uint32 {
 }
 
 func (bind *nativeBind) SetMark(value uint32) error {
+	bind.closing.RLock()
+	defer bind.closing.RUnlock()
+
 	if bind.sock6 != -1 {
 		err := unix.SetsockoptInt(
 			bind.sock6,
@@ -171,20 +171,26 @@ func (bind *nativeBind) SetMark(value uint32) error {
 	return nil
 }
 
-func closeUnblock(fd int) error {
-	// shutdown to unblock readers and writers
-	unix.Shutdown(fd, unix.SHUT_RDWR)
-	return unix.Close(fd)
-}
-
 func (bind *nativeBind) Close() error {
 	var err1, err2 error
+	bind.closing.RLock()
 	if bind.sock6 != -1 {
-		err1 = closeUnblock(bind.sock6)
+		unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
 	}
 	if bind.sock4 != -1 {
-		err2 = closeUnblock(bind.sock4)
+		unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
 	}
+	bind.closing.RUnlock()
+	bind.closing.Lock()
+	if bind.sock6 != -1 {
+		err1 = unix.Close(bind.sock6)
+		bind.sock6 = -1
+	}
+	if bind.sock4 != -1 {
+		err2 = unix.Close(bind.sock4)
+		bind.sock4 = -1
+	}
+	bind.closing.Unlock()
 
 	if err1 != nil {
 		return err1
@@ -193,6 +199,9 @@ func (bind *nativeBind) Close() error {
 }
 
 func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+	bind.closing.RLock()
+	defer bind.closing.RUnlock()
+
 	var end NativeEndpoint
 	if bind.sock6 == -1 {
 		return 0, nil, syscall.EAFNOSUPPORT
@@ -206,6 +215,9 @@ func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
 }
 
 func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+	bind.closing.RLock()
+	defer bind.closing.RUnlock()
+
 	var end NativeEndpoint
 	if bind.sock4 == -1 {
 		return 0, nil, syscall.EAFNOSUPPORT
@@ -219,6 +231,9 @@ func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
 }
 
 func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
+	bind.closing.RLock()
+	defer bind.closing.RUnlock()
+
 	nend := end.(*NativeEndpoint)
 	if !nend.isV6 {
 		if bind.sock4 == -1 {
@@ -316,7 +331,7 @@ func create4(port uint16) (int, uint16, error) {
 	)
 
 	if err != nil {
-		return FD_ERR, 0, err
+		return -1, 0, err
 	}
 
 	addr := unix.SockaddrInet4{
@@ -338,7 +353,7 @@ func create4(port uint16) (int, uint16, error) {
 		return unix.Bind(fd, &addr)
 	}(); err != nil {
 		unix.Close(fd)
-		return FD_ERR, 0, err
+		return -1, 0, err
 	}
 
 	sa, err := unix.Getsockname(fd)
@@ -360,7 +375,7 @@ func create6(port uint16) (int, uint16, error) {
 	)
 
 	if err != nil {
-		return FD_ERR, 0, err
+		return -1, 0, err
 	}
 
 	// set sockopts and bind
@@ -392,7 +407,7 @@ func create6(port uint16) (int, uint16, error) {
 
 	}(); err != nil {
 		unix.Close(fd)
-		return FD_ERR, 0, err
+		return -1, 0, err
 	}
 
 	sa, err := unix.Getsockname(fd)