storrent/client/torrent.go
2026-01-19 21:13:01 +09:00

420 lines
7.0 KiB
Go

package client
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"time"
"storrent/bt"
"storrent/dht"
"storrent/metainfo"
"storrent/utp"
)
type state int
const (
stopped state = iota
downloading
seeding
done
errored
)
const (
maxPeers = 8
maxPending = 5
dialTimeout = 30 * time.Second
retryInterval = 30 * time.Second
dhtTimeout = 5 * time.Minute
recvBufSize = 128
)
type recv struct {
msg *bt.Msg
err error
p *bt.Peer
}
type statusReq struct {
field string
resp chan string
}
type torrent struct {
id int
meta *metainfo.File
dir string
dht *dht.DHT
utp *utp.Socket
start chan struct{}
stop chan struct{}
seed chan struct{}
incoming chan *bt.Peer
addPeer chan *net.TCPAddr
status chan statusReq
peers *peerManager
scheduler *pieceScheduler
dialPool *dialPool
st state
down int64
up int64
ctx context.Context
cancel context.CancelFunc
recvs chan recv
retry <-chan time.Time
}
func newTorrent(id int, m *metainfo.File, dir string, d *dht.DHT, utpSock *utp.Socket) *torrent {
t := &torrent{
id: id,
meta: m,
dir: dir,
dht: d,
utp: utpSock,
start: make(chan struct{}, 1),
stop: make(chan struct{}, 1),
seed: make(chan struct{}, 1),
incoming: make(chan *bt.Peer),
addPeer: make(chan *net.TCPAddr, 8),
status: make(chan statusReq),
peers: newPeerManager(maxPeers),
scheduler: newPieceScheduler(m, dir),
dialPool: newDialPool(d, utpSock, m.InfoHash),
st: stopped,
recvs: make(chan recv, recvBufSize),
}
return t
}
func (t *torrent) run() {
for {
select {
case <-t.start:
t.handleStart()
case <-t.seed:
t.handleSeed()
case <-t.stop:
t.handleStop()
case <-t.retry:
t.handleRetry()
case d := <-t.dialPool.results:
t.handleDialResult(d)
case p := <-t.incoming:
t.handleIncoming(p)
case addr := <-t.addPeer:
t.handleAddPeer(addr)
case r := <-t.recvs:
t.handleRecv(r)
case req := <-t.status:
req.resp <- t.handleStatus(req.field)
}
}
}
func (t *torrent) handleStart() {
if t.st != stopped {
return
}
t.st = downloading
t.ctx, t.cancel = context.WithCancel(context.Background())
t.dialPool.start(t.ctx)
}
func (t *torrent) handleSeed() {
if t.st != stopped && t.st != done {
return
}
if !t.scheduler.isComplete() {
return
}
t.st = seeding
t.ctx, t.cancel = context.WithCancel(context.Background())
t.dialPool.start(t.ctx)
}
func (t *torrent) handleStop() {
if !t.isActive() {
return
}
t.cancel()
t.dialPool.stop()
t.peers.closeAll()
t.retry = nil
t.st = stopped
}
func (t *torrent) handleRetry() {
t.retry = nil
if t.isActive() {
t.dialPool.start(t.ctx)
}
}
func (t *torrent) handleDialResult(d dialResult) {
if d.done {
if t.peers.count() < maxPeers && t.isActive() {
t.retry = time.After(retryInterval)
}
return
}
if d.peer == nil {
return
}
p := d.peer
if !t.peers.add(p) {
p.Close()
return
}
go t.recvLoop(p)
if err := t.initPeer(p); err != nil {
p.Close()
t.peers.remove(p)
}
}
func (t *torrent) handleIncoming(p *bt.Peer) {
t.peers.addUnlimited(p)
go t.recvLoop(p)
if err := p.Send(&bt.Msg{Kind: bt.Bitfield, Bitfield: t.scheduler.bitfield(false)}); err != nil {
p.Close()
t.peers.remove(p)
return
}
if err := p.Send(&bt.Msg{Kind: bt.Unchoke}); err != nil {
p.Close()
t.peers.remove(p)
}
}
func (t *torrent) handleAddPeer(addr *net.TCPAddr) {
if !t.isActive() {
return
}
t.dialPool.dialSingle(t.ctx, addr)
}
func (t *torrent) handleRecv(r recv) {
if r.err != nil {
r.p.Close()
t.peers.remove(r.p)
t.scheduler.peerDisconnected(r.p)
if t.peers.count() < maxPeers && t.isActive() {
t.dialPool.start(t.ctx)
}
return
}
if r.msg == nil {
return
}
p := r.p
m := r.msg
switch m.Kind {
case bt.Choke:
p.Choked = true
p.Pending = 0
case bt.Unchoke:
p.Choked = false
case bt.Have:
p.SetPiece(int(m.Index))
case bt.Bitfield:
p.SetPieces(m.Bitfield, t.scheduler.count())
case bt.Piece:
t.scheduler.put(int(m.Index), int(m.Begin), m.Block)
p.Pending--
t.down += int64(len(m.Block))
case bt.Request:
t.handleRequest(p, m)
}
t.sendRequests(p)
t.checkCompletion()
}
func (t *torrent) handleRequest(p *bt.Peer, m *bt.Msg) {
data := t.scheduler.pieceData(int(m.Index))
if data == nil {
return
}
end := int(m.Begin + m.Length)
if end > len(data) {
return
}
if err := p.Send(&bt.Msg{
Kind: bt.Piece,
Index: m.Index,
Begin: m.Begin,
Block: data[m.Begin:end],
}); err != nil {
p.Close()
return
}
t.up += int64(m.Length)
}
func (t *torrent) sendRequests(p *bt.Peer) {
if p.Choked {
return
}
for p.Pending < maxPending {
req := t.scheduler.nextRequest(p)
if req == nil {
break
}
if err := p.Send(req); err != nil {
p.Close()
break
}
p.Pending++
}
}
func (t *torrent) checkCompletion() {
if t.st == downloading && t.scheduler.isComplete() {
t.st = done
t.peers.closeAll()
}
}
func (t *torrent) initPeer(p *bt.Peer) error {
if t.st == seeding {
return p.Send(&bt.Msg{Kind: bt.Bitfield, Bitfield: t.scheduler.bitfield(true)})
}
return p.Send(&bt.Msg{Kind: bt.Interested})
}
func (t *torrent) recvLoop(p *bt.Peer) {
for {
msg, err := p.Recv()
t.recvs <- recv{msg, err, p}
if err != nil {
return
}
}
}
func (t *torrent) isActive() bool {
return t.st == downloading || t.st == seeding
}
func (t *torrent) startDownload() {
select {
case t.start <- struct{}{}:
default:
}
}
func (t *torrent) startSeed() {
select {
case t.seed <- struct{}{}:
default:
}
}
func (t *torrent) stopTorrent() {
select {
case t.stop <- struct{}{}:
default:
}
}
func (t *torrent) enqueuePeer(addr *net.TCPAddr) {
t.addPeer <- addr
}
func (t *torrent) getStatus(field string) string {
ch := make(chan string)
t.status <- statusReq{field, ch}
return <-ch
}
func (t *torrent) handleStatus(field string) string {
switch field {
case "name":
return t.meta.Info.Name
case "state":
return t.st.String()
case "progress":
verified, total := t.scheduler.progress()
if total == 0 {
return "0"
}
return strconv.Itoa(verified * 100 / total)
case "size":
return strconv.FormatInt(t.meta.Size, 10)
case "down":
return strconv.FormatInt(t.down, 10)
case "up":
return strconv.FormatInt(t.up, 10)
case "pieces":
verified, total := t.scheduler.progress()
return fmt.Sprintf("%d/%d", verified, total)
case "peers":
return t.formatPeers()
}
return ""
}
func (t *torrent) formatPeers() string {
var buf strings.Builder
npieces := t.scheduler.count()
for _, p := range t.peers.all() {
st := "unchoked"
if p.Choked {
st = "choked"
}
have := 0
for _, h := range p.Have {
if h {
have++
}
}
fmt.Fprintf(&buf, "%s %s %s have:%d/%d pending:%d\n",
p.Addr, p.PeerID[:8], st, have, npieces, p.Pending)
}
return buf.String()
}
func (s state) String() string {
switch s {
case stopped:
return "stopped"
case downloading:
return "downloading"
case seeding:
return "seeding"
case done:
return "done"
case errored:
return "error"
}
return "unknown"
}