storrent/dht/msg.go
2026-01-19 21:13:01 +09:00

247 lines
4.7 KiB
Go

package dht
// https://www.bittorrent.org/beps/bep_0005.html
// Peer: TCP
// Node: UDP
// BEP 5: DHT Protocol - KRPC over UDP
// Message: bencode dict with keys:
// t: transaction id (2 bytes)
// y: type "q" (query), "r" (response), "e" (error)
// q: query name (ping, find_node, get_peers, announce_peer)
// a: query args dict
// r: response dict
// e: error [code, message]
// Compact formats:
// peer: 6 bytes (4 ip + 2 port)
// node: 26 bytes (20 id + 4 ip + 2 port)
import (
"crypto/rand"
"encoding/binary"
"fmt"
"net"
"storrent/bencode"
)
const (
query = "q"
response = "r"
error_ = "e"
)
const (
ping = "ping"
findNode = "find_node"
getPeers = "get_peers"
announcePeer = "announce_peer"
)
type node struct {
ID [20]byte
Addr *net.UDPAddr
}
type args struct {
ID [20]byte
Target [20]byte
InfoHash [20]byte
Port int
Token string
ImpliedPort int
}
type resp struct {
ID [20]byte
Nodes []node
Peers []string
Token string
}
type errResp struct {
Code int
Msg string
}
type msg struct {
T string
Y string
Q string
A args
R resp
E errResp
}
func genTxID() string {
b := make([]byte, 2)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return string(b)
}
func genNodeID() [20]byte {
var id [20]byte
if _, err := rand.Read(id[:]); err != nil {
panic(err)
}
return id
}
func (m *msg) Encode() ([]byte, error) {
d := make(map[string]any)
d["t"] = m.T
d["y"] = m.Y
switch m.Y {
case query:
d["q"] = m.Q
a := make(map[string]any)
a["id"] = string(m.A.ID[:])
switch m.Q {
case findNode:
a["target"] = string(m.A.Target[:])
case getPeers:
a["info_hash"] = string(m.A.InfoHash[:])
case announcePeer:
a["info_hash"] = string(m.A.InfoHash[:])
a["port"] = int64(m.A.Port)
a["token"] = m.A.Token
if m.A.ImpliedPort != 0 {
a["implied_port"] = int64(m.A.ImpliedPort)
}
}
d["a"] = a
case response:
r := make(map[string]any)
r["id"] = string(m.R.ID[:])
if len(m.R.Nodes) > 0 {
r["nodes"] = encodeNodes(m.R.Nodes)
}
if len(m.R.Peers) > 0 {
vals := make([]any, len(m.R.Peers))
for i, p := range m.R.Peers {
vals[i] = p
}
r["values"] = vals
}
if m.R.Token != "" {
r["token"] = m.R.Token
}
d["r"] = r
case error_:
d["e"] = []any{int64(m.E.Code), m.E.Msg}
}
return bencode.Encode(d)
}
func decodeMsg(data []byte) (*msg, error) {
v, _, err := bencode.Decode(data)
if err != nil {
return nil, err
}
d, ok := v.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid krpc message")
}
m := &msg{}
m.T, _ = d["t"].(string)
m.Y, _ = d["y"].(string)
switch m.Y {
case query:
m.Q, _ = d["q"].(string)
if a, ok := d["a"].(map[string]any); ok {
if id, ok := a["id"].(string); ok && len(id) == 20 {
copy(m.A.ID[:], id)
}
if target, ok := a["target"].(string); ok && len(target) == 20 {
copy(m.A.Target[:], target)
}
if ih, ok := a["info_hash"].(string); ok && len(ih) == 20 {
copy(m.A.InfoHash[:], ih)
}
if port, ok := a["port"].(int64); ok {
m.A.Port = int(port)
}
if token, ok := a["token"].(string); ok {
m.A.Token = token
}
if iport, ok := a["implied_port"].(int64); ok {
m.A.ImpliedPort = int(iport)
}
}
case response:
if r, ok := d["r"].(map[string]any); ok {
if id, ok := r["id"].(string); ok && len(id) == 20 {
copy(m.R.ID[:], id)
}
if nodes, ok := r["nodes"].(string); ok {
m.R.Nodes = decodeNodes(nodes)
}
if vals, ok := r["values"].([]any); ok {
for _, v := range vals {
if p, ok := v.(string); ok {
m.R.Peers = append(m.R.Peers, p)
}
}
}
if token, ok := r["token"].(string); ok {
m.R.Token = token
}
}
case error_:
if e, ok := d["e"].([]any); ok && len(e) >= 2 {
if code, ok := e[0].(int64); ok {
m.E.Code = int(code)
}
m.E.Msg, _ = e[1].(string)
}
}
return m, nil
}
// 26 = 20 id + 4 ip + 2 port
func encodeNodes(nodes []node) string {
buf := make([]byte, 26*len(nodes))
for i, n := range nodes {
off := i * 26
copy(buf[off:], n.ID[:])
copy(buf[off+20:], n.Addr.IP.To4())
binary.BigEndian.PutUint16(buf[off+24:], uint16(n.Addr.Port))
}
return string(buf)
}
func decodeNodes(s string) []node {
data := []byte(s)
if len(data)%26 != 0 {
return nil
}
n := len(data) / 26
nodes := make([]node, n)
for i := range nodes {
off := i * 26
copy(nodes[i].ID[:], data[off:])
ip := net.IP(data[off+20 : off+24])
port := binary.BigEndian.Uint16(data[off+24:])
nodes[i].Addr = &net.UDPAddr{IP: ip, Port: int(port)}
}
return nodes
}
// 6 = 4 ip + 2 port
func decodePeer(s string) *net.TCPAddr {
if len(s) != 6 {
return nil
}
data := []byte(s)
return &net.TCPAddr{
IP: net.IP(data[:4]),
Port: int(binary.BigEndian.Uint16(data[4:])),
}
}