Look up route for every peer
This commit is contained in:
parent
659106bd6d
commit
0fb14232fa
2
conn.go
2
conn.go
|
@ -123,7 +123,7 @@ func (device *Device) BindUpdate() error {
|
|||
|
||||
var err error
|
||||
netc := &device.net
|
||||
netc.bind, netc.port, err = CreateBind(netc.port)
|
||||
netc.bind, netc.port, err = CreateBind(netc.port, device)
|
||||
if err != nil {
|
||||
netc.bind = nil
|
||||
netc.port = 0
|
||||
|
|
|
@ -81,7 +81,7 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
|||
return conn, uaddr.Port, nil
|
||||
}
|
||||
|
||||
func CreateBind(uport uint16) (Bind, uint16, error) {
|
||||
func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
|
||||
var err error
|
||||
var bind NativeBind
|
||||
|
||||
|
|
|
@ -58,7 +58,6 @@ type NativeBind struct {
|
|||
sock4 int
|
||||
sock6 int
|
||||
netlinkSock int
|
||||
lastEndpoint *NativeEndpoint
|
||||
lastMark uint32
|
||||
}
|
||||
|
||||
|
@ -118,7 +117,7 @@ func createNetlinkRouteSocket() (int, error) {
|
|||
|
||||
}
|
||||
|
||||
func CreateBind(port uint16) (*NativeBind, uint16, error) {
|
||||
func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
|
||||
var err error
|
||||
var bind NativeBind
|
||||
|
||||
|
@ -127,7 +126,7 @@ func CreateBind(port uint16) (*NativeBind, uint16, error) {
|
|||
return nil, 0, err
|
||||
}
|
||||
|
||||
go bind.routineRouteListener()
|
||||
go bind.routineRouteListener(device)
|
||||
|
||||
bind.sock6, port, err = create6(port)
|
||||
if err != nil {
|
||||
|
@ -171,8 +170,8 @@ func (bind *NativeBind) SetMark(value uint32) error {
|
|||
}
|
||||
|
||||
func closeUnblock(fd int) error {
|
||||
// shutdown to unblock readers
|
||||
unix.Shutdown(fd, unix.SHUT_RD)
|
||||
// shutdown to unblock readers and writers
|
||||
unix.Shutdown(fd, unix.SHUT_RDWR)
|
||||
return unix.Close(fd)
|
||||
}
|
||||
|
||||
|
@ -206,7 +205,6 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
|||
buff,
|
||||
&end,
|
||||
)
|
||||
bind.lastEndpoint = &end
|
||||
return n, &end, err
|
||||
}
|
||||
|
||||
|
@ -551,8 +549,8 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
|||
return size, nil
|
||||
}
|
||||
|
||||
func (bind *NativeBind) routineRouteListener() {
|
||||
// TODO: this function doesn't lock the endpoint it modifies
|
||||
func (bind *NativeBind) routineRouteListener(device *Device) {
|
||||
var reqPeer map[uint32]*Peer
|
||||
|
||||
for msg := make([]byte, 1<<16); ; {
|
||||
msgn, _, _, _, err := unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
|
||||
|
@ -570,12 +568,7 @@ func (bind *NativeBind) routineRouteListener() {
|
|||
|
||||
switch hdr.Type {
|
||||
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
||||
|
||||
if bind.lastEndpoint == nil || bind.lastEndpoint.isV6 || bind.lastEndpoint.src4().ifindex == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if hdr.Seq == 0xff {
|
||||
if hdr.Seq <= MaxPeers {
|
||||
if uint(len(remain)) < uint(hdr.Len) {
|
||||
break
|
||||
}
|
||||
|
@ -591,16 +584,46 @@ func (bind *NativeBind) routineRouteListener() {
|
|||
}
|
||||
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
||||
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
||||
if uint32(bind.lastEndpoint.src4().ifindex) != ifidx {
|
||||
bind.lastEndpoint.ClearSrc()
|
||||
if reqPeer == nil {
|
||||
break
|
||||
}
|
||||
peer, ok := reqPeer[hdr.Seq]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
peer.mutex.RLock()
|
||||
if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
|
||||
peer.mutex.RUnlock()
|
||||
break
|
||||
}
|
||||
if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
|
||||
peer.mutex.RUnlock()
|
||||
break
|
||||
}
|
||||
if uint32(peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
|
||||
peer.mutex.RUnlock()
|
||||
break
|
||||
}
|
||||
peer.mutex.RUnlock()
|
||||
peer.mutex.Lock()
|
||||
peer.endpoint.(*NativeEndpoint).ClearSrc()
|
||||
peer.mutex.Unlock()
|
||||
}
|
||||
attr = attr[attrhdr.Len:]
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
reqPeer = make(map[uint32]*Peer)
|
||||
go func() {
|
||||
device.peers.mutex.RLock()
|
||||
i := uint32(1)
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.mutex.RLock()
|
||||
if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
|
||||
peer.mutex.RUnlock()
|
||||
continue
|
||||
}
|
||||
nlmsg := struct {
|
||||
hdr unix.NlMsghdr
|
||||
msg unix.RtMsg
|
||||
|
@ -614,7 +637,7 @@ func (bind *NativeBind) routineRouteListener() {
|
|||
unix.NlMsghdr{
|
||||
Type: uint16(unix.RTM_GETROUTE),
|
||||
Flags: unix.NLM_F_REQUEST,
|
||||
Seq: 0xff,
|
||||
Seq: i,
|
||||
},
|
||||
unix.RtMsg{
|
||||
Family: unix.AF_INET,
|
||||
|
@ -625,12 +648,12 @@ func (bind *NativeBind) routineRouteListener() {
|
|||
Len: 8,
|
||||
Type: unix.RTA_DST,
|
||||
},
|
||||
bind.lastEndpoint.dst4().Addr,
|
||||
peer.endpoint.(*NativeEndpoint).dst4().Addr,
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_SRC,
|
||||
},
|
||||
bind.lastEndpoint.src4().src,
|
||||
peer.endpoint.(*NativeEndpoint).src4().src,
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix
|
||||
|
@ -638,8 +661,14 @@ func (bind *NativeBind) routineRouteListener() {
|
|||
uint32(bind.lastMark),
|
||||
}
|
||||
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
||||
reqPeer[i] = peer
|
||||
peer.mutex.RUnlock()
|
||||
i++
|
||||
unix.Write(bind.netlinkSock, (*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||
}
|
||||
device.peers.mutex.RUnlock()
|
||||
}()
|
||||
}
|
||||
remain = remain[hdr.Len:]
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue