Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Sign TeeProof in notary #477

Merged
merged 13 commits into from
Feb 19, 2025
Merged
23 changes: 3 additions & 20 deletions book/book.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,10 @@ title ="Web Prover"
description ="Backend for Web Proofs"

[build]
build-dir="."
extra-watch-dirs=[] # Don't watch any extra directories
create-missing=false # Don't create missing files
build-dir ="."
extra-watch-dirs =[] # Don't watch any extra directories
create-missing =false # Don't create missing files
use-default-preprocessors=false
exclude=[
"target/**/*",
"**/target/**/*",
"**/node_modules/**/*",
"client_wasm/demo/**/*", # Explicitly exclude all demo content
"client_wasm/demo/static/build/**/*", # Extra specific exclusion for build artifacts
"client_wasm/demo/pkg/**/*", # Extra specific exclusion for pkg
"client_wasm/demo/node_modules/**/*", # Extra specific exclusion for node_modules
"build/**/*",
"bin/**/*",
"client/**/*",
"client_ios/**/*",
"fixture/**/*",
"notary/**/*",
"tls/**/*",
"proofs/src/**/*",
]

[preprocessor.links]

Expand Down
3 changes: 3 additions & 0 deletions client/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ pub enum ClientErrors {

#[error(transparent)]
WitnessGeneratorError(#[from] web_proof_circuits_witness_generator::WitnessGeneratorError),

#[error("TEE proof missing")]
TeeProofMissing,
}

#[cfg(target_arch = "wasm32")]
Expand Down
58 changes: 15 additions & 43 deletions client/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
extern crate core;

pub mod tlsn;
#[cfg(not(target_arch = "wasm32"))] mod tlsn_native;
#[cfg(target_arch = "wasm32")] mod tlsn_wasm32;
Expand Down Expand Up @@ -178,60 +180,30 @@ pub async fn prover_inner_proxy(config: config::Config) -> Result<Proof, errors:
};

let response = client.post(url).json(&proxy_config).send().await?;
assert!(response.status() == hyper::StatusCode::OK);
assert_eq!(response.status(), hyper::StatusCode::OK);
let tee_proof = response.json::<TeeProof>().await?;
Ok(Proof::Proxy(tee_proof))
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TeeProof {
pub data: TeeProofData,
pub signature: String,
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TeeProofData {
pub manifest_hash: Vec<u8>,
pub signature: SignedVerificationReply,
}

impl TeeProof {
pub fn from_manifest(manifest: &Manifest) -> Self {
let manifest_hash = manifest.to_keccak_digest();
let data = TeeProofData { manifest_hash: manifest_hash.to_vec() };
// TODO: How do I sign this?
let signature = "sign(hash(TeeProofData))".to_string();
TeeProof { data, signature }
}
impl TryFrom<&[u8]> for TeeProof {
type Error = serde_json::Error;

pub fn to_write_bytes(&self) -> Vec<u8> {
let serialized = self.to_bytes();
let length = serialized.len() as u32;
let mut wire_data = length.to_le_bytes().to_vec();
wire_data.extend(serialized);
wire_data
}

pub fn from_wire_bytes(buffer: &[u8]) -> Self {
// Confirm the buffer is at least large enough to contain the "header"
if buffer.len() < 4 {
panic!("Unexpected buffer length: {} < 4", buffer.len());
}

// Extract the first 4 bytes as the length prefix
let length_bytes = &buffer[..4];
let length = u32::from_le_bytes(length_bytes.try_into().unwrap()) as usize;

// Ensure the buffer contains enough data for the length specified
if buffer.len() < 4 + length {
panic!("Unexpected buffer length: {} < {} + 4", buffer.len(), length);
}
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> { serde_json::from_slice(bytes) }
}

// Extract the serialized data from the buffer
let serialized_data = &buffer[4..4 + length];
Self::from_bytes(serialized_data)
}
impl TryFrom<TeeProof> for Vec<u8> {
type Error = serde_json::Error;

fn to_bytes(&self) -> Vec<u8> { serde_json::to_vec(&self).unwrap() }
fn try_from(proof: TeeProof) -> Result<Self, Self::Error> { serde_json::to_vec(&proof) }
}

fn from_bytes(bytes: &[u8]) -> TeeProof { serde_json::from_slice(bytes).unwrap() }
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TeeProofData {
pub manifest_hash: Vec<u8>,
}
87 changes: 15 additions & 72 deletions client/src/origo.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// logic common to wasm32 and native
use std::collections::HashMap;

use futures::AsyncReadExt;
use proofs::{
circuits::construct_setup_data,
program::{
Expand Down Expand Up @@ -166,8 +165,7 @@ pub(crate) async fn proxy_and_sign_and_generate_proof(
response_inputs.clone(),
)
.await?;
let flattened_plaintext: Vec<u8> =
response_inputs.plaintext.into_iter().flat_map(|x| x).collect();
let flattened_plaintext: Vec<u8> = response_inputs.plaintext.into_iter().flatten().collect();
let http_body = compute_http_witness(&flattened_plaintext, HttpMaskType::Body);
let value = json_value_digest::<{ proofs::circuits::MAX_STACK_HEIGHT }>(
&http_body,
Expand Down Expand Up @@ -213,6 +211,18 @@ pub(crate) async fn generate_proof(
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct OrigoSecrets(HashMap<String, Vec<u8>>);

impl TryFrom<&OrigoSecrets> for Vec<u8> {
type Error = serde_json::Error;

fn try_from(secrets: &OrigoSecrets) -> Result<Self, Self::Error> { serde_json::to_vec(secrets) }
}

impl TryFrom<&[u8]> for OrigoSecrets {
type Error = serde_json::Error;

fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> { serde_json::from_slice(bytes) }
}

impl OrigoSecrets {
pub fn handshake_server_iv(&self) -> Option<Vec<u8>> {
self.0.get("Handshake:server_iv").cloned()
Expand All @@ -233,69 +243,6 @@ impl OrigoSecrets {
pub fn from_origo_conn(origo_conn: &OrigoConnection) -> Self {
Self(origo_conn.secret_map.clone())
}

/// Serializes the `OrigoSecrets` into a length-prefixed byte array.
pub fn to_wire_bytes(&self) -> Vec<u8> {
let serialized = self.to_bytes();
let length = serialized.len() as u32;
// Create the "header" with the length (as little-endian bytes)
let mut wire_data = length.to_le_bytes().to_vec();
wire_data.extend(serialized);
wire_data
}

/// Deserializes a `OrigoSecrets` from a length-prefixed byte buffer.
///
/// Expects a buffer with a 4-byte little-endian "header" followed by the serialized data.
pub fn from_wire_bytes(buffer: &[u8]) -> Self {
// Confirm the buffer is at least large enough to contain the "header"
if buffer.len() < 4 {
panic!("Unexpected buffer length: {} < 4", buffer.len());
}

// Extract the first 4 bytes as the length prefix
let length_bytes = &buffer[..4];
let length = u32::from_le_bytes(length_bytes.try_into().unwrap()) as usize;

// Ensure the buffer contains enough data for the length specified
if buffer.len() < 4 + length {
panic!("Unexpected buffer length: {} < {} + 4", buffer.len(), length);
}

// Extract the serialized data from the buffer
let serialized_data = &buffer[4..4 + length];
Self::from_bytes(serialized_data).unwrap()
}

fn to_bytes(&self) -> Vec<u8> { serde_json::to_vec(&self).unwrap() }

fn from_bytes(bytes: &[u8]) -> Result<Self, ClientErrors> {
let secrets: HashMap<String, Vec<u8>> = serde_json::from_slice(bytes)?;
Ok(Self(secrets))
}
}

// TODO: Refactor into struct helpers/trait
pub(crate) async fn read_wire_struct<R: AsyncReadExt + Unpin>(stream: &mut R) -> Vec<u8> {
// Buffer to store the "header" (4 bytes, indicating the length of the struct)
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf).await.unwrap();
// dbg!(format!("len_buf={:?}", len_buf));

// Deserialize the length prefix (convert from little-endian to usize)
let body_len = u32::from_le_bytes(len_buf) as usize;
// dbg!(format!("body_len={body_len}"));

// Allocate a buffer to hold only the bytes needed for the struct
let mut body_buf = vec![0u8; body_len];
stream.read_exact(&mut body_buf).await.unwrap();
// dbg!(format!("manifest_buf={:?}", manifest_buf));

// Prepend len_buf to manifest_buf
let mut wire_struct_buf = len_buf.to_vec();
wire_struct_buf.extend(body_buf);

wire_struct_buf
}

#[cfg(test)]
Expand All @@ -307,12 +254,8 @@ mod tests {
origo_conn.secret_map.insert("Handshake:server_iv".to_string(), vec![1, 2, 3]);
let origo_secrets = OrigoSecrets::from_origo_conn(&origo_conn);

let serialized = origo_secrets.to_bytes();
let deserialized: OrigoSecrets = OrigoSecrets::from_bytes(&serialized).unwrap();
let serialized: Vec<u8> = origo_secrets.try_into().unwrap();
let deserialized: OrigoSecrets = OrigoSecrets::try_from(&serialized).unwrap();
assert_eq!(origo_secrets, deserialized);

let wire_serialized = origo_secrets.to_wire_bytes();
let wire_deserialized: OrigoSecrets = OrigoSecrets::from_wire_bytes(&wire_serialized);
assert_eq!(origo_secrets, wire_deserialized);
}
}
31 changes: 21 additions & 10 deletions client/src/origo_native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@ use caratls_ekm_client::DummyTokenVerifier;
use caratls_ekm_client::TeeTlsConnector;
#[cfg(feature = "tee-google-confidential-space-token-verifier")]
use caratls_ekm_google_confidential_space_client::GoogleConfidentialSpaceTokenVerifier;
use futures::{channel::oneshot, AsyncReadExt, AsyncWriteExt as FuturesWriteExt};
use futures::{
channel::oneshot, AsyncReadExt, AsyncWriteExt as FuturesWriteExt, SinkExt, StreamExt,
};
use http_body_util::{BodyExt, Full};
use hyper::{body::Bytes, Request, StatusCode};
use tls_client2::origo::OrigoConnection;
use tokio::io::AsyncWriteExt;
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tokio_util::{
codec::{Framed, LengthDelimitedCodec},
compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt},
};
use tracing::debug;

use crate::{
Expand Down Expand Up @@ -139,7 +144,6 @@ async fn handle_origo_mode(
tokio::spawn(handled_connection_fut);

let response = request_sender.send_request(config.to_request()?).await?;

assert_eq!(response.status(), StatusCode::OK);

let payload = response.into_body().collect().await?.to_bytes();
Expand Down Expand Up @@ -271,15 +275,22 @@ async fn handle_tee_mode(
debug!("Waiting for magic byte, received: {:?}", buffer[0]);
}

let manifest_bytes = config.proving.manifest.unwrap().to_wire_bytes();
reunited_socket.write_all(&manifest_bytes).await?;
let mut framed_reunited_socket =
Framed::new(reunited_socket.compat(), LengthDelimitedCodec::new());

let manifest_bytes: Vec<u8> = config.proving.manifest.unwrap().try_into()?;
framed_reunited_socket.send(Bytes::from(manifest_bytes)).await?;

let origo_secret = OrigoSecrets::from_origo_conn(&origo_conn);
let origo_secret_bytes: Vec<u8> = (&origo_secret).try_into().unwrap();
framed_reunited_socket.send(Bytes::from(origo_secret_bytes)).await?;

let origo_secret_bytes = OrigoSecrets::from_origo_conn(&origo_conn).to_wire_bytes();
reunited_socket.write_all(&origo_secret_bytes).await?;
framed_reunited_socket.flush().await?;

let tee_proof_bytes = crate::origo::read_wire_struct(&mut reunited_socket).await;
let tee_proof = TeeProof::from_wire_bytes(&tee_proof_bytes);
// reunited_socket.close().await?;
let tee_proof_frame =
framed_reunited_socket.next().await.ok_or_else(|| ClientErrors::TeeProofMissing)??;
let tee_proof = TeeProof::try_from(tee_proof_frame.as_ref())?;
debug!("TeeProof: {:?}", tee_proof);

Ok((origo_conn, tee_proof))
}
Expand Down
31 changes: 21 additions & 10 deletions client/src/origo_wasm32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@ use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};

use bytes::Bytes;
#[cfg(feature = "tee-dummy-token-verifier")]
use caratls_ekm_client::DummyTokenVerifier;
use caratls_ekm_client::TeeTlsConnector;
#[cfg(feature = "tee-google-confidential-space-token-verifier")]
use caratls_ekm_google_confidential_space_client::GoogleConfidentialSpaceTokenVerifier;
use futures::{channel::oneshot, AsyncReadExt, AsyncWriteExt};
use futures::{channel::oneshot, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt};
use hyper::StatusCode;
use tls_client2::origo::OrigoConnection;
use tokio_util::compat::TokioAsyncReadCompatExt;
use tokio_util::{
codec::{Framed, LengthDelimitedCodec},
compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt},
};
use tracing::debug;
use wasm_bindgen_futures::{spawn_local, JsFuture};
use web_sys::window;
Expand Down Expand Up @@ -165,7 +168,7 @@ async fn handle_tee_mode(
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(client_tls_conn).await?;

let (connection_sender, connection_receiver) = oneshot::channel();
let (connection_sender, _) = oneshot::channel();
let connection_fut = connection.without_shutdown();
spawn_local(async {
let result = connection_fut.await;
Expand All @@ -191,14 +194,22 @@ async fn handle_tee_mode(
debug!("Waiting for magic byte, received: {:?}", buffer[0]);
}

let manifest_bytes = config.proving.manifest.unwrap().to_wire_bytes();
reunited_socket.write_all(&manifest_bytes).await?;
let mut framed_reunited_socket =
Framed::new(reunited_socket.compat(), LengthDelimitedCodec::new());

let manifest = config.proving.manifest.unwrap();
let manifest_bytes: Vec<u8> = (&manifest).try_into().unwrap();
framed_reunited_socket.send(Bytes::from(manifest_bytes)).await?;

let origo_secrets = OrigoSecrets::from_origo_conn(&origo_conn);
let origo_secrets_bytes: Vec<u8> = (&origo_secrets).try_into().unwrap();
framed_reunited_socket.send(Bytes::from(origo_secrets_bytes)).await?;

let origo_secret_bytes = OrigoSecrets::from_origo_conn(&origo_conn).to_wire_bytes();
reunited_socket.write_all(&origo_secret_bytes).await?;
framed_reunited_socket.flush().await?;

let tee_proof_bytes = crate::origo::read_wire_struct(&mut reunited_socket).await;
let tee_proof = TeeProof::from_wire_bytes(&tee_proof_bytes);
let tee_proof_frame =
framed_reunited_socket.next().await.ok_or_else(|| ClientErrors::TeeProofMissing)??;
let tee_proof = TeeProof::try_from(tee_proof_frame.as_ref())?;

// TODO something will be dropped here. if it's dropped, it closes ...
// let mut client_socket = connection_receiver.await.unwrap()?.io.into_inner();
Expand Down
8 changes: 7 additions & 1 deletion client_wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ wasm-bindgen-test="0.3.34"
cargo_metadata="0.19.1"

[package.metadata.wasm-pack.profile.release]
wasm-opt=["-Oz", "--enable-threads", "--enable-mutable-globals", "--enable-bulk-memory", "--enable-nontrapping-float-to-int"]
wasm-opt=[
"-Oz",
"--enable-threads",
"--enable-mutable-globals",
"--enable-bulk-memory",
"--enable-nontrapping-float-to-int",
]

[package.metadata.wasm-pack.profile.dev.wasm-bindgen]
dwarf-debug-info=true
Loading
Loading