storrent/dht/dht.go
2026-01-19 21:13:01 +09:00

229 lines
3.9 KiB
Go

package dht
import (
"context"
"fmt"
"net"
"slices"
"strconv"
"sync"
"time"
)
const (
queryTimeout = 5 * time.Second
alpha = 3
maxQueries = 32
)
type DHT struct {
conn *net.UDPConn
id [20]byte
nodes []node
pending map[string]chan *resp
mu sync.Mutex
}
func New(port int) (*DHT, error) {
addr := &net.UDPAddr{Port: port}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return nil, err
}
d := &DHT{
conn: conn,
id: genNodeID(),
pending: make(map[string]chan *resp),
}
go d.run()
return d, nil
}
func (d *DHT) Close() {
d.conn.Close()
}
func (d *DHT) run() {
buf := make([]byte, 65535)
for {
n, _, err := d.conn.ReadFromUDP(buf)
if err != nil {
return
}
msg, err := decodeMsg(buf[:n])
if err != nil {
continue
}
d.mu.Lock()
ch, ok := d.pending[msg.T]
if ok {
delete(d.pending, msg.T)
}
d.mu.Unlock()
if ok {
ch <- &msg.R
}
}
}
func xorDistance(a, b [20]byte) [20]byte {
var dist [20]byte
for i := range 20 {
dist[i] = a[i] ^ b[i]
}
return dist
}
func compareDist(a, b [20]byte) int {
for i := range 20 {
if a[i] < b[i] {
return -1
}
if a[i] > b[i] {
return 1
}
}
return 0
}
func (d *DHT) query(ctx context.Context, addr *net.UDPAddr, q string, a args) (*resp, error) {
ctx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
txid := genTxID()
a.ID = d.id
m := &msg{
T: txid,
Y: query,
Q: q,
A: a,
}
data, err := m.Encode()
if err != nil {
return nil, err
}
reply := make(chan *resp, 1)
d.mu.Lock()
d.pending[txid] = reply
d.mu.Unlock()
defer func() {
d.mu.Lock()
delete(d.pending, txid)
d.mu.Unlock()
}()
if _, err := d.conn.WriteToUDP(data, addr); err != nil {
return nil, err
}
select {
case r := <-reply:
return r, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (d *DHT) Bootstrap(ctx context.Context, addr string) error {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return err
}
ips, err := net.LookupHost(host)
if err != nil {
return err
}
if len(ips) == 0 {
return fmt.Errorf("no bootstrap nodes")
}
p, err := strconv.Atoi(port)
if err != nil {
return fmt.Errorf("invalid port: %w", err)
}
uaddr := &net.UDPAddr{
IP: net.ParseIP(ips[0]),
Port: p,
}
resp, err := d.query(ctx, uaddr, findNode, args{Target: d.id})
if err != nil {
return err
}
d.nodes = append(d.nodes, resp.Nodes...)
return nil
}
func sortByDist(nodes []node, target [20]byte) {
slices.SortFunc(nodes, func(a, b node) int {
return compareDist(xorDistance(a.ID, target), xorDistance(b.ID, target))
})
}
func (d *DHT) GetPeers(ctx context.Context, h [20]byte) ([]*net.TCPAddr, error) {
queried := make(map[string]bool)
candidates := make([]node, len(d.nodes))
copy(candidates, d.nodes)
sortByDist(candidates, h)
results := make(chan *resp)
inflight := 0
queryCount := 0
var peers []*net.TCPAddr
for queryCount < maxQueries {
if ctx.Err() != nil {
break
}
for inflight < alpha && queryCount < maxQueries {
n, i := nextCandidate(candidates, queried)
if i < 0 {
break
}
candidates = slices.Delete(candidates, i, i+1)
queried[n.Addr.String()] = true
inflight++
queryCount++
go func(addr *net.UDPAddr) {
r, _ := d.query(ctx, addr, getPeers, args{InfoHash: h})
results <- r
}(n.Addr)
}
if inflight == 0 {
break
}
r := <-results
inflight--
if r == nil {
continue
}
for _, p := range r.Peers {
if a := decodePeer(p); a != nil {
peers = append(peers, a)
}
}
for _, c := range r.Nodes {
if !queried[c.Addr.String()] {
candidates = append(candidates, c)
}
}
sortByDist(candidates, h)
}
for range inflight {
<-results
}
if len(peers) > 0 {
return peers, nil
}
return nil, fmt.Errorf("no peers found")
}
func nextCandidate(candidates []node, queried map[string]bool) (node, int) {
for i, c := range candidates {
if !queried[c.Addr.String()] {
return c, i
}
}
return node{}, -1
}