235 lines
4.0 KiB
Go
235 lines
4.0 KiB
Go
package utp
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"net"
|
|
"sync"
|
|
)
|
|
|
|
const (
|
|
Data = 0
|
|
Fin = 1
|
|
State = 2
|
|
Reset = 3
|
|
Syn = 4
|
|
)
|
|
|
|
const headerSize = 20
|
|
|
|
type header struct {
|
|
typ uint8
|
|
ver uint8
|
|
ext uint8
|
|
connID uint16
|
|
timestamp uint32
|
|
timeDiff uint32
|
|
wnd uint32
|
|
seq uint16
|
|
ack uint16
|
|
}
|
|
|
|
func encode(h *header, buf []byte) {
|
|
buf[0] = (h.typ << 4) | (h.ver & 0x0F)
|
|
buf[1] = h.ext
|
|
binary.BigEndian.PutUint16(buf[2:], h.connID)
|
|
binary.BigEndian.PutUint32(buf[4:], h.timestamp)
|
|
binary.BigEndian.PutUint32(buf[8:], h.timeDiff)
|
|
binary.BigEndian.PutUint32(buf[12:], h.wnd)
|
|
binary.BigEndian.PutUint16(buf[16:], h.seq)
|
|
binary.BigEndian.PutUint16(buf[18:], h.ack)
|
|
}
|
|
|
|
func decode(buf []byte) header {
|
|
return header{
|
|
typ: (buf[0] >> 4) & 0x0F,
|
|
ver: buf[0] & 0x0F,
|
|
ext: buf[1],
|
|
connID: binary.BigEndian.Uint16(buf[2:]),
|
|
timestamp: binary.BigEndian.Uint32(buf[4:]),
|
|
timeDiff: binary.BigEndian.Uint32(buf[8:]),
|
|
wnd: binary.BigEndian.Uint32(buf[12:]),
|
|
seq: binary.BigEndian.Uint16(buf[16:]),
|
|
ack: binary.BigEndian.Uint16(buf[18:]),
|
|
}
|
|
}
|
|
|
|
type Socket struct {
|
|
conn *net.UDPConn
|
|
dhtCh chan<- Packet
|
|
mu sync.RWMutex
|
|
conns map[uint16]*Conn
|
|
accepts chan *Conn
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
type Packet struct {
|
|
Hdr header
|
|
Payload []byte
|
|
Raw []byte
|
|
Addr *net.UDPAddr
|
|
}
|
|
|
|
func New(conn *net.UDPConn, dhtCh chan<- Packet) *Socket {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
s := &Socket{
|
|
conn: conn,
|
|
dhtCh: dhtCh,
|
|
conns: make(map[uint16]*Conn),
|
|
accepts: make(chan *Conn, 16),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Socket) Start() {
|
|
go s.reader()
|
|
}
|
|
|
|
func (s *Socket) reader() {
|
|
buf := make([]byte, 65535)
|
|
for {
|
|
n, addr, err := s.conn.ReadFromUDP(buf)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if n == 0 {
|
|
continue
|
|
}
|
|
|
|
data := append([]byte(nil), buf[:n]...)
|
|
|
|
if data[0] == 'd' {
|
|
select {
|
|
case s.dhtCh <- Packet{Raw: data, Addr: addr}:
|
|
default:
|
|
}
|
|
continue
|
|
}
|
|
|
|
if n < headerSize {
|
|
continue
|
|
}
|
|
|
|
pkt := NewPacket(data, addr)
|
|
if pkt.Hdr.typ == Syn {
|
|
s.handleSyn(pkt)
|
|
} else {
|
|
s.dispatch(pkt)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Socket) handleSyn(pkt Packet) {
|
|
c := s.newConn(pkt.Hdr.connID, pkt.Addr, false)
|
|
|
|
s.mu.Lock()
|
|
s.conns[c.recvID] = c
|
|
s.mu.Unlock()
|
|
|
|
go c.run()
|
|
c.in <- pkt
|
|
|
|
select {
|
|
case s.accepts <- c:
|
|
default:
|
|
c.Close()
|
|
}
|
|
}
|
|
|
|
func (s *Socket) dispatch(pkt Packet) {
|
|
s.mu.RLock()
|
|
c := s.conns[pkt.Hdr.connID]
|
|
s.mu.RUnlock()
|
|
|
|
if c != nil {
|
|
select {
|
|
case c.in <- pkt:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Socket) removeConn(id uint16) {
|
|
s.mu.Lock()
|
|
delete(s.conns, id)
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *Socket) newConn(peerID uint16, addr *net.UDPAddr, initiator bool) *Conn {
|
|
ctx, cancel := context.WithCancel(s.ctx)
|
|
c := &Conn{
|
|
sock: s,
|
|
addr: addr,
|
|
in: make(chan Packet, 256),
|
|
reads: make(chan readReq),
|
|
writes: make(chan writeReq),
|
|
closeReq: make(chan struct{}),
|
|
ready: make(chan struct{}),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
if initiator {
|
|
c.recvID = randUint16()
|
|
c.sendID = c.recvID + 1
|
|
} else {
|
|
c.recvID = peerID + 1
|
|
c.sendID = peerID
|
|
}
|
|
return c
|
|
}
|
|
|
|
func (s *Socket) DialContext(ctx context.Context, addr *net.UDPAddr) (*Conn, error) {
|
|
c := s.newConn(0, addr, true)
|
|
|
|
s.mu.Lock()
|
|
s.conns[c.recvID] = c
|
|
s.mu.Unlock()
|
|
|
|
go c.run()
|
|
|
|
select {
|
|
case <-c.ready:
|
|
return c, nil
|
|
case <-ctx.Done():
|
|
c.Close()
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
func (s *Socket) Accept() *Conn {
|
|
return <-s.accepts
|
|
}
|
|
|
|
func (s *Socket) Close() {
|
|
s.conn.Close()
|
|
s.cancel()
|
|
}
|
|
|
|
func (s *Socket) LocalPort() int {
|
|
return s.conn.LocalAddr().(*net.UDPAddr).Port
|
|
}
|
|
|
|
func randUint16() uint16 {
|
|
var b [2]byte
|
|
if _, err := rand.Read(b[:]); err != nil {
|
|
panic(err)
|
|
}
|
|
return binary.BigEndian.Uint16(b[:])
|
|
}
|
|
|
|
func NewPacket(data []byte, addr *net.UDPAddr) Packet {
|
|
pkt := Packet{
|
|
Raw: data,
|
|
Addr: addr,
|
|
}
|
|
if len(data) >= headerSize {
|
|
pkt.Hdr = decode(data)
|
|
pkt.Payload = data[headerSize:]
|
|
}
|
|
return pkt
|
|
}
|