1use std::future::Future;
6
7use std::pin::Pin;
8
9use futures::{future, FutureExt};
10use http_body_util::BodyExt;
11use hyper::header;
12use jsonrpsee::{
13 core::BoxError,
14 server::{HttpBody, HttpRequest, HttpResponse},
15};
16use jsonrpsee_types::ErrorObject;
17use serde::{Deserialize, Serialize};
18use tower::Service;
19
20use super::cookie::Cookie;
21
22use base64::{engine::general_purpose::URL_SAFE, Engine as _};
23
24#[derive(Clone, Debug)]
56pub struct HttpRequestMiddleware<S> {
57 service: S,
58 cookie: Option<Cookie>,
59}
60
61impl<S> HttpRequestMiddleware<S> {
62 pub fn new(service: S, cookie: Option<Cookie>) -> Self {
64 Self { service, cookie }
65 }
66
67 pub fn check_credentials(&self, headers: &header::HeaderMap) -> bool {
69 self.cookie.as_ref().is_none_or(|internal_cookie| {
70 headers
71 .get(header::AUTHORIZATION)
72 .and_then(|auth_header| auth_header.to_str().ok())
73 .and_then(|auth_header| auth_header.split_whitespace().nth(1))
74 .and_then(|encoded| URL_SAFE.decode(encoded).ok())
75 .and_then(|decoded| String::from_utf8(decoded).ok())
76 .and_then(|request_cookie| request_cookie.split(':').nth(1).map(String::from))
77 .is_some_and(|passwd| internal_cookie.authenticate(passwd))
78 })
79 }
80
81 pub fn insert_or_replace_content_type_header(headers: &mut header::HeaderMap) {
103 if !headers.contains_key(header::CONTENT_TYPE)
104 || headers
105 .get(header::CONTENT_TYPE)
106 .filter(|value| {
107 value
108 .to_str()
109 .ok()
110 .unwrap_or_default()
111 .starts_with("text/plain")
112 })
113 .is_some()
114 {
115 headers.insert(
116 header::CONTENT_TYPE,
117 header::HeaderValue::from_static("application/json"),
118 );
119 }
120 }
121
122 async fn request_to_json_rpc_2(
124 request: HttpRequest<HttpBody>,
125 ) -> (JsonRpcVersion, HttpRequest<HttpBody>) {
126 let (parts, body) = request.into_parts();
127 let bytes = body
128 .collect()
129 .await
130 .expect("Failed to collect body data")
131 .to_bytes();
132 let (version, bytes) =
133 if let Ok(request) = serde_json::from_slice::<'_, JsonRpcRequest>(bytes.as_ref()) {
134 let version = request.version();
135 if matches!(version, JsonRpcVersion::Unknown) {
136 (version, bytes)
137 } else {
138 (
139 version,
140 serde_json::to_vec(&request.into_2()).expect("valid").into(),
141 )
142 }
143 } else {
144 (JsonRpcVersion::Unknown, bytes)
145 };
146 (
147 version,
148 HttpRequest::from_parts(parts, HttpBody::from(bytes.as_ref().to_vec())),
149 )
150 }
151 async fn response_from_json_rpc_2(
153 version: JsonRpcVersion,
154 response: HttpResponse<HttpBody>,
155 ) -> HttpResponse<HttpBody> {
156 let (parts, body) = response.into_parts();
157 let bytes = body
158 .collect()
159 .await
160 .expect("Failed to collect body data")
161 .to_bytes();
162 let bytes =
163 if let Ok(response) = serde_json::from_slice::<'_, JsonRpcResponse>(bytes.as_ref()) {
164 serde_json::to_vec(&response.into_version(version))
165 .expect("valid")
166 .into()
167 } else {
168 bytes
169 };
170 HttpResponse::from_parts(parts, HttpBody::from(bytes.as_ref().to_vec()))
171 }
172}
173
174#[derive(Clone)]
176pub struct HttpRequestMiddlewareLayer {
177 cookie: Option<Cookie>,
178}
179
180impl HttpRequestMiddlewareLayer {
181 pub fn new(cookie: Option<Cookie>) -> Self {
183 Self { cookie }
184 }
185}
186
187impl<S> tower::Layer<S> for HttpRequestMiddlewareLayer {
188 type Service = HttpRequestMiddleware<S>;
189
190 fn layer(&self, service: S) -> Self::Service {
191 HttpRequestMiddleware::new(service, self.cookie.clone())
192 }
193}
194
195pub trait With<T> {
197 fn with(self, _: T) -> Self;
199}
200
201impl<S> With<Cookie> for HttpRequestMiddleware<S> {
202 fn with(mut self, cookie: Cookie) -> Self {
203 self.cookie = Some(cookie);
204 self
205 }
206}
207
208impl<S> Service<HttpRequest<HttpBody>> for HttpRequestMiddleware<S>
209where
210 S: Service<HttpRequest, Response = HttpResponse> + std::clone::Clone + Send + 'static,
211 S::Error: Into<BoxError> + 'static,
212 S::Future: Send + 'static,
213{
214 type Response = S::Response;
215 type Error = BoxError;
216 type Future =
217 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
218
219 fn poll_ready(
220 &mut self,
221 cx: &mut std::task::Context<'_>,
222 ) -> std::task::Poll<Result<(), Self::Error>> {
223 self.service.poll_ready(cx).map_err(Into::into)
224 }
225
226 fn call(&mut self, mut request: HttpRequest<HttpBody>) -> Self::Future {
227 if !self.check_credentials(request.headers_mut()) {
229 let error = ErrorObject::borrowed(401, "unauthenticated method", None);
230 return future::err(BoxError::from(error)).boxed();
232 }
233
234 Self::insert_or_replace_content_type_header(request.headers_mut());
236
237 let mut service = self.service.clone();
238
239 async move {
240 let (version, request) = Self::request_to_json_rpc_2(request).await;
241 let response = service.call(request).await.map_err(Into::into)?;
242 Ok(Self::response_from_json_rpc_2(version, response).await)
243 }
244 .boxed()
245 }
246}
247
248#[derive(Clone, Copy, Debug)]
249enum JsonRpcVersion {
250 Bitcoind,
252 Lightwalletd,
255 TwoPointZero,
257 Unknown,
259}
260
261#[derive(Debug, Deserialize, Serialize)]
263struct JsonRpcRequest {
264 #[serde(skip_serializing_if = "Option::is_none")]
265 jsonrpc: Option<String>,
266 method: String,
267 #[serde(skip_serializing_if = "Option::is_none")]
268 params: Option<serde_json::Value>,
269 #[serde(skip_serializing_if = "Option::is_none")]
270 id: Option<serde_json::Value>,
271}
272
273impl JsonRpcRequest {
274 fn version(&self) -> JsonRpcVersion {
275 match (self.jsonrpc.as_deref(), &self.params, &self.id) {
276 (
277 Some("2.0"),
278 _,
279 None
280 | Some(
281 serde_json::Value::Null
282 | serde_json::Value::String(_)
283 | serde_json::Value::Number(_),
284 ),
285 ) => JsonRpcVersion::TwoPointZero,
286 (Some("1.0"), Some(_), Some(_)) => JsonRpcVersion::Lightwalletd,
287 (None, Some(_), Some(_)) => JsonRpcVersion::Bitcoind,
288 _ => JsonRpcVersion::Unknown,
289 }
290 }
291
292 fn into_2(mut self) -> Self {
293 self.jsonrpc = Some("2.0".into());
294 self
295 }
296}
297#[derive(Debug, Deserialize, Serialize)]
299struct JsonRpcResponse {
300 #[serde(skip_serializing_if = "Option::is_none")]
301 jsonrpc: Option<String>,
302 id: serde_json::Value,
303 #[serde(skip_serializing_if = "Option::is_none")]
304 result: Option<Box<serde_json::value::RawValue>>,
305 #[serde(skip_serializing_if = "Option::is_none")]
306 error: Option<serde_json::Value>,
307}
308
309impl JsonRpcResponse {
310 fn into_version(mut self, version: JsonRpcVersion) -> Self {
311 match version {
312 JsonRpcVersion::Bitcoind => {
313 self.jsonrpc = None;
314 self.result = self
315 .result
316 .or_else(|| serde_json::value::to_raw_value(&()).ok());
317 self.error = self.error.or(Some(serde_json::Value::Null));
318 }
319 JsonRpcVersion::Lightwalletd => {
320 self.jsonrpc = Some("1.0".into());
321 self.result = self
322 .result
323 .or_else(|| serde_json::value::to_raw_value(&()).ok());
324 self.error = self.error.or(Some(serde_json::Value::Null));
325 }
326 JsonRpcVersion::TwoPointZero => {
327 assert_eq!(self.jsonrpc.as_deref(), Some("2.0"));
331 if self.error.is_none() {
332 self.result = self
333 .result
334 .or_else(|| serde_json::value::to_raw_value(&()).ok());
335 } else {
336 assert!(self.result.is_none());
337 }
338 }
339 JsonRpcVersion::Unknown => (),
340 }
341 self
342 }
343}