first commit
This commit is contained in:
139
client/dial.go
Normal file
139
client/dial.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"storrent/bt"
|
||||
"storrent/dht"
|
||||
"storrent/utp"
|
||||
)
|
||||
|
||||
const maxConcurrentDials = 8
|
||||
|
||||
type dialResult struct {
|
||||
peer *bt.Peer
|
||||
err error
|
||||
done bool
|
||||
}
|
||||
|
||||
type dialPool struct {
|
||||
sem chan struct{}
|
||||
results chan dialResult
|
||||
|
||||
dht *dht.DHT
|
||||
utp *utp.Socket
|
||||
infoHash [20]byte
|
||||
|
||||
running bool
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newDialPool(d *dht.DHT, utpSock *utp.Socket, infoHash [20]byte) *dialPool {
|
||||
return &dialPool{
|
||||
sem: make(chan struct{}, maxConcurrentDials),
|
||||
results: make(chan dialResult, maxConcurrentDials),
|
||||
dht: d,
|
||||
utp: utpSock,
|
||||
infoHash: infoHash,
|
||||
}
|
||||
}
|
||||
|
||||
func (dp *dialPool) start(parentCtx context.Context) {
|
||||
if dp.running {
|
||||
return
|
||||
}
|
||||
dp.running = true
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
dp.cancel = cancel
|
||||
go dp.connectLoop(ctx)
|
||||
}
|
||||
|
||||
func (dp *dialPool) stop() {
|
||||
if !dp.running {
|
||||
return
|
||||
}
|
||||
dp.running = false
|
||||
if dp.cancel != nil {
|
||||
dp.cancel()
|
||||
dp.cancel = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (dp *dialPool) connectLoop(ctx context.Context) {
|
||||
dhtCtx, cancel := context.WithTimeout(ctx, dhtTimeout)
|
||||
addrs, err := dp.dht.GetPeers(dhtCtx, dp.infoHash)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
select {
|
||||
case dp.results <- dialResult{done: true, err: err}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
loop:
|
||||
for _, addr := range addrs {
|
||||
select {
|
||||
case dp.sem <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
break loop
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(addr *net.TCPAddr) {
|
||||
defer wg.Done()
|
||||
defer func() { <-dp.sem }()
|
||||
|
||||
dialCtx, cancel := context.WithTimeout(ctx, dialTimeout)
|
||||
defer cancel()
|
||||
|
||||
p, err := bt.DialContext(dialCtx, addr, dp.infoHash, dp.utp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case dp.results <- dialResult{peer: p}:
|
||||
case <-ctx.Done():
|
||||
p.Close()
|
||||
}
|
||||
}(addr)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case dp.results <- dialResult{done: true}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func (dp *dialPool) dialSingle(ctx context.Context, addr *net.TCPAddr) {
|
||||
select {
|
||||
case dp.sem <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { <-dp.sem }()
|
||||
|
||||
dialCtx, cancel := context.WithTimeout(ctx, dialTimeout)
|
||||
defer cancel()
|
||||
|
||||
p, err := bt.DialContext(dialCtx, addr, dp.infoHash, dp.utp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case dp.results <- dialResult{peer: p}:
|
||||
case <-ctx.Done():
|
||||
p.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
178
client/manager.go
Normal file
178
client/manager.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"storrent/bt"
|
||||
"storrent/dht"
|
||||
"storrent/metainfo"
|
||||
"storrent/utp"
|
||||
)
|
||||
|
||||
var errNotFound = errors.New("torrent not found")
|
||||
|
||||
func (m *Manager) getTorrent(id int) (*torrent, error) {
|
||||
m.mu.Lock()
|
||||
t, ok := m.torrents[id]
|
||||
m.mu.Unlock()
|
||||
if !ok {
|
||||
return nil, errNotFound
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
torrents map[int]*torrent
|
||||
nextID int
|
||||
dir string
|
||||
dht *dht.DHT
|
||||
ln net.Listener
|
||||
utp *utp.Socket
|
||||
}
|
||||
|
||||
func NewManager(dir string, d *dht.DHT, btPort int, utpSock *utp.Socket) (*Manager, error) {
|
||||
var ln net.Listener
|
||||
if btPort > 0 {
|
||||
var err error
|
||||
ln, err = net.Listen("tcp", fmt.Sprintf(":%d", btPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &Manager{
|
||||
torrents: make(map[int]*torrent),
|
||||
dir: dir,
|
||||
dht: d,
|
||||
ln: ln,
|
||||
utp: utpSock,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Run() {
|
||||
if m.ln == nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
p, err := bt.Accept(m.ln)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
m.dispatch(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) dispatch(p *bt.Peer) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for _, t := range m.torrents {
|
||||
if t.meta.InfoHash == p.InfoHash {
|
||||
select {
|
||||
case t.incoming <- p:
|
||||
default:
|
||||
p.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
p.Close()
|
||||
}
|
||||
|
||||
func (m *Manager) Add(path string) ([]byte, error) {
|
||||
mi, err := metainfo.Parse(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
id := m.nextID
|
||||
m.nextID++
|
||||
t := newTorrent(id, mi, m.dir, m.dht, m.utp)
|
||||
complete := t.scheduler.isComplete()
|
||||
go t.run()
|
||||
m.torrents[id] = t
|
||||
m.mu.Unlock()
|
||||
if complete {
|
||||
t.startSeed()
|
||||
} else {
|
||||
t.startDownload()
|
||||
}
|
||||
return []byte(strconv.Itoa(id)), nil
|
||||
}
|
||||
|
||||
func (m *Manager) Remove(id int) error {
|
||||
m.mu.Lock()
|
||||
t, ok := m.torrents[id]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
return errNotFound
|
||||
}
|
||||
delete(m.torrents, id)
|
||||
m.mu.Unlock()
|
||||
t.stopTorrent()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Start(id int) error {
|
||||
t, err := m.getTorrent(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.startDownload()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Stop(id int) error {
|
||||
t, err := m.getTorrent(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.stopTorrent()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Seed(id int) error {
|
||||
t, err := m.getTorrent(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.startSeed()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddPeer(id int, addr string) error {
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t, err := m.getTorrent(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.enqueuePeer(tcpAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Status(id int, field string) ([]byte, error) {
|
||||
t, err := m.getTorrent(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(t.getStatus(field)), nil
|
||||
}
|
||||
|
||||
func (m *Manager) List() []byte {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
var data []byte
|
||||
for id := range m.torrents {
|
||||
if len(data) > 0 {
|
||||
data = append(data, '\n')
|
||||
}
|
||||
data = append(data, []byte(strconv.Itoa(id))...)
|
||||
}
|
||||
return data
|
||||
}
|
||||
50
client/peer.go
Normal file
50
client/peer.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package client
|
||||
|
||||
import "storrent/bt"
|
||||
|
||||
type peerManager struct {
|
||||
peers map[*bt.Peer]struct{}
|
||||
maxPeers int
|
||||
}
|
||||
|
||||
func newPeerManager(maxPeers int) *peerManager {
|
||||
return &peerManager{
|
||||
peers: make(map[*bt.Peer]struct{}),
|
||||
maxPeers: maxPeers,
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *peerManager) add(p *bt.Peer) bool {
|
||||
if len(pm.peers) >= pm.maxPeers {
|
||||
return false
|
||||
}
|
||||
pm.peers[p] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
func (pm *peerManager) addUnlimited(p *bt.Peer) {
|
||||
pm.peers[p] = struct{}{}
|
||||
}
|
||||
|
||||
func (pm *peerManager) remove(p *bt.Peer) {
|
||||
delete(pm.peers, p)
|
||||
}
|
||||
|
||||
func (pm *peerManager) count() int {
|
||||
return len(pm.peers)
|
||||
}
|
||||
|
||||
func (pm *peerManager) all() []*bt.Peer {
|
||||
peers := make([]*bt.Peer, 0, len(pm.peers))
|
||||
for p := range pm.peers {
|
||||
peers = append(peers, p)
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
func (pm *peerManager) closeAll() {
|
||||
for p := range pm.peers {
|
||||
p.Close()
|
||||
}
|
||||
pm.peers = make(map[*bt.Peer]struct{})
|
||||
}
|
||||
69
client/piece.go
Normal file
69
client/piece.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
|
||||
"storrent/bt"
|
||||
)
|
||||
|
||||
type piece struct {
|
||||
hash [20]byte
|
||||
size int64
|
||||
data []byte
|
||||
have []bool
|
||||
reqs []bool
|
||||
reqPeer []*bt.Peer
|
||||
done bool
|
||||
}
|
||||
|
||||
func (p *piece) put(begin int, data []byte) bool {
|
||||
if p.done {
|
||||
return false
|
||||
}
|
||||
block := begin / bt.BlockSize
|
||||
if block >= len(p.have) || p.have[block] {
|
||||
return false
|
||||
}
|
||||
copy(p.data[begin:], data)
|
||||
p.have[block] = true
|
||||
for _, got := range p.have {
|
||||
if !got {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if sha1.Sum(p.data) != p.hash {
|
||||
for i := range p.have {
|
||||
p.have[i] = false
|
||||
p.reqs[i] = false
|
||||
}
|
||||
return false
|
||||
}
|
||||
p.done = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *piece) next(peer *bt.Peer) (begin, length int64, ok bool) {
|
||||
for i := range p.have {
|
||||
if p.have[i] || p.reqs[i] {
|
||||
continue
|
||||
}
|
||||
p.reqs[i] = true
|
||||
p.reqPeer[i] = peer
|
||||
begin = int64(i) * bt.BlockSize
|
||||
length = bt.BlockSize
|
||||
if begin+length > p.size {
|
||||
length = p.size - begin
|
||||
}
|
||||
return begin, length, true
|
||||
}
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func (p *piece) resetPeer(peer *bt.Peer) {
|
||||
for i := range p.reqs {
|
||||
if p.reqs[i] && p.reqPeer[i] == peer {
|
||||
p.reqs[i] = false
|
||||
p.reqPeer[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
134
client/scheduler.go
Normal file
134
client/scheduler.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
|
||||
"storrent/bt"
|
||||
"storrent/metainfo"
|
||||
)
|
||||
|
||||
type pieceScheduler struct {
|
||||
pieces []piece
|
||||
incomplete []int
|
||||
verified int
|
||||
meta *metainfo.File
|
||||
dir string
|
||||
}
|
||||
|
||||
func newPieceScheduler(m *metainfo.File, dir string) *pieceScheduler {
|
||||
n := len(m.Info.Pieces)
|
||||
pieces := make([]piece, n)
|
||||
incomplete := make([]int, 0, n)
|
||||
verified := 0
|
||||
|
||||
for i := range n {
|
||||
size := m.Info.PieceSize
|
||||
if i == n-1 {
|
||||
size = m.Size - int64(n-1)*m.Info.PieceSize
|
||||
}
|
||||
nblocks := (size + bt.BlockSize - 1) / bt.BlockSize
|
||||
pieces[i] = piece{
|
||||
hash: m.Info.Pieces[i],
|
||||
size: size,
|
||||
data: make([]byte, size),
|
||||
have: make([]bool, nblocks),
|
||||
reqs: make([]bool, nblocks),
|
||||
reqPeer: make([]*bt.Peer, nblocks),
|
||||
}
|
||||
|
||||
data := m.Info.Read(dir, i, pieces[i].size)
|
||||
if data != nil && sha1.Sum(data) == pieces[i].hash {
|
||||
pieces[i].data = data
|
||||
pieces[i].done = true
|
||||
verified++
|
||||
} else {
|
||||
incomplete = append(incomplete, i)
|
||||
}
|
||||
}
|
||||
|
||||
return &pieceScheduler{
|
||||
pieces: pieces,
|
||||
incomplete: incomplete,
|
||||
verified: verified,
|
||||
meta: m,
|
||||
dir: dir,
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *pieceScheduler) nextRequest(p *bt.Peer) *bt.Msg {
|
||||
for _, idx := range ps.incomplete {
|
||||
pc := &ps.pieces[idx]
|
||||
if pc.done {
|
||||
continue
|
||||
}
|
||||
if !p.HasPiece(idx) {
|
||||
continue
|
||||
}
|
||||
if begin, length, ok := pc.next(p); ok {
|
||||
return &bt.Msg{
|
||||
Kind: bt.Request,
|
||||
Index: uint32(idx),
|
||||
Begin: uint32(begin),
|
||||
Length: uint32(length),
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *pieceScheduler) put(index int, begin int, block []byte) bool {
|
||||
pc := &ps.pieces[index]
|
||||
if pc.put(begin, block) {
|
||||
ps.meta.Info.Write(ps.dir, index, pc.data)
|
||||
ps.verified++
|
||||
ps.removeIncomplete(index)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ps *pieceScheduler) removeIncomplete(index int) {
|
||||
for i, idx := range ps.incomplete {
|
||||
if idx == index {
|
||||
ps.incomplete = append(ps.incomplete[:i], ps.incomplete[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *pieceScheduler) peerDisconnected(p *bt.Peer) {
|
||||
for i := range ps.pieces {
|
||||
ps.pieces[i].resetPeer(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *pieceScheduler) isComplete() bool {
|
||||
return ps.verified == len(ps.pieces)
|
||||
}
|
||||
|
||||
func (ps *pieceScheduler) progress() (verified, total int) {
|
||||
return ps.verified, len(ps.pieces)
|
||||
}
|
||||
|
||||
func (ps *pieceScheduler) pieceData(index int) []byte {
|
||||
pc := &ps.pieces[index]
|
||||
if !pc.done {
|
||||
return nil
|
||||
}
|
||||
return pc.data
|
||||
}
|
||||
|
||||
func (ps *pieceScheduler) count() int {
|
||||
return len(ps.pieces)
|
||||
}
|
||||
|
||||
func (ps *pieceScheduler) bitfield(all bool) []byte {
|
||||
n := len(ps.pieces)
|
||||
bf := make([]byte, (n+7)/8)
|
||||
for i := range ps.pieces {
|
||||
if all || ps.pieces[i].done {
|
||||
bt.SetBit(bf, i)
|
||||
}
|
||||
}
|
||||
return bf
|
||||
}
|
||||
419
client/torrent.go
Normal file
419
client/torrent.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"storrent/bt"
|
||||
"storrent/dht"
|
||||
"storrent/metainfo"
|
||||
"storrent/utp"
|
||||
)
|
||||
|
||||
type state int
|
||||
|
||||
const (
|
||||
stopped state = iota
|
||||
downloading
|
||||
seeding
|
||||
done
|
||||
errored
|
||||
)
|
||||
|
||||
const (
|
||||
maxPeers = 8
|
||||
maxPending = 5
|
||||
dialTimeout = 30 * time.Second
|
||||
retryInterval = 30 * time.Second
|
||||
dhtTimeout = 5 * time.Minute
|
||||
recvBufSize = 128
|
||||
)
|
||||
|
||||
type recv struct {
|
||||
msg *bt.Msg
|
||||
err error
|
||||
p *bt.Peer
|
||||
}
|
||||
|
||||
type statusReq struct {
|
||||
field string
|
||||
resp chan string
|
||||
}
|
||||
|
||||
type torrent struct {
|
||||
id int
|
||||
meta *metainfo.File
|
||||
dir string
|
||||
dht *dht.DHT
|
||||
utp *utp.Socket
|
||||
|
||||
start chan struct{}
|
||||
stop chan struct{}
|
||||
seed chan struct{}
|
||||
incoming chan *bt.Peer
|
||||
addPeer chan *net.TCPAddr
|
||||
status chan statusReq
|
||||
|
||||
peers *peerManager
|
||||
scheduler *pieceScheduler
|
||||
dialPool *dialPool
|
||||
|
||||
st state
|
||||
down int64
|
||||
up int64
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
recvs chan recv
|
||||
retry <-chan time.Time
|
||||
}
|
||||
|
||||
func newTorrent(id int, m *metainfo.File, dir string, d *dht.DHT, utpSock *utp.Socket) *torrent {
|
||||
t := &torrent{
|
||||
id: id,
|
||||
meta: m,
|
||||
dir: dir,
|
||||
dht: d,
|
||||
utp: utpSock,
|
||||
start: make(chan struct{}, 1),
|
||||
stop: make(chan struct{}, 1),
|
||||
seed: make(chan struct{}, 1),
|
||||
incoming: make(chan *bt.Peer),
|
||||
addPeer: make(chan *net.TCPAddr, 8),
|
||||
status: make(chan statusReq),
|
||||
peers: newPeerManager(maxPeers),
|
||||
scheduler: newPieceScheduler(m, dir),
|
||||
dialPool: newDialPool(d, utpSock, m.InfoHash),
|
||||
st: stopped,
|
||||
recvs: make(chan recv, recvBufSize),
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *torrent) run() {
|
||||
for {
|
||||
select {
|
||||
case <-t.start:
|
||||
t.handleStart()
|
||||
|
||||
case <-t.seed:
|
||||
t.handleSeed()
|
||||
|
||||
case <-t.stop:
|
||||
t.handleStop()
|
||||
|
||||
case <-t.retry:
|
||||
t.handleRetry()
|
||||
|
||||
case d := <-t.dialPool.results:
|
||||
t.handleDialResult(d)
|
||||
|
||||
case p := <-t.incoming:
|
||||
t.handleIncoming(p)
|
||||
|
||||
case addr := <-t.addPeer:
|
||||
t.handleAddPeer(addr)
|
||||
|
||||
case r := <-t.recvs:
|
||||
t.handleRecv(r)
|
||||
|
||||
case req := <-t.status:
|
||||
req.resp <- t.handleStatus(req.field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *torrent) handleStart() {
|
||||
if t.st != stopped {
|
||||
return
|
||||
}
|
||||
t.st = downloading
|
||||
t.ctx, t.cancel = context.WithCancel(context.Background())
|
||||
t.dialPool.start(t.ctx)
|
||||
}
|
||||
|
||||
func (t *torrent) handleSeed() {
|
||||
if t.st != stopped && t.st != done {
|
||||
return
|
||||
}
|
||||
if !t.scheduler.isComplete() {
|
||||
return
|
||||
}
|
||||
t.st = seeding
|
||||
t.ctx, t.cancel = context.WithCancel(context.Background())
|
||||
t.dialPool.start(t.ctx)
|
||||
}
|
||||
|
||||
func (t *torrent) handleStop() {
|
||||
if !t.isActive() {
|
||||
return
|
||||
}
|
||||
t.cancel()
|
||||
t.dialPool.stop()
|
||||
t.peers.closeAll()
|
||||
t.retry = nil
|
||||
t.st = stopped
|
||||
}
|
||||
|
||||
func (t *torrent) handleRetry() {
|
||||
t.retry = nil
|
||||
if t.isActive() {
|
||||
t.dialPool.start(t.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *torrent) handleDialResult(d dialResult) {
|
||||
if d.done {
|
||||
if t.peers.count() < maxPeers && t.isActive() {
|
||||
t.retry = time.After(retryInterval)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if d.peer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p := d.peer
|
||||
if !t.peers.add(p) {
|
||||
p.Close()
|
||||
return
|
||||
}
|
||||
|
||||
go t.recvLoop(p)
|
||||
if err := t.initPeer(p); err != nil {
|
||||
p.Close()
|
||||
t.peers.remove(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *torrent) handleIncoming(p *bt.Peer) {
|
||||
t.peers.addUnlimited(p)
|
||||
go t.recvLoop(p)
|
||||
|
||||
if err := p.Send(&bt.Msg{Kind: bt.Bitfield, Bitfield: t.scheduler.bitfield(false)}); err != nil {
|
||||
p.Close()
|
||||
t.peers.remove(p)
|
||||
return
|
||||
}
|
||||
if err := p.Send(&bt.Msg{Kind: bt.Unchoke}); err != nil {
|
||||
p.Close()
|
||||
t.peers.remove(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *torrent) handleAddPeer(addr *net.TCPAddr) {
|
||||
if !t.isActive() {
|
||||
return
|
||||
}
|
||||
t.dialPool.dialSingle(t.ctx, addr)
|
||||
}
|
||||
|
||||
func (t *torrent) handleRecv(r recv) {
|
||||
if r.err != nil {
|
||||
r.p.Close()
|
||||
t.peers.remove(r.p)
|
||||
t.scheduler.peerDisconnected(r.p)
|
||||
if t.peers.count() < maxPeers && t.isActive() {
|
||||
t.dialPool.start(t.ctx)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if r.msg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p := r.p
|
||||
m := r.msg
|
||||
|
||||
switch m.Kind {
|
||||
case bt.Choke:
|
||||
p.Choked = true
|
||||
p.Pending = 0
|
||||
|
||||
case bt.Unchoke:
|
||||
p.Choked = false
|
||||
|
||||
case bt.Have:
|
||||
p.SetPiece(int(m.Index))
|
||||
|
||||
case bt.Bitfield:
|
||||
p.SetPieces(m.Bitfield, t.scheduler.count())
|
||||
|
||||
case bt.Piece:
|
||||
t.scheduler.put(int(m.Index), int(m.Begin), m.Block)
|
||||
p.Pending--
|
||||
t.down += int64(len(m.Block))
|
||||
|
||||
case bt.Request:
|
||||
t.handleRequest(p, m)
|
||||
}
|
||||
|
||||
t.sendRequests(p)
|
||||
t.checkCompletion()
|
||||
}
|
||||
|
||||
func (t *torrent) handleRequest(p *bt.Peer, m *bt.Msg) {
|
||||
data := t.scheduler.pieceData(int(m.Index))
|
||||
if data == nil {
|
||||
return
|
||||
}
|
||||
end := int(m.Begin + m.Length)
|
||||
if end > len(data) {
|
||||
return
|
||||
}
|
||||
if err := p.Send(&bt.Msg{
|
||||
Kind: bt.Piece,
|
||||
Index: m.Index,
|
||||
Begin: m.Begin,
|
||||
Block: data[m.Begin:end],
|
||||
}); err != nil {
|
||||
p.Close()
|
||||
return
|
||||
}
|
||||
t.up += int64(m.Length)
|
||||
}
|
||||
|
||||
func (t *torrent) sendRequests(p *bt.Peer) {
|
||||
if p.Choked {
|
||||
return
|
||||
}
|
||||
for p.Pending < maxPending {
|
||||
req := t.scheduler.nextRequest(p)
|
||||
if req == nil {
|
||||
break
|
||||
}
|
||||
if err := p.Send(req); err != nil {
|
||||
p.Close()
|
||||
break
|
||||
}
|
||||
p.Pending++
|
||||
}
|
||||
}
|
||||
|
||||
func (t *torrent) checkCompletion() {
|
||||
if t.st == downloading && t.scheduler.isComplete() {
|
||||
t.st = done
|
||||
t.peers.closeAll()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *torrent) initPeer(p *bt.Peer) error {
|
||||
if t.st == seeding {
|
||||
return p.Send(&bt.Msg{Kind: bt.Bitfield, Bitfield: t.scheduler.bitfield(true)})
|
||||
}
|
||||
return p.Send(&bt.Msg{Kind: bt.Interested})
|
||||
}
|
||||
|
||||
func (t *torrent) recvLoop(p *bt.Peer) {
|
||||
for {
|
||||
msg, err := p.Recv()
|
||||
t.recvs <- recv{msg, err, p}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *torrent) isActive() bool {
|
||||
return t.st == downloading || t.st == seeding
|
||||
}
|
||||
|
||||
func (t *torrent) startDownload() {
|
||||
select {
|
||||
case t.start <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (t *torrent) startSeed() {
|
||||
select {
|
||||
case t.seed <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (t *torrent) stopTorrent() {
|
||||
select {
|
||||
case t.stop <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (t *torrent) enqueuePeer(addr *net.TCPAddr) {
|
||||
t.addPeer <- addr
|
||||
}
|
||||
|
||||
func (t *torrent) getStatus(field string) string {
|
||||
ch := make(chan string)
|
||||
t.status <- statusReq{field, ch}
|
||||
return <-ch
|
||||
}
|
||||
|
||||
func (t *torrent) handleStatus(field string) string {
|
||||
switch field {
|
||||
case "name":
|
||||
return t.meta.Info.Name
|
||||
case "state":
|
||||
return t.st.String()
|
||||
case "progress":
|
||||
verified, total := t.scheduler.progress()
|
||||
if total == 0 {
|
||||
return "0"
|
||||
}
|
||||
return strconv.Itoa(verified * 100 / total)
|
||||
case "size":
|
||||
return strconv.FormatInt(t.meta.Size, 10)
|
||||
case "down":
|
||||
return strconv.FormatInt(t.down, 10)
|
||||
case "up":
|
||||
return strconv.FormatInt(t.up, 10)
|
||||
case "pieces":
|
||||
verified, total := t.scheduler.progress()
|
||||
return fmt.Sprintf("%d/%d", verified, total)
|
||||
case "peers":
|
||||
return t.formatPeers()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (t *torrent) formatPeers() string {
|
||||
var buf strings.Builder
|
||||
npieces := t.scheduler.count()
|
||||
for _, p := range t.peers.all() {
|
||||
st := "unchoked"
|
||||
if p.Choked {
|
||||
st = "choked"
|
||||
}
|
||||
have := 0
|
||||
for _, h := range p.Have {
|
||||
if h {
|
||||
have++
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s %s %s have:%d/%d pending:%d\n",
|
||||
p.Addr, p.PeerID[:8], st, have, npieces, p.Pending)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (s state) String() string {
|
||||
switch s {
|
||||
case stopped:
|
||||
return "stopped"
|
||||
case downloading:
|
||||
return "downloading"
|
||||
case seeding:
|
||||
return "seeding"
|
||||
case done:
|
||||
return "done"
|
||||
case errored:
|
||||
return "error"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
121
client/torrent_test.go
Normal file
121
client/torrent_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"testing"
|
||||
|
||||
"storrent/bt"
|
||||
)
|
||||
|
||||
func TestPut(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
hash [20]byte
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
in: "hello",
|
||||
hash: sha1.Sum([]byte("hello")),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
in: "wrong",
|
||||
hash: sha1.Sum([]byte("hello")),
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := &piece{
|
||||
hash: tt.hash,
|
||||
size: int64(len(tt.in)),
|
||||
data: make([]byte, len(tt.in)),
|
||||
have: make([]bool, 1),
|
||||
reqs: make([]bool, 1),
|
||||
reqPeer: make([]*bt.Peer, 1),
|
||||
}
|
||||
if got := p.put(0, []byte(tt.in)); got != tt.want {
|
||||
t.Errorf("got %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutMultiple(t *testing.T) {
|
||||
data := make([]byte, bt.BlockSize+5)
|
||||
for i := range data {
|
||||
data[i] = byte(i)
|
||||
}
|
||||
p := &piece{
|
||||
hash: sha1.Sum(data),
|
||||
size: int64(len(data)),
|
||||
data: make([]byte, len(data)),
|
||||
have: make([]bool, 2),
|
||||
reqs: make([]bool, 2),
|
||||
reqPeer: make([]*bt.Peer, 2),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
begin int
|
||||
data []byte
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "first block",
|
||||
begin: 0,
|
||||
data: data[:bt.BlockSize],
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "second block",
|
||||
begin: bt.BlockSize,
|
||||
data: data[bt.BlockSize:],
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := p.put(tt.begin, tt.data); got != tt.want {
|
||||
t.Errorf("got %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pieceSize int64
|
||||
have []bool
|
||||
wantLength uint32
|
||||
}{
|
||||
{
|
||||
name: "last block partial",
|
||||
pieceSize: bt.BlockSize + 100,
|
||||
have: []bool{true, false},
|
||||
wantLength: 100,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ps := &pieceScheduler{
|
||||
pieces: []piece{{
|
||||
size: tt.pieceSize,
|
||||
have: tt.have,
|
||||
reqs: make([]bool, len(tt.have)),
|
||||
reqPeer: make([]*bt.Peer, len(tt.have)),
|
||||
}},
|
||||
incomplete: []int{0},
|
||||
}
|
||||
c := &bt.Peer{Have: []bool{true}}
|
||||
msg := ps.nextRequest(c)
|
||||
if msg == nil {
|
||||
t.Fatal("expected request")
|
||||
}
|
||||
if msg.Length != tt.wantLength {
|
||||
t.Errorf("got %d, want %d", msg.Length, tt.wantLength)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user