first commit

This commit is contained in:
Hojun-Cho 2026-01-19 21:13:01 +09:00
commit c70d24be5c
28 changed files with 3674 additions and 0 deletions

38
.gitignore vendored Normal file
View File

@ -0,0 +1,38 @@
# If you prefer the allow list template instead of the deny list, see community template:
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
#
*.torrent
storrent
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Code coverage profiles and other test artifacts
*.out
coverage.*
*.coverprofile
profile.cov
# Dependency directories (remove the comment below to include it)
# vendor/
# Go workspace file
go.work
go.work.sum
# env file
.env
# Editor/IDE
# .idea/
# .vscode/

64
README.md Normal file
View File

@ -0,0 +1,64 @@
# storrent
A BitTorrent client with a 9P filesystem interface.
## Dependencies
- plan9port
## Features
- DHT for peer discovery
- μTP transport
- 9P control interface
## Build
go build
## Run
./storrent -dht 6881 -utp 6882 -dir ./download -addr :5640 &
9pfuse localhost:5640 /mnt/storrent
## Usage
Interact via files:
echo 'add /path/to/file.torrent' > /mnt/storrent/ctl
cat /mnt/storrent/list
cat /mnt/storrent/torrents/0/progress
Torrent control commands (write to `torrents/<id>/ctl`):
start start downloading
stop stop torrent
seed start seeding
remove remove torrent
peer <addr> add peer manually
Status files in `torrents/<id>/`:
name torrent name
state current state
progress download progress
size total size
down bytes downloaded
up bytes uploaded
pieces piece info
peers connected peers
## TODO
- μTP congestion control
- ReadAt for streaming and seeding while downloading
## Packages
bencode/ bencode encoding/decoding
bt/ BitTorrent protocol messages
client/ torrent management
dht/ DHT implementation
fs/ 9P filesystem
metainfo/ .torrent file parsing
utp/ uTP transport

173
bencode/bencode.go Normal file
View File

@ -0,0 +1,173 @@
package bencode
// bencode types:
// string <len>:<data> "4:spam"
// int i<n>e "i42e"
// list l<items>e "l4:spam4:eggse"
// dict d<kv pairs>e "d3:cow3:mooe"
import (
"fmt"
"sort"
"strconv"
)
func Decode(data []byte) (any, int, error) {
if len(data) == 0 {
return nil, 0, fmt.Errorf("empty data")
}
switch data[0] {
case 'i':
return decodeInt(data)
case 'l':
return decodeList(data)
case 'd':
return decodeDict(data)
default:
if data[0] >= '0' && data[0] <= '9' {
return DecodeString(data)
}
return nil, 0, fmt.Errorf("invalid bencode: %q", data[0])
}
}
func DecodeString(data []byte) (string, int, error) {
i := 0
for i < len(data) && data[i] != ':' {
i++
}
if i >= len(data) {
return "", 0, fmt.Errorf("bencode: missing ':' in string at %d", i)
}
n, err := strconv.Atoi(string(data[:i]))
if err != nil {
return "", 0, err
}
i++
if i+n > len(data) {
return "", 0, fmt.Errorf("bencode: truncated string at %d", i)
}
return string(data[i : i+n]), i + n, nil
}
func decodeInt(data []byte) (int64, int, error) {
if len(data) < 3 {
return 0, 0, fmt.Errorf("bencode: int too short at 0")
}
i := 1
for i < len(data) && data[i] != 'e' {
i++
}
if i >= len(data) {
return 0, 0, fmt.Errorf("bencode: missing 'e' in int at %d", i)
}
val, err := strconv.ParseInt(string(data[1:i]), 10, 64)
if err != nil {
return 0, 0, err
}
return val, i + 1, nil
}
func decodeList(data []byte) ([]any, int, error) {
if len(data) == 0 {
return nil, 0, fmt.Errorf("bencode: empty list at 0")
}
var list []any
i := 1
for i < len(data) && data[i] != 'e' {
v, n, err := Decode(data[i:])
if err != nil {
return nil, 0, err
}
list = append(list, v)
i += n
}
if i >= len(data) {
return nil, 0, fmt.Errorf("bencode: truncated list at %d", i)
}
return list, i + 1, nil
}
func decodeDict(data []byte) (map[string]any, int, error) {
if len(data) == 0 {
return nil, 0, fmt.Errorf("bencode: empty dict at 0")
}
d := make(map[string]any)
i := 1
for i < len(data) && data[i] != 'e' {
k, n, err := DecodeString(data[i:])
if err != nil {
return nil, 0, err
}
i += n
v, n, err := Decode(data[i:])
if err != nil {
return nil, 0, err
}
d[k] = v
i += n
}
if i >= len(data) {
return nil, 0, fmt.Errorf("bencode: truncated dict at %d", i)
}
return d, i + 1, nil
}
func Encode(v any) ([]byte, error) {
switch v := v.(type) {
case string:
return encodeString(v), nil
case []byte:
return encodeString(string(v)), nil
case int:
return encodeInt(int64(v)), nil
case int64:
return encodeInt(v), nil
case []any:
return encodeList(v)
case map[string]any:
return encodeDict(v)
}
return nil, fmt.Errorf("cannot encode %T", v)
}
func encodeString(s string) []byte {
return fmt.Appendf(nil, "%d:%s", len(s), s)
}
func encodeInt(n int64) []byte {
return fmt.Appendf(nil, "i%de", n)
}
func encodeList(list []any) ([]byte, error) {
buf := []byte{'l'}
for _, v := range list {
enc, err := Encode(v)
if err != nil {
return nil, err
}
buf = append(buf, enc...)
}
return append(buf, 'e'), nil
}
func encodeDict(d map[string]any) ([]byte, error) {
keys := make([]string, 0, len(d))
for k := range d {
keys = append(keys, k)
}
sort.Strings(keys)
buf := []byte{'d'}
for _, k := range keys {
buf = append(buf, encodeString(k)...)
enc, err := Encode(d[k])
if err != nil {
return nil, err
}
buf = append(buf, enc...)
}
return append(buf, 'e'), nil
}

113
bencode/bencode_test.go Normal file
View File

@ -0,0 +1,113 @@
package bencode
import "testing"
func TestDecode(t *testing.T) {
tests := []struct {
in string
want any
err bool
}{
{
in: "4:spam",
want: "spam",
},
{
in: "0:",
want: "",
},
{
in: "i42e",
want: int64(42),
},
{
in: "i-42e",
want: int64(-42),
},
{
in: "i0e",
want: int64(0),
},
{
in: "",
err: true,
},
}
for _, tt := range tests {
t.Run(tt.in, func(t *testing.T) {
got, _, err := Decode([]byte(tt.in))
if tt.err {
if err == nil {
t.Error("expected error")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
func TestEncode(t *testing.T) {
tests := []struct {
name string
in any
want string
}{
{
name: "string",
in: "spam",
want: "4:spam",
},
{
name: "empty string",
in: "",
want: "0:",
},
{
name: "int",
in: int64(42),
want: "i42e",
},
{
name: "negative int",
in: int64(-42),
want: "i-42e",
},
{
name: "list",
in: []any{"a", "b"},
want: "l1:a1:be",
},
{
name: "empty list",
in: []any{},
want: "le",
},
{
name: "dict",
in: map[string]any{"b": int64(2), "a": int64(1)},
want: "d1:ai1e1:bi2ee",
},
{
name: "empty dict",
in: map[string]any{},
want: "de",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Encode(tt.in)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != tt.want {
t.Errorf("got %s, want %s", got, tt.want)
}
})
}
}

116
bt/msg.go Normal file
View File

@ -0,0 +1,116 @@
package bt
// https://wiki.theory.org/BitTorrentSpecification#Messages
// Messages: <length prefix><message ID><payload>
// keep-alive: <len=0000>
// choke: <len=0001><type=0>
// unchoke: <len=0001><type=1>
// interested: <len=0001><type=2>
// not interested: <len=0001><type=3>
// have: <len=0005><type=4><piece index>
// bitfield: <len=0001+X><type=5><bitfield>
// request: <len=0013><type=6><index><begin><length>
// piece: <len=0009+X><type=7><index><begin><block>
// cancel: <len=0013><type=8><index><begin><length>
import (
"encoding/binary"
"io"
)
type MsgKind uint8
const BlockSize = 16384
const (
Choke MsgKind = iota
Unchoke
Interested
NotInterested
Have
Bitfield
Request
Piece
Cancel
)
type Msg struct {
Kind MsgKind
Index uint32
Begin uint32
Length uint32
Bitfield []byte
Block []byte
}
// [4: len] [1: type] [len-1: payload]
func readMsg(r io.Reader) (*Msg, int, error) {
var lenBuf [4]byte
if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
return nil, 0, err
}
n := binary.BigEndian.Uint32(lenBuf[:])
if n == 0 {
return nil, 4, nil
}
buf := make([]byte, n)
if _, err := io.ReadFull(r, buf); err != nil {
return nil, 0, err
}
msg := &Msg{Kind: MsgKind(buf[0])}
payload := buf[1:]
switch msg.Kind {
case Have:
msg.Index = binary.BigEndian.Uint32(payload)
case Bitfield:
msg.Bitfield = payload
case Request, Cancel:
msg.Index = binary.BigEndian.Uint32(payload[0:4])
msg.Begin = binary.BigEndian.Uint32(payload[4:8])
msg.Length = binary.BigEndian.Uint32(payload[8:12])
case Piece:
msg.Index = binary.BigEndian.Uint32(payload[0:4])
msg.Begin = binary.BigEndian.Uint32(payload[4:8])
msg.Block = payload[8:]
}
return msg, int(4 + n), nil
}
func writeMsg(w io.Writer, msg *Msg) (int, error) {
if msg == nil {
return w.Write(make([]byte, 4))
}
var payload []byte
switch msg.Kind {
case Choke, Unchoke, Interested, NotInterested:
case Have:
payload = make([]byte, 4)
binary.BigEndian.PutUint32(payload, msg.Index)
case Bitfield:
payload = msg.Bitfield
case Request, Cancel:
payload = make([]byte, 12)
binary.BigEndian.PutUint32(payload[0:4], msg.Index)
binary.BigEndian.PutUint32(payload[4:8], msg.Begin)
binary.BigEndian.PutUint32(payload[8:12], msg.Length)
case Piece:
payload = make([]byte, 8+len(msg.Block))
binary.BigEndian.PutUint32(payload[0:4], msg.Index)
binary.BigEndian.PutUint32(payload[4:8], msg.Begin)
copy(payload[8:], msg.Block)
}
buf := make([]byte, 5+len(payload))
binary.BigEndian.PutUint32(buf[0:4], uint32(1+len(payload)))
buf[4] = byte(msg.Kind)
copy(buf[5:], payload)
return w.Write(buf)
}
func SetBit(bf []byte, i int) {
bf[i/8] |= 1 << (7 - i%8)
}
func HasBit(bf []byte, i int) bool {
return bf[i/8]&(1<<(7-i%8)) != 0
}

91
bt/msg_test.go Normal file
View File

@ -0,0 +1,91 @@
package bt
import (
"bytes"
"testing"
)
func TestMsg(t *testing.T) {
tests := []struct {
name string
msg *Msg
}{
{
name: "choke",
msg: &Msg{Kind: Choke},
},
{
name: "unchoke",
msg: &Msg{Kind: Unchoke},
},
{
name: "interested",
msg: &Msg{Kind: Interested},
},
{
name: "not interested",
msg: &Msg{Kind: NotInterested},
},
{
name: "have",
msg: &Msg{
Kind: Have,
Index: 42,
},
},
{
name: "bitfield",
msg: &Msg{
Kind: Bitfield,
Bitfield: []byte{0b1111_1111, 0b0000_0000},
},
},
{
name: "request",
msg: &Msg{
Kind: Request,
Index: 1,
Begin: 16384,
Length: 16384,
},
},
{
name: "piece",
msg: &Msg{
Kind: Piece,
Index: 1,
Begin: 0,
Block: []byte("data"),
},
},
{
name: "cancel",
msg: &Msg{
Kind: Cancel,
Index: 1,
Begin: 0,
Length: 16384,
},
},
{
name: "keep-alive",
msg: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
writeMsg(&buf, tt.msg)
data := buf.Bytes()
got, _, err := readMsg(bytes.NewReader(data))
if err != nil {
t.Fatalf("readMsg: %v", err)
}
var buf2 bytes.Buffer
writeMsg(&buf2, got)
if !bytes.Equal(buf2.Bytes(), data) {
t.Errorf("roundtrip mismatch: got %x, want %x", buf2.Bytes(), data)
}
})
}
}

160
bt/peer.go Normal file
View File

@ -0,0 +1,160 @@
package bt
// https://wiki.theory.org/BitTorrentSpecification#Handshake
// handshake: <pstrlen><pstr><reserved><info_hash><peer_id>
// pstrlen: string length of <pstr>, as a single raw byte.
// pstr: string identifier of the protocol.
// reserved: eight (8) reserved bytes.
// info_hash: 20-byte SHA1 hash of the info key in the metainfo file.
// peer_id: 20-byte string used as a unique ID for the client.
import (
"context"
"crypto/rand"
"fmt"
"io"
"net"
"time"
"storrent/utp"
)
type Peer struct {
Addr *net.TCPAddr
conn io.ReadWriteCloser
InfoHash [20]byte
PeerID [20]byte
Choked bool
Interested bool
Have []bool
Pending int
}
const btProto = "BitTorrent protocol"
const (
pstrlen = 19
reservedLen = 8
hashLen = 20
peerIDLen = 20
handshakeLen = 1 + pstrlen + reservedLen + hashLen + peerIDLen // 68
handshakeTimeout = 10 * time.Second
)
var peerID = genPeerID()
func genPeerID() [20]byte {
var id [20]byte
copy(id[:], "-SS0001-")
if _, err := rand.Read(id[8:]); err != nil {
panic(err)
}
return id
}
func DialContext(ctx context.Context, addr *net.TCPAddr, h [20]byte, utpSock *utp.Socket) (*Peer, error) {
if utpSock == nil {
return nil, fmt.Errorf("utp socket required")
}
udpAddr := &net.UDPAddr{IP: addr.IP, Port: addr.Port}
conn, err := utpSock.DialContext(ctx, udpAddr)
if err != nil {
return nil, err
}
p := &Peer{Addr: addr, conn: conn, InfoHash: h, Choked: true}
if err := p.handshake(); err != nil {
conn.Close()
return nil, err
}
return p, nil
}
// [1: pstrlen] [pstrlen: btProto] [reservedLen: reserved] [hashLen: info_hash] [peerIDLen: peer_id]
func (p *Peer) handshake() error {
buf := make([]byte, handshakeLen)
buf[0] = pstrlen
copy(buf[1:1+pstrlen], btProto)
copy(buf[1+pstrlen+reservedLen:1+pstrlen+reservedLen+hashLen], p.InfoHash[:])
copy(buf[1+pstrlen+reservedLen+hashLen:], peerID[:])
if _, err := p.conn.Write(buf); err != nil {
return err
}
if _, err := io.ReadFull(p.conn, buf); err != nil {
return err
}
if buf[0] != pstrlen || string(buf[1:1+pstrlen]) != btProto {
return fmt.Errorf("invalid handshake")
}
if [hashLen]byte(buf[1+pstrlen+reservedLen:1+pstrlen+reservedLen+hashLen]) != p.InfoHash {
return fmt.Errorf("info hash mismatch")
}
p.PeerID = [peerIDLen]byte(buf[1+pstrlen+reservedLen+hashLen:])
return nil
}
func (p *Peer) Send(msg *Msg) error {
_, err := writeMsg(p.conn, msg)
return err
}
func (p *Peer) Recv() (*Msg, error) {
msg, _, err := readMsg(p.conn)
return msg, err
}
func (p *Peer) Close() error {
return p.conn.Close()
}
func Accept(ln net.Listener) (*Peer, error) {
conn, err := ln.Accept()
if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(handshakeTimeout))
buf := make([]byte, handshakeLen)
if _, err := io.ReadFull(conn, buf); err != nil {
conn.Close()
return nil, err
}
if buf[0] != pstrlen || string(buf[1:1+pstrlen]) != btProto {
conn.Close()
return nil, fmt.Errorf("invalid handshake")
}
infoHash := [hashLen]byte(buf[1+pstrlen+reservedLen : 1+pstrlen+reservedLen+hashLen])
remotePeerID := [peerIDLen]byte(buf[1+pstrlen+reservedLen+hashLen:])
copy(buf[1+pstrlen+reservedLen+hashLen:], peerID[:])
if _, err := conn.Write(buf); err != nil {
conn.Close()
return nil, err
}
conn.SetDeadline(time.Time{})
return &Peer{
conn: conn,
InfoHash: infoHash,
PeerID: remotePeerID,
Choked: true,
}, nil
}
// piece 0~7: byte 0 [7 6 5 4 3 2 1 0]
// piece 8~15: byte 1 [7 6 5 4 3 2 1 0]
func (p *Peer) SetPieces(data []byte, n int) {
p.Have = make([]bool, n)
for i := range n {
p.Have[i] = HasBit(data, i)
}
}
func (p *Peer) HasPiece(i int) bool {
if i >= 0 && i < len(p.Have) {
return p.Have[i]
}
return false
}
func (p *Peer) SetPiece(i int) {
if i >= 0 && i < len(p.Have) {
p.Have[i] = true
}
}

37
bt/peer_test.go Normal file
View File

@ -0,0 +1,37 @@
package bt
import "testing"
func TestSetPieces(t *testing.T) {
tests := []struct {
name string
data []byte
n int
i int
want bool
}{
{
name: "piece 0 set",
data: []byte{0b1000_0000},
n: 8,
i: 0,
want: true,
},
{
name: "piece 7 unset",
data: []byte{0b1111_1110},
n: 8,
i: 7,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Peer{}
c.SetPieces(tt.data, tt.n)
if c.Have[tt.i] != tt.want {
t.Errorf("Have[%d]: got %v, want %v", tt.i, c.Have[tt.i], tt.want)
}
})
}
}

139
client/dial.go Normal file
View File

@ -0,0 +1,139 @@
package client
import (
"context"
"net"
"sync"
"storrent/bt"
"storrent/dht"
"storrent/utp"
)
const maxConcurrentDials = 8
type dialResult struct {
peer *bt.Peer
err error
done bool
}
type dialPool struct {
sem chan struct{}
results chan dialResult
dht *dht.DHT
utp *utp.Socket
infoHash [20]byte
running bool
cancel context.CancelFunc
}
func newDialPool(d *dht.DHT, utpSock *utp.Socket, infoHash [20]byte) *dialPool {
return &dialPool{
sem: make(chan struct{}, maxConcurrentDials),
results: make(chan dialResult, maxConcurrentDials),
dht: d,
utp: utpSock,
infoHash: infoHash,
}
}
func (dp *dialPool) start(parentCtx context.Context) {
if dp.running {
return
}
dp.running = true
ctx, cancel := context.WithCancel(parentCtx)
dp.cancel = cancel
go dp.connectLoop(ctx)
}
func (dp *dialPool) stop() {
if !dp.running {
return
}
dp.running = false
if dp.cancel != nil {
dp.cancel()
dp.cancel = nil
}
}
func (dp *dialPool) connectLoop(ctx context.Context) {
dhtCtx, cancel := context.WithTimeout(ctx, dhtTimeout)
addrs, err := dp.dht.GetPeers(dhtCtx, dp.infoHash)
cancel()
if err != nil {
select {
case dp.results <- dialResult{done: true, err: err}:
case <-ctx.Done():
}
return
}
var wg sync.WaitGroup
loop:
for _, addr := range addrs {
select {
case dp.sem <- struct{}{}:
case <-ctx.Done():
break loop
}
wg.Add(1)
go func(addr *net.TCPAddr) {
defer wg.Done()
defer func() { <-dp.sem }()
dialCtx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
p, err := bt.DialContext(dialCtx, addr, dp.infoHash, dp.utp)
if err != nil {
return
}
select {
case dp.results <- dialResult{peer: p}:
case <-ctx.Done():
p.Close()
}
}(addr)
}
wg.Wait()
select {
case dp.results <- dialResult{done: true}:
case <-ctx.Done():
}
}
func (dp *dialPool) dialSingle(ctx context.Context, addr *net.TCPAddr) {
select {
case dp.sem <- struct{}{}:
case <-ctx.Done():
return
}
go func() {
defer func() { <-dp.sem }()
dialCtx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
p, err := bt.DialContext(dialCtx, addr, dp.infoHash, dp.utp)
if err != nil {
return
}
select {
case dp.results <- dialResult{peer: p}:
case <-ctx.Done():
p.Close()
}
}()
}

178
client/manager.go Normal file
View File

@ -0,0 +1,178 @@
package client
import (
"errors"
"fmt"
"net"
"strconv"
"sync"
"storrent/bt"
"storrent/dht"
"storrent/metainfo"
"storrent/utp"
)
var errNotFound = errors.New("torrent not found")
func (m *Manager) getTorrent(id int) (*torrent, error) {
m.mu.Lock()
t, ok := m.torrents[id]
m.mu.Unlock()
if !ok {
return nil, errNotFound
}
return t, nil
}
type Manager struct {
mu sync.Mutex
torrents map[int]*torrent
nextID int
dir string
dht *dht.DHT
ln net.Listener
utp *utp.Socket
}
func NewManager(dir string, d *dht.DHT, btPort int, utpSock *utp.Socket) (*Manager, error) {
var ln net.Listener
if btPort > 0 {
var err error
ln, err = net.Listen("tcp", fmt.Sprintf(":%d", btPort))
if err != nil {
return nil, err
}
}
return &Manager{
torrents: make(map[int]*torrent),
dir: dir,
dht: d,
ln: ln,
utp: utpSock,
}, nil
}
func (m *Manager) Run() {
if m.ln == nil {
return
}
for {
p, err := bt.Accept(m.ln)
if err != nil {
return
}
m.dispatch(p)
}
}
func (m *Manager) dispatch(p *bt.Peer) {
m.mu.Lock()
defer m.mu.Unlock()
for _, t := range m.torrents {
if t.meta.InfoHash == p.InfoHash {
select {
case t.incoming <- p:
default:
p.Close()
}
return
}
}
p.Close()
}
func (m *Manager) Add(path string) ([]byte, error) {
mi, err := metainfo.Parse(path)
if err != nil {
return nil, err
}
m.mu.Lock()
id := m.nextID
m.nextID++
t := newTorrent(id, mi, m.dir, m.dht, m.utp)
complete := t.scheduler.isComplete()
go t.run()
m.torrents[id] = t
m.mu.Unlock()
if complete {
t.startSeed()
} else {
t.startDownload()
}
return []byte(strconv.Itoa(id)), nil
}
func (m *Manager) Remove(id int) error {
m.mu.Lock()
t, ok := m.torrents[id]
if !ok {
m.mu.Unlock()
return errNotFound
}
delete(m.torrents, id)
m.mu.Unlock()
t.stopTorrent()
return nil
}
func (m *Manager) Start(id int) error {
t, err := m.getTorrent(id)
if err != nil {
return err
}
t.startDownload()
return nil
}
func (m *Manager) Stop(id int) error {
t, err := m.getTorrent(id)
if err != nil {
return err
}
t.stopTorrent()
return nil
}
func (m *Manager) Seed(id int) error {
t, err := m.getTorrent(id)
if err != nil {
return err
}
t.startSeed()
return nil
}
func (m *Manager) AddPeer(id int, addr string) error {
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return err
}
t, err := m.getTorrent(id)
if err != nil {
return err
}
t.enqueuePeer(tcpAddr)
return nil
}
func (m *Manager) Status(id int, field string) ([]byte, error) {
t, err := m.getTorrent(id)
if err != nil {
return nil, err
}
return []byte(t.getStatus(field)), nil
}
func (m *Manager) List() []byte {
m.mu.Lock()
defer m.mu.Unlock()
var data []byte
for id := range m.torrents {
if len(data) > 0 {
data = append(data, '\n')
}
data = append(data, []byte(strconv.Itoa(id))...)
}
return data
}

50
client/peer.go Normal file
View File

@ -0,0 +1,50 @@
package client
import "storrent/bt"
type peerManager struct {
peers map[*bt.Peer]struct{}
maxPeers int
}
func newPeerManager(maxPeers int) *peerManager {
return &peerManager{
peers: make(map[*bt.Peer]struct{}),
maxPeers: maxPeers,
}
}
func (pm *peerManager) add(p *bt.Peer) bool {
if len(pm.peers) >= pm.maxPeers {
return false
}
pm.peers[p] = struct{}{}
return true
}
func (pm *peerManager) addUnlimited(p *bt.Peer) {
pm.peers[p] = struct{}{}
}
func (pm *peerManager) remove(p *bt.Peer) {
delete(pm.peers, p)
}
func (pm *peerManager) count() int {
return len(pm.peers)
}
func (pm *peerManager) all() []*bt.Peer {
peers := make([]*bt.Peer, 0, len(pm.peers))
for p := range pm.peers {
peers = append(peers, p)
}
return peers
}
func (pm *peerManager) closeAll() {
for p := range pm.peers {
p.Close()
}
pm.peers = make(map[*bt.Peer]struct{})
}

69
client/piece.go Normal file
View File

@ -0,0 +1,69 @@
package client
import (
"crypto/sha1"
"storrent/bt"
)
type piece struct {
hash [20]byte
size int64
data []byte
have []bool
reqs []bool
reqPeer []*bt.Peer
done bool
}
func (p *piece) put(begin int, data []byte) bool {
if p.done {
return false
}
block := begin / bt.BlockSize
if block >= len(p.have) || p.have[block] {
return false
}
copy(p.data[begin:], data)
p.have[block] = true
for _, got := range p.have {
if !got {
return false
}
}
if sha1.Sum(p.data) != p.hash {
for i := range p.have {
p.have[i] = false
p.reqs[i] = false
}
return false
}
p.done = true
return true
}
func (p *piece) next(peer *bt.Peer) (begin, length int64, ok bool) {
for i := range p.have {
if p.have[i] || p.reqs[i] {
continue
}
p.reqs[i] = true
p.reqPeer[i] = peer
begin = int64(i) * bt.BlockSize
length = bt.BlockSize
if begin+length > p.size {
length = p.size - begin
}
return begin, length, true
}
return 0, 0, false
}
func (p *piece) resetPeer(peer *bt.Peer) {
for i := range p.reqs {
if p.reqs[i] && p.reqPeer[i] == peer {
p.reqs[i] = false
p.reqPeer[i] = nil
}
}
}

134
client/scheduler.go Normal file
View File

@ -0,0 +1,134 @@
package client
import (
"crypto/sha1"
"storrent/bt"
"storrent/metainfo"
)
type pieceScheduler struct {
pieces []piece
incomplete []int
verified int
meta *metainfo.File
dir string
}
func newPieceScheduler(m *metainfo.File, dir string) *pieceScheduler {
n := len(m.Info.Pieces)
pieces := make([]piece, n)
incomplete := make([]int, 0, n)
verified := 0
for i := range n {
size := m.Info.PieceSize
if i == n-1 {
size = m.Size - int64(n-1)*m.Info.PieceSize
}
nblocks := (size + bt.BlockSize - 1) / bt.BlockSize
pieces[i] = piece{
hash: m.Info.Pieces[i],
size: size,
data: make([]byte, size),
have: make([]bool, nblocks),
reqs: make([]bool, nblocks),
reqPeer: make([]*bt.Peer, nblocks),
}
data := m.Info.Read(dir, i, pieces[i].size)
if data != nil && sha1.Sum(data) == pieces[i].hash {
pieces[i].data = data
pieces[i].done = true
verified++
} else {
incomplete = append(incomplete, i)
}
}
return &pieceScheduler{
pieces: pieces,
incomplete: incomplete,
verified: verified,
meta: m,
dir: dir,
}
}
func (ps *pieceScheduler) nextRequest(p *bt.Peer) *bt.Msg {
for _, idx := range ps.incomplete {
pc := &ps.pieces[idx]
if pc.done {
continue
}
if !p.HasPiece(idx) {
continue
}
if begin, length, ok := pc.next(p); ok {
return &bt.Msg{
Kind: bt.Request,
Index: uint32(idx),
Begin: uint32(begin),
Length: uint32(length),
}
}
}
return nil
}
func (ps *pieceScheduler) put(index int, begin int, block []byte) bool {
pc := &ps.pieces[index]
if pc.put(begin, block) {
ps.meta.Info.Write(ps.dir, index, pc.data)
ps.verified++
ps.removeIncomplete(index)
return true
}
return false
}
func (ps *pieceScheduler) removeIncomplete(index int) {
for i, idx := range ps.incomplete {
if idx == index {
ps.incomplete = append(ps.incomplete[:i], ps.incomplete[i+1:]...)
return
}
}
}
func (ps *pieceScheduler) peerDisconnected(p *bt.Peer) {
for i := range ps.pieces {
ps.pieces[i].resetPeer(p)
}
}
func (ps *pieceScheduler) isComplete() bool {
return ps.verified == len(ps.pieces)
}
func (ps *pieceScheduler) progress() (verified, total int) {
return ps.verified, len(ps.pieces)
}
func (ps *pieceScheduler) pieceData(index int) []byte {
pc := &ps.pieces[index]
if !pc.done {
return nil
}
return pc.data
}
func (ps *pieceScheduler) count() int {
return len(ps.pieces)
}
func (ps *pieceScheduler) bitfield(all bool) []byte {
n := len(ps.pieces)
bf := make([]byte, (n+7)/8)
for i := range ps.pieces {
if all || ps.pieces[i].done {
bt.SetBit(bf, i)
}
}
return bf
}

419
client/torrent.go Normal file
View File

@ -0,0 +1,419 @@
package client
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"time"
"storrent/bt"
"storrent/dht"
"storrent/metainfo"
"storrent/utp"
)
type state int
const (
stopped state = iota
downloading
seeding
done
errored
)
const (
maxPeers = 8
maxPending = 5
dialTimeout = 30 * time.Second
retryInterval = 30 * time.Second
dhtTimeout = 5 * time.Minute
recvBufSize = 128
)
type recv struct {
msg *bt.Msg
err error
p *bt.Peer
}
type statusReq struct {
field string
resp chan string
}
type torrent struct {
id int
meta *metainfo.File
dir string
dht *dht.DHT
utp *utp.Socket
start chan struct{}
stop chan struct{}
seed chan struct{}
incoming chan *bt.Peer
addPeer chan *net.TCPAddr
status chan statusReq
peers *peerManager
scheduler *pieceScheduler
dialPool *dialPool
st state
down int64
up int64
ctx context.Context
cancel context.CancelFunc
recvs chan recv
retry <-chan time.Time
}
func newTorrent(id int, m *metainfo.File, dir string, d *dht.DHT, utpSock *utp.Socket) *torrent {
t := &torrent{
id: id,
meta: m,
dir: dir,
dht: d,
utp: utpSock,
start: make(chan struct{}, 1),
stop: make(chan struct{}, 1),
seed: make(chan struct{}, 1),
incoming: make(chan *bt.Peer),
addPeer: make(chan *net.TCPAddr, 8),
status: make(chan statusReq),
peers: newPeerManager(maxPeers),
scheduler: newPieceScheduler(m, dir),
dialPool: newDialPool(d, utpSock, m.InfoHash),
st: stopped,
recvs: make(chan recv, recvBufSize),
}
return t
}
func (t *torrent) run() {
for {
select {
case <-t.start:
t.handleStart()
case <-t.seed:
t.handleSeed()
case <-t.stop:
t.handleStop()
case <-t.retry:
t.handleRetry()
case d := <-t.dialPool.results:
t.handleDialResult(d)
case p := <-t.incoming:
t.handleIncoming(p)
case addr := <-t.addPeer:
t.handleAddPeer(addr)
case r := <-t.recvs:
t.handleRecv(r)
case req := <-t.status:
req.resp <- t.handleStatus(req.field)
}
}
}
func (t *torrent) handleStart() {
if t.st != stopped {
return
}
t.st = downloading
t.ctx, t.cancel = context.WithCancel(context.Background())
t.dialPool.start(t.ctx)
}
func (t *torrent) handleSeed() {
if t.st != stopped && t.st != done {
return
}
if !t.scheduler.isComplete() {
return
}
t.st = seeding
t.ctx, t.cancel = context.WithCancel(context.Background())
t.dialPool.start(t.ctx)
}
func (t *torrent) handleStop() {
if !t.isActive() {
return
}
t.cancel()
t.dialPool.stop()
t.peers.closeAll()
t.retry = nil
t.st = stopped
}
func (t *torrent) handleRetry() {
t.retry = nil
if t.isActive() {
t.dialPool.start(t.ctx)
}
}
func (t *torrent) handleDialResult(d dialResult) {
if d.done {
if t.peers.count() < maxPeers && t.isActive() {
t.retry = time.After(retryInterval)
}
return
}
if d.peer == nil {
return
}
p := d.peer
if !t.peers.add(p) {
p.Close()
return
}
go t.recvLoop(p)
if err := t.initPeer(p); err != nil {
p.Close()
t.peers.remove(p)
}
}
func (t *torrent) handleIncoming(p *bt.Peer) {
t.peers.addUnlimited(p)
go t.recvLoop(p)
if err := p.Send(&bt.Msg{Kind: bt.Bitfield, Bitfield: t.scheduler.bitfield(false)}); err != nil {
p.Close()
t.peers.remove(p)
return
}
if err := p.Send(&bt.Msg{Kind: bt.Unchoke}); err != nil {
p.Close()
t.peers.remove(p)
}
}
func (t *torrent) handleAddPeer(addr *net.TCPAddr) {
if !t.isActive() {
return
}
t.dialPool.dialSingle(t.ctx, addr)
}
func (t *torrent) handleRecv(r recv) {
if r.err != nil {
r.p.Close()
t.peers.remove(r.p)
t.scheduler.peerDisconnected(r.p)
if t.peers.count() < maxPeers && t.isActive() {
t.dialPool.start(t.ctx)
}
return
}
if r.msg == nil {
return
}
p := r.p
m := r.msg
switch m.Kind {
case bt.Choke:
p.Choked = true
p.Pending = 0
case bt.Unchoke:
p.Choked = false
case bt.Have:
p.SetPiece(int(m.Index))
case bt.Bitfield:
p.SetPieces(m.Bitfield, t.scheduler.count())
case bt.Piece:
t.scheduler.put(int(m.Index), int(m.Begin), m.Block)
p.Pending--
t.down += int64(len(m.Block))
case bt.Request:
t.handleRequest(p, m)
}
t.sendRequests(p)
t.checkCompletion()
}
func (t *torrent) handleRequest(p *bt.Peer, m *bt.Msg) {
data := t.scheduler.pieceData(int(m.Index))
if data == nil {
return
}
end := int(m.Begin + m.Length)
if end > len(data) {
return
}
if err := p.Send(&bt.Msg{
Kind: bt.Piece,
Index: m.Index,
Begin: m.Begin,
Block: data[m.Begin:end],
}); err != nil {
p.Close()
return
}
t.up += int64(m.Length)
}
func (t *torrent) sendRequests(p *bt.Peer) {
if p.Choked {
return
}
for p.Pending < maxPending {
req := t.scheduler.nextRequest(p)
if req == nil {
break
}
if err := p.Send(req); err != nil {
p.Close()
break
}
p.Pending++
}
}
func (t *torrent) checkCompletion() {
if t.st == downloading && t.scheduler.isComplete() {
t.st = done
t.peers.closeAll()
}
}
func (t *torrent) initPeer(p *bt.Peer) error {
if t.st == seeding {
return p.Send(&bt.Msg{Kind: bt.Bitfield, Bitfield: t.scheduler.bitfield(true)})
}
return p.Send(&bt.Msg{Kind: bt.Interested})
}
func (t *torrent) recvLoop(p *bt.Peer) {
for {
msg, err := p.Recv()
t.recvs <- recv{msg, err, p}
if err != nil {
return
}
}
}
func (t *torrent) isActive() bool {
return t.st == downloading || t.st == seeding
}
func (t *torrent) startDownload() {
select {
case t.start <- struct{}{}:
default:
}
}
func (t *torrent) startSeed() {
select {
case t.seed <- struct{}{}:
default:
}
}
func (t *torrent) stopTorrent() {
select {
case t.stop <- struct{}{}:
default:
}
}
func (t *torrent) enqueuePeer(addr *net.TCPAddr) {
t.addPeer <- addr
}
func (t *torrent) getStatus(field string) string {
ch := make(chan string)
t.status <- statusReq{field, ch}
return <-ch
}
func (t *torrent) handleStatus(field string) string {
switch field {
case "name":
return t.meta.Info.Name
case "state":
return t.st.String()
case "progress":
verified, total := t.scheduler.progress()
if total == 0 {
return "0"
}
return strconv.Itoa(verified * 100 / total)
case "size":
return strconv.FormatInt(t.meta.Size, 10)
case "down":
return strconv.FormatInt(t.down, 10)
case "up":
return strconv.FormatInt(t.up, 10)
case "pieces":
verified, total := t.scheduler.progress()
return fmt.Sprintf("%d/%d", verified, total)
case "peers":
return t.formatPeers()
}
return ""
}
func (t *torrent) formatPeers() string {
var buf strings.Builder
npieces := t.scheduler.count()
for _, p := range t.peers.all() {
st := "unchoked"
if p.Choked {
st = "choked"
}
have := 0
for _, h := range p.Have {
if h {
have++
}
}
fmt.Fprintf(&buf, "%s %s %s have:%d/%d pending:%d\n",
p.Addr, p.PeerID[:8], st, have, npieces, p.Pending)
}
return buf.String()
}
func (s state) String() string {
switch s {
case stopped:
return "stopped"
case downloading:
return "downloading"
case seeding:
return "seeding"
case done:
return "done"
case errored:
return "error"
}
return "unknown"
}

121
client/torrent_test.go Normal file
View File

@ -0,0 +1,121 @@
package client
import (
"crypto/sha1"
"testing"
"storrent/bt"
)
func TestPut(t *testing.T) {
tests := []struct {
in string
hash [20]byte
want bool
}{
{
in: "hello",
hash: sha1.Sum([]byte("hello")),
want: true,
},
{
in: "wrong",
hash: sha1.Sum([]byte("hello")),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.in, func(t *testing.T) {
p := &piece{
hash: tt.hash,
size: int64(len(tt.in)),
data: make([]byte, len(tt.in)),
have: make([]bool, 1),
reqs: make([]bool, 1),
reqPeer: make([]*bt.Peer, 1),
}
if got := p.put(0, []byte(tt.in)); got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
func TestPutMultiple(t *testing.T) {
data := make([]byte, bt.BlockSize+5)
for i := range data {
data[i] = byte(i)
}
p := &piece{
hash: sha1.Sum(data),
size: int64(len(data)),
data: make([]byte, len(data)),
have: make([]bool, 2),
reqs: make([]bool, 2),
reqPeer: make([]*bt.Peer, 2),
}
tests := []struct {
name string
begin int
data []byte
want bool
}{
{
name: "first block",
begin: 0,
data: data[:bt.BlockSize],
want: false,
},
{
name: "second block",
begin: bt.BlockSize,
data: data[bt.BlockSize:],
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := p.put(tt.begin, tt.data); got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
func TestNextRequest(t *testing.T) {
tests := []struct {
name string
pieceSize int64
have []bool
wantLength uint32
}{
{
name: "last block partial",
pieceSize: bt.BlockSize + 100,
have: []bool{true, false},
wantLength: 100,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ps := &pieceScheduler{
pieces: []piece{{
size: tt.pieceSize,
have: tt.have,
reqs: make([]bool, len(tt.have)),
reqPeer: make([]*bt.Peer, len(tt.have)),
}},
incomplete: []int{0},
}
c := &bt.Peer{Have: []bool{true}}
msg := ps.nextRequest(c)
if msg == nil {
t.Fatal("expected request")
}
if msg.Length != tt.wantLength {
t.Errorf("got %d, want %d", msg.Length, tt.wantLength)
}
})
}
}

228
dht/dht.go Normal file
View File

@ -0,0 +1,228 @@
package dht
import (
"context"
"fmt"
"net"
"slices"
"strconv"
"sync"
"time"
)
const (
queryTimeout = 5 * time.Second
alpha = 3
maxQueries = 32
)
type DHT struct {
conn *net.UDPConn
id [20]byte
nodes []node
pending map[string]chan *resp
mu sync.Mutex
}
func New(port int) (*DHT, error) {
addr := &net.UDPAddr{Port: port}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return nil, err
}
d := &DHT{
conn: conn,
id: genNodeID(),
pending: make(map[string]chan *resp),
}
go d.run()
return d, nil
}
func (d *DHT) Close() {
d.conn.Close()
}
func (d *DHT) run() {
buf := make([]byte, 65535)
for {
n, _, err := d.conn.ReadFromUDP(buf)
if err != nil {
return
}
msg, err := decodeMsg(buf[:n])
if err != nil {
continue
}
d.mu.Lock()
ch, ok := d.pending[msg.T]
if ok {
delete(d.pending, msg.T)
}
d.mu.Unlock()
if ok {
ch <- &msg.R
}
}
}
func xorDistance(a, b [20]byte) [20]byte {
var dist [20]byte
for i := range 20 {
dist[i] = a[i] ^ b[i]
}
return dist
}
func compareDist(a, b [20]byte) int {
for i := range 20 {
if a[i] < b[i] {
return -1
}
if a[i] > b[i] {
return 1
}
}
return 0
}
func (d *DHT) query(ctx context.Context, addr *net.UDPAddr, q string, a args) (*resp, error) {
ctx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
txid := genTxID()
a.ID = d.id
m := &msg{
T: txid,
Y: query,
Q: q,
A: a,
}
data, err := m.Encode()
if err != nil {
return nil, err
}
reply := make(chan *resp, 1)
d.mu.Lock()
d.pending[txid] = reply
d.mu.Unlock()
defer func() {
d.mu.Lock()
delete(d.pending, txid)
d.mu.Unlock()
}()
if _, err := d.conn.WriteToUDP(data, addr); err != nil {
return nil, err
}
select {
case r := <-reply:
return r, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (d *DHT) Bootstrap(ctx context.Context, addr string) error {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return err
}
ips, err := net.LookupHost(host)
if err != nil {
return err
}
if len(ips) == 0 {
return fmt.Errorf("no bootstrap nodes")
}
p, err := strconv.Atoi(port)
if err != nil {
return fmt.Errorf("invalid port: %w", err)
}
uaddr := &net.UDPAddr{
IP: net.ParseIP(ips[0]),
Port: p,
}
resp, err := d.query(ctx, uaddr, findNode, args{Target: d.id})
if err != nil {
return err
}
d.nodes = append(d.nodes, resp.Nodes...)
return nil
}
func sortByDist(nodes []node, target [20]byte) {
slices.SortFunc(nodes, func(a, b node) int {
return compareDist(xorDistance(a.ID, target), xorDistance(b.ID, target))
})
}
func (d *DHT) GetPeers(ctx context.Context, h [20]byte) ([]*net.TCPAddr, error) {
queried := make(map[string]bool)
candidates := make([]node, len(d.nodes))
copy(candidates, d.nodes)
sortByDist(candidates, h)
results := make(chan *resp)
inflight := 0
queryCount := 0
var peers []*net.TCPAddr
for queryCount < maxQueries {
if ctx.Err() != nil {
break
}
for inflight < alpha && queryCount < maxQueries {
n, i := nextCandidate(candidates, queried)
if i < 0 {
break
}
candidates = slices.Delete(candidates, i, i+1)
queried[n.Addr.String()] = true
inflight++
queryCount++
go func(addr *net.UDPAddr) {
r, _ := d.query(ctx, addr, getPeers, args{InfoHash: h})
results <- r
}(n.Addr)
}
if inflight == 0 {
break
}
r := <-results
inflight--
if r == nil {
continue
}
for _, p := range r.Peers {
if a := decodePeer(p); a != nil {
peers = append(peers, a)
}
}
for _, c := range r.Nodes {
if !queried[c.Addr.String()] {
candidates = append(candidates, c)
}
}
sortByDist(candidates, h)
}
for range inflight {
<-results
}
if len(peers) > 0 {
return peers, nil
}
return nil, fmt.Errorf("no peers found")
}
func nextCandidate(candidates []node, queried map[string]bool) (node, int) {
for i, c := range candidates {
if !queried[c.Addr.String()] {
return c, i
}
}
return node{}, -1
}

246
dht/msg.go Normal file
View File

@ -0,0 +1,246 @@
package dht
// https://www.bittorrent.org/beps/bep_0005.html
// Peer: TCP
// Node: UDP
// BEP 5: DHT Protocol - KRPC over UDP
// Message: bencode dict with keys:
// t: transaction id (2 bytes)
// y: type "q" (query), "r" (response), "e" (error)
// q: query name (ping, find_node, get_peers, announce_peer)
// a: query args dict
// r: response dict
// e: error [code, message]
// Compact formats:
// peer: 6 bytes (4 ip + 2 port)
// node: 26 bytes (20 id + 4 ip + 2 port)
import (
"crypto/rand"
"encoding/binary"
"fmt"
"net"
"storrent/bencode"
)
const (
query = "q"
response = "r"
error_ = "e"
)
const (
ping = "ping"
findNode = "find_node"
getPeers = "get_peers"
announcePeer = "announce_peer"
)
type node struct {
ID [20]byte
Addr *net.UDPAddr
}
type args struct {
ID [20]byte
Target [20]byte
InfoHash [20]byte
Port int
Token string
ImpliedPort int
}
type resp struct {
ID [20]byte
Nodes []node
Peers []string
Token string
}
type errResp struct {
Code int
Msg string
}
type msg struct {
T string
Y string
Q string
A args
R resp
E errResp
}
func genTxID() string {
b := make([]byte, 2)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return string(b)
}
func genNodeID() [20]byte {
var id [20]byte
if _, err := rand.Read(id[:]); err != nil {
panic(err)
}
return id
}
func (m *msg) Encode() ([]byte, error) {
d := make(map[string]any)
d["t"] = m.T
d["y"] = m.Y
switch m.Y {
case query:
d["q"] = m.Q
a := make(map[string]any)
a["id"] = string(m.A.ID[:])
switch m.Q {
case findNode:
a["target"] = string(m.A.Target[:])
case getPeers:
a["info_hash"] = string(m.A.InfoHash[:])
case announcePeer:
a["info_hash"] = string(m.A.InfoHash[:])
a["port"] = int64(m.A.Port)
a["token"] = m.A.Token
if m.A.ImpliedPort != 0 {
a["implied_port"] = int64(m.A.ImpliedPort)
}
}
d["a"] = a
case response:
r := make(map[string]any)
r["id"] = string(m.R.ID[:])
if len(m.R.Nodes) > 0 {
r["nodes"] = encodeNodes(m.R.Nodes)
}
if len(m.R.Peers) > 0 {
vals := make([]any, len(m.R.Peers))
for i, p := range m.R.Peers {
vals[i] = p
}
r["values"] = vals
}
if m.R.Token != "" {
r["token"] = m.R.Token
}
d["r"] = r
case error_:
d["e"] = []any{int64(m.E.Code), m.E.Msg}
}
return bencode.Encode(d)
}
func decodeMsg(data []byte) (*msg, error) {
v, _, err := bencode.Decode(data)
if err != nil {
return nil, err
}
d, ok := v.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid krpc message")
}
m := &msg{}
m.T, _ = d["t"].(string)
m.Y, _ = d["y"].(string)
switch m.Y {
case query:
m.Q, _ = d["q"].(string)
if a, ok := d["a"].(map[string]any); ok {
if id, ok := a["id"].(string); ok && len(id) == 20 {
copy(m.A.ID[:], id)
}
if target, ok := a["target"].(string); ok && len(target) == 20 {
copy(m.A.Target[:], target)
}
if ih, ok := a["info_hash"].(string); ok && len(ih) == 20 {
copy(m.A.InfoHash[:], ih)
}
if port, ok := a["port"].(int64); ok {
m.A.Port = int(port)
}
if token, ok := a["token"].(string); ok {
m.A.Token = token
}
if iport, ok := a["implied_port"].(int64); ok {
m.A.ImpliedPort = int(iport)
}
}
case response:
if r, ok := d["r"].(map[string]any); ok {
if id, ok := r["id"].(string); ok && len(id) == 20 {
copy(m.R.ID[:], id)
}
if nodes, ok := r["nodes"].(string); ok {
m.R.Nodes = decodeNodes(nodes)
}
if vals, ok := r["values"].([]any); ok {
for _, v := range vals {
if p, ok := v.(string); ok {
m.R.Peers = append(m.R.Peers, p)
}
}
}
if token, ok := r["token"].(string); ok {
m.R.Token = token
}
}
case error_:
if e, ok := d["e"].([]any); ok && len(e) >= 2 {
if code, ok := e[0].(int64); ok {
m.E.Code = int(code)
}
m.E.Msg, _ = e[1].(string)
}
}
return m, nil
}
// 26 = 20 id + 4 ip + 2 port
func encodeNodes(nodes []node) string {
buf := make([]byte, 26*len(nodes))
for i, n := range nodes {
off := i * 26
copy(buf[off:], n.ID[:])
copy(buf[off+20:], n.Addr.IP.To4())
binary.BigEndian.PutUint16(buf[off+24:], uint16(n.Addr.Port))
}
return string(buf)
}
func decodeNodes(s string) []node {
data := []byte(s)
if len(data)%26 != 0 {
return nil
}
n := len(data) / 26
nodes := make([]node, n)
for i := range nodes {
off := i * 26
copy(nodes[i].ID[:], data[off:])
ip := net.IP(data[off+20 : off+24])
port := binary.BigEndian.Uint16(data[off+24:])
nodes[i].Addr = &net.UDPAddr{IP: ip, Port: int(port)}
}
return nodes
}
// 6 = 4 ip + 2 port
func decodePeer(s string) *net.TCPAddr {
if len(s) != 6 {
return nil
}
data := []byte(s)
return &net.TCPAddr{
IP: net.IP(data[:4]),
Port: int(binary.BigEndian.Uint16(data[4:])),
}
}

67
dht/msg_test.go Normal file
View File

@ -0,0 +1,67 @@
package dht
import (
"net"
"strings"
"testing"
)
func nid(c string) [20]byte {
var id [20]byte
copy(id[:], strings.Repeat(c, 20))
return id
}
func TestEncodeDecodeNodes(t *testing.T) {
tests := []struct {
name string
nodes []node
}{
{
name: "empty",
nodes: []node{},
},
{
name: "single",
nodes: []node{
{
ID: nid("a"),
Addr: &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 6881,
},
},
},
},
{
name: "multiple",
nodes: []node{
{
ID: nid("a"),
Addr: &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 6881,
},
},
{
ID: nid("a"),
Addr: &net.UDPAddr{
IP: net.IPv4(192, 168, 1, 1),
Port: 8080,
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data := encodeNodes(tt.nodes)
got := decodeNodes(data)
data2 := encodeNodes(got)
if data != data2 {
t.Errorf("roundtrip mismatch: got %x, want %x", data2, data)
}
})
}
}

156
fs/fs.go Normal file
View File

@ -0,0 +1,156 @@
package fs
import (
"fmt"
"strconv"
"strings"
"sync"
"github.com/knusbaum/go9p/fs"
"github.com/knusbaum/go9p/proto"
"storrent/client"
)
type Dir struct {
stat proto.Stat
parent fs.Dir
man *client.Manager
fsys *fs.FS
sync.RWMutex
}
func New(e *client.Manager) *fs.FS {
fsys, root := fs.NewFS("storrent", "storrent", 0555)
root.AddChild(&fs.WrappedFile{
File: fs.NewBaseFile(fsys.NewStat("ctl", "storrent", "storrent", 0222)),
WriteF: func(fid uint64, offset uint64, data []byte) (uint32, error) {
cmd := strings.TrimSpace(string(data))
if strings.HasPrefix(cmd, "add ") {
path := strings.TrimSpace(cmd[4:])
_, err := e.Add(path)
if err != nil {
return 0, err
}
return uint32(len(data)), nil
}
return 0, fmt.Errorf("unknown command: %s", cmd)
},
})
root.AddChild(fs.NewDynamicFile(
fsys.NewStat("list", "storrent", "storrent", 0444),
func() []byte {
data := e.List()
if len(data) > 0 {
return append(data, '\n')
}
return data
},
))
dir := &Dir{
stat: *fsys.NewStat("torrents", "storrent", "storrent", 0555|proto.DMDIR),
man: e,
fsys: fsys,
}
dir.stat.Qid.Qtype = uint8(dir.stat.Mode >> 24)
root.AddChild(dir)
return fsys
}
func (d *Dir) Stat() proto.Stat {
d.Lock()
defer d.Unlock()
return d.stat
}
func (d *Dir) WriteStat(s *proto.Stat) error {
d.Lock()
defer d.Unlock()
d.stat = *s
return nil
}
func (d *Dir) SetParent(p fs.Dir) {
d.Lock()
defer d.Unlock()
d.parent = p
}
func (d *Dir) Parent() fs.Dir {
d.RLock()
defer d.RUnlock()
return d.parent
}
func (d *Dir) Children() map[string]fs.FSNode {
data := d.man.List()
m := make(map[string]fs.FSNode)
if len(data) == 0 {
return m
}
for _, name := range strings.Split(string(data), "\n") {
id, err := strconv.Atoi(name)
if err != nil {
continue
}
m[name] = d.newTorrentDir(id)
}
return m
}
func (d *Dir) newTorrentDir(id int) fs.FSNode {
dir := fs.NewStaticDir(d.fsys.NewStat(strconv.Itoa(id), "storrent", "storrent", 0555))
dir.AddChild(&fs.WrappedFile{
File: fs.NewBaseFile(d.fsys.NewStat("ctl", "storrent", "storrent", 0222)),
WriteF: func(fid uint64, offset uint64, data []byte) (uint32, error) {
cmd := strings.TrimSpace(string(data))
var err error
switch {
case cmd == "start":
err = d.man.Start(id)
case cmd == "stop":
err = d.man.Stop(id)
case cmd == "seed":
err = d.man.Seed(id)
case cmd == "remove":
err = d.man.Remove(id)
case strings.HasPrefix(cmd, "peer "):
err = d.man.AddPeer(id, strings.TrimSpace(cmd[5:]))
default:
return 0, fmt.Errorf("unknown command: %s", cmd)
}
if err != nil {
return 0, err
}
return uint32(len(data)), nil
},
})
dir.AddChild(d.newStatusFile(id, "name"))
dir.AddChild(d.newStatusFile(id, "state"))
dir.AddChild(d.newStatusFile(id, "progress"))
dir.AddChild(d.newStatusFile(id, "size"))
dir.AddChild(d.newStatusFile(id, "down"))
dir.AddChild(d.newStatusFile(id, "up"))
dir.AddChild(d.newStatusFile(id, "pieces"))
dir.AddChild(d.newStatusFile(id, "peers"))
return dir
}
func (d *Dir) newStatusFile(id int, field string) fs.FSNode {
return fs.NewDynamicFile(
d.fsys.NewStat(field, "storrent", "storrent", 0444),
func() []byte {
data, err := d.man.Status(id, field)
if err != nil {
return nil
}
if len(data) > 0 {
return append(data, '\n')
}
return data
},
)
}

12
go.mod Normal file
View File

@ -0,0 +1,12 @@
module storrent
go 1.25.4
require github.com/knusbaum/go9p v1.18.0
require (
9fans.net/go v0.0.2 // indirect
github.com/Plan9-Archive/libauth v0.0.0-20180917063427-d1ca9e94969d // indirect
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 // indirect
github.com/fhs/mux9p v0.3.1 // indirect
)

25
go.sum Normal file
View File

@ -0,0 +1,25 @@
9fans.net/go v0.0.2 h1:RYM6lWITV8oADrwLfdzxmt8ucfW6UtP9v1jg4qAbqts=
9fans.net/go v0.0.2/go.mod h1:lfPdxjq9v8pVQXUMBCx5EO5oLXWQFlKRQgs1kEkjoIM=
github.com/Plan9-Archive/libauth v0.0.0-20180917063427-d1ca9e94969d h1:xH/U6K+HYxh1480TkQYRqRO8F2RJsg+R6wFiVJzdldg=
github.com/Plan9-Archive/libauth v0.0.0-20180917063427-d1ca9e94969d/go.mod h1:UKp8dv9aeaZoQFWin7eQXtz89iHly1YAFZNn3MCutmQ=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 h1:OJyUGMJTzHTd1XQp98QTaHernxMYzRaOasRir9hUlFQ=
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
github.com/fhs/mux9p v0.3.1 h1:x1UswUWZoA9vrA02jfisndCq3xQm+wrQUxUt5N99E08=
github.com/fhs/mux9p v0.3.1/go.mod h1:F4hwdenmit0WDoNVT2VMWlLJrBVCp/8UhzJa7scfjEQ=
github.com/hanwen/go-fuse v1.0.0/go.mod h1:unqXarDXqzAk0rt98O2tVndEPIpUgLD9+rwFisZH3Ok=
github.com/hanwen/go-fuse/v2 v2.0.3/go.mod h1:0EQM6aH2ctVpvZ6a+onrQ/vaykxh2GH7hy3e13vzTUY=
github.com/knusbaum/go9p v1.18.0 h1:/Y67RNvNKX1ZV1IOdnO1lIetiF0X+CumOyvEc0011GI=
github.com/knusbaum/go9p v1.18.0/go.mod h1:HtMoJKqZUe1Oqag5uJqG5RKQ9gWPSP+wolsnLLv44r8=
github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201020230747-6e5568b54d1a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

49
main.go Normal file
View File

@ -0,0 +1,49 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"github.com/knusbaum/go9p"
"storrent/client"
"storrent/dht"
"storrent/fs"
"storrent/utp"
)
func main() {
addr := flag.String("addr", ":5640", "9P listen address")
dir := flag.String("dir", "./download", "download directory")
dhtPort := flag.Int("dht", 6881, "DHT port")
btPort := flag.Int("bt", 0, "BT listen port")
utpPort := flag.Int("utp", 6882, "uTP port")
bootstrap := flag.String("bootstrap", "router.bittorrent.com:6881", "DHT bootstrap node")
flag.Parse()
d, err := dht.New(*dhtPort)
if err != nil {
log.Fatal(err)
}
if err := d.Bootstrap(context.Background(), *bootstrap); err != nil {
log.Fatal(err)
}
utpSock, err := utp.New(fmt.Sprintf(":%d", *utpPort))
if err != nil {
log.Fatal(err)
}
utpSock.Start()
m, err := client.NewManager(*dir, d, *btPort, utpSock)
if err != nil {
log.Fatal(err)
}
go m.Run()
if err := go9p.Serve(*addr, fs.New(m).Server()); err != nil {
log.Fatal(err)
}
}

182
metainfo/metainfo.go Normal file
View File

@ -0,0 +1,182 @@
package metainfo
import (
"crypto/sha1"
"fmt"
"os"
"storrent/bencode"
)
type File struct {
Info Info
InfoHash [20]byte
Size int64
}
type Info struct {
Name string
PieceSize int64
Pieces [][20]byte
Size int64
Files []Entry
}
type Entry struct {
Size int64
Offset int64
Path []string
}
type Segment struct {
File int
Offset int64
Size int64
}
func (i *Info) Segments(off, size int64) []Segment {
var segs []Segment
for idx, f := range i.Files {
end := f.Offset + f.Size
if off >= end {
continue
}
n := min(size, end-off)
segs = append(segs, Segment{File: idx, Offset: off - f.Offset, Size: n})
off += n
size -= n
if size == 0 {
break
}
}
return segs
}
func Parse(path string) (*File, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return ParseBytes(data)
}
func ParseBytes(data []byte) (*File, error) {
v, _, err := bencode.Decode(data)
if err != nil {
return nil, err
}
d, ok := v.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid torrent dict")
}
f := &File{}
raw, err := findInfoBytes(data)
if err != nil {
return nil, err
}
f.InfoHash = sha1.Sum(raw)
dict, ok := d["info"].(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid info dict")
}
f.Info, err = parseInfo(dict)
if err != nil {
return nil, err
}
if f.Info.Size > 0 {
f.Size = f.Info.Size
} else {
for _, e := range f.Info.Files {
f.Size += e.Size
}
}
return f, nil
}
func findInfoBytes(data []byte) ([]byte, error) {
if len(data) == 0 || data[0] != 'd' {
return nil, fmt.Errorf("not a dict")
}
i := 1
for i < len(data) && data[i] != 'e' {
k, n, err := bencode.DecodeString(data[i:])
if err != nil {
return nil, err
}
i += n
if k == "info" {
_, n, err := bencode.Decode(data[i:])
if err != nil {
return nil, err
}
return data[i : i+n], nil
}
_, n, err = bencode.Decode(data[i:])
if err != nil {
return nil, err
}
i += n
}
return nil, fmt.Errorf("info not found")
}
func parseInfo(d map[string]any) (Info, error) {
var info Info
name, ok := d["name"].(string)
if !ok {
return info, fmt.Errorf("invalid name")
}
info.Name = name
pl, ok := d["piece length"].(int64)
if !ok {
return info, fmt.Errorf("invalid piece length")
}
info.PieceSize = pl
ps, ok := d["pieces"].(string)
if !ok || len(ps)%20 != 0 {
return info, fmt.Errorf("invalid pieces")
}
npieces := len(ps) / 20
info.Pieces = make([][20]byte, npieces)
for i := range npieces {
copy(info.Pieces[i][:], ps[i*20:(i+1)*20])
}
if n, ok := d["length"].(int64); ok {
info.Size = n
return info, nil
}
fs, ok := d["files"].([]any)
if !ok {
return info, fmt.Errorf("invalid files")
}
off := int64(0)
for _, f := range fs {
fd, ok := f.(map[string]any)
if !ok {
return info, fmt.Errorf("invalid file entry")
}
n, ok := fd["length"].(int64)
if !ok {
return info, fmt.Errorf("invalid file length")
}
list, ok := fd["path"].([]any)
if !ok {
return info, fmt.Errorf("invalid file path")
}
var path []string
for _, p := range list {
s, ok := p.(string)
if !ok {
return info, fmt.Errorf("invalid path element")
}
path = append(path, s)
}
info.Files = append(info.Files, Entry{Size: n, Offset: off, Path: path})
off += n
}
return info, nil
}

141
metainfo/metainfo_test.go Normal file
View File

@ -0,0 +1,141 @@
package metainfo
import (
"crypto/sha1"
"testing"
"storrent/bencode"
)
func TestParse(t *testing.T) {
tests := []struct {
name string
info map[string]any
wantName string
wantSize int64
}{
{
name: "single file",
info: map[string]any{
"name": "test.txt",
"piece length": int64(16384),
"pieces": string(make([]byte, 20)),
"length": int64(100),
},
wantName: "test.txt",
wantSize: 100,
},
{
name: "multi file",
info: map[string]any{
"name": "dir",
"piece length": int64(16384),
"pieces": string(make([]byte, 20)),
"files": []any{
map[string]any{"length": int64(100), "path": []any{"a.txt"}},
map[string]any{"length": int64(200), "path": []any{"b.txt"}},
},
},
wantName: "dir",
wantSize: 300,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := bencode.Encode(map[string]any{"info": tt.info})
if err != nil {
t.Fatalf("Encode: %v", err)
}
m, err := ParseBytes(data)
if err != nil {
t.Fatalf("ParseBytes: %v", err)
}
if m.Info.Name != tt.wantName {
t.Errorf("Name: got %s, want %s", m.Info.Name, tt.wantName)
}
if m.Size != tt.wantSize {
t.Errorf("Size: got %d, want %d", m.Size, tt.wantSize)
}
})
}
}
func TestInfoHash(t *testing.T) {
tests := []struct {
name string
raw string
}{
{
name: "basic",
raw: "d6:lengthi1e4:name4:test12:piece lengthi16384e6:pieces20:01234567890123456789e",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
raw := []byte(tt.raw)
torrent := append([]byte("d4:info"), raw...)
torrent = append(torrent, 'e')
m, err := ParseBytes(torrent)
if err != nil {
t.Fatalf("ParseBytes: %v", err)
}
want := sha1.Sum(raw)
if m.InfoHash != want {
t.Errorf("InfoHash: got %x, want %x", m.InfoHash, want)
}
})
}
}
func TestSegments(t *testing.T) {
tests := []struct {
name string
files []Entry
off int64
size int64
want []Segment
}{
{
name: "single file",
files: []Entry{{Size: 32, Offset: 0}},
off: 0,
size: 32,
want: []Segment{{0, 0, 32}},
},
{
name: "spans two files",
files: []Entry{{Size: 16, Offset: 0}, {Size: 16, Offset: 16}},
off: 0,
size: 32,
want: []Segment{{0, 0, 16}, {1, 0, 16}},
},
{
name: "middle of file",
files: []Entry{{Size: 32, Offset: 0}},
off: 8,
size: 16,
want: []Segment{{0, 8, 16}},
},
{
name: "skip first file",
files: []Entry{{Size: 16, Offset: 0}, {Size: 16, Offset: 16}},
off: 16,
size: 16,
want: []Segment{{1, 0, 16}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
info := Info{Files: tt.files}
got := info.Segments(tt.off, tt.size)
if len(got) != len(tt.want) {
t.Fatalf("got %d segments, want %d", len(got), len(tt.want))
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("segment %d: got %+v, want %+v", i, got[i], tt.want[i])
}
}
})
}
}

68
metainfo/storage.go Normal file
View File

@ -0,0 +1,68 @@
package metainfo
import (
"os"
"path/filepath"
)
func (info *Info) Write(dir string, i int, data []byte) error {
off := int64(i) * info.PieceSize
if info.Size > 0 {
return writeAt(filepath.Join(dir, info.Name), off, data)
}
for _, seg := range info.Segments(off, int64(len(data))) {
f := info.Files[seg.File]
path := filepath.Join(dir, info.Name, filepath.Join(f.Path...))
if err := writeAt(path, seg.Offset, data[:seg.Size]); err != nil {
return err
}
data = data[seg.Size:]
}
return nil
}
func (info *Info) Read(dir string, i int, size int64) []byte {
off := int64(i) * info.PieceSize
if info.Size > 0 {
return readAt(filepath.Join(dir, info.Name), off, size)
}
data := make([]byte, size)
pos := int64(0)
for _, seg := range info.Segments(off, size) {
f := info.Files[seg.File]
path := filepath.Join(dir, info.Name, filepath.Join(f.Path...))
chunk := readAt(path, seg.Offset, seg.Size)
if chunk == nil {
return nil
}
copy(data[pos:], chunk)
pos += seg.Size
}
return data
}
func writeAt(path string, off int64, data []byte) error {
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return err
}
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return err
}
defer f.Close()
_, err = f.WriteAt(data, off)
return err
}
func readAt(path string, off, size int64) []byte {
f, err := os.Open(path)
if err != nil {
return nil
}
defer f.Close()
data := make([]byte, size)
if _, err := f.ReadAt(data, off); err != nil {
return nil
}
return data
}

305
utp/conn.go Normal file
View File

@ -0,0 +1,305 @@
package utp
import (
"context"
"io"
"net"
"time"
)
const (
initRTO = 1000 // ms
retransCheck = 500 // ms
defaultWnd = 64 * 1024
)
type Conn struct {
sock *Socket
addr *net.UDPAddr
recvID uint16
sendID uint16
in chan packet
reads chan readReq
writes chan writeReq
closeReq chan struct{}
ready chan struct{}
ctx context.Context
cancel context.CancelFunc
seq uint16
ack uint16
unacked []sent
reorder map[uint16][]byte
rest []byte
pending *readReq
rto uint32
}
type readReq struct {
p []byte
reply chan readResp
}
type readResp struct {
n int
err error
}
type writeReq struct {
data []byte
reply chan error
}
type sent struct {
seq uint16
data []byte
at time.Time
}
func (c *Conn) flush() {
if c.pending == nil || len(c.rest) == 0 {
return
}
n := copy(c.pending.p, c.rest)
c.rest = c.rest[n:]
c.pending.reply <- readResp{n, nil}
c.pending = nil
}
func (c *Conn) deliver(payload []byte) {
c.rest = append(c.rest, payload...)
c.flush()
}
func (c *Conn) send(typ uint8, payload []byte) error {
c.seq++
h := &header{
typ: typ,
ver: 1,
connID: c.sendID,
timestamp: timestamp(),
wnd: defaultWnd,
seq: c.seq,
ack: c.ack,
}
buf := make([]byte, headerSize+len(payload))
encode(h, buf)
copy(buf[headerSize:], payload)
if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil {
return err
}
if typ == Data || typ == Syn || typ == Fin {
c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()})
}
return nil
}
func (c *Conn) sendState() error {
h := &header{
typ: State,
ver: 1,
connID: c.sendID,
timestamp: timestamp(),
wnd: defaultWnd,
seq: c.seq,
ack: c.ack,
}
buf := make([]byte, headerSize)
encode(h, buf)
_, err := c.sock.conn.WriteToUDP(buf, c.addr)
return err
}
func (c *Conn) sendSyn() error {
c.seq++
h := &header{
typ: Syn,
ver: 1,
connID: c.recvID,
timestamp: timestamp(),
wnd: defaultWnd,
seq: c.seq,
ack: c.ack,
}
buf := make([]byte, headerSize)
encode(h, buf)
if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil {
return err
}
c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()})
return nil
}
func (c *Conn) retrans() error {
select {
case <-c.ctx.Done():
return nil
default:
}
now := time.Now()
for i := range c.unacked {
p := &c.unacked[i]
if now.Sub(p.at) > time.Duration(c.rto)*time.Millisecond {
if _, err := c.sock.conn.WriteToUDP(p.data, c.addr); err != nil {
return err
}
p.at = now
}
}
return nil
}
func (c *Conn) recv(p packet) (bool, error) {
switch p.hdr.typ {
case Syn:
c.ack = p.hdr.seq
if err := c.sendState(); err != nil {
return false, err
}
case State:
select {
case <-c.ready:
default:
c.ack = p.hdr.seq - 1
close(c.ready)
}
var remaining []sent
for _, s := range c.unacked {
if seqLess(p.hdr.ack, s.seq) {
remaining = append(remaining, s)
}
}
c.unacked = remaining
case Data:
seq, payload := p.hdr.seq, p.payload
if seq == c.ack+1 {
c.deliver(payload)
c.ack++
for {
if p, ok := c.reorder[c.ack+1]; ok {
c.deliver(p)
delete(c.reorder, c.ack+1)
c.ack++
} else {
break
}
}
} else if seqLess(c.ack+1, seq) {
c.reorder[seq] = payload
}
if err := c.sendState(); err != nil {
return false, err
}
case Fin, Reset:
return true, nil
}
return false, nil
}
func (c *Conn) shutdown() {
if c.pending != nil {
c.pending.reply <- readResp{0, io.EOF}
}
select {
case <-c.ready:
c.send(Fin, nil)
default:
}
c.unacked = nil
c.cancel()
c.sock.removeConn(c.recvID)
}
func (c *Conn) run() {
c.reorder = make(map[uint16][]byte)
c.rto = initRTO
if c.recvID < c.sendID {
c.seq = 1
if err := c.sendSyn(); err != nil {
c.shutdown()
return
}
}
ticker := time.NewTicker(retransCheck * time.Millisecond)
defer ticker.Stop()
for {
select {
case p := <-c.in:
done, err := c.recv(p)
if err != nil {
c.shutdown()
return
}
if done {
c.shutdown()
return
}
case req := <-c.reads:
if c.pending != nil {
req.reply <- readResp{0, io.ErrNoProgress}
continue
}
c.pending = &req
c.flush()
case req := <-c.writes:
if err := c.send(Data, req.data); err != nil {
req.reply <- err
c.shutdown()
return
}
req.reply <- nil
case <-c.closeReq:
c.shutdown()
return
case <-ticker.C:
c.retrans()
case <-c.ctx.Done():
return
}
}
}
func (c *Conn) Read(p []byte) (int, error) {
reply := make(chan readResp, 1)
select {
case c.reads <- readReq{p, reply}:
r := <-reply
return r.n, r.err
case <-c.ctx.Done():
return 0, io.EOF
}
}
func (c *Conn) Write(p []byte) (int, error) {
reply := make(chan error, 1)
select {
case c.writes <- writeReq{p, reply}:
err := <-reply
if err != nil {
return 0, err
}
return len(p), nil
case <-c.ctx.Done():
return 0, io.ErrClosedPipe
}
}
func (c *Conn) Close() error {
select {
case c.closeReq <- struct{}{}:
case <-c.ctx.Done():
}
return nil
}
func timestamp() uint32 {
return uint32(time.Now().UnixMicro() & 0xFFFFFFFF)
}
func seqLess(a, b uint16) bool {
return int16(a-b) < 0
}

211
utp/socket.go Normal file
View File

@ -0,0 +1,211 @@
package utp
import (
"context"
"crypto/rand"
"encoding/binary"
"net"
"sync"
)
const (
Data = 0
Fin = 1
State = 2
Reset = 3
Syn = 4
)
const headerSize = 20
type header struct {
typ uint8
ver uint8
ext uint8
connID uint16
timestamp uint32
timeDiff uint32
wnd uint32
seq uint16
ack uint16
}
func encode(h *header, buf []byte) {
buf[0] = (h.typ << 4) | (h.ver & 0x0F)
buf[1] = h.ext
binary.BigEndian.PutUint16(buf[2:], h.connID)
binary.BigEndian.PutUint32(buf[4:], h.timestamp)
binary.BigEndian.PutUint32(buf[8:], h.timeDiff)
binary.BigEndian.PutUint32(buf[12:], h.wnd)
binary.BigEndian.PutUint16(buf[16:], h.seq)
binary.BigEndian.PutUint16(buf[18:], h.ack)
}
func decode(buf []byte) header {
return header{
typ: (buf[0] >> 4) & 0x0F,
ver: buf[0] & 0x0F,
ext: buf[1],
connID: binary.BigEndian.Uint16(buf[2:]),
timestamp: binary.BigEndian.Uint32(buf[4:]),
timeDiff: binary.BigEndian.Uint32(buf[8:]),
wnd: binary.BigEndian.Uint32(buf[12:]),
seq: binary.BigEndian.Uint16(buf[16:]),
ack: binary.BigEndian.Uint16(buf[18:]),
}
}
type Socket struct {
conn *net.UDPConn
mu sync.RWMutex
conns map[uint16]*Conn
accepts chan *Conn
ctx context.Context
cancel context.CancelFunc
}
type packet struct {
hdr header
payload []byte
addr *net.UDPAddr
}
func New(addr string) (*Socket, error) {
a, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", a)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
s := &Socket{
conn: conn,
conns: make(map[uint16]*Conn),
accepts: make(chan *Conn, 16),
ctx: ctx,
cancel: cancel,
}
return s, nil
}
func (s *Socket) Start() {
go s.reader()
}
func (s *Socket) reader() {
buf := make([]byte, 65535)
for {
n, addr, err := s.conn.ReadFromUDP(buf)
if err != nil {
return
}
if n < headerSize {
continue
}
hdr := decode(buf)
payload := make([]byte, n-headerSize)
copy(payload, buf[headerSize:n])
if hdr.typ == Syn {
s.handleSyn(hdr, payload, addr)
} else {
s.dispatch(hdr, payload)
}
}
}
func (s *Socket) handleSyn(hdr header, payload []byte, addr *net.UDPAddr) {
c := s.newConn(hdr.connID, addr, false)
s.mu.Lock()
s.conns[c.recvID] = c
s.mu.Unlock()
go c.run()
c.in <- packet{hdr, payload, addr}
select {
case s.accepts <- c:
default:
c.Close()
}
}
func (s *Socket) dispatch(hdr header, payload []byte) {
s.mu.RLock()
c := s.conns[hdr.connID]
s.mu.RUnlock()
if c != nil {
select {
case c.in <- packet{hdr, payload, nil}:
default:
}
}
}
func (s *Socket) removeConn(id uint16) {
s.mu.Lock()
delete(s.conns, id)
s.mu.Unlock()
}
func (s *Socket) newConn(peerID uint16, addr *net.UDPAddr, initiator bool) *Conn {
ctx, cancel := context.WithCancel(s.ctx)
c := &Conn{
sock: s,
addr: addr,
in: make(chan packet, 16),
reads: make(chan readReq),
writes: make(chan writeReq),
closeReq: make(chan struct{}),
ready: make(chan struct{}),
ctx: ctx,
cancel: cancel,
}
if initiator {
c.recvID = randUint16()
c.sendID = c.recvID + 1
} else {
c.recvID = peerID + 1
c.sendID = peerID
}
return c
}
func (s *Socket) DialContext(ctx context.Context, addr *net.UDPAddr) (*Conn, error) {
c := s.newConn(0, addr, true)
s.mu.Lock()
s.conns[c.recvID] = c
s.mu.Unlock()
go c.run()
select {
case <-c.ready:
return c, nil
case <-ctx.Done():
c.Close()
return nil, ctx.Err()
}
}
func (s *Socket) Accept() *Conn {
return <-s.accepts
}
func (s *Socket) Close() {
s.conn.Close()
s.cancel()
}
func randUint16() uint16 {
var b [2]byte
if _, err := rand.Read(b[:]); err != nil {
panic(err)
}
return binary.BigEndian.Uint16(b[:])
}

82
utp/utp_test.go Normal file
View File

@ -0,0 +1,82 @@
package utp
import "testing"
func TestEncodeDecode(t *testing.T) {
tests := []struct {
name string
h header
}{
{
name: "syn",
h: header{
typ: Syn,
ver: 1,
ext: 0,
connID: 1234,
timestamp: 12345678,
timeDiff: 1000,
wnd: 65535,
seq: 1,
ack: 0,
},
},
{
name: "data",
h: header{
typ: Data,
ver: 1,
connID: 5678,
timestamp: 99999999,
timeDiff: 500,
wnd: 32768,
seq: 100,
ack: 99,
},
},
{
name: "state",
h: header{
typ: State,
ver: 1,
connID: 1234,
wnd: 65535,
seq: 1,
ack: 1,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := make([]byte, headerSize)
encode(&tt.h, buf)
got := decode(buf)
if got != tt.h {
t.Errorf("got %+v, want %+v", got, tt.h)
}
})
}
}
func TestSeqLess(t *testing.T) {
tests := []struct {
name string
a, b uint16
want bool
}{
{"1 < 2", 1, 2, true},
{"2 > 1", 2, 1, false},
{"wrap 0xFFFF < 0", 0xFFFF, 0, true},
{"wrap 0 > 0xFFFF", 0, 0xFFFF, false},
{"half 0x7FFF > 0", 0x7FFF, 0, false},
{"half 0 < 0x7FFF", 0, 0x7FFF, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := seqLess(tt.a, tt.b)
if got != tt.want {
t.Errorf("seqLess(%d, %d) = %v, want %v", tt.a, tt.b, got, tt.want)
}
})
}
}