package main import ( "crypto/hmac" "crypto/rand" "errors" "golang.org/x/crypto/blake2s" "net" "sync" "time" ) type MACStateDevice struct { mutex sync.RWMutex refreshed time.Time secret [blake2s.Size]byte keyMAC1 [blake2s.Size]byte keyMAC2 [blake2s.Size]byte } type MACStatePeer struct { mutex sync.RWMutex cookieSet time.Time cookie [blake2s.Size128]byte lastMAC1 [blake2s.Size128]byte keyMAC1 [blake2s.Size]byte keyMAC2 [blake2s.Size]byte } /* Methods for verifing MAC fields * and creating/consuming cookies replies * (per device) */ func (state *MACStateDevice) Init(pk NoisePublicKey) { state.mutex.Lock() defer state.mutex.Unlock() func() { hsh, _ := blake2s.New256(nil) hsh.Write([]byte(WGLabelMAC1)) hsh.Write(pk[:]) hsh.Sum(state.keyMAC1[:0]) }() func() { hsh, _ := blake2s.New256(nil) hsh.Write([]byte(WGLabelCookie)) hsh.Write(pk[:]) hsh.Sum(state.keyMAC2[:0]) }() state.refreshed = time.Time{} } func (state *MACStateDevice) CheckMAC1(msg []byte) bool { size := len(msg) startMac1 := size - (blake2s.Size128 * 2) startMac2 := size - blake2s.Size128 var mac1 [blake2s.Size128]byte func() { mac, _ := blake2s.New128(state.keyMAC1[:]) mac.Write(msg[:startMac1]) mac.Sum(mac1[:0]) }() return hmac.Equal(mac1[:], msg[startMac1:startMac2]) } func (state *MACStateDevice) CheckMAC2(msg []byte, addr *net.UDPAddr) bool { state.mutex.RLock() defer state.mutex.RUnlock() if time.Now().Sub(state.refreshed) > CookieRefreshTime { return false } // derive cookie key var cookie [blake2s.Size128]byte func() { port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)} mac, _ := blake2s.New128(state.secret[:]) mac.Write(addr.IP) mac.Write(port[:]) mac.Sum(cookie[:0]) }() // calculate mac of packet start := len(msg) - blake2s.Size128 var mac2 [blake2s.Size128]byte func() { mac, _ := blake2s.New128(cookie[:]) mac.Write(msg[:start]) mac.Sum(mac2[:0]) }() return hmac.Equal(mac2[:], msg[start:]) } func (device *Device) CreateMessageCookieReply( msg []byte, receiver uint32, addr *net.UDPAddr, ) (*MessageCookieReply, error) { state := &device.mac state.mutex.RLock() // refresh cookie secret if time.Now().Sub(state.refreshed) > CookieRefreshTime { state.mutex.RUnlock() state.mutex.Lock() _, err := rand.Read(state.secret[:]) if err != nil { state.mutex.Unlock() return nil, err } state.refreshed = time.Now() state.mutex.Unlock() state.mutex.RLock() } // derive cookie key var cookie [blake2s.Size128]byte func() { port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)} mac, _ := blake2s.New128(state.secret[:]) mac.Write(addr.IP) mac.Write(port[:]) mac.Sum(cookie[:0]) }() // encrypt cookie size := len(msg) startMac1 := size - (blake2s.Size128 * 2) startMac2 := size - blake2s.Size128 mac1 := msg[startMac1:startMac2] reply := new(MessageCookieReply) reply.Type = MessageCookieReplyType reply.Receiver = receiver _, err := rand.Read(reply.Nonce[:]) if err != nil { state.mutex.RUnlock() return nil, err } XChaCha20Poly1305Encrypt( reply.Cookie[:0], &reply.Nonce, cookie[:], mac1, &state.keyMAC2, ) state.mutex.RUnlock() return reply, nil } func (device *Device) ConsumeMessageCookieReply(msg *MessageCookieReply) bool { if msg.Type != MessageCookieReplyType { return false } // lookup peer lookup := device.indices.Lookup(msg.Receiver) if lookup.handshake == nil { return false } // decrypt and store cookie var cookie [blake2s.Size128]byte state := &lookup.peer.mac state.mutex.Lock() defer state.mutex.Unlock() _, err := XChaCha20Poly1305Decrypt( cookie[:0], &msg.Nonce, msg.Cookie[:], state.lastMAC1[:], &state.keyMAC2, ) if err != nil { return false } state.cookieSet = time.Now() state.cookie = cookie return true } /* Methods for generating the MAC fields * (per peer) */ func (state *MACStatePeer) Init(pk NoisePublicKey) { state.mutex.Lock() defer state.mutex.Unlock() func() { hsh, _ := blake2s.New256(nil) hsh.Write([]byte(WGLabelMAC1)) hsh.Write(pk[:]) hsh.Sum(state.keyMAC1[:0]) }() func() { hsh, _ := blake2s.New256(nil) hsh.Write([]byte(WGLabelCookie)) hsh.Write(pk[:]) hsh.Sum(state.keyMAC2[:0]) }() state.cookieSet = time.Time{} // never } func (state *MACStatePeer) AddMacs(msg []byte) { size := len(msg) if size < blake2s.Size128*2 { panic(errors.New("bug: message too short")) } startMac1 := size - (blake2s.Size128 * 2) startMac2 := size - blake2s.Size128 mac1 := msg[startMac1 : startMac1+blake2s.Size128] mac2 := msg[startMac2 : startMac2+blake2s.Size128] state.mutex.Lock() defer state.mutex.Unlock() // set mac1 func() { mac, _ := blake2s.New128(state.keyMAC1[:]) mac.Write(msg[:startMac1]) mac.Sum(mac1[:0]) }() copy(state.lastMAC1[:], mac1) // set mac2 if state.cookieSet.IsZero() { return } if time.Now().Sub(state.cookieSet) > CookieRefreshTime { state.cookieSet = time.Time{} return } func() { mac, _ := blake2s.New128(state.cookie[:]) mac.Write(msg[:startMac2]) mac.Sum(mac2[:0]) }() }