first commit
This commit is contained in:
228
dht/dht.go
Normal file
228
dht/dht.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package dht
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
queryTimeout = 5 * time.Second
|
||||
alpha = 3
|
||||
maxQueries = 32
|
||||
)
|
||||
|
||||
type DHT struct {
|
||||
conn *net.UDPConn
|
||||
id [20]byte
|
||||
nodes []node
|
||||
pending map[string]chan *resp
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func New(port int) (*DHT, error) {
|
||||
addr := &net.UDPAddr{Port: port}
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := &DHT{
|
||||
conn: conn,
|
||||
id: genNodeID(),
|
||||
pending: make(map[string]chan *resp),
|
||||
}
|
||||
go d.run()
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d *DHT) Close() {
|
||||
d.conn.Close()
|
||||
}
|
||||
|
||||
func (d *DHT) run() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, _, err := d.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
msg, err := decodeMsg(buf[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
d.mu.Lock()
|
||||
ch, ok := d.pending[msg.T]
|
||||
if ok {
|
||||
delete(d.pending, msg.T)
|
||||
}
|
||||
d.mu.Unlock()
|
||||
if ok {
|
||||
ch <- &msg.R
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func xorDistance(a, b [20]byte) [20]byte {
|
||||
var dist [20]byte
|
||||
for i := range 20 {
|
||||
dist[i] = a[i] ^ b[i]
|
||||
}
|
||||
return dist
|
||||
}
|
||||
|
||||
func compareDist(a, b [20]byte) int {
|
||||
for i := range 20 {
|
||||
if a[i] < b[i] {
|
||||
return -1
|
||||
}
|
||||
if a[i] > b[i] {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (d *DHT) query(ctx context.Context, addr *net.UDPAddr, q string, a args) (*resp, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||
defer cancel()
|
||||
|
||||
txid := genTxID()
|
||||
a.ID = d.id
|
||||
m := &msg{
|
||||
T: txid,
|
||||
Y: query,
|
||||
Q: q,
|
||||
A: a,
|
||||
}
|
||||
data, err := m.Encode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reply := make(chan *resp, 1)
|
||||
d.mu.Lock()
|
||||
d.pending[txid] = reply
|
||||
d.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
d.mu.Lock()
|
||||
delete(d.pending, txid)
|
||||
d.mu.Unlock()
|
||||
}()
|
||||
|
||||
if _, err := d.conn.WriteToUDP(data, addr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case r := <-reply:
|
||||
return r, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DHT) Bootstrap(ctx context.Context, addr string) error {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ips, err := net.LookupHost(host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return fmt.Errorf("no bootstrap nodes")
|
||||
}
|
||||
p, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
uaddr := &net.UDPAddr{
|
||||
IP: net.ParseIP(ips[0]),
|
||||
Port: p,
|
||||
}
|
||||
resp, err := d.query(ctx, uaddr, findNode, args{Target: d.id})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.nodes = append(d.nodes, resp.Nodes...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func sortByDist(nodes []node, target [20]byte) {
|
||||
slices.SortFunc(nodes, func(a, b node) int {
|
||||
return compareDist(xorDistance(a.ID, target), xorDistance(b.ID, target))
|
||||
})
|
||||
}
|
||||
|
||||
func (d *DHT) GetPeers(ctx context.Context, h [20]byte) ([]*net.TCPAddr, error) {
|
||||
queried := make(map[string]bool)
|
||||
candidates := make([]node, len(d.nodes))
|
||||
copy(candidates, d.nodes)
|
||||
sortByDist(candidates, h)
|
||||
|
||||
results := make(chan *resp)
|
||||
inflight := 0
|
||||
queryCount := 0
|
||||
var peers []*net.TCPAddr
|
||||
|
||||
for queryCount < maxQueries {
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
for inflight < alpha && queryCount < maxQueries {
|
||||
n, i := nextCandidate(candidates, queried)
|
||||
if i < 0 {
|
||||
break
|
||||
}
|
||||
candidates = slices.Delete(candidates, i, i+1)
|
||||
queried[n.Addr.String()] = true
|
||||
inflight++
|
||||
queryCount++
|
||||
go func(addr *net.UDPAddr) {
|
||||
r, _ := d.query(ctx, addr, getPeers, args{InfoHash: h})
|
||||
results <- r
|
||||
}(n.Addr)
|
||||
}
|
||||
if inflight == 0 {
|
||||
break
|
||||
}
|
||||
r := <-results
|
||||
inflight--
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
for _, p := range r.Peers {
|
||||
if a := decodePeer(p); a != nil {
|
||||
peers = append(peers, a)
|
||||
}
|
||||
}
|
||||
for _, c := range r.Nodes {
|
||||
if !queried[c.Addr.String()] {
|
||||
candidates = append(candidates, c)
|
||||
}
|
||||
}
|
||||
sortByDist(candidates, h)
|
||||
}
|
||||
for range inflight {
|
||||
<-results
|
||||
}
|
||||
if len(peers) > 0 {
|
||||
return peers, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no peers found")
|
||||
}
|
||||
|
||||
func nextCandidate(candidates []node, queried map[string]bool) (node, int) {
|
||||
for i, c := range candidates {
|
||||
if !queried[c.Addr.String()] {
|
||||
return c, i
|
||||
}
|
||||
}
|
||||
return node{}, -1
|
||||
}
|
||||
246
dht/msg.go
Normal file
246
dht/msg.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package dht
|
||||
|
||||
// https://www.bittorrent.org/beps/bep_0005.html
|
||||
// Peer: TCP
|
||||
// Node: UDP
|
||||
// BEP 5: DHT Protocol - KRPC over UDP
|
||||
// Message: bencode dict with keys:
|
||||
// t: transaction id (2 bytes)
|
||||
// y: type "q" (query), "r" (response), "e" (error)
|
||||
// q: query name (ping, find_node, get_peers, announce_peer)
|
||||
// a: query args dict
|
||||
// r: response dict
|
||||
// e: error [code, message]
|
||||
// Compact formats:
|
||||
// peer: 6 bytes (4 ip + 2 port)
|
||||
// node: 26 bytes (20 id + 4 ip + 2 port)
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"storrent/bencode"
|
||||
)
|
||||
|
||||
const (
|
||||
query = "q"
|
||||
response = "r"
|
||||
error_ = "e"
|
||||
)
|
||||
|
||||
const (
|
||||
ping = "ping"
|
||||
findNode = "find_node"
|
||||
getPeers = "get_peers"
|
||||
announcePeer = "announce_peer"
|
||||
)
|
||||
|
||||
type node struct {
|
||||
ID [20]byte
|
||||
Addr *net.UDPAddr
|
||||
}
|
||||
|
||||
type args struct {
|
||||
ID [20]byte
|
||||
Target [20]byte
|
||||
InfoHash [20]byte
|
||||
Port int
|
||||
Token string
|
||||
ImpliedPort int
|
||||
}
|
||||
|
||||
type resp struct {
|
||||
ID [20]byte
|
||||
Nodes []node
|
||||
Peers []string
|
||||
Token string
|
||||
}
|
||||
|
||||
type errResp struct {
|
||||
Code int
|
||||
Msg string
|
||||
}
|
||||
|
||||
type msg struct {
|
||||
T string
|
||||
Y string
|
||||
Q string
|
||||
A args
|
||||
R resp
|
||||
E errResp
|
||||
}
|
||||
|
||||
func genTxID() string {
|
||||
b := make([]byte, 2)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func genNodeID() [20]byte {
|
||||
var id [20]byte
|
||||
if _, err := rand.Read(id[:]); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func (m *msg) Encode() ([]byte, error) {
|
||||
d := make(map[string]any)
|
||||
d["t"] = m.T
|
||||
d["y"] = m.Y
|
||||
|
||||
switch m.Y {
|
||||
case query:
|
||||
d["q"] = m.Q
|
||||
a := make(map[string]any)
|
||||
a["id"] = string(m.A.ID[:])
|
||||
switch m.Q {
|
||||
case findNode:
|
||||
a["target"] = string(m.A.Target[:])
|
||||
case getPeers:
|
||||
a["info_hash"] = string(m.A.InfoHash[:])
|
||||
case announcePeer:
|
||||
a["info_hash"] = string(m.A.InfoHash[:])
|
||||
a["port"] = int64(m.A.Port)
|
||||
a["token"] = m.A.Token
|
||||
if m.A.ImpliedPort != 0 {
|
||||
a["implied_port"] = int64(m.A.ImpliedPort)
|
||||
}
|
||||
}
|
||||
d["a"] = a
|
||||
case response:
|
||||
r := make(map[string]any)
|
||||
r["id"] = string(m.R.ID[:])
|
||||
if len(m.R.Nodes) > 0 {
|
||||
r["nodes"] = encodeNodes(m.R.Nodes)
|
||||
}
|
||||
if len(m.R.Peers) > 0 {
|
||||
vals := make([]any, len(m.R.Peers))
|
||||
for i, p := range m.R.Peers {
|
||||
vals[i] = p
|
||||
}
|
||||
r["values"] = vals
|
||||
}
|
||||
if m.R.Token != "" {
|
||||
r["token"] = m.R.Token
|
||||
}
|
||||
d["r"] = r
|
||||
case error_:
|
||||
d["e"] = []any{int64(m.E.Code), m.E.Msg}
|
||||
}
|
||||
return bencode.Encode(d)
|
||||
}
|
||||
|
||||
func decodeMsg(data []byte) (*msg, error) {
|
||||
v, _, err := bencode.Decode(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid krpc message")
|
||||
}
|
||||
|
||||
m := &msg{}
|
||||
m.T, _ = d["t"].(string)
|
||||
m.Y, _ = d["y"].(string)
|
||||
switch m.Y {
|
||||
case query:
|
||||
m.Q, _ = d["q"].(string)
|
||||
if a, ok := d["a"].(map[string]any); ok {
|
||||
if id, ok := a["id"].(string); ok && len(id) == 20 {
|
||||
copy(m.A.ID[:], id)
|
||||
}
|
||||
if target, ok := a["target"].(string); ok && len(target) == 20 {
|
||||
copy(m.A.Target[:], target)
|
||||
}
|
||||
if ih, ok := a["info_hash"].(string); ok && len(ih) == 20 {
|
||||
copy(m.A.InfoHash[:], ih)
|
||||
}
|
||||
if port, ok := a["port"].(int64); ok {
|
||||
m.A.Port = int(port)
|
||||
}
|
||||
if token, ok := a["token"].(string); ok {
|
||||
m.A.Token = token
|
||||
}
|
||||
if iport, ok := a["implied_port"].(int64); ok {
|
||||
m.A.ImpliedPort = int(iport)
|
||||
}
|
||||
}
|
||||
|
||||
case response:
|
||||
if r, ok := d["r"].(map[string]any); ok {
|
||||
if id, ok := r["id"].(string); ok && len(id) == 20 {
|
||||
copy(m.R.ID[:], id)
|
||||
}
|
||||
if nodes, ok := r["nodes"].(string); ok {
|
||||
m.R.Nodes = decodeNodes(nodes)
|
||||
}
|
||||
if vals, ok := r["values"].([]any); ok {
|
||||
for _, v := range vals {
|
||||
if p, ok := v.(string); ok {
|
||||
m.R.Peers = append(m.R.Peers, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
if token, ok := r["token"].(string); ok {
|
||||
m.R.Token = token
|
||||
}
|
||||
}
|
||||
|
||||
case error_:
|
||||
if e, ok := d["e"].([]any); ok && len(e) >= 2 {
|
||||
if code, ok := e[0].(int64); ok {
|
||||
m.E.Code = int(code)
|
||||
}
|
||||
m.E.Msg, _ = e[1].(string)
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// 26 = 20 id + 4 ip + 2 port
|
||||
func encodeNodes(nodes []node) string {
|
||||
buf := make([]byte, 26*len(nodes))
|
||||
for i, n := range nodes {
|
||||
off := i * 26
|
||||
copy(buf[off:], n.ID[:])
|
||||
copy(buf[off+20:], n.Addr.IP.To4())
|
||||
binary.BigEndian.PutUint16(buf[off+24:], uint16(n.Addr.Port))
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func decodeNodes(s string) []node {
|
||||
data := []byte(s)
|
||||
if len(data)%26 != 0 {
|
||||
return nil
|
||||
}
|
||||
n := len(data) / 26
|
||||
nodes := make([]node, n)
|
||||
for i := range nodes {
|
||||
off := i * 26
|
||||
copy(nodes[i].ID[:], data[off:])
|
||||
ip := net.IP(data[off+20 : off+24])
|
||||
port := binary.BigEndian.Uint16(data[off+24:])
|
||||
nodes[i].Addr = &net.UDPAddr{IP: ip, Port: int(port)}
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// 6 = 4 ip + 2 port
|
||||
func decodePeer(s string) *net.TCPAddr {
|
||||
if len(s) != 6 {
|
||||
return nil
|
||||
}
|
||||
data := []byte(s)
|
||||
return &net.TCPAddr{
|
||||
IP: net.IP(data[:4]),
|
||||
Port: int(binary.BigEndian.Uint16(data[4:])),
|
||||
}
|
||||
}
|
||||
67
dht/msg_test.go
Normal file
67
dht/msg_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package dht
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func nid(c string) [20]byte {
|
||||
var id [20]byte
|
||||
copy(id[:], strings.Repeat(c, 20))
|
||||
return id
|
||||
}
|
||||
|
||||
func TestEncodeDecodeNodes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nodes []node
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
nodes: []node{},
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
nodes: []node{
|
||||
{
|
||||
ID: nid("a"),
|
||||
Addr: &net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 6881,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
nodes: []node{
|
||||
{
|
||||
ID: nid("a"),
|
||||
Addr: &net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 6881,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: nid("a"),
|
||||
Addr: &net.UDPAddr{
|
||||
IP: net.IPv4(192, 168, 1, 1),
|
||||
Port: 8080,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data := encodeNodes(tt.nodes)
|
||||
got := decodeNodes(data)
|
||||
data2 := encodeNodes(got)
|
||||
if data != data2 {
|
||||
t.Errorf("roundtrip mismatch: got %x, want %x", data2, data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user