diff --git a/src/net/buffers.rs b/src/net/buffers.rs new file mode 100644 index 0000000..1769863 --- /dev/null +++ b/src/net/buffers.rs @@ -0,0 +1,68 @@ +use std::io::{self, Read}; +use std::ptr; + +use sha1::Sha1; + +use metainfo::Hash; + +/// Allocate a new buffer of the given size and read_exact into it. +pub fn read_exact(mut reader: R, size: usize) -> io::Result> where R: Read { + let mut buf = Vec::with_capacity(size); + unsafe { + buf.set_len(size); + } + reader.read_exact(&mut buf)?; + Ok(buf) +} + +pub struct PieceBuffer { + fragment_count: u32, + num_fragments: u32, + buffer: Vec, +} + +impl PieceBuffer { + pub fn new(num_fragments: u32, piece_length: u32) -> Self { + let mut buffer = Vec::with_capacity(piece_length as usize); + unsafe { + buffer.set_len(piece_length as usize); + } + PieceBuffer { + fragment_count: 0, + num_fragments: num_fragments, + buffer: buffer, + } + } + + pub fn add_fragment(&mut self, begin: u32, fragment: &[u8]) { + let begin = begin as usize; + assert!(begin + fragment.len() <= self.buffer.len()); + unsafe { + ptr::copy_nonoverlapping(fragment.as_ptr(), self.buffer[begin..].as_mut_ptr(), fragment.len()); + } + self.fragment_count += 1; + } + + #[inline] + pub fn complete(&self) -> bool { + self.fragment_count == self.num_fragments + } + + pub fn matches_hash(&self, info_hash: &Hash) -> bool { + if !self.complete() { + panic!("trying to hash piece buffer before it's complete"); + } else { + let mut m = Sha1::new(); + m.update(&self.buffer); + m.digest().bytes() == &info_hash[..] + } + } + + pub fn get(self) -> Vec { + if !self.complete() { + panic!("trying to get piece buffer before it's complete"); + } else { + self.buffer + } + } +} diff --git a/src/net/mod.rs b/src/net/mod.rs index 0eadd19..00df1d7 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -1,3 +1,4 @@ pub mod bitfield; +mod buffers; pub mod peer; pub mod session; diff --git a/src/net/peer.rs b/src/net/peer.rs index e964950..b042d20 100644 --- a/src/net/peer.rs +++ b/src/net/peer.rs @@ -4,6 +4,7 @@ use std::net::{Shutdown, TcpStream}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use metainfo::Hash; +use net::buffers; use tracker::Peer; #[derive(Debug)] @@ -68,11 +69,8 @@ impl PeerConnection { }, 5 => { let size = len as usize - 1; - let mut bitfield = Vec::with_capacity(size); - unsafe { - bitfield.set_len(size); - } - self.sock.read_exact(&mut bitfield)?; + let bitfield = buffers::read_exact(&mut self.sock, size)?; + Packet::Bitfield { bitfield: bitfield, } @@ -86,11 +84,7 @@ impl PeerConnection { let size = len as usize - 9; let index = self.sock.read_u32::()?; let begin = self.sock.read_u32::()?; - let mut block = Vec::with_capacity(size); - unsafe { - block.set_len(size); - } - self.sock.read_exact(&mut block)?; + let block = buffers::read_exact(&mut self.sock, size)?; Packet::Piece { index: index, diff --git a/src/net/session.rs b/src/net/session.rs index 52dfdad..d3630e1 100644 --- a/src/net/session.rs +++ b/src/net/session.rs @@ -9,10 +9,10 @@ use std::thread; use std::time::{Duration, Instant}; use libc; -use sha1::Sha1; use metainfo::{Hash, Metainfo}; use net::bitfield::BitField; +use net::buffers::PieceBuffer; use net::peer::{Packet, PeerConnection}; use tracker::http; @@ -132,9 +132,7 @@ pub struct SessionFragment { struct SessionPiece { fragments: BTreeMap, - buffer: Vec, - num_fragments: u32, - total_fragments: u32, + buffer: PieceBuffer, } struct SessionTorrent { @@ -197,9 +195,9 @@ impl SessionTorrent { for index in (0..self.own_bitfield.len()).filter(|index| !keys.contains(index)) { if !self.own_bitfield.is_set(index) && peer.bitfield.is_set(index) { - let total_fragments = f64::ceil(self.metainfo.piece_length as f64 / FRAGMENT_SIZE as f64) as u32; + let num_fragments = f64::ceil(self.metainfo.piece_length as f64 / FRAGMENT_SIZE as f64) as u32; let mut fragments = BTreeMap::new(); - for idx in 0..total_fragments { + for idx in 0..num_fragments { let begin = idx * FRAGMENT_SIZE; fragments.insert(begin ,SessionFragment { @@ -218,15 +216,9 @@ impl SessionTorrent { length = taken.length; } - let mut buffer = Vec::with_capacity(self.metainfo.piece_length as usize); - unsafe { - buffer.set_len(self.metainfo.piece_length as usize); - } self.pieces.insert(index, SessionPiece { fragments: fragments, - buffer: buffer, - num_fragments: 0, - total_fragments: total_fragments, + buffer: PieceBuffer::new(num_fragments, self.metainfo.piece_length), }); return Some((index, begin, length)); @@ -246,7 +238,7 @@ impl SessionTorrent { } } - fn write_piece(&mut self, index: u32) { + fn write_piece(&mut self, index: u32, buffer: Vec) { let mut start = 0; let mut end = 0; let mut pos: u64 = 0; @@ -259,7 +251,7 @@ impl SessionTorrent { if write < end { let len = cmp::min(remaining, fileinfo.length as u64); self.files[id].seek(SeekFrom::Start(write - start)); - self.files[id].write_all(&self.pieces[&index].buffer[pos as usize..(pos+len) as usize]); + self.files[id].write_all(&buffer[pos as usize..(pos+len) as usize]); write += len; pos += len; remaining -= len; @@ -280,36 +272,39 @@ impl SessionTorrent { fn piece_reply(&mut self, peer_id: &Hash, index: u32, begin: u32, block: Vec) { let mut remove = false; + let mut reset = false; { if let Some(piece) = self.pieces.get_mut(&index) { if let Some(fragment) = piece.fragments.get_mut(&begin) { fragment.status = FragmentStatus::Complete; - piece.buffer[begin as usize..(begin as usize)+block.len()].copy_from_slice(&block); - piece.num_fragments += 1; - if piece.num_fragments == piece.total_fragments { - // TODO check hash + piece.buffer.add_fragment(begin, &block); + if piece.buffer.complete() { println!("piece is done {}", index); - let mut m = Sha1::new(); - m.update(&piece.buffer); - if m.digest().bytes() == &self.metainfo.pieces[index as usize][..] { + if piece.buffer.matches_hash(&self.metainfo.pieces[index as usize]) { self.own_bitfield.set(index); println!("it's a match!"); remove = true; } else { + reset = true; println!("no match"); } } } else { println!("could not find fragment {}", begin); } + if reset { + for fragment in piece.fragments.values_mut() { + fragment.status = FragmentStatus::Available; + } + } } else { println!("could not find piece {}", index); } } if remove { - self.write_piece(index); - self.pieces.remove(&index); + let piece = self.pieces.remove(&index).expect("told to remove piece that doesn't exist"); + self.write_piece(index, piece.buffer.get()); } self.requeue(peer_id); }