package utp import ( "context" "crypto/rand" "encoding/binary" "net" "sync" ) const ( Data = 0 Fin = 1 State = 2 Reset = 3 Syn = 4 ) const headerSize = 20 type header struct { typ uint8 ver uint8 ext uint8 connID uint16 timestamp uint32 timeDiff uint32 wnd uint32 seq uint16 ack uint16 } func encode(h *header, buf []byte) { buf[0] = (h.typ << 4) | (h.ver & 0x0F) buf[1] = h.ext binary.BigEndian.PutUint16(buf[2:], h.connID) binary.BigEndian.PutUint32(buf[4:], h.timestamp) binary.BigEndian.PutUint32(buf[8:], h.timeDiff) binary.BigEndian.PutUint32(buf[12:], h.wnd) binary.BigEndian.PutUint16(buf[16:], h.seq) binary.BigEndian.PutUint16(buf[18:], h.ack) } func decode(buf []byte) header { return header{ typ: (buf[0] >> 4) & 0x0F, ver: buf[0] & 0x0F, ext: buf[1], connID: binary.BigEndian.Uint16(buf[2:]), timestamp: binary.BigEndian.Uint32(buf[4:]), timeDiff: binary.BigEndian.Uint32(buf[8:]), wnd: binary.BigEndian.Uint32(buf[12:]), seq: binary.BigEndian.Uint16(buf[16:]), ack: binary.BigEndian.Uint16(buf[18:]), } } type Socket struct { conn *net.UDPConn dhtCh chan<- Packet mu sync.RWMutex conns map[uint16]*Conn accepts chan *Conn ctx context.Context cancel context.CancelFunc } type Packet struct { Hdr header Payload []byte Raw []byte Addr *net.UDPAddr } func New(conn *net.UDPConn, dhtCh chan<- Packet) *Socket { ctx, cancel := context.WithCancel(context.Background()) s := &Socket{ conn: conn, dhtCh: dhtCh, conns: make(map[uint16]*Conn), accepts: make(chan *Conn, 16), ctx: ctx, cancel: cancel, } return s } func (s *Socket) Start() { go s.reader() } func (s *Socket) reader() { buf := make([]byte, 65535) for { n, addr, err := s.conn.ReadFromUDP(buf) if err != nil { return } if n == 0 { continue } data := append([]byte(nil), buf[:n]...) if data[0] == 'd' { select { case s.dhtCh <- Packet{Raw: data, Addr: addr}: default: } continue } if n < headerSize { continue } pkt := NewPacket(data, addr) if pkt.Hdr.typ == Syn { s.handleSyn(pkt) } else { s.dispatch(pkt) } } } func (s *Socket) handleSyn(pkt Packet) { c := s.newConn(pkt.Hdr.connID, pkt.Addr, false) s.mu.Lock() s.conns[c.recvID] = c s.mu.Unlock() go c.run() c.in <- pkt select { case s.accepts <- c: default: c.Close() } } func (s *Socket) dispatch(pkt Packet) { s.mu.RLock() c := s.conns[pkt.Hdr.connID] s.mu.RUnlock() if c != nil { select { case c.in <- pkt: default: } } } func (s *Socket) removeConn(id uint16) { s.mu.Lock() delete(s.conns, id) s.mu.Unlock() } func (s *Socket) newConn(peerID uint16, addr *net.UDPAddr, initiator bool) *Conn { ctx, cancel := context.WithCancel(s.ctx) c := &Conn{ sock: s, addr: addr, in: make(chan Packet, 256), reads: make(chan readReq), writes: make(chan writeReq), closeReq: make(chan struct{}), ready: make(chan struct{}), ctx: ctx, cancel: cancel, } if initiator { c.recvID = randUint16() c.sendID = c.recvID + 1 } else { c.recvID = peerID + 1 c.sendID = peerID } return c } func (s *Socket) DialContext(ctx context.Context, addr *net.UDPAddr) (*Conn, error) { c := s.newConn(0, addr, true) s.mu.Lock() s.conns[c.recvID] = c s.mu.Unlock() go c.run() select { case <-c.ready: return c, nil case <-ctx.Done(): c.Close() return nil, ctx.Err() } } func (s *Socket) Accept() *Conn { return <-s.accepts } func (s *Socket) Close() { s.conn.Close() s.cancel() } func (s *Socket) LocalPort() int { return s.conn.LocalAddr().(*net.UDPAddr).Port } func randUint16() uint16 { var b [2]byte if _, err := rand.Read(b[:]); err != nil { panic(err) } return binary.BigEndian.Uint16(b[:]) } func NewPacket(data []byte, addr *net.UDPAddr) Packet { pkt := Packet{ Raw: data, Addr: addr, } if len(data) >= headerSize { pkt.Hdr = decode(data) pkt.Payload = data[headerSize:] } return pkt }