first commit

This commit is contained in:
2026-01-19 21:13:01 +09:00
commit c70d24be5c
28 changed files with 3674 additions and 0 deletions

305
utp/conn.go Normal file
View File

@@ -0,0 +1,305 @@
package utp
import (
"context"
"io"
"net"
"time"
)
const (
initRTO = 1000 // ms
retransCheck = 500 // ms
defaultWnd = 64 * 1024
)
type Conn struct {
sock *Socket
addr *net.UDPAddr
recvID uint16
sendID uint16
in chan packet
reads chan readReq
writes chan writeReq
closeReq chan struct{}
ready chan struct{}
ctx context.Context
cancel context.CancelFunc
seq uint16
ack uint16
unacked []sent
reorder map[uint16][]byte
rest []byte
pending *readReq
rto uint32
}
type readReq struct {
p []byte
reply chan readResp
}
type readResp struct {
n int
err error
}
type writeReq struct {
data []byte
reply chan error
}
type sent struct {
seq uint16
data []byte
at time.Time
}
func (c *Conn) flush() {
if c.pending == nil || len(c.rest) == 0 {
return
}
n := copy(c.pending.p, c.rest)
c.rest = c.rest[n:]
c.pending.reply <- readResp{n, nil}
c.pending = nil
}
func (c *Conn) deliver(payload []byte) {
c.rest = append(c.rest, payload...)
c.flush()
}
func (c *Conn) send(typ uint8, payload []byte) error {
c.seq++
h := &header{
typ: typ,
ver: 1,
connID: c.sendID,
timestamp: timestamp(),
wnd: defaultWnd,
seq: c.seq,
ack: c.ack,
}
buf := make([]byte, headerSize+len(payload))
encode(h, buf)
copy(buf[headerSize:], payload)
if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil {
return err
}
if typ == Data || typ == Syn || typ == Fin {
c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()})
}
return nil
}
func (c *Conn) sendState() error {
h := &header{
typ: State,
ver: 1,
connID: c.sendID,
timestamp: timestamp(),
wnd: defaultWnd,
seq: c.seq,
ack: c.ack,
}
buf := make([]byte, headerSize)
encode(h, buf)
_, err := c.sock.conn.WriteToUDP(buf, c.addr)
return err
}
func (c *Conn) sendSyn() error {
c.seq++
h := &header{
typ: Syn,
ver: 1,
connID: c.recvID,
timestamp: timestamp(),
wnd: defaultWnd,
seq: c.seq,
ack: c.ack,
}
buf := make([]byte, headerSize)
encode(h, buf)
if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil {
return err
}
c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()})
return nil
}
func (c *Conn) retrans() error {
select {
case <-c.ctx.Done():
return nil
default:
}
now := time.Now()
for i := range c.unacked {
p := &c.unacked[i]
if now.Sub(p.at) > time.Duration(c.rto)*time.Millisecond {
if _, err := c.sock.conn.WriteToUDP(p.data, c.addr); err != nil {
return err
}
p.at = now
}
}
return nil
}
func (c *Conn) recv(p packet) (bool, error) {
switch p.hdr.typ {
case Syn:
c.ack = p.hdr.seq
if err := c.sendState(); err != nil {
return false, err
}
case State:
select {
case <-c.ready:
default:
c.ack = p.hdr.seq - 1
close(c.ready)
}
var remaining []sent
for _, s := range c.unacked {
if seqLess(p.hdr.ack, s.seq) {
remaining = append(remaining, s)
}
}
c.unacked = remaining
case Data:
seq, payload := p.hdr.seq, p.payload
if seq == c.ack+1 {
c.deliver(payload)
c.ack++
for {
if p, ok := c.reorder[c.ack+1]; ok {
c.deliver(p)
delete(c.reorder, c.ack+1)
c.ack++
} else {
break
}
}
} else if seqLess(c.ack+1, seq) {
c.reorder[seq] = payload
}
if err := c.sendState(); err != nil {
return false, err
}
case Fin, Reset:
return true, nil
}
return false, nil
}
func (c *Conn) shutdown() {
if c.pending != nil {
c.pending.reply <- readResp{0, io.EOF}
}
select {
case <-c.ready:
c.send(Fin, nil)
default:
}
c.unacked = nil
c.cancel()
c.sock.removeConn(c.recvID)
}
func (c *Conn) run() {
c.reorder = make(map[uint16][]byte)
c.rto = initRTO
if c.recvID < c.sendID {
c.seq = 1
if err := c.sendSyn(); err != nil {
c.shutdown()
return
}
}
ticker := time.NewTicker(retransCheck * time.Millisecond)
defer ticker.Stop()
for {
select {
case p := <-c.in:
done, err := c.recv(p)
if err != nil {
c.shutdown()
return
}
if done {
c.shutdown()
return
}
case req := <-c.reads:
if c.pending != nil {
req.reply <- readResp{0, io.ErrNoProgress}
continue
}
c.pending = &req
c.flush()
case req := <-c.writes:
if err := c.send(Data, req.data); err != nil {
req.reply <- err
c.shutdown()
return
}
req.reply <- nil
case <-c.closeReq:
c.shutdown()
return
case <-ticker.C:
c.retrans()
case <-c.ctx.Done():
return
}
}
}
func (c *Conn) Read(p []byte) (int, error) {
reply := make(chan readResp, 1)
select {
case c.reads <- readReq{p, reply}:
r := <-reply
return r.n, r.err
case <-c.ctx.Done():
return 0, io.EOF
}
}
func (c *Conn) Write(p []byte) (int, error) {
reply := make(chan error, 1)
select {
case c.writes <- writeReq{p, reply}:
err := <-reply
if err != nil {
return 0, err
}
return len(p), nil
case <-c.ctx.Done():
return 0, io.ErrClosedPipe
}
}
func (c *Conn) Close() error {
select {
case c.closeReq <- struct{}{}:
case <-c.ctx.Done():
}
return nil
}
func timestamp() uint32 {
return uint32(time.Now().UnixMicro() & 0xFFFFFFFF)
}
func seqLess(a, b uint16) bool {
return int16(a-b) < 0
}

211
utp/socket.go Normal file
View File

@@ -0,0 +1,211 @@
package utp
import (
"context"
"crypto/rand"
"encoding/binary"
"net"
"sync"
)
const (
Data = 0
Fin = 1
State = 2
Reset = 3
Syn = 4
)
const headerSize = 20
type header struct {
typ uint8
ver uint8
ext uint8
connID uint16
timestamp uint32
timeDiff uint32
wnd uint32
seq uint16
ack uint16
}
func encode(h *header, buf []byte) {
buf[0] = (h.typ << 4) | (h.ver & 0x0F)
buf[1] = h.ext
binary.BigEndian.PutUint16(buf[2:], h.connID)
binary.BigEndian.PutUint32(buf[4:], h.timestamp)
binary.BigEndian.PutUint32(buf[8:], h.timeDiff)
binary.BigEndian.PutUint32(buf[12:], h.wnd)
binary.BigEndian.PutUint16(buf[16:], h.seq)
binary.BigEndian.PutUint16(buf[18:], h.ack)
}
func decode(buf []byte) header {
return header{
typ: (buf[0] >> 4) & 0x0F,
ver: buf[0] & 0x0F,
ext: buf[1],
connID: binary.BigEndian.Uint16(buf[2:]),
timestamp: binary.BigEndian.Uint32(buf[4:]),
timeDiff: binary.BigEndian.Uint32(buf[8:]),
wnd: binary.BigEndian.Uint32(buf[12:]),
seq: binary.BigEndian.Uint16(buf[16:]),
ack: binary.BigEndian.Uint16(buf[18:]),
}
}
type Socket struct {
conn *net.UDPConn
mu sync.RWMutex
conns map[uint16]*Conn
accepts chan *Conn
ctx context.Context
cancel context.CancelFunc
}
type packet struct {
hdr header
payload []byte
addr *net.UDPAddr
}
func New(addr string) (*Socket, error) {
a, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", a)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
s := &Socket{
conn: conn,
conns: make(map[uint16]*Conn),
accepts: make(chan *Conn, 16),
ctx: ctx,
cancel: cancel,
}
return s, nil
}
func (s *Socket) Start() {
go s.reader()
}
func (s *Socket) reader() {
buf := make([]byte, 65535)
for {
n, addr, err := s.conn.ReadFromUDP(buf)
if err != nil {
return
}
if n < headerSize {
continue
}
hdr := decode(buf)
payload := make([]byte, n-headerSize)
copy(payload, buf[headerSize:n])
if hdr.typ == Syn {
s.handleSyn(hdr, payload, addr)
} else {
s.dispatch(hdr, payload)
}
}
}
func (s *Socket) handleSyn(hdr header, payload []byte, addr *net.UDPAddr) {
c := s.newConn(hdr.connID, addr, false)
s.mu.Lock()
s.conns[c.recvID] = c
s.mu.Unlock()
go c.run()
c.in <- packet{hdr, payload, addr}
select {
case s.accepts <- c:
default:
c.Close()
}
}
func (s *Socket) dispatch(hdr header, payload []byte) {
s.mu.RLock()
c := s.conns[hdr.connID]
s.mu.RUnlock()
if c != nil {
select {
case c.in <- packet{hdr, payload, nil}:
default:
}
}
}
func (s *Socket) removeConn(id uint16) {
s.mu.Lock()
delete(s.conns, id)
s.mu.Unlock()
}
func (s *Socket) newConn(peerID uint16, addr *net.UDPAddr, initiator bool) *Conn {
ctx, cancel := context.WithCancel(s.ctx)
c := &Conn{
sock: s,
addr: addr,
in: make(chan packet, 16),
reads: make(chan readReq),
writes: make(chan writeReq),
closeReq: make(chan struct{}),
ready: make(chan struct{}),
ctx: ctx,
cancel: cancel,
}
if initiator {
c.recvID = randUint16()
c.sendID = c.recvID + 1
} else {
c.recvID = peerID + 1
c.sendID = peerID
}
return c
}
func (s *Socket) DialContext(ctx context.Context, addr *net.UDPAddr) (*Conn, error) {
c := s.newConn(0, addr, true)
s.mu.Lock()
s.conns[c.recvID] = c
s.mu.Unlock()
go c.run()
select {
case <-c.ready:
return c, nil
case <-ctx.Done():
c.Close()
return nil, ctx.Err()
}
}
func (s *Socket) Accept() *Conn {
return <-s.accepts
}
func (s *Socket) Close() {
s.conn.Close()
s.cancel()
}
func randUint16() uint16 {
var b [2]byte
if _, err := rand.Read(b[:]); err != nil {
panic(err)
}
return binary.BigEndian.Uint16(b[:])
}

82
utp/utp_test.go Normal file
View File

@@ -0,0 +1,82 @@
package utp
import "testing"
func TestEncodeDecode(t *testing.T) {
tests := []struct {
name string
h header
}{
{
name: "syn",
h: header{
typ: Syn,
ver: 1,
ext: 0,
connID: 1234,
timestamp: 12345678,
timeDiff: 1000,
wnd: 65535,
seq: 1,
ack: 0,
},
},
{
name: "data",
h: header{
typ: Data,
ver: 1,
connID: 5678,
timestamp: 99999999,
timeDiff: 500,
wnd: 32768,
seq: 100,
ack: 99,
},
},
{
name: "state",
h: header{
typ: State,
ver: 1,
connID: 1234,
wnd: 65535,
seq: 1,
ack: 1,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := make([]byte, headerSize)
encode(&tt.h, buf)
got := decode(buf)
if got != tt.h {
t.Errorf("got %+v, want %+v", got, tt.h)
}
})
}
}
func TestSeqLess(t *testing.T) {
tests := []struct {
name string
a, b uint16
want bool
}{
{"1 < 2", 1, 2, true},
{"2 > 1", 2, 1, false},
{"wrap 0xFFFF < 0", 0xFFFF, 0, true},
{"wrap 0 > 0xFFFF", 0, 0xFFFF, false},
{"half 0x7FFF > 0", 0x7FFF, 0, false},
{"half 0 < 0x7FFF", 0, 0x7FFF, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := seqLess(tt.a, tt.b)
if got != tt.want {
t.Errorf("seqLess(%d, %d) = %v, want %v", tt.a, tt.b, got, tt.want)
}
})
}
}