zebra_chain/
amount.rs

1//! Strongly-typed zatoshi amounts that prevent under/overflows.
2//!
3//! The [`Amount`] type is parameterized by a [`Constraint`] implementation that
4//! declares the range of allowed values. In contrast to regular arithmetic
5//! operations, which return values, arithmetic on [`Amount`]s returns
6//! [`Result`](std::result::Result)s.
7
8use std::{
9    cmp::Ordering,
10    fmt,
11    hash::{Hash, Hasher},
12    marker::PhantomData,
13    ops::RangeInclusive,
14};
15
16use crate::serialization::{ZcashDeserialize, ZcashSerialize};
17use byteorder::{ByteOrder, LittleEndian, ReadBytesExt, WriteBytesExt};
18
19#[cfg(any(test, feature = "proptest-impl"))]
20pub mod arbitrary;
21
22#[cfg(test)]
23mod tests;
24
25/// The result of an amount operation.
26pub type Result<T, E = Error> = std::result::Result<T, E>;
27
28/// A runtime validated type for representing amounts of zatoshis
29//
30// TODO:
31// - remove the default NegativeAllowed bound, to make consensus rule reviews easier
32// - put a Constraint bound on the type generic, not just some implementations
33#[derive(Clone, Copy, Serialize, Deserialize, Default)]
34#[serde(try_from = "i64")]
35#[serde(into = "i64")]
36#[serde(bound = "C: Constraint + Clone")]
37pub struct Amount<C = NegativeAllowed>(
38    /// The inner amount value.
39    i64,
40    /// Used for [`Constraint`] type inference.
41    ///
42    /// # Correctness
43    ///
44    /// This internal Zebra marker type is not consensus-critical.
45    /// And it should be ignored during testing. (And other internal uses.)
46    #[serde(skip)]
47    PhantomData<C>,
48);
49
50impl<C> fmt::Display for Amount<C> {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        let zats = self.zatoshis();
53
54        f.pad_integral(zats > 0, "", &zats.to_string())
55    }
56}
57
58impl<C> fmt::Debug for Amount<C> {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.debug_tuple(&format!("Amount<{}>", std::any::type_name::<C>()))
61            .field(&self.0)
62            .finish()
63    }
64}
65
66impl<C> Amount<C> {
67    /// Convert this amount to a different Amount type if it satisfies the new constraint
68    pub fn constrain<C2>(self) -> Result<Amount<C2>>
69    where
70        C2: Constraint,
71    {
72        self.0.try_into()
73    }
74
75    /// Returns the number of zatoshis in this amount.
76    pub fn zatoshis(&self) -> i64 {
77        self.0
78    }
79
80    /// To little endian byte array
81    pub fn to_bytes(&self) -> [u8; 8] {
82        let mut buf: [u8; 8] = [0; 8];
83        LittleEndian::write_i64(&mut buf, self.0);
84        buf
85    }
86
87    /// From little endian byte array
88    pub fn from_bytes(bytes: [u8; 8]) -> Result<Amount<C>>
89    where
90        C: Constraint,
91    {
92        let amount = i64::from_le_bytes(bytes);
93        amount.try_into()
94    }
95
96    /// Create a zero `Amount`
97    pub fn zero() -> Amount<C>
98    where
99        C: Constraint,
100    {
101        0.try_into().expect("an amount of 0 is always valid")
102    }
103}
104
105impl<C> std::ops::Add<Amount<C>> for Amount<C>
106where
107    C: Constraint,
108{
109    type Output = Result<Amount<C>>;
110
111    fn add(self, rhs: Amount<C>) -> Self::Output {
112        let value = self
113            .0
114            .checked_add(rhs.0)
115            .expect("adding two constrained Amounts is always within an i64");
116        value.try_into()
117    }
118}
119
120impl<C> std::ops::Add<Amount<C>> for Result<Amount<C>>
121where
122    C: Constraint,
123{
124    type Output = Result<Amount<C>>;
125
126    fn add(self, rhs: Amount<C>) -> Self::Output {
127        self? + rhs
128    }
129}
130
131impl<C> std::ops::Add<Result<Amount<C>>> for Amount<C>
132where
133    C: Constraint,
134{
135    type Output = Result<Amount<C>>;
136
137    fn add(self, rhs: Result<Amount<C>>) -> Self::Output {
138        self + rhs?
139    }
140}
141
142impl<C> std::ops::AddAssign<Amount<C>> for Result<Amount<C>>
143where
144    Amount<C>: Copy,
145    C: Constraint,
146{
147    fn add_assign(&mut self, rhs: Amount<C>) {
148        if let Ok(lhs) = *self {
149            *self = lhs + rhs;
150        }
151    }
152}
153
154impl<C> std::ops::Sub<Amount<C>> for Amount<C>
155where
156    C: Constraint,
157{
158    type Output = Result<Amount<C>>;
159
160    fn sub(self, rhs: Amount<C>) -> Self::Output {
161        let value = self
162            .0
163            .checked_sub(rhs.0)
164            .expect("subtracting two constrained Amounts is always within an i64");
165        value.try_into()
166    }
167}
168
169impl<C> std::ops::Sub<Amount<C>> for Result<Amount<C>>
170where
171    C: Constraint,
172{
173    type Output = Result<Amount<C>>;
174
175    fn sub(self, rhs: Amount<C>) -> Self::Output {
176        self? - rhs
177    }
178}
179
180impl<C> std::ops::Sub<Result<Amount<C>>> for Amount<C>
181where
182    C: Constraint,
183{
184    type Output = Result<Amount<C>>;
185
186    fn sub(self, rhs: Result<Amount<C>>) -> Self::Output {
187        self - rhs?
188    }
189}
190
191impl<C> std::ops::SubAssign<Amount<C>> for Result<Amount<C>>
192where
193    Amount<C>: Copy,
194    C: Constraint,
195{
196    fn sub_assign(&mut self, rhs: Amount<C>) {
197        if let Ok(lhs) = *self {
198            *self = lhs - rhs;
199        }
200    }
201}
202
203impl<C> From<Amount<C>> for i64 {
204    fn from(amount: Amount<C>) -> Self {
205        amount.0
206    }
207}
208
209impl From<Amount<NonNegative>> for u64 {
210    fn from(amount: Amount<NonNegative>) -> Self {
211        amount.0.try_into().expect("non-negative i64 fits in u64")
212    }
213}
214
215impl<C> From<Amount<C>> for jubjub::Fr {
216    fn from(a: Amount<C>) -> jubjub::Fr {
217        // TODO: this isn't constant time -- does that matter?
218        if a.0 < 0 {
219            let abs_amount = i128::from(a.0)
220                .checked_abs()
221                .expect("absolute i64 fits in i128");
222            let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
223
224            jubjub::Fr::from(abs_amount).neg()
225        } else {
226            jubjub::Fr::from(u64::try_from(a.0).expect("non-negative i64 fits in u64"))
227        }
228    }
229}
230
231impl<C> From<Amount<C>> for halo2::pasta::pallas::Scalar {
232    fn from(a: Amount<C>) -> halo2::pasta::pallas::Scalar {
233        // TODO: this isn't constant time -- does that matter?
234        if a.0 < 0 {
235            let abs_amount = i128::from(a.0)
236                .checked_abs()
237                .expect("absolute i64 fits in i128");
238            let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
239
240            halo2::pasta::pallas::Scalar::from(abs_amount).neg()
241        } else {
242            halo2::pasta::pallas::Scalar::from(
243                u64::try_from(a.0).expect("non-negative i64 fits in u64"),
244            )
245        }
246    }
247}
248
249impl<C> TryFrom<i32> for Amount<C>
250where
251    C: Constraint,
252{
253    type Error = Error;
254
255    fn try_from(value: i32) -> Result<Self, Self::Error> {
256        C::validate(value.into()).map(|v| Self(v, PhantomData))
257    }
258}
259
260impl<C> TryFrom<i64> for Amount<C>
261where
262    C: Constraint,
263{
264    type Error = Error;
265
266    fn try_from(value: i64) -> Result<Self, Self::Error> {
267        C::validate(value).map(|v| Self(v, PhantomData))
268    }
269}
270
271impl<C> TryFrom<u64> for Amount<C>
272where
273    C: Constraint,
274{
275    type Error = Error;
276
277    fn try_from(value: u64) -> Result<Self, Self::Error> {
278        let value = value.try_into().map_err(|source| Error::Convert {
279            value: value.into(),
280            source,
281        })?;
282
283        C::validate(value).map(|v| Self(v, PhantomData))
284    }
285}
286
287/// Conversion from `i128` to `Amount`.
288///
289/// Used to handle the result of multiplying negative `Amount`s by `u64`.
290impl<C> TryFrom<i128> for Amount<C>
291where
292    C: Constraint,
293{
294    type Error = Error;
295
296    fn try_from(value: i128) -> Result<Self, Self::Error> {
297        let value = value
298            .try_into()
299            .map_err(|source| Error::Convert { value, source })?;
300
301        C::validate(value).map(|v| Self(v, PhantomData))
302    }
303}
304
305impl<C> Hash for Amount<C> {
306    /// Amounts with the same value are equal, even if they have different constraints
307    fn hash<H: Hasher>(&self, state: &mut H) {
308        self.0.hash(state);
309    }
310}
311
312impl<C1, C2> PartialEq<Amount<C2>> for Amount<C1> {
313    fn eq(&self, other: &Amount<C2>) -> bool {
314        self.0.eq(&other.0)
315    }
316}
317
318impl<C> PartialEq<i64> for Amount<C> {
319    fn eq(&self, other: &i64) -> bool {
320        self.0.eq(other)
321    }
322}
323
324impl<C> PartialEq<Amount<C>> for i64 {
325    fn eq(&self, other: &Amount<C>) -> bool {
326        self.eq(&other.0)
327    }
328}
329
330impl<C> Eq for Amount<C> {}
331
332impl<C1, C2> PartialOrd<Amount<C2>> for Amount<C1> {
333    fn partial_cmp(&self, other: &Amount<C2>) -> Option<Ordering> {
334        Some(self.0.cmp(&other.0))
335    }
336}
337
338impl<C> Ord for Amount<C> {
339    fn cmp(&self, other: &Amount<C>) -> Ordering {
340        self.0.cmp(&other.0)
341    }
342}
343
344impl<C> std::ops::Mul<u64> for Amount<C>
345where
346    C: Constraint,
347{
348    type Output = Result<Amount<C>>;
349
350    fn mul(self, rhs: u64) -> Self::Output {
351        // use i128 for multiplication, so we can handle negative Amounts
352        let value = i128::from(self.0)
353            .checked_mul(i128::from(rhs))
354            .expect("multiplying i64 by u64 can't overflow i128");
355
356        value.try_into().map_err(|_| Error::MultiplicationOverflow {
357            amount: self.0,
358            multiplier: rhs,
359            overflowing_result: value,
360        })
361    }
362}
363
364impl<C> std::ops::Mul<Amount<C>> for u64
365where
366    C: Constraint,
367{
368    type Output = Result<Amount<C>>;
369
370    fn mul(self, rhs: Amount<C>) -> Self::Output {
371        rhs.mul(self)
372    }
373}
374
375impl<C> std::ops::Div<u64> for Amount<C>
376where
377    C: Constraint,
378{
379    type Output = Result<Amount<C>>;
380
381    fn div(self, rhs: u64) -> Self::Output {
382        let quotient = i128::from(self.0)
383            .checked_div(i128::from(rhs))
384            .ok_or(Error::DivideByZero { amount: self.0 })?;
385
386        Ok(quotient
387            .try_into()
388            .expect("division by a positive integer always stays within the constraint"))
389    }
390}
391
392impl<C> std::iter::Sum<Amount<C>> for Result<Amount<C>>
393where
394    C: Constraint,
395{
396    fn sum<I: Iterator<Item = Amount<C>>>(mut iter: I) -> Self {
397        let sum = iter.try_fold(Amount::zero(), |acc, amount| acc + amount);
398
399        match sum {
400            Ok(sum) => Ok(sum),
401            Err(Error::Constraint { value, .. }) => Err(Error::SumOverflow {
402                partial_sum: value,
403                remaining_items: iter.count(),
404            }),
405            Err(unexpected_error) => unreachable!("unexpected Add error: {:?}", unexpected_error),
406        }
407    }
408}
409
410impl<'amt, C> std::iter::Sum<&'amt Amount<C>> for Result<Amount<C>>
411where
412    C: Constraint + Copy + 'amt,
413{
414    fn sum<I: Iterator<Item = &'amt Amount<C>>>(iter: I) -> Self {
415        iter.copied().sum()
416    }
417}
418
419// TODO: add infallible impls for NonNegative <-> NegativeOrZero,
420//       when Rust uses trait output types to disambiguate overlapping impls.
421impl<C> std::ops::Neg for Amount<C>
422where
423    C: Constraint,
424{
425    type Output = Amount<NegativeAllowed>;
426    fn neg(self) -> Self::Output {
427        Amount::<NegativeAllowed>::try_from(-self.0)
428            .expect("a negation of any Amount into NegativeAllowed is always valid")
429    }
430}
431
432#[allow(missing_docs)]
433#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
434/// Errors that can be returned when validating [`Amount`]s.
435pub enum Error {
436    /// input {value} is outside of valid range for zatoshi Amount, valid_range={range:?}
437    Constraint {
438        value: i64,
439        range: RangeInclusive<i64>,
440    },
441
442    /// {value} could not be converted to an i64 Amount
443    Convert {
444        value: i128,
445        source: std::num::TryFromIntError,
446    },
447
448    /// i64 overflow when multiplying i64 amount {amount} by u64 {multiplier}, overflowing result {overflowing_result}
449    MultiplicationOverflow {
450        amount: i64,
451        multiplier: u64,
452        overflowing_result: i128,
453    },
454
455    /// cannot divide amount {amount} by zero
456    DivideByZero { amount: i64 },
457
458    /// i64 overflow when summing i64 amounts, partial_sum: {partial_sum}, remaining items: {remaining_items}
459    SumOverflow {
460        partial_sum: i64,
461        remaining_items: usize,
462    },
463}
464
465impl fmt::Display for Error {
466    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
467        f.write_str(&match self {
468            Error::Constraint { value, range } => format!(
469                "input {value} is outside of valid range for zatoshi Amount, valid_range={range:?}"
470            ),
471            Error::Convert { value, .. } => {
472                format!("{value} could not be converted to an i64 Amount")
473            }
474            Error::MultiplicationOverflow {
475                amount,
476                multiplier,
477                overflowing_result,
478            } => format!(
479                "overflow when calculating {amount}i64 * {multiplier}u64 = {overflowing_result}i128"
480            ),
481            Error::DivideByZero { amount } => format!("cannot divide amount {amount} by zero"),
482            Error::SumOverflow {
483                partial_sum,
484                remaining_items,
485            } => format!(
486                "overflow when summing i64 amounts; \
487                          partial sum: {partial_sum}, number of remaining items: {remaining_items}"
488            ),
489        })
490    }
491}
492
493impl Error {
494    /// Returns the invalid value for this error.
495    ///
496    /// This value may be an initial input value, partially calculated value,
497    /// or an overflowing or underflowing value.
498    pub fn invalid_value(&self) -> i128 {
499        use Error::*;
500
501        match self.clone() {
502            Constraint { value, .. } => value.into(),
503            Convert { value, .. } => value,
504            MultiplicationOverflow {
505                overflowing_result, ..
506            } => overflowing_result,
507            DivideByZero { amount } => amount.into(),
508            SumOverflow { partial_sum, .. } => partial_sum.into(),
509        }
510    }
511}
512
513/// Marker type for `Amount` that allows negative values.
514///
515/// ```
516/// # use zebra_chain::amount::{Constraint, MAX_MONEY, NegativeAllowed};
517/// assert_eq!(
518///     NegativeAllowed::valid_range(),
519///     -MAX_MONEY..=MAX_MONEY,
520/// );
521/// ```
522#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
523pub struct NegativeAllowed;
524
525impl Constraint for NegativeAllowed {
526    fn valid_range() -> RangeInclusive<i64> {
527        -MAX_MONEY..=MAX_MONEY
528    }
529}
530
531/// Marker type for `Amount` that requires nonnegative values.
532///
533/// ```
534/// # use zebra_chain::amount::{Constraint, MAX_MONEY, NonNegative};
535/// assert_eq!(
536///     NonNegative::valid_range(),
537///     0..=MAX_MONEY,
538/// );
539/// ```
540#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Default)]
541#[cfg_attr(
542    any(test, feature = "proptest-impl"),
543    derive(proptest_derive::Arbitrary)
544)]
545pub struct NonNegative;
546
547impl Constraint for NonNegative {
548    fn valid_range() -> RangeInclusive<i64> {
549        0..=MAX_MONEY
550    }
551}
552
553/// Marker type for `Amount` that requires negative or zero values.
554///
555/// Used for coinbase transactions in `getblocktemplate` RPCs.
556///
557/// ```
558/// # use zebra_chain::amount::{Constraint, MAX_MONEY, NegativeOrZero};
559/// assert_eq!(
560///     NegativeOrZero::valid_range(),
561///     -MAX_MONEY..=0,
562/// );
563/// ```
564#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
565pub struct NegativeOrZero;
566
567impl Constraint for NegativeOrZero {
568    fn valid_range() -> RangeInclusive<i64> {
569        -MAX_MONEY..=0
570    }
571}
572
573/// Number of zatoshis in 1 ZEC
574pub const COIN: i64 = 100_000_000;
575
576/// The maximum zatoshi amount.
577pub const MAX_MONEY: i64 = 21_000_000 * COIN;
578
579/// A trait for defining constraints on `Amount`
580pub trait Constraint {
581    /// Returns the range of values that are valid under this constraint
582    fn valid_range() -> RangeInclusive<i64>;
583
584    /// Check if an input value is within the valid range
585    fn validate(value: i64) -> Result<i64, Error> {
586        let range = Self::valid_range();
587
588        if !range.contains(&value) {
589            Err(Error::Constraint { value, range })
590        } else {
591            Ok(value)
592        }
593    }
594}
595
596impl ZcashSerialize for Amount<NegativeAllowed> {
597    fn zcash_serialize<W: std::io::Write>(&self, mut writer: W) -> Result<(), std::io::Error> {
598        writer.write_i64::<LittleEndian>(self.0)
599    }
600}
601
602impl ZcashDeserialize for Amount<NegativeAllowed> {
603    fn zcash_deserialize<R: std::io::Read>(
604        mut reader: R,
605    ) -> Result<Self, crate::serialization::SerializationError> {
606        Ok(reader.read_i64::<LittleEndian>()?.try_into()?)
607    }
608}
609
610impl ZcashSerialize for Amount<NonNegative> {
611    #[allow(clippy::unwrap_in_result)]
612    fn zcash_serialize<W: std::io::Write>(&self, mut writer: W) -> Result<(), std::io::Error> {
613        let amount = self
614            .0
615            .try_into()
616            .expect("constraint guarantees value is positive");
617
618        writer.write_u64::<LittleEndian>(amount)
619    }
620}
621
622impl ZcashDeserialize for Amount<NonNegative> {
623    fn zcash_deserialize<R: std::io::Read>(
624        mut reader: R,
625    ) -> Result<Self, crate::serialization::SerializationError> {
626        Ok(reader.read_u64::<LittleEndian>()?.try_into()?)
627    }
628}