Fixed blocking reader on closed socket

This commit is contained in:
Mathias Hall-Andersen 2017-11-11 23:26:44 +01:00
parent 892276aa64
commit 566269275e
4 changed files with 32 additions and 23 deletions

View file

@ -37,15 +37,14 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
/* Must hold device and net lock /* Must hold device and net lock
*/ */
func unsafeCloseUDPListener(device *Device) error { func unsafeCloseUDPListener(device *Device) error {
var err error
netc := &device.net netc := &device.net
if netc.bind != nil { if netc.bind != nil {
if err := netc.bind.Close(); err != nil { err = netc.bind.Close()
return err
}
netc.bind = nil netc.bind = nil
netc.update.Broadcast() netc.update.Add(1)
} }
return nil return err
} }
// must inform all listeners // must inform all listeners
@ -63,7 +62,7 @@ func UpdateUDPListener(device *Device) error {
return err return err
} }
// wait for reader // assumption: netc.update WaitGroup should be exactly 1
// open new sockets // open new sockets
@ -93,9 +92,10 @@ func UpdateUDPListener(device *Device) error {
peer.mutex.Unlock() peer.mutex.Unlock()
} }
// inform readers of updated bind // decrease waitgroup to 0
netc.update.Broadcast() device.log.Debug.Println("UDP bind has been updated")
netc.update.Done()
} }
return nil return nil

View file

@ -84,9 +84,15 @@ func (bind NativeBind) SetMark(value uint32) error {
) )
} }
func closeUnblock(fd int) error {
// shutdown to unblock readers
unix.Shutdown(fd, unix.SHUT_RD)
return unix.Close(fd)
}
func (bind NativeBind) Close() error { func (bind NativeBind) Close() error {
err1 := unix.Close(bind.sock6) err1 := closeUnblock(bind.sock6)
err2 := unix.Close(bind.sock4) err2 := closeUnblock(bind.sock4)
if err1 != nil { if err1 != nil {
return err1 return err1
} }
@ -125,13 +131,13 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string {
switch addr.Family { switch addr.Family {
case unix.AF_INET6: case unix.AF_INET6:
udpAddr.Port = int(addr.Port) udpAddr.Port = int(ntohs(addr.Port))
udpAddr.IP = addr.Addr[:] udpAddr.IP = addr.Addr[:]
return udpAddr.String() return udpAddr.String()
case unix.AF_INET: case unix.AF_INET:
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
udpAddr.Port = int(ptr.Port) udpAddr.Port = int(ntohs(ptr.Port))
udpAddr.IP = net.IPv4( udpAddr.IP = net.IPv4(
ptr.Addr[0], ptr.Addr[0],
ptr.Addr[1], ptr.Addr[1],

View file

@ -23,10 +23,10 @@ type Device struct {
} }
net struct { net struct {
mutex sync.RWMutex mutex sync.RWMutex
bind UDPBind // bind interface bind UDPBind // bind interface
port uint16 // listening port port uint16 // listening port
fwmark uint32 // mark value (0 = disabled) fwmark uint32 // mark value (0 = disabled)
update *sync.Cond // the bind was updated update sync.WaitGroup // the bind was updated (acting as a barrier)
} }
mutex sync.RWMutex mutex sync.RWMutex
privateKey NoisePrivateKey privateKey NoisePrivateKey
@ -167,7 +167,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
device.net.port = 0 device.net.port = 0
device.net.bind = nil device.net.bind = nil
device.net.update = sync.NewCond(&device.net.mutex) device.net.update.Add(1)
// start workers // start workers
@ -209,9 +209,11 @@ func (device *Device) RemoveAllPeers() {
} }
func (device *Device) Close() { func (device *Device) Close() {
device.log.Info.Println("Closing device")
device.RemoveAllPeers() device.RemoveAllPeers()
close(device.signal.stop) close(device.signal.stop)
CloseUDPListener(device) CloseUDPListener(device)
device.tun.device.Close()
} }
func (device *Device) WaitChannel() chan struct{} { func (device *Device) WaitChannel() chan struct{} {

View file

@ -95,23 +95,22 @@ func (device *Device) addToHandshakeQueue(
func (device *Device) RoutineReceiveIncomming(IPVersion int) { func (device *Device) RoutineReceiveIncomming(IPVersion int) {
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, receive incomming, started") logDebug.Println("Routine, receive incomming, IP version:", IPVersion)
for { for {
// wait for bind // wait for bind
logDebug.Println("Waiting for udp bind") logDebug.Println("Waiting for UDP socket, IP version:", IPVersion)
device.net.mutex.Lock()
device.net.update.Wait() device.net.update.Wait()
device.net.mutex.RLock()
bind := device.net.bind bind := device.net.bind
device.net.mutex.Unlock() device.net.mutex.RUnlock()
if bind == nil { if bind == nil {
continue continue
} }
logDebug.Println("LISTEN\n\n\n")
// receive datagrams until conn is closed // receive datagrams until conn is closed
buffer := device.GetMessageBuffer() buffer := device.GetMessageBuffer()
@ -427,6 +426,8 @@ func (device *Device) RoutineHandshake() {
err = peer.SendBuffer(packet) err = peer.SendBuffer(packet)
if err == nil { if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal() peer.TimerAnyAuthenticatedPacketTraversal()
} else {
logError.Println("Failed to send response to:", peer.String(), err)
} }
case MessageResponseType: case MessageResponseType: