229 lines
3.9 KiB
Go
229 lines
3.9 KiB
Go
package dht
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"slices"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
queryTimeout = 5 * time.Second
|
|
alpha = 3
|
|
maxQueries = 32
|
|
)
|
|
|
|
type DHT struct {
|
|
conn *net.UDPConn
|
|
id [20]byte
|
|
nodes []node
|
|
pending map[string]chan *resp
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func New(port int) (*DHT, error) {
|
|
addr := &net.UDPAddr{Port: port}
|
|
conn, err := net.ListenUDP("udp", addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
d := &DHT{
|
|
conn: conn,
|
|
id: genNodeID(),
|
|
pending: make(map[string]chan *resp),
|
|
}
|
|
go d.run()
|
|
return d, nil
|
|
}
|
|
|
|
func (d *DHT) Close() {
|
|
d.conn.Close()
|
|
}
|
|
|
|
func (d *DHT) run() {
|
|
buf := make([]byte, 65535)
|
|
for {
|
|
n, _, err := d.conn.ReadFromUDP(buf)
|
|
if err != nil {
|
|
return
|
|
}
|
|
msg, err := decodeMsg(buf[:n])
|
|
if err != nil {
|
|
continue
|
|
}
|
|
d.mu.Lock()
|
|
ch, ok := d.pending[msg.T]
|
|
if ok {
|
|
delete(d.pending, msg.T)
|
|
}
|
|
d.mu.Unlock()
|
|
if ok {
|
|
ch <- &msg.R
|
|
}
|
|
}
|
|
}
|
|
|
|
func xorDistance(a, b [20]byte) [20]byte {
|
|
var dist [20]byte
|
|
for i := range 20 {
|
|
dist[i] = a[i] ^ b[i]
|
|
}
|
|
return dist
|
|
}
|
|
|
|
func compareDist(a, b [20]byte) int {
|
|
for i := range 20 {
|
|
if a[i] < b[i] {
|
|
return -1
|
|
}
|
|
if a[i] > b[i] {
|
|
return 1
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (d *DHT) query(ctx context.Context, addr *net.UDPAddr, q string, a args) (*resp, error) {
|
|
ctx, cancel := context.WithTimeout(ctx, queryTimeout)
|
|
defer cancel()
|
|
|
|
txid := genTxID()
|
|
a.ID = d.id
|
|
m := &msg{
|
|
T: txid,
|
|
Y: query,
|
|
Q: q,
|
|
A: a,
|
|
}
|
|
data, err := m.Encode()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
reply := make(chan *resp, 1)
|
|
d.mu.Lock()
|
|
d.pending[txid] = reply
|
|
d.mu.Unlock()
|
|
|
|
defer func() {
|
|
d.mu.Lock()
|
|
delete(d.pending, txid)
|
|
d.mu.Unlock()
|
|
}()
|
|
|
|
if _, err := d.conn.WriteToUDP(data, addr); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
select {
|
|
case r := <-reply:
|
|
return r, nil
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
func (d *DHT) Bootstrap(ctx context.Context, addr string) error {
|
|
host, port, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ips, err := net.LookupHost(host)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(ips) == 0 {
|
|
return fmt.Errorf("no bootstrap nodes")
|
|
}
|
|
p, err := strconv.Atoi(port)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid port: %w", err)
|
|
}
|
|
uaddr := &net.UDPAddr{
|
|
IP: net.ParseIP(ips[0]),
|
|
Port: p,
|
|
}
|
|
resp, err := d.query(ctx, uaddr, findNode, args{Target: d.id})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
d.nodes = append(d.nodes, resp.Nodes...)
|
|
return nil
|
|
}
|
|
|
|
func sortByDist(nodes []node, target [20]byte) {
|
|
slices.SortFunc(nodes, func(a, b node) int {
|
|
return compareDist(xorDistance(a.ID, target), xorDistance(b.ID, target))
|
|
})
|
|
}
|
|
|
|
func (d *DHT) GetPeers(ctx context.Context, h [20]byte) ([]*net.TCPAddr, error) {
|
|
queried := make(map[string]bool)
|
|
candidates := make([]node, len(d.nodes))
|
|
copy(candidates, d.nodes)
|
|
sortByDist(candidates, h)
|
|
|
|
results := make(chan *resp)
|
|
inflight := 0
|
|
queryCount := 0
|
|
var peers []*net.TCPAddr
|
|
|
|
for queryCount < maxQueries {
|
|
if ctx.Err() != nil {
|
|
break
|
|
}
|
|
for inflight < alpha && queryCount < maxQueries {
|
|
n, i := nextCandidate(candidates, queried)
|
|
if i < 0 {
|
|
break
|
|
}
|
|
candidates = slices.Delete(candidates, i, i+1)
|
|
queried[n.Addr.String()] = true
|
|
inflight++
|
|
queryCount++
|
|
go func(addr *net.UDPAddr) {
|
|
r, _ := d.query(ctx, addr, getPeers, args{InfoHash: h})
|
|
results <- r
|
|
}(n.Addr)
|
|
}
|
|
if inflight == 0 {
|
|
break
|
|
}
|
|
r := <-results
|
|
inflight--
|
|
if r == nil {
|
|
continue
|
|
}
|
|
for _, p := range r.Peers {
|
|
if a := decodePeer(p); a != nil {
|
|
peers = append(peers, a)
|
|
}
|
|
}
|
|
for _, c := range r.Nodes {
|
|
if !queried[c.Addr.String()] {
|
|
candidates = append(candidates, c)
|
|
}
|
|
}
|
|
sortByDist(candidates, h)
|
|
}
|
|
for range inflight {
|
|
<-results
|
|
}
|
|
if len(peers) > 0 {
|
|
return peers, nil
|
|
}
|
|
return nil, fmt.Errorf("no peers found")
|
|
}
|
|
|
|
func nextCandidate(candidates []node, queried map[string]bool) (node, int) {
|
|
for i, c := range candidates {
|
|
if !queried[c.Addr.String()] {
|
|
return c, i
|
|
}
|
|
}
|
|
return node{}, -1
|
|
}
|