use std::future::Future;
use futures::future::{Either, FutureExt};
use jsonrpc_core::{
middleware::Middleware,
types::{Call, Failure, Output, Response},
BoxFuture, ErrorCode, Metadata, MethodCall, Notification,
};
use crate::constants::{INVALID_PARAMETERS_ERROR_CODE, MAX_PARAMS_LOG_LENGTH};
pub struct FixRpcResponseMiddleware;
impl<M: Metadata> Middleware<M> for FixRpcResponseMiddleware {
type Future = BoxFuture<Option<Response>>;
type CallFuture = BoxFuture<Option<Output>>;
fn on_call<Next, NextFuture>(
&self,
call: Call,
meta: M,
next: Next,
) -> Either<Self::CallFuture, NextFuture>
where
Next: Fn(Call, M) -> NextFuture + Send + Sync,
NextFuture: Future<Output = Option<Output>> + Send + 'static,
{
Either::Left(
next(call.clone(), meta)
.map(|mut output| {
Self::fix_error_codes(&mut output);
output
})
.inspect(|output| Self::log_if_error(output, call))
.boxed(),
)
}
}
impl FixRpcResponseMiddleware {
fn fix_error_codes(output: &mut Option<Output>) {
if let Some(Output::Failure(Failure { ref mut error, .. })) = output {
if matches!(error.code, ErrorCode::InvalidParams) {
let original_code = error.code.clone();
error.code = INVALID_PARAMETERS_ERROR_CODE;
tracing::debug!("Replacing RPC error: {original_code:?} with {error}");
}
}
}
fn call_description(call: &Call) -> String {
match call {
Call::MethodCall(MethodCall { method, params, .. }) => {
let mut params = format!("{params:?}");
if params.len() >= MAX_PARAMS_LOG_LENGTH {
params.truncate(MAX_PARAMS_LOG_LENGTH);
params.push_str("...");
}
format!(r#"method = {method:?}, params = {params}"#)
}
Call::Notification(Notification { method, params, .. }) => {
let mut params = format!("{params:?}");
if params.len() >= MAX_PARAMS_LOG_LENGTH {
params.truncate(MAX_PARAMS_LOG_LENGTH);
params.push_str("...");
}
format!(r#"notification = {method:?}, params = {params}"#)
}
Call::Invalid { .. } => "invalid request".to_owned(),
}
}
fn log_if_error(output: &Option<Output>, call: Call) {
if let Some(Output::Failure(Failure { error, .. })) = output {
let call_description = Self::call_description(&call);
tracing::info!("RPC error: {error} in call: {call_description}");
}
}
}