From 60ec9c19ea0340b8b13415f361c277651a5cefc1 Mon Sep 17 00:00:00 2001 From: Hojun-Cho Date: Thu, 22 Jan 2026 10:42:55 +0900 Subject: [PATCH] dht: save tokens from get_peers for announce TODO: implement announce_peer --- client/dial.go | 3 ++- dht/dht.go | 41 ++++++++++++++++++++++++++--------------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/client/dial.go b/client/dial.go index e5f5f7a..6b5eedb 100644 --- a/client/dial.go +++ b/client/dial.go @@ -63,7 +63,8 @@ func (dp *dialPool) stop() { func (dp *dialPool) connectLoop(ctx context.Context) { dhtCtx, cancel := context.WithTimeout(ctx, dhtTimeout) - addrs, err := dp.dht.GetPeers(dhtCtx, dp.infoHash) + // TODO: handle tokens + addrs, _, err := dp.dht.GetPeers(dhtCtx, dp.infoHash) cancel() if err != nil { diff --git a/dht/dht.go b/dht/dht.go index 038e918..11d156d 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -160,17 +160,23 @@ func sortByDist(nodes []node, target [20]byte) { }) } -func (d *DHT) GetPeers(ctx context.Context, h [20]byte) ([]*net.TCPAddr, error) { - queried := make(map[string]bool) - candidates := make([]node, len(d.nodes)) - copy(candidates, d.nodes) - sortByDist(candidates, h) +func (d *DHT) GetPeers(ctx context.Context, h [20]byte) ([]*net.TCPAddr, map[string]string, error) { + type reply struct { + from *net.UDPAddr + data *resp + } - results := make(chan *resp) inflight := 0 queryCount := 0 + queried := make(map[string]bool) + candidates := make([]node, len(d.nodes)) + tokens := make(map[string]string) + replies := make(chan reply) var peers []*net.TCPAddr + copy(candidates, d.nodes) + sortByDist(candidates, h) + for queryCount < maxQueries { if ctx.Err() != nil { break @@ -186,23 +192,28 @@ func (d *DHT) GetPeers(ctx context.Context, h [20]byte) ([]*net.TCPAddr, error) queryCount++ go func(addr *net.UDPAddr) { r, _ := d.query(ctx, addr, getPeers, args{InfoHash: h}) - results <- r + replies <- reply{addr, r} }(n.Addr) } if inflight == 0 { break } - r := <-results + r := <-replies inflight-- - if r == nil { + if r.data == nil || r.data.Token == "" { continue } - for _, p := range r.Peers { + tokens[r.from.String()] = r.data.Token + for _, p := range r.data.Peers { if a := decodePeer(p); a != nil { - peers = append(peers, a) + if !slices.ContainsFunc(peers, func(p *net.TCPAddr) bool { + return p.String() == a.String() + }) { + peers = append(peers, a) + } } } - for _, c := range r.Nodes { + for _, c := range r.data.Nodes { if !queried[c.Addr.String()] { candidates = append(candidates, c) } @@ -210,12 +221,12 @@ func (d *DHT) GetPeers(ctx context.Context, h [20]byte) ([]*net.TCPAddr, error) sortByDist(candidates, h) } for range inflight { - <-results + <-replies } if len(peers) > 0 { - return peers, nil + return peers, tokens, nil } - return nil, fmt.Errorf("no peers found") + return nil, nil, fmt.Errorf("no peers found") } func nextCandidate(candidates []node, queried map[string]bool) (node, int) {