package dht import ( "context" "fmt" "net" "slices" "strconv" "sync" "time" ) const ( queryTimeout = 5 * time.Second alpha = 3 maxQueries = 32 ) type DHT struct { conn *net.UDPConn id [20]byte nodes []node pending map[string]chan *resp mu sync.Mutex } func New(port int) (*DHT, error) { addr := &net.UDPAddr{Port: port} conn, err := net.ListenUDP("udp", addr) if err != nil { return nil, err } d := &DHT{ conn: conn, id: genNodeID(), pending: make(map[string]chan *resp), } go d.run() return d, nil } func (d *DHT) Close() { d.conn.Close() } func (d *DHT) run() { buf := make([]byte, 65535) for { n, _, err := d.conn.ReadFromUDP(buf) if err != nil { return } msg, err := decodeMsg(buf[:n]) if err != nil { continue } d.mu.Lock() ch, ok := d.pending[msg.T] if ok { delete(d.pending, msg.T) } d.mu.Unlock() if ok { ch <- &msg.R } } } func xorDistance(a, b [20]byte) [20]byte { var dist [20]byte for i := range 20 { dist[i] = a[i] ^ b[i] } return dist } func compareDist(a, b [20]byte) int { for i := range 20 { if a[i] < b[i] { return -1 } if a[i] > b[i] { return 1 } } return 0 } func (d *DHT) query(ctx context.Context, addr *net.UDPAddr, q string, a args) (*resp, error) { ctx, cancel := context.WithTimeout(ctx, queryTimeout) defer cancel() txid := genTxID() a.ID = d.id m := &msg{ T: txid, Y: query, Q: q, A: a, } data, err := m.Encode() if err != nil { return nil, err } reply := make(chan *resp, 1) d.mu.Lock() d.pending[txid] = reply d.mu.Unlock() defer func() { d.mu.Lock() delete(d.pending, txid) d.mu.Unlock() }() if _, err := d.conn.WriteToUDP(data, addr); err != nil { return nil, err } select { case r := <-reply: return r, nil case <-ctx.Done(): return nil, ctx.Err() } } func (d *DHT) Bootstrap(ctx context.Context, addr string) error { host, port, err := net.SplitHostPort(addr) if err != nil { return err } ips, err := net.LookupHost(host) if err != nil { return err } if len(ips) == 0 { return fmt.Errorf("no bootstrap nodes") } p, err := strconv.Atoi(port) if err != nil { return fmt.Errorf("invalid port: %w", err) } uaddr := &net.UDPAddr{ IP: net.ParseIP(ips[0]), Port: p, } resp, err := d.query(ctx, uaddr, findNode, args{Target: d.id}) if err != nil { return err } d.nodes = append(d.nodes, resp.Nodes...) return nil } func sortByDist(nodes []node, target [20]byte) { slices.SortFunc(nodes, func(a, b node) int { return compareDist(xorDistance(a.ID, target), xorDistance(b.ID, target)) }) } func (d *DHT) GetPeers(ctx context.Context, h [20]byte) ([]*net.TCPAddr, error) { queried := make(map[string]bool) candidates := make([]node, len(d.nodes)) copy(candidates, d.nodes) sortByDist(candidates, h) results := make(chan *resp) inflight := 0 queryCount := 0 var peers []*net.TCPAddr for queryCount < maxQueries { if ctx.Err() != nil { break } for inflight < alpha && queryCount < maxQueries { n, i := nextCandidate(candidates, queried) if i < 0 { break } candidates = slices.Delete(candidates, i, i+1) queried[n.Addr.String()] = true inflight++ queryCount++ go func(addr *net.UDPAddr) { r, _ := d.query(ctx, addr, getPeers, args{InfoHash: h}) results <- r }(n.Addr) } if inflight == 0 { break } r := <-results inflight-- if r == nil { continue } for _, p := range r.Peers { if a := decodePeer(p); a != nil { peers = append(peers, a) } } for _, c := range r.Nodes { if !queried[c.Addr.String()] { candidates = append(candidates, c) } } sortByDist(candidates, h) } for range inflight { <-results } if len(peers) > 0 { return peers, nil } return nil, fmt.Errorf("no peers found") } func nextCandidate(candidates []node, queried map[string]bool) (node, int) { for i, c := range candidates { if !queried[c.Addr.String()] { return c, i } } return node{}, -1 }