diff --git a/client/dial.go b/client/dial.go index 6b5eedb..071371b 100644 --- a/client/dial.go +++ b/client/dial.go @@ -63,10 +63,13 @@ func (dp *dialPool) stop() { func (dp *dialPool) connectLoop(ctx context.Context) { dhtCtx, cancel := context.WithTimeout(ctx, dhtTimeout) - // TODO: handle tokens - addrs, _, err := dp.dht.GetPeers(dhtCtx, dp.infoHash) + addrs, tokens, err := dp.dht.GetPeers(dhtCtx, dp.infoHash) cancel() + if len(tokens) > 0 { + dp.dht.AnnouncePeer(ctx, dp.infoHash, tokens) + } + if err != nil { select { case dp.results <- dialResult{done: true, err: err}: diff --git a/client/torrent.go b/client/torrent.go index d80f7f7..7f5bd54 100644 --- a/client/torrent.go +++ b/client/torrent.go @@ -26,7 +26,7 @@ const ( const ( maxPeers = 8 - maxPending = 5 + maxPending = 50 dialTimeout = 30 * time.Second retryInterval = 30 * time.Second dhtTimeout = 5 * time.Minute diff --git a/dht/dht.go b/dht/dht.go index 11d156d..3246434 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -8,6 +8,8 @@ import ( "strconv" "sync" "time" + + "storrent/utp" ) const ( @@ -18,39 +20,27 @@ const ( type DHT struct { conn *net.UDPConn + in <-chan utp.Packet id [20]byte nodes []node pending map[string]chan *resp mu sync.Mutex } -func New(port int) (*DHT, error) { - addr := &net.UDPAddr{Port: port} - conn, err := net.ListenUDP("udp", addr) - if err != nil { - return nil, err - } +func New(conn *net.UDPConn, in <-chan utp.Packet) *DHT { d := &DHT{ conn: conn, + in: in, id: genNodeID(), pending: make(map[string]chan *resp), } go d.run() - return d, nil -} - -func (d *DHT) Close() { - d.conn.Close() + return d } func (d *DHT) run() { - buf := make([]byte, 65535) - for { - n, _, err := d.conn.ReadFromUDP(buf) - if err != nil { - return - } - msg, err := decodeMsg(buf[:n]) + for pkt := range d.in { + msg, err := decodeMsg(pkt.Raw) if err != nil { continue } @@ -237,3 +227,23 @@ func nextCandidate(candidates []node, queried map[string]bool) (node, int) { } return node{}, -1 } + +func (d *DHT) AnnouncePeer(ctx context.Context, h [20]byte, tokens map[string]string) { + var wg sync.WaitGroup + for addrStr, token := range tokens { + addr, err := net.ResolveUDPAddr("udp", addrStr) + if err != nil { + continue + } + wg.Add(1) + go func(addr *net.UDPAddr, token string) { + defer wg.Done() + d.query(ctx, addr, announcePeer, args{ + InfoHash: h, + ImpliedPort: 1, + Token: token, + }) + }(addr, token) + } + wg.Wait() +} diff --git a/main.go b/main.go index 9460f03..3ff4276 100644 --- a/main.go +++ b/main.go @@ -3,8 +3,8 @@ package main import ( "context" "flag" - "fmt" "log" + "net" "github.com/knusbaum/go9p" @@ -17,27 +17,25 @@ import ( func main() { addr := flag.String("addr", ":5640", "9P listen address") dir := flag.String("dir", "./download", "download directory") - dhtPort := flag.Int("dht", 6881, "DHT port") - btPort := flag.Int("bt", 0, "BT listen port") - utpPort := flag.Int("utp", 6882, "uTP port") + udpPort := flag.Int("port", 6881, "UDP port (DHT + uTP)") + btPort := flag.Int("bt", 0, "BT listen port (TCP)") bootstrap := flag.String("bootstrap", "router.bittorrent.com:6881", "DHT bootstrap node") flag.Parse() - d, err := dht.New(*dhtPort) + conn, err := net.ListenUDP("udp", &net.UDPAddr{Port: *udpPort}) if err != nil { log.Fatal(err) } - if err := d.Bootstrap(context.Background(), *bootstrap); err != nil { + + dhtCh := make(chan utp.Packet, 64) + sock := utp.New(conn, dhtCh) + dht := dht.New(conn, dhtCh) + sock.Start() + if err := dht.Bootstrap(context.Background(), *bootstrap); err != nil { log.Fatal(err) } - utpSock, err := utp.New(fmt.Sprintf(":%d", *utpPort)) - if err != nil { - log.Fatal(err) - } - utpSock.Start() - - m, err := client.NewManager(*dir, d, *btPort, utpSock) + m, err := client.NewManager(*dir, dht, *btPort, sock) if err != nil { log.Fatal(err) } diff --git a/utp/conn.go b/utp/conn.go index 539ab3a..6ab6c75 100644 --- a/utp/conn.go +++ b/utp/conn.go @@ -19,7 +19,7 @@ type Conn struct { recvID uint16 sendID uint16 - in chan packet + in chan Packet reads chan readReq writes chan writeReq closeReq chan struct{} @@ -150,10 +150,10 @@ func (c *Conn) retrans() error { return nil } -func (c *Conn) recv(p packet) (bool, error) { - switch p.hdr.typ { +func (c *Conn) recv(p Packet) (bool, error) { + switch p.Hdr.typ { case Syn: - c.ack = p.hdr.seq + c.ack = p.Hdr.seq if err := c.sendState(); err != nil { return false, err } @@ -161,18 +161,18 @@ func (c *Conn) recv(p packet) (bool, error) { select { case <-c.ready: default: - c.ack = p.hdr.seq - 1 + c.ack = p.Hdr.seq - 1 close(c.ready) } var remaining []sent for _, s := range c.unacked { - if seqLess(p.hdr.ack, s.seq) { + if seqLess(p.Hdr.ack, s.seq) { remaining = append(remaining, s) } } c.unacked = remaining case Data: - seq, payload := p.hdr.seq, p.payload + seq, payload := p.Hdr.seq, p.Payload if seq == c.ack+1 { c.deliver(payload) c.ack++ diff --git a/utp/socket.go b/utp/socket.go index 7911329..94ff453 100644 --- a/utp/socket.go +++ b/utp/socket.go @@ -57,6 +57,7 @@ func decode(buf []byte) header { type Socket struct { conn *net.UDPConn + dhtCh chan<- Packet mu sync.RWMutex conns map[uint16]*Conn accepts chan *Conn @@ -64,30 +65,24 @@ type Socket struct { cancel context.CancelFunc } -type packet struct { - hdr header - payload []byte - addr *net.UDPAddr +type Packet struct { + Hdr header + Payload []byte + Raw []byte + Addr *net.UDPAddr } -func New(addr string) (*Socket, error) { - a, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - conn, err := net.ListenUDP("udp", a) - if err != nil { - return nil, err - } +func New(conn *net.UDPConn, dhtCh chan<- Packet) *Socket { ctx, cancel := context.WithCancel(context.Background()) s := &Socket{ conn: conn, + dhtCh: dhtCh, conns: make(map[uint16]*Conn), accepts: make(chan *Conn, 16), ctx: ctx, cancel: cancel, } - return s, nil + return s } func (s *Socket) Start() { @@ -101,30 +96,42 @@ func (s *Socket) reader() { if err != nil { return } + if n == 0 { + continue + } + + data := append([]byte(nil), buf[:n]...) + + if data[0] == 'd' { + select { + case s.dhtCh <- Packet{Raw: data, Addr: addr}: + default: + } + continue + } + if n < headerSize { continue } - hdr := decode(buf) - payload := make([]byte, n-headerSize) - copy(payload, buf[headerSize:n]) - if hdr.typ == Syn { - s.handleSyn(hdr, payload, addr) + pkt := NewPacket(data, addr) + if pkt.Hdr.typ == Syn { + s.handleSyn(pkt) } else { - s.dispatch(hdr, payload) + s.dispatch(pkt) } } } -func (s *Socket) handleSyn(hdr header, payload []byte, addr *net.UDPAddr) { - c := s.newConn(hdr.connID, addr, false) +func (s *Socket) handleSyn(pkt Packet) { + c := s.newConn(pkt.Hdr.connID, pkt.Addr, false) s.mu.Lock() s.conns[c.recvID] = c s.mu.Unlock() go c.run() - c.in <- packet{hdr, payload, addr} + c.in <- pkt select { case s.accepts <- c: @@ -133,14 +140,14 @@ func (s *Socket) handleSyn(hdr header, payload []byte, addr *net.UDPAddr) { } } -func (s *Socket) dispatch(hdr header, payload []byte) { +func (s *Socket) dispatch(pkt Packet) { s.mu.RLock() - c := s.conns[hdr.connID] + c := s.conns[pkt.Hdr.connID] s.mu.RUnlock() if c != nil { select { - case c.in <- packet{hdr, payload, nil}: + case c.in <- pkt: default: } } @@ -157,7 +164,7 @@ func (s *Socket) newConn(peerID uint16, addr *net.UDPAddr, initiator bool) *Conn c := &Conn{ sock: s, addr: addr, - in: make(chan packet, 16), + in: make(chan Packet, 256), reads: make(chan readReq), writes: make(chan writeReq), closeReq: make(chan struct{}), @@ -202,6 +209,10 @@ func (s *Socket) Close() { s.cancel() } +func (s *Socket) LocalPort() int { + return s.conn.LocalAddr().(*net.UDPAddr).Port +} + func randUint16() uint16 { var b [2]byte if _, err := rand.Read(b[:]); err != nil { @@ -209,3 +220,15 @@ func randUint16() uint16 { } return binary.BigEndian.Uint16(b[:]) } + +func NewPacket(data []byte, addr *net.UDPAddr) Packet { + pkt := Packet{ + Raw: data, + Addr: addr, + } + if len(data) >= headerSize { + pkt.Hdr = decode(data) + pkt.Payload = data[headerSize:] + } + return pkt +}