first commit
This commit is contained in:
116
bt/msg.go
Normal file
116
bt/msg.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package bt
|
||||
|
||||
// https://wiki.theory.org/BitTorrentSpecification#Messages
|
||||
// Messages: <length prefix><message ID><payload>
|
||||
// keep-alive: <len=0000>
|
||||
// choke: <len=0001><type=0>
|
||||
// unchoke: <len=0001><type=1>
|
||||
// interested: <len=0001><type=2>
|
||||
// not interested: <len=0001><type=3>
|
||||
// have: <len=0005><type=4><piece index>
|
||||
// bitfield: <len=0001+X><type=5><bitfield>
|
||||
// request: <len=0013><type=6><index><begin><length>
|
||||
// piece: <len=0009+X><type=7><index><begin><block>
|
||||
// cancel: <len=0013><type=8><index><begin><length>
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
)
|
||||
|
||||
type MsgKind uint8
|
||||
|
||||
const BlockSize = 16384
|
||||
|
||||
const (
|
||||
Choke MsgKind = iota
|
||||
Unchoke
|
||||
Interested
|
||||
NotInterested
|
||||
Have
|
||||
Bitfield
|
||||
Request
|
||||
Piece
|
||||
Cancel
|
||||
)
|
||||
|
||||
type Msg struct {
|
||||
Kind MsgKind
|
||||
Index uint32
|
||||
Begin uint32
|
||||
Length uint32
|
||||
Bitfield []byte
|
||||
Block []byte
|
||||
}
|
||||
|
||||
// [4: len] [1: type] [len-1: payload]
|
||||
func readMsg(r io.Reader) (*Msg, int, error) {
|
||||
var lenBuf [4]byte
|
||||
if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
n := binary.BigEndian.Uint32(lenBuf[:])
|
||||
if n == 0 {
|
||||
return nil, 4, nil
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
msg := &Msg{Kind: MsgKind(buf[0])}
|
||||
payload := buf[1:]
|
||||
switch msg.Kind {
|
||||
case Have:
|
||||
msg.Index = binary.BigEndian.Uint32(payload)
|
||||
case Bitfield:
|
||||
msg.Bitfield = payload
|
||||
case Request, Cancel:
|
||||
msg.Index = binary.BigEndian.Uint32(payload[0:4])
|
||||
msg.Begin = binary.BigEndian.Uint32(payload[4:8])
|
||||
msg.Length = binary.BigEndian.Uint32(payload[8:12])
|
||||
case Piece:
|
||||
msg.Index = binary.BigEndian.Uint32(payload[0:4])
|
||||
msg.Begin = binary.BigEndian.Uint32(payload[4:8])
|
||||
msg.Block = payload[8:]
|
||||
}
|
||||
return msg, int(4 + n), nil
|
||||
}
|
||||
|
||||
func writeMsg(w io.Writer, msg *Msg) (int, error) {
|
||||
if msg == nil {
|
||||
return w.Write(make([]byte, 4))
|
||||
}
|
||||
|
||||
var payload []byte
|
||||
switch msg.Kind {
|
||||
case Choke, Unchoke, Interested, NotInterested:
|
||||
case Have:
|
||||
payload = make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(payload, msg.Index)
|
||||
case Bitfield:
|
||||
payload = msg.Bitfield
|
||||
case Request, Cancel:
|
||||
payload = make([]byte, 12)
|
||||
binary.BigEndian.PutUint32(payload[0:4], msg.Index)
|
||||
binary.BigEndian.PutUint32(payload[4:8], msg.Begin)
|
||||
binary.BigEndian.PutUint32(payload[8:12], msg.Length)
|
||||
case Piece:
|
||||
payload = make([]byte, 8+len(msg.Block))
|
||||
binary.BigEndian.PutUint32(payload[0:4], msg.Index)
|
||||
binary.BigEndian.PutUint32(payload[4:8], msg.Begin)
|
||||
copy(payload[8:], msg.Block)
|
||||
}
|
||||
buf := make([]byte, 5+len(payload))
|
||||
binary.BigEndian.PutUint32(buf[0:4], uint32(1+len(payload)))
|
||||
buf[4] = byte(msg.Kind)
|
||||
copy(buf[5:], payload)
|
||||
return w.Write(buf)
|
||||
}
|
||||
|
||||
func SetBit(bf []byte, i int) {
|
||||
bf[i/8] |= 1 << (7 - i%8)
|
||||
}
|
||||
|
||||
func HasBit(bf []byte, i int) bool {
|
||||
return bf[i/8]&(1<<(7-i%8)) != 0
|
||||
}
|
||||
91
bt/msg_test.go
Normal file
91
bt/msg_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package bt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMsg(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *Msg
|
||||
}{
|
||||
{
|
||||
name: "choke",
|
||||
msg: &Msg{Kind: Choke},
|
||||
},
|
||||
{
|
||||
name: "unchoke",
|
||||
msg: &Msg{Kind: Unchoke},
|
||||
},
|
||||
{
|
||||
name: "interested",
|
||||
msg: &Msg{Kind: Interested},
|
||||
},
|
||||
{
|
||||
name: "not interested",
|
||||
msg: &Msg{Kind: NotInterested},
|
||||
},
|
||||
{
|
||||
name: "have",
|
||||
msg: &Msg{
|
||||
Kind: Have,
|
||||
Index: 42,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bitfield",
|
||||
msg: &Msg{
|
||||
Kind: Bitfield,
|
||||
Bitfield: []byte{0b1111_1111, 0b0000_0000},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request",
|
||||
msg: &Msg{
|
||||
Kind: Request,
|
||||
Index: 1,
|
||||
Begin: 16384,
|
||||
Length: 16384,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "piece",
|
||||
msg: &Msg{
|
||||
Kind: Piece,
|
||||
Index: 1,
|
||||
Begin: 0,
|
||||
Block: []byte("data"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cancel",
|
||||
msg: &Msg{
|
||||
Kind: Cancel,
|
||||
Index: 1,
|
||||
Begin: 0,
|
||||
Length: 16384,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "keep-alive",
|
||||
msg: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
writeMsg(&buf, tt.msg)
|
||||
data := buf.Bytes()
|
||||
got, _, err := readMsg(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
t.Fatalf("readMsg: %v", err)
|
||||
}
|
||||
var buf2 bytes.Buffer
|
||||
writeMsg(&buf2, got)
|
||||
if !bytes.Equal(buf2.Bytes(), data) {
|
||||
t.Errorf("roundtrip mismatch: got %x, want %x", buf2.Bytes(), data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
160
bt/peer.go
Normal file
160
bt/peer.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package bt
|
||||
|
||||
// https://wiki.theory.org/BitTorrentSpecification#Handshake
|
||||
// handshake: <pstrlen><pstr><reserved><info_hash><peer_id>
|
||||
// pstrlen: string length of <pstr>, as a single raw byte.
|
||||
// pstr: string identifier of the protocol.
|
||||
// reserved: eight (8) reserved bytes.
|
||||
// info_hash: 20-byte SHA1 hash of the info key in the metainfo file.
|
||||
// peer_id: 20-byte string used as a unique ID for the client.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"storrent/utp"
|
||||
)
|
||||
|
||||
type Peer struct {
|
||||
Addr *net.TCPAddr
|
||||
conn io.ReadWriteCloser
|
||||
InfoHash [20]byte
|
||||
PeerID [20]byte
|
||||
Choked bool
|
||||
Interested bool
|
||||
Have []bool
|
||||
Pending int
|
||||
}
|
||||
|
||||
const btProto = "BitTorrent protocol"
|
||||
|
||||
const (
|
||||
pstrlen = 19
|
||||
reservedLen = 8
|
||||
hashLen = 20
|
||||
peerIDLen = 20
|
||||
handshakeLen = 1 + pstrlen + reservedLen + hashLen + peerIDLen // 68
|
||||
handshakeTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
var peerID = genPeerID()
|
||||
|
||||
func genPeerID() [20]byte {
|
||||
var id [20]byte
|
||||
copy(id[:], "-SS0001-")
|
||||
if _, err := rand.Read(id[8:]); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func DialContext(ctx context.Context, addr *net.TCPAddr, h [20]byte, utpSock *utp.Socket) (*Peer, error) {
|
||||
if utpSock == nil {
|
||||
return nil, fmt.Errorf("utp socket required")
|
||||
}
|
||||
udpAddr := &net.UDPAddr{IP: addr.IP, Port: addr.Port}
|
||||
conn, err := utpSock.DialContext(ctx, udpAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := &Peer{Addr: addr, conn: conn, InfoHash: h, Choked: true}
|
||||
if err := p.handshake(); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// [1: pstrlen] [pstrlen: btProto] [reservedLen: reserved] [hashLen: info_hash] [peerIDLen: peer_id]
|
||||
func (p *Peer) handshake() error {
|
||||
buf := make([]byte, handshakeLen)
|
||||
buf[0] = pstrlen
|
||||
copy(buf[1:1+pstrlen], btProto)
|
||||
copy(buf[1+pstrlen+reservedLen:1+pstrlen+reservedLen+hashLen], p.InfoHash[:])
|
||||
copy(buf[1+pstrlen+reservedLen+hashLen:], peerID[:])
|
||||
if _, err := p.conn.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.ReadFull(p.conn, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
if buf[0] != pstrlen || string(buf[1:1+pstrlen]) != btProto {
|
||||
return fmt.Errorf("invalid handshake")
|
||||
}
|
||||
if [hashLen]byte(buf[1+pstrlen+reservedLen:1+pstrlen+reservedLen+hashLen]) != p.InfoHash {
|
||||
return fmt.Errorf("info hash mismatch")
|
||||
}
|
||||
p.PeerID = [peerIDLen]byte(buf[1+pstrlen+reservedLen+hashLen:])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Peer) Send(msg *Msg) error {
|
||||
_, err := writeMsg(p.conn, msg)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Peer) Recv() (*Msg, error) {
|
||||
msg, _, err := readMsg(p.conn)
|
||||
return msg, err
|
||||
}
|
||||
|
||||
func (p *Peer) Close() error {
|
||||
return p.conn.Close()
|
||||
}
|
||||
|
||||
func Accept(ln net.Listener) (*Peer, error) {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn.SetDeadline(time.Now().Add(handshakeTimeout))
|
||||
buf := make([]byte, handshakeLen)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if buf[0] != pstrlen || string(buf[1:1+pstrlen]) != btProto {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("invalid handshake")
|
||||
}
|
||||
infoHash := [hashLen]byte(buf[1+pstrlen+reservedLen : 1+pstrlen+reservedLen+hashLen])
|
||||
remotePeerID := [peerIDLen]byte(buf[1+pstrlen+reservedLen+hashLen:])
|
||||
copy(buf[1+pstrlen+reservedLen+hashLen:], peerID[:])
|
||||
if _, err := conn.Write(buf); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
conn.SetDeadline(time.Time{})
|
||||
return &Peer{
|
||||
conn: conn,
|
||||
InfoHash: infoHash,
|
||||
PeerID: remotePeerID,
|
||||
Choked: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// piece 0~7: byte 0 [7 6 5 4 3 2 1 0]
|
||||
// piece 8~15: byte 1 [7 6 5 4 3 2 1 0]
|
||||
func (p *Peer) SetPieces(data []byte, n int) {
|
||||
p.Have = make([]bool, n)
|
||||
for i := range n {
|
||||
p.Have[i] = HasBit(data, i)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) HasPiece(i int) bool {
|
||||
if i >= 0 && i < len(p.Have) {
|
||||
return p.Have[i]
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Peer) SetPiece(i int) {
|
||||
if i >= 0 && i < len(p.Have) {
|
||||
p.Have[i] = true
|
||||
}
|
||||
}
|
||||
37
bt/peer_test.go
Normal file
37
bt/peer_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package bt
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSetPieces(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
n int
|
||||
i int
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "piece 0 set",
|
||||
data: []byte{0b1000_0000},
|
||||
n: 8,
|
||||
i: 0,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "piece 7 unset",
|
||||
data: []byte{0b1111_1110},
|
||||
n: 8,
|
||||
i: 7,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Peer{}
|
||||
c.SetPieces(tt.data, tt.n)
|
||||
if c.Have[tt.i] != tt.want {
|
||||
t.Errorf("Have[%d]: got %v, want %v", tt.i, c.Have[tt.i], tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user