Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(raft): add snapshot feature for raft #3

Merged
merged 12 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/scala.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,23 @@ jobs:

- name: Set up QEMU
uses: docker/setup-qemu-action@v2
if: github.ref_name == 'main'

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
if: github.ref_name == 'main'

- name: Log in to Docker Hub
uses: docker/login-action@v2
if: github.ref_name == 'main'
with:
username: leibniz007
password: ${{ secrets.DOCKERHUB_TOKEN }}

- name: Docker meta
id: meta
uses: docker/metadata-action@v4
if: github.ref_name == 'main'
with:
images: leibniz007/tinylsm
tags: |
Expand All @@ -71,6 +75,7 @@ jobs:

- name: Push to DockerHub
uses: docker/build-push-action@v3
if: github.ref_name == 'main'
with:
context: .
file: Dockerfile
Expand Down
13 changes: 12 additions & 1 deletion src/main/resources/log4j2.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,22 @@
<PatternLayout pattern="${LOG_PATTERN}"/>
<ThresholdFilter level="DEBUG"/>
</Console>
<Console name="STDOUT_NO_THREAD" target="SYSTEM_OUT">
<PatternLayout pattern="%d{yyyy-MM-dd_HH:mm:ss.SSS} %p %c{1}:%L: %m%n"/>
<ThresholdFilter level="DEBUG"/>
</Console>
</Appenders>

<Loggers>
<Logger name="org.apache.pekko" level="INFO" additivity="false">
<AppenderRef ref="STDOUT" />
<AppenderRef ref="STDOUT"/>
</Logger>

<Logger name="io.github.leibnizhu.tinylsm.raft.RaftNode$" level="INFO" additivity="false">
<AppenderRef ref="STDOUT_NO_THREAD"/>
</Logger>
<Logger name="io.github.leibnizhu.tinylsm.raft.RaftState" level="INFO" additivity="false">
<AppenderRef ref="STDOUT_NO_THREAD"/>
</Logger>

<!--<Root level="DEBUG">-->
Expand Down
149 changes: 139 additions & 10 deletions src/main/scala/io/github/leibnizhu/tinylsm/raft/Command.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import org.apache.pekko.actor.typed.ActorRef

sealed trait Command extends Serializable

sealed trait ResponsibleCommand[Resp <: Command] extends Command {
val replyTo: ActorRef[Resp]
}

/**
* 让 Leader 发送心跳的事件
*/
Expand All @@ -29,16 +33,23 @@ case class VoteRequest(
lastLogIndex: Int,
//Candidate最后日志条目的任期号
lastLogTerm: Int,
replyTo: ActorRef[Command],
replyTo: ActorRef[VoteResponse],
) extends Command

case class VoteResponse(
//当前任期号,以便于Candidate去更新自己的任期号
term: Int,
//Candidate赢得了此张选票时为真
voteGranted: Boolean) extends Command
voteGranted: Boolean
) extends Command

case class LogEntry(term: Int, index: Int, command: Array[Byte])
case class LogEntry(
term: Int,
index: Int,
command: Array[Byte]
) {
override def toString: String = s"LogEntry($index@$term, '${new String(command)}')"
}

case class AppendLogRequest(
//领导者的任期
Expand All @@ -53,8 +64,8 @@ case class AppendLogRequest(
entries: Array[LogEntry],
// Leader的已知已提交的最高的日志条目的索引
leaderCommit: Int,
replyTo: ActorRef[Command],
) extends Command
replyTo: ActorRef[AppendLogResponse],
) extends ResponsibleCommand[AppendLogResponse]

case class AppendLogResponse(
//当前任期,对于领导者而言它会更新自己的任期
Expand All @@ -70,18 +81,136 @@ case class AppendLogResponse(
) extends Command

/**
* 客户端的命令
* 【上层应用发送】 的命令(日志操作)
*
* @param command 命令,对应日志记录的命令
* @param replyTo 响应给上层应用用
*/
case class ClientRequest(command: Array[Byte]) extends Command
case class CommandRequest(
command: Array[Byte],
replyTo: ActorRef[CommandResponse]
) extends ResponsibleCommand[CommandResponse]

case class CommandResponse(
index: Int,
term: Int,
isLeader: Boolean
) extends Command

/**
* 发送给上层应用,让上层应用应用日志的命令
*
* @param commandValid 命令是否可用
* @param command 命令,对应日志记录的命令
* @param commandIndex 命令索引
* @param newLeader 是否产生了新leader
*/
case class ApplyLogRequest(
// 正常应用命令
commandValid: Boolean = false,
command: Array[Byte] = null,
commandIndex: Int = -1,
newLeader: Boolean = false) extends Command

case class QueryStateRequest(replyTo: ActorRef[Command]) extends Command
// snapshot相关
snapshotValid: Boolean = false,
snapshot: Array[Byte] = null,
snapshotTerm: Int = -1,
snapshotIndex: Int = -1,

// 产生了新leader
newLeader: Boolean = false
) extends Command {
override def toString: String = if (commandValid) {
s"ApplyLogRequest: Command($commandIndex, ${new String(command)})"
} else if (snapshotValid) {
s"ApplyLogRequest: Snapshot($snapshotIndex@$snapshotTerm@, snapshot size: ${snapshot.length})"
} else if (newLeader) {
s"ApplyLogRequest: NewLeader()"
} else {
super.toString
}
}

object ApplyLogRequest {
def newLeader(): ApplyLogRequest = ApplyLogRequest(newLeader = true)

def logEntry(logEntry: LogEntry): ApplyLogRequest = ApplyLogRequest(
commandValid = true,
command = logEntry.command,
commandIndex = logEntry.index
)

def snapshot(snapshotRequest: InstallSnapshotRequest): ApplyLogRequest = ApplyLogRequest(
snapshotValid = true,
snapshot = snapshotRequest.data,
snapshotTerm = snapshotRequest.lastIncludedTerm,
snapshotIndex = snapshotRequest.lastIncludedIndex
)
}

case class QueryStateRequest(
replyTo: ActorRef[QueryStateResponse]
) extends ResponsibleCommand[QueryStateResponse]

case class QueryStateResponse(state: RaftState) extends Command

/**
* 【上层应用发送】 的snapshot请求
*
* @param index 快照包含的最大日志index
* @param snapshot 快照内容
*/
case class Snapshot(index: Int, snapshot: Array[Byte]) extends Command

/**
* Raft Leader 节点告诉滞后的 Follower 节点用快照替换其状态
*
* @param term term
* @param leaderId leader序号
* @param lastIncludedIndex 快照包含的最大index
* @param lastIncludedTerm 快照包含的最大term
* @param data 快照本照
* @param replyTo leader地址
*/
case class InstallSnapshotRequest(
term: Int,
leaderId: Int,
lastIncludedIndex: Int,
lastIncludedTerm: Int,
data: Array[Byte],
replyTo: ActorRef[InstallSnapshotResponse]
) extends Command

/**
* 安装快照的响应
*
* @param nodeIdx 当前节点下标
* @param term 当前节点的Term
* @param reqTerm InstallSnapshotRequest请求带过来的Term
* @param lastIncludedIndex 安装的快照的lastIncludedIndex
*/
case class InstallSnapshotResponse(
nodeIdx: Int,
term: Int,
reqTerm: Int,
lastIncludedIndex: Int
) extends Command

/**
* 【上层应用】收到快照的ApplyLogRequest后,调用这个判断是否快照是否最新、是否可以安装快照
*
* @param lastIncludedTerm 快照包含的最大日志term
* @param lastIncludedIndex 快照包含的最大日志index
* @param snapshot 快照内容
* @param replyTo 响应给上层应用用
*/
case class CondInstallSnapshotRequest(
lastIncludedTerm: Int,
lastIncludedIndex: Int,
snapshot: Array[Byte],
replyTo: ActorRef[CondInstallSnapshotResponse],
) extends ResponsibleCommand[CondInstallSnapshotResponse]

case class CondInstallSnapshotResponse(success: Boolean) extends Command

case class QueryStateResponse(state: RaftState) extends Command
case class RenewState() extends Command
85 changes: 78 additions & 7 deletions src/main/scala/io/github/leibnizhu/tinylsm/raft/Persistor.scala
Original file line number Diff line number Diff line change
@@ -1,20 +1,87 @@
package io.github.leibnizhu.tinylsm.raft

import io.github.leibnizhu.tinylsm.utils.Config
import io.github.leibnizhu.tinylsm.utils.{ByteArrayReader, ByteArrayWriter, Config}

import java.io.*
import java.util.concurrent.locks.ReentrantReadWriteLock

case class RaftPersistState(
currentTerm: Int,
votedFor: Option[Int],
log: Array[LogEntry],
snapshot: Array[Byte],
snapshotLastIndex: Int,
snapshotLastTerm: Int,
) {
// 计算每条命令是 命令本身长度+命令长度存储的4byte + 存储index的4byte + 存储term的4byte
def logSize(): Int = log.map(_.command.length + 4 + 4 + 4).sum
}

sealed trait Persistor {

val (unPersistLock, persistLock) = {
val rwLock = ReentrantReadWriteLock()
(rwLock.readLock(), rwLock.writeLock())
}

def persist(data: Array[Byte]): Unit
def persist(state: RaftPersistState): Unit = {
val buf = new ByteArrayWriter()
buf.putUint32(state.currentTerm).putUint32(state.votedFor.getOrElse(-1)).putUint32(state.log.length)
state.log.foreach(logEntry => buf.putUint32(logEntry.term).putUint32(logEntry.index)
.putUint32(logEntry.command.length).putBytes(logEntry.command))

// 记录snapshot
buf.putUint32(state.snapshotLastTerm).putUint32(state.snapshotLastIndex)
.putUint32(state.snapshot.length).putBytes(state.snapshot)
doPersist(buf.toArray)
}

def readPersist(): Option[RaftPersistState] = {
val bytes = doReadPersist()
if (bytes == null) {
return None
}
val buf = new ByteArrayReader(bytes)
if (buf.remaining <= 0) {
return None
}
val currentTerm = buf.readUint32()
val votedFor = {
val v = buf.readUint32()
if (v == -1) None else Some(v)
}
// 读日志
val logLength = buf.readUint32()
val log = new Array[LogEntry](logLength)
for (i <- 0 until logLength) {
val logTerm = buf.readUint32()
val logIndex = buf.readUint32()
val commandLength = buf.readUint32()
val command = buf.readBytes(commandLength)
log(i) = LogEntry(logTerm, logTerm, command)
}

val snapshotLastTerm = buf.readUint32()
val snapshotLastIndex = buf.readUint32()
val snapshotLen = buf.readUint32()
val snapshot = if (snapshotLen > 0) {
buf.readBytes(snapshotLen)
} else Array[Byte]()
Some(RaftPersistState(
currentTerm = currentTerm,
votedFor = votedFor,
log = log,
snapshot = snapshot,
snapshotLastIndex = snapshotLastIndex,
snapshotLastTerm = snapshotLastTerm
))
}

def doPersist(data: Array[Byte]): Unit

def doReadPersist(): Array[Byte]

def readPersist(): Array[Byte]
def size(): Int
}

object PersistorFactory {
Expand All @@ -29,19 +96,21 @@ object PersistorFactory {
case class MemoryPersistor(nodeIdx: Int) extends Persistor {
private var data: Array[Byte] = _

override def persist(data: Array[Byte]): Unit = try {
override def doPersist(data: Array[Byte]): Unit = try {
persistLock.lock()
this.data = data
} finally {
persistLock.unlock()
}

override def readPersist(): Array[Byte] = try {
override def doReadPersist(): Array[Byte] = try {
unPersistLock.lock()
this.data
} finally {
unPersistLock.unlock()
}

override def size(): Int = data.length
}

case class FilePersistor(file: File) extends Persistor {
Expand All @@ -53,7 +122,7 @@ case class FilePersistor(file: File) extends Persistor {
file.createNewFile()
}

override def persist(data: Array[Byte]): Unit = try {
override def doPersist(data: Array[Byte]): Unit = try {
persistLock.lock()
val writer = new BufferedOutputStream(new FileOutputStream(file, true))
writer.write(data)
Expand All @@ -63,10 +132,12 @@ case class FilePersistor(file: File) extends Persistor {
}


override def readPersist(): Array[Byte] = try {
override def doReadPersist(): Array[Byte] = try {
unPersistLock.lock()
new BufferedInputStream(FileInputStream(file)).readAllBytes()
} finally {
unPersistLock.unlock()
}

override def size(): Int = file.length().toInt
}
Loading
Loading