first commit
This commit is contained in:
commit
c70d24be5c
38
.gitignore
vendored
Normal file
38
.gitignore
vendored
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# If you prefer the allow list template instead of the deny list, see community template:
|
||||||
|
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
|
||||||
|
#
|
||||||
|
|
||||||
|
*.torrent
|
||||||
|
storrent
|
||||||
|
|
||||||
|
# Binaries for programs and plugins
|
||||||
|
*.exe
|
||||||
|
*.exe~
|
||||||
|
*.dll
|
||||||
|
*.so
|
||||||
|
*.dylib
|
||||||
|
|
||||||
|
# Test binary, built with `go test -c`
|
||||||
|
*.test
|
||||||
|
|
||||||
|
# Code coverage profiles and other test artifacts
|
||||||
|
*.out
|
||||||
|
coverage.*
|
||||||
|
*.coverprofile
|
||||||
|
profile.cov
|
||||||
|
|
||||||
|
# Dependency directories (remove the comment below to include it)
|
||||||
|
# vendor/
|
||||||
|
|
||||||
|
# Go workspace file
|
||||||
|
go.work
|
||||||
|
go.work.sum
|
||||||
|
|
||||||
|
# env file
|
||||||
|
.env
|
||||||
|
|
||||||
|
# Editor/IDE
|
||||||
|
# .idea/
|
||||||
|
# .vscode/
|
||||||
|
|
||||||
|
|
||||||
64
README.md
Normal file
64
README.md
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# storrent
|
||||||
|
|
||||||
|
A BitTorrent client with a 9P filesystem interface.
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
- plan9port
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- DHT for peer discovery
|
||||||
|
- μTP transport
|
||||||
|
- 9P control interface
|
||||||
|
|
||||||
|
## Build
|
||||||
|
|
||||||
|
go build
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
./storrent -dht 6881 -utp 6882 -dir ./download -addr :5640 &
|
||||||
|
9pfuse localhost:5640 /mnt/storrent
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Interact via files:
|
||||||
|
|
||||||
|
echo 'add /path/to/file.torrent' > /mnt/storrent/ctl
|
||||||
|
cat /mnt/storrent/list
|
||||||
|
cat /mnt/storrent/torrents/0/progress
|
||||||
|
|
||||||
|
Torrent control commands (write to `torrents/<id>/ctl`):
|
||||||
|
|
||||||
|
start start downloading
|
||||||
|
stop stop torrent
|
||||||
|
seed start seeding
|
||||||
|
remove remove torrent
|
||||||
|
peer <addr> add peer manually
|
||||||
|
|
||||||
|
Status files in `torrents/<id>/`:
|
||||||
|
|
||||||
|
name torrent name
|
||||||
|
state current state
|
||||||
|
progress download progress
|
||||||
|
size total size
|
||||||
|
down bytes downloaded
|
||||||
|
up bytes uploaded
|
||||||
|
pieces piece info
|
||||||
|
peers connected peers
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
|
||||||
|
- μTP congestion control
|
||||||
|
- ReadAt for streaming and seeding while downloading
|
||||||
|
|
||||||
|
## Packages
|
||||||
|
|
||||||
|
bencode/ bencode encoding/decoding
|
||||||
|
bt/ BitTorrent protocol messages
|
||||||
|
client/ torrent management
|
||||||
|
dht/ DHT implementation
|
||||||
|
fs/ 9P filesystem
|
||||||
|
metainfo/ .torrent file parsing
|
||||||
|
utp/ uTP transport
|
||||||
173
bencode/bencode.go
Normal file
173
bencode/bencode.go
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
package bencode
|
||||||
|
|
||||||
|
// bencode types:
|
||||||
|
// string <len>:<data> "4:spam"
|
||||||
|
// int i<n>e "i42e"
|
||||||
|
// list l<items>e "l4:spam4:eggse"
|
||||||
|
// dict d<kv pairs>e "d3:cow3:mooe"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Decode(data []byte) (any, int, error) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil, 0, fmt.Errorf("empty data")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch data[0] {
|
||||||
|
case 'i':
|
||||||
|
return decodeInt(data)
|
||||||
|
case 'l':
|
||||||
|
return decodeList(data)
|
||||||
|
case 'd':
|
||||||
|
return decodeDict(data)
|
||||||
|
default:
|
||||||
|
if data[0] >= '0' && data[0] <= '9' {
|
||||||
|
return DecodeString(data)
|
||||||
|
}
|
||||||
|
return nil, 0, fmt.Errorf("invalid bencode: %q", data[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecodeString(data []byte) (string, int, error) {
|
||||||
|
i := 0
|
||||||
|
for i < len(data) && data[i] != ':' {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i >= len(data) {
|
||||||
|
return "", 0, fmt.Errorf("bencode: missing ':' in string at %d", i)
|
||||||
|
}
|
||||||
|
n, err := strconv.Atoi(string(data[:i]))
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
if i+n > len(data) {
|
||||||
|
return "", 0, fmt.Errorf("bencode: truncated string at %d", i)
|
||||||
|
}
|
||||||
|
return string(data[i : i+n]), i + n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeInt(data []byte) (int64, int, error) {
|
||||||
|
if len(data) < 3 {
|
||||||
|
return 0, 0, fmt.Errorf("bencode: int too short at 0")
|
||||||
|
}
|
||||||
|
i := 1
|
||||||
|
for i < len(data) && data[i] != 'e' {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i >= len(data) {
|
||||||
|
return 0, 0, fmt.Errorf("bencode: missing 'e' in int at %d", i)
|
||||||
|
}
|
||||||
|
val, err := strconv.ParseInt(string(data[1:i]), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
return val, i + 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeList(data []byte) ([]any, int, error) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil, 0, fmt.Errorf("bencode: empty list at 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
var list []any
|
||||||
|
i := 1
|
||||||
|
for i < len(data) && data[i] != 'e' {
|
||||||
|
v, n, err := Decode(data[i:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
list = append(list, v)
|
||||||
|
i += n
|
||||||
|
}
|
||||||
|
if i >= len(data) {
|
||||||
|
return nil, 0, fmt.Errorf("bencode: truncated list at %d", i)
|
||||||
|
}
|
||||||
|
return list, i + 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeDict(data []byte) (map[string]any, int, error) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil, 0, fmt.Errorf("bencode: empty dict at 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
d := make(map[string]any)
|
||||||
|
i := 1
|
||||||
|
for i < len(data) && data[i] != 'e' {
|
||||||
|
k, n, err := DecodeString(data[i:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
i += n
|
||||||
|
v, n, err := Decode(data[i:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
d[k] = v
|
||||||
|
i += n
|
||||||
|
}
|
||||||
|
if i >= len(data) {
|
||||||
|
return nil, 0, fmt.Errorf("bencode: truncated dict at %d", i)
|
||||||
|
}
|
||||||
|
return d, i + 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Encode(v any) ([]byte, error) {
|
||||||
|
switch v := v.(type) {
|
||||||
|
case string:
|
||||||
|
return encodeString(v), nil
|
||||||
|
case []byte:
|
||||||
|
return encodeString(string(v)), nil
|
||||||
|
case int:
|
||||||
|
return encodeInt(int64(v)), nil
|
||||||
|
case int64:
|
||||||
|
return encodeInt(v), nil
|
||||||
|
case []any:
|
||||||
|
return encodeList(v)
|
||||||
|
case map[string]any:
|
||||||
|
return encodeDict(v)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("cannot encode %T", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeString(s string) []byte {
|
||||||
|
return fmt.Appendf(nil, "%d:%s", len(s), s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeInt(n int64) []byte {
|
||||||
|
return fmt.Appendf(nil, "i%de", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeList(list []any) ([]byte, error) {
|
||||||
|
buf := []byte{'l'}
|
||||||
|
for _, v := range list {
|
||||||
|
enc, err := Encode(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
buf = append(buf, enc...)
|
||||||
|
}
|
||||||
|
return append(buf, 'e'), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeDict(d map[string]any) ([]byte, error) {
|
||||||
|
keys := make([]string, 0, len(d))
|
||||||
|
for k := range d {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
buf := []byte{'d'}
|
||||||
|
for _, k := range keys {
|
||||||
|
buf = append(buf, encodeString(k)...)
|
||||||
|
enc, err := Encode(d[k])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
buf = append(buf, enc...)
|
||||||
|
}
|
||||||
|
return append(buf, 'e'), nil
|
||||||
|
}
|
||||||
113
bencode/bencode_test.go
Normal file
113
bencode/bencode_test.go
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
package bencode
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDecode(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
in string
|
||||||
|
want any
|
||||||
|
err bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
in: "4:spam",
|
||||||
|
want: "spam",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "0:",
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "i42e",
|
||||||
|
want: int64(42),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "i-42e",
|
||||||
|
want: int64(-42),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "i0e",
|
||||||
|
want: int64(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "",
|
||||||
|
err: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.in, func(t *testing.T) {
|
||||||
|
got, _, err := Decode([]byte(tt.in))
|
||||||
|
if tt.err {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("got %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncode(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in any
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string",
|
||||||
|
in: "spam",
|
||||||
|
want: "4:spam",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
in: "",
|
||||||
|
want: "0:",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "int",
|
||||||
|
in: int64(42),
|
||||||
|
want: "i42e",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative int",
|
||||||
|
in: int64(-42),
|
||||||
|
want: "i-42e",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "list",
|
||||||
|
in: []any{"a", "b"},
|
||||||
|
want: "l1:a1:be",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty list",
|
||||||
|
in: []any{},
|
||||||
|
want: "le",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dict",
|
||||||
|
in: map[string]any{"b": int64(2), "a": int64(1)},
|
||||||
|
want: "d1:ai1e1:bi2ee",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty dict",
|
||||||
|
in: map[string]any{},
|
||||||
|
want: "de",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := Encode(tt.in)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != tt.want {
|
||||||
|
t.Errorf("got %s, want %s", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
116
bt/msg.go
Normal file
116
bt/msg.go
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
package bt
|
||||||
|
|
||||||
|
// https://wiki.theory.org/BitTorrentSpecification#Messages
|
||||||
|
// Messages: <length prefix><message ID><payload>
|
||||||
|
// keep-alive: <len=0000>
|
||||||
|
// choke: <len=0001><type=0>
|
||||||
|
// unchoke: <len=0001><type=1>
|
||||||
|
// interested: <len=0001><type=2>
|
||||||
|
// not interested: <len=0001><type=3>
|
||||||
|
// have: <len=0005><type=4><piece index>
|
||||||
|
// bitfield: <len=0001+X><type=5><bitfield>
|
||||||
|
// request: <len=0013><type=6><index><begin><length>
|
||||||
|
// piece: <len=0009+X><type=7><index><begin><block>
|
||||||
|
// cancel: <len=0013><type=8><index><begin><length>
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MsgKind uint8
|
||||||
|
|
||||||
|
const BlockSize = 16384
|
||||||
|
|
||||||
|
const (
|
||||||
|
Choke MsgKind = iota
|
||||||
|
Unchoke
|
||||||
|
Interested
|
||||||
|
NotInterested
|
||||||
|
Have
|
||||||
|
Bitfield
|
||||||
|
Request
|
||||||
|
Piece
|
||||||
|
Cancel
|
||||||
|
)
|
||||||
|
|
||||||
|
type Msg struct {
|
||||||
|
Kind MsgKind
|
||||||
|
Index uint32
|
||||||
|
Begin uint32
|
||||||
|
Length uint32
|
||||||
|
Bitfield []byte
|
||||||
|
Block []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// [4: len] [1: type] [len-1: payload]
|
||||||
|
func readMsg(r io.Reader) (*Msg, int, error) {
|
||||||
|
var lenBuf [4]byte
|
||||||
|
if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
n := binary.BigEndian.Uint32(lenBuf[:])
|
||||||
|
if n == 0 {
|
||||||
|
return nil, 4, nil
|
||||||
|
}
|
||||||
|
buf := make([]byte, n)
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
msg := &Msg{Kind: MsgKind(buf[0])}
|
||||||
|
payload := buf[1:]
|
||||||
|
switch msg.Kind {
|
||||||
|
case Have:
|
||||||
|
msg.Index = binary.BigEndian.Uint32(payload)
|
||||||
|
case Bitfield:
|
||||||
|
msg.Bitfield = payload
|
||||||
|
case Request, Cancel:
|
||||||
|
msg.Index = binary.BigEndian.Uint32(payload[0:4])
|
||||||
|
msg.Begin = binary.BigEndian.Uint32(payload[4:8])
|
||||||
|
msg.Length = binary.BigEndian.Uint32(payload[8:12])
|
||||||
|
case Piece:
|
||||||
|
msg.Index = binary.BigEndian.Uint32(payload[0:4])
|
||||||
|
msg.Begin = binary.BigEndian.Uint32(payload[4:8])
|
||||||
|
msg.Block = payload[8:]
|
||||||
|
}
|
||||||
|
return msg, int(4 + n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeMsg(w io.Writer, msg *Msg) (int, error) {
|
||||||
|
if msg == nil {
|
||||||
|
return w.Write(make([]byte, 4))
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload []byte
|
||||||
|
switch msg.Kind {
|
||||||
|
case Choke, Unchoke, Interested, NotInterested:
|
||||||
|
case Have:
|
||||||
|
payload = make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint32(payload, msg.Index)
|
||||||
|
case Bitfield:
|
||||||
|
payload = msg.Bitfield
|
||||||
|
case Request, Cancel:
|
||||||
|
payload = make([]byte, 12)
|
||||||
|
binary.BigEndian.PutUint32(payload[0:4], msg.Index)
|
||||||
|
binary.BigEndian.PutUint32(payload[4:8], msg.Begin)
|
||||||
|
binary.BigEndian.PutUint32(payload[8:12], msg.Length)
|
||||||
|
case Piece:
|
||||||
|
payload = make([]byte, 8+len(msg.Block))
|
||||||
|
binary.BigEndian.PutUint32(payload[0:4], msg.Index)
|
||||||
|
binary.BigEndian.PutUint32(payload[4:8], msg.Begin)
|
||||||
|
copy(payload[8:], msg.Block)
|
||||||
|
}
|
||||||
|
buf := make([]byte, 5+len(payload))
|
||||||
|
binary.BigEndian.PutUint32(buf[0:4], uint32(1+len(payload)))
|
||||||
|
buf[4] = byte(msg.Kind)
|
||||||
|
copy(buf[5:], payload)
|
||||||
|
return w.Write(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetBit(bf []byte, i int) {
|
||||||
|
bf[i/8] |= 1 << (7 - i%8)
|
||||||
|
}
|
||||||
|
|
||||||
|
func HasBit(bf []byte, i int) bool {
|
||||||
|
return bf[i/8]&(1<<(7-i%8)) != 0
|
||||||
|
}
|
||||||
91
bt/msg_test.go
Normal file
91
bt/msg_test.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package bt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMsg(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
msg *Msg
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "choke",
|
||||||
|
msg: &Msg{Kind: Choke},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unchoke",
|
||||||
|
msg: &Msg{Kind: Unchoke},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "interested",
|
||||||
|
msg: &Msg{Kind: Interested},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not interested",
|
||||||
|
msg: &Msg{Kind: NotInterested},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "have",
|
||||||
|
msg: &Msg{
|
||||||
|
Kind: Have,
|
||||||
|
Index: 42,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bitfield",
|
||||||
|
msg: &Msg{
|
||||||
|
Kind: Bitfield,
|
||||||
|
Bitfield: []byte{0b1111_1111, 0b0000_0000},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "request",
|
||||||
|
msg: &Msg{
|
||||||
|
Kind: Request,
|
||||||
|
Index: 1,
|
||||||
|
Begin: 16384,
|
||||||
|
Length: 16384,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "piece",
|
||||||
|
msg: &Msg{
|
||||||
|
Kind: Piece,
|
||||||
|
Index: 1,
|
||||||
|
Begin: 0,
|
||||||
|
Block: []byte("data"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cancel",
|
||||||
|
msg: &Msg{
|
||||||
|
Kind: Cancel,
|
||||||
|
Index: 1,
|
||||||
|
Begin: 0,
|
||||||
|
Length: 16384,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "keep-alive",
|
||||||
|
msg: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writeMsg(&buf, tt.msg)
|
||||||
|
data := buf.Bytes()
|
||||||
|
got, _, err := readMsg(bytes.NewReader(data))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("readMsg: %v", err)
|
||||||
|
}
|
||||||
|
var buf2 bytes.Buffer
|
||||||
|
writeMsg(&buf2, got)
|
||||||
|
if !bytes.Equal(buf2.Bytes(), data) {
|
||||||
|
t.Errorf("roundtrip mismatch: got %x, want %x", buf2.Bytes(), data)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
160
bt/peer.go
Normal file
160
bt/peer.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
package bt
|
||||||
|
|
||||||
|
// https://wiki.theory.org/BitTorrentSpecification#Handshake
|
||||||
|
// handshake: <pstrlen><pstr><reserved><info_hash><peer_id>
|
||||||
|
// pstrlen: string length of <pstr>, as a single raw byte.
|
||||||
|
// pstr: string identifier of the protocol.
|
||||||
|
// reserved: eight (8) reserved bytes.
|
||||||
|
// info_hash: 20-byte SHA1 hash of the info key in the metainfo file.
|
||||||
|
// peer_id: 20-byte string used as a unique ID for the client.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"storrent/utp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Peer struct {
|
||||||
|
Addr *net.TCPAddr
|
||||||
|
conn io.ReadWriteCloser
|
||||||
|
InfoHash [20]byte
|
||||||
|
PeerID [20]byte
|
||||||
|
Choked bool
|
||||||
|
Interested bool
|
||||||
|
Have []bool
|
||||||
|
Pending int
|
||||||
|
}
|
||||||
|
|
||||||
|
const btProto = "BitTorrent protocol"
|
||||||
|
|
||||||
|
const (
|
||||||
|
pstrlen = 19
|
||||||
|
reservedLen = 8
|
||||||
|
hashLen = 20
|
||||||
|
peerIDLen = 20
|
||||||
|
handshakeLen = 1 + pstrlen + reservedLen + hashLen + peerIDLen // 68
|
||||||
|
handshakeTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var peerID = genPeerID()
|
||||||
|
|
||||||
|
func genPeerID() [20]byte {
|
||||||
|
var id [20]byte
|
||||||
|
copy(id[:], "-SS0001-")
|
||||||
|
if _, err := rand.Read(id[8:]); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialContext(ctx context.Context, addr *net.TCPAddr, h [20]byte, utpSock *utp.Socket) (*Peer, error) {
|
||||||
|
if utpSock == nil {
|
||||||
|
return nil, fmt.Errorf("utp socket required")
|
||||||
|
}
|
||||||
|
udpAddr := &net.UDPAddr{IP: addr.IP, Port: addr.Port}
|
||||||
|
conn, err := utpSock.DialContext(ctx, udpAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
p := &Peer{Addr: addr, conn: conn, InfoHash: h, Choked: true}
|
||||||
|
if err := p.handshake(); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// [1: pstrlen] [pstrlen: btProto] [reservedLen: reserved] [hashLen: info_hash] [peerIDLen: peer_id]
|
||||||
|
func (p *Peer) handshake() error {
|
||||||
|
buf := make([]byte, handshakeLen)
|
||||||
|
buf[0] = pstrlen
|
||||||
|
copy(buf[1:1+pstrlen], btProto)
|
||||||
|
copy(buf[1+pstrlen+reservedLen:1+pstrlen+reservedLen+hashLen], p.InfoHash[:])
|
||||||
|
copy(buf[1+pstrlen+reservedLen+hashLen:], peerID[:])
|
||||||
|
if _, err := p.conn.Write(buf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := io.ReadFull(p.conn, buf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if buf[0] != pstrlen || string(buf[1:1+pstrlen]) != btProto {
|
||||||
|
return fmt.Errorf("invalid handshake")
|
||||||
|
}
|
||||||
|
if [hashLen]byte(buf[1+pstrlen+reservedLen:1+pstrlen+reservedLen+hashLen]) != p.InfoHash {
|
||||||
|
return fmt.Errorf("info hash mismatch")
|
||||||
|
}
|
||||||
|
p.PeerID = [peerIDLen]byte(buf[1+pstrlen+reservedLen+hashLen:])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) Send(msg *Msg) error {
|
||||||
|
_, err := writeMsg(p.conn, msg)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) Recv() (*Msg, error) {
|
||||||
|
msg, _, err := readMsg(p.conn)
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) Close() error {
|
||||||
|
return p.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Accept(ln net.Listener) (*Peer, error) {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn.SetDeadline(time.Now().Add(handshakeTimeout))
|
||||||
|
buf := make([]byte, handshakeLen)
|
||||||
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if buf[0] != pstrlen || string(buf[1:1+pstrlen]) != btProto {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("invalid handshake")
|
||||||
|
}
|
||||||
|
infoHash := [hashLen]byte(buf[1+pstrlen+reservedLen : 1+pstrlen+reservedLen+hashLen])
|
||||||
|
remotePeerID := [peerIDLen]byte(buf[1+pstrlen+reservedLen+hashLen:])
|
||||||
|
copy(buf[1+pstrlen+reservedLen+hashLen:], peerID[:])
|
||||||
|
if _, err := conn.Write(buf); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn.SetDeadline(time.Time{})
|
||||||
|
return &Peer{
|
||||||
|
conn: conn,
|
||||||
|
InfoHash: infoHash,
|
||||||
|
PeerID: remotePeerID,
|
||||||
|
Choked: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// piece 0~7: byte 0 [7 6 5 4 3 2 1 0]
|
||||||
|
// piece 8~15: byte 1 [7 6 5 4 3 2 1 0]
|
||||||
|
func (p *Peer) SetPieces(data []byte, n int) {
|
||||||
|
p.Have = make([]bool, n)
|
||||||
|
for i := range n {
|
||||||
|
p.Have[i] = HasBit(data, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) HasPiece(i int) bool {
|
||||||
|
if i >= 0 && i < len(p.Have) {
|
||||||
|
return p.Have[i]
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) SetPiece(i int) {
|
||||||
|
if i >= 0 && i < len(p.Have) {
|
||||||
|
p.Have[i] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
37
bt/peer_test.go
Normal file
37
bt/peer_test.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package bt
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSetPieces(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
n int
|
||||||
|
i int
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "piece 0 set",
|
||||||
|
data: []byte{0b1000_0000},
|
||||||
|
n: 8,
|
||||||
|
i: 0,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "piece 7 unset",
|
||||||
|
data: []byte{0b1111_1110},
|
||||||
|
n: 8,
|
||||||
|
i: 7,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Peer{}
|
||||||
|
c.SetPieces(tt.data, tt.n)
|
||||||
|
if c.Have[tt.i] != tt.want {
|
||||||
|
t.Errorf("Have[%d]: got %v, want %v", tt.i, c.Have[tt.i], tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
139
client/dial.go
Normal file
139
client/dial.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"storrent/bt"
|
||||||
|
"storrent/dht"
|
||||||
|
"storrent/utp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxConcurrentDials = 8
|
||||||
|
|
||||||
|
type dialResult struct {
|
||||||
|
peer *bt.Peer
|
||||||
|
err error
|
||||||
|
done bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type dialPool struct {
|
||||||
|
sem chan struct{}
|
||||||
|
results chan dialResult
|
||||||
|
|
||||||
|
dht *dht.DHT
|
||||||
|
utp *utp.Socket
|
||||||
|
infoHash [20]byte
|
||||||
|
|
||||||
|
running bool
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDialPool(d *dht.DHT, utpSock *utp.Socket, infoHash [20]byte) *dialPool {
|
||||||
|
return &dialPool{
|
||||||
|
sem: make(chan struct{}, maxConcurrentDials),
|
||||||
|
results: make(chan dialResult, maxConcurrentDials),
|
||||||
|
dht: d,
|
||||||
|
utp: utpSock,
|
||||||
|
infoHash: infoHash,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dp *dialPool) start(parentCtx context.Context) {
|
||||||
|
if dp.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dp.running = true
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
dp.cancel = cancel
|
||||||
|
go dp.connectLoop(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dp *dialPool) stop() {
|
||||||
|
if !dp.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dp.running = false
|
||||||
|
if dp.cancel != nil {
|
||||||
|
dp.cancel()
|
||||||
|
dp.cancel = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dp *dialPool) connectLoop(ctx context.Context) {
|
||||||
|
dhtCtx, cancel := context.WithTimeout(ctx, dhtTimeout)
|
||||||
|
addrs, err := dp.dht.GetPeers(dhtCtx, dp.infoHash)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case dp.results <- dialResult{done: true, err: err}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
loop:
|
||||||
|
for _, addr := range addrs {
|
||||||
|
select {
|
||||||
|
case dp.sem <- struct{}{}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(addr *net.TCPAddr) {
|
||||||
|
defer wg.Done()
|
||||||
|
defer func() { <-dp.sem }()
|
||||||
|
|
||||||
|
dialCtx, cancel := context.WithTimeout(ctx, dialTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
p, err := bt.DialContext(dialCtx, addr, dp.infoHash, dp.utp)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case dp.results <- dialResult{peer: p}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
p.Close()
|
||||||
|
}
|
||||||
|
}(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case dp.results <- dialResult{done: true}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dp *dialPool) dialSingle(ctx context.Context, addr *net.TCPAddr) {
|
||||||
|
select {
|
||||||
|
case dp.sem <- struct{}{}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { <-dp.sem }()
|
||||||
|
|
||||||
|
dialCtx, cancel := context.WithTimeout(ctx, dialTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
p, err := bt.DialContext(dialCtx, addr, dp.infoHash, dp.utp)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case dp.results <- dialResult{peer: p}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
p.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
178
client/manager.go
Normal file
178
client/manager.go
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"storrent/bt"
|
||||||
|
"storrent/dht"
|
||||||
|
"storrent/metainfo"
|
||||||
|
"storrent/utp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errNotFound = errors.New("torrent not found")
|
||||||
|
|
||||||
|
func (m *Manager) getTorrent(id int) (*torrent, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
t, ok := m.torrents[id]
|
||||||
|
m.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
return nil, errNotFound
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
torrents map[int]*torrent
|
||||||
|
nextID int
|
||||||
|
dir string
|
||||||
|
dht *dht.DHT
|
||||||
|
ln net.Listener
|
||||||
|
utp *utp.Socket
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(dir string, d *dht.DHT, btPort int, utpSock *utp.Socket) (*Manager, error) {
|
||||||
|
var ln net.Listener
|
||||||
|
if btPort > 0 {
|
||||||
|
var err error
|
||||||
|
ln, err = net.Listen("tcp", fmt.Sprintf(":%d", btPort))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &Manager{
|
||||||
|
torrents: make(map[int]*torrent),
|
||||||
|
dir: dir,
|
||||||
|
dht: d,
|
||||||
|
ln: ln,
|
||||||
|
utp: utpSock,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Run() {
|
||||||
|
if m.ln == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
p, err := bt.Accept(m.ln)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.dispatch(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) dispatch(p *bt.Peer) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
for _, t := range m.torrents {
|
||||||
|
if t.meta.InfoHash == p.InfoHash {
|
||||||
|
select {
|
||||||
|
case t.incoming <- p:
|
||||||
|
default:
|
||||||
|
p.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Add(path string) ([]byte, error) {
|
||||||
|
mi, err := metainfo.Parse(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
id := m.nextID
|
||||||
|
m.nextID++
|
||||||
|
t := newTorrent(id, mi, m.dir, m.dht, m.utp)
|
||||||
|
complete := t.scheduler.isComplete()
|
||||||
|
go t.run()
|
||||||
|
m.torrents[id] = t
|
||||||
|
m.mu.Unlock()
|
||||||
|
if complete {
|
||||||
|
t.startSeed()
|
||||||
|
} else {
|
||||||
|
t.startDownload()
|
||||||
|
}
|
||||||
|
return []byte(strconv.Itoa(id)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Remove(id int) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
t, ok := m.torrents[id]
|
||||||
|
if !ok {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return errNotFound
|
||||||
|
}
|
||||||
|
delete(m.torrents, id)
|
||||||
|
m.mu.Unlock()
|
||||||
|
t.stopTorrent()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Start(id int) error {
|
||||||
|
t, err := m.getTorrent(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.startDownload()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Stop(id int) error {
|
||||||
|
t, err := m.getTorrent(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.stopTorrent()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Seed(id int) error {
|
||||||
|
t, err := m.getTorrent(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.startSeed()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) AddPeer(id int, addr string) error {
|
||||||
|
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t, err := m.getTorrent(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.enqueuePeer(tcpAddr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Status(id int, field string) ([]byte, error) {
|
||||||
|
t, err := m.getTorrent(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return []byte(t.getStatus(field)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) List() []byte {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
var data []byte
|
||||||
|
for id := range m.torrents {
|
||||||
|
if len(data) > 0 {
|
||||||
|
data = append(data, '\n')
|
||||||
|
}
|
||||||
|
data = append(data, []byte(strconv.Itoa(id))...)
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
50
client/peer.go
Normal file
50
client/peer.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import "storrent/bt"
|
||||||
|
|
||||||
|
type peerManager struct {
|
||||||
|
peers map[*bt.Peer]struct{}
|
||||||
|
maxPeers int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPeerManager(maxPeers int) *peerManager {
|
||||||
|
return &peerManager{
|
||||||
|
peers: make(map[*bt.Peer]struct{}),
|
||||||
|
maxPeers: maxPeers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *peerManager) add(p *bt.Peer) bool {
|
||||||
|
if len(pm.peers) >= pm.maxPeers {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
pm.peers[p] = struct{}{}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *peerManager) addUnlimited(p *bt.Peer) {
|
||||||
|
pm.peers[p] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *peerManager) remove(p *bt.Peer) {
|
||||||
|
delete(pm.peers, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *peerManager) count() int {
|
||||||
|
return len(pm.peers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *peerManager) all() []*bt.Peer {
|
||||||
|
peers := make([]*bt.Peer, 0, len(pm.peers))
|
||||||
|
for p := range pm.peers {
|
||||||
|
peers = append(peers, p)
|
||||||
|
}
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *peerManager) closeAll() {
|
||||||
|
for p := range pm.peers {
|
||||||
|
p.Close()
|
||||||
|
}
|
||||||
|
pm.peers = make(map[*bt.Peer]struct{})
|
||||||
|
}
|
||||||
69
client/piece.go
Normal file
69
client/piece.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
|
||||||
|
"storrent/bt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type piece struct {
|
||||||
|
hash [20]byte
|
||||||
|
size int64
|
||||||
|
data []byte
|
||||||
|
have []bool
|
||||||
|
reqs []bool
|
||||||
|
reqPeer []*bt.Peer
|
||||||
|
done bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *piece) put(begin int, data []byte) bool {
|
||||||
|
if p.done {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
block := begin / bt.BlockSize
|
||||||
|
if block >= len(p.have) || p.have[block] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
copy(p.data[begin:], data)
|
||||||
|
p.have[block] = true
|
||||||
|
for _, got := range p.have {
|
||||||
|
if !got {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sha1.Sum(p.data) != p.hash {
|
||||||
|
for i := range p.have {
|
||||||
|
p.have[i] = false
|
||||||
|
p.reqs[i] = false
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
p.done = true
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *piece) next(peer *bt.Peer) (begin, length int64, ok bool) {
|
||||||
|
for i := range p.have {
|
||||||
|
if p.have[i] || p.reqs[i] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
p.reqs[i] = true
|
||||||
|
p.reqPeer[i] = peer
|
||||||
|
begin = int64(i) * bt.BlockSize
|
||||||
|
length = bt.BlockSize
|
||||||
|
if begin+length > p.size {
|
||||||
|
length = p.size - begin
|
||||||
|
}
|
||||||
|
return begin, length, true
|
||||||
|
}
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *piece) resetPeer(peer *bt.Peer) {
|
||||||
|
for i := range p.reqs {
|
||||||
|
if p.reqs[i] && p.reqPeer[i] == peer {
|
||||||
|
p.reqs[i] = false
|
||||||
|
p.reqPeer[i] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
134
client/scheduler.go
Normal file
134
client/scheduler.go
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
|
||||||
|
"storrent/bt"
|
||||||
|
"storrent/metainfo"
|
||||||
|
)
|
||||||
|
|
||||||
|
type pieceScheduler struct {
|
||||||
|
pieces []piece
|
||||||
|
incomplete []int
|
||||||
|
verified int
|
||||||
|
meta *metainfo.File
|
||||||
|
dir string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPieceScheduler(m *metainfo.File, dir string) *pieceScheduler {
|
||||||
|
n := len(m.Info.Pieces)
|
||||||
|
pieces := make([]piece, n)
|
||||||
|
incomplete := make([]int, 0, n)
|
||||||
|
verified := 0
|
||||||
|
|
||||||
|
for i := range n {
|
||||||
|
size := m.Info.PieceSize
|
||||||
|
if i == n-1 {
|
||||||
|
size = m.Size - int64(n-1)*m.Info.PieceSize
|
||||||
|
}
|
||||||
|
nblocks := (size + bt.BlockSize - 1) / bt.BlockSize
|
||||||
|
pieces[i] = piece{
|
||||||
|
hash: m.Info.Pieces[i],
|
||||||
|
size: size,
|
||||||
|
data: make([]byte, size),
|
||||||
|
have: make([]bool, nblocks),
|
||||||
|
reqs: make([]bool, nblocks),
|
||||||
|
reqPeer: make([]*bt.Peer, nblocks),
|
||||||
|
}
|
||||||
|
|
||||||
|
data := m.Info.Read(dir, i, pieces[i].size)
|
||||||
|
if data != nil && sha1.Sum(data) == pieces[i].hash {
|
||||||
|
pieces[i].data = data
|
||||||
|
pieces[i].done = true
|
||||||
|
verified++
|
||||||
|
} else {
|
||||||
|
incomplete = append(incomplete, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pieceScheduler{
|
||||||
|
pieces: pieces,
|
||||||
|
incomplete: incomplete,
|
||||||
|
verified: verified,
|
||||||
|
meta: m,
|
||||||
|
dir: dir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *pieceScheduler) nextRequest(p *bt.Peer) *bt.Msg {
|
||||||
|
for _, idx := range ps.incomplete {
|
||||||
|
pc := &ps.pieces[idx]
|
||||||
|
if pc.done {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !p.HasPiece(idx) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if begin, length, ok := pc.next(p); ok {
|
||||||
|
return &bt.Msg{
|
||||||
|
Kind: bt.Request,
|
||||||
|
Index: uint32(idx),
|
||||||
|
Begin: uint32(begin),
|
||||||
|
Length: uint32(length),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *pieceScheduler) put(index int, begin int, block []byte) bool {
|
||||||
|
pc := &ps.pieces[index]
|
||||||
|
if pc.put(begin, block) {
|
||||||
|
ps.meta.Info.Write(ps.dir, index, pc.data)
|
||||||
|
ps.verified++
|
||||||
|
ps.removeIncomplete(index)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *pieceScheduler) removeIncomplete(index int) {
|
||||||
|
for i, idx := range ps.incomplete {
|
||||||
|
if idx == index {
|
||||||
|
ps.incomplete = append(ps.incomplete[:i], ps.incomplete[i+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *pieceScheduler) peerDisconnected(p *bt.Peer) {
|
||||||
|
for i := range ps.pieces {
|
||||||
|
ps.pieces[i].resetPeer(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *pieceScheduler) isComplete() bool {
|
||||||
|
return ps.verified == len(ps.pieces)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *pieceScheduler) progress() (verified, total int) {
|
||||||
|
return ps.verified, len(ps.pieces)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *pieceScheduler) pieceData(index int) []byte {
|
||||||
|
pc := &ps.pieces[index]
|
||||||
|
if !pc.done {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return pc.data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *pieceScheduler) count() int {
|
||||||
|
return len(ps.pieces)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *pieceScheduler) bitfield(all bool) []byte {
|
||||||
|
n := len(ps.pieces)
|
||||||
|
bf := make([]byte, (n+7)/8)
|
||||||
|
for i := range ps.pieces {
|
||||||
|
if all || ps.pieces[i].done {
|
||||||
|
bt.SetBit(bf, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bf
|
||||||
|
}
|
||||||
419
client/torrent.go
Normal file
419
client/torrent.go
Normal file
@ -0,0 +1,419 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"storrent/bt"
|
||||||
|
"storrent/dht"
|
||||||
|
"storrent/metainfo"
|
||||||
|
"storrent/utp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type state int
|
||||||
|
|
||||||
|
const (
|
||||||
|
stopped state = iota
|
||||||
|
downloading
|
||||||
|
seeding
|
||||||
|
done
|
||||||
|
errored
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxPeers = 8
|
||||||
|
maxPending = 5
|
||||||
|
dialTimeout = 30 * time.Second
|
||||||
|
retryInterval = 30 * time.Second
|
||||||
|
dhtTimeout = 5 * time.Minute
|
||||||
|
recvBufSize = 128
|
||||||
|
)
|
||||||
|
|
||||||
|
type recv struct {
|
||||||
|
msg *bt.Msg
|
||||||
|
err error
|
||||||
|
p *bt.Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
type statusReq struct {
|
||||||
|
field string
|
||||||
|
resp chan string
|
||||||
|
}
|
||||||
|
|
||||||
|
type torrent struct {
|
||||||
|
id int
|
||||||
|
meta *metainfo.File
|
||||||
|
dir string
|
||||||
|
dht *dht.DHT
|
||||||
|
utp *utp.Socket
|
||||||
|
|
||||||
|
start chan struct{}
|
||||||
|
stop chan struct{}
|
||||||
|
seed chan struct{}
|
||||||
|
incoming chan *bt.Peer
|
||||||
|
addPeer chan *net.TCPAddr
|
||||||
|
status chan statusReq
|
||||||
|
|
||||||
|
peers *peerManager
|
||||||
|
scheduler *pieceScheduler
|
||||||
|
dialPool *dialPool
|
||||||
|
|
||||||
|
st state
|
||||||
|
down int64
|
||||||
|
up int64
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
recvs chan recv
|
||||||
|
retry <-chan time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTorrent(id int, m *metainfo.File, dir string, d *dht.DHT, utpSock *utp.Socket) *torrent {
|
||||||
|
t := &torrent{
|
||||||
|
id: id,
|
||||||
|
meta: m,
|
||||||
|
dir: dir,
|
||||||
|
dht: d,
|
||||||
|
utp: utpSock,
|
||||||
|
start: make(chan struct{}, 1),
|
||||||
|
stop: make(chan struct{}, 1),
|
||||||
|
seed: make(chan struct{}, 1),
|
||||||
|
incoming: make(chan *bt.Peer),
|
||||||
|
addPeer: make(chan *net.TCPAddr, 8),
|
||||||
|
status: make(chan statusReq),
|
||||||
|
peers: newPeerManager(maxPeers),
|
||||||
|
scheduler: newPieceScheduler(m, dir),
|
||||||
|
dialPool: newDialPool(d, utpSock, m.InfoHash),
|
||||||
|
st: stopped,
|
||||||
|
recvs: make(chan recv, recvBufSize),
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) run() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.start:
|
||||||
|
t.handleStart()
|
||||||
|
|
||||||
|
case <-t.seed:
|
||||||
|
t.handleSeed()
|
||||||
|
|
||||||
|
case <-t.stop:
|
||||||
|
t.handleStop()
|
||||||
|
|
||||||
|
case <-t.retry:
|
||||||
|
t.handleRetry()
|
||||||
|
|
||||||
|
case d := <-t.dialPool.results:
|
||||||
|
t.handleDialResult(d)
|
||||||
|
|
||||||
|
case p := <-t.incoming:
|
||||||
|
t.handleIncoming(p)
|
||||||
|
|
||||||
|
case addr := <-t.addPeer:
|
||||||
|
t.handleAddPeer(addr)
|
||||||
|
|
||||||
|
case r := <-t.recvs:
|
||||||
|
t.handleRecv(r)
|
||||||
|
|
||||||
|
case req := <-t.status:
|
||||||
|
req.resp <- t.handleStatus(req.field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) handleStart() {
|
||||||
|
if t.st != stopped {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.st = downloading
|
||||||
|
t.ctx, t.cancel = context.WithCancel(context.Background())
|
||||||
|
t.dialPool.start(t.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) handleSeed() {
|
||||||
|
if t.st != stopped && t.st != done {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !t.scheduler.isComplete() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.st = seeding
|
||||||
|
t.ctx, t.cancel = context.WithCancel(context.Background())
|
||||||
|
t.dialPool.start(t.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) handleStop() {
|
||||||
|
if !t.isActive() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.cancel()
|
||||||
|
t.dialPool.stop()
|
||||||
|
t.peers.closeAll()
|
||||||
|
t.retry = nil
|
||||||
|
t.st = stopped
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) handleRetry() {
|
||||||
|
t.retry = nil
|
||||||
|
if t.isActive() {
|
||||||
|
t.dialPool.start(t.ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) handleDialResult(d dialResult) {
|
||||||
|
if d.done {
|
||||||
|
if t.peers.count() < maxPeers && t.isActive() {
|
||||||
|
t.retry = time.After(retryInterval)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.peer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p := d.peer
|
||||||
|
if !t.peers.add(p) {
|
||||||
|
p.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go t.recvLoop(p)
|
||||||
|
if err := t.initPeer(p); err != nil {
|
||||||
|
p.Close()
|
||||||
|
t.peers.remove(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) handleIncoming(p *bt.Peer) {
|
||||||
|
t.peers.addUnlimited(p)
|
||||||
|
go t.recvLoop(p)
|
||||||
|
|
||||||
|
if err := p.Send(&bt.Msg{Kind: bt.Bitfield, Bitfield: t.scheduler.bitfield(false)}); err != nil {
|
||||||
|
p.Close()
|
||||||
|
t.peers.remove(p)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := p.Send(&bt.Msg{Kind: bt.Unchoke}); err != nil {
|
||||||
|
p.Close()
|
||||||
|
t.peers.remove(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) handleAddPeer(addr *net.TCPAddr) {
|
||||||
|
if !t.isActive() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.dialPool.dialSingle(t.ctx, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) handleRecv(r recv) {
|
||||||
|
if r.err != nil {
|
||||||
|
r.p.Close()
|
||||||
|
t.peers.remove(r.p)
|
||||||
|
t.scheduler.peerDisconnected(r.p)
|
||||||
|
if t.peers.count() < maxPeers && t.isActive() {
|
||||||
|
t.dialPool.start(t.ctx)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.msg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p := r.p
|
||||||
|
m := r.msg
|
||||||
|
|
||||||
|
switch m.Kind {
|
||||||
|
case bt.Choke:
|
||||||
|
p.Choked = true
|
||||||
|
p.Pending = 0
|
||||||
|
|
||||||
|
case bt.Unchoke:
|
||||||
|
p.Choked = false
|
||||||
|
|
||||||
|
case bt.Have:
|
||||||
|
p.SetPiece(int(m.Index))
|
||||||
|
|
||||||
|
case bt.Bitfield:
|
||||||
|
p.SetPieces(m.Bitfield, t.scheduler.count())
|
||||||
|
|
||||||
|
case bt.Piece:
|
||||||
|
t.scheduler.put(int(m.Index), int(m.Begin), m.Block)
|
||||||
|
p.Pending--
|
||||||
|
t.down += int64(len(m.Block))
|
||||||
|
|
||||||
|
case bt.Request:
|
||||||
|
t.handleRequest(p, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.sendRequests(p)
|
||||||
|
t.checkCompletion()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) handleRequest(p *bt.Peer, m *bt.Msg) {
|
||||||
|
data := t.scheduler.pieceData(int(m.Index))
|
||||||
|
if data == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
end := int(m.Begin + m.Length)
|
||||||
|
if end > len(data) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := p.Send(&bt.Msg{
|
||||||
|
Kind: bt.Piece,
|
||||||
|
Index: m.Index,
|
||||||
|
Begin: m.Begin,
|
||||||
|
Block: data[m.Begin:end],
|
||||||
|
}); err != nil {
|
||||||
|
p.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.up += int64(m.Length)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) sendRequests(p *bt.Peer) {
|
||||||
|
if p.Choked {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for p.Pending < maxPending {
|
||||||
|
req := t.scheduler.nextRequest(p)
|
||||||
|
if req == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err := p.Send(req); err != nil {
|
||||||
|
p.Close()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
p.Pending++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) checkCompletion() {
|
||||||
|
if t.st == downloading && t.scheduler.isComplete() {
|
||||||
|
t.st = done
|
||||||
|
t.peers.closeAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) initPeer(p *bt.Peer) error {
|
||||||
|
if t.st == seeding {
|
||||||
|
return p.Send(&bt.Msg{Kind: bt.Bitfield, Bitfield: t.scheduler.bitfield(true)})
|
||||||
|
}
|
||||||
|
return p.Send(&bt.Msg{Kind: bt.Interested})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) recvLoop(p *bt.Peer) {
|
||||||
|
for {
|
||||||
|
msg, err := p.Recv()
|
||||||
|
t.recvs <- recv{msg, err, p}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) isActive() bool {
|
||||||
|
return t.st == downloading || t.st == seeding
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) startDownload() {
|
||||||
|
select {
|
||||||
|
case t.start <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) startSeed() {
|
||||||
|
select {
|
||||||
|
case t.seed <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) stopTorrent() {
|
||||||
|
select {
|
||||||
|
case t.stop <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) enqueuePeer(addr *net.TCPAddr) {
|
||||||
|
t.addPeer <- addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) getStatus(field string) string {
|
||||||
|
ch := make(chan string)
|
||||||
|
t.status <- statusReq{field, ch}
|
||||||
|
return <-ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) handleStatus(field string) string {
|
||||||
|
switch field {
|
||||||
|
case "name":
|
||||||
|
return t.meta.Info.Name
|
||||||
|
case "state":
|
||||||
|
return t.st.String()
|
||||||
|
case "progress":
|
||||||
|
verified, total := t.scheduler.progress()
|
||||||
|
if total == 0 {
|
||||||
|
return "0"
|
||||||
|
}
|
||||||
|
return strconv.Itoa(verified * 100 / total)
|
||||||
|
case "size":
|
||||||
|
return strconv.FormatInt(t.meta.Size, 10)
|
||||||
|
case "down":
|
||||||
|
return strconv.FormatInt(t.down, 10)
|
||||||
|
case "up":
|
||||||
|
return strconv.FormatInt(t.up, 10)
|
||||||
|
case "pieces":
|
||||||
|
verified, total := t.scheduler.progress()
|
||||||
|
return fmt.Sprintf("%d/%d", verified, total)
|
||||||
|
case "peers":
|
||||||
|
return t.formatPeers()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *torrent) formatPeers() string {
|
||||||
|
var buf strings.Builder
|
||||||
|
npieces := t.scheduler.count()
|
||||||
|
for _, p := range t.peers.all() {
|
||||||
|
st := "unchoked"
|
||||||
|
if p.Choked {
|
||||||
|
st = "choked"
|
||||||
|
}
|
||||||
|
have := 0
|
||||||
|
for _, h := range p.Have {
|
||||||
|
if h {
|
||||||
|
have++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&buf, "%s %s %s have:%d/%d pending:%d\n",
|
||||||
|
p.Addr, p.PeerID[:8], st, have, npieces, p.Pending)
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s state) String() string {
|
||||||
|
switch s {
|
||||||
|
case stopped:
|
||||||
|
return "stopped"
|
||||||
|
case downloading:
|
||||||
|
return "downloading"
|
||||||
|
case seeding:
|
||||||
|
return "seeding"
|
||||||
|
case done:
|
||||||
|
return "done"
|
||||||
|
case errored:
|
||||||
|
return "error"
|
||||||
|
}
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
121
client/torrent_test.go
Normal file
121
client/torrent_test.go
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"storrent/bt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPut(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
in string
|
||||||
|
hash [20]byte
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
in: "hello",
|
||||||
|
hash: sha1.Sum([]byte("hello")),
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "wrong",
|
||||||
|
hash: sha1.Sum([]byte("hello")),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.in, func(t *testing.T) {
|
||||||
|
p := &piece{
|
||||||
|
hash: tt.hash,
|
||||||
|
size: int64(len(tt.in)),
|
||||||
|
data: make([]byte, len(tt.in)),
|
||||||
|
have: make([]bool, 1),
|
||||||
|
reqs: make([]bool, 1),
|
||||||
|
reqPeer: make([]*bt.Peer, 1),
|
||||||
|
}
|
||||||
|
if got := p.put(0, []byte(tt.in)); got != tt.want {
|
||||||
|
t.Errorf("got %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPutMultiple(t *testing.T) {
|
||||||
|
data := make([]byte, bt.BlockSize+5)
|
||||||
|
for i := range data {
|
||||||
|
data[i] = byte(i)
|
||||||
|
}
|
||||||
|
p := &piece{
|
||||||
|
hash: sha1.Sum(data),
|
||||||
|
size: int64(len(data)),
|
||||||
|
data: make([]byte, len(data)),
|
||||||
|
have: make([]bool, 2),
|
||||||
|
reqs: make([]bool, 2),
|
||||||
|
reqPeer: make([]*bt.Peer, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
begin int
|
||||||
|
data []byte
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "first block",
|
||||||
|
begin: 0,
|
||||||
|
data: data[:bt.BlockSize],
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second block",
|
||||||
|
begin: bt.BlockSize,
|
||||||
|
data: data[bt.BlockSize:],
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := p.put(tt.begin, tt.data); got != tt.want {
|
||||||
|
t.Errorf("got %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pieceSize int64
|
||||||
|
have []bool
|
||||||
|
wantLength uint32
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "last block partial",
|
||||||
|
pieceSize: bt.BlockSize + 100,
|
||||||
|
have: []bool{true, false},
|
||||||
|
wantLength: 100,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ps := &pieceScheduler{
|
||||||
|
pieces: []piece{{
|
||||||
|
size: tt.pieceSize,
|
||||||
|
have: tt.have,
|
||||||
|
reqs: make([]bool, len(tt.have)),
|
||||||
|
reqPeer: make([]*bt.Peer, len(tt.have)),
|
||||||
|
}},
|
||||||
|
incomplete: []int{0},
|
||||||
|
}
|
||||||
|
c := &bt.Peer{Have: []bool{true}}
|
||||||
|
msg := ps.nextRequest(c)
|
||||||
|
if msg == nil {
|
||||||
|
t.Fatal("expected request")
|
||||||
|
}
|
||||||
|
if msg.Length != tt.wantLength {
|
||||||
|
t.Errorf("got %d, want %d", msg.Length, tt.wantLength)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
228
dht/dht.go
Normal file
228
dht/dht.go
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
package dht
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
queryTimeout = 5 * time.Second
|
||||||
|
alpha = 3
|
||||||
|
maxQueries = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
type DHT struct {
|
||||||
|
conn *net.UDPConn
|
||||||
|
id [20]byte
|
||||||
|
nodes []node
|
||||||
|
pending map[string]chan *resp
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(port int) (*DHT, error) {
|
||||||
|
addr := &net.UDPAddr{Port: port}
|
||||||
|
conn, err := net.ListenUDP("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
d := &DHT{
|
||||||
|
conn: conn,
|
||||||
|
id: genNodeID(),
|
||||||
|
pending: make(map[string]chan *resp),
|
||||||
|
}
|
||||||
|
go d.run()
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DHT) Close() {
|
||||||
|
d.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DHT) run() {
|
||||||
|
buf := make([]byte, 65535)
|
||||||
|
for {
|
||||||
|
n, _, err := d.conn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg, err := decodeMsg(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
d.mu.Lock()
|
||||||
|
ch, ok := d.pending[msg.T]
|
||||||
|
if ok {
|
||||||
|
delete(d.pending, msg.T)
|
||||||
|
}
|
||||||
|
d.mu.Unlock()
|
||||||
|
if ok {
|
||||||
|
ch <- &msg.R
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func xorDistance(a, b [20]byte) [20]byte {
|
||||||
|
var dist [20]byte
|
||||||
|
for i := range 20 {
|
||||||
|
dist[i] = a[i] ^ b[i]
|
||||||
|
}
|
||||||
|
return dist
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareDist(a, b [20]byte) int {
|
||||||
|
for i := range 20 {
|
||||||
|
if a[i] < b[i] {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if a[i] > b[i] {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DHT) query(ctx context.Context, addr *net.UDPAddr, q string, a args) (*resp, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, queryTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
txid := genTxID()
|
||||||
|
a.ID = d.id
|
||||||
|
m := &msg{
|
||||||
|
T: txid,
|
||||||
|
Y: query,
|
||||||
|
Q: q,
|
||||||
|
A: a,
|
||||||
|
}
|
||||||
|
data, err := m.Encode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reply := make(chan *resp, 1)
|
||||||
|
d.mu.Lock()
|
||||||
|
d.pending[txid] = reply
|
||||||
|
d.mu.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
d.mu.Lock()
|
||||||
|
delete(d.pending, txid)
|
||||||
|
d.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if _, err := d.conn.WriteToUDP(data, addr); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case r := <-reply:
|
||||||
|
return r, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DHT) Bootstrap(ctx context.Context, addr string) error {
|
||||||
|
host, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ips, err := net.LookupHost(host)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(ips) == 0 {
|
||||||
|
return fmt.Errorf("no bootstrap nodes")
|
||||||
|
}
|
||||||
|
p, err := strconv.Atoi(port)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid port: %w", err)
|
||||||
|
}
|
||||||
|
uaddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP(ips[0]),
|
||||||
|
Port: p,
|
||||||
|
}
|
||||||
|
resp, err := d.query(ctx, uaddr, findNode, args{Target: d.id})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
d.nodes = append(d.nodes, resp.Nodes...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortByDist(nodes []node, target [20]byte) {
|
||||||
|
slices.SortFunc(nodes, func(a, b node) int {
|
||||||
|
return compareDist(xorDistance(a.ID, target), xorDistance(b.ID, target))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DHT) GetPeers(ctx context.Context, h [20]byte) ([]*net.TCPAddr, error) {
|
||||||
|
queried := make(map[string]bool)
|
||||||
|
candidates := make([]node, len(d.nodes))
|
||||||
|
copy(candidates, d.nodes)
|
||||||
|
sortByDist(candidates, h)
|
||||||
|
|
||||||
|
results := make(chan *resp)
|
||||||
|
inflight := 0
|
||||||
|
queryCount := 0
|
||||||
|
var peers []*net.TCPAddr
|
||||||
|
|
||||||
|
for queryCount < maxQueries {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
for inflight < alpha && queryCount < maxQueries {
|
||||||
|
n, i := nextCandidate(candidates, queried)
|
||||||
|
if i < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
candidates = slices.Delete(candidates, i, i+1)
|
||||||
|
queried[n.Addr.String()] = true
|
||||||
|
inflight++
|
||||||
|
queryCount++
|
||||||
|
go func(addr *net.UDPAddr) {
|
||||||
|
r, _ := d.query(ctx, addr, getPeers, args{InfoHash: h})
|
||||||
|
results <- r
|
||||||
|
}(n.Addr)
|
||||||
|
}
|
||||||
|
if inflight == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
r := <-results
|
||||||
|
inflight--
|
||||||
|
if r == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, p := range r.Peers {
|
||||||
|
if a := decodePeer(p); a != nil {
|
||||||
|
peers = append(peers, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, c := range r.Nodes {
|
||||||
|
if !queried[c.Addr.String()] {
|
||||||
|
candidates = append(candidates, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sortByDist(candidates, h)
|
||||||
|
}
|
||||||
|
for range inflight {
|
||||||
|
<-results
|
||||||
|
}
|
||||||
|
if len(peers) > 0 {
|
||||||
|
return peers, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("no peers found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func nextCandidate(candidates []node, queried map[string]bool) (node, int) {
|
||||||
|
for i, c := range candidates {
|
||||||
|
if !queried[c.Addr.String()] {
|
||||||
|
return c, i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return node{}, -1
|
||||||
|
}
|
||||||
246
dht/msg.go
Normal file
246
dht/msg.go
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
package dht
|
||||||
|
|
||||||
|
// https://www.bittorrent.org/beps/bep_0005.html
|
||||||
|
// Peer: TCP
|
||||||
|
// Node: UDP
|
||||||
|
// BEP 5: DHT Protocol - KRPC over UDP
|
||||||
|
// Message: bencode dict with keys:
|
||||||
|
// t: transaction id (2 bytes)
|
||||||
|
// y: type "q" (query), "r" (response), "e" (error)
|
||||||
|
// q: query name (ping, find_node, get_peers, announce_peer)
|
||||||
|
// a: query args dict
|
||||||
|
// r: response dict
|
||||||
|
// e: error [code, message]
|
||||||
|
// Compact formats:
|
||||||
|
// peer: 6 bytes (4 ip + 2 port)
|
||||||
|
// node: 26 bytes (20 id + 4 ip + 2 port)
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"storrent/bencode"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
query = "q"
|
||||||
|
response = "r"
|
||||||
|
error_ = "e"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ping = "ping"
|
||||||
|
findNode = "find_node"
|
||||||
|
getPeers = "get_peers"
|
||||||
|
announcePeer = "announce_peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type node struct {
|
||||||
|
ID [20]byte
|
||||||
|
Addr *net.UDPAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
ID [20]byte
|
||||||
|
Target [20]byte
|
||||||
|
InfoHash [20]byte
|
||||||
|
Port int
|
||||||
|
Token string
|
||||||
|
ImpliedPort int
|
||||||
|
}
|
||||||
|
|
||||||
|
type resp struct {
|
||||||
|
ID [20]byte
|
||||||
|
Nodes []node
|
||||||
|
Peers []string
|
||||||
|
Token string
|
||||||
|
}
|
||||||
|
|
||||||
|
type errResp struct {
|
||||||
|
Code int
|
||||||
|
Msg string
|
||||||
|
}
|
||||||
|
|
||||||
|
type msg struct {
|
||||||
|
T string
|
||||||
|
Y string
|
||||||
|
Q string
|
||||||
|
A args
|
||||||
|
R resp
|
||||||
|
E errResp
|
||||||
|
}
|
||||||
|
|
||||||
|
func genTxID() string {
|
||||||
|
b := make([]byte, 2)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func genNodeID() [20]byte {
|
||||||
|
var id [20]byte
|
||||||
|
if _, err := rand.Read(id[:]); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *msg) Encode() ([]byte, error) {
|
||||||
|
d := make(map[string]any)
|
||||||
|
d["t"] = m.T
|
||||||
|
d["y"] = m.Y
|
||||||
|
|
||||||
|
switch m.Y {
|
||||||
|
case query:
|
||||||
|
d["q"] = m.Q
|
||||||
|
a := make(map[string]any)
|
||||||
|
a["id"] = string(m.A.ID[:])
|
||||||
|
switch m.Q {
|
||||||
|
case findNode:
|
||||||
|
a["target"] = string(m.A.Target[:])
|
||||||
|
case getPeers:
|
||||||
|
a["info_hash"] = string(m.A.InfoHash[:])
|
||||||
|
case announcePeer:
|
||||||
|
a["info_hash"] = string(m.A.InfoHash[:])
|
||||||
|
a["port"] = int64(m.A.Port)
|
||||||
|
a["token"] = m.A.Token
|
||||||
|
if m.A.ImpliedPort != 0 {
|
||||||
|
a["implied_port"] = int64(m.A.ImpliedPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d["a"] = a
|
||||||
|
case response:
|
||||||
|
r := make(map[string]any)
|
||||||
|
r["id"] = string(m.R.ID[:])
|
||||||
|
if len(m.R.Nodes) > 0 {
|
||||||
|
r["nodes"] = encodeNodes(m.R.Nodes)
|
||||||
|
}
|
||||||
|
if len(m.R.Peers) > 0 {
|
||||||
|
vals := make([]any, len(m.R.Peers))
|
||||||
|
for i, p := range m.R.Peers {
|
||||||
|
vals[i] = p
|
||||||
|
}
|
||||||
|
r["values"] = vals
|
||||||
|
}
|
||||||
|
if m.R.Token != "" {
|
||||||
|
r["token"] = m.R.Token
|
||||||
|
}
|
||||||
|
d["r"] = r
|
||||||
|
case error_:
|
||||||
|
d["e"] = []any{int64(m.E.Code), m.E.Msg}
|
||||||
|
}
|
||||||
|
return bencode.Encode(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeMsg(data []byte) (*msg, error) {
|
||||||
|
v, _, err := bencode.Decode(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
d, ok := v.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid krpc message")
|
||||||
|
}
|
||||||
|
|
||||||
|
m := &msg{}
|
||||||
|
m.T, _ = d["t"].(string)
|
||||||
|
m.Y, _ = d["y"].(string)
|
||||||
|
switch m.Y {
|
||||||
|
case query:
|
||||||
|
m.Q, _ = d["q"].(string)
|
||||||
|
if a, ok := d["a"].(map[string]any); ok {
|
||||||
|
if id, ok := a["id"].(string); ok && len(id) == 20 {
|
||||||
|
copy(m.A.ID[:], id)
|
||||||
|
}
|
||||||
|
if target, ok := a["target"].(string); ok && len(target) == 20 {
|
||||||
|
copy(m.A.Target[:], target)
|
||||||
|
}
|
||||||
|
if ih, ok := a["info_hash"].(string); ok && len(ih) == 20 {
|
||||||
|
copy(m.A.InfoHash[:], ih)
|
||||||
|
}
|
||||||
|
if port, ok := a["port"].(int64); ok {
|
||||||
|
m.A.Port = int(port)
|
||||||
|
}
|
||||||
|
if token, ok := a["token"].(string); ok {
|
||||||
|
m.A.Token = token
|
||||||
|
}
|
||||||
|
if iport, ok := a["implied_port"].(int64); ok {
|
||||||
|
m.A.ImpliedPort = int(iport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case response:
|
||||||
|
if r, ok := d["r"].(map[string]any); ok {
|
||||||
|
if id, ok := r["id"].(string); ok && len(id) == 20 {
|
||||||
|
copy(m.R.ID[:], id)
|
||||||
|
}
|
||||||
|
if nodes, ok := r["nodes"].(string); ok {
|
||||||
|
m.R.Nodes = decodeNodes(nodes)
|
||||||
|
}
|
||||||
|
if vals, ok := r["values"].([]any); ok {
|
||||||
|
for _, v := range vals {
|
||||||
|
if p, ok := v.(string); ok {
|
||||||
|
m.R.Peers = append(m.R.Peers, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if token, ok := r["token"].(string); ok {
|
||||||
|
m.R.Token = token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case error_:
|
||||||
|
if e, ok := d["e"].([]any); ok && len(e) >= 2 {
|
||||||
|
if code, ok := e[0].(int64); ok {
|
||||||
|
m.E.Code = int(code)
|
||||||
|
}
|
||||||
|
m.E.Msg, _ = e[1].(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 26 = 20 id + 4 ip + 2 port
|
||||||
|
func encodeNodes(nodes []node) string {
|
||||||
|
buf := make([]byte, 26*len(nodes))
|
||||||
|
for i, n := range nodes {
|
||||||
|
off := i * 26
|
||||||
|
copy(buf[off:], n.ID[:])
|
||||||
|
copy(buf[off+20:], n.Addr.IP.To4())
|
||||||
|
binary.BigEndian.PutUint16(buf[off+24:], uint16(n.Addr.Port))
|
||||||
|
}
|
||||||
|
return string(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeNodes(s string) []node {
|
||||||
|
data := []byte(s)
|
||||||
|
if len(data)%26 != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
n := len(data) / 26
|
||||||
|
nodes := make([]node, n)
|
||||||
|
for i := range nodes {
|
||||||
|
off := i * 26
|
||||||
|
copy(nodes[i].ID[:], data[off:])
|
||||||
|
ip := net.IP(data[off+20 : off+24])
|
||||||
|
port := binary.BigEndian.Uint16(data[off+24:])
|
||||||
|
nodes[i].Addr = &net.UDPAddr{IP: ip, Port: int(port)}
|
||||||
|
}
|
||||||
|
return nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6 = 4 ip + 2 port
|
||||||
|
func decodePeer(s string) *net.TCPAddr {
|
||||||
|
if len(s) != 6 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data := []byte(s)
|
||||||
|
return &net.TCPAddr{
|
||||||
|
IP: net.IP(data[:4]),
|
||||||
|
Port: int(binary.BigEndian.Uint16(data[4:])),
|
||||||
|
}
|
||||||
|
}
|
||||||
67
dht/msg_test.go
Normal file
67
dht/msg_test.go
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
package dht
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func nid(c string) [20]byte {
|
||||||
|
var id [20]byte
|
||||||
|
copy(id[:], strings.Repeat(c, 20))
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeNodes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
nodes []node
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
nodes: []node{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single",
|
||||||
|
nodes: []node{
|
||||||
|
{
|
||||||
|
ID: nid("a"),
|
||||||
|
Addr: &net.UDPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 6881,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple",
|
||||||
|
nodes: []node{
|
||||||
|
{
|
||||||
|
ID: nid("a"),
|
||||||
|
Addr: &net.UDPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 6881,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: nid("a"),
|
||||||
|
Addr: &net.UDPAddr{
|
||||||
|
IP: net.IPv4(192, 168, 1, 1),
|
||||||
|
Port: 8080,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
data := encodeNodes(tt.nodes)
|
||||||
|
got := decodeNodes(data)
|
||||||
|
data2 := encodeNodes(got)
|
||||||
|
if data != data2 {
|
||||||
|
t.Errorf("roundtrip mismatch: got %x, want %x", data2, data)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
156
fs/fs.go
Normal file
156
fs/fs.go
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
package fs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/knusbaum/go9p/fs"
|
||||||
|
"github.com/knusbaum/go9p/proto"
|
||||||
|
|
||||||
|
"storrent/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dir struct {
|
||||||
|
stat proto.Stat
|
||||||
|
parent fs.Dir
|
||||||
|
man *client.Manager
|
||||||
|
fsys *fs.FS
|
||||||
|
sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(e *client.Manager) *fs.FS {
|
||||||
|
fsys, root := fs.NewFS("storrent", "storrent", 0555)
|
||||||
|
root.AddChild(&fs.WrappedFile{
|
||||||
|
File: fs.NewBaseFile(fsys.NewStat("ctl", "storrent", "storrent", 0222)),
|
||||||
|
WriteF: func(fid uint64, offset uint64, data []byte) (uint32, error) {
|
||||||
|
cmd := strings.TrimSpace(string(data))
|
||||||
|
if strings.HasPrefix(cmd, "add ") {
|
||||||
|
path := strings.TrimSpace(cmd[4:])
|
||||||
|
_, err := e.Add(path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return uint32(len(data)), nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("unknown command: %s", cmd)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
root.AddChild(fs.NewDynamicFile(
|
||||||
|
fsys.NewStat("list", "storrent", "storrent", 0444),
|
||||||
|
func() []byte {
|
||||||
|
data := e.List()
|
||||||
|
if len(data) > 0 {
|
||||||
|
return append(data, '\n')
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
dir := &Dir{
|
||||||
|
stat: *fsys.NewStat("torrents", "storrent", "storrent", 0555|proto.DMDIR),
|
||||||
|
man: e,
|
||||||
|
fsys: fsys,
|
||||||
|
}
|
||||||
|
dir.stat.Qid.Qtype = uint8(dir.stat.Mode >> 24)
|
||||||
|
root.AddChild(dir)
|
||||||
|
return fsys
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dir) Stat() proto.Stat {
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
return d.stat
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dir) WriteStat(s *proto.Stat) error {
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
d.stat = *s
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dir) SetParent(p fs.Dir) {
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
d.parent = p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dir) Parent() fs.Dir {
|
||||||
|
d.RLock()
|
||||||
|
defer d.RUnlock()
|
||||||
|
return d.parent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dir) Children() map[string]fs.FSNode {
|
||||||
|
data := d.man.List()
|
||||||
|
m := make(map[string]fs.FSNode)
|
||||||
|
if len(data) == 0 {
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
for _, name := range strings.Split(string(data), "\n") {
|
||||||
|
id, err := strconv.Atoi(name)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m[name] = d.newTorrentDir(id)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dir) newTorrentDir(id int) fs.FSNode {
|
||||||
|
dir := fs.NewStaticDir(d.fsys.NewStat(strconv.Itoa(id), "storrent", "storrent", 0555))
|
||||||
|
dir.AddChild(&fs.WrappedFile{
|
||||||
|
File: fs.NewBaseFile(d.fsys.NewStat("ctl", "storrent", "storrent", 0222)),
|
||||||
|
WriteF: func(fid uint64, offset uint64, data []byte) (uint32, error) {
|
||||||
|
cmd := strings.TrimSpace(string(data))
|
||||||
|
var err error
|
||||||
|
switch {
|
||||||
|
case cmd == "start":
|
||||||
|
err = d.man.Start(id)
|
||||||
|
case cmd == "stop":
|
||||||
|
err = d.man.Stop(id)
|
||||||
|
case cmd == "seed":
|
||||||
|
err = d.man.Seed(id)
|
||||||
|
case cmd == "remove":
|
||||||
|
err = d.man.Remove(id)
|
||||||
|
case strings.HasPrefix(cmd, "peer "):
|
||||||
|
err = d.man.AddPeer(id, strings.TrimSpace(cmd[5:]))
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unknown command: %s", cmd)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return uint32(len(data)), nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
dir.AddChild(d.newStatusFile(id, "name"))
|
||||||
|
dir.AddChild(d.newStatusFile(id, "state"))
|
||||||
|
dir.AddChild(d.newStatusFile(id, "progress"))
|
||||||
|
dir.AddChild(d.newStatusFile(id, "size"))
|
||||||
|
dir.AddChild(d.newStatusFile(id, "down"))
|
||||||
|
dir.AddChild(d.newStatusFile(id, "up"))
|
||||||
|
dir.AddChild(d.newStatusFile(id, "pieces"))
|
||||||
|
dir.AddChild(d.newStatusFile(id, "peers"))
|
||||||
|
return dir
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dir) newStatusFile(id int, field string) fs.FSNode {
|
||||||
|
return fs.NewDynamicFile(
|
||||||
|
d.fsys.NewStat(field, "storrent", "storrent", 0444),
|
||||||
|
func() []byte {
|
||||||
|
data, err := d.man.Status(id, field)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(data) > 0 {
|
||||||
|
return append(data, '\n')
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
12
go.mod
Normal file
12
go.mod
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
module storrent
|
||||||
|
|
||||||
|
go 1.25.4
|
||||||
|
|
||||||
|
require github.com/knusbaum/go9p v1.18.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
9fans.net/go v0.0.2 // indirect
|
||||||
|
github.com/Plan9-Archive/libauth v0.0.0-20180917063427-d1ca9e94969d // indirect
|
||||||
|
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 // indirect
|
||||||
|
github.com/fhs/mux9p v0.3.1 // indirect
|
||||||
|
)
|
||||||
25
go.sum
Normal file
25
go.sum
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
9fans.net/go v0.0.2 h1:RYM6lWITV8oADrwLfdzxmt8ucfW6UtP9v1jg4qAbqts=
|
||||||
|
9fans.net/go v0.0.2/go.mod h1:lfPdxjq9v8pVQXUMBCx5EO5oLXWQFlKRQgs1kEkjoIM=
|
||||||
|
github.com/Plan9-Archive/libauth v0.0.0-20180917063427-d1ca9e94969d h1:xH/U6K+HYxh1480TkQYRqRO8F2RJsg+R6wFiVJzdldg=
|
||||||
|
github.com/Plan9-Archive/libauth v0.0.0-20180917063427-d1ca9e94969d/go.mod h1:UKp8dv9aeaZoQFWin7eQXtz89iHly1YAFZNn3MCutmQ=
|
||||||
|
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 h1:OJyUGMJTzHTd1XQp98QTaHernxMYzRaOasRir9hUlFQ=
|
||||||
|
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
|
||||||
|
github.com/fhs/mux9p v0.3.1 h1:x1UswUWZoA9vrA02jfisndCq3xQm+wrQUxUt5N99E08=
|
||||||
|
github.com/fhs/mux9p v0.3.1/go.mod h1:F4hwdenmit0WDoNVT2VMWlLJrBVCp/8UhzJa7scfjEQ=
|
||||||
|
github.com/hanwen/go-fuse v1.0.0/go.mod h1:unqXarDXqzAk0rt98O2tVndEPIpUgLD9+rwFisZH3Ok=
|
||||||
|
github.com/hanwen/go-fuse/v2 v2.0.3/go.mod h1:0EQM6aH2ctVpvZ6a+onrQ/vaykxh2GH7hy3e13vzTUY=
|
||||||
|
github.com/knusbaum/go9p v1.18.0 h1:/Y67RNvNKX1ZV1IOdnO1lIetiF0X+CumOyvEc0011GI=
|
||||||
|
github.com/knusbaum/go9p v1.18.0/go.mod h1:HtMoJKqZUe1Oqag5uJqG5RKQ9gWPSP+wolsnLLv44r8=
|
||||||
|
github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||||
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20201020230747-6e5568b54d1a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||||
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
49
main.go
Normal file
49
main.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/knusbaum/go9p"
|
||||||
|
|
||||||
|
"storrent/client"
|
||||||
|
"storrent/dht"
|
||||||
|
"storrent/fs"
|
||||||
|
"storrent/utp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
addr := flag.String("addr", ":5640", "9P listen address")
|
||||||
|
dir := flag.String("dir", "./download", "download directory")
|
||||||
|
dhtPort := flag.Int("dht", 6881, "DHT port")
|
||||||
|
btPort := flag.Int("bt", 0, "BT listen port")
|
||||||
|
utpPort := flag.Int("utp", 6882, "uTP port")
|
||||||
|
bootstrap := flag.String("bootstrap", "router.bittorrent.com:6881", "DHT bootstrap node")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
d, err := dht.New(*dhtPort)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := d.Bootstrap(context.Background(), *bootstrap); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
utpSock, err := utp.New(fmt.Sprintf(":%d", *utpPort))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
utpSock.Start()
|
||||||
|
|
||||||
|
m, err := client.NewManager(*dir, d, *btPort, utpSock)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
go m.Run()
|
||||||
|
|
||||||
|
if err := go9p.Serve(*addr, fs.New(m).Server()); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
182
metainfo/metainfo.go
Normal file
182
metainfo/metainfo.go
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
package metainfo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"storrent/bencode"
|
||||||
|
)
|
||||||
|
|
||||||
|
type File struct {
|
||||||
|
Info Info
|
||||||
|
InfoHash [20]byte
|
||||||
|
Size int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type Info struct {
|
||||||
|
Name string
|
||||||
|
PieceSize int64
|
||||||
|
Pieces [][20]byte
|
||||||
|
Size int64
|
||||||
|
Files []Entry
|
||||||
|
}
|
||||||
|
|
||||||
|
type Entry struct {
|
||||||
|
Size int64
|
||||||
|
Offset int64
|
||||||
|
Path []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Segment struct {
|
||||||
|
File int
|
||||||
|
Offset int64
|
||||||
|
Size int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Info) Segments(off, size int64) []Segment {
|
||||||
|
var segs []Segment
|
||||||
|
for idx, f := range i.Files {
|
||||||
|
end := f.Offset + f.Size
|
||||||
|
if off >= end {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n := min(size, end-off)
|
||||||
|
segs = append(segs, Segment{File: idx, Offset: off - f.Offset, Size: n})
|
||||||
|
off += n
|
||||||
|
size -= n
|
||||||
|
if size == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return segs
|
||||||
|
}
|
||||||
|
|
||||||
|
func Parse(path string) (*File, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ParseBytes(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseBytes(data []byte) (*File, error) {
|
||||||
|
v, _, err := bencode.Decode(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
d, ok := v.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid torrent dict")
|
||||||
|
}
|
||||||
|
|
||||||
|
f := &File{}
|
||||||
|
raw, err := findInfoBytes(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f.InfoHash = sha1.Sum(raw)
|
||||||
|
dict, ok := d["info"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid info dict")
|
||||||
|
}
|
||||||
|
f.Info, err = parseInfo(dict)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if f.Info.Size > 0 {
|
||||||
|
f.Size = f.Info.Size
|
||||||
|
} else {
|
||||||
|
for _, e := range f.Info.Files {
|
||||||
|
f.Size += e.Size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func findInfoBytes(data []byte) ([]byte, error) {
|
||||||
|
if len(data) == 0 || data[0] != 'd' {
|
||||||
|
return nil, fmt.Errorf("not a dict")
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 1
|
||||||
|
for i < len(data) && data[i] != 'e' {
|
||||||
|
k, n, err := bencode.DecodeString(data[i:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
i += n
|
||||||
|
if k == "info" {
|
||||||
|
_, n, err := bencode.Decode(data[i:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return data[i : i+n], nil
|
||||||
|
}
|
||||||
|
_, n, err = bencode.Decode(data[i:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
i += n
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("info not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInfo(d map[string]any) (Info, error) {
|
||||||
|
var info Info
|
||||||
|
name, ok := d["name"].(string)
|
||||||
|
if !ok {
|
||||||
|
return info, fmt.Errorf("invalid name")
|
||||||
|
}
|
||||||
|
info.Name = name
|
||||||
|
pl, ok := d["piece length"].(int64)
|
||||||
|
if !ok {
|
||||||
|
return info, fmt.Errorf("invalid piece length")
|
||||||
|
}
|
||||||
|
info.PieceSize = pl
|
||||||
|
ps, ok := d["pieces"].(string)
|
||||||
|
if !ok || len(ps)%20 != 0 {
|
||||||
|
return info, fmt.Errorf("invalid pieces")
|
||||||
|
}
|
||||||
|
npieces := len(ps) / 20
|
||||||
|
info.Pieces = make([][20]byte, npieces)
|
||||||
|
for i := range npieces {
|
||||||
|
copy(info.Pieces[i][:], ps[i*20:(i+1)*20])
|
||||||
|
}
|
||||||
|
if n, ok := d["length"].(int64); ok {
|
||||||
|
info.Size = n
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fs, ok := d["files"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return info, fmt.Errorf("invalid files")
|
||||||
|
}
|
||||||
|
|
||||||
|
off := int64(0)
|
||||||
|
for _, f := range fs {
|
||||||
|
fd, ok := f.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return info, fmt.Errorf("invalid file entry")
|
||||||
|
}
|
||||||
|
n, ok := fd["length"].(int64)
|
||||||
|
if !ok {
|
||||||
|
return info, fmt.Errorf("invalid file length")
|
||||||
|
}
|
||||||
|
list, ok := fd["path"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return info, fmt.Errorf("invalid file path")
|
||||||
|
}
|
||||||
|
var path []string
|
||||||
|
for _, p := range list {
|
||||||
|
s, ok := p.(string)
|
||||||
|
if !ok {
|
||||||
|
return info, fmt.Errorf("invalid path element")
|
||||||
|
}
|
||||||
|
path = append(path, s)
|
||||||
|
}
|
||||||
|
info.Files = append(info.Files, Entry{Size: n, Offset: off, Path: path})
|
||||||
|
off += n
|
||||||
|
}
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
141
metainfo/metainfo_test.go
Normal file
141
metainfo/metainfo_test.go
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
package metainfo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"storrent/bencode"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParse(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
info map[string]any
|
||||||
|
wantName string
|
||||||
|
wantSize int64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single file",
|
||||||
|
info: map[string]any{
|
||||||
|
"name": "test.txt",
|
||||||
|
"piece length": int64(16384),
|
||||||
|
"pieces": string(make([]byte, 20)),
|
||||||
|
"length": int64(100),
|
||||||
|
},
|
||||||
|
wantName: "test.txt",
|
||||||
|
wantSize: 100,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi file",
|
||||||
|
info: map[string]any{
|
||||||
|
"name": "dir",
|
||||||
|
"piece length": int64(16384),
|
||||||
|
"pieces": string(make([]byte, 20)),
|
||||||
|
"files": []any{
|
||||||
|
map[string]any{"length": int64(100), "path": []any{"a.txt"}},
|
||||||
|
map[string]any{"length": int64(200), "path": []any{"b.txt"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantName: "dir",
|
||||||
|
wantSize: 300,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
data, err := bencode.Encode(map[string]any{"info": tt.info})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Encode: %v", err)
|
||||||
|
}
|
||||||
|
m, err := ParseBytes(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseBytes: %v", err)
|
||||||
|
}
|
||||||
|
if m.Info.Name != tt.wantName {
|
||||||
|
t.Errorf("Name: got %s, want %s", m.Info.Name, tt.wantName)
|
||||||
|
}
|
||||||
|
if m.Size != tt.wantSize {
|
||||||
|
t.Errorf("Size: got %d, want %d", m.Size, tt.wantSize)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInfoHash(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
raw string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic",
|
||||||
|
raw: "d6:lengthi1e4:name4:test12:piece lengthi16384e6:pieces20:01234567890123456789e",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
raw := []byte(tt.raw)
|
||||||
|
torrent := append([]byte("d4:info"), raw...)
|
||||||
|
torrent = append(torrent, 'e')
|
||||||
|
m, err := ParseBytes(torrent)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseBytes: %v", err)
|
||||||
|
}
|
||||||
|
want := sha1.Sum(raw)
|
||||||
|
if m.InfoHash != want {
|
||||||
|
t.Errorf("InfoHash: got %x, want %x", m.InfoHash, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSegments(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
files []Entry
|
||||||
|
off int64
|
||||||
|
size int64
|
||||||
|
want []Segment
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single file",
|
||||||
|
files: []Entry{{Size: 32, Offset: 0}},
|
||||||
|
off: 0,
|
||||||
|
size: 32,
|
||||||
|
want: []Segment{{0, 0, 32}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "spans two files",
|
||||||
|
files: []Entry{{Size: 16, Offset: 0}, {Size: 16, Offset: 16}},
|
||||||
|
off: 0,
|
||||||
|
size: 32,
|
||||||
|
want: []Segment{{0, 0, 16}, {1, 0, 16}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "middle of file",
|
||||||
|
files: []Entry{{Size: 32, Offset: 0}},
|
||||||
|
off: 8,
|
||||||
|
size: 16,
|
||||||
|
want: []Segment{{0, 8, 16}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip first file",
|
||||||
|
files: []Entry{{Size: 16, Offset: 0}, {Size: 16, Offset: 16}},
|
||||||
|
off: 16,
|
||||||
|
size: 16,
|
||||||
|
want: []Segment{{1, 0, 16}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
info := Info{Files: tt.files}
|
||||||
|
got := info.Segments(tt.off, tt.size)
|
||||||
|
if len(got) != len(tt.want) {
|
||||||
|
t.Fatalf("got %d segments, want %d", len(got), len(tt.want))
|
||||||
|
}
|
||||||
|
for i := range got {
|
||||||
|
if got[i] != tt.want[i] {
|
||||||
|
t.Errorf("segment %d: got %+v, want %+v", i, got[i], tt.want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
68
metainfo/storage.go
Normal file
68
metainfo/storage.go
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
package metainfo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (info *Info) Write(dir string, i int, data []byte) error {
|
||||||
|
off := int64(i) * info.PieceSize
|
||||||
|
if info.Size > 0 {
|
||||||
|
return writeAt(filepath.Join(dir, info.Name), off, data)
|
||||||
|
}
|
||||||
|
for _, seg := range info.Segments(off, int64(len(data))) {
|
||||||
|
f := info.Files[seg.File]
|
||||||
|
path := filepath.Join(dir, info.Name, filepath.Join(f.Path...))
|
||||||
|
if err := writeAt(path, seg.Offset, data[:seg.Size]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data = data[seg.Size:]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (info *Info) Read(dir string, i int, size int64) []byte {
|
||||||
|
off := int64(i) * info.PieceSize
|
||||||
|
if info.Size > 0 {
|
||||||
|
return readAt(filepath.Join(dir, info.Name), off, size)
|
||||||
|
}
|
||||||
|
data := make([]byte, size)
|
||||||
|
pos := int64(0)
|
||||||
|
for _, seg := range info.Segments(off, size) {
|
||||||
|
f := info.Files[seg.File]
|
||||||
|
path := filepath.Join(dir, info.Name, filepath.Join(f.Path...))
|
||||||
|
chunk := readAt(path, seg.Offset, seg.Size)
|
||||||
|
if chunk == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
copy(data[pos:], chunk)
|
||||||
|
pos += seg.Size
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeAt(path string, off int64, data []byte) error {
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
_, err = f.WriteAt(data, off)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func readAt(path string, off, size int64) []byte {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
data := make([]byte, size)
|
||||||
|
if _, err := f.ReadAt(data, off); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
305
utp/conn.go
Normal file
305
utp/conn.go
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
package utp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
initRTO = 1000 // ms
|
||||||
|
retransCheck = 500 // ms
|
||||||
|
defaultWnd = 64 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
sock *Socket
|
||||||
|
addr *net.UDPAddr
|
||||||
|
recvID uint16
|
||||||
|
sendID uint16
|
||||||
|
|
||||||
|
in chan packet
|
||||||
|
reads chan readReq
|
||||||
|
writes chan writeReq
|
||||||
|
closeReq chan struct{}
|
||||||
|
ready chan struct{}
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
seq uint16
|
||||||
|
ack uint16
|
||||||
|
unacked []sent
|
||||||
|
reorder map[uint16][]byte
|
||||||
|
rest []byte
|
||||||
|
pending *readReq
|
||||||
|
rto uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type readReq struct {
|
||||||
|
p []byte
|
||||||
|
reply chan readResp
|
||||||
|
}
|
||||||
|
|
||||||
|
type readResp struct {
|
||||||
|
n int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type writeReq struct {
|
||||||
|
data []byte
|
||||||
|
reply chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
type sent struct {
|
||||||
|
seq uint16
|
||||||
|
data []byte
|
||||||
|
at time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) flush() {
|
||||||
|
if c.pending == nil || len(c.rest) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n := copy(c.pending.p, c.rest)
|
||||||
|
c.rest = c.rest[n:]
|
||||||
|
c.pending.reply <- readResp{n, nil}
|
||||||
|
c.pending = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) deliver(payload []byte) {
|
||||||
|
c.rest = append(c.rest, payload...)
|
||||||
|
c.flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) send(typ uint8, payload []byte) error {
|
||||||
|
c.seq++
|
||||||
|
h := &header{
|
||||||
|
typ: typ,
|
||||||
|
ver: 1,
|
||||||
|
connID: c.sendID,
|
||||||
|
timestamp: timestamp(),
|
||||||
|
wnd: defaultWnd,
|
||||||
|
seq: c.seq,
|
||||||
|
ack: c.ack,
|
||||||
|
}
|
||||||
|
buf := make([]byte, headerSize+len(payload))
|
||||||
|
encode(h, buf)
|
||||||
|
copy(buf[headerSize:], payload)
|
||||||
|
if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if typ == Data || typ == Syn || typ == Fin {
|
||||||
|
c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) sendState() error {
|
||||||
|
h := &header{
|
||||||
|
typ: State,
|
||||||
|
ver: 1,
|
||||||
|
connID: c.sendID,
|
||||||
|
timestamp: timestamp(),
|
||||||
|
wnd: defaultWnd,
|
||||||
|
seq: c.seq,
|
||||||
|
ack: c.ack,
|
||||||
|
}
|
||||||
|
buf := make([]byte, headerSize)
|
||||||
|
encode(h, buf)
|
||||||
|
_, err := c.sock.conn.WriteToUDP(buf, c.addr)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) sendSyn() error {
|
||||||
|
c.seq++
|
||||||
|
h := &header{
|
||||||
|
typ: Syn,
|
||||||
|
ver: 1,
|
||||||
|
connID: c.recvID,
|
||||||
|
timestamp: timestamp(),
|
||||||
|
wnd: defaultWnd,
|
||||||
|
seq: c.seq,
|
||||||
|
ack: c.ack,
|
||||||
|
}
|
||||||
|
buf := make([]byte, headerSize)
|
||||||
|
encode(h, buf)
|
||||||
|
if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) retrans() error {
|
||||||
|
select {
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
for i := range c.unacked {
|
||||||
|
p := &c.unacked[i]
|
||||||
|
if now.Sub(p.at) > time.Duration(c.rto)*time.Millisecond {
|
||||||
|
if _, err := c.sock.conn.WriteToUDP(p.data, c.addr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.at = now
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) recv(p packet) (bool, error) {
|
||||||
|
switch p.hdr.typ {
|
||||||
|
case Syn:
|
||||||
|
c.ack = p.hdr.seq
|
||||||
|
if err := c.sendState(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
case State:
|
||||||
|
select {
|
||||||
|
case <-c.ready:
|
||||||
|
default:
|
||||||
|
c.ack = p.hdr.seq - 1
|
||||||
|
close(c.ready)
|
||||||
|
}
|
||||||
|
var remaining []sent
|
||||||
|
for _, s := range c.unacked {
|
||||||
|
if seqLess(p.hdr.ack, s.seq) {
|
||||||
|
remaining = append(remaining, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.unacked = remaining
|
||||||
|
case Data:
|
||||||
|
seq, payload := p.hdr.seq, p.payload
|
||||||
|
if seq == c.ack+1 {
|
||||||
|
c.deliver(payload)
|
||||||
|
c.ack++
|
||||||
|
for {
|
||||||
|
if p, ok := c.reorder[c.ack+1]; ok {
|
||||||
|
c.deliver(p)
|
||||||
|
delete(c.reorder, c.ack+1)
|
||||||
|
c.ack++
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if seqLess(c.ack+1, seq) {
|
||||||
|
c.reorder[seq] = payload
|
||||||
|
}
|
||||||
|
if err := c.sendState(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
case Fin, Reset:
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) shutdown() {
|
||||||
|
if c.pending != nil {
|
||||||
|
c.pending.reply <- readResp{0, io.EOF}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-c.ready:
|
||||||
|
c.send(Fin, nil)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
c.unacked = nil
|
||||||
|
c.cancel()
|
||||||
|
c.sock.removeConn(c.recvID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) run() {
|
||||||
|
c.reorder = make(map[uint16][]byte)
|
||||||
|
c.rto = initRTO
|
||||||
|
|
||||||
|
if c.recvID < c.sendID {
|
||||||
|
c.seq = 1
|
||||||
|
if err := c.sendSyn(); err != nil {
|
||||||
|
c.shutdown()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(retransCheck * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case p := <-c.in:
|
||||||
|
done, err := c.recv(p)
|
||||||
|
if err != nil {
|
||||||
|
c.shutdown()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
c.shutdown()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case req := <-c.reads:
|
||||||
|
if c.pending != nil {
|
||||||
|
req.reply <- readResp{0, io.ErrNoProgress}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.pending = &req
|
||||||
|
c.flush()
|
||||||
|
case req := <-c.writes:
|
||||||
|
if err := c.send(Data, req.data); err != nil {
|
||||||
|
req.reply <- err
|
||||||
|
c.shutdown()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.reply <- nil
|
||||||
|
case <-c.closeReq:
|
||||||
|
c.shutdown()
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
c.retrans()
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Read(p []byte) (int, error) {
|
||||||
|
reply := make(chan readResp, 1)
|
||||||
|
select {
|
||||||
|
case c.reads <- readReq{p, reply}:
|
||||||
|
r := <-reply
|
||||||
|
return r.n, r.err
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Write(p []byte) (int, error) {
|
||||||
|
reply := make(chan error, 1)
|
||||||
|
select {
|
||||||
|
case c.writes <- writeReq{p, reply}:
|
||||||
|
err := <-reply
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
select {
|
||||||
|
case c.closeReq <- struct{}{}:
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func timestamp() uint32 {
|
||||||
|
return uint32(time.Now().UnixMicro() & 0xFFFFFFFF)
|
||||||
|
}
|
||||||
|
|
||||||
|
func seqLess(a, b uint16) bool {
|
||||||
|
return int16(a-b) < 0
|
||||||
|
}
|
||||||
211
utp/socket.go
Normal file
211
utp/socket.go
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
package utp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Data = 0
|
||||||
|
Fin = 1
|
||||||
|
State = 2
|
||||||
|
Reset = 3
|
||||||
|
Syn = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
const headerSize = 20
|
||||||
|
|
||||||
|
type header struct {
|
||||||
|
typ uint8
|
||||||
|
ver uint8
|
||||||
|
ext uint8
|
||||||
|
connID uint16
|
||||||
|
timestamp uint32
|
||||||
|
timeDiff uint32
|
||||||
|
wnd uint32
|
||||||
|
seq uint16
|
||||||
|
ack uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func encode(h *header, buf []byte) {
|
||||||
|
buf[0] = (h.typ << 4) | (h.ver & 0x0F)
|
||||||
|
buf[1] = h.ext
|
||||||
|
binary.BigEndian.PutUint16(buf[2:], h.connID)
|
||||||
|
binary.BigEndian.PutUint32(buf[4:], h.timestamp)
|
||||||
|
binary.BigEndian.PutUint32(buf[8:], h.timeDiff)
|
||||||
|
binary.BigEndian.PutUint32(buf[12:], h.wnd)
|
||||||
|
binary.BigEndian.PutUint16(buf[16:], h.seq)
|
||||||
|
binary.BigEndian.PutUint16(buf[18:], h.ack)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decode(buf []byte) header {
|
||||||
|
return header{
|
||||||
|
typ: (buf[0] >> 4) & 0x0F,
|
||||||
|
ver: buf[0] & 0x0F,
|
||||||
|
ext: buf[1],
|
||||||
|
connID: binary.BigEndian.Uint16(buf[2:]),
|
||||||
|
timestamp: binary.BigEndian.Uint32(buf[4:]),
|
||||||
|
timeDiff: binary.BigEndian.Uint32(buf[8:]),
|
||||||
|
wnd: binary.BigEndian.Uint32(buf[12:]),
|
||||||
|
seq: binary.BigEndian.Uint16(buf[16:]),
|
||||||
|
ack: binary.BigEndian.Uint16(buf[18:]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Socket struct {
|
||||||
|
conn *net.UDPConn
|
||||||
|
mu sync.RWMutex
|
||||||
|
conns map[uint16]*Conn
|
||||||
|
accepts chan *Conn
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
type packet struct {
|
||||||
|
hdr header
|
||||||
|
payload []byte
|
||||||
|
addr *net.UDPAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(addr string) (*Socket, error) {
|
||||||
|
a, err := net.ResolveUDPAddr("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn, err := net.ListenUDP("udp", a)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
s := &Socket{
|
||||||
|
conn: conn,
|
||||||
|
conns: make(map[uint16]*Conn),
|
||||||
|
accepts: make(chan *Conn, 16),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) Start() {
|
||||||
|
go s.reader()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) reader() {
|
||||||
|
buf := make([]byte, 65535)
|
||||||
|
for {
|
||||||
|
n, addr, err := s.conn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n < headerSize {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hdr := decode(buf)
|
||||||
|
payload := make([]byte, n-headerSize)
|
||||||
|
copy(payload, buf[headerSize:n])
|
||||||
|
|
||||||
|
if hdr.typ == Syn {
|
||||||
|
s.handleSyn(hdr, payload, addr)
|
||||||
|
} else {
|
||||||
|
s.dispatch(hdr, payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) handleSyn(hdr header, payload []byte, addr *net.UDPAddr) {
|
||||||
|
c := s.newConn(hdr.connID, addr, false)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.conns[c.recvID] = c
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
go c.run()
|
||||||
|
c.in <- packet{hdr, payload, addr}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case s.accepts <- c:
|
||||||
|
default:
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) dispatch(hdr header, payload []byte) {
|
||||||
|
s.mu.RLock()
|
||||||
|
c := s.conns[hdr.connID]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if c != nil {
|
||||||
|
select {
|
||||||
|
case c.in <- packet{hdr, payload, nil}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) removeConn(id uint16) {
|
||||||
|
s.mu.Lock()
|
||||||
|
delete(s.conns, id)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) newConn(peerID uint16, addr *net.UDPAddr, initiator bool) *Conn {
|
||||||
|
ctx, cancel := context.WithCancel(s.ctx)
|
||||||
|
c := &Conn{
|
||||||
|
sock: s,
|
||||||
|
addr: addr,
|
||||||
|
in: make(chan packet, 16),
|
||||||
|
reads: make(chan readReq),
|
||||||
|
writes: make(chan writeReq),
|
||||||
|
closeReq: make(chan struct{}),
|
||||||
|
ready: make(chan struct{}),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
if initiator {
|
||||||
|
c.recvID = randUint16()
|
||||||
|
c.sendID = c.recvID + 1
|
||||||
|
} else {
|
||||||
|
c.recvID = peerID + 1
|
||||||
|
c.sendID = peerID
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) DialContext(ctx context.Context, addr *net.UDPAddr) (*Conn, error) {
|
||||||
|
c := s.newConn(0, addr, true)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.conns[c.recvID] = c
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
go c.run()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.ready:
|
||||||
|
return c, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
c.Close()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) Accept() *Conn {
|
||||||
|
return <-s.accepts
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) Close() {
|
||||||
|
s.conn.Close()
|
||||||
|
s.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func randUint16() uint16 {
|
||||||
|
var b [2]byte
|
||||||
|
if _, err := rand.Read(b[:]); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return binary.BigEndian.Uint16(b[:])
|
||||||
|
}
|
||||||
82
utp/utp_test.go
Normal file
82
utp/utp_test.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
package utp
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestEncodeDecode(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
h header
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "syn",
|
||||||
|
h: header{
|
||||||
|
typ: Syn,
|
||||||
|
ver: 1,
|
||||||
|
ext: 0,
|
||||||
|
connID: 1234,
|
||||||
|
timestamp: 12345678,
|
||||||
|
timeDiff: 1000,
|
||||||
|
wnd: 65535,
|
||||||
|
seq: 1,
|
||||||
|
ack: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "data",
|
||||||
|
h: header{
|
||||||
|
typ: Data,
|
||||||
|
ver: 1,
|
||||||
|
connID: 5678,
|
||||||
|
timestamp: 99999999,
|
||||||
|
timeDiff: 500,
|
||||||
|
wnd: 32768,
|
||||||
|
seq: 100,
|
||||||
|
ack: 99,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "state",
|
||||||
|
h: header{
|
||||||
|
typ: State,
|
||||||
|
ver: 1,
|
||||||
|
connID: 1234,
|
||||||
|
wnd: 65535,
|
||||||
|
seq: 1,
|
||||||
|
ack: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
buf := make([]byte, headerSize)
|
||||||
|
encode(&tt.h, buf)
|
||||||
|
got := decode(buf)
|
||||||
|
if got != tt.h {
|
||||||
|
t.Errorf("got %+v, want %+v", got, tt.h)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSeqLess(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
a, b uint16
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"1 < 2", 1, 2, true},
|
||||||
|
{"2 > 1", 2, 1, false},
|
||||||
|
{"wrap 0xFFFF < 0", 0xFFFF, 0, true},
|
||||||
|
{"wrap 0 > 0xFFFF", 0, 0xFFFF, false},
|
||||||
|
{"half 0x7FFF > 0", 0x7FFF, 0, false},
|
||||||
|
{"half 0 < 0x7FFF", 0, 0x7FFF, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := seqLess(tt.a, tt.b)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("seqLess(%d, %d) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user