1use std::{
9 cmp::Ordering,
10 fmt,
11 hash::{Hash, Hasher},
12 marker::PhantomData,
13 ops::RangeInclusive,
14};
15
16use crate::serialization::{ZcashDeserialize, ZcashSerialize};
17use byteorder::{ByteOrder, LittleEndian, ReadBytesExt, WriteBytesExt};
18
19#[cfg(any(test, feature = "proptest-impl"))]
20pub mod arbitrary;
21
22#[cfg(test)]
23mod tests;
24
25pub type Result<T, E = Error> = std::result::Result<T, E>;
27
28#[derive(Clone, Copy, Serialize, Deserialize, Default)]
34#[serde(try_from = "i64")]
35#[serde(into = "i64")]
36#[serde(bound = "C: Constraint + Clone")]
37pub struct Amount<C = NegativeAllowed>(
38 i64,
40 #[serde(skip)]
47 PhantomData<C>,
48);
49
50impl<C> fmt::Display for Amount<C> {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 let zats = self.zatoshis();
53
54 f.pad_integral(zats > 0, "", &zats.to_string())
55 }
56}
57
58impl<C> fmt::Debug for Amount<C> {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.debug_tuple(&format!("Amount<{}>", std::any::type_name::<C>()))
61 .field(&self.0)
62 .finish()
63 }
64}
65
66impl<C> Amount<C> {
67 pub fn constrain<C2>(self) -> Result<Amount<C2>>
69 where
70 C2: Constraint,
71 {
72 self.0.try_into()
73 }
74
75 pub fn zatoshis(&self) -> i64 {
77 self.0
78 }
79
80 pub fn to_bytes(&self) -> [u8; 8] {
82 let mut buf: [u8; 8] = [0; 8];
83 LittleEndian::write_i64(&mut buf, self.0);
84 buf
85 }
86
87 pub fn from_bytes(bytes: [u8; 8]) -> Result<Amount<C>>
89 where
90 C: Constraint,
91 {
92 let amount = i64::from_le_bytes(bytes);
93 amount.try_into()
94 }
95
96 pub fn zero() -> Amount<C>
98 where
99 C: Constraint,
100 {
101 0.try_into().expect("an amount of 0 is always valid")
102 }
103}
104
105impl<C> std::ops::Add<Amount<C>> for Amount<C>
106where
107 C: Constraint,
108{
109 type Output = Result<Amount<C>>;
110
111 fn add(self, rhs: Amount<C>) -> Self::Output {
112 let value = self
113 .0
114 .checked_add(rhs.0)
115 .expect("adding two constrained Amounts is always within an i64");
116 value.try_into()
117 }
118}
119
120impl<C> std::ops::Add<Amount<C>> for Result<Amount<C>>
121where
122 C: Constraint,
123{
124 type Output = Result<Amount<C>>;
125
126 fn add(self, rhs: Amount<C>) -> Self::Output {
127 self? + rhs
128 }
129}
130
131impl<C> std::ops::Add<Result<Amount<C>>> for Amount<C>
132where
133 C: Constraint,
134{
135 type Output = Result<Amount<C>>;
136
137 fn add(self, rhs: Result<Amount<C>>) -> Self::Output {
138 self + rhs?
139 }
140}
141
142impl<C> std::ops::AddAssign<Amount<C>> for Result<Amount<C>>
143where
144 Amount<C>: Copy,
145 C: Constraint,
146{
147 fn add_assign(&mut self, rhs: Amount<C>) {
148 if let Ok(lhs) = *self {
149 *self = lhs + rhs;
150 }
151 }
152}
153
154impl<C> std::ops::Sub<Amount<C>> for Amount<C>
155where
156 C: Constraint,
157{
158 type Output = Result<Amount<C>>;
159
160 fn sub(self, rhs: Amount<C>) -> Self::Output {
161 let value = self
162 .0
163 .checked_sub(rhs.0)
164 .expect("subtracting two constrained Amounts is always within an i64");
165 value.try_into()
166 }
167}
168
169impl<C> std::ops::Sub<Amount<C>> for Result<Amount<C>>
170where
171 C: Constraint,
172{
173 type Output = Result<Amount<C>>;
174
175 fn sub(self, rhs: Amount<C>) -> Self::Output {
176 self? - rhs
177 }
178}
179
180impl<C> std::ops::Sub<Result<Amount<C>>> for Amount<C>
181where
182 C: Constraint,
183{
184 type Output = Result<Amount<C>>;
185
186 fn sub(self, rhs: Result<Amount<C>>) -> Self::Output {
187 self - rhs?
188 }
189}
190
191impl<C> std::ops::SubAssign<Amount<C>> for Result<Amount<C>>
192where
193 Amount<C>: Copy,
194 C: Constraint,
195{
196 fn sub_assign(&mut self, rhs: Amount<C>) {
197 if let Ok(lhs) = *self {
198 *self = lhs - rhs;
199 }
200 }
201}
202
203impl<C> From<Amount<C>> for i64 {
204 fn from(amount: Amount<C>) -> Self {
205 amount.0
206 }
207}
208
209impl From<Amount<NonNegative>> for u64 {
210 fn from(amount: Amount<NonNegative>) -> Self {
211 amount.0.try_into().expect("non-negative i64 fits in u64")
212 }
213}
214
215impl<C> From<Amount<C>> for jubjub::Fr {
216 fn from(a: Amount<C>) -> jubjub::Fr {
217 if a.0 < 0 {
219 let abs_amount = i128::from(a.0)
220 .checked_abs()
221 .expect("absolute i64 fits in i128");
222 let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
223
224 jubjub::Fr::from(abs_amount).neg()
225 } else {
226 jubjub::Fr::from(u64::try_from(a.0).expect("non-negative i64 fits in u64"))
227 }
228 }
229}
230
231impl<C> From<Amount<C>> for halo2::pasta::pallas::Scalar {
232 fn from(a: Amount<C>) -> halo2::pasta::pallas::Scalar {
233 if a.0 < 0 {
235 let abs_amount = i128::from(a.0)
236 .checked_abs()
237 .expect("absolute i64 fits in i128");
238 let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
239
240 halo2::pasta::pallas::Scalar::from(abs_amount).neg()
241 } else {
242 halo2::pasta::pallas::Scalar::from(
243 u64::try_from(a.0).expect("non-negative i64 fits in u64"),
244 )
245 }
246 }
247}
248
249impl<C> TryFrom<i32> for Amount<C>
250where
251 C: Constraint,
252{
253 type Error = Error;
254
255 fn try_from(value: i32) -> Result<Self, Self::Error> {
256 C::validate(value.into()).map(|v| Self(v, PhantomData))
257 }
258}
259
260impl<C> TryFrom<i64> for Amount<C>
261where
262 C: Constraint,
263{
264 type Error = Error;
265
266 fn try_from(value: i64) -> Result<Self, Self::Error> {
267 C::validate(value).map(|v| Self(v, PhantomData))
268 }
269}
270
271impl<C> TryFrom<u64> for Amount<C>
272where
273 C: Constraint,
274{
275 type Error = Error;
276
277 fn try_from(value: u64) -> Result<Self, Self::Error> {
278 let value = value.try_into().map_err(|source| Error::Convert {
279 value: value.into(),
280 source,
281 })?;
282
283 C::validate(value).map(|v| Self(v, PhantomData))
284 }
285}
286
287impl<C> TryFrom<i128> for Amount<C>
291where
292 C: Constraint,
293{
294 type Error = Error;
295
296 fn try_from(value: i128) -> Result<Self, Self::Error> {
297 let value = value
298 .try_into()
299 .map_err(|source| Error::Convert { value, source })?;
300
301 C::validate(value).map(|v| Self(v, PhantomData))
302 }
303}
304
305impl<C> Hash for Amount<C> {
306 fn hash<H: Hasher>(&self, state: &mut H) {
308 self.0.hash(state);
309 }
310}
311
312impl<C1, C2> PartialEq<Amount<C2>> for Amount<C1> {
313 fn eq(&self, other: &Amount<C2>) -> bool {
314 self.0.eq(&other.0)
315 }
316}
317
318impl<C> PartialEq<i64> for Amount<C> {
319 fn eq(&self, other: &i64) -> bool {
320 self.0.eq(other)
321 }
322}
323
324impl<C> PartialEq<Amount<C>> for i64 {
325 fn eq(&self, other: &Amount<C>) -> bool {
326 self.eq(&other.0)
327 }
328}
329
330impl<C> Eq for Amount<C> {}
331
332impl<C1, C2> PartialOrd<Amount<C2>> for Amount<C1> {
333 fn partial_cmp(&self, other: &Amount<C2>) -> Option<Ordering> {
334 Some(self.0.cmp(&other.0))
335 }
336}
337
338impl<C> Ord for Amount<C> {
339 fn cmp(&self, other: &Amount<C>) -> Ordering {
340 self.0.cmp(&other.0)
341 }
342}
343
344impl<C> std::ops::Mul<u64> for Amount<C>
345where
346 C: Constraint,
347{
348 type Output = Result<Amount<C>>;
349
350 fn mul(self, rhs: u64) -> Self::Output {
351 let value = i128::from(self.0)
353 .checked_mul(i128::from(rhs))
354 .expect("multiplying i64 by u64 can't overflow i128");
355
356 value.try_into().map_err(|_| Error::MultiplicationOverflow {
357 amount: self.0,
358 multiplier: rhs,
359 overflowing_result: value,
360 })
361 }
362}
363
364impl<C> std::ops::Mul<Amount<C>> for u64
365where
366 C: Constraint,
367{
368 type Output = Result<Amount<C>>;
369
370 fn mul(self, rhs: Amount<C>) -> Self::Output {
371 rhs.mul(self)
372 }
373}
374
375impl<C> std::ops::Div<u64> for Amount<C>
376where
377 C: Constraint,
378{
379 type Output = Result<Amount<C>>;
380
381 fn div(self, rhs: u64) -> Self::Output {
382 let quotient = i128::from(self.0)
383 .checked_div(i128::from(rhs))
384 .ok_or(Error::DivideByZero { amount: self.0 })?;
385
386 Ok(quotient
387 .try_into()
388 .expect("division by a positive integer always stays within the constraint"))
389 }
390}
391
392impl<C> std::iter::Sum<Amount<C>> for Result<Amount<C>>
393where
394 C: Constraint,
395{
396 fn sum<I: Iterator<Item = Amount<C>>>(mut iter: I) -> Self {
397 let sum = iter.try_fold(Amount::zero(), |acc, amount| acc + amount);
398
399 match sum {
400 Ok(sum) => Ok(sum),
401 Err(Error::Constraint { value, .. }) => Err(Error::SumOverflow {
402 partial_sum: value,
403 remaining_items: iter.count(),
404 }),
405 Err(unexpected_error) => unreachable!("unexpected Add error: {:?}", unexpected_error),
406 }
407 }
408}
409
410impl<'amt, C> std::iter::Sum<&'amt Amount<C>> for Result<Amount<C>>
411where
412 C: Constraint + Copy + 'amt,
413{
414 fn sum<I: Iterator<Item = &'amt Amount<C>>>(iter: I) -> Self {
415 iter.copied().sum()
416 }
417}
418
419impl<C> std::ops::Neg for Amount<C>
422where
423 C: Constraint,
424{
425 type Output = Amount<NegativeAllowed>;
426 fn neg(self) -> Self::Output {
427 Amount::<NegativeAllowed>::try_from(-self.0)
428 .expect("a negation of any Amount into NegativeAllowed is always valid")
429 }
430}
431
432#[allow(missing_docs)]
433#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
434pub enum Error {
436 Constraint {
438 value: i64,
439 range: RangeInclusive<i64>,
440 },
441
442 Convert {
444 value: i128,
445 source: std::num::TryFromIntError,
446 },
447
448 MultiplicationOverflow {
450 amount: i64,
451 multiplier: u64,
452 overflowing_result: i128,
453 },
454
455 DivideByZero { amount: i64 },
457
458 SumOverflow {
460 partial_sum: i64,
461 remaining_items: usize,
462 },
463}
464
465impl fmt::Display for Error {
466 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
467 f.write_str(&match self {
468 Error::Constraint { value, range } => format!(
469 "input {value} is outside of valid range for zatoshi Amount, valid_range={range:?}"
470 ),
471 Error::Convert { value, .. } => {
472 format!("{value} could not be converted to an i64 Amount")
473 }
474 Error::MultiplicationOverflow {
475 amount,
476 multiplier,
477 overflowing_result,
478 } => format!(
479 "overflow when calculating {amount}i64 * {multiplier}u64 = {overflowing_result}i128"
480 ),
481 Error::DivideByZero { amount } => format!("cannot divide amount {amount} by zero"),
482 Error::SumOverflow {
483 partial_sum,
484 remaining_items,
485 } => format!(
486 "overflow when summing i64 amounts; \
487 partial sum: {partial_sum}, number of remaining items: {remaining_items}"
488 ),
489 })
490 }
491}
492
493impl Error {
494 pub fn invalid_value(&self) -> i128 {
499 use Error::*;
500
501 match self.clone() {
502 Constraint { value, .. } => value.into(),
503 Convert { value, .. } => value,
504 MultiplicationOverflow {
505 overflowing_result, ..
506 } => overflowing_result,
507 DivideByZero { amount } => amount.into(),
508 SumOverflow { partial_sum, .. } => partial_sum.into(),
509 }
510 }
511}
512
513#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
523pub struct NegativeAllowed;
524
525impl Constraint for NegativeAllowed {
526 fn valid_range() -> RangeInclusive<i64> {
527 -MAX_MONEY..=MAX_MONEY
528 }
529}
530
531#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Default)]
541#[cfg_attr(
542 any(test, feature = "proptest-impl"),
543 derive(proptest_derive::Arbitrary)
544)]
545pub struct NonNegative;
546
547impl Constraint for NonNegative {
548 fn valid_range() -> RangeInclusive<i64> {
549 0..=MAX_MONEY
550 }
551}
552
553#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
565pub struct NegativeOrZero;
566
567impl Constraint for NegativeOrZero {
568 fn valid_range() -> RangeInclusive<i64> {
569 -MAX_MONEY..=0
570 }
571}
572
573pub const COIN: i64 = 100_000_000;
575
576pub const MAX_MONEY: i64 = 21_000_000 * COIN;
578
579pub trait Constraint {
581 fn valid_range() -> RangeInclusive<i64>;
583
584 fn validate(value: i64) -> Result<i64, Error> {
586 let range = Self::valid_range();
587
588 if !range.contains(&value) {
589 Err(Error::Constraint { value, range })
590 } else {
591 Ok(value)
592 }
593 }
594}
595
596impl ZcashSerialize for Amount<NegativeAllowed> {
597 fn zcash_serialize<W: std::io::Write>(&self, mut writer: W) -> Result<(), std::io::Error> {
598 writer.write_i64::<LittleEndian>(self.0)
599 }
600}
601
602impl ZcashDeserialize for Amount<NegativeAllowed> {
603 fn zcash_deserialize<R: std::io::Read>(
604 mut reader: R,
605 ) -> Result<Self, crate::serialization::SerializationError> {
606 Ok(reader.read_i64::<LittleEndian>()?.try_into()?)
607 }
608}
609
610impl ZcashSerialize for Amount<NonNegative> {
611 #[allow(clippy::unwrap_in_result)]
612 fn zcash_serialize<W: std::io::Write>(&self, mut writer: W) -> Result<(), std::io::Error> {
613 let amount = self
614 .0
615 .try_into()
616 .expect("constraint guarantees value is positive");
617
618 writer.write_u64::<LittleEndian>(amount)
619 }
620}
621
622impl ZcashDeserialize for Amount<NonNegative> {
623 fn zcash_deserialize<R: std::io::Read>(
624 mut reader: R,
625 ) -> Result<Self, crate::serialization::SerializationError> {
626 Ok(reader.read_u64::<LittleEndian>()?.try_into()?)
627 }
628}