Skip to content

Commit

Permalink
add db migration (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
Liuhaai authored Oct 11, 2024
1 parent 2dc9d1a commit 8765c27
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 56 deletions.
5 changes: 3 additions & 2 deletions risc0-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ tempfile = "3.7.1"
serde = "1.0"
serde_json = "1.0"
serde_derive = "1.0"
diesel = { version = "2.1.0", features = ["postgres"] }
diesel = { version = "2.1.0", features = ["postgres", "r2d2"] }
dotenvy = "0.15"
hex = { version = "0.4.3", default-features = false, features = ["alloc"] }
bincode = "1.3"
Expand All @@ -41,6 +41,7 @@ rustc-hex = "2.1.0"
tonic = "0.8"
tonic-reflection = "0.6.0"
prost = "0.11"
diesel_migrations = "2.2.0"

[dev-dependencies]
lazy_static = "=1.4.0"
Expand All @@ -51,4 +52,4 @@ tonic-build = "0.8"
[features]
cuda = ["risc0-zkvm/cuda"]
default = []
metal = ["risc0-zkvm/metal"]
metal = ["risc0-zkvm/metal"]
92 changes: 73 additions & 19 deletions risc0-server/src/db/pgdb.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use std::env;

use diesel::{PgConnection, Connection, SelectableHelper, RunQueryDsl, QueryDsl, ExpressionMethods};
use diesel::{
r2d2::{ConnectionManager, Pool},
Connection, ExpressionMethods, PgConnection, QueryDsl, RunQueryDsl, SelectableHelper,
};
use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
use dotenvy::dotenv;

use crate::db::models::{NewPoof, NewVm, Proof, Vm};

pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");

pub fn establish_connection() -> PgConnection {
dotenv().ok();

Expand All @@ -13,9 +19,32 @@ pub fn establish_connection() -> PgConnection {
.unwrap_or_else(|_| panic!("Error connecting to {}", database_url))
}

pub fn create_vm<'a>(conn: &mut PgConnection, prj_name: &'a str, elf_str: &'a str, id_str: &'a str) -> Result<Vm, diesel::result::Error> {
use crate::db::schema::vms::dsl::*;
pub fn get_connection_pool() -> Pool<ConnectionManager<PgConnection>> {
dotenv().ok();
let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");

let manager = ConnectionManager::<PgConnection>::new(database_url);
// Refer to the `r2d2` documentation for more methods to use
// when building a connection pool
Pool::builder()
.test_on_check_out(true)
.build(manager)
.expect("Could not build connection pool")
}

pub fn run_migration(conn: &mut PgConnection) {
conn.run_pending_migrations(MIGRATIONS)
.expect("Error running migrations");
}

pub fn create_vm<'a>(
conn: &mut PgConnection,
prj_name: &'a str,
elf_str: &'a str,
id_str: &'a str,
) -> Result<Vm, diesel::result::Error> {
use crate::db::schema::vms;
use crate::db::schema::vms::dsl::*;

let new_vm = NewVm {
project_name: prj_name,
Expand All @@ -27,36 +56,53 @@ pub fn create_vm<'a>(conn: &mut PgConnection, prj_name: &'a str, elf_str: &'a st
Ok(_) => (),
Err(err) => {
return Err(err);
},
}
};


diesel::insert_into(vms::table)
.values(&new_vm)
.returning(Vm::as_returning())
.get_result(conn)
.values(&new_vm)
.returning(Vm::as_returning())
.get_result(conn)
}

pub fn get_vm<'a>(conn: &mut PgConnection, id_str: &'a str) -> Result<Vm, diesel::result::Error> {
use crate::db::schema::vms::dsl::*;

// let results = vms.filter(image_id.eq(id_str)).limit(1).select(Vm::as_select()).load(conn).expect("Error loading vms");

let vm = vms.filter(image_id.eq(id_str)).select(Vm::as_select()).first(conn)?;
let vm = vms
.filter(image_id.eq(id_str))
.select(Vm::as_select())
.first(conn)?;
Ok(vm)
}

pub fn get_vm_by_project<'a>(conn: &mut PgConnection, project: &'a str) -> Result<Vm, diesel::result::Error> {
pub fn get_vm_by_project<'a>(
conn: &mut PgConnection,
project: &'a str,
) -> Result<Vm, diesel::result::Error> {
use crate::db::schema::vms::dsl::*;

// let results = vms.filter(image_id.eq(id_str)).limit(1).select(Vm::as_select()).load(conn).expect("Error loading vms");

let vm = vms.filter(project_name.eq(project)).select(Vm::as_select()).first(conn)?;
let vm = vms
.filter(project_name.eq(project))
.select(Vm::as_select())
.first(conn)?;
Ok(vm)
}

pub fn create_proof<'a>(conn: &mut PgConnection, project_id: &'a str, task_id: &'a str, client_id: &'a str, sequencer_sign: &'a str, image_id: &'a str,
datas_input: &'a str, receipt_type: &'a str, status: &'a str) -> Proof {
pub fn create_proof<'a>(
conn: &mut PgConnection,
project_id: &'a str,
task_id: &'a str,
client_id: &'a str,
sequencer_sign: &'a str,
image_id: &'a str,
datas_input: &'a str,
receipt_type: &'a str,
status: &'a str,
) -> Proof {
use crate::db::schema::proofs;

let new_proof = NewPoof {
Expand All @@ -71,13 +117,17 @@ pub fn create_proof<'a>(conn: &mut PgConnection, project_id: &'a str, task_id: &
};

diesel::insert_into(proofs::table)
.values(&new_proof)
.returning(Proof::as_returning())
.get_result(conn)
.expect("Error saving new proof")
.values(&new_proof)
.returning(Proof::as_returning())
.get_result(conn)
.expect("Error saving new proof")
}

pub fn update_proof_with_receipt<'a>(conn: &mut PgConnection, p: &'a Proof, r: &'a String) -> Proof {
pub fn update_proof_with_receipt<'a>(
conn: &mut PgConnection,
p: &'a Proof,
r: &'a String,
) -> Proof {
use crate::db::schema::proofs::dsl::*;

diesel::update(proofs.filter(id.eq(p.id)))
Expand All @@ -87,7 +137,11 @@ pub fn update_proof_with_receipt<'a>(conn: &mut PgConnection, p: &'a Proof, r: &
.expect("Error updating proof")
}

pub fn update_proof_status_with_receipt<'a>(conn: &mut PgConnection, p: &'a Proof, s: &'a String) -> Proof {
pub fn update_proof_status_with_receipt<'a>(
conn: &mut PgConnection,
p: &'a Proof,
s: &'a String,
) -> Proof {
use crate::db::schema::proofs::dsl::*;

diesel::update(proofs.filter(id.eq(p.id)))
Expand Down
72 changes: 45 additions & 27 deletions risc0-server/src/grpc/server.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
use std::{io::Read, str::FromStr};

use diesel::{
r2d2::{ConnectionManager, Pool},
PgConnection,
};
use ethers::abi::{encode, Token};
use flate2::read::ZlibDecoder;
use hex::FromHex;
use risc0_ethereum_contracts::groth16;
use risc0_zkvm::{InnerReceipt, Receipt};
use rust_grpc::grpc::vm_runtime::{
vm_runtime_server::VmRuntime, CreateRequest, CreateResponse, ExecuteRequest, ExecuteResponse,
};
use serde_json::Value;
use tonic::{Request, Response, Status};
use risc0_ethereum_contracts::groth16;

use crate::{db, handlers::proof::get_receipt, model::models::ProofType};

use rust_grpc::grpc::vm_runtime::{vm_runtime_server::VmRuntime, CreateRequest, CreateResponse, ExecuteRequest, ExecuteResponse};

pub struct Risc0Server {
db_conn_pool: Pool<ConnectionManager<PgConnection>>,
}

#[derive(Debug)]
pub struct Risc0Server {}
impl Risc0Server {
pub fn new() -> Self {
let pool = db::pgdb::get_connection_pool();
db::pgdb::run_migration(&mut pool.get().unwrap());
Risc0Server { db_conn_pool: pool }
}
}

#[tonic::async_trait]
impl VmRuntime for Risc0Server {
Expand Down Expand Up @@ -59,8 +72,8 @@ impl VmRuntime for Risc0Server {
}
}

let connection = &mut db::pgdb::establish_connection();
let _ = db::pgdb::create_vm(connection, &project, &elf_str, &id_str);
let conn = &mut self.db_conn_pool.get().unwrap();
let _ = db::pgdb::create_vm(conn, &project, &elf_str, &id_str);

Ok(Response::new(CreateResponse {}))
}
Expand All @@ -79,13 +92,18 @@ impl VmRuntime for Risc0Server {
let datas = request.datas;

if datas.len() == 0 {
return Err(Status::invalid_argument("need datas"))
return Err(Status::invalid_argument("need datas"));
}

let connection = &mut db::pgdb::establish_connection();
let image_id = match db::pgdb::get_vm_by_project(connection, &project_id.to_string()) {
let conn = &mut self.db_conn_pool.get().unwrap();
let image_id = match db::pgdb::get_vm_by_project(conn, &project_id.to_string()) {
Ok(v) => v.image_id,
Err(_) => return Err(Status::not_found(format!("{} not found", project_id.to_string()))),
Err(_) => {
return Err(Status::not_found(format!(
"{} not found",
project_id.to_string()
)))
}
};

// TODO move to guest method
Expand All @@ -95,7 +113,12 @@ impl VmRuntime for Risc0Server {
let mut receipt_type = ProofType::from_str("Snark").unwrap();
// TODO check v
if v.get("receipt_type").is_some() {
receipt_type = v["receipt_type"].as_str().unwrap().to_string().parse().unwrap();
receipt_type = v["receipt_type"]
.as_str()
.unwrap()
.to_string()
.parse()
.unwrap();
}
// let receipt_type: Result<ProofType, _> =
// v["receipt_type"].as_str().unwrap().to_string().parse();
Expand Down Expand Up @@ -123,17 +146,12 @@ impl VmRuntime for Risc0Server {
let seal = groth16::encode(risc_receipt.inner.groth16().unwrap().seal.clone()).unwrap();
let journal = risc_receipt.journal.bytes.clone();

let tokens = vec![
Token::Bytes(seal),
Token::Bytes(journal),
];
let tokens = vec![Token::Bytes(seal), Token::Bytes(journal)];

result = encode(&tokens);
}

Ok(Response::new(ExecuteResponse {
result,
}))
Ok(Response::new(ExecuteResponse { result }))
}
}

Expand All @@ -150,17 +168,17 @@ fn param() {
let journal = risc_receipt.journal.bytes.clone();
println!("journal {}", format!("0x{}", hex::encode(journal.clone())));

let tokens = vec![
Token::Bytes(seal),
Token::Bytes(journal),
];
let tokens = vec![Token::Bytes(seal), Token::Bytes(journal)];

let result = encode(&tokens);
println!("bytes_seal_journal {}", format!("0x{}", hex::encode(result.clone())));
},
let result = encode(&tokens);
println!(
"bytes_seal_journal {}",
format!("0x{}", hex::encode(result.clone()))
);
}
InnerReceipt::Composite(_) => todo!(),
InnerReceipt::Succinct(_) => todo!(),
InnerReceipt::Fake(_) => todo!(),
_ => todo!(),
}
}
}
17 changes: 9 additions & 8 deletions risc0-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,29 @@ use rust_grpc::grpc::vm_runtime::vm_runtime_server::VmRuntimeServer;

use tonic::transport::Server;

mod db;
mod core;
mod db;
mod grpc;
mod handlers;
mod model;
mod tools;
#[cfg(test)]
mod tests;
mod tools;

pub async fn start_grpc_server(addr: &str) {
let addr = addr.parse().unwrap();
let risc0_server = Risc0Server{};
let risc0_server = Risc0Server::new();

tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.init();
.with_max_level(tracing::Level::DEBUG)
.init();

tracing::info!(message = "Starting server.", %addr);

Server::builder()
.trace_fn(|_| tracing::info_span!("risc0_server"))
.add_service(VmRuntimeServer::new(risc0_server))
.serve(addr)
.await.unwrap();
}
.await
.unwrap();
}

0 comments on commit 8765c27

Please sign in to comment.