first commit

This commit is contained in:
2026-01-19 21:13:01 +09:00
commit c70d24be5c
28 changed files with 3674 additions and 0 deletions

139
client/dial.go Normal file
View File

@@ -0,0 +1,139 @@
package client
import (
"context"
"net"
"sync"
"storrent/bt"
"storrent/dht"
"storrent/utp"
)
const maxConcurrentDials = 8
type dialResult struct {
peer *bt.Peer
err error
done bool
}
type dialPool struct {
sem chan struct{}
results chan dialResult
dht *dht.DHT
utp *utp.Socket
infoHash [20]byte
running bool
cancel context.CancelFunc
}
func newDialPool(d *dht.DHT, utpSock *utp.Socket, infoHash [20]byte) *dialPool {
return &dialPool{
sem: make(chan struct{}, maxConcurrentDials),
results: make(chan dialResult, maxConcurrentDials),
dht: d,
utp: utpSock,
infoHash: infoHash,
}
}
func (dp *dialPool) start(parentCtx context.Context) {
if dp.running {
return
}
dp.running = true
ctx, cancel := context.WithCancel(parentCtx)
dp.cancel = cancel
go dp.connectLoop(ctx)
}
func (dp *dialPool) stop() {
if !dp.running {
return
}
dp.running = false
if dp.cancel != nil {
dp.cancel()
dp.cancel = nil
}
}
func (dp *dialPool) connectLoop(ctx context.Context) {
dhtCtx, cancel := context.WithTimeout(ctx, dhtTimeout)
addrs, err := dp.dht.GetPeers(dhtCtx, dp.infoHash)
cancel()
if err != nil {
select {
case dp.results <- dialResult{done: true, err: err}:
case <-ctx.Done():
}
return
}
var wg sync.WaitGroup
loop:
for _, addr := range addrs {
select {
case dp.sem <- struct{}{}:
case <-ctx.Done():
break loop
}
wg.Add(1)
go func(addr *net.TCPAddr) {
defer wg.Done()
defer func() { <-dp.sem }()
dialCtx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
p, err := bt.DialContext(dialCtx, addr, dp.infoHash, dp.utp)
if err != nil {
return
}
select {
case dp.results <- dialResult{peer: p}:
case <-ctx.Done():
p.Close()
}
}(addr)
}
wg.Wait()
select {
case dp.results <- dialResult{done: true}:
case <-ctx.Done():
}
}
func (dp *dialPool) dialSingle(ctx context.Context, addr *net.TCPAddr) {
select {
case dp.sem <- struct{}{}:
case <-ctx.Done():
return
}
go func() {
defer func() { <-dp.sem }()
dialCtx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
p, err := bt.DialContext(dialCtx, addr, dp.infoHash, dp.utp)
if err != nil {
return
}
select {
case dp.results <- dialResult{peer: p}:
case <-ctx.Done():
p.Close()
}
}()
}

178
client/manager.go Normal file
View File

@@ -0,0 +1,178 @@
package client
import (
"errors"
"fmt"
"net"
"strconv"
"sync"
"storrent/bt"
"storrent/dht"
"storrent/metainfo"
"storrent/utp"
)
var errNotFound = errors.New("torrent not found")
func (m *Manager) getTorrent(id int) (*torrent, error) {
m.mu.Lock()
t, ok := m.torrents[id]
m.mu.Unlock()
if !ok {
return nil, errNotFound
}
return t, nil
}
type Manager struct {
mu sync.Mutex
torrents map[int]*torrent
nextID int
dir string
dht *dht.DHT
ln net.Listener
utp *utp.Socket
}
func NewManager(dir string, d *dht.DHT, btPort int, utpSock *utp.Socket) (*Manager, error) {
var ln net.Listener
if btPort > 0 {
var err error
ln, err = net.Listen("tcp", fmt.Sprintf(":%d", btPort))
if err != nil {
return nil, err
}
}
return &Manager{
torrents: make(map[int]*torrent),
dir: dir,
dht: d,
ln: ln,
utp: utpSock,
}, nil
}
func (m *Manager) Run() {
if m.ln == nil {
return
}
for {
p, err := bt.Accept(m.ln)
if err != nil {
return
}
m.dispatch(p)
}
}
func (m *Manager) dispatch(p *bt.Peer) {
m.mu.Lock()
defer m.mu.Unlock()
for _, t := range m.torrents {
if t.meta.InfoHash == p.InfoHash {
select {
case t.incoming <- p:
default:
p.Close()
}
return
}
}
p.Close()
}
func (m *Manager) Add(path string) ([]byte, error) {
mi, err := metainfo.Parse(path)
if err != nil {
return nil, err
}
m.mu.Lock()
id := m.nextID
m.nextID++
t := newTorrent(id, mi, m.dir, m.dht, m.utp)
complete := t.scheduler.isComplete()
go t.run()
m.torrents[id] = t
m.mu.Unlock()
if complete {
t.startSeed()
} else {
t.startDownload()
}
return []byte(strconv.Itoa(id)), nil
}
func (m *Manager) Remove(id int) error {
m.mu.Lock()
t, ok := m.torrents[id]
if !ok {
m.mu.Unlock()
return errNotFound
}
delete(m.torrents, id)
m.mu.Unlock()
t.stopTorrent()
return nil
}
func (m *Manager) Start(id int) error {
t, err := m.getTorrent(id)
if err != nil {
return err
}
t.startDownload()
return nil
}
func (m *Manager) Stop(id int) error {
t, err := m.getTorrent(id)
if err != nil {
return err
}
t.stopTorrent()
return nil
}
func (m *Manager) Seed(id int) error {
t, err := m.getTorrent(id)
if err != nil {
return err
}
t.startSeed()
return nil
}
func (m *Manager) AddPeer(id int, addr string) error {
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return err
}
t, err := m.getTorrent(id)
if err != nil {
return err
}
t.enqueuePeer(tcpAddr)
return nil
}
func (m *Manager) Status(id int, field string) ([]byte, error) {
t, err := m.getTorrent(id)
if err != nil {
return nil, err
}
return []byte(t.getStatus(field)), nil
}
func (m *Manager) List() []byte {
m.mu.Lock()
defer m.mu.Unlock()
var data []byte
for id := range m.torrents {
if len(data) > 0 {
data = append(data, '\n')
}
data = append(data, []byte(strconv.Itoa(id))...)
}
return data
}

50
client/peer.go Normal file
View File

@@ -0,0 +1,50 @@
package client
import "storrent/bt"
type peerManager struct {
peers map[*bt.Peer]struct{}
maxPeers int
}
func newPeerManager(maxPeers int) *peerManager {
return &peerManager{
peers: make(map[*bt.Peer]struct{}),
maxPeers: maxPeers,
}
}
func (pm *peerManager) add(p *bt.Peer) bool {
if len(pm.peers) >= pm.maxPeers {
return false
}
pm.peers[p] = struct{}{}
return true
}
func (pm *peerManager) addUnlimited(p *bt.Peer) {
pm.peers[p] = struct{}{}
}
func (pm *peerManager) remove(p *bt.Peer) {
delete(pm.peers, p)
}
func (pm *peerManager) count() int {
return len(pm.peers)
}
func (pm *peerManager) all() []*bt.Peer {
peers := make([]*bt.Peer, 0, len(pm.peers))
for p := range pm.peers {
peers = append(peers, p)
}
return peers
}
func (pm *peerManager) closeAll() {
for p := range pm.peers {
p.Close()
}
pm.peers = make(map[*bt.Peer]struct{})
}

69
client/piece.go Normal file
View File

@@ -0,0 +1,69 @@
package client
import (
"crypto/sha1"
"storrent/bt"
)
type piece struct {
hash [20]byte
size int64
data []byte
have []bool
reqs []bool
reqPeer []*bt.Peer
done bool
}
func (p *piece) put(begin int, data []byte) bool {
if p.done {
return false
}
block := begin / bt.BlockSize
if block >= len(p.have) || p.have[block] {
return false
}
copy(p.data[begin:], data)
p.have[block] = true
for _, got := range p.have {
if !got {
return false
}
}
if sha1.Sum(p.data) != p.hash {
for i := range p.have {
p.have[i] = false
p.reqs[i] = false
}
return false
}
p.done = true
return true
}
func (p *piece) next(peer *bt.Peer) (begin, length int64, ok bool) {
for i := range p.have {
if p.have[i] || p.reqs[i] {
continue
}
p.reqs[i] = true
p.reqPeer[i] = peer
begin = int64(i) * bt.BlockSize
length = bt.BlockSize
if begin+length > p.size {
length = p.size - begin
}
return begin, length, true
}
return 0, 0, false
}
func (p *piece) resetPeer(peer *bt.Peer) {
for i := range p.reqs {
if p.reqs[i] && p.reqPeer[i] == peer {
p.reqs[i] = false
p.reqPeer[i] = nil
}
}
}

134
client/scheduler.go Normal file
View File

@@ -0,0 +1,134 @@
package client
import (
"crypto/sha1"
"storrent/bt"
"storrent/metainfo"
)
type pieceScheduler struct {
pieces []piece
incomplete []int
verified int
meta *metainfo.File
dir string
}
func newPieceScheduler(m *metainfo.File, dir string) *pieceScheduler {
n := len(m.Info.Pieces)
pieces := make([]piece, n)
incomplete := make([]int, 0, n)
verified := 0
for i := range n {
size := m.Info.PieceSize
if i == n-1 {
size = m.Size - int64(n-1)*m.Info.PieceSize
}
nblocks := (size + bt.BlockSize - 1) / bt.BlockSize
pieces[i] = piece{
hash: m.Info.Pieces[i],
size: size,
data: make([]byte, size),
have: make([]bool, nblocks),
reqs: make([]bool, nblocks),
reqPeer: make([]*bt.Peer, nblocks),
}
data := m.Info.Read(dir, i, pieces[i].size)
if data != nil && sha1.Sum(data) == pieces[i].hash {
pieces[i].data = data
pieces[i].done = true
verified++
} else {
incomplete = append(incomplete, i)
}
}
return &pieceScheduler{
pieces: pieces,
incomplete: incomplete,
verified: verified,
meta: m,
dir: dir,
}
}
func (ps *pieceScheduler) nextRequest(p *bt.Peer) *bt.Msg {
for _, idx := range ps.incomplete {
pc := &ps.pieces[idx]
if pc.done {
continue
}
if !p.HasPiece(idx) {
continue
}
if begin, length, ok := pc.next(p); ok {
return &bt.Msg{
Kind: bt.Request,
Index: uint32(idx),
Begin: uint32(begin),
Length: uint32(length),
}
}
}
return nil
}
func (ps *pieceScheduler) put(index int, begin int, block []byte) bool {
pc := &ps.pieces[index]
if pc.put(begin, block) {
ps.meta.Info.Write(ps.dir, index, pc.data)
ps.verified++
ps.removeIncomplete(index)
return true
}
return false
}
func (ps *pieceScheduler) removeIncomplete(index int) {
for i, idx := range ps.incomplete {
if idx == index {
ps.incomplete = append(ps.incomplete[:i], ps.incomplete[i+1:]...)
return
}
}
}
func (ps *pieceScheduler) peerDisconnected(p *bt.Peer) {
for i := range ps.pieces {
ps.pieces[i].resetPeer(p)
}
}
func (ps *pieceScheduler) isComplete() bool {
return ps.verified == len(ps.pieces)
}
func (ps *pieceScheduler) progress() (verified, total int) {
return ps.verified, len(ps.pieces)
}
func (ps *pieceScheduler) pieceData(index int) []byte {
pc := &ps.pieces[index]
if !pc.done {
return nil
}
return pc.data
}
func (ps *pieceScheduler) count() int {
return len(ps.pieces)
}
func (ps *pieceScheduler) bitfield(all bool) []byte {
n := len(ps.pieces)
bf := make([]byte, (n+7)/8)
for i := range ps.pieces {
if all || ps.pieces[i].done {
bt.SetBit(bf, i)
}
}
return bf
}

419
client/torrent.go Normal file
View File

@@ -0,0 +1,419 @@
package client
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"time"
"storrent/bt"
"storrent/dht"
"storrent/metainfo"
"storrent/utp"
)
type state int
const (
stopped state = iota
downloading
seeding
done
errored
)
const (
maxPeers = 8
maxPending = 5
dialTimeout = 30 * time.Second
retryInterval = 30 * time.Second
dhtTimeout = 5 * time.Minute
recvBufSize = 128
)
type recv struct {
msg *bt.Msg
err error
p *bt.Peer
}
type statusReq struct {
field string
resp chan string
}
type torrent struct {
id int
meta *metainfo.File
dir string
dht *dht.DHT
utp *utp.Socket
start chan struct{}
stop chan struct{}
seed chan struct{}
incoming chan *bt.Peer
addPeer chan *net.TCPAddr
status chan statusReq
peers *peerManager
scheduler *pieceScheduler
dialPool *dialPool
st state
down int64
up int64
ctx context.Context
cancel context.CancelFunc
recvs chan recv
retry <-chan time.Time
}
func newTorrent(id int, m *metainfo.File, dir string, d *dht.DHT, utpSock *utp.Socket) *torrent {
t := &torrent{
id: id,
meta: m,
dir: dir,
dht: d,
utp: utpSock,
start: make(chan struct{}, 1),
stop: make(chan struct{}, 1),
seed: make(chan struct{}, 1),
incoming: make(chan *bt.Peer),
addPeer: make(chan *net.TCPAddr, 8),
status: make(chan statusReq),
peers: newPeerManager(maxPeers),
scheduler: newPieceScheduler(m, dir),
dialPool: newDialPool(d, utpSock, m.InfoHash),
st: stopped,
recvs: make(chan recv, recvBufSize),
}
return t
}
func (t *torrent) run() {
for {
select {
case <-t.start:
t.handleStart()
case <-t.seed:
t.handleSeed()
case <-t.stop:
t.handleStop()
case <-t.retry:
t.handleRetry()
case d := <-t.dialPool.results:
t.handleDialResult(d)
case p := <-t.incoming:
t.handleIncoming(p)
case addr := <-t.addPeer:
t.handleAddPeer(addr)
case r := <-t.recvs:
t.handleRecv(r)
case req := <-t.status:
req.resp <- t.handleStatus(req.field)
}
}
}
func (t *torrent) handleStart() {
if t.st != stopped {
return
}
t.st = downloading
t.ctx, t.cancel = context.WithCancel(context.Background())
t.dialPool.start(t.ctx)
}
func (t *torrent) handleSeed() {
if t.st != stopped && t.st != done {
return
}
if !t.scheduler.isComplete() {
return
}
t.st = seeding
t.ctx, t.cancel = context.WithCancel(context.Background())
t.dialPool.start(t.ctx)
}
func (t *torrent) handleStop() {
if !t.isActive() {
return
}
t.cancel()
t.dialPool.stop()
t.peers.closeAll()
t.retry = nil
t.st = stopped
}
func (t *torrent) handleRetry() {
t.retry = nil
if t.isActive() {
t.dialPool.start(t.ctx)
}
}
func (t *torrent) handleDialResult(d dialResult) {
if d.done {
if t.peers.count() < maxPeers && t.isActive() {
t.retry = time.After(retryInterval)
}
return
}
if d.peer == nil {
return
}
p := d.peer
if !t.peers.add(p) {
p.Close()
return
}
go t.recvLoop(p)
if err := t.initPeer(p); err != nil {
p.Close()
t.peers.remove(p)
}
}
func (t *torrent) handleIncoming(p *bt.Peer) {
t.peers.addUnlimited(p)
go t.recvLoop(p)
if err := p.Send(&bt.Msg{Kind: bt.Bitfield, Bitfield: t.scheduler.bitfield(false)}); err != nil {
p.Close()
t.peers.remove(p)
return
}
if err := p.Send(&bt.Msg{Kind: bt.Unchoke}); err != nil {
p.Close()
t.peers.remove(p)
}
}
func (t *torrent) handleAddPeer(addr *net.TCPAddr) {
if !t.isActive() {
return
}
t.dialPool.dialSingle(t.ctx, addr)
}
func (t *torrent) handleRecv(r recv) {
if r.err != nil {
r.p.Close()
t.peers.remove(r.p)
t.scheduler.peerDisconnected(r.p)
if t.peers.count() < maxPeers && t.isActive() {
t.dialPool.start(t.ctx)
}
return
}
if r.msg == nil {
return
}
p := r.p
m := r.msg
switch m.Kind {
case bt.Choke:
p.Choked = true
p.Pending = 0
case bt.Unchoke:
p.Choked = false
case bt.Have:
p.SetPiece(int(m.Index))
case bt.Bitfield:
p.SetPieces(m.Bitfield, t.scheduler.count())
case bt.Piece:
t.scheduler.put(int(m.Index), int(m.Begin), m.Block)
p.Pending--
t.down += int64(len(m.Block))
case bt.Request:
t.handleRequest(p, m)
}
t.sendRequests(p)
t.checkCompletion()
}
func (t *torrent) handleRequest(p *bt.Peer, m *bt.Msg) {
data := t.scheduler.pieceData(int(m.Index))
if data == nil {
return
}
end := int(m.Begin + m.Length)
if end > len(data) {
return
}
if err := p.Send(&bt.Msg{
Kind: bt.Piece,
Index: m.Index,
Begin: m.Begin,
Block: data[m.Begin:end],
}); err != nil {
p.Close()
return
}
t.up += int64(m.Length)
}
func (t *torrent) sendRequests(p *bt.Peer) {
if p.Choked {
return
}
for p.Pending < maxPending {
req := t.scheduler.nextRequest(p)
if req == nil {
break
}
if err := p.Send(req); err != nil {
p.Close()
break
}
p.Pending++
}
}
func (t *torrent) checkCompletion() {
if t.st == downloading && t.scheduler.isComplete() {
t.st = done
t.peers.closeAll()
}
}
func (t *torrent) initPeer(p *bt.Peer) error {
if t.st == seeding {
return p.Send(&bt.Msg{Kind: bt.Bitfield, Bitfield: t.scheduler.bitfield(true)})
}
return p.Send(&bt.Msg{Kind: bt.Interested})
}
func (t *torrent) recvLoop(p *bt.Peer) {
for {
msg, err := p.Recv()
t.recvs <- recv{msg, err, p}
if err != nil {
return
}
}
}
func (t *torrent) isActive() bool {
return t.st == downloading || t.st == seeding
}
func (t *torrent) startDownload() {
select {
case t.start <- struct{}{}:
default:
}
}
func (t *torrent) startSeed() {
select {
case t.seed <- struct{}{}:
default:
}
}
func (t *torrent) stopTorrent() {
select {
case t.stop <- struct{}{}:
default:
}
}
func (t *torrent) enqueuePeer(addr *net.TCPAddr) {
t.addPeer <- addr
}
func (t *torrent) getStatus(field string) string {
ch := make(chan string)
t.status <- statusReq{field, ch}
return <-ch
}
func (t *torrent) handleStatus(field string) string {
switch field {
case "name":
return t.meta.Info.Name
case "state":
return t.st.String()
case "progress":
verified, total := t.scheduler.progress()
if total == 0 {
return "0"
}
return strconv.Itoa(verified * 100 / total)
case "size":
return strconv.FormatInt(t.meta.Size, 10)
case "down":
return strconv.FormatInt(t.down, 10)
case "up":
return strconv.FormatInt(t.up, 10)
case "pieces":
verified, total := t.scheduler.progress()
return fmt.Sprintf("%d/%d", verified, total)
case "peers":
return t.formatPeers()
}
return ""
}
func (t *torrent) formatPeers() string {
var buf strings.Builder
npieces := t.scheduler.count()
for _, p := range t.peers.all() {
st := "unchoked"
if p.Choked {
st = "choked"
}
have := 0
for _, h := range p.Have {
if h {
have++
}
}
fmt.Fprintf(&buf, "%s %s %s have:%d/%d pending:%d\n",
p.Addr, p.PeerID[:8], st, have, npieces, p.Pending)
}
return buf.String()
}
func (s state) String() string {
switch s {
case stopped:
return "stopped"
case downloading:
return "downloading"
case seeding:
return "seeding"
case done:
return "done"
case errored:
return "error"
}
return "unknown"
}

121
client/torrent_test.go Normal file
View File

@@ -0,0 +1,121 @@
package client
import (
"crypto/sha1"
"testing"
"storrent/bt"
)
func TestPut(t *testing.T) {
tests := []struct {
in string
hash [20]byte
want bool
}{
{
in: "hello",
hash: sha1.Sum([]byte("hello")),
want: true,
},
{
in: "wrong",
hash: sha1.Sum([]byte("hello")),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.in, func(t *testing.T) {
p := &piece{
hash: tt.hash,
size: int64(len(tt.in)),
data: make([]byte, len(tt.in)),
have: make([]bool, 1),
reqs: make([]bool, 1),
reqPeer: make([]*bt.Peer, 1),
}
if got := p.put(0, []byte(tt.in)); got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
func TestPutMultiple(t *testing.T) {
data := make([]byte, bt.BlockSize+5)
for i := range data {
data[i] = byte(i)
}
p := &piece{
hash: sha1.Sum(data),
size: int64(len(data)),
data: make([]byte, len(data)),
have: make([]bool, 2),
reqs: make([]bool, 2),
reqPeer: make([]*bt.Peer, 2),
}
tests := []struct {
name string
begin int
data []byte
want bool
}{
{
name: "first block",
begin: 0,
data: data[:bt.BlockSize],
want: false,
},
{
name: "second block",
begin: bt.BlockSize,
data: data[bt.BlockSize:],
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := p.put(tt.begin, tt.data); got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
func TestNextRequest(t *testing.T) {
tests := []struct {
name string
pieceSize int64
have []bool
wantLength uint32
}{
{
name: "last block partial",
pieceSize: bt.BlockSize + 100,
have: []bool{true, false},
wantLength: 100,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ps := &pieceScheduler{
pieces: []piece{{
size: tt.pieceSize,
have: tt.have,
reqs: make([]bool, len(tt.have)),
reqPeer: make([]*bt.Peer, len(tt.have)),
}},
incomplete: []int{0},
}
c := &bt.Peer{Have: []bool{true}}
msg := ps.nextRequest(c)
if msg == nil {
t.Fatal("expected request")
}
if msg.Length != tt.wantLength {
t.Errorf("got %d, want %d", msg.Length, tt.wantLength)
}
})
}
}