package utp import ( "context" "io" "net" "time" ) const ( initRTO = 1000 // ms retransCheck = 500 // ms defaultWnd = 64 * 1024 ) type Conn struct { sock *Socket addr *net.UDPAddr recvID uint16 sendID uint16 in chan packet reads chan readReq writes chan writeReq closeReq chan struct{} ready chan struct{} ctx context.Context cancel context.CancelFunc seq uint16 ack uint16 unacked []sent reorder map[uint16][]byte rest []byte pending *readReq rto uint32 } type readReq struct { p []byte reply chan readResp } type readResp struct { n int err error } type writeReq struct { data []byte reply chan error } type sent struct { seq uint16 data []byte at time.Time } func (c *Conn) flush() { if c.pending == nil || len(c.rest) == 0 { return } n := copy(c.pending.p, c.rest) c.rest = c.rest[n:] c.pending.reply <- readResp{n, nil} c.pending = nil } func (c *Conn) deliver(payload []byte) { c.rest = append(c.rest, payload...) c.flush() } func (c *Conn) send(typ uint8, payload []byte) error { c.seq++ h := &header{ typ: typ, ver: 1, connID: c.sendID, timestamp: timestamp(), wnd: defaultWnd, seq: c.seq, ack: c.ack, } buf := make([]byte, headerSize+len(payload)) encode(h, buf) copy(buf[headerSize:], payload) if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil { return err } if typ == Data || typ == Syn || typ == Fin { c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()}) } return nil } func (c *Conn) sendState() error { h := &header{ typ: State, ver: 1, connID: c.sendID, timestamp: timestamp(), wnd: defaultWnd, seq: c.seq, ack: c.ack, } buf := make([]byte, headerSize) encode(h, buf) _, err := c.sock.conn.WriteToUDP(buf, c.addr) return err } func (c *Conn) sendSyn() error { c.seq++ h := &header{ typ: Syn, ver: 1, connID: c.recvID, timestamp: timestamp(), wnd: defaultWnd, seq: c.seq, ack: c.ack, } buf := make([]byte, headerSize) encode(h, buf) if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil { return err } c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()}) return nil } func (c *Conn) retrans() error { select { case <-c.ctx.Done(): return nil default: } now := time.Now() for i := range c.unacked { p := &c.unacked[i] if now.Sub(p.at) > time.Duration(c.rto)*time.Millisecond { if _, err := c.sock.conn.WriteToUDP(p.data, c.addr); err != nil { return err } p.at = now } } return nil } func (c *Conn) recv(p packet) (bool, error) { switch p.hdr.typ { case Syn: c.ack = p.hdr.seq if err := c.sendState(); err != nil { return false, err } case State: select { case <-c.ready: default: c.ack = p.hdr.seq - 1 close(c.ready) } var remaining []sent for _, s := range c.unacked { if seqLess(p.hdr.ack, s.seq) { remaining = append(remaining, s) } } c.unacked = remaining case Data: seq, payload := p.hdr.seq, p.payload if seq == c.ack+1 { c.deliver(payload) c.ack++ for { if p, ok := c.reorder[c.ack+1]; ok { c.deliver(p) delete(c.reorder, c.ack+1) c.ack++ } else { break } } } else if seqLess(c.ack+1, seq) { c.reorder[seq] = payload } if err := c.sendState(); err != nil { return false, err } case Fin, Reset: return true, nil } return false, nil } func (c *Conn) shutdown() { if c.pending != nil { c.pending.reply <- readResp{0, io.EOF} } select { case <-c.ready: c.send(Fin, nil) default: } c.unacked = nil c.cancel() c.sock.removeConn(c.recvID) } func (c *Conn) run() { c.reorder = make(map[uint16][]byte) c.rto = initRTO if c.recvID < c.sendID { c.seq = 1 if err := c.sendSyn(); err != nil { c.shutdown() return } } ticker := time.NewTicker(retransCheck * time.Millisecond) defer ticker.Stop() for { select { case p := <-c.in: done, err := c.recv(p) if err != nil { c.shutdown() return } if done { c.shutdown() return } case req := <-c.reads: if c.pending != nil { req.reply <- readResp{0, io.ErrNoProgress} continue } c.pending = &req c.flush() case req := <-c.writes: if err := c.send(Data, req.data); err != nil { req.reply <- err c.shutdown() return } req.reply <- nil case <-c.closeReq: c.shutdown() return case <-ticker.C: c.retrans() case <-c.ctx.Done(): return } } } func (c *Conn) Read(p []byte) (int, error) { reply := make(chan readResp, 1) select { case c.reads <- readReq{p, reply}: r := <-reply return r.n, r.err case <-c.ctx.Done(): return 0, io.EOF } } func (c *Conn) Write(p []byte) (int, error) { reply := make(chan error, 1) select { case c.writes <- writeReq{p, reply}: err := <-reply if err != nil { return 0, err } return len(p), nil case <-c.ctx.Done(): return 0, io.ErrClosedPipe } } func (c *Conn) Close() error { select { case c.closeReq <- struct{}{}: case <-c.ctx.Done(): } return nil } func timestamp() uint32 { return uint32(time.Now().UnixMicro() & 0xFFFFFFFF) } func seqLess(a, b uint16) bool { return int16(a-b) < 0 }