306 lines
5.2 KiB
Go
306 lines
5.2 KiB
Go
package utp
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
initRTO = 1000 // ms
|
|
retransCheck = 500 // ms
|
|
defaultWnd = 64 * 1024
|
|
)
|
|
|
|
type Conn struct {
|
|
sock *Socket
|
|
addr *net.UDPAddr
|
|
recvID uint16
|
|
sendID uint16
|
|
|
|
in chan packet
|
|
reads chan readReq
|
|
writes chan writeReq
|
|
closeReq chan struct{}
|
|
ready chan struct{}
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
|
|
seq uint16
|
|
ack uint16
|
|
unacked []sent
|
|
reorder map[uint16][]byte
|
|
rest []byte
|
|
pending *readReq
|
|
rto uint32
|
|
}
|
|
|
|
type readReq struct {
|
|
p []byte
|
|
reply chan readResp
|
|
}
|
|
|
|
type readResp struct {
|
|
n int
|
|
err error
|
|
}
|
|
|
|
type writeReq struct {
|
|
data []byte
|
|
reply chan error
|
|
}
|
|
|
|
type sent struct {
|
|
seq uint16
|
|
data []byte
|
|
at time.Time
|
|
}
|
|
|
|
func (c *Conn) flush() {
|
|
if c.pending == nil || len(c.rest) == 0 {
|
|
return
|
|
}
|
|
n := copy(c.pending.p, c.rest)
|
|
c.rest = c.rest[n:]
|
|
c.pending.reply <- readResp{n, nil}
|
|
c.pending = nil
|
|
}
|
|
|
|
func (c *Conn) deliver(payload []byte) {
|
|
c.rest = append(c.rest, payload...)
|
|
c.flush()
|
|
}
|
|
|
|
func (c *Conn) send(typ uint8, payload []byte) error {
|
|
c.seq++
|
|
h := &header{
|
|
typ: typ,
|
|
ver: 1,
|
|
connID: c.sendID,
|
|
timestamp: timestamp(),
|
|
wnd: defaultWnd,
|
|
seq: c.seq,
|
|
ack: c.ack,
|
|
}
|
|
buf := make([]byte, headerSize+len(payload))
|
|
encode(h, buf)
|
|
copy(buf[headerSize:], payload)
|
|
if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil {
|
|
return err
|
|
}
|
|
if typ == Data || typ == Syn || typ == Fin {
|
|
c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()})
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) sendState() error {
|
|
h := &header{
|
|
typ: State,
|
|
ver: 1,
|
|
connID: c.sendID,
|
|
timestamp: timestamp(),
|
|
wnd: defaultWnd,
|
|
seq: c.seq,
|
|
ack: c.ack,
|
|
}
|
|
buf := make([]byte, headerSize)
|
|
encode(h, buf)
|
|
_, err := c.sock.conn.WriteToUDP(buf, c.addr)
|
|
return err
|
|
}
|
|
|
|
func (c *Conn) sendSyn() error {
|
|
c.seq++
|
|
h := &header{
|
|
typ: Syn,
|
|
ver: 1,
|
|
connID: c.recvID,
|
|
timestamp: timestamp(),
|
|
wnd: defaultWnd,
|
|
seq: c.seq,
|
|
ack: c.ack,
|
|
}
|
|
buf := make([]byte, headerSize)
|
|
encode(h, buf)
|
|
if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil {
|
|
return err
|
|
}
|
|
c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()})
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) retrans() error {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
return nil
|
|
default:
|
|
}
|
|
now := time.Now()
|
|
for i := range c.unacked {
|
|
p := &c.unacked[i]
|
|
if now.Sub(p.at) > time.Duration(c.rto)*time.Millisecond {
|
|
if _, err := c.sock.conn.WriteToUDP(p.data, c.addr); err != nil {
|
|
return err
|
|
}
|
|
p.at = now
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) recv(p packet) (bool, error) {
|
|
switch p.hdr.typ {
|
|
case Syn:
|
|
c.ack = p.hdr.seq
|
|
if err := c.sendState(); err != nil {
|
|
return false, err
|
|
}
|
|
case State:
|
|
select {
|
|
case <-c.ready:
|
|
default:
|
|
c.ack = p.hdr.seq - 1
|
|
close(c.ready)
|
|
}
|
|
var remaining []sent
|
|
for _, s := range c.unacked {
|
|
if seqLess(p.hdr.ack, s.seq) {
|
|
remaining = append(remaining, s)
|
|
}
|
|
}
|
|
c.unacked = remaining
|
|
case Data:
|
|
seq, payload := p.hdr.seq, p.payload
|
|
if seq == c.ack+1 {
|
|
c.deliver(payload)
|
|
c.ack++
|
|
for {
|
|
if p, ok := c.reorder[c.ack+1]; ok {
|
|
c.deliver(p)
|
|
delete(c.reorder, c.ack+1)
|
|
c.ack++
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
} else if seqLess(c.ack+1, seq) {
|
|
c.reorder[seq] = payload
|
|
}
|
|
if err := c.sendState(); err != nil {
|
|
return false, err
|
|
}
|
|
case Fin, Reset:
|
|
return true, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (c *Conn) shutdown() {
|
|
if c.pending != nil {
|
|
c.pending.reply <- readResp{0, io.EOF}
|
|
}
|
|
select {
|
|
case <-c.ready:
|
|
c.send(Fin, nil)
|
|
default:
|
|
}
|
|
c.unacked = nil
|
|
c.cancel()
|
|
c.sock.removeConn(c.recvID)
|
|
}
|
|
|
|
func (c *Conn) run() {
|
|
c.reorder = make(map[uint16][]byte)
|
|
c.rto = initRTO
|
|
|
|
if c.recvID < c.sendID {
|
|
c.seq = 1
|
|
if err := c.sendSyn(); err != nil {
|
|
c.shutdown()
|
|
return
|
|
}
|
|
}
|
|
|
|
ticker := time.NewTicker(retransCheck * time.Millisecond)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case p := <-c.in:
|
|
done, err := c.recv(p)
|
|
if err != nil {
|
|
c.shutdown()
|
|
return
|
|
}
|
|
if done {
|
|
c.shutdown()
|
|
return
|
|
}
|
|
case req := <-c.reads:
|
|
if c.pending != nil {
|
|
req.reply <- readResp{0, io.ErrNoProgress}
|
|
continue
|
|
}
|
|
c.pending = &req
|
|
c.flush()
|
|
case req := <-c.writes:
|
|
if err := c.send(Data, req.data); err != nil {
|
|
req.reply <- err
|
|
c.shutdown()
|
|
return
|
|
}
|
|
req.reply <- nil
|
|
case <-c.closeReq:
|
|
c.shutdown()
|
|
return
|
|
case <-ticker.C:
|
|
c.retrans()
|
|
case <-c.ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Conn) Read(p []byte) (int, error) {
|
|
reply := make(chan readResp, 1)
|
|
select {
|
|
case c.reads <- readReq{p, reply}:
|
|
r := <-reply
|
|
return r.n, r.err
|
|
case <-c.ctx.Done():
|
|
return 0, io.EOF
|
|
}
|
|
}
|
|
|
|
func (c *Conn) Write(p []byte) (int, error) {
|
|
reply := make(chan error, 1)
|
|
select {
|
|
case c.writes <- writeReq{p, reply}:
|
|
err := <-reply
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return len(p), nil
|
|
case <-c.ctx.Done():
|
|
return 0, io.ErrClosedPipe
|
|
}
|
|
}
|
|
|
|
func (c *Conn) Close() error {
|
|
select {
|
|
case c.closeReq <- struct{}{}:
|
|
case <-c.ctx.Done():
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func timestamp() uint32 {
|
|
return uint32(time.Now().UnixMicro() & 0xFFFFFFFF)
|
|
}
|
|
|
|
func seqLess(a, b uint16) bool {
|
|
return int16(a-b) < 0
|
|
}
|