storrent/bencode/bencode.go
2026-01-19 21:13:01 +09:00

174 lines
3.4 KiB
Go

package bencode
// bencode types:
// string <len>:<data> "4:spam"
// int i<n>e "i42e"
// list l<items>e "l4:spam4:eggse"
// dict d<kv pairs>e "d3:cow3:mooe"
import (
"fmt"
"sort"
"strconv"
)
func Decode(data []byte) (any, int, error) {
if len(data) == 0 {
return nil, 0, fmt.Errorf("empty data")
}
switch data[0] {
case 'i':
return decodeInt(data)
case 'l':
return decodeList(data)
case 'd':
return decodeDict(data)
default:
if data[0] >= '0' && data[0] <= '9' {
return DecodeString(data)
}
return nil, 0, fmt.Errorf("invalid bencode: %q", data[0])
}
}
func DecodeString(data []byte) (string, int, error) {
i := 0
for i < len(data) && data[i] != ':' {
i++
}
if i >= len(data) {
return "", 0, fmt.Errorf("bencode: missing ':' in string at %d", i)
}
n, err := strconv.Atoi(string(data[:i]))
if err != nil {
return "", 0, err
}
i++
if i+n > len(data) {
return "", 0, fmt.Errorf("bencode: truncated string at %d", i)
}
return string(data[i : i+n]), i + n, nil
}
func decodeInt(data []byte) (int64, int, error) {
if len(data) < 3 {
return 0, 0, fmt.Errorf("bencode: int too short at 0")
}
i := 1
for i < len(data) && data[i] != 'e' {
i++
}
if i >= len(data) {
return 0, 0, fmt.Errorf("bencode: missing 'e' in int at %d", i)
}
val, err := strconv.ParseInt(string(data[1:i]), 10, 64)
if err != nil {
return 0, 0, err
}
return val, i + 1, nil
}
func decodeList(data []byte) ([]any, int, error) {
if len(data) == 0 {
return nil, 0, fmt.Errorf("bencode: empty list at 0")
}
var list []any
i := 1
for i < len(data) && data[i] != 'e' {
v, n, err := Decode(data[i:])
if err != nil {
return nil, 0, err
}
list = append(list, v)
i += n
}
if i >= len(data) {
return nil, 0, fmt.Errorf("bencode: truncated list at %d", i)
}
return list, i + 1, nil
}
func decodeDict(data []byte) (map[string]any, int, error) {
if len(data) == 0 {
return nil, 0, fmt.Errorf("bencode: empty dict at 0")
}
d := make(map[string]any)
i := 1
for i < len(data) && data[i] != 'e' {
k, n, err := DecodeString(data[i:])
if err != nil {
return nil, 0, err
}
i += n
v, n, err := Decode(data[i:])
if err != nil {
return nil, 0, err
}
d[k] = v
i += n
}
if i >= len(data) {
return nil, 0, fmt.Errorf("bencode: truncated dict at %d", i)
}
return d, i + 1, nil
}
func Encode(v any) ([]byte, error) {
switch v := v.(type) {
case string:
return encodeString(v), nil
case []byte:
return encodeString(string(v)), nil
case int:
return encodeInt(int64(v)), nil
case int64:
return encodeInt(v), nil
case []any:
return encodeList(v)
case map[string]any:
return encodeDict(v)
}
return nil, fmt.Errorf("cannot encode %T", v)
}
func encodeString(s string) []byte {
return fmt.Appendf(nil, "%d:%s", len(s), s)
}
func encodeInt(n int64) []byte {
return fmt.Appendf(nil, "i%de", n)
}
func encodeList(list []any) ([]byte, error) {
buf := []byte{'l'}
for _, v := range list {
enc, err := Encode(v)
if err != nil {
return nil, err
}
buf = append(buf, enc...)
}
return append(buf, 'e'), nil
}
func encodeDict(d map[string]any) ([]byte, error) {
keys := make([]string, 0, len(d))
for k := range d {
keys = append(keys, k)
}
sort.Strings(keys)
buf := []byte{'d'}
for _, k := range keys {
buf = append(buf, encodeString(k)...)
enc, err := Encode(d[k])
if err != nil {
return nil, err
}
buf = append(buf, enc...)
}
return append(buf, 'e'), nil
}