first commit
This commit is contained in:
305
utp/conn.go
Normal file
305
utp/conn.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package utp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
initRTO = 1000 // ms
|
||||
retransCheck = 500 // ms
|
||||
defaultWnd = 64 * 1024
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
sock *Socket
|
||||
addr *net.UDPAddr
|
||||
recvID uint16
|
||||
sendID uint16
|
||||
|
||||
in chan packet
|
||||
reads chan readReq
|
||||
writes chan writeReq
|
||||
closeReq chan struct{}
|
||||
ready chan struct{}
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
seq uint16
|
||||
ack uint16
|
||||
unacked []sent
|
||||
reorder map[uint16][]byte
|
||||
rest []byte
|
||||
pending *readReq
|
||||
rto uint32
|
||||
}
|
||||
|
||||
type readReq struct {
|
||||
p []byte
|
||||
reply chan readResp
|
||||
}
|
||||
|
||||
type readResp struct {
|
||||
n int
|
||||
err error
|
||||
}
|
||||
|
||||
type writeReq struct {
|
||||
data []byte
|
||||
reply chan error
|
||||
}
|
||||
|
||||
type sent struct {
|
||||
seq uint16
|
||||
data []byte
|
||||
at time.Time
|
||||
}
|
||||
|
||||
func (c *Conn) flush() {
|
||||
if c.pending == nil || len(c.rest) == 0 {
|
||||
return
|
||||
}
|
||||
n := copy(c.pending.p, c.rest)
|
||||
c.rest = c.rest[n:]
|
||||
c.pending.reply <- readResp{n, nil}
|
||||
c.pending = nil
|
||||
}
|
||||
|
||||
func (c *Conn) deliver(payload []byte) {
|
||||
c.rest = append(c.rest, payload...)
|
||||
c.flush()
|
||||
}
|
||||
|
||||
func (c *Conn) send(typ uint8, payload []byte) error {
|
||||
c.seq++
|
||||
h := &header{
|
||||
typ: typ,
|
||||
ver: 1,
|
||||
connID: c.sendID,
|
||||
timestamp: timestamp(),
|
||||
wnd: defaultWnd,
|
||||
seq: c.seq,
|
||||
ack: c.ack,
|
||||
}
|
||||
buf := make([]byte, headerSize+len(payload))
|
||||
encode(h, buf)
|
||||
copy(buf[headerSize:], payload)
|
||||
if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil {
|
||||
return err
|
||||
}
|
||||
if typ == Data || typ == Syn || typ == Fin {
|
||||
c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) sendState() error {
|
||||
h := &header{
|
||||
typ: State,
|
||||
ver: 1,
|
||||
connID: c.sendID,
|
||||
timestamp: timestamp(),
|
||||
wnd: defaultWnd,
|
||||
seq: c.seq,
|
||||
ack: c.ack,
|
||||
}
|
||||
buf := make([]byte, headerSize)
|
||||
encode(h, buf)
|
||||
_, err := c.sock.conn.WriteToUDP(buf, c.addr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) sendSyn() error {
|
||||
c.seq++
|
||||
h := &header{
|
||||
typ: Syn,
|
||||
ver: 1,
|
||||
connID: c.recvID,
|
||||
timestamp: timestamp(),
|
||||
wnd: defaultWnd,
|
||||
seq: c.seq,
|
||||
ack: c.ack,
|
||||
}
|
||||
buf := make([]byte, headerSize)
|
||||
encode(h, buf)
|
||||
if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil {
|
||||
return err
|
||||
}
|
||||
c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) retrans() error {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
now := time.Now()
|
||||
for i := range c.unacked {
|
||||
p := &c.unacked[i]
|
||||
if now.Sub(p.at) > time.Duration(c.rto)*time.Millisecond {
|
||||
if _, err := c.sock.conn.WriteToUDP(p.data, c.addr); err != nil {
|
||||
return err
|
||||
}
|
||||
p.at = now
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) recv(p packet) (bool, error) {
|
||||
switch p.hdr.typ {
|
||||
case Syn:
|
||||
c.ack = p.hdr.seq
|
||||
if err := c.sendState(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
case State:
|
||||
select {
|
||||
case <-c.ready:
|
||||
default:
|
||||
c.ack = p.hdr.seq - 1
|
||||
close(c.ready)
|
||||
}
|
||||
var remaining []sent
|
||||
for _, s := range c.unacked {
|
||||
if seqLess(p.hdr.ack, s.seq) {
|
||||
remaining = append(remaining, s)
|
||||
}
|
||||
}
|
||||
c.unacked = remaining
|
||||
case Data:
|
||||
seq, payload := p.hdr.seq, p.payload
|
||||
if seq == c.ack+1 {
|
||||
c.deliver(payload)
|
||||
c.ack++
|
||||
for {
|
||||
if p, ok := c.reorder[c.ack+1]; ok {
|
||||
c.deliver(p)
|
||||
delete(c.reorder, c.ack+1)
|
||||
c.ack++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if seqLess(c.ack+1, seq) {
|
||||
c.reorder[seq] = payload
|
||||
}
|
||||
if err := c.sendState(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
case Fin, Reset:
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Conn) shutdown() {
|
||||
if c.pending != nil {
|
||||
c.pending.reply <- readResp{0, io.EOF}
|
||||
}
|
||||
select {
|
||||
case <-c.ready:
|
||||
c.send(Fin, nil)
|
||||
default:
|
||||
}
|
||||
c.unacked = nil
|
||||
c.cancel()
|
||||
c.sock.removeConn(c.recvID)
|
||||
}
|
||||
|
||||
func (c *Conn) run() {
|
||||
c.reorder = make(map[uint16][]byte)
|
||||
c.rto = initRTO
|
||||
|
||||
if c.recvID < c.sendID {
|
||||
c.seq = 1
|
||||
if err := c.sendSyn(); err != nil {
|
||||
c.shutdown()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(retransCheck * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case p := <-c.in:
|
||||
done, err := c.recv(p)
|
||||
if err != nil {
|
||||
c.shutdown()
|
||||
return
|
||||
}
|
||||
if done {
|
||||
c.shutdown()
|
||||
return
|
||||
}
|
||||
case req := <-c.reads:
|
||||
if c.pending != nil {
|
||||
req.reply <- readResp{0, io.ErrNoProgress}
|
||||
continue
|
||||
}
|
||||
c.pending = &req
|
||||
c.flush()
|
||||
case req := <-c.writes:
|
||||
if err := c.send(Data, req.data); err != nil {
|
||||
req.reply <- err
|
||||
c.shutdown()
|
||||
return
|
||||
}
|
||||
req.reply <- nil
|
||||
case <-c.closeReq:
|
||||
c.shutdown()
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.retrans()
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(p []byte) (int, error) {
|
||||
reply := make(chan readResp, 1)
|
||||
select {
|
||||
case c.reads <- readReq{p, reply}:
|
||||
r := <-reply
|
||||
return r.n, r.err
|
||||
case <-c.ctx.Done():
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Write(p []byte) (int, error) {
|
||||
reply := make(chan error, 1)
|
||||
select {
|
||||
case c.writes <- writeReq{p, reply}:
|
||||
err := <-reply
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(p), nil
|
||||
case <-c.ctx.Done():
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
select {
|
||||
case c.closeReq <- struct{}{}:
|
||||
case <-c.ctx.Done():
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func timestamp() uint32 {
|
||||
return uint32(time.Now().UnixMicro() & 0xFFFFFFFF)
|
||||
}
|
||||
|
||||
func seqLess(a, b uint16) bool {
|
||||
return int16(a-b) < 0
|
||||
}
|
||||
211
utp/socket.go
Normal file
211
utp/socket.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package utp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
Data = 0
|
||||
Fin = 1
|
||||
State = 2
|
||||
Reset = 3
|
||||
Syn = 4
|
||||
)
|
||||
|
||||
const headerSize = 20
|
||||
|
||||
type header struct {
|
||||
typ uint8
|
||||
ver uint8
|
||||
ext uint8
|
||||
connID uint16
|
||||
timestamp uint32
|
||||
timeDiff uint32
|
||||
wnd uint32
|
||||
seq uint16
|
||||
ack uint16
|
||||
}
|
||||
|
||||
func encode(h *header, buf []byte) {
|
||||
buf[0] = (h.typ << 4) | (h.ver & 0x0F)
|
||||
buf[1] = h.ext
|
||||
binary.BigEndian.PutUint16(buf[2:], h.connID)
|
||||
binary.BigEndian.PutUint32(buf[4:], h.timestamp)
|
||||
binary.BigEndian.PutUint32(buf[8:], h.timeDiff)
|
||||
binary.BigEndian.PutUint32(buf[12:], h.wnd)
|
||||
binary.BigEndian.PutUint16(buf[16:], h.seq)
|
||||
binary.BigEndian.PutUint16(buf[18:], h.ack)
|
||||
}
|
||||
|
||||
func decode(buf []byte) header {
|
||||
return header{
|
||||
typ: (buf[0] >> 4) & 0x0F,
|
||||
ver: buf[0] & 0x0F,
|
||||
ext: buf[1],
|
||||
connID: binary.BigEndian.Uint16(buf[2:]),
|
||||
timestamp: binary.BigEndian.Uint32(buf[4:]),
|
||||
timeDiff: binary.BigEndian.Uint32(buf[8:]),
|
||||
wnd: binary.BigEndian.Uint32(buf[12:]),
|
||||
seq: binary.BigEndian.Uint16(buf[16:]),
|
||||
ack: binary.BigEndian.Uint16(buf[18:]),
|
||||
}
|
||||
}
|
||||
|
||||
type Socket struct {
|
||||
conn *net.UDPConn
|
||||
mu sync.RWMutex
|
||||
conns map[uint16]*Conn
|
||||
accepts chan *Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type packet struct {
|
||||
hdr header
|
||||
payload []byte
|
||||
addr *net.UDPAddr
|
||||
}
|
||||
|
||||
func New(addr string) (*Socket, error) {
|
||||
a, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &Socket{
|
||||
conn: conn,
|
||||
conns: make(map[uint16]*Conn),
|
||||
accepts: make(chan *Conn, 16),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Socket) Start() {
|
||||
go s.reader()
|
||||
}
|
||||
|
||||
func (s *Socket) reader() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := s.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n < headerSize {
|
||||
continue
|
||||
}
|
||||
hdr := decode(buf)
|
||||
payload := make([]byte, n-headerSize)
|
||||
copy(payload, buf[headerSize:n])
|
||||
|
||||
if hdr.typ == Syn {
|
||||
s.handleSyn(hdr, payload, addr)
|
||||
} else {
|
||||
s.dispatch(hdr, payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Socket) handleSyn(hdr header, payload []byte, addr *net.UDPAddr) {
|
||||
c := s.newConn(hdr.connID, addr, false)
|
||||
|
||||
s.mu.Lock()
|
||||
s.conns[c.recvID] = c
|
||||
s.mu.Unlock()
|
||||
|
||||
go c.run()
|
||||
c.in <- packet{hdr, payload, addr}
|
||||
|
||||
select {
|
||||
case s.accepts <- c:
|
||||
default:
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Socket) dispatch(hdr header, payload []byte) {
|
||||
s.mu.RLock()
|
||||
c := s.conns[hdr.connID]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if c != nil {
|
||||
select {
|
||||
case c.in <- packet{hdr, payload, nil}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Socket) removeConn(id uint16) {
|
||||
s.mu.Lock()
|
||||
delete(s.conns, id)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Socket) newConn(peerID uint16, addr *net.UDPAddr, initiator bool) *Conn {
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
c := &Conn{
|
||||
sock: s,
|
||||
addr: addr,
|
||||
in: make(chan packet, 16),
|
||||
reads: make(chan readReq),
|
||||
writes: make(chan writeReq),
|
||||
closeReq: make(chan struct{}),
|
||||
ready: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
if initiator {
|
||||
c.recvID = randUint16()
|
||||
c.sendID = c.recvID + 1
|
||||
} else {
|
||||
c.recvID = peerID + 1
|
||||
c.sendID = peerID
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (s *Socket) DialContext(ctx context.Context, addr *net.UDPAddr) (*Conn, error) {
|
||||
c := s.newConn(0, addr, true)
|
||||
|
||||
s.mu.Lock()
|
||||
s.conns[c.recvID] = c
|
||||
s.mu.Unlock()
|
||||
|
||||
go c.run()
|
||||
|
||||
select {
|
||||
case <-c.ready:
|
||||
return c, nil
|
||||
case <-ctx.Done():
|
||||
c.Close()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Socket) Accept() *Conn {
|
||||
return <-s.accepts
|
||||
}
|
||||
|
||||
func (s *Socket) Close() {
|
||||
s.conn.Close()
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
func randUint16() uint16 {
|
||||
var b [2]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return binary.BigEndian.Uint16(b[:])
|
||||
}
|
||||
82
utp/utp_test.go
Normal file
82
utp/utp_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package utp
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestEncodeDecode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
h header
|
||||
}{
|
||||
{
|
||||
name: "syn",
|
||||
h: header{
|
||||
typ: Syn,
|
||||
ver: 1,
|
||||
ext: 0,
|
||||
connID: 1234,
|
||||
timestamp: 12345678,
|
||||
timeDiff: 1000,
|
||||
wnd: 65535,
|
||||
seq: 1,
|
||||
ack: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "data",
|
||||
h: header{
|
||||
typ: Data,
|
||||
ver: 1,
|
||||
connID: 5678,
|
||||
timestamp: 99999999,
|
||||
timeDiff: 500,
|
||||
wnd: 32768,
|
||||
seq: 100,
|
||||
ack: 99,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "state",
|
||||
h: header{
|
||||
typ: State,
|
||||
ver: 1,
|
||||
connID: 1234,
|
||||
wnd: 65535,
|
||||
seq: 1,
|
||||
ack: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
buf := make([]byte, headerSize)
|
||||
encode(&tt.h, buf)
|
||||
got := decode(buf)
|
||||
if got != tt.h {
|
||||
t.Errorf("got %+v, want %+v", got, tt.h)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeqLess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a, b uint16
|
||||
want bool
|
||||
}{
|
||||
{"1 < 2", 1, 2, true},
|
||||
{"2 > 1", 2, 1, false},
|
||||
{"wrap 0xFFFF < 0", 0xFFFF, 0, true},
|
||||
{"wrap 0 > 0xFFFF", 0, 0xFFFF, false},
|
||||
{"half 0x7FFF > 0", 0x7FFF, 0, false},
|
||||
{"half 0 < 0x7FFF", 0, 0x7FFF, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := seqLess(tt.a, tt.b)
|
||||
if got != tt.want {
|
||||
t.Errorf("seqLess(%d, %d) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user