storrent/bt/peer.go
2026-01-19 21:13:01 +09:00

161 lines
3.8 KiB
Go

package bt
// https://wiki.theory.org/BitTorrentSpecification#Handshake
// handshake: <pstrlen><pstr><reserved><info_hash><peer_id>
// pstrlen: string length of <pstr>, as a single raw byte.
// pstr: string identifier of the protocol.
// reserved: eight (8) reserved bytes.
// info_hash: 20-byte SHA1 hash of the info key in the metainfo file.
// peer_id: 20-byte string used as a unique ID for the client.
import (
"context"
"crypto/rand"
"fmt"
"io"
"net"
"time"
"storrent/utp"
)
type Peer struct {
Addr *net.TCPAddr
conn io.ReadWriteCloser
InfoHash [20]byte
PeerID [20]byte
Choked bool
Interested bool
Have []bool
Pending int
}
const btProto = "BitTorrent protocol"
const (
pstrlen = 19
reservedLen = 8
hashLen = 20
peerIDLen = 20
handshakeLen = 1 + pstrlen + reservedLen + hashLen + peerIDLen // 68
handshakeTimeout = 10 * time.Second
)
var peerID = genPeerID()
func genPeerID() [20]byte {
var id [20]byte
copy(id[:], "-SS0001-")
if _, err := rand.Read(id[8:]); err != nil {
panic(err)
}
return id
}
func DialContext(ctx context.Context, addr *net.TCPAddr, h [20]byte, utpSock *utp.Socket) (*Peer, error) {
if utpSock == nil {
return nil, fmt.Errorf("utp socket required")
}
udpAddr := &net.UDPAddr{IP: addr.IP, Port: addr.Port}
conn, err := utpSock.DialContext(ctx, udpAddr)
if err != nil {
return nil, err
}
p := &Peer{Addr: addr, conn: conn, InfoHash: h, Choked: true}
if err := p.handshake(); err != nil {
conn.Close()
return nil, err
}
return p, nil
}
// [1: pstrlen] [pstrlen: btProto] [reservedLen: reserved] [hashLen: info_hash] [peerIDLen: peer_id]
func (p *Peer) handshake() error {
buf := make([]byte, handshakeLen)
buf[0] = pstrlen
copy(buf[1:1+pstrlen], btProto)
copy(buf[1+pstrlen+reservedLen:1+pstrlen+reservedLen+hashLen], p.InfoHash[:])
copy(buf[1+pstrlen+reservedLen+hashLen:], peerID[:])
if _, err := p.conn.Write(buf); err != nil {
return err
}
if _, err := io.ReadFull(p.conn, buf); err != nil {
return err
}
if buf[0] != pstrlen || string(buf[1:1+pstrlen]) != btProto {
return fmt.Errorf("invalid handshake")
}
if [hashLen]byte(buf[1+pstrlen+reservedLen:1+pstrlen+reservedLen+hashLen]) != p.InfoHash {
return fmt.Errorf("info hash mismatch")
}
p.PeerID = [peerIDLen]byte(buf[1+pstrlen+reservedLen+hashLen:])
return nil
}
func (p *Peer) Send(msg *Msg) error {
_, err := writeMsg(p.conn, msg)
return err
}
func (p *Peer) Recv() (*Msg, error) {
msg, _, err := readMsg(p.conn)
return msg, err
}
func (p *Peer) Close() error {
return p.conn.Close()
}
func Accept(ln net.Listener) (*Peer, error) {
conn, err := ln.Accept()
if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(handshakeTimeout))
buf := make([]byte, handshakeLen)
if _, err := io.ReadFull(conn, buf); err != nil {
conn.Close()
return nil, err
}
if buf[0] != pstrlen || string(buf[1:1+pstrlen]) != btProto {
conn.Close()
return nil, fmt.Errorf("invalid handshake")
}
infoHash := [hashLen]byte(buf[1+pstrlen+reservedLen : 1+pstrlen+reservedLen+hashLen])
remotePeerID := [peerIDLen]byte(buf[1+pstrlen+reservedLen+hashLen:])
copy(buf[1+pstrlen+reservedLen+hashLen:], peerID[:])
if _, err := conn.Write(buf); err != nil {
conn.Close()
return nil, err
}
conn.SetDeadline(time.Time{})
return &Peer{
conn: conn,
InfoHash: infoHash,
PeerID: remotePeerID,
Choked: true,
}, nil
}
// piece 0~7: byte 0 [7 6 5 4 3 2 1 0]
// piece 8~15: byte 1 [7 6 5 4 3 2 1 0]
func (p *Peer) SetPieces(data []byte, n int) {
p.Have = make([]bool, n)
for i := range n {
p.Have[i] = HasBit(data, i)
}
}
func (p *Peer) HasPiece(i int) bool {
if i >= 0 && i < len(p.Have) {
return p.Have[i]
}
return false
}
func (p *Peer) SetPiece(i int) {
if i >= 0 && i < len(p.Have) {
p.Have[i] = true
}
}