Add announcePeer feature.

TODO:
	fix sync write
	congestion control
	Request timeout
This commit is contained in:
Hojun-Cho 2026-01-22 13:55:29 +09:00
parent 60ec9c19ea
commit aa0c61139f
6 changed files with 102 additions and 68 deletions

View File

@ -63,10 +63,13 @@ func (dp *dialPool) stop() {
func (dp *dialPool) connectLoop(ctx context.Context) { func (dp *dialPool) connectLoop(ctx context.Context) {
dhtCtx, cancel := context.WithTimeout(ctx, dhtTimeout) dhtCtx, cancel := context.WithTimeout(ctx, dhtTimeout)
// TODO: handle tokens addrs, tokens, err := dp.dht.GetPeers(dhtCtx, dp.infoHash)
addrs, _, err := dp.dht.GetPeers(dhtCtx, dp.infoHash)
cancel() cancel()
if len(tokens) > 0 {
dp.dht.AnnouncePeer(ctx, dp.infoHash, tokens)
}
if err != nil { if err != nil {
select { select {
case dp.results <- dialResult{done: true, err: err}: case dp.results <- dialResult{done: true, err: err}:

View File

@ -26,7 +26,7 @@ const (
const ( const (
maxPeers = 8 maxPeers = 8
maxPending = 5 maxPending = 50
dialTimeout = 30 * time.Second dialTimeout = 30 * time.Second
retryInterval = 30 * time.Second retryInterval = 30 * time.Second
dhtTimeout = 5 * time.Minute dhtTimeout = 5 * time.Minute

View File

@ -8,6 +8,8 @@ import (
"strconv" "strconv"
"sync" "sync"
"time" "time"
"storrent/utp"
) )
const ( const (
@ -18,39 +20,27 @@ const (
type DHT struct { type DHT struct {
conn *net.UDPConn conn *net.UDPConn
in <-chan utp.Packet
id [20]byte id [20]byte
nodes []node nodes []node
pending map[string]chan *resp pending map[string]chan *resp
mu sync.Mutex mu sync.Mutex
} }
func New(port int) (*DHT, error) { func New(conn *net.UDPConn, in <-chan utp.Packet) *DHT {
addr := &net.UDPAddr{Port: port}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return nil, err
}
d := &DHT{ d := &DHT{
conn: conn, conn: conn,
in: in,
id: genNodeID(), id: genNodeID(),
pending: make(map[string]chan *resp), pending: make(map[string]chan *resp),
} }
go d.run() go d.run()
return d, nil return d
}
func (d *DHT) Close() {
d.conn.Close()
} }
func (d *DHT) run() { func (d *DHT) run() {
buf := make([]byte, 65535) for pkt := range d.in {
for { msg, err := decodeMsg(pkt.Raw)
n, _, err := d.conn.ReadFromUDP(buf)
if err != nil {
return
}
msg, err := decodeMsg(buf[:n])
if err != nil { if err != nil {
continue continue
} }
@ -237,3 +227,23 @@ func nextCandidate(candidates []node, queried map[string]bool) (node, int) {
} }
return node{}, -1 return node{}, -1
} }
func (d *DHT) AnnouncePeer(ctx context.Context, h [20]byte, tokens map[string]string) {
var wg sync.WaitGroup
for addrStr, token := range tokens {
addr, err := net.ResolveUDPAddr("udp", addrStr)
if err != nil {
continue
}
wg.Add(1)
go func(addr *net.UDPAddr, token string) {
defer wg.Done()
d.query(ctx, addr, announcePeer, args{
InfoHash: h,
ImpliedPort: 1,
Token: token,
})
}(addr, token)
}
wg.Wait()
}

24
main.go
View File

@ -3,8 +3,8 @@ package main
import ( import (
"context" "context"
"flag" "flag"
"fmt"
"log" "log"
"net"
"github.com/knusbaum/go9p" "github.com/knusbaum/go9p"
@ -17,27 +17,25 @@ import (
func main() { func main() {
addr := flag.String("addr", ":5640", "9P listen address") addr := flag.String("addr", ":5640", "9P listen address")
dir := flag.String("dir", "./download", "download directory") dir := flag.String("dir", "./download", "download directory")
dhtPort := flag.Int("dht", 6881, "DHT port") udpPort := flag.Int("port", 6881, "UDP port (DHT + uTP)")
btPort := flag.Int("bt", 0, "BT listen port") btPort := flag.Int("bt", 0, "BT listen port (TCP)")
utpPort := flag.Int("utp", 6882, "uTP port")
bootstrap := flag.String("bootstrap", "router.bittorrent.com:6881", "DHT bootstrap node") bootstrap := flag.String("bootstrap", "router.bittorrent.com:6881", "DHT bootstrap node")
flag.Parse() flag.Parse()
d, err := dht.New(*dhtPort) conn, err := net.ListenUDP("udp", &net.UDPAddr{Port: *udpPort})
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
if err := d.Bootstrap(context.Background(), *bootstrap); err != nil {
dhtCh := make(chan utp.Packet, 64)
sock := utp.New(conn, dhtCh)
dht := dht.New(conn, dhtCh)
sock.Start()
if err := dht.Bootstrap(context.Background(), *bootstrap); err != nil {
log.Fatal(err) log.Fatal(err)
} }
utpSock, err := utp.New(fmt.Sprintf(":%d", *utpPort)) m, err := client.NewManager(*dir, dht, *btPort, sock)
if err != nil {
log.Fatal(err)
}
utpSock.Start()
m, err := client.NewManager(*dir, d, *btPort, utpSock)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -19,7 +19,7 @@ type Conn struct {
recvID uint16 recvID uint16
sendID uint16 sendID uint16
in chan packet in chan Packet
reads chan readReq reads chan readReq
writes chan writeReq writes chan writeReq
closeReq chan struct{} closeReq chan struct{}
@ -150,10 +150,10 @@ func (c *Conn) retrans() error {
return nil return nil
} }
func (c *Conn) recv(p packet) (bool, error) { func (c *Conn) recv(p Packet) (bool, error) {
switch p.hdr.typ { switch p.Hdr.typ {
case Syn: case Syn:
c.ack = p.hdr.seq c.ack = p.Hdr.seq
if err := c.sendState(); err != nil { if err := c.sendState(); err != nil {
return false, err return false, err
} }
@ -161,18 +161,18 @@ func (c *Conn) recv(p packet) (bool, error) {
select { select {
case <-c.ready: case <-c.ready:
default: default:
c.ack = p.hdr.seq - 1 c.ack = p.Hdr.seq - 1
close(c.ready) close(c.ready)
} }
var remaining []sent var remaining []sent
for _, s := range c.unacked { for _, s := range c.unacked {
if seqLess(p.hdr.ack, s.seq) { if seqLess(p.Hdr.ack, s.seq) {
remaining = append(remaining, s) remaining = append(remaining, s)
} }
} }
c.unacked = remaining c.unacked = remaining
case Data: case Data:
seq, payload := p.hdr.seq, p.payload seq, payload := p.Hdr.seq, p.Payload
if seq == c.ack+1 { if seq == c.ack+1 {
c.deliver(payload) c.deliver(payload)
c.ack++ c.ack++

View File

@ -57,6 +57,7 @@ func decode(buf []byte) header {
type Socket struct { type Socket struct {
conn *net.UDPConn conn *net.UDPConn
dhtCh chan<- Packet
mu sync.RWMutex mu sync.RWMutex
conns map[uint16]*Conn conns map[uint16]*Conn
accepts chan *Conn accepts chan *Conn
@ -64,30 +65,24 @@ type Socket struct {
cancel context.CancelFunc cancel context.CancelFunc
} }
type packet struct { type Packet struct {
hdr header Hdr header
payload []byte Payload []byte
addr *net.UDPAddr Raw []byte
Addr *net.UDPAddr
} }
func New(addr string) (*Socket, error) { func New(conn *net.UDPConn, dhtCh chan<- Packet) *Socket {
a, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", a)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
s := &Socket{ s := &Socket{
conn: conn, conn: conn,
dhtCh: dhtCh,
conns: make(map[uint16]*Conn), conns: make(map[uint16]*Conn),
accepts: make(chan *Conn, 16), accepts: make(chan *Conn, 16),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
} }
return s, nil return s
} }
func (s *Socket) Start() { func (s *Socket) Start() {
@ -101,30 +96,42 @@ func (s *Socket) reader() {
if err != nil { if err != nil {
return return
} }
if n == 0 {
continue
}
data := append([]byte(nil), buf[:n]...)
if data[0] == 'd' {
select {
case s.dhtCh <- Packet{Raw: data, Addr: addr}:
default:
}
continue
}
if n < headerSize { if n < headerSize {
continue continue
} }
hdr := decode(buf)
payload := make([]byte, n-headerSize)
copy(payload, buf[headerSize:n])
if hdr.typ == Syn { pkt := NewPacket(data, addr)
s.handleSyn(hdr, payload, addr) if pkt.Hdr.typ == Syn {
s.handleSyn(pkt)
} else { } else {
s.dispatch(hdr, payload) s.dispatch(pkt)
} }
} }
} }
func (s *Socket) handleSyn(hdr header, payload []byte, addr *net.UDPAddr) { func (s *Socket) handleSyn(pkt Packet) {
c := s.newConn(hdr.connID, addr, false) c := s.newConn(pkt.Hdr.connID, pkt.Addr, false)
s.mu.Lock() s.mu.Lock()
s.conns[c.recvID] = c s.conns[c.recvID] = c
s.mu.Unlock() s.mu.Unlock()
go c.run() go c.run()
c.in <- packet{hdr, payload, addr} c.in <- pkt
select { select {
case s.accepts <- c: case s.accepts <- c:
@ -133,14 +140,14 @@ func (s *Socket) handleSyn(hdr header, payload []byte, addr *net.UDPAddr) {
} }
} }
func (s *Socket) dispatch(hdr header, payload []byte) { func (s *Socket) dispatch(pkt Packet) {
s.mu.RLock() s.mu.RLock()
c := s.conns[hdr.connID] c := s.conns[pkt.Hdr.connID]
s.mu.RUnlock() s.mu.RUnlock()
if c != nil { if c != nil {
select { select {
case c.in <- packet{hdr, payload, nil}: case c.in <- pkt:
default: default:
} }
} }
@ -157,7 +164,7 @@ func (s *Socket) newConn(peerID uint16, addr *net.UDPAddr, initiator bool) *Conn
c := &Conn{ c := &Conn{
sock: s, sock: s,
addr: addr, addr: addr,
in: make(chan packet, 16), in: make(chan Packet, 256),
reads: make(chan readReq), reads: make(chan readReq),
writes: make(chan writeReq), writes: make(chan writeReq),
closeReq: make(chan struct{}), closeReq: make(chan struct{}),
@ -202,6 +209,10 @@ func (s *Socket) Close() {
s.cancel() s.cancel()
} }
func (s *Socket) LocalPort() int {
return s.conn.LocalAddr().(*net.UDPAddr).Port
}
func randUint16() uint16 { func randUint16() uint16 {
var b [2]byte var b [2]byte
if _, err := rand.Read(b[:]); err != nil { if _, err := rand.Read(b[:]); err != nil {
@ -209,3 +220,15 @@ func randUint16() uint16 {
} }
return binary.BigEndian.Uint16(b[:]) return binary.BigEndian.Uint16(b[:])
} }
func NewPacket(data []byte, addr *net.UDPAddr) Packet {
pkt := Packet{
Raw: data,
Addr: addr,
}
if len(data) >= headerSize {
pkt.Hdr = decode(data)
pkt.Payload = data[headerSize:]
}
return pkt
}