session: add buffer utilities

This commit is contained in:
2016-12-14 16:36:43 -05:00
parent 3860112f21
commit 18d64ecebd
4 changed files with 92 additions and 34 deletions

68
src/net/buffers.rs Normal file
View File

@@ -0,0 +1,68 @@
use std::io::{self, Read};
use std::ptr;
use sha1::Sha1;
use metainfo::Hash;
/// Allocate a new buffer of the given size and read_exact into it.
pub fn read_exact<R>(mut reader: R, size: usize) -> io::Result<Vec<u8>> where R: Read {
let mut buf = Vec::with_capacity(size);
unsafe {
buf.set_len(size);
}
reader.read_exact(&mut buf)?;
Ok(buf)
}
pub struct PieceBuffer {
fragment_count: u32,
num_fragments: u32,
buffer: Vec<u8>,
}
impl PieceBuffer {
pub fn new(num_fragments: u32, piece_length: u32) -> Self {
let mut buffer = Vec::with_capacity(piece_length as usize);
unsafe {
buffer.set_len(piece_length as usize);
}
PieceBuffer {
fragment_count: 0,
num_fragments: num_fragments,
buffer: buffer,
}
}
pub fn add_fragment(&mut self, begin: u32, fragment: &[u8]) {
let begin = begin as usize;
assert!(begin + fragment.len() <= self.buffer.len());
unsafe {
ptr::copy_nonoverlapping(fragment.as_ptr(), self.buffer[begin..].as_mut_ptr(), fragment.len());
}
self.fragment_count += 1;
}
#[inline]
pub fn complete(&self) -> bool {
self.fragment_count == self.num_fragments
}
pub fn matches_hash(&self, info_hash: &Hash) -> bool {
if !self.complete() {
panic!("trying to hash piece buffer before it's complete");
} else {
let mut m = Sha1::new();
m.update(&self.buffer);
m.digest().bytes() == &info_hash[..]
}
}
pub fn get(self) -> Vec<u8> {
if !self.complete() {
panic!("trying to get piece buffer before it's complete");
} else {
self.buffer
}
}
}

View File

@@ -1,3 +1,4 @@
pub mod bitfield; pub mod bitfield;
mod buffers;
pub mod peer; pub mod peer;
pub mod session; pub mod session;

View File

@@ -4,6 +4,7 @@ use std::net::{Shutdown, TcpStream};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use metainfo::Hash; use metainfo::Hash;
use net::buffers;
use tracker::Peer; use tracker::Peer;
#[derive(Debug)] #[derive(Debug)]
@@ -68,11 +69,8 @@ impl PeerConnection {
}, },
5 => { 5 => {
let size = len as usize - 1; let size = len as usize - 1;
let mut bitfield = Vec::with_capacity(size); let bitfield = buffers::read_exact(&mut self.sock, size)?;
unsafe {
bitfield.set_len(size);
}
self.sock.read_exact(&mut bitfield)?;
Packet::Bitfield { Packet::Bitfield {
bitfield: bitfield, bitfield: bitfield,
} }
@@ -86,11 +84,7 @@ impl PeerConnection {
let size = len as usize - 9; let size = len as usize - 9;
let index = self.sock.read_u32::<BigEndian>()?; let index = self.sock.read_u32::<BigEndian>()?;
let begin = self.sock.read_u32::<BigEndian>()?; let begin = self.sock.read_u32::<BigEndian>()?;
let mut block = Vec::with_capacity(size); let block = buffers::read_exact(&mut self.sock, size)?;
unsafe {
block.set_len(size);
}
self.sock.read_exact(&mut block)?;
Packet::Piece { Packet::Piece {
index: index, index: index,

View File

@@ -9,10 +9,10 @@ use std::thread;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use libc; use libc;
use sha1::Sha1;
use metainfo::{Hash, Metainfo}; use metainfo::{Hash, Metainfo};
use net::bitfield::BitField; use net::bitfield::BitField;
use net::buffers::PieceBuffer;
use net::peer::{Packet, PeerConnection}; use net::peer::{Packet, PeerConnection};
use tracker::http; use tracker::http;
@@ -132,9 +132,7 @@ pub struct SessionFragment {
struct SessionPiece { struct SessionPiece {
fragments: BTreeMap<u32, SessionFragment>, fragments: BTreeMap<u32, SessionFragment>,
buffer: Vec<u8>, buffer: PieceBuffer,
num_fragments: u32,
total_fragments: u32,
} }
struct SessionTorrent { struct SessionTorrent {
@@ -197,9 +195,9 @@ impl SessionTorrent {
for index in (0..self.own_bitfield.len()).filter(|index| !keys.contains(index)) { for index in (0..self.own_bitfield.len()).filter(|index| !keys.contains(index)) {
if !self.own_bitfield.is_set(index) && peer.bitfield.is_set(index) { if !self.own_bitfield.is_set(index) && peer.bitfield.is_set(index) {
let total_fragments = f64::ceil(self.metainfo.piece_length as f64 / FRAGMENT_SIZE as f64) as u32; let num_fragments = f64::ceil(self.metainfo.piece_length as f64 / FRAGMENT_SIZE as f64) as u32;
let mut fragments = BTreeMap::new(); let mut fragments = BTreeMap::new();
for idx in 0..total_fragments { for idx in 0..num_fragments {
let begin = idx * FRAGMENT_SIZE; let begin = idx * FRAGMENT_SIZE;
fragments.insert(begin ,SessionFragment { fragments.insert(begin ,SessionFragment {
@@ -218,15 +216,9 @@ impl SessionTorrent {
length = taken.length; length = taken.length;
} }
let mut buffer = Vec::with_capacity(self.metainfo.piece_length as usize);
unsafe {
buffer.set_len(self.metainfo.piece_length as usize);
}
self.pieces.insert(index, SessionPiece { self.pieces.insert(index, SessionPiece {
fragments: fragments, fragments: fragments,
buffer: buffer, buffer: PieceBuffer::new(num_fragments, self.metainfo.piece_length),
num_fragments: 0,
total_fragments: total_fragments,
}); });
return Some((index, begin, length)); return Some((index, begin, length));
@@ -246,7 +238,7 @@ impl SessionTorrent {
} }
} }
fn write_piece(&mut self, index: u32) { fn write_piece(&mut self, index: u32, buffer: Vec<u8>) {
let mut start = 0; let mut start = 0;
let mut end = 0; let mut end = 0;
let mut pos: u64 = 0; let mut pos: u64 = 0;
@@ -259,7 +251,7 @@ impl SessionTorrent {
if write < end { if write < end {
let len = cmp::min(remaining, fileinfo.length as u64); let len = cmp::min(remaining, fileinfo.length as u64);
self.files[id].seek(SeekFrom::Start(write - start)); self.files[id].seek(SeekFrom::Start(write - start));
self.files[id].write_all(&self.pieces[&index].buffer[pos as usize..(pos+len) as usize]); self.files[id].write_all(&buffer[pos as usize..(pos+len) as usize]);
write += len; write += len;
pos += len; pos += len;
remaining -= len; remaining -= len;
@@ -280,36 +272,39 @@ impl SessionTorrent {
fn piece_reply(&mut self, peer_id: &Hash, index: u32, begin: u32, block: Vec<u8>) { fn piece_reply(&mut self, peer_id: &Hash, index: u32, begin: u32, block: Vec<u8>) {
let mut remove = false; let mut remove = false;
let mut reset = false;
{ {
if let Some(piece) = self.pieces.get_mut(&index) { if let Some(piece) = self.pieces.get_mut(&index) {
if let Some(fragment) = piece.fragments.get_mut(&begin) { if let Some(fragment) = piece.fragments.get_mut(&begin) {
fragment.status = FragmentStatus::Complete; fragment.status = FragmentStatus::Complete;
piece.buffer[begin as usize..(begin as usize)+block.len()].copy_from_slice(&block); piece.buffer.add_fragment(begin, &block);
piece.num_fragments += 1; if piece.buffer.complete() {
if piece.num_fragments == piece.total_fragments {
// TODO check hash
println!("piece is done {}", index); println!("piece is done {}", index);
let mut m = Sha1::new(); if piece.buffer.matches_hash(&self.metainfo.pieces[index as usize]) {
m.update(&piece.buffer);
if m.digest().bytes() == &self.metainfo.pieces[index as usize][..] {
self.own_bitfield.set(index); self.own_bitfield.set(index);
println!("it's a match!"); println!("it's a match!");
remove = true; remove = true;
} else { } else {
reset = true;
println!("no match"); println!("no match");
} }
} }
} else { } else {
println!("could not find fragment {}", begin); println!("could not find fragment {}", begin);
} }
if reset {
for fragment in piece.fragments.values_mut() {
fragment.status = FragmentStatus::Available;
}
}
} else { } else {
println!("could not find piece {}", index); println!("could not find piece {}", index);
} }
} }
if remove { if remove {
self.write_piece(index); let piece = self.pieces.remove(&index).expect("told to remove piece that doesn't exist");
self.pieces.remove(&index); self.write_piece(index, piece.buffer.get());
} }
self.requeue(peer_id); self.requeue(peer_id);
} }