diff --git a/book/book.toml b/book/book.toml index 53142352..4e472e3f 100644 --- a/book/book.toml +++ b/book/book.toml @@ -7,27 +7,10 @@ title ="Web Prover" description ="Backend for Web Proofs" [build] -build-dir="." -extra-watch-dirs=[] # Don't watch any extra directories -create-missing=false # Don't create missing files +build-dir ="." +extra-watch-dirs =[] # Don't watch any extra directories +create-missing =false # Don't create missing files use-default-preprocessors=false -exclude=[ - "target/**/*", - "**/target/**/*", - "**/node_modules/**/*", - "client_wasm/demo/**/*", # Explicitly exclude all demo content - "client_wasm/demo/static/build/**/*", # Extra specific exclusion for build artifacts - "client_wasm/demo/pkg/**/*", # Extra specific exclusion for pkg - "client_wasm/demo/node_modules/**/*", # Extra specific exclusion for node_modules - "build/**/*", - "bin/**/*", - "client/**/*", - "client_ios/**/*", - "fixture/**/*", - "notary/**/*", - "tls/**/*", - "proofs/src/**/*", -] [preprocessor.links] diff --git a/client/src/errors.rs b/client/src/errors.rs index c1516491..0c036a23 100644 --- a/client/src/errors.rs +++ b/client/src/errors.rs @@ -104,6 +104,9 @@ pub enum ClientErrors { #[error(transparent)] WitnessGeneratorError(#[from] web_proof_circuits_witness_generator::WitnessGeneratorError), + + #[error("TEE proof missing")] + TeeProofMissing, } #[cfg(target_arch = "wasm32")] diff --git a/client/src/lib.rs b/client/src/lib.rs index 51cd0228..e0a7b603 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -1,3 +1,5 @@ +extern crate core; + pub mod tlsn; #[cfg(not(target_arch = "wasm32"))] mod tlsn_native; #[cfg(target_arch = "wasm32")] mod tlsn_wasm32; @@ -178,7 +180,7 @@ pub async fn prover_inner_proxy(config: config::Config) -> Result().await?; Ok(Proof::Proxy(tee_proof)) } @@ -186,52 +188,22 @@ pub async fn prover_inner_proxy(config: config::Config) -> Result, + pub signature: SignedVerificationReply, } -impl TeeProof { - pub fn from_manifest(manifest: &Manifest) -> Self { - let manifest_hash = manifest.to_keccak_digest(); - let data = TeeProofData { manifest_hash: manifest_hash.to_vec() }; - // TODO: How do I sign this? - let signature = "sign(hash(TeeProofData))".to_string(); - TeeProof { data, signature } - } +impl TryFrom<&[u8]> for TeeProof { + type Error = serde_json::Error; - pub fn to_write_bytes(&self) -> Vec { - let serialized = self.to_bytes(); - let length = serialized.len() as u32; - let mut wire_data = length.to_le_bytes().to_vec(); - wire_data.extend(serialized); - wire_data - } - - pub fn from_wire_bytes(buffer: &[u8]) -> Self { - // Confirm the buffer is at least large enough to contain the "header" - if buffer.len() < 4 { - panic!("Unexpected buffer length: {} < 4", buffer.len()); - } - - // Extract the first 4 bytes as the length prefix - let length_bytes = &buffer[..4]; - let length = u32::from_le_bytes(length_bytes.try_into().unwrap()) as usize; - - // Ensure the buffer contains enough data for the length specified - if buffer.len() < 4 + length { - panic!("Unexpected buffer length: {} < {} + 4", buffer.len(), length); - } + fn try_from(bytes: &[u8]) -> Result { serde_json::from_slice(bytes) } +} - // Extract the serialized data from the buffer - let serialized_data = &buffer[4..4 + length]; - Self::from_bytes(serialized_data) - } +impl TryFrom for Vec { + type Error = serde_json::Error; - fn to_bytes(&self) -> Vec { serde_json::to_vec(&self).unwrap() } + fn try_from(proof: TeeProof) -> Result { serde_json::to_vec(&proof) } +} - fn from_bytes(bytes: &[u8]) -> TeeProof { serde_json::from_slice(bytes).unwrap() } +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct TeeProofData { + pub manifest_hash: Vec, } diff --git a/client/src/origo.rs b/client/src/origo.rs index c9975090..5ed17801 100644 --- a/client/src/origo.rs +++ b/client/src/origo.rs @@ -1,7 +1,6 @@ // logic common to wasm32 and native use std::collections::HashMap; -use futures::AsyncReadExt; use proofs::{ circuits::construct_setup_data, program::{ @@ -166,8 +165,7 @@ pub(crate) async fn proxy_and_sign_and_generate_proof( response_inputs.clone(), ) .await?; - let flattened_plaintext: Vec = - response_inputs.plaintext.into_iter().flat_map(|x| x).collect(); + let flattened_plaintext: Vec = response_inputs.plaintext.into_iter().flatten().collect(); let http_body = compute_http_witness(&flattened_plaintext, HttpMaskType::Body); let value = json_value_digest::<{ proofs::circuits::MAX_STACK_HEIGHT }>( &http_body, @@ -213,6 +211,18 @@ pub(crate) async fn generate_proof( #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct OrigoSecrets(HashMap>); +impl TryFrom<&OrigoSecrets> for Vec { + type Error = serde_json::Error; + + fn try_from(secrets: &OrigoSecrets) -> Result { serde_json::to_vec(secrets) } +} + +impl TryFrom<&[u8]> for OrigoSecrets { + type Error = serde_json::Error; + + fn try_from(bytes: &[u8]) -> Result { serde_json::from_slice(bytes) } +} + impl OrigoSecrets { pub fn handshake_server_iv(&self) -> Option> { self.0.get("Handshake:server_iv").cloned() @@ -233,69 +243,6 @@ impl OrigoSecrets { pub fn from_origo_conn(origo_conn: &OrigoConnection) -> Self { Self(origo_conn.secret_map.clone()) } - - /// Serializes the `OrigoSecrets` into a length-prefixed byte array. - pub fn to_wire_bytes(&self) -> Vec { - let serialized = self.to_bytes(); - let length = serialized.len() as u32; - // Create the "header" with the length (as little-endian bytes) - let mut wire_data = length.to_le_bytes().to_vec(); - wire_data.extend(serialized); - wire_data - } - - /// Deserializes a `OrigoSecrets` from a length-prefixed byte buffer. - /// - /// Expects a buffer with a 4-byte little-endian "header" followed by the serialized data. - pub fn from_wire_bytes(buffer: &[u8]) -> Self { - // Confirm the buffer is at least large enough to contain the "header" - if buffer.len() < 4 { - panic!("Unexpected buffer length: {} < 4", buffer.len()); - } - - // Extract the first 4 bytes as the length prefix - let length_bytes = &buffer[..4]; - let length = u32::from_le_bytes(length_bytes.try_into().unwrap()) as usize; - - // Ensure the buffer contains enough data for the length specified - if buffer.len() < 4 + length { - panic!("Unexpected buffer length: {} < {} + 4", buffer.len(), length); - } - - // Extract the serialized data from the buffer - let serialized_data = &buffer[4..4 + length]; - Self::from_bytes(serialized_data).unwrap() - } - - fn to_bytes(&self) -> Vec { serde_json::to_vec(&self).unwrap() } - - fn from_bytes(bytes: &[u8]) -> Result { - let secrets: HashMap> = serde_json::from_slice(bytes)?; - Ok(Self(secrets)) - } -} - -// TODO: Refactor into struct helpers/trait -pub(crate) async fn read_wire_struct(stream: &mut R) -> Vec { - // Buffer to store the "header" (4 bytes, indicating the length of the struct) - let mut len_buf = [0u8; 4]; - stream.read_exact(&mut len_buf).await.unwrap(); - // dbg!(format!("len_buf={:?}", len_buf)); - - // Deserialize the length prefix (convert from little-endian to usize) - let body_len = u32::from_le_bytes(len_buf) as usize; - // dbg!(format!("body_len={body_len}")); - - // Allocate a buffer to hold only the bytes needed for the struct - let mut body_buf = vec![0u8; body_len]; - stream.read_exact(&mut body_buf).await.unwrap(); - // dbg!(format!("manifest_buf={:?}", manifest_buf)); - - // Prepend len_buf to manifest_buf - let mut wire_struct_buf = len_buf.to_vec(); - wire_struct_buf.extend(body_buf); - - wire_struct_buf } #[cfg(test)] @@ -307,12 +254,8 @@ mod tests { origo_conn.secret_map.insert("Handshake:server_iv".to_string(), vec![1, 2, 3]); let origo_secrets = OrigoSecrets::from_origo_conn(&origo_conn); - let serialized = origo_secrets.to_bytes(); - let deserialized: OrigoSecrets = OrigoSecrets::from_bytes(&serialized).unwrap(); + let serialized: Vec = origo_secrets.try_into().unwrap(); + let deserialized: OrigoSecrets = OrigoSecrets::try_from(&serialized).unwrap(); assert_eq!(origo_secrets, deserialized); - - let wire_serialized = origo_secrets.to_wire_bytes(); - let wire_deserialized: OrigoSecrets = OrigoSecrets::from_wire_bytes(&wire_serialized); - assert_eq!(origo_secrets, wire_deserialized); } } diff --git a/client/src/origo_native.rs b/client/src/origo_native.rs index c6d03079..c4586300 100644 --- a/client/src/origo_native.rs +++ b/client/src/origo_native.rs @@ -5,12 +5,17 @@ use caratls_ekm_client::DummyTokenVerifier; use caratls_ekm_client::TeeTlsConnector; #[cfg(feature = "tee-google-confidential-space-token-verifier")] use caratls_ekm_google_confidential_space_client::GoogleConfidentialSpaceTokenVerifier; -use futures::{channel::oneshot, AsyncReadExt, AsyncWriteExt as FuturesWriteExt}; +use futures::{ + channel::oneshot, AsyncReadExt, AsyncWriteExt as FuturesWriteExt, SinkExt, StreamExt, +}; use http_body_util::{BodyExt, Full}; use hyper::{body::Bytes, Request, StatusCode}; use tls_client2::origo::OrigoConnection; use tokio::io::AsyncWriteExt; -use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; +use tokio_util::{ + codec::{Framed, LengthDelimitedCodec}, + compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}, +}; use tracing::debug; use crate::{ @@ -139,7 +144,6 @@ async fn handle_origo_mode( tokio::spawn(handled_connection_fut); let response = request_sender.send_request(config.to_request()?).await?; - assert_eq!(response.status(), StatusCode::OK); let payload = response.into_body().collect().await?.to_bytes(); @@ -271,15 +275,22 @@ async fn handle_tee_mode( debug!("Waiting for magic byte, received: {:?}", buffer[0]); } - let manifest_bytes = config.proving.manifest.unwrap().to_wire_bytes(); - reunited_socket.write_all(&manifest_bytes).await?; + let mut framed_reunited_socket = + Framed::new(reunited_socket.compat(), LengthDelimitedCodec::new()); + + let manifest_bytes: Vec = config.proving.manifest.unwrap().try_into()?; + framed_reunited_socket.send(Bytes::from(manifest_bytes)).await?; + + let origo_secret = OrigoSecrets::from_origo_conn(&origo_conn); + let origo_secret_bytes: Vec = (&origo_secret).try_into().unwrap(); + framed_reunited_socket.send(Bytes::from(origo_secret_bytes)).await?; - let origo_secret_bytes = OrigoSecrets::from_origo_conn(&origo_conn).to_wire_bytes(); - reunited_socket.write_all(&origo_secret_bytes).await?; + framed_reunited_socket.flush().await?; - let tee_proof_bytes = crate::origo::read_wire_struct(&mut reunited_socket).await; - let tee_proof = TeeProof::from_wire_bytes(&tee_proof_bytes); - // reunited_socket.close().await?; + let tee_proof_frame = + framed_reunited_socket.next().await.ok_or_else(|| ClientErrors::TeeProofMissing)??; + let tee_proof = TeeProof::try_from(tee_proof_frame.as_ref())?; + debug!("TeeProof: {:?}", tee_proof); Ok((origo_conn, tee_proof)) } diff --git a/client/src/origo_wasm32.rs b/client/src/origo_wasm32.rs index 8f1114fc..86c34772 100644 --- a/client/src/origo_wasm32.rs +++ b/client/src/origo_wasm32.rs @@ -4,18 +4,21 @@ use std::{ pin::Pin, sync::Arc, task::{Context, Poll}, - time::Duration, }; +use bytes::Bytes; #[cfg(feature = "tee-dummy-token-verifier")] use caratls_ekm_client::DummyTokenVerifier; use caratls_ekm_client::TeeTlsConnector; #[cfg(feature = "tee-google-confidential-space-token-verifier")] use caratls_ekm_google_confidential_space_client::GoogleConfidentialSpaceTokenVerifier; -use futures::{channel::oneshot, AsyncReadExt, AsyncWriteExt}; +use futures::{channel::oneshot, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt}; use hyper::StatusCode; use tls_client2::origo::OrigoConnection; -use tokio_util::compat::TokioAsyncReadCompatExt; +use tokio_util::{ + codec::{Framed, LengthDelimitedCodec}, + compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}, +}; use tracing::debug; use wasm_bindgen_futures::{spawn_local, JsFuture}; use web_sys::window; @@ -165,7 +168,7 @@ async fn handle_tee_mode( let (mut request_sender, connection) = hyper::client::conn::http1::handshake(client_tls_conn).await?; - let (connection_sender, connection_receiver) = oneshot::channel(); + let (connection_sender, _) = oneshot::channel(); let connection_fut = connection.without_shutdown(); spawn_local(async { let result = connection_fut.await; @@ -191,14 +194,22 @@ async fn handle_tee_mode( debug!("Waiting for magic byte, received: {:?}", buffer[0]); } - let manifest_bytes = config.proving.manifest.unwrap().to_wire_bytes(); - reunited_socket.write_all(&manifest_bytes).await?; + let mut framed_reunited_socket = + Framed::new(reunited_socket.compat(), LengthDelimitedCodec::new()); + + let manifest = config.proving.manifest.unwrap(); + let manifest_bytes: Vec = (&manifest).try_into().unwrap(); + framed_reunited_socket.send(Bytes::from(manifest_bytes)).await?; + + let origo_secrets = OrigoSecrets::from_origo_conn(&origo_conn); + let origo_secrets_bytes: Vec = (&origo_secrets).try_into().unwrap(); + framed_reunited_socket.send(Bytes::from(origo_secrets_bytes)).await?; - let origo_secret_bytes = OrigoSecrets::from_origo_conn(&origo_conn).to_wire_bytes(); - reunited_socket.write_all(&origo_secret_bytes).await?; + framed_reunited_socket.flush().await?; - let tee_proof_bytes = crate::origo::read_wire_struct(&mut reunited_socket).await; - let tee_proof = TeeProof::from_wire_bytes(&tee_proof_bytes); + let tee_proof_frame = + framed_reunited_socket.next().await.ok_or_else(|| ClientErrors::TeeProofMissing)??; + let tee_proof = TeeProof::try_from(tee_proof_frame.as_ref())?; // TODO something will be dropped here. if it's dropped, it closes ... // let mut client_socket = connection_receiver.await.unwrap()?.io.into_inner(); diff --git a/client_wasm/Cargo.toml b/client_wasm/Cargo.toml index d0702372..64f693d7 100644 --- a/client_wasm/Cargo.toml +++ b/client_wasm/Cargo.toml @@ -43,7 +43,13 @@ wasm-bindgen-test="0.3.34" cargo_metadata="0.19.1" [package.metadata.wasm-pack.profile.release] -wasm-opt=["-Oz", "--enable-threads", "--enable-mutable-globals", "--enable-bulk-memory", "--enable-nontrapping-float-to-int"] +wasm-opt=[ + "-Oz", + "--enable-threads", + "--enable-mutable-globals", + "--enable-bulk-memory", + "--enable-nontrapping-float-to-int", +] [package.metadata.wasm-pack.profile.dev.wasm-bindgen] dwarf-debug-info=true diff --git a/notary/src/errors.rs b/notary/src/errors.rs index a52708be..11dc028d 100644 --- a/notary/src/errors.rs +++ b/notary/src/errors.rs @@ -8,6 +8,8 @@ use thiserror::Error; use tlsn_verifier::tls::{VerifierConfigBuilderError, VerifierError}; use tracing::error; +use crate::tls_parser::Direction; + #[derive(Debug, Error)] pub enum ProxyError { #[error(transparent)] @@ -89,8 +91,8 @@ pub enum NotaryServerError { #[error(transparent)] ProofError(#[from] ProofError), - #[error("Missing application-level data messages. Expected: {0}, Actual: {1}")] - MissingAppDataMessages(usize, usize), + #[error("Missing application-level data messages. Direction={0}, expected={1}, received={2}")] + MissingAppDataMessages(Direction, usize, usize), // TODO: Update to contain feedback #[error("Manifest-request mismatch")] @@ -99,6 +101,12 @@ pub enum NotaryServerError { // TODO: Update to contain feedback #[error("Manifest-response mismatch")] ManifestResponseMismatch, + + #[error("Manifest is missing")] + ManifestMissing, + + #[error("Origo secrets are missing")] + MissingOrigoSecrets, } impl From for NotaryServerError { diff --git a/notary/src/main.rs b/notary/src/main.rs index ccf242a8..9ad6a416 100644 --- a/notary/src/main.rs +++ b/notary/src/main.rs @@ -15,7 +15,6 @@ use axum::{ use errors::NotaryServerError; use hyper::{body::Incoming, server::conn::http1}; use hyper_util::rt::TokioIo; -use k256::ecdsa::SigningKey as Secp256k1SigningKey; use p256::{ecdsa::SigningKey, elliptic_curve::rand_core, pkcs8::DecodePrivateKey}; use rustls::{ pki_types::{CertificateDer, PrivateKeyDer}, @@ -30,6 +29,8 @@ use tower_service::Service; use tracing::{error, info}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use crate::origo::OrigoSigningKey; + mod axum_websocket; mod config; mod errors; @@ -44,7 +45,7 @@ mod websocket_proxy; struct SharedState { notary_signing_key: SigningKey, - origo_signing_key: Secp256k1SigningKey, + origo_signing_key: OrigoSigningKey, tlsn_max_sent_data: usize, tlsn_max_recv_data: usize, origo_sessions: Arc>>>, @@ -101,7 +102,7 @@ async fn main() -> Result<(), NotaryServerError> { let shared_state = Arc::new(SharedState { notary_signing_key: load_notary_signing_key(&c.notary_signing_key), - origo_signing_key: load_origo_signing_key(&c.origo_signing_key), + origo_signing_key: OrigoSigningKey::load(&c.origo_signing_key), tlsn_max_sent_data: c.tlsn_max_sent_data, tlsn_max_recv_data: c.tlsn_max_recv_data, origo_sessions: Default::default(), @@ -377,21 +378,6 @@ pub fn load_notary_signing_key(private_key_pem_path: &str) -> SigningKey { pub fn ephemeral_notary_signing_key() -> SigningKey { SigningKey::random(&mut rand_core::OsRng) } -pub fn load_origo_signing_key(private_key_pem_path: &str) -> Secp256k1SigningKey { - if private_key_pem_path.is_empty() { - info!("Using ephemeral origo signing key"); - ephemeral_origo_signing_key() - } else { - info!("Using origo signing key: {}", private_key_pem_path); - let raw = fs::read_to_string(private_key_pem_path).unwrap(); - Secp256k1SigningKey::from_pkcs8_pem(&raw).unwrap() - } -} - -pub fn ephemeral_origo_signing_key() -> Secp256k1SigningKey { - Secp256k1SigningKey::random(&mut rand_core::OsRng) -} - async fn meta_keys( Path(key): Path, State(state): State>, @@ -404,7 +390,7 @@ async fn meta_keys( }, "origo.pub" => { - let vkey = state.origo_signing_key.verifying_key(); + let vkey = state.origo_signing_key.0.verifying_key(); let hex = hex::encode(vkey.to_sec1_bytes()); (StatusCode::OK, hex) diff --git a/notary/src/origo.rs b/notary/src/origo.rs index 76723852..ea62e14b 100644 --- a/notary/src/origo.rs +++ b/notary/src/origo.rs @@ -1,4 +1,7 @@ -use std::sync::{Arc, Mutex}; +use std::{ + fs, + sync::{Arc, Mutex}, +}; use alloy_primitives::utils::keccak256; use axum::{ @@ -9,6 +12,9 @@ use axum::{ use client::origo::{SignBody, SignedVerificationReply, VerifyBody, VerifyReply}; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; +use k256::{ + ecdsa::SigningKey as Secp256k1SigningKey, elliptic_curve::rand_core, pkcs8::DecodePrivateKey, +}; use proofs::{ circuits::{CIRCUIT_SIZE_512, MAX_STACK_HEIGHT}, errors::ProofError, @@ -37,6 +43,23 @@ use crate::{ verifier, SharedState, }; +pub struct OrigoSigningKey(pub(crate) Secp256k1SigningKey); + +impl OrigoSigningKey { + pub fn load(private_key_pem_path: &str) -> Self { + if private_key_pem_path.is_empty() { + info!("Using ephemeral origo signing key"); + Self::ephemeral() + } else { + info!("Using origo signing key: {}", private_key_pem_path); + let raw = fs::read_to_string(private_key_pem_path).unwrap(); + Self(Secp256k1SigningKey::from_pkcs8_pem(&raw).unwrap()) + } + } + + pub fn ephemeral() -> Self { Self(Secp256k1SigningKey::random(&mut rand_core::OsRng)) } +} + #[derive(Deserialize)] pub struct SignQuery { session_id: Uuid, @@ -120,16 +143,16 @@ pub async fn sign( // need secp256k1 here for Solidity let (signature, recover_id) = - state.origo_signing_key.clone().sign_prehash_recoverable(&merkle_root).unwrap(); + state.origo_signing_key.0.sign_prehash_recoverable(&merkle_root).unwrap(); let signer_address = - alloy_primitives::Address::from_public_key(state.origo_signing_key.verifying_key()); + alloy_primitives::Address::from_public_key(state.origo_signing_key.0.verifying_key()); let verifying_key = k256::ecdsa::VerifyingKey::recover_from_prehash(&merkle_root.clone(), &signature, recover_id) .unwrap(); - assert_eq!(state.origo_signing_key.verifying_key(), &verifying_key); + assert_eq!(state.origo_signing_key.0.verifying_key(), &verifying_key); // TODO is this right? we need lower form S for sure though let s = if signature.normalize_s().is_some() { @@ -157,7 +180,7 @@ pub async fn sign( pub fn sign_verification( query: VerifyReply, State(state): State>, -) -> Result, ProxyError> { +) -> Result { // TODO check OSCP and CT (maybe) // TODO check target_name matches SNI and/or cert name (let's discuss) @@ -170,16 +193,16 @@ pub fn sign_verification( // need secp256k1 here for Solidity let (signature, recover_id) = - state.origo_signing_key.clone().sign_prehash_recoverable(&merkle_root).unwrap(); + state.origo_signing_key.0.sign_prehash_recoverable(&merkle_root).unwrap(); let signer_address = - alloy_primitives::Address::from_public_key(state.origo_signing_key.verifying_key()); + alloy_primitives::Address::from_public_key(state.origo_signing_key.0.verifying_key()); let verifying_key = k256::ecdsa::VerifyingKey::recover_from_prehash(&merkle_root.clone(), &signature, recover_id) .unwrap(); - assert_eq!(state.origo_signing_key.verifying_key(), &verifying_key); + assert_eq!(state.origo_signing_key.0.verifying_key(), &verifying_key); // TODO is this right? we need lower form S for sure though let s = if signature.normalize_s().is_some() { @@ -204,7 +227,7 @@ pub fn sign_verification( signer: "0x".to_string() + &hex::encode(signer_address), }; - Ok(Json(response)) + Ok(response) } #[derive(Clone)] @@ -390,7 +413,7 @@ pub async fn verify( debug!( "value_polynomial_digest: {:?}", polynomial_digest( - &payload.origo_proof.value.clone().unwrap().as_bytes(), + payload.origo_proof.value.clone().unwrap().as_bytes(), ciphertext_digest, 0, ) @@ -409,7 +432,7 @@ pub async fn verify( }, }; - sign_verification(verify_output, State(state)) + sign_verification(verify_output, State(state)).map(Json) } pub async fn websocket_notarize( diff --git a/notary/src/proxy.rs b/notary/src/proxy.rs index 2fe15928..9703c4bc 100644 --- a/notary/src/proxy.rs +++ b/notary/src/proxy.rs @@ -5,17 +5,14 @@ use axum::{ Json, }; use client::TeeProof; -use proofs::program::{ - http::{JsonKey, ManifestRequest, ManifestResponse, ResponseBody}, - manifest::HTTP_1_1, -}; +use proofs::program::http::{JsonKey, ManifestRequest, ManifestResponse, ResponseBody}; use reqwest::{Request, Response}; use serde::Deserialize; use serde_json::Value; -use tracing::{debug, info}; +use tracing::info; use uuid::Uuid; -use crate::{errors::NotaryServerError, SharedState}; +use crate::{errors::NotaryServerError, tee::create_tee_proof, SharedState}; #[derive(Deserialize)] pub struct NotarizeQuery { @@ -51,18 +48,7 @@ pub async fn proxy( let response = from_reqwest_response(reqwest_response).await; // debug!("{:?}", response); - if !payload.manifest.request.is_subset_of(&request) { - return Err(NotaryServerError::ManifestRequestMismatch); - } - - if !payload.manifest.response.is_subset_of(&response) { - return Err(NotaryServerError::ManifestResponseMismatch); - } - - // TODO: Maybe move that to `TeeProof::from_manifest`? - payload.manifest.validate()?; - - let tee_proof = TeeProof::from_manifest(&payload.manifest); + let tee_proof = create_tee_proof(&payload.manifest, &request, &response, State(state))?; Ok(Json(tee_proof)) } diff --git a/notary/src/tee.rs b/notary/src/tee.rs index 53caa3ce..f3f7fb90 100644 --- a/notary/src/tee.rs +++ b/notary/src/tee.rs @@ -9,8 +9,12 @@ use caratls_ekm_google_confidential_space_server::GoogleConfidentialSpaceTokenGe #[cfg(feature = "tee-dummy-token-generator")] use caratls_ekm_server::DummyTokenGenerator; use caratls_ekm_server::TeeTlsAcceptor; -use client::{errors::ClientErrors, origo::OrigoSecrets, TeeProof, TeeProofData}; -use hyper::upgrade::Upgraded; +use client::{ + origo::{OrigoSecrets, VerifyReply}, + TeeProof, TeeProofData, +}; +use futures_util::SinkExt; +use hyper::{body::Bytes, upgrade::Upgraded}; use hyper_util::rt::TokioIo; use proofs::program::{ http::{ManifestRequest, ManifestResponse}, @@ -19,18 +23,22 @@ use proofs::program::{ use serde::Deserialize; use tls_client2::tls_core::msgs::message::MessagePayload; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio_util::compat::FuturesAsyncReadCompatExt; -use tracing::{debug, error, field::debug, info}; +use tokio_stream::StreamExt; +use tokio_util::{ + codec::{Framed, LengthDelimitedCodec}, + compat::FuturesAsyncReadCompatExt, +}; +use tracing::{debug, error, info}; use uuid::Uuid; use ws_stream_tungstenite::WsStream; use crate::{ axum_websocket::WebSocket, errors::NotaryServerError, - origo::proxy_service, + origo::{proxy_service, sign_verification}, tls_parser::{Direction, ParsedMessage, WrappedPayload}, tlsn::ProtocolUpgrade, - verifier, SharedState, + SharedState, }; #[derive(Deserialize)] @@ -148,15 +156,16 @@ pub async fn tee_proxy_service( debug!("Sending magic byte to indicate readiness to read"); tee_tls_stream.write_all(&[0xAA]).await?; - // TODO: Consider implementing from_stream instead of read_wire_struct - let manifest_bytes = read_wire_struct(&mut tee_tls_stream).await; - let manifest = Manifest::from_wire_bytes(&manifest_bytes); - // Checking if manifest is valid - manifest.validate()?; + let mut framed_stream = Framed::new(tee_tls_stream, LengthDelimitedCodec::new()); + + let manifest_frame = + framed_stream.next().await.ok_or_else(|| NotaryServerError::ManifestMissing)??; + let manifest = Manifest::try_from(manifest_frame.as_ref())?; // dbg!(&manifest); - let secret_bytes = read_wire_struct(&mut tee_tls_stream).await; - let origo_secrets = OrigoSecrets::from_wire_bytes(&secret_bytes); + let secret_frame = + framed_stream.next().await.ok_or_else(|| NotaryServerError::MissingOrigoSecrets)??; + let origo_secrets = OrigoSecrets::try_from(secret_frame.as_ref())?; // dbg!(&origo_secrets); let handshake_server_key = @@ -185,46 +194,90 @@ pub async fn tee_proxy_service( .unwrap(); // dbg!(parsed_transcript); - let mut app_data_vec = Vec::new(); - for message in &parsed_transcript.payload { - if let ParsedMessage { payload, direction, .. } = message { - if let Some(app_data) = get_app_data(payload) { - // if let Ok(readable_data) = String::from_utf8(app_data.clone()) { - // debug!("{:?} app_data: {}", direction, readable_data); - // } - app_data_vec.push(app_data); + let (request, response) = extract_request_and_response(&parsed_transcript.payload)?; + + // send TeeProof to client + let tee_proof = create_tee_proof(&manifest, &request, &response, State(state))?; + let tee_proof_bytes: Vec = tee_proof.try_into()?; + framed_stream.send(Bytes::from(tee_proof_bytes)).await?; + framed_stream.flush().await?; + + Ok(()) +} + +pub fn extract_request_and_response( + parsed_transcript: &[ParsedMessage], +) -> Result<(ManifestRequest, ManifestResponse), NotaryServerError> { + let mut request_header = None; + let mut request_body = None; + let mut response_header = None; + let mut response_body = None; + + // Classify parsed messages into headers and bodies + for ParsedMessage { payload, direction, .. } in parsed_transcript { + if let Some(app_data) = get_app_data(payload) { + debug!("App data message {:?}: {:?}", direction, String::from_utf8_lossy(&app_data)); + match direction { + Direction::Sent if request_header.is_none() => request_header = Some(app_data), + Direction::Sent => request_body = Some(app_data), + Direction::Received if response_header.is_none() => response_header = Some(app_data), + Direction::Received => response_body = Some(app_data), } } } - if app_data_vec.len() != 3 { - return Err(NotaryServerError::MissingAppDataMessages(3, app_data_vec.len())); - } + // Ensure mandatory headers are present + let request_header = + request_header.ok_or(NotaryServerError::MissingAppDataMessages(Direction::Sent, 2, 0))?; + let response_header = + response_header.ok_or(NotaryServerError::MissingAppDataMessages(Direction::Received, 2, 0))?; + + let request = ManifestRequest::from_payload(&request_header, request_body.as_deref())?; + let response = + ManifestResponse::from_payload(&response_header, response_body.as_deref().unwrap_or_default())?; - let request_header = app_data_vec[0].clone(); - // TODO: Do we expect to get request_body as well part of app_data? - let response_header = app_data_vec[1].clone(); - let response_body = app_data_vec[2].clone(); + Ok((request, response)) +} + +// TODO: Should TeeProof and other proofs be moved to `proofs` crate? +// Otherwise, adding TeeProof::manifest necessitates extra dependencies on the client +// Notice that all inputs to this function are from `proofs` crate +pub fn create_tee_proof( + manifest: &Manifest, + request: &ManifestRequest, + response: &ManifestResponse, + State(state): State>, +) -> Result { + validate_notarization_legal(manifest, request, response)?; - let request = ManifestRequest::from_payload(&request_header, None)?; - debug!("{:?}", request); + let manifest_hash = manifest.to_keccak_digest()?; + let to_sign = VerifyReply { + // Using manifest hash as a value here since we are not exposing any values extracted + // from the request or response + value: format!("0x{}", hex::encode(manifest_hash)), + manifest: manifest.clone(), + }; + let signature = sign_verification(to_sign, State(state)).unwrap(); - let response = ManifestResponse::from_payload(&response_header, &response_body)?; - debug!("{:?}", response); + let data = TeeProofData { manifest_hash: manifest_hash.to_vec() }; + Ok(TeeProof { data, signature }) +} + +/// Check if `manifest`, `request`, and `response` all fulfill requirements necessary for +/// a proof to be created +fn validate_notarization_legal( + manifest: &Manifest, + request: &ManifestRequest, + response: &ManifestResponse, +) -> Result<(), NotaryServerError> { + manifest.validate()?; if !manifest.request.is_subset_of(&request) { return Err(NotaryServerError::ManifestRequestMismatch); } - if !manifest.response.is_subset_of(&response) { return Err(NotaryServerError::ManifestResponseMismatch); } - - // send TeeProof to client - let tee_proof = TeeProof::from_manifest(&manifest); - let tee_proof_bytes = tee_proof.to_write_bytes(); - tee_tls_stream.write_all(&tee_proof_bytes).await?; - Ok(()) } @@ -237,26 +290,3 @@ fn get_app_data(payload: &WrappedPayload) -> Option> { _ => None, } } - -// TODO: Refactor into struct helpers/trait -async fn read_wire_struct(stream: &mut R) -> Vec { - // Buffer to store the "header" (4 bytes, indicating the length of the struct) - let mut len_buf = [0u8; 4]; - stream.read_exact(&mut len_buf).await.unwrap(); - // dbg!(format!("len_buf={:?}", len_buf)); - - // Deserialize the length prefix (convert from little-endian to usize) - let body_len = u32::from_le_bytes(len_buf) as usize; - // dbg!(format!("body_len={body_len}")); - - // Allocate a buffer to hold only the bytes needed for the struct - let mut body_buf = vec![0u8; body_len]; - stream.read_exact(&mut body_buf).await.unwrap(); - // dbg!(format!("body_buf={:?}", body_buf)); - - // Prepend len_buf to manifest_buf - let mut wire_struct_buf = len_buf.to_vec(); - wire_struct_buf.extend(body_buf); - - wire_struct_buf -} diff --git a/notary/src/tls_parser.rs b/notary/src/tls_parser.rs index 732779e8..f31500ef 100644 --- a/notary/src/tls_parser.rs +++ b/notary/src/tls_parser.rs @@ -1,4 +1,4 @@ -use std::{fmt::Display, io::Cursor}; +use std::io::Cursor; use nom::{bytes::streaming::take, IResult}; use tls_client2::{ @@ -34,6 +34,15 @@ pub enum Direction { Received, } +impl std::fmt::Display for Direction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Direction::Sent => write!(f, "Sent"), + Direction::Received => write!(f, "Received"), + } + } +} + #[derive(Debug)] pub enum WrappedPayload { Encrypted(OpaqueMessage), diff --git a/proofs/src/program/http.rs b/proofs/src/program/http.rs index 56ad44d2..b8eb2b7b 100644 --- a/proofs/src/program/http.rs +++ b/proofs/src/program/http.rs @@ -576,7 +576,6 @@ pub(crate) mod tests { br#"{"key1": "value1"}"#, ) .unwrap(); - let expected_response = create_response!( body: ResponseBody { json: vec![JsonKey::String("key1".to_string())] diff --git a/proofs/src/program/manifest.rs b/proofs/src/program/manifest.rs index 24fb7af0..a0e047d7 100644 --- a/proofs/src/program/manifest.rs +++ b/proofs/src/program/manifest.rs @@ -82,6 +82,24 @@ pub struct Manifest { pub response: ManifestResponse, } +impl TryFrom<&[u8]> for Manifest { + type Error = serde_json::Error; + + fn try_from(bytes: &[u8]) -> Result { serde_json::from_slice(bytes) } +} + +impl TryFrom<&Manifest> for Vec { + type Error = serde_json::Error; + + fn try_from(manifest: &Manifest) -> Result { serde_json::to_vec(manifest) } +} + +impl TryFrom for Vec { + type Error = serde_json::Error; + + fn try_from(manifest: Manifest) -> Result { serde_json::to_vec(&manifest) } +} + impl Manifest { /// Validates `Manifest` request and response fields. They are validated against valid statuses, /// http methods, and template variables. @@ -98,62 +116,16 @@ impl Manifest { Ok(()) } - /// Serializes the `Manifest` into a length-prefixed byte array. - pub fn to_wire_bytes(&self) -> Vec { - let serialized = self.to_bytes(); - let length = serialized.len() as u32; - // Create the "header" with the length (as little-endian bytes) - let mut wire_data = length.to_le_bytes().to_vec(); - wire_data.extend(serialized); - wire_data - } - - /// Deserializes a `Manifest` from a length-prefixed byte buffer. - /// - /// Expects a buffer with a 4-byte little-endian "header" followed by the serialized data. - pub fn from_wire_bytes(buffer: &[u8]) -> Self { - // Confirm the buffer is at least large enough to contain the "header" - if buffer.len() < 4 { - panic!("Unexpected buffer length: {} < 4", buffer.len()); - } - - // Extract the first 4 bytes as the length prefix - let length_bytes = &buffer[..4]; - let length = u32::from_le_bytes(length_bytes.try_into().unwrap()) as usize; - - // Ensure the buffer contains enough data for the length specified - if buffer.len() < 4 + length { - panic!("Unexpected buffer length: {} < {} + 4", buffer.len(), length); - } - - // Extract the serialized data from the buffer - let serialized_data = &buffer[4..4 + length]; - Self::from_bytes(serialized_data) - } - - /// Serializes the `Manifest` to raw bytes. - /// - /// Doesn't expect a "wire" header. - fn to_bytes(&self) -> Vec { - // Serializing as JSON because `untagged` in `JsonKey` break bincode - serde_json::to_vec(&self).unwrap() - } - - /// Deserializes a `Manifest` from raw bytes. - /// - /// Doesn't expect a "wire" header. - fn from_bytes(bytes: &[u8]) -> Manifest { serde_json::from_slice(bytes).unwrap() } - /// Compute a `Keccak256` hash of the serialized Manifest - pub fn to_keccak_digest(&self) -> [u8; 32] { - let bytes = self.to_bytes(); + pub fn to_keccak_digest(&self) -> Result<[u8; 32], ProofError> { + let as_bytes: Vec = self.try_into()?; let mut hasher = Keccak::v256(); let mut output = [0u8; 32]; - hasher.update(&bytes); + hasher.update(&as_bytes); hasher.finalize(&mut output); - output + Ok(output) } } @@ -1015,14 +987,9 @@ mod tests { #[test] fn test_manifest_serialization() { let manifest: Manifest = serde_json::from_str(TEST_MANIFEST).unwrap(); - - let serialized = manifest.to_bytes(); - let deserialized: Manifest = Manifest::from_bytes(&serialized); + let serialized: Vec = manifest.clone().try_into().unwrap(); + let deserialized = Manifest::try_from(serialized.as_slice()).unwrap(); assert_eq!(manifest, deserialized); - - let wire_serialized = manifest.to_wire_bytes(); - let wire_deserialized: Manifest = Manifest::from_wire_bytes(&wire_serialized); - assert_eq!(manifest, wire_deserialized); } macro_rules! create_manifest {