storrent/client/dial.go
2026-01-19 21:13:01 +09:00

140 lines
2.3 KiB
Go

package client
import (
"context"
"net"
"sync"
"storrent/bt"
"storrent/dht"
"storrent/utp"
)
const maxConcurrentDials = 8
type dialResult struct {
peer *bt.Peer
err error
done bool
}
type dialPool struct {
sem chan struct{}
results chan dialResult
dht *dht.DHT
utp *utp.Socket
infoHash [20]byte
running bool
cancel context.CancelFunc
}
func newDialPool(d *dht.DHT, utpSock *utp.Socket, infoHash [20]byte) *dialPool {
return &dialPool{
sem: make(chan struct{}, maxConcurrentDials),
results: make(chan dialResult, maxConcurrentDials),
dht: d,
utp: utpSock,
infoHash: infoHash,
}
}
func (dp *dialPool) start(parentCtx context.Context) {
if dp.running {
return
}
dp.running = true
ctx, cancel := context.WithCancel(parentCtx)
dp.cancel = cancel
go dp.connectLoop(ctx)
}
func (dp *dialPool) stop() {
if !dp.running {
return
}
dp.running = false
if dp.cancel != nil {
dp.cancel()
dp.cancel = nil
}
}
func (dp *dialPool) connectLoop(ctx context.Context) {
dhtCtx, cancel := context.WithTimeout(ctx, dhtTimeout)
addrs, err := dp.dht.GetPeers(dhtCtx, dp.infoHash)
cancel()
if err != nil {
select {
case dp.results <- dialResult{done: true, err: err}:
case <-ctx.Done():
}
return
}
var wg sync.WaitGroup
loop:
for _, addr := range addrs {
select {
case dp.sem <- struct{}{}:
case <-ctx.Done():
break loop
}
wg.Add(1)
go func(addr *net.TCPAddr) {
defer wg.Done()
defer func() { <-dp.sem }()
dialCtx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
p, err := bt.DialContext(dialCtx, addr, dp.infoHash, dp.utp)
if err != nil {
return
}
select {
case dp.results <- dialResult{peer: p}:
case <-ctx.Done():
p.Close()
}
}(addr)
}
wg.Wait()
select {
case dp.results <- dialResult{done: true}:
case <-ctx.Done():
}
}
func (dp *dialPool) dialSingle(ctx context.Context, addr *net.TCPAddr) {
select {
case dp.sem <- struct{}{}:
case <-ctx.Done():
return
}
go func() {
defer func() { <-dp.sem }()
dialCtx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
p, err := bt.DialContext(dialCtx, addr, dp.infoHash, dp.utp)
if err != nil {
return
}
select {
case dp.results <- dialResult{peer: p}:
case <-ctx.Done():
p.Close()
}
}()
}