storrent/utp/conn.go
2026-01-19 21:13:01 +09:00

306 lines
5.2 KiB
Go

package utp
import (
"context"
"io"
"net"
"time"
)
const (
initRTO = 1000 // ms
retransCheck = 500 // ms
defaultWnd = 64 * 1024
)
type Conn struct {
sock *Socket
addr *net.UDPAddr
recvID uint16
sendID uint16
in chan packet
reads chan readReq
writes chan writeReq
closeReq chan struct{}
ready chan struct{}
ctx context.Context
cancel context.CancelFunc
seq uint16
ack uint16
unacked []sent
reorder map[uint16][]byte
rest []byte
pending *readReq
rto uint32
}
type readReq struct {
p []byte
reply chan readResp
}
type readResp struct {
n int
err error
}
type writeReq struct {
data []byte
reply chan error
}
type sent struct {
seq uint16
data []byte
at time.Time
}
func (c *Conn) flush() {
if c.pending == nil || len(c.rest) == 0 {
return
}
n := copy(c.pending.p, c.rest)
c.rest = c.rest[n:]
c.pending.reply <- readResp{n, nil}
c.pending = nil
}
func (c *Conn) deliver(payload []byte) {
c.rest = append(c.rest, payload...)
c.flush()
}
func (c *Conn) send(typ uint8, payload []byte) error {
c.seq++
h := &header{
typ: typ,
ver: 1,
connID: c.sendID,
timestamp: timestamp(),
wnd: defaultWnd,
seq: c.seq,
ack: c.ack,
}
buf := make([]byte, headerSize+len(payload))
encode(h, buf)
copy(buf[headerSize:], payload)
if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil {
return err
}
if typ == Data || typ == Syn || typ == Fin {
c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()})
}
return nil
}
func (c *Conn) sendState() error {
h := &header{
typ: State,
ver: 1,
connID: c.sendID,
timestamp: timestamp(),
wnd: defaultWnd,
seq: c.seq,
ack: c.ack,
}
buf := make([]byte, headerSize)
encode(h, buf)
_, err := c.sock.conn.WriteToUDP(buf, c.addr)
return err
}
func (c *Conn) sendSyn() error {
c.seq++
h := &header{
typ: Syn,
ver: 1,
connID: c.recvID,
timestamp: timestamp(),
wnd: defaultWnd,
seq: c.seq,
ack: c.ack,
}
buf := make([]byte, headerSize)
encode(h, buf)
if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil {
return err
}
c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()})
return nil
}
func (c *Conn) retrans() error {
select {
case <-c.ctx.Done():
return nil
default:
}
now := time.Now()
for i := range c.unacked {
p := &c.unacked[i]
if now.Sub(p.at) > time.Duration(c.rto)*time.Millisecond {
if _, err := c.sock.conn.WriteToUDP(p.data, c.addr); err != nil {
return err
}
p.at = now
}
}
return nil
}
func (c *Conn) recv(p packet) (bool, error) {
switch p.hdr.typ {
case Syn:
c.ack = p.hdr.seq
if err := c.sendState(); err != nil {
return false, err
}
case State:
select {
case <-c.ready:
default:
c.ack = p.hdr.seq - 1
close(c.ready)
}
var remaining []sent
for _, s := range c.unacked {
if seqLess(p.hdr.ack, s.seq) {
remaining = append(remaining, s)
}
}
c.unacked = remaining
case Data:
seq, payload := p.hdr.seq, p.payload
if seq == c.ack+1 {
c.deliver(payload)
c.ack++
for {
if p, ok := c.reorder[c.ack+1]; ok {
c.deliver(p)
delete(c.reorder, c.ack+1)
c.ack++
} else {
break
}
}
} else if seqLess(c.ack+1, seq) {
c.reorder[seq] = payload
}
if err := c.sendState(); err != nil {
return false, err
}
case Fin, Reset:
return true, nil
}
return false, nil
}
func (c *Conn) shutdown() {
if c.pending != nil {
c.pending.reply <- readResp{0, io.EOF}
}
select {
case <-c.ready:
c.send(Fin, nil)
default:
}
c.unacked = nil
c.cancel()
c.sock.removeConn(c.recvID)
}
func (c *Conn) run() {
c.reorder = make(map[uint16][]byte)
c.rto = initRTO
if c.recvID < c.sendID {
c.seq = 1
if err := c.sendSyn(); err != nil {
c.shutdown()
return
}
}
ticker := time.NewTicker(retransCheck * time.Millisecond)
defer ticker.Stop()
for {
select {
case p := <-c.in:
done, err := c.recv(p)
if err != nil {
c.shutdown()
return
}
if done {
c.shutdown()
return
}
case req := <-c.reads:
if c.pending != nil {
req.reply <- readResp{0, io.ErrNoProgress}
continue
}
c.pending = &req
c.flush()
case req := <-c.writes:
if err := c.send(Data, req.data); err != nil {
req.reply <- err
c.shutdown()
return
}
req.reply <- nil
case <-c.closeReq:
c.shutdown()
return
case <-ticker.C:
c.retrans()
case <-c.ctx.Done():
return
}
}
}
func (c *Conn) Read(p []byte) (int, error) {
reply := make(chan readResp, 1)
select {
case c.reads <- readReq{p, reply}:
r := <-reply
return r.n, r.err
case <-c.ctx.Done():
return 0, io.EOF
}
}
func (c *Conn) Write(p []byte) (int, error) {
reply := make(chan error, 1)
select {
case c.writes <- writeReq{p, reply}:
err := <-reply
if err != nil {
return 0, err
}
return len(p), nil
case <-c.ctx.Done():
return 0, io.ErrClosedPipe
}
}
func (c *Conn) Close() error {
select {
case c.closeReq <- struct{}{}:
case <-c.ctx.Done():
}
return nil
}
func timestamp() uint32 {
return uint32(time.Now().UnixMicro() & 0xFFFFFFFF)
}
func seqLess(a, b uint16) bool {
return int16(a-b) < 0
}