tlib.rs - sraft - simple raft implementation
 (HTM) git clone https://git.parazyd.org/sraft
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) README
       ---
       tlib.rs (16370B)
       ---
            1 use std::{collections::HashMap, io, net::SocketAddr, time::Duration};
            2 
            3 use async_channel::{Receiver, Sender};
            4 use async_std::{
            5     io::{ReadExt, WriteExt},
            6     net::{TcpListener, TcpStream},
            7     stream::StreamExt,
            8     sync::Mutex,
            9     task,
           10 };
           11 use borsh::{BorshDeserialize, BorshSerialize};
           12 use futures::{select, FutureExt};
           13 use lazy_static::lazy_static;
           14 use log::{debug, error};
           15 use rand::Rng;
           16 
           17 mod method;
           18 use crate::method::{HeartbeatArgs, HeartbeatReply, RaftMethod, VoteArgs, VoteReply};
           19 
           20 #[derive(BorshSerialize, BorshDeserialize, Clone, Debug)]
           21 pub struct LogEntry {
           22     log_term: u64,
           23     log_index: u64,
           24     log_data: Vec<u8>,
           25 }
           26 
           27 pub struct LogStore(pub Vec<LogEntry>);
           28 
           29 impl LogStore {
           30     fn get_last_index(&self) -> u64 {
           31         let rlen = self.0.len();
           32         if rlen == 0 {
           33             return 0
           34         }
           35 
           36         self.0[rlen - 1].log_index
           37     }
           38 }
           39 
           40 lazy_static! {
           41     pub static ref LOG_STORE: Mutex<LogStore> = Mutex::new(LogStore(vec![]));
           42     // This is used for heartbeats
           43     pub static ref HEARTBEAT_CHAN: (Sender<bool>, Receiver<bool>) = async_channel::unbounded();
           44     // This is used to let our node know when it has become a leader
           45     pub static ref TOLEADER_CHAN: (Sender<bool>, Receiver<bool>) = async_channel::unbounded();
           46 
           47     pub static ref STATE: Mutex<State> = Mutex::new(State::new());
           48 }
           49 
           50 #[derive(Default)]
           51 pub struct State {
           52     pub current_term: u64,
           53     pub voted_for: u64,
           54     pub vote_count: u64,
           55 
           56     pub commit_index: u64,
           57     pub _last_applied: u64,
           58 
           59     pub next_index: Vec<u64>,
           60     pub match_index: Vec<u64>,
           61 }
           62 
           63 impl State {
           64     pub fn new() -> Self {
           65         Self {
           66             current_term: 0,
           67             voted_for: 0,
           68             vote_count: 0,
           69             commit_index: 0,
           70             _last_applied: 0,
           71             next_index: vec![],
           72             match_index: vec![],
           73         }
           74     }
           75 }
           76 
           77 pub enum Role {
           78     Follower,
           79     Candidate,
           80     Leader,
           81 }
           82 
           83 pub struct Raft {
           84     pub peers: HashMap<u64, SocketAddr>,
           85     node_id: u64,
           86     role: Role,
           87 }
           88 
           89 impl Raft {
           90     pub fn new(node_id: u64) -> Self {
           91         Self { peers: Default::default(), node_id, role: Role::Follower }
           92     }
           93 
           94     pub async fn start(&mut self) {
           95         debug!("Raft::start()");
           96         self.role = Role::Follower;
           97 
           98         let mut state = STATE.lock().await;
           99         state.current_term = 0;
          100         state.voted_for = 0;
          101         drop(state);
          102 
          103         let mut rng = rand::thread_rng();
          104 
          105         loop {
          106             let delay = Duration::from_millis(rng.gen_range(0..200) + 300);
          107 
          108             match self.role {
          109                 Role::Follower => {
          110                     select! {
          111                         _ = HEARTBEAT_CHAN.1.recv().fuse() => {
          112                             debug!("[FOLLOWER] Raft::start(): follower_{} got heartbeat", self.node_id);
          113                         }
          114                         _ = task::sleep(delay).fuse() => {
          115                             debug!("[FOLLOWER] Raft::start(): follower_{} timeout", self.node_id);
          116                             self.role = Role::Candidate;
          117                         }
          118                     }
          119                 }
          120 
          121                 Role::Candidate => {
          122                     debug!("[CANDIDATE] Raft::start(): peer_{} is now a candidate", self.node_id);
          123                     let mut state = STATE.lock().await;
          124                     state.current_term += 1;
          125                     state.voted_for = self.node_id;
          126                     state.vote_count = 1;
          127                     drop(state);
          128 
          129                     // TODO: In background
          130                     debug!("[CANDIDATE] Raft::start(): broadcasting request_vote");
          131                     self.broadcast_request_vote().await;
          132 
          133                     select! {
          134                         _ = task::sleep(delay).fuse() => {
          135                             debug!("[CANDIDATE] Raft::start(): Timeout as candidate, becoming a follower");
          136                             self.role = Role::Follower;
          137                         }
          138                         _ = TOLEADER_CHAN.1.recv().fuse() => {
          139                             debug!("[CANDIDATE] Raft::start(): We are now the leader");
          140                             self.role = Role::Leader;
          141 
          142                             let mut state = STATE.lock().await;
          143                             state.next_index = vec![1_u64; self.peers.len()];
          144                             state.match_index = vec![0_u64; self.peers.len()];
          145                             drop(state);
          146 
          147                             // TODO: In background
          148                             let t = task::spawn(async {
          149                                 let mut i = 0;
          150                                 loop {
          151                                     debug!("[CANDIDATE] Raft::start(): Appending data in bg loop");
          152                                     i += 1;
          153                                     let state = STATE.lock().await;
          154                                     let logentry = LogEntry {
          155                                         log_term: state.current_term,
          156                                         log_index: i,
          157                                         log_data: format!("user send: {}", i).as_bytes().to_vec(),
          158                                     };
          159                                     drop(state);
          160 
          161                                     debug!("[CANDIDATE] Raft::start(): Acquiring logstore lock in bg loop");
          162                                     let mut logstore = LOG_STORE.lock().await;
          163                                     logstore.0.push(logentry);
          164                                     drop(logstore);
          165                                     debug!("[CANDIDATE] Raft::start(): Dropped logstore lock in bg loop");
          166                                     task::sleep(Duration::from_secs(3)).await;
          167                                 }
          168                             });
          169                         }
          170                     }
          171                 }
          172 
          173                 Role::Leader => {
          174                     debug!("[LEADER] Raft::start(): Broadcasting heartbeat as leader");
          175                     self.broadcast_heartbeat().await;
          176                     task::sleep(Duration::from_millis(100)).await;
          177                 }
          178             }
          179         }
          180     }
          181 
          182     async fn broadcast_request_vote(&mut self) {
          183         debug!("Raft::broadcast_request_vote()");
          184         let state = STATE.lock().await;
          185         let args = VoteArgs { term: state.current_term, candidate_id: self.node_id };
          186         drop(state);
          187 
          188         // TODO: Do this concurrently
          189         for i in self.peers.clone() {
          190             debug!("Raft::broadcast_request_vote(): Sending req to peer {}", i.1);
          191             match self.send_request_vote(i.0, args.clone()).await {
          192                 Ok(v) => debug!("Raft::broadcast_request_vote(): Got reply: {:?}", v),
          193                 Err(e) => {
          194                     error!("Raft::broadcast_request_vote(): Failed vote to peer {}, ({})", i.1, e);
          195                     continue
          196                 }
          197             };
          198         }
          199     }
          200 
          201     async fn send_request_vote(
          202         &mut self,
          203         node_id: u64,
          204         args: VoteArgs,
          205     ) -> Result<VoteReply, io::Error> {
          206         debug!("Raft::send_request_vote()");
          207         let addr = self.peers[&node_id];
          208 
          209         let method = RaftMethod::Vote(args);
          210         let payload = method.try_to_vec().unwrap();
          211 
          212         debug!("Raft::send_request_vote(): Connecting to peer_{}", node_id);
          213         let mut stream = TcpStream::connect(addr).await?;
          214         debug!("Raft::send_request_vote(): Writing to stream");
          215         stream.write_all(&payload).await?;
          216         debug!("Raft::send_request_vote(): Wrote to stream");
          217 
          218         debug!("Raft::send_request_vote(): Reading from stream");
          219         let mut buf = vec![0_u8; 4096];
          220         stream.read(&mut buf).await?;
          221         debug!("Raft::send_request_vote(): Read from stream");
          222 
          223         let reply = try_from_slice_unchecked::<VoteReply>(&buf)?;
          224         let mut state = STATE.lock().await;
          225         if reply.term > state.current_term {
          226             debug!("Raft::send_request_vote(): reply.term > state.current_term");
          227             state.current_term = reply.term;
          228             state.voted_for = 0;
          229             drop(state);
          230             self.role = Role::Follower;
          231             return Ok(reply)
          232         }
          233         drop(state);
          234 
          235         if reply.vote_granted {
          236             debug!("Raft::send_request_vote(): reply.vote_granted == true");
          237             let mut state = STATE.lock().await;
          238             state.vote_count += 1;
          239             drop(state);
          240         }
          241 
          242         let state = STATE.lock().await;
          243         if state.vote_count >= (self.peers.len() / 2 + 1).try_into().unwrap() {
          244             debug!("Raft::send_request_vote(): Elected for leader");
          245             TOLEADER_CHAN.0.send(true).await.unwrap();
          246         }
          247         drop(state);
          248 
          249         Ok(reply)
          250     }
          251 
          252     async fn broadcast_heartbeat(&mut self) {
          253         debug!("[LEADER] Raft::broadcast_heartbeat()");
          254 
          255         for i in self.peers.clone() {
          256             let state = STATE.lock().await;
          257             let mut args = HeartbeatArgs {
          258                 term: state.current_term,
          259                 leader_id: self.node_id,
          260                 prev_log_index: 0,
          261                 prev_log_term: 0,
          262                 entries: vec![],
          263                 leader_commit: state.commit_index,
          264             };
          265 
          266             let prev_log_index = state.next_index[i.0 as usize] - 1;
          267             drop(state);
          268 
          269             debug!("[LEADER] Raft::broadcast_heartbeat(): Acquiring lock on LOG_STORE");
          270             let logstore = LOG_STORE.lock().await;
          271             if logstore.get_last_index() > prev_log_index {
          272                 args.prev_log_index = prev_log_index;
          273                 args.prev_log_term = logstore.0[prev_log_index as usize].log_term;
          274                 args.entries = logstore.0[prev_log_index as usize..].to_vec();
          275                 drop(logstore);
          276                 debug!("[LEADER] Raft::broadcast_heartbeat(): Dropped lock on LOG_STORE");
          277                 debug!("[LEADER] Raft::broadcast_heartbeat(): Send entries: {:?}", args.entries);
          278             }
          279 
          280             // TODO: Run in background
          281             match self.send_heartbeat(i.0, args).await {
          282                 Ok(v) => debug!("[LEADER] Raft::broadcast_heartbeat(): Got reply: {:?}", v),
          283                 Err(e) => {
          284                     error!(
          285                         "[LEADER] Raft::broadcast_heartbeat(): Failed heartbeat to peer_{} ({})",
          286                         i.0, e
          287                     );
          288                     continue
          289                 }
          290             };
          291         }
          292     }
          293 
          294     async fn send_heartbeat(
          295         &mut self,
          296         node_id: u64,
          297         args: HeartbeatArgs,
          298     ) -> Result<HeartbeatReply, io::Error> {
          299         debug!("Raft::send_heartbeat({}, {:?}", node_id, args);
          300         let addr = self.peers[&node_id];
          301 
          302         let method = RaftMethod::Heartbeat(args);
          303         let payload = method.try_to_vec()?;
          304 
          305         debug!("Raft::send_heartbeat(): Connecting to peer_{}", node_id);
          306         let mut stream = TcpStream::connect(addr).await?;
          307         debug!("Raft::send_heartbeat(): Writing to stream");
          308         stream.write_all(&payload).await?;
          309         debug!("Raft::send_heartbeat(): Wrote to stream");
          310 
          311         debug!("Raft::send_heartbeat(): Reading from stream");
          312         let mut buf = vec![0_u8; 4096];
          313         stream.read(&mut buf).await?;
          314         debug!("Raft::send_heartbeat(): Read from stream");
          315 
          316         let reply = try_from_slice_unchecked::<HeartbeatReply>(&buf)?;
          317 
          318         let mut state = STATE.lock().await;
          319         if reply.success {
          320             debug!("Raft::send_heartbeat(): Got success reply");
          321             if reply.next_index > 0 {
          322                 state.next_index[node_id as usize] = reply.next_index;
          323                 state.match_index[node_id as usize] = reply.next_index - 1;
          324             }
          325         } else if reply.term > state.current_term {
          326             debug!("Raft::send_heartbeat(): reply.term > state.current_term");
          327             state.current_term = reply.term;
          328             state.voted_for = 0;
          329             self.role = Role::Follower;
          330         }
          331         drop(state);
          332 
          333         Ok(reply)
          334     }
          335 }
          336 
          337 pub struct RaftRpc(pub SocketAddr);
          338 
          339 impl RaftRpc {
          340     pub async fn start(&self) {
          341         debug!("RaftRpc::start()");
          342 
          343         debug!("RaftRpc::start(): Binding to {}", self.0);
          344         let listener = TcpListener::bind(self.0).await.unwrap();
          345         let mut incoming = listener.incoming();
          346 
          347         while let Some(stream) = incoming.next().await {
          348             debug!("RaftRpc::start(): Got RPC request");
          349             let stream = stream.unwrap();
          350             let (reader, writer) = &mut (&stream, &stream);
          351 
          352             debug!("RaftRpc::start(): Reading from reader...");
          353             let mut buf = vec![0_u8; 4096];
          354             reader.read(&mut buf).await.unwrap();
          355             debug!("RaftRpc::start(): Read from reader");
          356 
          357             match try_from_slice_unchecked::<RaftMethod>(&buf).unwrap() {
          358                 RaftMethod::Vote(args) => {
          359                     debug!("RaftRpc::start(): Got RaftMethod::Vote");
          360                     let reply = self.request_vote(args).await;
          361                     let payload = reply.try_to_vec().unwrap();
          362 
          363                     debug!("RaftRpc::start(): Vote: Writing to writer...");
          364                     writer.write_all(&payload).await.unwrap();
          365                     debug!("RaftRpc::start(): Vote: Wrote to writer");
          366                 }
          367 
          368                 RaftMethod::Heartbeat(args) => {
          369                     debug!("RaftRpc::start(): Got RaftMethod::Heartbeat");
          370                     let reply = self.heartbeat(args).await;
          371                     let payload = reply.try_to_vec().unwrap();
          372 
          373                     debug!("RaftRpc::start(): Heartbeat: Writing to writer...");
          374                     writer.write_all(&payload).await.unwrap();
          375                     debug!("RaftRpc::start(): Heartbeat: Wrote to writer");
          376                 }
          377             }
          378         }
          379     }
          380 
          381     async fn request_vote(&self, args: VoteArgs) -> VoteReply {
          382         debug!("RaftRpc::request_vote()");
          383         let mut reply = VoteReply { term: 0, vote_granted: false };
          384 
          385         debug!("RaftRpc::request_vote(): Acquiring state lock");
          386         let mut state = STATE.lock().await;
          387         debug!("RaftRpc::request_vote(): Got lock");
          388 
          389         if args.term < state.current_term {
          390             reply.term = state.current_term;
          391             drop(state);
          392             reply.vote_granted = false;
          393             return reply
          394         }
          395 
          396         if state.voted_for == 0 {
          397             state.current_term = args.term;
          398             state.voted_for = args.candidate_id;
          399             drop(state);
          400             reply.term = args.term;
          401             reply.vote_granted = true;
          402             return reply
          403         }
          404 
          405         drop(state);
          406         reply
          407     }
          408 
          409     async fn heartbeat(&self, args: HeartbeatArgs) -> HeartbeatReply {
          410         debug!("RaftRpc::heartbeat()");
          411         let mut reply = HeartbeatReply { success: false, term: 0, next_index: 0 };
          412 
          413         debug!("RaftRpc::heartbeat(): Acquiring state lock");
          414         let state = STATE.lock().await;
          415         debug!("RaftRpc::heartbeat(): Got state lock");
          416         let current_term = state.current_term;
          417         drop(state);
          418         debug!("RaftRpc::heartbeat(): Dropped state lock");
          419 
          420         if args.term < current_term {
          421             reply.success = false;
          422             reply.term = current_term;
          423             return reply
          424         }
          425 
          426         debug!("RaftRpc::heartbeat(): Sending to channel");
          427         HEARTBEAT_CHAN.0.send(true).await.unwrap();
          428         debug!("RaftRpc::heartbeat(): Sent to channel");
          429 
          430         if args.entries.is_empty() {
          431             reply.success = true;
          432             reply.term = current_term;
          433             return reply
          434         }
          435 
          436         debug!("RaftRpc::heartbeat(): Acquiring logstore lock");
          437         let mut logstore = LOG_STORE.lock().await;
          438         debug!("RaftRpc::heartbeat(): Got logstore lock");
          439         if args.prev_log_index > logstore.get_last_index() {
          440             reply.success = false;
          441             reply.term = current_term;
          442             reply.next_index = logstore.get_last_index() + 1;
          443             drop(logstore);
          444             return reply
          445         }
          446 
          447         logstore.0.extend_from_slice(&args.entries);
          448         reply.next_index = logstore.get_last_index() + 1;
          449         drop(logstore);
          450         debug!("RaftRpc::heartbeat(): Dropped logstore lock");
          451 
          452         reply.success = true;
          453         reply.term = current_term;
          454 
          455         reply
          456     }
          457 }
          458 
          459 fn try_from_slice_unchecked<T: BorshDeserialize>(data: &[u8]) -> Result<T, io::Error> {
          460     let mut data_mut = data;
          461     let result = T::deserialize(&mut data_mut)?;
          462     Ok(result)
          463 }