commit c70d24be5c04927d85095aa1c755be54f208a672 Author: Hojun-Cho Date: Mon Jan 19 21:13:01 2026 +0900 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..dd17e7e --- /dev/null +++ b/.gitignore @@ -0,0 +1,38 @@ +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# + +*.torrent +storrent + +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Code coverage profiles and other test artifacts +*.out +coverage.* +*.coverprofile +profile.cov + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work +go.work.sum + +# env file +.env + +# Editor/IDE +# .idea/ +# .vscode/ + + diff --git a/README.md b/README.md new file mode 100644 index 0000000..280060a --- /dev/null +++ b/README.md @@ -0,0 +1,64 @@ +# storrent + +A BitTorrent client with a 9P filesystem interface. + +## Dependencies + +- plan9port + +## Features + +- DHT for peer discovery +- μTP transport +- 9P control interface + +## Build + + go build + +## Run + + ./storrent -dht 6881 -utp 6882 -dir ./download -addr :5640 & + 9pfuse localhost:5640 /mnt/storrent + +## Usage + +Interact via files: + + echo 'add /path/to/file.torrent' > /mnt/storrent/ctl + cat /mnt/storrent/list + cat /mnt/storrent/torrents/0/progress + +Torrent control commands (write to `torrents//ctl`): + + start start downloading + stop stop torrent + seed start seeding + remove remove torrent + peer add peer manually + +Status files in `torrents//`: + + name torrent name + state current state + progress download progress + size total size + down bytes downloaded + up bytes uploaded + pieces piece info + peers connected peers + +## TODO + +- μTP congestion control +- ReadAt for streaming and seeding while downloading + +## Packages + + bencode/ bencode encoding/decoding + bt/ BitTorrent protocol messages + client/ torrent management + dht/ DHT implementation + fs/ 9P filesystem + metainfo/ .torrent file parsing + utp/ uTP transport diff --git a/bencode/bencode.go b/bencode/bencode.go new file mode 100644 index 0000000..9249549 --- /dev/null +++ b/bencode/bencode.go @@ -0,0 +1,173 @@ +package bencode + +// bencode types: +// string : "4:spam" +// int ie "i42e" +// list le "l4:spam4:eggse" +// dict de "d3:cow3:mooe" + +import ( + "fmt" + "sort" + "strconv" +) + +func Decode(data []byte) (any, int, error) { + if len(data) == 0 { + return nil, 0, fmt.Errorf("empty data") + } + + switch data[0] { + case 'i': + return decodeInt(data) + case 'l': + return decodeList(data) + case 'd': + return decodeDict(data) + default: + if data[0] >= '0' && data[0] <= '9' { + return DecodeString(data) + } + return nil, 0, fmt.Errorf("invalid bencode: %q", data[0]) + } +} + +func DecodeString(data []byte) (string, int, error) { + i := 0 + for i < len(data) && data[i] != ':' { + i++ + } + if i >= len(data) { + return "", 0, fmt.Errorf("bencode: missing ':' in string at %d", i) + } + n, err := strconv.Atoi(string(data[:i])) + if err != nil { + return "", 0, err + } + i++ + if i+n > len(data) { + return "", 0, fmt.Errorf("bencode: truncated string at %d", i) + } + return string(data[i : i+n]), i + n, nil +} + +func decodeInt(data []byte) (int64, int, error) { + if len(data) < 3 { + return 0, 0, fmt.Errorf("bencode: int too short at 0") + } + i := 1 + for i < len(data) && data[i] != 'e' { + i++ + } + if i >= len(data) { + return 0, 0, fmt.Errorf("bencode: missing 'e' in int at %d", i) + } + val, err := strconv.ParseInt(string(data[1:i]), 10, 64) + if err != nil { + return 0, 0, err + } + return val, i + 1, nil +} + +func decodeList(data []byte) ([]any, int, error) { + if len(data) == 0 { + return nil, 0, fmt.Errorf("bencode: empty list at 0") + } + + var list []any + i := 1 + for i < len(data) && data[i] != 'e' { + v, n, err := Decode(data[i:]) + if err != nil { + return nil, 0, err + } + list = append(list, v) + i += n + } + if i >= len(data) { + return nil, 0, fmt.Errorf("bencode: truncated list at %d", i) + } + return list, i + 1, nil +} + +func decodeDict(data []byte) (map[string]any, int, error) { + if len(data) == 0 { + return nil, 0, fmt.Errorf("bencode: empty dict at 0") + } + + d := make(map[string]any) + i := 1 + for i < len(data) && data[i] != 'e' { + k, n, err := DecodeString(data[i:]) + if err != nil { + return nil, 0, err + } + i += n + v, n, err := Decode(data[i:]) + if err != nil { + return nil, 0, err + } + d[k] = v + i += n + } + if i >= len(data) { + return nil, 0, fmt.Errorf("bencode: truncated dict at %d", i) + } + return d, i + 1, nil +} + +func Encode(v any) ([]byte, error) { + switch v := v.(type) { + case string: + return encodeString(v), nil + case []byte: + return encodeString(string(v)), nil + case int: + return encodeInt(int64(v)), nil + case int64: + return encodeInt(v), nil + case []any: + return encodeList(v) + case map[string]any: + return encodeDict(v) + } + return nil, fmt.Errorf("cannot encode %T", v) +} + +func encodeString(s string) []byte { + return fmt.Appendf(nil, "%d:%s", len(s), s) +} + +func encodeInt(n int64) []byte { + return fmt.Appendf(nil, "i%de", n) +} + +func encodeList(list []any) ([]byte, error) { + buf := []byte{'l'} + for _, v := range list { + enc, err := Encode(v) + if err != nil { + return nil, err + } + buf = append(buf, enc...) + } + return append(buf, 'e'), nil +} + +func encodeDict(d map[string]any) ([]byte, error) { + keys := make([]string, 0, len(d)) + for k := range d { + keys = append(keys, k) + } + sort.Strings(keys) + buf := []byte{'d'} + for _, k := range keys { + buf = append(buf, encodeString(k)...) + enc, err := Encode(d[k]) + if err != nil { + return nil, err + } + buf = append(buf, enc...) + } + return append(buf, 'e'), nil +} diff --git a/bencode/bencode_test.go b/bencode/bencode_test.go new file mode 100644 index 0000000..2a17990 --- /dev/null +++ b/bencode/bencode_test.go @@ -0,0 +1,113 @@ +package bencode + +import "testing" + +func TestDecode(t *testing.T) { + tests := []struct { + in string + want any + err bool + }{ + { + in: "4:spam", + want: "spam", + }, + { + in: "0:", + want: "", + }, + { + in: "i42e", + want: int64(42), + }, + { + in: "i-42e", + want: int64(-42), + }, + { + in: "i0e", + want: int64(0), + }, + { + in: "", + err: true, + }, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + got, _, err := Decode([]byte(tt.in)) + if tt.err { + if err == nil { + t.Error("expected error") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.want { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +} + +func TestEncode(t *testing.T) { + tests := []struct { + name string + in any + want string + }{ + { + name: "string", + in: "spam", + want: "4:spam", + }, + { + name: "empty string", + in: "", + want: "0:", + }, + { + name: "int", + in: int64(42), + want: "i42e", + }, + { + name: "negative int", + in: int64(-42), + want: "i-42e", + }, + { + name: "list", + in: []any{"a", "b"}, + want: "l1:a1:be", + }, + { + name: "empty list", + in: []any{}, + want: "le", + }, + { + name: "dict", + in: map[string]any{"b": int64(2), "a": int64(1)}, + want: "d1:ai1e1:bi2ee", + }, + { + name: "empty dict", + in: map[string]any{}, + want: "de", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Encode(tt.in) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(got) != tt.want { + t.Errorf("got %s, want %s", got, tt.want) + } + }) + } +} diff --git a/bt/msg.go b/bt/msg.go new file mode 100644 index 0000000..ad21b81 --- /dev/null +++ b/bt/msg.go @@ -0,0 +1,116 @@ +package bt + +// https://wiki.theory.org/BitTorrentSpecification#Messages +// Messages: +// keep-alive: +// choke: +// unchoke: +// interested: +// not interested: +// have: +// bitfield: +// request: +// piece: +// cancel: + +import ( + "encoding/binary" + "io" +) + +type MsgKind uint8 + +const BlockSize = 16384 + +const ( + Choke MsgKind = iota + Unchoke + Interested + NotInterested + Have + Bitfield + Request + Piece + Cancel +) + +type Msg struct { + Kind MsgKind + Index uint32 + Begin uint32 + Length uint32 + Bitfield []byte + Block []byte +} + +// [4: len] [1: type] [len-1: payload] +func readMsg(r io.Reader) (*Msg, int, error) { + var lenBuf [4]byte + if _, err := io.ReadFull(r, lenBuf[:]); err != nil { + return nil, 0, err + } + n := binary.BigEndian.Uint32(lenBuf[:]) + if n == 0 { + return nil, 4, nil + } + buf := make([]byte, n) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, 0, err + } + msg := &Msg{Kind: MsgKind(buf[0])} + payload := buf[1:] + switch msg.Kind { + case Have: + msg.Index = binary.BigEndian.Uint32(payload) + case Bitfield: + msg.Bitfield = payload + case Request, Cancel: + msg.Index = binary.BigEndian.Uint32(payload[0:4]) + msg.Begin = binary.BigEndian.Uint32(payload[4:8]) + msg.Length = binary.BigEndian.Uint32(payload[8:12]) + case Piece: + msg.Index = binary.BigEndian.Uint32(payload[0:4]) + msg.Begin = binary.BigEndian.Uint32(payload[4:8]) + msg.Block = payload[8:] + } + return msg, int(4 + n), nil +} + +func writeMsg(w io.Writer, msg *Msg) (int, error) { + if msg == nil { + return w.Write(make([]byte, 4)) + } + + var payload []byte + switch msg.Kind { + case Choke, Unchoke, Interested, NotInterested: + case Have: + payload = make([]byte, 4) + binary.BigEndian.PutUint32(payload, msg.Index) + case Bitfield: + payload = msg.Bitfield + case Request, Cancel: + payload = make([]byte, 12) + binary.BigEndian.PutUint32(payload[0:4], msg.Index) + binary.BigEndian.PutUint32(payload[4:8], msg.Begin) + binary.BigEndian.PutUint32(payload[8:12], msg.Length) + case Piece: + payload = make([]byte, 8+len(msg.Block)) + binary.BigEndian.PutUint32(payload[0:4], msg.Index) + binary.BigEndian.PutUint32(payload[4:8], msg.Begin) + copy(payload[8:], msg.Block) + } + buf := make([]byte, 5+len(payload)) + binary.BigEndian.PutUint32(buf[0:4], uint32(1+len(payload))) + buf[4] = byte(msg.Kind) + copy(buf[5:], payload) + return w.Write(buf) +} + +func SetBit(bf []byte, i int) { + bf[i/8] |= 1 << (7 - i%8) +} + +func HasBit(bf []byte, i int) bool { + return bf[i/8]&(1<<(7-i%8)) != 0 +} diff --git a/bt/msg_test.go b/bt/msg_test.go new file mode 100644 index 0000000..a2de47d --- /dev/null +++ b/bt/msg_test.go @@ -0,0 +1,91 @@ +package bt + +import ( + "bytes" + "testing" +) + +func TestMsg(t *testing.T) { + tests := []struct { + name string + msg *Msg + }{ + { + name: "choke", + msg: &Msg{Kind: Choke}, + }, + { + name: "unchoke", + msg: &Msg{Kind: Unchoke}, + }, + { + name: "interested", + msg: &Msg{Kind: Interested}, + }, + { + name: "not interested", + msg: &Msg{Kind: NotInterested}, + }, + { + name: "have", + msg: &Msg{ + Kind: Have, + Index: 42, + }, + }, + { + name: "bitfield", + msg: &Msg{ + Kind: Bitfield, + Bitfield: []byte{0b1111_1111, 0b0000_0000}, + }, + }, + { + name: "request", + msg: &Msg{ + Kind: Request, + Index: 1, + Begin: 16384, + Length: 16384, + }, + }, + { + name: "piece", + msg: &Msg{ + Kind: Piece, + Index: 1, + Begin: 0, + Block: []byte("data"), + }, + }, + { + name: "cancel", + msg: &Msg{ + Kind: Cancel, + Index: 1, + Begin: 0, + Length: 16384, + }, + }, + { + name: "keep-alive", + msg: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + writeMsg(&buf, tt.msg) + data := buf.Bytes() + got, _, err := readMsg(bytes.NewReader(data)) + if err != nil { + t.Fatalf("readMsg: %v", err) + } + var buf2 bytes.Buffer + writeMsg(&buf2, got) + if !bytes.Equal(buf2.Bytes(), data) { + t.Errorf("roundtrip mismatch: got %x, want %x", buf2.Bytes(), data) + } + }) + } +} diff --git a/bt/peer.go b/bt/peer.go new file mode 100644 index 0000000..b3c5e46 --- /dev/null +++ b/bt/peer.go @@ -0,0 +1,160 @@ +package bt + +// https://wiki.theory.org/BitTorrentSpecification#Handshake +// handshake: +// pstrlen: string length of , as a single raw byte. +// pstr: string identifier of the protocol. +// reserved: eight (8) reserved bytes. +// info_hash: 20-byte SHA1 hash of the info key in the metainfo file. +// peer_id: 20-byte string used as a unique ID for the client. + +import ( + "context" + "crypto/rand" + "fmt" + "io" + "net" + "time" + + "storrent/utp" +) + +type Peer struct { + Addr *net.TCPAddr + conn io.ReadWriteCloser + InfoHash [20]byte + PeerID [20]byte + Choked bool + Interested bool + Have []bool + Pending int +} + +const btProto = "BitTorrent protocol" + +const ( + pstrlen = 19 + reservedLen = 8 + hashLen = 20 + peerIDLen = 20 + handshakeLen = 1 + pstrlen + reservedLen + hashLen + peerIDLen // 68 + handshakeTimeout = 10 * time.Second +) + +var peerID = genPeerID() + +func genPeerID() [20]byte { + var id [20]byte + copy(id[:], "-SS0001-") + if _, err := rand.Read(id[8:]); err != nil { + panic(err) + } + return id +} + +func DialContext(ctx context.Context, addr *net.TCPAddr, h [20]byte, utpSock *utp.Socket) (*Peer, error) { + if utpSock == nil { + return nil, fmt.Errorf("utp socket required") + } + udpAddr := &net.UDPAddr{IP: addr.IP, Port: addr.Port} + conn, err := utpSock.DialContext(ctx, udpAddr) + if err != nil { + return nil, err + } + p := &Peer{Addr: addr, conn: conn, InfoHash: h, Choked: true} + if err := p.handshake(); err != nil { + conn.Close() + return nil, err + } + return p, nil +} + +// [1: pstrlen] [pstrlen: btProto] [reservedLen: reserved] [hashLen: info_hash] [peerIDLen: peer_id] +func (p *Peer) handshake() error { + buf := make([]byte, handshakeLen) + buf[0] = pstrlen + copy(buf[1:1+pstrlen], btProto) + copy(buf[1+pstrlen+reservedLen:1+pstrlen+reservedLen+hashLen], p.InfoHash[:]) + copy(buf[1+pstrlen+reservedLen+hashLen:], peerID[:]) + if _, err := p.conn.Write(buf); err != nil { + return err + } + if _, err := io.ReadFull(p.conn, buf); err != nil { + return err + } + if buf[0] != pstrlen || string(buf[1:1+pstrlen]) != btProto { + return fmt.Errorf("invalid handshake") + } + if [hashLen]byte(buf[1+pstrlen+reservedLen:1+pstrlen+reservedLen+hashLen]) != p.InfoHash { + return fmt.Errorf("info hash mismatch") + } + p.PeerID = [peerIDLen]byte(buf[1+pstrlen+reservedLen+hashLen:]) + return nil +} + +func (p *Peer) Send(msg *Msg) error { + _, err := writeMsg(p.conn, msg) + return err +} + +func (p *Peer) Recv() (*Msg, error) { + msg, _, err := readMsg(p.conn) + return msg, err +} + +func (p *Peer) Close() error { + return p.conn.Close() +} + +func Accept(ln net.Listener) (*Peer, error) { + conn, err := ln.Accept() + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(handshakeTimeout)) + buf := make([]byte, handshakeLen) + if _, err := io.ReadFull(conn, buf); err != nil { + conn.Close() + return nil, err + } + if buf[0] != pstrlen || string(buf[1:1+pstrlen]) != btProto { + conn.Close() + return nil, fmt.Errorf("invalid handshake") + } + infoHash := [hashLen]byte(buf[1+pstrlen+reservedLen : 1+pstrlen+reservedLen+hashLen]) + remotePeerID := [peerIDLen]byte(buf[1+pstrlen+reservedLen+hashLen:]) + copy(buf[1+pstrlen+reservedLen+hashLen:], peerID[:]) + if _, err := conn.Write(buf); err != nil { + conn.Close() + return nil, err + } + conn.SetDeadline(time.Time{}) + return &Peer{ + conn: conn, + InfoHash: infoHash, + PeerID: remotePeerID, + Choked: true, + }, nil +} + +// piece 0~7: byte 0 [7 6 5 4 3 2 1 0] +// piece 8~15: byte 1 [7 6 5 4 3 2 1 0] +func (p *Peer) SetPieces(data []byte, n int) { + p.Have = make([]bool, n) + for i := range n { + p.Have[i] = HasBit(data, i) + } +} + +func (p *Peer) HasPiece(i int) bool { + if i >= 0 && i < len(p.Have) { + return p.Have[i] + } + return false +} + +func (p *Peer) SetPiece(i int) { + if i >= 0 && i < len(p.Have) { + p.Have[i] = true + } +} diff --git a/bt/peer_test.go b/bt/peer_test.go new file mode 100644 index 0000000..48d6ba7 --- /dev/null +++ b/bt/peer_test.go @@ -0,0 +1,37 @@ +package bt + +import "testing" + +func TestSetPieces(t *testing.T) { + tests := []struct { + name string + data []byte + n int + i int + want bool + }{ + { + name: "piece 0 set", + data: []byte{0b1000_0000}, + n: 8, + i: 0, + want: true, + }, + { + name: "piece 7 unset", + data: []byte{0b1111_1110}, + n: 8, + i: 7, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Peer{} + c.SetPieces(tt.data, tt.n) + if c.Have[tt.i] != tt.want { + t.Errorf("Have[%d]: got %v, want %v", tt.i, c.Have[tt.i], tt.want) + } + }) + } +} diff --git a/client/dial.go b/client/dial.go new file mode 100644 index 0000000..e5f5f7a --- /dev/null +++ b/client/dial.go @@ -0,0 +1,139 @@ +package client + +import ( + "context" + "net" + "sync" + + "storrent/bt" + "storrent/dht" + "storrent/utp" +) + +const maxConcurrentDials = 8 + +type dialResult struct { + peer *bt.Peer + err error + done bool +} + +type dialPool struct { + sem chan struct{} + results chan dialResult + + dht *dht.DHT + utp *utp.Socket + infoHash [20]byte + + running bool + cancel context.CancelFunc +} + +func newDialPool(d *dht.DHT, utpSock *utp.Socket, infoHash [20]byte) *dialPool { + return &dialPool{ + sem: make(chan struct{}, maxConcurrentDials), + results: make(chan dialResult, maxConcurrentDials), + dht: d, + utp: utpSock, + infoHash: infoHash, + } +} + +func (dp *dialPool) start(parentCtx context.Context) { + if dp.running { + return + } + dp.running = true + ctx, cancel := context.WithCancel(parentCtx) + dp.cancel = cancel + go dp.connectLoop(ctx) +} + +func (dp *dialPool) stop() { + if !dp.running { + return + } + dp.running = false + if dp.cancel != nil { + dp.cancel() + dp.cancel = nil + } +} + +func (dp *dialPool) connectLoop(ctx context.Context) { + dhtCtx, cancel := context.WithTimeout(ctx, dhtTimeout) + addrs, err := dp.dht.GetPeers(dhtCtx, dp.infoHash) + cancel() + + if err != nil { + select { + case dp.results <- dialResult{done: true, err: err}: + case <-ctx.Done(): + } + return + } + + var wg sync.WaitGroup +loop: + for _, addr := range addrs { + select { + case dp.sem <- struct{}{}: + case <-ctx.Done(): + break loop + } + + wg.Add(1) + go func(addr *net.TCPAddr) { + defer wg.Done() + defer func() { <-dp.sem }() + + dialCtx, cancel := context.WithTimeout(ctx, dialTimeout) + defer cancel() + + p, err := bt.DialContext(dialCtx, addr, dp.infoHash, dp.utp) + if err != nil { + return + } + + select { + case dp.results <- dialResult{peer: p}: + case <-ctx.Done(): + p.Close() + } + }(addr) + } + + wg.Wait() + + select { + case dp.results <- dialResult{done: true}: + case <-ctx.Done(): + } +} + +func (dp *dialPool) dialSingle(ctx context.Context, addr *net.TCPAddr) { + select { + case dp.sem <- struct{}{}: + case <-ctx.Done(): + return + } + + go func() { + defer func() { <-dp.sem }() + + dialCtx, cancel := context.WithTimeout(ctx, dialTimeout) + defer cancel() + + p, err := bt.DialContext(dialCtx, addr, dp.infoHash, dp.utp) + if err != nil { + return + } + + select { + case dp.results <- dialResult{peer: p}: + case <-ctx.Done(): + p.Close() + } + }() +} diff --git a/client/manager.go b/client/manager.go new file mode 100644 index 0000000..a21fb0d --- /dev/null +++ b/client/manager.go @@ -0,0 +1,178 @@ +package client + +import ( + "errors" + "fmt" + "net" + "strconv" + "sync" + + "storrent/bt" + "storrent/dht" + "storrent/metainfo" + "storrent/utp" +) + +var errNotFound = errors.New("torrent not found") + +func (m *Manager) getTorrent(id int) (*torrent, error) { + m.mu.Lock() + t, ok := m.torrents[id] + m.mu.Unlock() + if !ok { + return nil, errNotFound + } + return t, nil +} + +type Manager struct { + mu sync.Mutex + torrents map[int]*torrent + nextID int + dir string + dht *dht.DHT + ln net.Listener + utp *utp.Socket +} + +func NewManager(dir string, d *dht.DHT, btPort int, utpSock *utp.Socket) (*Manager, error) { + var ln net.Listener + if btPort > 0 { + var err error + ln, err = net.Listen("tcp", fmt.Sprintf(":%d", btPort)) + if err != nil { + return nil, err + } + } + return &Manager{ + torrents: make(map[int]*torrent), + dir: dir, + dht: d, + ln: ln, + utp: utpSock, + }, nil +} + +func (m *Manager) Run() { + if m.ln == nil { + return + } + for { + p, err := bt.Accept(m.ln) + if err != nil { + return + } + m.dispatch(p) + } +} + +func (m *Manager) dispatch(p *bt.Peer) { + m.mu.Lock() + defer m.mu.Unlock() + for _, t := range m.torrents { + if t.meta.InfoHash == p.InfoHash { + select { + case t.incoming <- p: + default: + p.Close() + } + return + } + } + p.Close() +} + +func (m *Manager) Add(path string) ([]byte, error) { + mi, err := metainfo.Parse(path) + if err != nil { + return nil, err + } + m.mu.Lock() + id := m.nextID + m.nextID++ + t := newTorrent(id, mi, m.dir, m.dht, m.utp) + complete := t.scheduler.isComplete() + go t.run() + m.torrents[id] = t + m.mu.Unlock() + if complete { + t.startSeed() + } else { + t.startDownload() + } + return []byte(strconv.Itoa(id)), nil +} + +func (m *Manager) Remove(id int) error { + m.mu.Lock() + t, ok := m.torrents[id] + if !ok { + m.mu.Unlock() + return errNotFound + } + delete(m.torrents, id) + m.mu.Unlock() + t.stopTorrent() + return nil +} + +func (m *Manager) Start(id int) error { + t, err := m.getTorrent(id) + if err != nil { + return err + } + t.startDownload() + return nil +} + +func (m *Manager) Stop(id int) error { + t, err := m.getTorrent(id) + if err != nil { + return err + } + t.stopTorrent() + return nil +} + +func (m *Manager) Seed(id int) error { + t, err := m.getTorrent(id) + if err != nil { + return err + } + t.startSeed() + return nil +} + +func (m *Manager) AddPeer(id int, addr string) error { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return err + } + t, err := m.getTorrent(id) + if err != nil { + return err + } + t.enqueuePeer(tcpAddr) + return nil +} + +func (m *Manager) Status(id int, field string) ([]byte, error) { + t, err := m.getTorrent(id) + if err != nil { + return nil, err + } + return []byte(t.getStatus(field)), nil +} + +func (m *Manager) List() []byte { + m.mu.Lock() + defer m.mu.Unlock() + var data []byte + for id := range m.torrents { + if len(data) > 0 { + data = append(data, '\n') + } + data = append(data, []byte(strconv.Itoa(id))...) + } + return data +} diff --git a/client/peer.go b/client/peer.go new file mode 100644 index 0000000..53e8a72 --- /dev/null +++ b/client/peer.go @@ -0,0 +1,50 @@ +package client + +import "storrent/bt" + +type peerManager struct { + peers map[*bt.Peer]struct{} + maxPeers int +} + +func newPeerManager(maxPeers int) *peerManager { + return &peerManager{ + peers: make(map[*bt.Peer]struct{}), + maxPeers: maxPeers, + } +} + +func (pm *peerManager) add(p *bt.Peer) bool { + if len(pm.peers) >= pm.maxPeers { + return false + } + pm.peers[p] = struct{}{} + return true +} + +func (pm *peerManager) addUnlimited(p *bt.Peer) { + pm.peers[p] = struct{}{} +} + +func (pm *peerManager) remove(p *bt.Peer) { + delete(pm.peers, p) +} + +func (pm *peerManager) count() int { + return len(pm.peers) +} + +func (pm *peerManager) all() []*bt.Peer { + peers := make([]*bt.Peer, 0, len(pm.peers)) + for p := range pm.peers { + peers = append(peers, p) + } + return peers +} + +func (pm *peerManager) closeAll() { + for p := range pm.peers { + p.Close() + } + pm.peers = make(map[*bt.Peer]struct{}) +} diff --git a/client/piece.go b/client/piece.go new file mode 100644 index 0000000..e07833b --- /dev/null +++ b/client/piece.go @@ -0,0 +1,69 @@ +package client + +import ( + "crypto/sha1" + + "storrent/bt" +) + +type piece struct { + hash [20]byte + size int64 + data []byte + have []bool + reqs []bool + reqPeer []*bt.Peer + done bool +} + +func (p *piece) put(begin int, data []byte) bool { + if p.done { + return false + } + block := begin / bt.BlockSize + if block >= len(p.have) || p.have[block] { + return false + } + copy(p.data[begin:], data) + p.have[block] = true + for _, got := range p.have { + if !got { + return false + } + } + if sha1.Sum(p.data) != p.hash { + for i := range p.have { + p.have[i] = false + p.reqs[i] = false + } + return false + } + p.done = true + return true +} + +func (p *piece) next(peer *bt.Peer) (begin, length int64, ok bool) { + for i := range p.have { + if p.have[i] || p.reqs[i] { + continue + } + p.reqs[i] = true + p.reqPeer[i] = peer + begin = int64(i) * bt.BlockSize + length = bt.BlockSize + if begin+length > p.size { + length = p.size - begin + } + return begin, length, true + } + return 0, 0, false +} + +func (p *piece) resetPeer(peer *bt.Peer) { + for i := range p.reqs { + if p.reqs[i] && p.reqPeer[i] == peer { + p.reqs[i] = false + p.reqPeer[i] = nil + } + } +} diff --git a/client/scheduler.go b/client/scheduler.go new file mode 100644 index 0000000..67d95ba --- /dev/null +++ b/client/scheduler.go @@ -0,0 +1,134 @@ +package client + +import ( + "crypto/sha1" + + "storrent/bt" + "storrent/metainfo" +) + +type pieceScheduler struct { + pieces []piece + incomplete []int + verified int + meta *metainfo.File + dir string +} + +func newPieceScheduler(m *metainfo.File, dir string) *pieceScheduler { + n := len(m.Info.Pieces) + pieces := make([]piece, n) + incomplete := make([]int, 0, n) + verified := 0 + + for i := range n { + size := m.Info.PieceSize + if i == n-1 { + size = m.Size - int64(n-1)*m.Info.PieceSize + } + nblocks := (size + bt.BlockSize - 1) / bt.BlockSize + pieces[i] = piece{ + hash: m.Info.Pieces[i], + size: size, + data: make([]byte, size), + have: make([]bool, nblocks), + reqs: make([]bool, nblocks), + reqPeer: make([]*bt.Peer, nblocks), + } + + data := m.Info.Read(dir, i, pieces[i].size) + if data != nil && sha1.Sum(data) == pieces[i].hash { + pieces[i].data = data + pieces[i].done = true + verified++ + } else { + incomplete = append(incomplete, i) + } + } + + return &pieceScheduler{ + pieces: pieces, + incomplete: incomplete, + verified: verified, + meta: m, + dir: dir, + } +} + +func (ps *pieceScheduler) nextRequest(p *bt.Peer) *bt.Msg { + for _, idx := range ps.incomplete { + pc := &ps.pieces[idx] + if pc.done { + continue + } + if !p.HasPiece(idx) { + continue + } + if begin, length, ok := pc.next(p); ok { + return &bt.Msg{ + Kind: bt.Request, + Index: uint32(idx), + Begin: uint32(begin), + Length: uint32(length), + } + } + } + return nil +} + +func (ps *pieceScheduler) put(index int, begin int, block []byte) bool { + pc := &ps.pieces[index] + if pc.put(begin, block) { + ps.meta.Info.Write(ps.dir, index, pc.data) + ps.verified++ + ps.removeIncomplete(index) + return true + } + return false +} + +func (ps *pieceScheduler) removeIncomplete(index int) { + for i, idx := range ps.incomplete { + if idx == index { + ps.incomplete = append(ps.incomplete[:i], ps.incomplete[i+1:]...) + return + } + } +} + +func (ps *pieceScheduler) peerDisconnected(p *bt.Peer) { + for i := range ps.pieces { + ps.pieces[i].resetPeer(p) + } +} + +func (ps *pieceScheduler) isComplete() bool { + return ps.verified == len(ps.pieces) +} + +func (ps *pieceScheduler) progress() (verified, total int) { + return ps.verified, len(ps.pieces) +} + +func (ps *pieceScheduler) pieceData(index int) []byte { + pc := &ps.pieces[index] + if !pc.done { + return nil + } + return pc.data +} + +func (ps *pieceScheduler) count() int { + return len(ps.pieces) +} + +func (ps *pieceScheduler) bitfield(all bool) []byte { + n := len(ps.pieces) + bf := make([]byte, (n+7)/8) + for i := range ps.pieces { + if all || ps.pieces[i].done { + bt.SetBit(bf, i) + } + } + return bf +} diff --git a/client/torrent.go b/client/torrent.go new file mode 100644 index 0000000..d80f7f7 --- /dev/null +++ b/client/torrent.go @@ -0,0 +1,419 @@ +package client + +import ( + "context" + "fmt" + "net" + "strconv" + "strings" + "time" + + "storrent/bt" + "storrent/dht" + "storrent/metainfo" + "storrent/utp" +) + +type state int + +const ( + stopped state = iota + downloading + seeding + done + errored +) + +const ( + maxPeers = 8 + maxPending = 5 + dialTimeout = 30 * time.Second + retryInterval = 30 * time.Second + dhtTimeout = 5 * time.Minute + recvBufSize = 128 +) + +type recv struct { + msg *bt.Msg + err error + p *bt.Peer +} + +type statusReq struct { + field string + resp chan string +} + +type torrent struct { + id int + meta *metainfo.File + dir string + dht *dht.DHT + utp *utp.Socket + + start chan struct{} + stop chan struct{} + seed chan struct{} + incoming chan *bt.Peer + addPeer chan *net.TCPAddr + status chan statusReq + + peers *peerManager + scheduler *pieceScheduler + dialPool *dialPool + + st state + down int64 + up int64 + + ctx context.Context + cancel context.CancelFunc + recvs chan recv + retry <-chan time.Time +} + +func newTorrent(id int, m *metainfo.File, dir string, d *dht.DHT, utpSock *utp.Socket) *torrent { + t := &torrent{ + id: id, + meta: m, + dir: dir, + dht: d, + utp: utpSock, + start: make(chan struct{}, 1), + stop: make(chan struct{}, 1), + seed: make(chan struct{}, 1), + incoming: make(chan *bt.Peer), + addPeer: make(chan *net.TCPAddr, 8), + status: make(chan statusReq), + peers: newPeerManager(maxPeers), + scheduler: newPieceScheduler(m, dir), + dialPool: newDialPool(d, utpSock, m.InfoHash), + st: stopped, + recvs: make(chan recv, recvBufSize), + } + return t +} + +func (t *torrent) run() { + for { + select { + case <-t.start: + t.handleStart() + + case <-t.seed: + t.handleSeed() + + case <-t.stop: + t.handleStop() + + case <-t.retry: + t.handleRetry() + + case d := <-t.dialPool.results: + t.handleDialResult(d) + + case p := <-t.incoming: + t.handleIncoming(p) + + case addr := <-t.addPeer: + t.handleAddPeer(addr) + + case r := <-t.recvs: + t.handleRecv(r) + + case req := <-t.status: + req.resp <- t.handleStatus(req.field) + } + } +} + +func (t *torrent) handleStart() { + if t.st != stopped { + return + } + t.st = downloading + t.ctx, t.cancel = context.WithCancel(context.Background()) + t.dialPool.start(t.ctx) +} + +func (t *torrent) handleSeed() { + if t.st != stopped && t.st != done { + return + } + if !t.scheduler.isComplete() { + return + } + t.st = seeding + t.ctx, t.cancel = context.WithCancel(context.Background()) + t.dialPool.start(t.ctx) +} + +func (t *torrent) handleStop() { + if !t.isActive() { + return + } + t.cancel() + t.dialPool.stop() + t.peers.closeAll() + t.retry = nil + t.st = stopped +} + +func (t *torrent) handleRetry() { + t.retry = nil + if t.isActive() { + t.dialPool.start(t.ctx) + } +} + +func (t *torrent) handleDialResult(d dialResult) { + if d.done { + if t.peers.count() < maxPeers && t.isActive() { + t.retry = time.After(retryInterval) + } + return + } + + if d.peer == nil { + return + } + + p := d.peer + if !t.peers.add(p) { + p.Close() + return + } + + go t.recvLoop(p) + if err := t.initPeer(p); err != nil { + p.Close() + t.peers.remove(p) + } +} + +func (t *torrent) handleIncoming(p *bt.Peer) { + t.peers.addUnlimited(p) + go t.recvLoop(p) + + if err := p.Send(&bt.Msg{Kind: bt.Bitfield, Bitfield: t.scheduler.bitfield(false)}); err != nil { + p.Close() + t.peers.remove(p) + return + } + if err := p.Send(&bt.Msg{Kind: bt.Unchoke}); err != nil { + p.Close() + t.peers.remove(p) + } +} + +func (t *torrent) handleAddPeer(addr *net.TCPAddr) { + if !t.isActive() { + return + } + t.dialPool.dialSingle(t.ctx, addr) +} + +func (t *torrent) handleRecv(r recv) { + if r.err != nil { + r.p.Close() + t.peers.remove(r.p) + t.scheduler.peerDisconnected(r.p) + if t.peers.count() < maxPeers && t.isActive() { + t.dialPool.start(t.ctx) + } + return + } + + if r.msg == nil { + return + } + + p := r.p + m := r.msg + + switch m.Kind { + case bt.Choke: + p.Choked = true + p.Pending = 0 + + case bt.Unchoke: + p.Choked = false + + case bt.Have: + p.SetPiece(int(m.Index)) + + case bt.Bitfield: + p.SetPieces(m.Bitfield, t.scheduler.count()) + + case bt.Piece: + t.scheduler.put(int(m.Index), int(m.Begin), m.Block) + p.Pending-- + t.down += int64(len(m.Block)) + + case bt.Request: + t.handleRequest(p, m) + } + + t.sendRequests(p) + t.checkCompletion() +} + +func (t *torrent) handleRequest(p *bt.Peer, m *bt.Msg) { + data := t.scheduler.pieceData(int(m.Index)) + if data == nil { + return + } + end := int(m.Begin + m.Length) + if end > len(data) { + return + } + if err := p.Send(&bt.Msg{ + Kind: bt.Piece, + Index: m.Index, + Begin: m.Begin, + Block: data[m.Begin:end], + }); err != nil { + p.Close() + return + } + t.up += int64(m.Length) +} + +func (t *torrent) sendRequests(p *bt.Peer) { + if p.Choked { + return + } + for p.Pending < maxPending { + req := t.scheduler.nextRequest(p) + if req == nil { + break + } + if err := p.Send(req); err != nil { + p.Close() + break + } + p.Pending++ + } +} + +func (t *torrent) checkCompletion() { + if t.st == downloading && t.scheduler.isComplete() { + t.st = done + t.peers.closeAll() + } +} + +func (t *torrent) initPeer(p *bt.Peer) error { + if t.st == seeding { + return p.Send(&bt.Msg{Kind: bt.Bitfield, Bitfield: t.scheduler.bitfield(true)}) + } + return p.Send(&bt.Msg{Kind: bt.Interested}) +} + +func (t *torrent) recvLoop(p *bt.Peer) { + for { + msg, err := p.Recv() + t.recvs <- recv{msg, err, p} + if err != nil { + return + } + } +} + +func (t *torrent) isActive() bool { + return t.st == downloading || t.st == seeding +} + +func (t *torrent) startDownload() { + select { + case t.start <- struct{}{}: + default: + } +} + +func (t *torrent) startSeed() { + select { + case t.seed <- struct{}{}: + default: + } +} + +func (t *torrent) stopTorrent() { + select { + case t.stop <- struct{}{}: + default: + } +} + +func (t *torrent) enqueuePeer(addr *net.TCPAddr) { + t.addPeer <- addr +} + +func (t *torrent) getStatus(field string) string { + ch := make(chan string) + t.status <- statusReq{field, ch} + return <-ch +} + +func (t *torrent) handleStatus(field string) string { + switch field { + case "name": + return t.meta.Info.Name + case "state": + return t.st.String() + case "progress": + verified, total := t.scheduler.progress() + if total == 0 { + return "0" + } + return strconv.Itoa(verified * 100 / total) + case "size": + return strconv.FormatInt(t.meta.Size, 10) + case "down": + return strconv.FormatInt(t.down, 10) + case "up": + return strconv.FormatInt(t.up, 10) + case "pieces": + verified, total := t.scheduler.progress() + return fmt.Sprintf("%d/%d", verified, total) + case "peers": + return t.formatPeers() + } + return "" +} + +func (t *torrent) formatPeers() string { + var buf strings.Builder + npieces := t.scheduler.count() + for _, p := range t.peers.all() { + st := "unchoked" + if p.Choked { + st = "choked" + } + have := 0 + for _, h := range p.Have { + if h { + have++ + } + } + fmt.Fprintf(&buf, "%s %s %s have:%d/%d pending:%d\n", + p.Addr, p.PeerID[:8], st, have, npieces, p.Pending) + } + return buf.String() +} + +func (s state) String() string { + switch s { + case stopped: + return "stopped" + case downloading: + return "downloading" + case seeding: + return "seeding" + case done: + return "done" + case errored: + return "error" + } + return "unknown" +} diff --git a/client/torrent_test.go b/client/torrent_test.go new file mode 100644 index 0000000..088d9e0 --- /dev/null +++ b/client/torrent_test.go @@ -0,0 +1,121 @@ +package client + +import ( + "crypto/sha1" + "testing" + + "storrent/bt" +) + +func TestPut(t *testing.T) { + tests := []struct { + in string + hash [20]byte + want bool + }{ + { + in: "hello", + hash: sha1.Sum([]byte("hello")), + want: true, + }, + { + in: "wrong", + hash: sha1.Sum([]byte("hello")), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + p := &piece{ + hash: tt.hash, + size: int64(len(tt.in)), + data: make([]byte, len(tt.in)), + have: make([]bool, 1), + reqs: make([]bool, 1), + reqPeer: make([]*bt.Peer, 1), + } + if got := p.put(0, []byte(tt.in)); got != tt.want { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +} + +func TestPutMultiple(t *testing.T) { + data := make([]byte, bt.BlockSize+5) + for i := range data { + data[i] = byte(i) + } + p := &piece{ + hash: sha1.Sum(data), + size: int64(len(data)), + data: make([]byte, len(data)), + have: make([]bool, 2), + reqs: make([]bool, 2), + reqPeer: make([]*bt.Peer, 2), + } + + tests := []struct { + name string + begin int + data []byte + want bool + }{ + { + name: "first block", + begin: 0, + data: data[:bt.BlockSize], + want: false, + }, + { + name: "second block", + begin: bt.BlockSize, + data: data[bt.BlockSize:], + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := p.put(tt.begin, tt.data); got != tt.want { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +} + +func TestNextRequest(t *testing.T) { + tests := []struct { + name string + pieceSize int64 + have []bool + wantLength uint32 + }{ + { + name: "last block partial", + pieceSize: bt.BlockSize + 100, + have: []bool{true, false}, + wantLength: 100, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ps := &pieceScheduler{ + pieces: []piece{{ + size: tt.pieceSize, + have: tt.have, + reqs: make([]bool, len(tt.have)), + reqPeer: make([]*bt.Peer, len(tt.have)), + }}, + incomplete: []int{0}, + } + c := &bt.Peer{Have: []bool{true}} + msg := ps.nextRequest(c) + if msg == nil { + t.Fatal("expected request") + } + if msg.Length != tt.wantLength { + t.Errorf("got %d, want %d", msg.Length, tt.wantLength) + } + }) + } +} diff --git a/dht/dht.go b/dht/dht.go new file mode 100644 index 0000000..038e918 --- /dev/null +++ b/dht/dht.go @@ -0,0 +1,228 @@ +package dht + +import ( + "context" + "fmt" + "net" + "slices" + "strconv" + "sync" + "time" +) + +const ( + queryTimeout = 5 * time.Second + alpha = 3 + maxQueries = 32 +) + +type DHT struct { + conn *net.UDPConn + id [20]byte + nodes []node + pending map[string]chan *resp + mu sync.Mutex +} + +func New(port int) (*DHT, error) { + addr := &net.UDPAddr{Port: port} + conn, err := net.ListenUDP("udp", addr) + if err != nil { + return nil, err + } + d := &DHT{ + conn: conn, + id: genNodeID(), + pending: make(map[string]chan *resp), + } + go d.run() + return d, nil +} + +func (d *DHT) Close() { + d.conn.Close() +} + +func (d *DHT) run() { + buf := make([]byte, 65535) + for { + n, _, err := d.conn.ReadFromUDP(buf) + if err != nil { + return + } + msg, err := decodeMsg(buf[:n]) + if err != nil { + continue + } + d.mu.Lock() + ch, ok := d.pending[msg.T] + if ok { + delete(d.pending, msg.T) + } + d.mu.Unlock() + if ok { + ch <- &msg.R + } + } +} + +func xorDistance(a, b [20]byte) [20]byte { + var dist [20]byte + for i := range 20 { + dist[i] = a[i] ^ b[i] + } + return dist +} + +func compareDist(a, b [20]byte) int { + for i := range 20 { + if a[i] < b[i] { + return -1 + } + if a[i] > b[i] { + return 1 + } + } + return 0 +} + +func (d *DHT) query(ctx context.Context, addr *net.UDPAddr, q string, a args) (*resp, error) { + ctx, cancel := context.WithTimeout(ctx, queryTimeout) + defer cancel() + + txid := genTxID() + a.ID = d.id + m := &msg{ + T: txid, + Y: query, + Q: q, + A: a, + } + data, err := m.Encode() + if err != nil { + return nil, err + } + + reply := make(chan *resp, 1) + d.mu.Lock() + d.pending[txid] = reply + d.mu.Unlock() + + defer func() { + d.mu.Lock() + delete(d.pending, txid) + d.mu.Unlock() + }() + + if _, err := d.conn.WriteToUDP(data, addr); err != nil { + return nil, err + } + + select { + case r := <-reply: + return r, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (d *DHT) Bootstrap(ctx context.Context, addr string) error { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return err + } + ips, err := net.LookupHost(host) + if err != nil { + return err + } + if len(ips) == 0 { + return fmt.Errorf("no bootstrap nodes") + } + p, err := strconv.Atoi(port) + if err != nil { + return fmt.Errorf("invalid port: %w", err) + } + uaddr := &net.UDPAddr{ + IP: net.ParseIP(ips[0]), + Port: p, + } + resp, err := d.query(ctx, uaddr, findNode, args{Target: d.id}) + if err != nil { + return err + } + d.nodes = append(d.nodes, resp.Nodes...) + return nil +} + +func sortByDist(nodes []node, target [20]byte) { + slices.SortFunc(nodes, func(a, b node) int { + return compareDist(xorDistance(a.ID, target), xorDistance(b.ID, target)) + }) +} + +func (d *DHT) GetPeers(ctx context.Context, h [20]byte) ([]*net.TCPAddr, error) { + queried := make(map[string]bool) + candidates := make([]node, len(d.nodes)) + copy(candidates, d.nodes) + sortByDist(candidates, h) + + results := make(chan *resp) + inflight := 0 + queryCount := 0 + var peers []*net.TCPAddr + + for queryCount < maxQueries { + if ctx.Err() != nil { + break + } + for inflight < alpha && queryCount < maxQueries { + n, i := nextCandidate(candidates, queried) + if i < 0 { + break + } + candidates = slices.Delete(candidates, i, i+1) + queried[n.Addr.String()] = true + inflight++ + queryCount++ + go func(addr *net.UDPAddr) { + r, _ := d.query(ctx, addr, getPeers, args{InfoHash: h}) + results <- r + }(n.Addr) + } + if inflight == 0 { + break + } + r := <-results + inflight-- + if r == nil { + continue + } + for _, p := range r.Peers { + if a := decodePeer(p); a != nil { + peers = append(peers, a) + } + } + for _, c := range r.Nodes { + if !queried[c.Addr.String()] { + candidates = append(candidates, c) + } + } + sortByDist(candidates, h) + } + for range inflight { + <-results + } + if len(peers) > 0 { + return peers, nil + } + return nil, fmt.Errorf("no peers found") +} + +func nextCandidate(candidates []node, queried map[string]bool) (node, int) { + for i, c := range candidates { + if !queried[c.Addr.String()] { + return c, i + } + } + return node{}, -1 +} diff --git a/dht/msg.go b/dht/msg.go new file mode 100644 index 0000000..2e757f2 --- /dev/null +++ b/dht/msg.go @@ -0,0 +1,246 @@ +package dht + +// https://www.bittorrent.org/beps/bep_0005.html +// Peer: TCP +// Node: UDP +// BEP 5: DHT Protocol - KRPC over UDP +// Message: bencode dict with keys: +// t: transaction id (2 bytes) +// y: type "q" (query), "r" (response), "e" (error) +// q: query name (ping, find_node, get_peers, announce_peer) +// a: query args dict +// r: response dict +// e: error [code, message] +// Compact formats: +// peer: 6 bytes (4 ip + 2 port) +// node: 26 bytes (20 id + 4 ip + 2 port) + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "net" + + "storrent/bencode" +) + +const ( + query = "q" + response = "r" + error_ = "e" +) + +const ( + ping = "ping" + findNode = "find_node" + getPeers = "get_peers" + announcePeer = "announce_peer" +) + +type node struct { + ID [20]byte + Addr *net.UDPAddr +} + +type args struct { + ID [20]byte + Target [20]byte + InfoHash [20]byte + Port int + Token string + ImpliedPort int +} + +type resp struct { + ID [20]byte + Nodes []node + Peers []string + Token string +} + +type errResp struct { + Code int + Msg string +} + +type msg struct { + T string + Y string + Q string + A args + R resp + E errResp +} + +func genTxID() string { + b := make([]byte, 2) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return string(b) +} + +func genNodeID() [20]byte { + var id [20]byte + if _, err := rand.Read(id[:]); err != nil { + panic(err) + } + return id +} + +func (m *msg) Encode() ([]byte, error) { + d := make(map[string]any) + d["t"] = m.T + d["y"] = m.Y + + switch m.Y { + case query: + d["q"] = m.Q + a := make(map[string]any) + a["id"] = string(m.A.ID[:]) + switch m.Q { + case findNode: + a["target"] = string(m.A.Target[:]) + case getPeers: + a["info_hash"] = string(m.A.InfoHash[:]) + case announcePeer: + a["info_hash"] = string(m.A.InfoHash[:]) + a["port"] = int64(m.A.Port) + a["token"] = m.A.Token + if m.A.ImpliedPort != 0 { + a["implied_port"] = int64(m.A.ImpliedPort) + } + } + d["a"] = a + case response: + r := make(map[string]any) + r["id"] = string(m.R.ID[:]) + if len(m.R.Nodes) > 0 { + r["nodes"] = encodeNodes(m.R.Nodes) + } + if len(m.R.Peers) > 0 { + vals := make([]any, len(m.R.Peers)) + for i, p := range m.R.Peers { + vals[i] = p + } + r["values"] = vals + } + if m.R.Token != "" { + r["token"] = m.R.Token + } + d["r"] = r + case error_: + d["e"] = []any{int64(m.E.Code), m.E.Msg} + } + return bencode.Encode(d) +} + +func decodeMsg(data []byte) (*msg, error) { + v, _, err := bencode.Decode(data) + if err != nil { + return nil, err + } + + d, ok := v.(map[string]any) + if !ok { + return nil, fmt.Errorf("invalid krpc message") + } + + m := &msg{} + m.T, _ = d["t"].(string) + m.Y, _ = d["y"].(string) + switch m.Y { + case query: + m.Q, _ = d["q"].(string) + if a, ok := d["a"].(map[string]any); ok { + if id, ok := a["id"].(string); ok && len(id) == 20 { + copy(m.A.ID[:], id) + } + if target, ok := a["target"].(string); ok && len(target) == 20 { + copy(m.A.Target[:], target) + } + if ih, ok := a["info_hash"].(string); ok && len(ih) == 20 { + copy(m.A.InfoHash[:], ih) + } + if port, ok := a["port"].(int64); ok { + m.A.Port = int(port) + } + if token, ok := a["token"].(string); ok { + m.A.Token = token + } + if iport, ok := a["implied_port"].(int64); ok { + m.A.ImpliedPort = int(iport) + } + } + + case response: + if r, ok := d["r"].(map[string]any); ok { + if id, ok := r["id"].(string); ok && len(id) == 20 { + copy(m.R.ID[:], id) + } + if nodes, ok := r["nodes"].(string); ok { + m.R.Nodes = decodeNodes(nodes) + } + if vals, ok := r["values"].([]any); ok { + for _, v := range vals { + if p, ok := v.(string); ok { + m.R.Peers = append(m.R.Peers, p) + } + } + } + if token, ok := r["token"].(string); ok { + m.R.Token = token + } + } + + case error_: + if e, ok := d["e"].([]any); ok && len(e) >= 2 { + if code, ok := e[0].(int64); ok { + m.E.Code = int(code) + } + m.E.Msg, _ = e[1].(string) + } + } + return m, nil +} + +// 26 = 20 id + 4 ip + 2 port +func encodeNodes(nodes []node) string { + buf := make([]byte, 26*len(nodes)) + for i, n := range nodes { + off := i * 26 + copy(buf[off:], n.ID[:]) + copy(buf[off+20:], n.Addr.IP.To4()) + binary.BigEndian.PutUint16(buf[off+24:], uint16(n.Addr.Port)) + } + return string(buf) +} + +func decodeNodes(s string) []node { + data := []byte(s) + if len(data)%26 != 0 { + return nil + } + n := len(data) / 26 + nodes := make([]node, n) + for i := range nodes { + off := i * 26 + copy(nodes[i].ID[:], data[off:]) + ip := net.IP(data[off+20 : off+24]) + port := binary.BigEndian.Uint16(data[off+24:]) + nodes[i].Addr = &net.UDPAddr{IP: ip, Port: int(port)} + } + return nodes +} + +// 6 = 4 ip + 2 port +func decodePeer(s string) *net.TCPAddr { + if len(s) != 6 { + return nil + } + data := []byte(s) + return &net.TCPAddr{ + IP: net.IP(data[:4]), + Port: int(binary.BigEndian.Uint16(data[4:])), + } +} diff --git a/dht/msg_test.go b/dht/msg_test.go new file mode 100644 index 0000000..f6c7757 --- /dev/null +++ b/dht/msg_test.go @@ -0,0 +1,67 @@ +package dht + +import ( + "net" + "strings" + "testing" +) + +func nid(c string) [20]byte { + var id [20]byte + copy(id[:], strings.Repeat(c, 20)) + return id +} + +func TestEncodeDecodeNodes(t *testing.T) { + tests := []struct { + name string + nodes []node + }{ + { + name: "empty", + nodes: []node{}, + }, + { + name: "single", + nodes: []node{ + { + ID: nid("a"), + Addr: &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 6881, + }, + }, + }, + }, + { + name: "multiple", + nodes: []node{ + { + ID: nid("a"), + Addr: &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 6881, + }, + }, + { + ID: nid("a"), + Addr: &net.UDPAddr{ + IP: net.IPv4(192, 168, 1, 1), + Port: 8080, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := encodeNodes(tt.nodes) + got := decodeNodes(data) + data2 := encodeNodes(got) + if data != data2 { + t.Errorf("roundtrip mismatch: got %x, want %x", data2, data) + } + }) + } +} diff --git a/fs/fs.go b/fs/fs.go new file mode 100644 index 0000000..954265f --- /dev/null +++ b/fs/fs.go @@ -0,0 +1,156 @@ +package fs + +import ( + "fmt" + "strconv" + "strings" + "sync" + + "github.com/knusbaum/go9p/fs" + "github.com/knusbaum/go9p/proto" + + "storrent/client" +) + +type Dir struct { + stat proto.Stat + parent fs.Dir + man *client.Manager + fsys *fs.FS + sync.RWMutex +} + +func New(e *client.Manager) *fs.FS { + fsys, root := fs.NewFS("storrent", "storrent", 0555) + root.AddChild(&fs.WrappedFile{ + File: fs.NewBaseFile(fsys.NewStat("ctl", "storrent", "storrent", 0222)), + WriteF: func(fid uint64, offset uint64, data []byte) (uint32, error) { + cmd := strings.TrimSpace(string(data)) + if strings.HasPrefix(cmd, "add ") { + path := strings.TrimSpace(cmd[4:]) + _, err := e.Add(path) + if err != nil { + return 0, err + } + return uint32(len(data)), nil + } + return 0, fmt.Errorf("unknown command: %s", cmd) + }, + }) + + root.AddChild(fs.NewDynamicFile( + fsys.NewStat("list", "storrent", "storrent", 0444), + func() []byte { + data := e.List() + if len(data) > 0 { + return append(data, '\n') + } + return data + }, + )) + + dir := &Dir{ + stat: *fsys.NewStat("torrents", "storrent", "storrent", 0555|proto.DMDIR), + man: e, + fsys: fsys, + } + dir.stat.Qid.Qtype = uint8(dir.stat.Mode >> 24) + root.AddChild(dir) + return fsys +} + +func (d *Dir) Stat() proto.Stat { + d.Lock() + defer d.Unlock() + return d.stat +} + +func (d *Dir) WriteStat(s *proto.Stat) error { + d.Lock() + defer d.Unlock() + d.stat = *s + return nil +} + +func (d *Dir) SetParent(p fs.Dir) { + d.Lock() + defer d.Unlock() + d.parent = p +} + +func (d *Dir) Parent() fs.Dir { + d.RLock() + defer d.RUnlock() + return d.parent +} + +func (d *Dir) Children() map[string]fs.FSNode { + data := d.man.List() + m := make(map[string]fs.FSNode) + if len(data) == 0 { + return m + } + for _, name := range strings.Split(string(data), "\n") { + id, err := strconv.Atoi(name) + if err != nil { + continue + } + m[name] = d.newTorrentDir(id) + } + return m +} + +func (d *Dir) newTorrentDir(id int) fs.FSNode { + dir := fs.NewStaticDir(d.fsys.NewStat(strconv.Itoa(id), "storrent", "storrent", 0555)) + dir.AddChild(&fs.WrappedFile{ + File: fs.NewBaseFile(d.fsys.NewStat("ctl", "storrent", "storrent", 0222)), + WriteF: func(fid uint64, offset uint64, data []byte) (uint32, error) { + cmd := strings.TrimSpace(string(data)) + var err error + switch { + case cmd == "start": + err = d.man.Start(id) + case cmd == "stop": + err = d.man.Stop(id) + case cmd == "seed": + err = d.man.Seed(id) + case cmd == "remove": + err = d.man.Remove(id) + case strings.HasPrefix(cmd, "peer "): + err = d.man.AddPeer(id, strings.TrimSpace(cmd[5:])) + default: + return 0, fmt.Errorf("unknown command: %s", cmd) + } + if err != nil { + return 0, err + } + return uint32(len(data)), nil + }, + }) + + dir.AddChild(d.newStatusFile(id, "name")) + dir.AddChild(d.newStatusFile(id, "state")) + dir.AddChild(d.newStatusFile(id, "progress")) + dir.AddChild(d.newStatusFile(id, "size")) + dir.AddChild(d.newStatusFile(id, "down")) + dir.AddChild(d.newStatusFile(id, "up")) + dir.AddChild(d.newStatusFile(id, "pieces")) + dir.AddChild(d.newStatusFile(id, "peers")) + return dir +} + +func (d *Dir) newStatusFile(id int, field string) fs.FSNode { + return fs.NewDynamicFile( + d.fsys.NewStat(field, "storrent", "storrent", 0444), + func() []byte { + data, err := d.man.Status(id, field) + if err != nil { + return nil + } + if len(data) > 0 { + return append(data, '\n') + } + return data + }, + ) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..47409d0 --- /dev/null +++ b/go.mod @@ -0,0 +1,12 @@ +module storrent + +go 1.25.4 + +require github.com/knusbaum/go9p v1.18.0 + +require ( + 9fans.net/go v0.0.2 // indirect + github.com/Plan9-Archive/libauth v0.0.0-20180917063427-d1ca9e94969d // indirect + github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 // indirect + github.com/fhs/mux9p v0.3.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6fc824c --- /dev/null +++ b/go.sum @@ -0,0 +1,25 @@ +9fans.net/go v0.0.2 h1:RYM6lWITV8oADrwLfdzxmt8ucfW6UtP9v1jg4qAbqts= +9fans.net/go v0.0.2/go.mod h1:lfPdxjq9v8pVQXUMBCx5EO5oLXWQFlKRQgs1kEkjoIM= +github.com/Plan9-Archive/libauth v0.0.0-20180917063427-d1ca9e94969d h1:xH/U6K+HYxh1480TkQYRqRO8F2RJsg+R6wFiVJzdldg= +github.com/Plan9-Archive/libauth v0.0.0-20180917063427-d1ca9e94969d/go.mod h1:UKp8dv9aeaZoQFWin7eQXtz89iHly1YAFZNn3MCutmQ= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 h1:OJyUGMJTzHTd1XQp98QTaHernxMYzRaOasRir9hUlFQ= +github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= +github.com/fhs/mux9p v0.3.1 h1:x1UswUWZoA9vrA02jfisndCq3xQm+wrQUxUt5N99E08= +github.com/fhs/mux9p v0.3.1/go.mod h1:F4hwdenmit0WDoNVT2VMWlLJrBVCp/8UhzJa7scfjEQ= +github.com/hanwen/go-fuse v1.0.0/go.mod h1:unqXarDXqzAk0rt98O2tVndEPIpUgLD9+rwFisZH3Ok= +github.com/hanwen/go-fuse/v2 v2.0.3/go.mod h1:0EQM6aH2ctVpvZ6a+onrQ/vaykxh2GH7hy3e13vzTUY= +github.com/knusbaum/go9p v1.18.0 h1:/Y67RNvNKX1ZV1IOdnO1lIetiF0X+CumOyvEc0011GI= +github.com/knusbaum/go9p v1.18.0/go.mod h1:HtMoJKqZUe1Oqag5uJqG5RKQ9gWPSP+wolsnLLv44r8= +github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201020230747-6e5568b54d1a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/main.go b/main.go new file mode 100644 index 0000000..9460f03 --- /dev/null +++ b/main.go @@ -0,0 +1,49 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + + "github.com/knusbaum/go9p" + + "storrent/client" + "storrent/dht" + "storrent/fs" + "storrent/utp" +) + +func main() { + addr := flag.String("addr", ":5640", "9P listen address") + dir := flag.String("dir", "./download", "download directory") + dhtPort := flag.Int("dht", 6881, "DHT port") + btPort := flag.Int("bt", 0, "BT listen port") + utpPort := flag.Int("utp", 6882, "uTP port") + bootstrap := flag.String("bootstrap", "router.bittorrent.com:6881", "DHT bootstrap node") + flag.Parse() + + d, err := dht.New(*dhtPort) + if err != nil { + log.Fatal(err) + } + if err := d.Bootstrap(context.Background(), *bootstrap); err != nil { + log.Fatal(err) + } + + utpSock, err := utp.New(fmt.Sprintf(":%d", *utpPort)) + if err != nil { + log.Fatal(err) + } + utpSock.Start() + + m, err := client.NewManager(*dir, d, *btPort, utpSock) + if err != nil { + log.Fatal(err) + } + go m.Run() + + if err := go9p.Serve(*addr, fs.New(m).Server()); err != nil { + log.Fatal(err) + } +} diff --git a/metainfo/metainfo.go b/metainfo/metainfo.go new file mode 100644 index 0000000..eadc913 --- /dev/null +++ b/metainfo/metainfo.go @@ -0,0 +1,182 @@ +package metainfo + +import ( + "crypto/sha1" + "fmt" + "os" + + "storrent/bencode" +) + +type File struct { + Info Info + InfoHash [20]byte + Size int64 +} + +type Info struct { + Name string + PieceSize int64 + Pieces [][20]byte + Size int64 + Files []Entry +} + +type Entry struct { + Size int64 + Offset int64 + Path []string +} + +type Segment struct { + File int + Offset int64 + Size int64 +} + +func (i *Info) Segments(off, size int64) []Segment { + var segs []Segment + for idx, f := range i.Files { + end := f.Offset + f.Size + if off >= end { + continue + } + n := min(size, end-off) + segs = append(segs, Segment{File: idx, Offset: off - f.Offset, Size: n}) + off += n + size -= n + if size == 0 { + break + } + } + return segs +} + +func Parse(path string) (*File, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return ParseBytes(data) +} + +func ParseBytes(data []byte) (*File, error) { + v, _, err := bencode.Decode(data) + if err != nil { + return nil, err + } + d, ok := v.(map[string]any) + if !ok { + return nil, fmt.Errorf("invalid torrent dict") + } + + f := &File{} + raw, err := findInfoBytes(data) + if err != nil { + return nil, err + } + f.InfoHash = sha1.Sum(raw) + dict, ok := d["info"].(map[string]any) + if !ok { + return nil, fmt.Errorf("invalid info dict") + } + f.Info, err = parseInfo(dict) + if err != nil { + return nil, err + } + if f.Info.Size > 0 { + f.Size = f.Info.Size + } else { + for _, e := range f.Info.Files { + f.Size += e.Size + } + } + return f, nil +} + +func findInfoBytes(data []byte) ([]byte, error) { + if len(data) == 0 || data[0] != 'd' { + return nil, fmt.Errorf("not a dict") + } + + i := 1 + for i < len(data) && data[i] != 'e' { + k, n, err := bencode.DecodeString(data[i:]) + if err != nil { + return nil, err + } + i += n + if k == "info" { + _, n, err := bencode.Decode(data[i:]) + if err != nil { + return nil, err + } + return data[i : i+n], nil + } + _, n, err = bencode.Decode(data[i:]) + if err != nil { + return nil, err + } + i += n + } + return nil, fmt.Errorf("info not found") +} + +func parseInfo(d map[string]any) (Info, error) { + var info Info + name, ok := d["name"].(string) + if !ok { + return info, fmt.Errorf("invalid name") + } + info.Name = name + pl, ok := d["piece length"].(int64) + if !ok { + return info, fmt.Errorf("invalid piece length") + } + info.PieceSize = pl + ps, ok := d["pieces"].(string) + if !ok || len(ps)%20 != 0 { + return info, fmt.Errorf("invalid pieces") + } + npieces := len(ps) / 20 + info.Pieces = make([][20]byte, npieces) + for i := range npieces { + copy(info.Pieces[i][:], ps[i*20:(i+1)*20]) + } + if n, ok := d["length"].(int64); ok { + info.Size = n + return info, nil + } + + fs, ok := d["files"].([]any) + if !ok { + return info, fmt.Errorf("invalid files") + } + + off := int64(0) + for _, f := range fs { + fd, ok := f.(map[string]any) + if !ok { + return info, fmt.Errorf("invalid file entry") + } + n, ok := fd["length"].(int64) + if !ok { + return info, fmt.Errorf("invalid file length") + } + list, ok := fd["path"].([]any) + if !ok { + return info, fmt.Errorf("invalid file path") + } + var path []string + for _, p := range list { + s, ok := p.(string) + if !ok { + return info, fmt.Errorf("invalid path element") + } + path = append(path, s) + } + info.Files = append(info.Files, Entry{Size: n, Offset: off, Path: path}) + off += n + } + return info, nil +} diff --git a/metainfo/metainfo_test.go b/metainfo/metainfo_test.go new file mode 100644 index 0000000..9633c7e --- /dev/null +++ b/metainfo/metainfo_test.go @@ -0,0 +1,141 @@ +package metainfo + +import ( + "crypto/sha1" + "testing" + + "storrent/bencode" +) + +func TestParse(t *testing.T) { + tests := []struct { + name string + info map[string]any + wantName string + wantSize int64 + }{ + { + name: "single file", + info: map[string]any{ + "name": "test.txt", + "piece length": int64(16384), + "pieces": string(make([]byte, 20)), + "length": int64(100), + }, + wantName: "test.txt", + wantSize: 100, + }, + { + name: "multi file", + info: map[string]any{ + "name": "dir", + "piece length": int64(16384), + "pieces": string(make([]byte, 20)), + "files": []any{ + map[string]any{"length": int64(100), "path": []any{"a.txt"}}, + map[string]any{"length": int64(200), "path": []any{"b.txt"}}, + }, + }, + wantName: "dir", + wantSize: 300, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := bencode.Encode(map[string]any{"info": tt.info}) + if err != nil { + t.Fatalf("Encode: %v", err) + } + m, err := ParseBytes(data) + if err != nil { + t.Fatalf("ParseBytes: %v", err) + } + if m.Info.Name != tt.wantName { + t.Errorf("Name: got %s, want %s", m.Info.Name, tt.wantName) + } + if m.Size != tt.wantSize { + t.Errorf("Size: got %d, want %d", m.Size, tt.wantSize) + } + }) + } +} + +func TestInfoHash(t *testing.T) { + tests := []struct { + name string + raw string + }{ + { + name: "basic", + raw: "d6:lengthi1e4:name4:test12:piece lengthi16384e6:pieces20:01234567890123456789e", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + raw := []byte(tt.raw) + torrent := append([]byte("d4:info"), raw...) + torrent = append(torrent, 'e') + m, err := ParseBytes(torrent) + if err != nil { + t.Fatalf("ParseBytes: %v", err) + } + want := sha1.Sum(raw) + if m.InfoHash != want { + t.Errorf("InfoHash: got %x, want %x", m.InfoHash, want) + } + }) + } +} + +func TestSegments(t *testing.T) { + tests := []struct { + name string + files []Entry + off int64 + size int64 + want []Segment + }{ + { + name: "single file", + files: []Entry{{Size: 32, Offset: 0}}, + off: 0, + size: 32, + want: []Segment{{0, 0, 32}}, + }, + { + name: "spans two files", + files: []Entry{{Size: 16, Offset: 0}, {Size: 16, Offset: 16}}, + off: 0, + size: 32, + want: []Segment{{0, 0, 16}, {1, 0, 16}}, + }, + { + name: "middle of file", + files: []Entry{{Size: 32, Offset: 0}}, + off: 8, + size: 16, + want: []Segment{{0, 8, 16}}, + }, + { + name: "skip first file", + files: []Entry{{Size: 16, Offset: 0}, {Size: 16, Offset: 16}}, + off: 16, + size: 16, + want: []Segment{{1, 0, 16}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info := Info{Files: tt.files} + got := info.Segments(tt.off, tt.size) + if len(got) != len(tt.want) { + t.Fatalf("got %d segments, want %d", len(got), len(tt.want)) + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("segment %d: got %+v, want %+v", i, got[i], tt.want[i]) + } + } + }) + } +} diff --git a/metainfo/storage.go b/metainfo/storage.go new file mode 100644 index 0000000..d52f997 --- /dev/null +++ b/metainfo/storage.go @@ -0,0 +1,68 @@ +package metainfo + +import ( + "os" + "path/filepath" +) + +func (info *Info) Write(dir string, i int, data []byte) error { + off := int64(i) * info.PieceSize + if info.Size > 0 { + return writeAt(filepath.Join(dir, info.Name), off, data) + } + for _, seg := range info.Segments(off, int64(len(data))) { + f := info.Files[seg.File] + path := filepath.Join(dir, info.Name, filepath.Join(f.Path...)) + if err := writeAt(path, seg.Offset, data[:seg.Size]); err != nil { + return err + } + data = data[seg.Size:] + } + return nil +} + +func (info *Info) Read(dir string, i int, size int64) []byte { + off := int64(i) * info.PieceSize + if info.Size > 0 { + return readAt(filepath.Join(dir, info.Name), off, size) + } + data := make([]byte, size) + pos := int64(0) + for _, seg := range info.Segments(off, size) { + f := info.Files[seg.File] + path := filepath.Join(dir, info.Name, filepath.Join(f.Path...)) + chunk := readAt(path, seg.Offset, seg.Size) + if chunk == nil { + return nil + } + copy(data[pos:], chunk) + pos += seg.Size + } + return data +} + +func writeAt(path string, off int64, data []byte) error { + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return err + } + defer f.Close() + _, err = f.WriteAt(data, off) + return err +} + +func readAt(path string, off, size int64) []byte { + f, err := os.Open(path) + if err != nil { + return nil + } + defer f.Close() + data := make([]byte, size) + if _, err := f.ReadAt(data, off); err != nil { + return nil + } + return data +} diff --git a/utp/conn.go b/utp/conn.go new file mode 100644 index 0000000..539ab3a --- /dev/null +++ b/utp/conn.go @@ -0,0 +1,305 @@ +package utp + +import ( + "context" + "io" + "net" + "time" +) + +const ( + initRTO = 1000 // ms + retransCheck = 500 // ms + defaultWnd = 64 * 1024 +) + +type Conn struct { + sock *Socket + addr *net.UDPAddr + recvID uint16 + sendID uint16 + + in chan packet + reads chan readReq + writes chan writeReq + closeReq chan struct{} + ready chan struct{} + ctx context.Context + cancel context.CancelFunc + + seq uint16 + ack uint16 + unacked []sent + reorder map[uint16][]byte + rest []byte + pending *readReq + rto uint32 +} + +type readReq struct { + p []byte + reply chan readResp +} + +type readResp struct { + n int + err error +} + +type writeReq struct { + data []byte + reply chan error +} + +type sent struct { + seq uint16 + data []byte + at time.Time +} + +func (c *Conn) flush() { + if c.pending == nil || len(c.rest) == 0 { + return + } + n := copy(c.pending.p, c.rest) + c.rest = c.rest[n:] + c.pending.reply <- readResp{n, nil} + c.pending = nil +} + +func (c *Conn) deliver(payload []byte) { + c.rest = append(c.rest, payload...) + c.flush() +} + +func (c *Conn) send(typ uint8, payload []byte) error { + c.seq++ + h := &header{ + typ: typ, + ver: 1, + connID: c.sendID, + timestamp: timestamp(), + wnd: defaultWnd, + seq: c.seq, + ack: c.ack, + } + buf := make([]byte, headerSize+len(payload)) + encode(h, buf) + copy(buf[headerSize:], payload) + if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil { + return err + } + if typ == Data || typ == Syn || typ == Fin { + c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()}) + } + return nil +} + +func (c *Conn) sendState() error { + h := &header{ + typ: State, + ver: 1, + connID: c.sendID, + timestamp: timestamp(), + wnd: defaultWnd, + seq: c.seq, + ack: c.ack, + } + buf := make([]byte, headerSize) + encode(h, buf) + _, err := c.sock.conn.WriteToUDP(buf, c.addr) + return err +} + +func (c *Conn) sendSyn() error { + c.seq++ + h := &header{ + typ: Syn, + ver: 1, + connID: c.recvID, + timestamp: timestamp(), + wnd: defaultWnd, + seq: c.seq, + ack: c.ack, + } + buf := make([]byte, headerSize) + encode(h, buf) + if _, err := c.sock.conn.WriteToUDP(buf, c.addr); err != nil { + return err + } + c.unacked = append(c.unacked, sent{c.seq, buf, time.Now()}) + return nil +} + +func (c *Conn) retrans() error { + select { + case <-c.ctx.Done(): + return nil + default: + } + now := time.Now() + for i := range c.unacked { + p := &c.unacked[i] + if now.Sub(p.at) > time.Duration(c.rto)*time.Millisecond { + if _, err := c.sock.conn.WriteToUDP(p.data, c.addr); err != nil { + return err + } + p.at = now + } + } + return nil +} + +func (c *Conn) recv(p packet) (bool, error) { + switch p.hdr.typ { + case Syn: + c.ack = p.hdr.seq + if err := c.sendState(); err != nil { + return false, err + } + case State: + select { + case <-c.ready: + default: + c.ack = p.hdr.seq - 1 + close(c.ready) + } + var remaining []sent + for _, s := range c.unacked { + if seqLess(p.hdr.ack, s.seq) { + remaining = append(remaining, s) + } + } + c.unacked = remaining + case Data: + seq, payload := p.hdr.seq, p.payload + if seq == c.ack+1 { + c.deliver(payload) + c.ack++ + for { + if p, ok := c.reorder[c.ack+1]; ok { + c.deliver(p) + delete(c.reorder, c.ack+1) + c.ack++ + } else { + break + } + } + } else if seqLess(c.ack+1, seq) { + c.reorder[seq] = payload + } + if err := c.sendState(); err != nil { + return false, err + } + case Fin, Reset: + return true, nil + } + return false, nil +} + +func (c *Conn) shutdown() { + if c.pending != nil { + c.pending.reply <- readResp{0, io.EOF} + } + select { + case <-c.ready: + c.send(Fin, nil) + default: + } + c.unacked = nil + c.cancel() + c.sock.removeConn(c.recvID) +} + +func (c *Conn) run() { + c.reorder = make(map[uint16][]byte) + c.rto = initRTO + + if c.recvID < c.sendID { + c.seq = 1 + if err := c.sendSyn(); err != nil { + c.shutdown() + return + } + } + + ticker := time.NewTicker(retransCheck * time.Millisecond) + defer ticker.Stop() + + for { + select { + case p := <-c.in: + done, err := c.recv(p) + if err != nil { + c.shutdown() + return + } + if done { + c.shutdown() + return + } + case req := <-c.reads: + if c.pending != nil { + req.reply <- readResp{0, io.ErrNoProgress} + continue + } + c.pending = &req + c.flush() + case req := <-c.writes: + if err := c.send(Data, req.data); err != nil { + req.reply <- err + c.shutdown() + return + } + req.reply <- nil + case <-c.closeReq: + c.shutdown() + return + case <-ticker.C: + c.retrans() + case <-c.ctx.Done(): + return + } + } +} + +func (c *Conn) Read(p []byte) (int, error) { + reply := make(chan readResp, 1) + select { + case c.reads <- readReq{p, reply}: + r := <-reply + return r.n, r.err + case <-c.ctx.Done(): + return 0, io.EOF + } +} + +func (c *Conn) Write(p []byte) (int, error) { + reply := make(chan error, 1) + select { + case c.writes <- writeReq{p, reply}: + err := <-reply + if err != nil { + return 0, err + } + return len(p), nil + case <-c.ctx.Done(): + return 0, io.ErrClosedPipe + } +} + +func (c *Conn) Close() error { + select { + case c.closeReq <- struct{}{}: + case <-c.ctx.Done(): + } + return nil +} + +func timestamp() uint32 { + return uint32(time.Now().UnixMicro() & 0xFFFFFFFF) +} + +func seqLess(a, b uint16) bool { + return int16(a-b) < 0 +} diff --git a/utp/socket.go b/utp/socket.go new file mode 100644 index 0000000..7911329 --- /dev/null +++ b/utp/socket.go @@ -0,0 +1,211 @@ +package utp + +import ( + "context" + "crypto/rand" + "encoding/binary" + "net" + "sync" +) + +const ( + Data = 0 + Fin = 1 + State = 2 + Reset = 3 + Syn = 4 +) + +const headerSize = 20 + +type header struct { + typ uint8 + ver uint8 + ext uint8 + connID uint16 + timestamp uint32 + timeDiff uint32 + wnd uint32 + seq uint16 + ack uint16 +} + +func encode(h *header, buf []byte) { + buf[0] = (h.typ << 4) | (h.ver & 0x0F) + buf[1] = h.ext + binary.BigEndian.PutUint16(buf[2:], h.connID) + binary.BigEndian.PutUint32(buf[4:], h.timestamp) + binary.BigEndian.PutUint32(buf[8:], h.timeDiff) + binary.BigEndian.PutUint32(buf[12:], h.wnd) + binary.BigEndian.PutUint16(buf[16:], h.seq) + binary.BigEndian.PutUint16(buf[18:], h.ack) +} + +func decode(buf []byte) header { + return header{ + typ: (buf[0] >> 4) & 0x0F, + ver: buf[0] & 0x0F, + ext: buf[1], + connID: binary.BigEndian.Uint16(buf[2:]), + timestamp: binary.BigEndian.Uint32(buf[4:]), + timeDiff: binary.BigEndian.Uint32(buf[8:]), + wnd: binary.BigEndian.Uint32(buf[12:]), + seq: binary.BigEndian.Uint16(buf[16:]), + ack: binary.BigEndian.Uint16(buf[18:]), + } +} + +type Socket struct { + conn *net.UDPConn + mu sync.RWMutex + conns map[uint16]*Conn + accepts chan *Conn + ctx context.Context + cancel context.CancelFunc +} + +type packet struct { + hdr header + payload []byte + addr *net.UDPAddr +} + +func New(addr string) (*Socket, error) { + a, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + conn, err := net.ListenUDP("udp", a) + if err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(context.Background()) + s := &Socket{ + conn: conn, + conns: make(map[uint16]*Conn), + accepts: make(chan *Conn, 16), + ctx: ctx, + cancel: cancel, + } + return s, nil +} + +func (s *Socket) Start() { + go s.reader() +} + +func (s *Socket) reader() { + buf := make([]byte, 65535) + for { + n, addr, err := s.conn.ReadFromUDP(buf) + if err != nil { + return + } + if n < headerSize { + continue + } + hdr := decode(buf) + payload := make([]byte, n-headerSize) + copy(payload, buf[headerSize:n]) + + if hdr.typ == Syn { + s.handleSyn(hdr, payload, addr) + } else { + s.dispatch(hdr, payload) + } + } +} + +func (s *Socket) handleSyn(hdr header, payload []byte, addr *net.UDPAddr) { + c := s.newConn(hdr.connID, addr, false) + + s.mu.Lock() + s.conns[c.recvID] = c + s.mu.Unlock() + + go c.run() + c.in <- packet{hdr, payload, addr} + + select { + case s.accepts <- c: + default: + c.Close() + } +} + +func (s *Socket) dispatch(hdr header, payload []byte) { + s.mu.RLock() + c := s.conns[hdr.connID] + s.mu.RUnlock() + + if c != nil { + select { + case c.in <- packet{hdr, payload, nil}: + default: + } + } +} + +func (s *Socket) removeConn(id uint16) { + s.mu.Lock() + delete(s.conns, id) + s.mu.Unlock() +} + +func (s *Socket) newConn(peerID uint16, addr *net.UDPAddr, initiator bool) *Conn { + ctx, cancel := context.WithCancel(s.ctx) + c := &Conn{ + sock: s, + addr: addr, + in: make(chan packet, 16), + reads: make(chan readReq), + writes: make(chan writeReq), + closeReq: make(chan struct{}), + ready: make(chan struct{}), + ctx: ctx, + cancel: cancel, + } + if initiator { + c.recvID = randUint16() + c.sendID = c.recvID + 1 + } else { + c.recvID = peerID + 1 + c.sendID = peerID + } + return c +} + +func (s *Socket) DialContext(ctx context.Context, addr *net.UDPAddr) (*Conn, error) { + c := s.newConn(0, addr, true) + + s.mu.Lock() + s.conns[c.recvID] = c + s.mu.Unlock() + + go c.run() + + select { + case <-c.ready: + return c, nil + case <-ctx.Done(): + c.Close() + return nil, ctx.Err() + } +} + +func (s *Socket) Accept() *Conn { + return <-s.accepts +} + +func (s *Socket) Close() { + s.conn.Close() + s.cancel() +} + +func randUint16() uint16 { + var b [2]byte + if _, err := rand.Read(b[:]); err != nil { + panic(err) + } + return binary.BigEndian.Uint16(b[:]) +} diff --git a/utp/utp_test.go b/utp/utp_test.go new file mode 100644 index 0000000..e00fa59 --- /dev/null +++ b/utp/utp_test.go @@ -0,0 +1,82 @@ +package utp + +import "testing" + +func TestEncodeDecode(t *testing.T) { + tests := []struct { + name string + h header + }{ + { + name: "syn", + h: header{ + typ: Syn, + ver: 1, + ext: 0, + connID: 1234, + timestamp: 12345678, + timeDiff: 1000, + wnd: 65535, + seq: 1, + ack: 0, + }, + }, + { + name: "data", + h: header{ + typ: Data, + ver: 1, + connID: 5678, + timestamp: 99999999, + timeDiff: 500, + wnd: 32768, + seq: 100, + ack: 99, + }, + }, + { + name: "state", + h: header{ + typ: State, + ver: 1, + connID: 1234, + wnd: 65535, + seq: 1, + ack: 1, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, headerSize) + encode(&tt.h, buf) + got := decode(buf) + if got != tt.h { + t.Errorf("got %+v, want %+v", got, tt.h) + } + }) + } +} + +func TestSeqLess(t *testing.T) { + tests := []struct { + name string + a, b uint16 + want bool + }{ + {"1 < 2", 1, 2, true}, + {"2 > 1", 2, 1, false}, + {"wrap 0xFFFF < 0", 0xFFFF, 0, true}, + {"wrap 0 > 0xFFFF", 0, 0xFFFF, false}, + {"half 0x7FFF > 0", 0x7FFF, 0, false}, + {"half 0 < 0x7FFF", 0, 0x7FFF, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := seqLess(tt.a, tt.b) + if got != tt.want { + t.Errorf("seqLess(%d, %d) = %v, want %v", tt.a, tt.b, got, tt.want) + } + }) + } +}