first commit

This commit is contained in:
2026-01-19 21:13:01 +09:00
commit c70d24be5c
28 changed files with 3674 additions and 0 deletions

228
dht/dht.go Normal file
View File

@@ -0,0 +1,228 @@
package dht
import (
"context"
"fmt"
"net"
"slices"
"strconv"
"sync"
"time"
)
const (
queryTimeout = 5 * time.Second
alpha = 3
maxQueries = 32
)
type DHT struct {
conn *net.UDPConn
id [20]byte
nodes []node
pending map[string]chan *resp
mu sync.Mutex
}
func New(port int) (*DHT, error) {
addr := &net.UDPAddr{Port: port}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return nil, err
}
d := &DHT{
conn: conn,
id: genNodeID(),
pending: make(map[string]chan *resp),
}
go d.run()
return d, nil
}
func (d *DHT) Close() {
d.conn.Close()
}
func (d *DHT) run() {
buf := make([]byte, 65535)
for {
n, _, err := d.conn.ReadFromUDP(buf)
if err != nil {
return
}
msg, err := decodeMsg(buf[:n])
if err != nil {
continue
}
d.mu.Lock()
ch, ok := d.pending[msg.T]
if ok {
delete(d.pending, msg.T)
}
d.mu.Unlock()
if ok {
ch <- &msg.R
}
}
}
func xorDistance(a, b [20]byte) [20]byte {
var dist [20]byte
for i := range 20 {
dist[i] = a[i] ^ b[i]
}
return dist
}
func compareDist(a, b [20]byte) int {
for i := range 20 {
if a[i] < b[i] {
return -1
}
if a[i] > b[i] {
return 1
}
}
return 0
}
func (d *DHT) query(ctx context.Context, addr *net.UDPAddr, q string, a args) (*resp, error) {
ctx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
txid := genTxID()
a.ID = d.id
m := &msg{
T: txid,
Y: query,
Q: q,
A: a,
}
data, err := m.Encode()
if err != nil {
return nil, err
}
reply := make(chan *resp, 1)
d.mu.Lock()
d.pending[txid] = reply
d.mu.Unlock()
defer func() {
d.mu.Lock()
delete(d.pending, txid)
d.mu.Unlock()
}()
if _, err := d.conn.WriteToUDP(data, addr); err != nil {
return nil, err
}
select {
case r := <-reply:
return r, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (d *DHT) Bootstrap(ctx context.Context, addr string) error {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return err
}
ips, err := net.LookupHost(host)
if err != nil {
return err
}
if len(ips) == 0 {
return fmt.Errorf("no bootstrap nodes")
}
p, err := strconv.Atoi(port)
if err != nil {
return fmt.Errorf("invalid port: %w", err)
}
uaddr := &net.UDPAddr{
IP: net.ParseIP(ips[0]),
Port: p,
}
resp, err := d.query(ctx, uaddr, findNode, args{Target: d.id})
if err != nil {
return err
}
d.nodes = append(d.nodes, resp.Nodes...)
return nil
}
func sortByDist(nodes []node, target [20]byte) {
slices.SortFunc(nodes, func(a, b node) int {
return compareDist(xorDistance(a.ID, target), xorDistance(b.ID, target))
})
}
func (d *DHT) GetPeers(ctx context.Context, h [20]byte) ([]*net.TCPAddr, error) {
queried := make(map[string]bool)
candidates := make([]node, len(d.nodes))
copy(candidates, d.nodes)
sortByDist(candidates, h)
results := make(chan *resp)
inflight := 0
queryCount := 0
var peers []*net.TCPAddr
for queryCount < maxQueries {
if ctx.Err() != nil {
break
}
for inflight < alpha && queryCount < maxQueries {
n, i := nextCandidate(candidates, queried)
if i < 0 {
break
}
candidates = slices.Delete(candidates, i, i+1)
queried[n.Addr.String()] = true
inflight++
queryCount++
go func(addr *net.UDPAddr) {
r, _ := d.query(ctx, addr, getPeers, args{InfoHash: h})
results <- r
}(n.Addr)
}
if inflight == 0 {
break
}
r := <-results
inflight--
if r == nil {
continue
}
for _, p := range r.Peers {
if a := decodePeer(p); a != nil {
peers = append(peers, a)
}
}
for _, c := range r.Nodes {
if !queried[c.Addr.String()] {
candidates = append(candidates, c)
}
}
sortByDist(candidates, h)
}
for range inflight {
<-results
}
if len(peers) > 0 {
return peers, nil
}
return nil, fmt.Errorf("no peers found")
}
func nextCandidate(candidates []node, queried map[string]bool) (node, int) {
for i, c := range candidates {
if !queried[c.Addr.String()] {
return c, i
}
}
return node{}, -1
}

246
dht/msg.go Normal file
View File

@@ -0,0 +1,246 @@
package dht
// https://www.bittorrent.org/beps/bep_0005.html
// Peer: TCP
// Node: UDP
// BEP 5: DHT Protocol - KRPC over UDP
// Message: bencode dict with keys:
// t: transaction id (2 bytes)
// y: type "q" (query), "r" (response), "e" (error)
// q: query name (ping, find_node, get_peers, announce_peer)
// a: query args dict
// r: response dict
// e: error [code, message]
// Compact formats:
// peer: 6 bytes (4 ip + 2 port)
// node: 26 bytes (20 id + 4 ip + 2 port)
import (
"crypto/rand"
"encoding/binary"
"fmt"
"net"
"storrent/bencode"
)
const (
query = "q"
response = "r"
error_ = "e"
)
const (
ping = "ping"
findNode = "find_node"
getPeers = "get_peers"
announcePeer = "announce_peer"
)
type node struct {
ID [20]byte
Addr *net.UDPAddr
}
type args struct {
ID [20]byte
Target [20]byte
InfoHash [20]byte
Port int
Token string
ImpliedPort int
}
type resp struct {
ID [20]byte
Nodes []node
Peers []string
Token string
}
type errResp struct {
Code int
Msg string
}
type msg struct {
T string
Y string
Q string
A args
R resp
E errResp
}
func genTxID() string {
b := make([]byte, 2)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return string(b)
}
func genNodeID() [20]byte {
var id [20]byte
if _, err := rand.Read(id[:]); err != nil {
panic(err)
}
return id
}
func (m *msg) Encode() ([]byte, error) {
d := make(map[string]any)
d["t"] = m.T
d["y"] = m.Y
switch m.Y {
case query:
d["q"] = m.Q
a := make(map[string]any)
a["id"] = string(m.A.ID[:])
switch m.Q {
case findNode:
a["target"] = string(m.A.Target[:])
case getPeers:
a["info_hash"] = string(m.A.InfoHash[:])
case announcePeer:
a["info_hash"] = string(m.A.InfoHash[:])
a["port"] = int64(m.A.Port)
a["token"] = m.A.Token
if m.A.ImpliedPort != 0 {
a["implied_port"] = int64(m.A.ImpliedPort)
}
}
d["a"] = a
case response:
r := make(map[string]any)
r["id"] = string(m.R.ID[:])
if len(m.R.Nodes) > 0 {
r["nodes"] = encodeNodes(m.R.Nodes)
}
if len(m.R.Peers) > 0 {
vals := make([]any, len(m.R.Peers))
for i, p := range m.R.Peers {
vals[i] = p
}
r["values"] = vals
}
if m.R.Token != "" {
r["token"] = m.R.Token
}
d["r"] = r
case error_:
d["e"] = []any{int64(m.E.Code), m.E.Msg}
}
return bencode.Encode(d)
}
func decodeMsg(data []byte) (*msg, error) {
v, _, err := bencode.Decode(data)
if err != nil {
return nil, err
}
d, ok := v.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid krpc message")
}
m := &msg{}
m.T, _ = d["t"].(string)
m.Y, _ = d["y"].(string)
switch m.Y {
case query:
m.Q, _ = d["q"].(string)
if a, ok := d["a"].(map[string]any); ok {
if id, ok := a["id"].(string); ok && len(id) == 20 {
copy(m.A.ID[:], id)
}
if target, ok := a["target"].(string); ok && len(target) == 20 {
copy(m.A.Target[:], target)
}
if ih, ok := a["info_hash"].(string); ok && len(ih) == 20 {
copy(m.A.InfoHash[:], ih)
}
if port, ok := a["port"].(int64); ok {
m.A.Port = int(port)
}
if token, ok := a["token"].(string); ok {
m.A.Token = token
}
if iport, ok := a["implied_port"].(int64); ok {
m.A.ImpliedPort = int(iport)
}
}
case response:
if r, ok := d["r"].(map[string]any); ok {
if id, ok := r["id"].(string); ok && len(id) == 20 {
copy(m.R.ID[:], id)
}
if nodes, ok := r["nodes"].(string); ok {
m.R.Nodes = decodeNodes(nodes)
}
if vals, ok := r["values"].([]any); ok {
for _, v := range vals {
if p, ok := v.(string); ok {
m.R.Peers = append(m.R.Peers, p)
}
}
}
if token, ok := r["token"].(string); ok {
m.R.Token = token
}
}
case error_:
if e, ok := d["e"].([]any); ok && len(e) >= 2 {
if code, ok := e[0].(int64); ok {
m.E.Code = int(code)
}
m.E.Msg, _ = e[1].(string)
}
}
return m, nil
}
// 26 = 20 id + 4 ip + 2 port
func encodeNodes(nodes []node) string {
buf := make([]byte, 26*len(nodes))
for i, n := range nodes {
off := i * 26
copy(buf[off:], n.ID[:])
copy(buf[off+20:], n.Addr.IP.To4())
binary.BigEndian.PutUint16(buf[off+24:], uint16(n.Addr.Port))
}
return string(buf)
}
func decodeNodes(s string) []node {
data := []byte(s)
if len(data)%26 != 0 {
return nil
}
n := len(data) / 26
nodes := make([]node, n)
for i := range nodes {
off := i * 26
copy(nodes[i].ID[:], data[off:])
ip := net.IP(data[off+20 : off+24])
port := binary.BigEndian.Uint16(data[off+24:])
nodes[i].Addr = &net.UDPAddr{IP: ip, Port: int(port)}
}
return nodes
}
// 6 = 4 ip + 2 port
func decodePeer(s string) *net.TCPAddr {
if len(s) != 6 {
return nil
}
data := []byte(s)
return &net.TCPAddr{
IP: net.IP(data[:4]),
Port: int(binary.BigEndian.Uint16(data[4:])),
}
}

67
dht/msg_test.go Normal file
View File

@@ -0,0 +1,67 @@
package dht
import (
"net"
"strings"
"testing"
)
func nid(c string) [20]byte {
var id [20]byte
copy(id[:], strings.Repeat(c, 20))
return id
}
func TestEncodeDecodeNodes(t *testing.T) {
tests := []struct {
name string
nodes []node
}{
{
name: "empty",
nodes: []node{},
},
{
name: "single",
nodes: []node{
{
ID: nid("a"),
Addr: &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 6881,
},
},
},
},
{
name: "multiple",
nodes: []node{
{
ID: nid("a"),
Addr: &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 6881,
},
},
{
ID: nid("a"),
Addr: &net.UDPAddr{
IP: net.IPv4(192, 168, 1, 1),
Port: 8080,
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data := encodeNodes(tt.nodes)
got := decodeNodes(data)
data2 := encodeNodes(got)
if data != data2 {
t.Errorf("roundtrip mismatch: got %x, want %x", data2, data)
}
})
}
}