Extend structs rather than embed, when possible

This commit is contained in:
Jason A. Donenfeld 2019-01-03 19:04:00 +01:00
parent dff424baf8
commit 89d2c5ed7a
16 changed files with 213 additions and 215 deletions

28
conn.go
View file

@ -78,8 +78,8 @@ func unsafeCloseBind(device *Device) error {
func (device *Device) BindSetMark(mark uint32) error { func (device *Device) BindSetMark(mark uint32) error {
device.net.mutex.Lock() device.net.Lock()
defer device.net.mutex.Unlock() defer device.net.Unlock()
// check if modified // check if modified
@ -98,23 +98,23 @@ func (device *Device) BindSetMark(mark uint32) error {
// clear cached source addresses // clear cached source addresses
device.peers.mutex.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.mutex.Lock() peer.Lock()
defer peer.mutex.Unlock() defer peer.Unlock()
if peer.endpoint != nil { if peer.endpoint != nil {
peer.endpoint.ClearSrc() peer.endpoint.ClearSrc()
} }
} }
device.peers.mutex.RUnlock() device.peers.RUnlock()
return nil return nil
} }
func (device *Device) BindUpdate() error { func (device *Device) BindUpdate() error {
device.net.mutex.Lock() device.net.Lock()
defer device.net.mutex.Unlock() defer device.net.Unlock()
// close existing sockets // close existing sockets
@ -148,15 +148,15 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses // clear cached source addresses
device.peers.mutex.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.mutex.Lock() peer.Lock()
defer peer.mutex.Unlock() defer peer.Unlock()
if peer.endpoint != nil { if peer.endpoint != nil {
peer.endpoint.ClearSrc() peer.endpoint.ClearSrc()
} }
} }
device.peers.mutex.RUnlock() device.peers.RUnlock()
// start receiving routines // start receiving routines
@ -173,8 +173,8 @@ func (device *Device) BindUpdate() error {
} }
func (device *Device) BindClose() error { func (device *Device) BindClose() error {
device.net.mutex.Lock() device.net.Lock()
err := unsafeCloseBind(device) err := unsafeCloseBind(device)
device.net.mutex.Unlock() device.net.Unlock()
return err return err
} }

View file

@ -654,17 +654,17 @@ func (bind *NativeBind) routineRouteListener(device *Device) {
if !ok { if !ok {
break break
} }
pePtr.peer.mutex.Lock() pePtr.peer.Lock()
if &pePtr.peer.endpoint != pePtr.endpoint { if &pePtr.peer.endpoint != pePtr.endpoint {
pePtr.peer.mutex.Unlock() pePtr.peer.Unlock()
break break
} }
if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx { if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
pePtr.peer.mutex.Unlock() pePtr.peer.Unlock()
break break
} }
pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc() pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
pePtr.peer.mutex.Unlock() pePtr.peer.Unlock()
} }
attr = attr[attrhdr.Len:] attr = attr[attrhdr.Len:]
} }
@ -675,16 +675,16 @@ func (bind *NativeBind) routineRouteListener(device *Device) {
reqPeer = make(map[uint32]peerEndpointPtr) reqPeer = make(map[uint32]peerEndpointPtr)
reqPeerLock.Unlock() reqPeerLock.Unlock()
go func() { go func() {
device.peers.mutex.RLock() device.peers.RLock()
i := uint32(1) i := uint32(1)
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.mutex.RLock() peer.RLock()
if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil { if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
peer.mutex.RUnlock() peer.RUnlock()
continue continue
} }
if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 { if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
peer.mutex.RUnlock() peer.RUnlock()
break break
} }
nlmsg := struct { nlmsg := struct {
@ -730,14 +730,14 @@ func (bind *NativeBind) routineRouteListener(device *Device) {
endpoint: &peer.endpoint, endpoint: &peer.endpoint,
} }
reqPeerLock.Unlock() reqPeerLock.Unlock()
peer.mutex.RUnlock() peer.RUnlock()
i++ i++
_, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil { if err != nil {
break break
} }
} }
device.peers.mutex.RUnlock() device.peers.RUnlock()
}() }()
} }
remain = remain[hdr.Len:] remain = remain[hdr.Len:]

View file

@ -15,8 +15,8 @@ import (
) )
type CookieChecker struct { type CookieChecker struct {
mutex sync.RWMutex sync.RWMutex
mac1 struct { mac1 struct {
key [blake2s.Size]byte key [blake2s.Size]byte
} }
mac2 struct { mac2 struct {
@ -27,8 +27,8 @@ type CookieChecker struct {
} }
type CookieGenerator struct { type CookieGenerator struct {
mutex sync.RWMutex sync.RWMutex
mac1 struct { mac1 struct {
key [blake2s.Size]byte key [blake2s.Size]byte
} }
mac2 struct { mac2 struct {
@ -41,8 +41,8 @@ type CookieGenerator struct {
} }
func (st *CookieChecker) Init(pk NoisePublicKey) { func (st *CookieChecker) Init(pk NoisePublicKey) {
st.mutex.Lock() st.Lock()
defer st.mutex.Unlock() defer st.Unlock()
// mac1 state // mac1 state
@ -66,8 +66,8 @@ func (st *CookieChecker) Init(pk NoisePublicKey) {
} }
func (st *CookieChecker) CheckMAC1(msg []byte) bool { func (st *CookieChecker) CheckMAC1(msg []byte) bool {
st.mutex.RLock() st.RLock()
defer st.mutex.RUnlock() defer st.RUnlock()
size := len(msg) size := len(msg)
smac2 := size - blake2s.Size128 smac2 := size - blake2s.Size128
@ -83,8 +83,8 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
} }
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
st.mutex.RLock() st.RLock()
defer st.mutex.RUnlock() defer st.RUnlock()
if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime {
return false return false
@ -119,21 +119,21 @@ func (st *CookieChecker) CreateReply(
src []byte, src []byte,
) (*MessageCookieReply, error) { ) (*MessageCookieReply, error) {
st.mutex.RLock() st.RLock()
// refresh cookie secret // refresh cookie secret
if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime {
st.mutex.RUnlock() st.RUnlock()
st.mutex.Lock() st.Lock()
_, err := rand.Read(st.mac2.secret[:]) _, err := rand.Read(st.mac2.secret[:])
if err != nil { if err != nil {
st.mutex.Unlock() st.Unlock()
return nil, err return nil, err
} }
st.mac2.secretSet = time.Now() st.mac2.secretSet = time.Now()
st.mutex.Unlock() st.Unlock()
st.mutex.RLock() st.RLock()
} }
// derive cookie // derive cookie
@ -158,21 +158,21 @@ func (st *CookieChecker) CreateReply(
_, err := rand.Read(reply.Nonce[:]) _, err := rand.Read(reply.Nonce[:])
if err != nil { if err != nil {
st.mutex.RUnlock() st.RUnlock()
return nil, err return nil, err
} }
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2]) xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2])
st.mutex.RUnlock() st.RUnlock()
return reply, nil return reply, nil
} }
func (st *CookieGenerator) Init(pk NoisePublicKey) { func (st *CookieGenerator) Init(pk NoisePublicKey) {
st.mutex.Lock() st.Lock()
defer st.mutex.Unlock() defer st.Unlock()
func() { func() {
hash, _ := blake2s.New256(nil) hash, _ := blake2s.New256(nil)
@ -192,8 +192,8 @@ func (st *CookieGenerator) Init(pk NoisePublicKey) {
} }
func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
st.mutex.Lock() st.Lock()
defer st.mutex.Unlock() defer st.Unlock()
if !st.mac2.hasLastMAC1 { if !st.mac2.hasLastMAC1 {
return false return false
@ -223,8 +223,8 @@ func (st *CookieGenerator) AddMacs(msg []byte) {
mac1 := msg[smac1:smac2] mac1 := msg[smac1:smac2]
mac2 := msg[smac2:] mac2 := msg[smac2:]
st.mutex.Lock() st.Lock()
defer st.mutex.Unlock() defer st.Unlock()
// set mac1 // set mac1

View file

@ -29,7 +29,7 @@ type Device struct {
state struct { state struct {
starting sync.WaitGroup starting sync.WaitGroup
stopping sync.WaitGroup stopping sync.WaitGroup
mutex sync.Mutex sync.Mutex
changing AtomicBool changing AtomicBool
current bool current bool
} }
@ -37,20 +37,20 @@ type Device struct {
net struct { net struct {
starting sync.WaitGroup starting sync.WaitGroup
stopping sync.WaitGroup stopping sync.WaitGroup
mutex sync.RWMutex sync.RWMutex
bind Bind // bind interface bind Bind // bind interface
port uint16 // listening port port uint16 // listening port
fwmark uint32 // mark value (0 = disabled) fwmark uint32 // mark value (0 = disabled)
} }
staticIdentity struct { staticIdentity struct {
mutex sync.RWMutex sync.RWMutex
privateKey NoisePrivateKey privateKey NoisePrivateKey
publicKey NoisePublicKey publicKey NoisePublicKey
} }
peers struct { peers struct {
mutex sync.RWMutex sync.RWMutex
keyMap map[NoisePublicKey]*Peer keyMap map[NoisePublicKey]*Peer
} }
@ -93,7 +93,7 @@ type Device struct {
/* Converts the peer into a "zombie", which remains in the peer map, /* Converts the peer into a "zombie", which remains in the peer map,
* but processes no packets and does not exists in the routing table. * but processes no packets and does not exists in the routing table.
* *
* Must hold device.peers.mutex. * Must hold device.peers.Mutex
*/ */
func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
@ -117,13 +117,13 @@ func deviceUpdateState(device *Device) {
// compare to current state of device // compare to current state of device
device.state.mutex.Lock() device.state.Lock()
newIsUp := device.isUp.Get() newIsUp := device.isUp.Get()
if newIsUp == device.state.current { if newIsUp == device.state.current {
device.state.changing.Set(false) device.state.changing.Set(false)
device.state.mutex.Unlock() device.state.Unlock()
return return
} }
@ -135,26 +135,26 @@ func deviceUpdateState(device *Device) {
device.isUp.Set(false) device.isUp.Set(false)
break break
} }
device.peers.mutex.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Start() peer.Start()
} }
device.peers.mutex.RUnlock() device.peers.RUnlock()
case false: case false:
device.BindClose() device.BindClose()
device.peers.mutex.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Stop() peer.Stop()
} }
device.peers.mutex.RUnlock() device.peers.RUnlock()
} }
// update state variables // update state variables
device.state.current = newIsUp device.state.current = newIsUp
device.state.changing.Set(false) device.state.changing.Set(false)
device.state.mutex.Unlock() device.state.Unlock()
// check for state change in the mean time // check for state change in the mean time
@ -199,11 +199,11 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// lock required resources // lock required resources
device.staticIdentity.mutex.Lock() device.staticIdentity.Lock()
defer device.staticIdentity.mutex.Unlock() defer device.staticIdentity.Unlock()
device.peers.mutex.Lock() device.peers.Lock()
defer device.peers.mutex.Unlock() defer device.peers.Unlock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.handshake.mutex.RLock() peer.handshake.mutex.RLock()
@ -310,15 +310,15 @@ func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device {
} }
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.mutex.RLock() device.peers.RLock()
defer device.peers.mutex.RUnlock() defer device.peers.RUnlock()
return device.peers.keyMap[pk] return device.peers.keyMap[pk]
} }
func (device *Device) RemovePeer(key NoisePublicKey) { func (device *Device) RemovePeer(key NoisePublicKey) {
device.peers.mutex.Lock() device.peers.Lock()
defer device.peers.mutex.Unlock() defer device.peers.Unlock()
// stop peer and remove from routing // stop peer and remove from routing
@ -329,8 +329,8 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
} }
func (device *Device) RemoveAllPeers() { func (device *Device) RemoveAllPeers() {
device.peers.mutex.Lock() device.peers.Lock()
defer device.peers.mutex.Unlock() defer device.peers.Unlock()
for key, peer := range device.peers.keyMap { for key, peer := range device.peers.keyMap {
unsafeRemovePeer(device, peer, key) unsafeRemovePeer(device, peer, key)
@ -367,8 +367,8 @@ func (device *Device) Close() {
device.log.Info.Println("Device closing") device.log.Info.Println("Device closing")
device.state.changing.Set(true) device.state.changing.Set(true)
device.state.mutex.Lock() device.state.Lock()
defer device.state.mutex.Unlock() defer device.state.Unlock()
device.tun.device.Close() device.tun.device.Close()
device.BindClose() device.BindClose()

View file

@ -18,7 +18,7 @@ type IndexTableEntry struct {
} }
type IndexTable struct { type IndexTable struct {
mutex sync.RWMutex sync.RWMutex
table map[uint32]IndexTableEntry table map[uint32]IndexTableEntry
} }
@ -29,20 +29,20 @@ func randUint32() (uint32, error) {
} }
func (table *IndexTable) Init() { func (table *IndexTable) Init() {
table.mutex.Lock() table.Lock()
defer table.mutex.Unlock() defer table.Unlock()
table.table = make(map[uint32]IndexTableEntry) table.table = make(map[uint32]IndexTableEntry)
} }
func (table *IndexTable) Delete(index uint32) { func (table *IndexTable) Delete(index uint32) {
table.mutex.Lock() table.Lock()
defer table.mutex.Unlock() defer table.Unlock()
delete(table.table, index) delete(table.table, index)
} }
func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) { func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) {
table.mutex.Lock() table.Lock()
defer table.mutex.Unlock() defer table.Unlock()
entry, ok := table.table[index] entry, ok := table.table[index]
if !ok { if !ok {
return return
@ -65,19 +65,19 @@ func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake)
// check if index used // check if index used
table.mutex.RLock() table.RLock()
_, ok := table.table[index] _, ok := table.table[index]
table.mutex.RUnlock() table.RUnlock()
if ok { if ok {
continue continue
} }
// check again while locked // check again while locked
table.mutex.Lock() table.Lock()
_, found := table.table[index] _, found := table.table[index]
if found { if found {
table.mutex.Unlock() table.Unlock()
continue continue
} }
table.table[index] = IndexTableEntry{ table.table[index] = IndexTableEntry{
@ -85,13 +85,13 @@ func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake)
handshake: handshake, handshake: handshake,
keypair: nil, keypair: nil,
} }
table.mutex.Unlock() table.Unlock()
return index, nil return index, nil
} }
} }
func (table *IndexTable) Lookup(id uint32) IndexTableEntry { func (table *IndexTable) Lookup(id uint32) IndexTableEntry {
table.mutex.RLock() table.RLock()
defer table.mutex.RUnlock() defer table.RUnlock()
return table.table[id] return table.table[id]
} }

View file

@ -31,15 +31,15 @@ type Keypair struct {
} }
type Keypairs struct { type Keypairs struct {
mutex sync.RWMutex sync.RWMutex
current *Keypair current *Keypair
previous *Keypair previous *Keypair
next *Keypair next *Keypair
} }
func (kp *Keypairs) Current() *Keypair { func (kp *Keypairs) Current() *Keypair {
kp.mutex.RLock() kp.RLock()
defer kp.mutex.RUnlock() defer kp.RUnlock()
return kp.current return kp.current
} }

View file

@ -17,11 +17,11 @@ const (
) )
type AtomicBool struct { type AtomicBool struct {
flag int32 int32
} }
func (a *AtomicBool) Get() bool { func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.flag) == AtomicTrue return atomic.LoadInt32(&a.int32) == AtomicTrue
} }
func (a *AtomicBool) Swap(val bool) bool { func (a *AtomicBool) Swap(val bool) bool {
@ -29,7 +29,7 @@ func (a *AtomicBool) Swap(val bool) bool {
if val { if val {
flag = AtomicTrue flag = AtomicTrue
} }
return atomic.SwapInt32(&a.flag, flag) == AtomicTrue return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
} }
func (a *AtomicBool) Set(val bool) { func (a *AtomicBool) Set(val bool) {
@ -37,7 +37,7 @@ func (a *AtomicBool) Set(val bool) {
if val { if val {
flag = AtomicTrue flag = AtomicTrue
} }
atomic.StoreInt32(&a.flag, flag) atomic.StoreInt32(&a.int32, flag)
} }
func min(a, b uint) uint { func min(a, b uint) uint {

View file

@ -154,8 +154,8 @@ func init() {
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
device.staticIdentity.mutex.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.mutex.RUnlock() defer device.staticIdentity.RUnlock()
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
@ -241,8 +241,8 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
return nil return nil
} }
device.staticIdentity.mutex.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.mutex.RUnlock() defer device.staticIdentity.RUnlock()
mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
mixHash(&hash, &hash, msg.Ephemeral[:]) mixHash(&hash, &hash, msg.Ephemeral[:])
@ -423,8 +423,8 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// lock private key for reading // lock private key for reading
device.staticIdentity.mutex.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.mutex.RUnlock() defer device.staticIdentity.RUnlock()
// finish 3-way DH // finish 3-way DH
@ -554,8 +554,8 @@ func (peer *Peer) BeginSymmetricSession() error {
// rotate key pairs // rotate key pairs
keypairs := &peer.keypairs keypairs := &peer.keypairs
keypairs.mutex.Lock() keypairs.Lock()
defer keypairs.mutex.Unlock() defer keypairs.Unlock()
previous := keypairs.previous previous := keypairs.previous
next := keypairs.next next := keypairs.next
@ -586,8 +586,8 @@ func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
if keypairs.next != receivedKeypair { if keypairs.next != receivedKeypair {
return false return false
} }
keypairs.mutex.Lock() keypairs.Lock()
defer keypairs.mutex.Unlock() defer keypairs.Unlock()
if keypairs.next != receivedKeypair { if keypairs.next != receivedKeypair {
return false return false
} }

46
peer.go
View file

@ -19,7 +19,7 @@ const (
type Peer struct { type Peer struct {
isRunning AtomicBool isRunning AtomicBool
mutex sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
keypairs Keypairs keypairs Keypairs
handshake Handshake handshake Handshake
device *Device device *Device
@ -57,10 +57,10 @@ type Peer struct {
} }
routines struct { routines struct {
mutex sync.Mutex // held when stopping / starting routines sync.Mutex // held when stopping / starting routines
starting sync.WaitGroup // routines pending start starting sync.WaitGroup // routines pending start
stopping sync.WaitGroup // routines pending stop stopping sync.WaitGroup // routines pending stop
stop chan struct{} // size 0, stop all go routines in peer stop chan struct{} // size 0, stop all go routines in peer
} }
cookieGenerator CookieGenerator cookieGenerator CookieGenerator
@ -74,11 +74,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// lock resources // lock resources
device.staticIdentity.mutex.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.mutex.RUnlock() defer device.staticIdentity.RUnlock()
device.peers.mutex.Lock() device.peers.Lock()
defer device.peers.mutex.Unlock() defer device.peers.Unlock()
// check if over limit // check if over limit
@ -89,8 +89,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// create peer // create peer
peer := new(Peer) peer := new(Peer)
peer.mutex.Lock() peer.Lock()
defer peer.mutex.Unlock() defer peer.Unlock()
peer.cookieGenerator.Init(pk) peer.cookieGenerator.Init(pk)
peer.device = device peer.device = device
@ -126,15 +126,15 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
} }
func (peer *Peer) SendBuffer(buffer []byte) error { func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.mutex.RLock() peer.device.net.RLock()
defer peer.device.net.mutex.RUnlock() defer peer.device.net.RUnlock()
if peer.device.net.bind == nil { if peer.device.net.bind == nil {
return errors.New("no bind") return errors.New("no bind")
} }
peer.mutex.RLock() peer.RLock()
defer peer.mutex.RUnlock() defer peer.RUnlock()
if peer.endpoint == nil { if peer.endpoint == nil {
return errors.New("no known endpoint for peer") return errors.New("no known endpoint for peer")
@ -162,8 +162,8 @@ func (peer *Peer) Start() {
// prevent simultaneous start/stop operations // prevent simultaneous start/stop operations
peer.routines.mutex.Lock() peer.routines.Lock()
defer peer.routines.mutex.Unlock() defer peer.routines.Unlock()
if peer.isRunning.Get() { if peer.isRunning.Get() {
return return
@ -207,14 +207,14 @@ func (peer *Peer) ZeroAndFlushAll() {
// clear key pairs // clear key pairs
keypairs := &peer.keypairs keypairs := &peer.keypairs
keypairs.mutex.Lock() keypairs.Lock()
device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current) device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.next) device.DeleteKeypair(keypairs.next)
keypairs.previous = nil keypairs.previous = nil
keypairs.current = nil keypairs.current = nil
keypairs.next = nil keypairs.next = nil
keypairs.mutex.Unlock() keypairs.Unlock()
// clear handshake state // clear handshake state
@ -237,8 +237,8 @@ func (peer *Peer) Stop() {
peer.routines.starting.Wait() peer.routines.starting.Wait()
peer.routines.mutex.Lock() peer.routines.Lock()
defer peer.routines.mutex.Unlock() defer peer.routines.Unlock()
peer.device.log.Debug.Println(peer, "- Stopping...") peer.device.log.Debug.Println(peer, "- Stopping...")
@ -264,7 +264,7 @@ func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
if roamingDisabled { if roamingDisabled {
return return
} }
peer.mutex.Lock() peer.Lock()
peer.endpoint = endpoint peer.endpoint = endpoint
peer.mutex.Unlock() peer.Unlock()
} }

View file

@ -20,21 +20,21 @@ const (
) )
type RatelimiterEntry struct { type RatelimiterEntry struct {
mutex sync.Mutex sync.Mutex
lastTime time.Time lastTime time.Time
tokens int64 tokens int64
} }
type Ratelimiter struct { type Ratelimiter struct {
mutex sync.RWMutex sync.RWMutex
stopReset chan struct{} stopReset chan struct{}
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
} }
func (rate *Ratelimiter) Close() { func (rate *Ratelimiter) Close() {
rate.mutex.Lock() rate.Lock()
defer rate.mutex.Unlock() defer rate.Unlock()
if rate.stopReset != nil { if rate.stopReset != nil {
close(rate.stopReset) close(rate.stopReset)
@ -42,8 +42,8 @@ func (rate *Ratelimiter) Close() {
} }
func (rate *Ratelimiter) Init() { func (rate *Ratelimiter) Init() {
rate.mutex.Lock() rate.Lock()
defer rate.mutex.Unlock() defer rate.Unlock()
// stop any ongoing garbage collection routine // stop any ongoing garbage collection routine
@ -71,23 +71,23 @@ func (rate *Ratelimiter) Init() {
} }
case <-ticker.C: case <-ticker.C:
func() { func() {
rate.mutex.Lock() rate.Lock()
defer rate.mutex.Unlock() defer rate.Unlock()
for key, entry := range rate.tableIPv4 { for key, entry := range rate.tableIPv4 {
entry.mutex.Lock() entry.Lock()
if time.Now().Sub(entry.lastTime) > garbageCollectTime { if time.Now().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv4, key) delete(rate.tableIPv4, key)
} }
entry.mutex.Unlock() entry.Unlock()
} }
for key, entry := range rate.tableIPv6 { for key, entry := range rate.tableIPv6 {
entry.mutex.Lock() entry.Lock()
if time.Now().Sub(entry.lastTime) > garbageCollectTime { if time.Now().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv6, key) delete(rate.tableIPv6, key)
} }
entry.mutex.Unlock() entry.Unlock()
} }
if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 { if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
@ -109,7 +109,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
IPv4 := ip.To4() IPv4 := ip.To4()
IPv6 := ip.To16() IPv6 := ip.To16()
rate.mutex.RLock() rate.RLock()
if IPv4 != nil { if IPv4 != nil {
copy(keyIPv4[:], IPv4) copy(keyIPv4[:], IPv4)
@ -119,7 +119,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
entry = rate.tableIPv6[keyIPv6] entry = rate.tableIPv6[keyIPv6]
} }
rate.mutex.RUnlock() rate.RUnlock()
// make new entry if not found // make new entry if not found
@ -127,7 +127,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
entry = new(RatelimiterEntry) entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost entry.tokens = maxTokens - packetCost
entry.lastTime = time.Now() entry.lastTime = time.Now()
rate.mutex.Lock() rate.Lock()
if IPv4 != nil { if IPv4 != nil {
rate.tableIPv4[keyIPv4] = entry rate.tableIPv4[keyIPv4] = entry
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
@ -139,13 +139,13 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
rate.stopReset <- struct{}{} rate.stopReset <- struct{}{}
} }
} }
rate.mutex.Unlock() rate.Unlock()
return true return true
} }
// add tokens to entry // add tokens to entry
entry.mutex.Lock() entry.Lock()
now := time.Now() now := time.Now()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds() entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
entry.lastTime = now entry.lastTime = now
@ -157,9 +157,9 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
if entry.tokens > packetCost { if entry.tokens > packetCost {
entry.tokens -= packetCost entry.tokens -= packetCost
entry.mutex.Unlock() entry.Unlock()
return true return true
} }
entry.mutex.Unlock() entry.Unlock()
return false return false
} }

View file

@ -26,8 +26,8 @@ type QueueHandshakeElement struct {
} }
type QueueInboundElement struct { type QueueInboundElement struct {
dropped int32 dropped int32
mutex sync.Mutex sync.Mutex
buffer *[MaxMessageSize]byte buffer *[MaxMessageSize]byte
packet []byte packet []byte
counter uint64 counter uint64
@ -51,7 +51,7 @@ func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueIn
return true return true
default: default:
element.Drop() element.Drop()
element.mutex.Unlock() element.Unlock()
return false return false
} }
default: default:
@ -177,8 +177,8 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
elem.dropped = AtomicFalse elem.dropped = AtomicFalse
elem.endpoint = endpoint elem.endpoint = endpoint
elem.counter = 0 elem.counter = 0
elem.mutex = sync.Mutex{} elem.Mutex = sync.Mutex{}
elem.mutex.Lock() elem.Lock()
// add to decryption queues // add to decryption queues
@ -281,7 +281,7 @@ func (device *Device) RoutineDecryption() {
elem.Drop() elem.Drop()
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
} }
elem.mutex.Unlock() elem.Unlock()
} }
} }
} }
@ -529,7 +529,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// wait for decryption // wait for decryption
elem.mutex.Lock() elem.Lock()
if elem.IsDropped() { if elem.IsDropped() {
continue continue

View file

@ -8,15 +8,15 @@ package rwcancel
import "golang.org/x/sys/unix" import "golang.org/x/sys/unix"
type fdSet struct { type fdSet struct {
fdset unix.FdSet unix.FdSet
} }
func (fdset *fdSet) set(i int) { func (fdset *fdSet) set(i int) {
bits := 32 << (^uint(0) >> 63) bits := 32 << (^uint(0) >> 63)
fdset.fdset.Bits[i/bits] |= 1 << uint(i%bits) fdset.Bits[i/bits] |= 1 << uint(i%bits)
} }
func (fdset *fdSet) check(i int) bool { func (fdset *fdSet) check(i int) bool {
bits := 32 << (^uint(0) >> 63) bits := 32 << (^uint(0) >> 63)
return (fdset.fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0 return (fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0
} }

View file

@ -59,7 +59,7 @@ func (rw *RWCancel) ReadyRead() bool {
fdset := fdSet{} fdset := fdSet{}
fdset.set(rw.fd) fdset.set(rw.fd)
fdset.set(closeFd) fdset.set(closeFd)
err := unixSelect(max(rw.fd, closeFd)+1, &fdset.fdset, nil, nil, nil) err := unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
if err != nil { if err != nil {
return false return false
} }
@ -74,7 +74,7 @@ func (rw *RWCancel) ReadyWrite() bool {
fdset := fdSet{} fdset := fdSet{}
fdset.set(rw.fd) fdset.set(rw.fd)
fdset.set(closeFd) fdset.set(closeFd)
err := unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.fdset, nil, nil) err := unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
if err != nil { if err != nil {
return false return false
} }

14
send.go
View file

@ -43,7 +43,7 @@ import (
type QueueOutboundElement struct { type QueueOutboundElement struct {
dropped int32 dropped int32
mutex sync.Mutex sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!) packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption nonce uint64 // nonce for encryption
@ -55,7 +55,7 @@ func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement() elem := device.GetOutboundElement()
elem.dropped = AtomicFalse elem.dropped = AtomicFalse
elem.buffer = device.GetMessageBuffer() elem.buffer = device.GetMessageBuffer()
elem.mutex = sync.Mutex{} elem.Mutex = sync.Mutex{}
elem.nonce = 0 elem.nonce = 0
elem.keypair = nil elem.keypair = nil
elem.peer = nil elem.peer = nil
@ -95,7 +95,7 @@ func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement,
default: default:
element.Drop() element.Drop()
element.peer.device.PutMessageBuffer(element.buffer) element.peer.device.PutMessageBuffer(element.buffer)
element.mutex.Unlock() element.Unlock()
} }
default: default:
element.peer.device.PutMessageBuffer(element.buffer) element.peer.device.PutMessageBuffer(element.buffer)
@ -442,7 +442,7 @@ func (peer *Peer) RoutineNonce() {
elem.keypair = keypair elem.keypair = keypair
elem.dropped = AtomicFalse elem.dropped = AtomicFalse
elem.mutex.Lock() elem.Lock()
// add to parallel and sequential queue // add to parallel and sequential queue
addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem) addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem)
@ -468,7 +468,7 @@ func (device *Device) RoutineEncryption() {
if ok && !elem.IsDropped() { if ok && !elem.IsDropped() {
elem.Drop() elem.Drop()
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
elem.mutex.Unlock() elem.Unlock()
} }
default: default:
goto out goto out
@ -535,7 +535,7 @@ func (device *Device) RoutineEncryption() {
elem.packet, elem.packet,
nil, nil,
) )
elem.mutex.Unlock() elem.Unlock()
} }
} }
} }
@ -588,7 +588,7 @@ func (peer *Peer) RoutineSequentialSender() {
return return
} }
elem.mutex.Lock() elem.Lock()
if elem.IsDropped() { if elem.IsDropped() {
device.PutOutboundElement(elem) device.PutOutboundElement(elem)
continue continue

View file

@ -19,7 +19,7 @@ import (
*/ */
type Timer struct { type Timer struct {
timer *time.Timer *time.Timer
modifyingLock sync.RWMutex modifyingLock sync.RWMutex
runningLock sync.Mutex runningLock sync.Mutex
isPending bool isPending bool
@ -27,7 +27,7 @@ type Timer struct {
func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
timer := &Timer{} timer := &Timer{}
timer.timer = time.AfterFunc(time.Hour, func() { timer.Timer = time.AfterFunc(time.Hour, func() {
timer.runningLock.Lock() timer.runningLock.Lock()
timer.modifyingLock.Lock() timer.modifyingLock.Lock()
@ -42,21 +42,21 @@ func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
expirationFunction(peer) expirationFunction(peer)
timer.runningLock.Unlock() timer.runningLock.Unlock()
}) })
timer.timer.Stop() timer.Stop()
return timer return timer
} }
func (timer *Timer) Mod(d time.Duration) { func (timer *Timer) Mod(d time.Duration) {
timer.modifyingLock.Lock() timer.modifyingLock.Lock()
timer.isPending = true timer.isPending = true
timer.timer.Reset(d) timer.Reset(d)
timer.modifyingLock.Unlock() timer.modifyingLock.Unlock()
} }
func (timer *Timer) Del() { func (timer *Timer) Del() {
timer.modifyingLock.Lock() timer.modifyingLock.Lock()
timer.isPending = false timer.isPending = false
timer.timer.Stop() timer.Stop()
timer.modifyingLock.Unlock() timer.modifyingLock.Unlock()
} }
@ -101,11 +101,11 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */ /* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.mutex.Lock() peer.Lock()
if peer.endpoint != nil { if peer.endpoint != nil {
peer.endpoint.ClearSrc() peer.endpoint.ClearSrc()
} }
peer.mutex.Unlock() peer.Unlock()
peer.SendHandshakeInitiation(true) peer.SendHandshakeInitiation(true)
} }
@ -124,11 +124,11 @@ func expiredSendKeepalive(peer *Peer) {
func expiredNewHandshake(peer *Peer) { func expiredNewHandshake(peer *Peer) {
peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */ /* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.mutex.Lock() peer.Lock()
if peer.endpoint != nil { if peer.endpoint != nil {
peer.endpoint.ClearSrc() peer.endpoint.ClearSrc()
} }
peer.mutex.Unlock() peer.Unlock()
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }

76
uapi.go
View file

@ -17,15 +17,15 @@ import (
) )
type IPCError struct { type IPCError struct {
Code int64 int64
} }
func (s *IPCError) Error() string { func (s *IPCError) Error() string {
return fmt.Sprintf("IPC error: %d", s.Code) return fmt.Sprintf("IPC error: %d", s.int64)
} }
func (s *IPCError) ErrorCode() int64 { func (s *IPCError) ErrorCode() int64 {
return s.Code return s.int64
} }
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
@ -43,14 +43,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// lock required resources // lock required resources
device.net.mutex.RLock() device.net.RLock()
defer device.net.mutex.RUnlock() defer device.net.RUnlock()
device.staticIdentity.mutex.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.mutex.RUnlock() defer device.staticIdentity.RUnlock()
device.peers.mutex.RLock() device.peers.RLock()
defer device.peers.mutex.RUnlock() defer device.peers.RUnlock()
// serialize device related values // serialize device related values
@ -69,8 +69,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// serialize each peer state // serialize each peer state
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.mutex.RLock() peer.RLock()
defer peer.mutex.RUnlock() defer peer.RUnlock()
send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex())
@ -101,9 +101,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
for _, line := range lines { for _, line := range lines {
_, err := socket.WriteString(line + "\n") _, err := socket.WriteString(line + "\n")
if err != nil { if err != nil {
return &IPCError{ return &IPCError{ipcErrorIO}
Code: ipcErrorIO,
}
} }
} }
@ -130,7 +128,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
} }
parts := strings.Split(line, "=") parts := strings.Split(line, "=")
if len(parts) != 2 { if len(parts) != 2 {
return &IPCError{Code: ipcErrorProtocol} return &IPCError{ipcErrorProtocol}
} }
key := parts[0] key := parts[0]
value := parts[1] value := parts[1]
@ -145,7 +143,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
err := sk.FromHex(value) err := sk.FromHex(value)
if err != nil { if err != nil {
logError.Println("Failed to set private_key:", err) logError.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
logDebug.Println("UAPI: Updating private key") logDebug.Println("UAPI: Updating private key")
device.SetPrivateKey(sk) device.SetPrivateKey(sk)
@ -157,20 +155,20 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
port, err := strconv.ParseUint(value, 10, 16) port, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
logError.Println("Failed to parse listen_port:", err) logError.Println("Failed to parse listen_port:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
// update port and rebind // update port and rebind
logDebug.Println("UAPI: Updating listen port") logDebug.Println("UAPI: Updating listen port")
device.net.mutex.Lock() device.net.Lock()
device.net.port = uint16(port) device.net.port = uint16(port)
device.net.mutex.Unlock() device.net.Unlock()
if err := device.BindUpdate(); err != nil { if err := device.BindUpdate(); err != nil {
logError.Println("Failed to set listen_port:", err) logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorPortInUse} return &IPCError{ipcErrorPortInUse}
} }
case "fwmark": case "fwmark":
@ -187,14 +185,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
if err != nil { if err != nil {
logError.Println("Invalid fwmark", err) logError.Println("Invalid fwmark", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
logDebug.Println("UAPI: Updating fwmark") logDebug.Println("UAPI: Updating fwmark")
if err := device.BindSetMark(uint32(fwmark)); err != nil { if err := device.BindSetMark(uint32(fwmark)); err != nil {
logError.Println("Failed to update fwmark:", err) logError.Println("Failed to update fwmark:", err)
return &IPCError{Code: ipcErrorPortInUse} return &IPCError{ipcErrorPortInUse}
} }
case "public_key": case "public_key":
@ -205,14 +203,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "replace_peers": case "replace_peers":
if value != "true" { if value != "true" {
logError.Println("Failed to set replace_peers, invalid value:", value) logError.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
logDebug.Println("UAPI: Removing all peers") logDebug.Println("UAPI: Removing all peers")
device.RemoveAllPeers() device.RemoveAllPeers()
default: default:
logError.Println("Invalid UAPI device key:", key) logError.Println("Invalid UAPI device key:", key)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
} }
@ -227,14 +225,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
err := publicKey.FromHex(value) err := publicKey.FromHex(value)
if err != nil { if err != nil {
logError.Println("Failed to get peer by public key:", err) logError.Println("Failed to get peer by public key:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
// ignore peer with public key of device // ignore peer with public key of device
device.staticIdentity.mutex.RLock() device.staticIdentity.RLock()
dummy = device.staticIdentity.publicKey.Equals(publicKey) dummy = device.staticIdentity.publicKey.Equals(publicKey)
device.staticIdentity.mutex.RUnlock() device.staticIdentity.RUnlock()
if dummy { if dummy {
peer = &Peer{} peer = &Peer{}
@ -246,7 +244,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
peer, err = device.NewPeer(publicKey) peer, err = device.NewPeer(publicKey)
if err != nil { if err != nil {
logError.Println("Failed to create new peer:", err) logError.Println("Failed to create new peer:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
logDebug.Println(peer, "- UAPI: Created") logDebug.Println(peer, "- UAPI: Created")
} }
@ -257,7 +255,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
if value != "true" { if value != "true" {
logError.Println("Failed to set remove, invalid value:", value) logError.Println("Failed to set remove, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
if !dummy { if !dummy {
logDebug.Println(peer, "- UAPI: Removing") logDebug.Println(peer, "- UAPI: Removing")
@ -278,7 +276,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
if err != nil { if err != nil {
logError.Println("Failed to set preshared key:", err) logError.Println("Failed to set preshared key:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
case "endpoint": case "endpoint":
@ -288,8 +286,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logDebug.Println(peer, "- UAPI: Updating endpoint") logDebug.Println(peer, "- UAPI: Updating endpoint")
err := func() error { err := func() error {
peer.mutex.Lock() peer.Lock()
defer peer.mutex.Unlock() defer peer.Unlock()
endpoint, err := CreateEndpoint(value) endpoint, err := CreateEndpoint(value)
if err != nil { if err != nil {
return err return err
@ -300,7 +298,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
if err != nil { if err != nil {
logError.Println("Failed to set endpoint:", value) logError.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
case "persistent_keepalive_interval": case "persistent_keepalive_interval":
@ -312,7 +310,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
secs, err := strconv.ParseUint(value, 10, 16) secs, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
logError.Println("Failed to set persistent keepalive interval:", err) logError.Println("Failed to set persistent keepalive interval:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
old := peer.persistentKeepaliveInterval old := peer.persistentKeepaliveInterval
@ -323,7 +321,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
if old == 0 && secs != 0 { if old == 0 && secs != 0 {
if err != nil { if err != nil {
logError.Println("Failed to get tun device status:", err) logError.Println("Failed to get tun device status:", err)
return &IPCError{Code: ipcErrorIO} return &IPCError{ipcErrorIO}
} }
if device.isUp.Get() && !dummy { if device.isUp.Get() && !dummy {
peer.SendKeepalive() peer.SendKeepalive()
@ -336,7 +334,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
if value != "true" { if value != "true" {
logError.Println("Failed to replace allowedips, invalid value:", value) logError.Println("Failed to replace allowedips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
if dummy { if dummy {
@ -352,7 +350,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
_, network, err := net.ParseCIDR(value) _, network, err := net.ParseCIDR(value)
if err != nil { if err != nil {
logError.Println("Failed to set allowed ip:", err) logError.Println("Failed to set allowed ip:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
if dummy { if dummy {
@ -366,12 +364,12 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
if value != "1" { if value != "1" {
logError.Println("Invalid protocol version:", value) logError.Println("Invalid protocol version:", value)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
default: default:
logError.Println("Invalid UAPI peer key:", key) logError.Println("Invalid UAPI peer key:", key)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{ipcErrorInvalid}
} }
} }
} }