1use std::{
9 cmp::Ordering,
10 fmt,
11 hash::{Hash, Hasher},
12 marker::PhantomData,
13 ops::RangeInclusive,
14};
15
16use crate::serialization::{ZcashDeserialize, ZcashSerialize};
17use byteorder::{ByteOrder, LittleEndian, ReadBytesExt, WriteBytesExt};
18
19#[cfg(any(test, feature = "proptest-impl"))]
20pub mod arbitrary;
21
22#[cfg(test)]
23mod tests;
24
25pub type Result<T, E = Error> = std::result::Result<T, E>;
27
28#[derive(Clone, Copy, Serialize, Deserialize, Default)]
34#[serde(try_from = "i64")]
35#[serde(into = "i64")]
36#[serde(bound = "C: Constraint + Clone")]
37pub struct Amount<C = NegativeAllowed>(
38 i64,
40 #[serde(skip)]
47 PhantomData<C>,
48);
49
50impl<C> fmt::Display for Amount<C> {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 let zats = self.zatoshis();
53
54 f.pad_integral(zats > 0, "", &zats.to_string())
55 }
56}
57
58impl<C> fmt::Debug for Amount<C> {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.debug_tuple(&format!("Amount<{}>", std::any::type_name::<C>()))
61 .field(&self.0)
62 .finish()
63 }
64}
65
66impl Amount<NonNegative> {
67 pub const fn new_from_zec(zec_value: i64) -> Self {
69 Self::new(zec_value.checked_mul(COIN).expect("should fit in i64"))
70 }
71
72 pub const fn new(zatoshis: i64) -> Self {
74 assert!(zatoshis <= MAX_MONEY && zatoshis >= 0);
75 Self(zatoshis, PhantomData)
76 }
77
78 pub const fn div_exact(self, rhs: i64) -> Self {
80 let result = self.0.checked_div(rhs).expect("divisor must be non-zero");
81 if self.0 % rhs != 0 {
82 panic!("divisor must divide amount evenly, no remainder");
83 }
84
85 Self(result, PhantomData)
86 }
87}
88
89impl<C> Amount<C> {
90 pub fn constrain<C2>(self) -> Result<Amount<C2>>
92 where
93 C2: Constraint,
94 {
95 self.0.try_into()
96 }
97
98 pub fn zatoshis(&self) -> i64 {
100 self.0
101 }
102
103 pub fn checked_sub<C2: Constraint>(self, rhs: Amount<C2>) -> Option<Amount> {
105 self.0.checked_sub(rhs.0).and_then(|v| v.try_into().ok())
106 }
107
108 pub fn to_bytes(&self) -> [u8; 8] {
110 let mut buf: [u8; 8] = [0; 8];
111 LittleEndian::write_i64(&mut buf, self.0);
112 buf
113 }
114
115 pub fn from_bytes(bytes: [u8; 8]) -> Result<Amount<C>>
117 where
118 C: Constraint,
119 {
120 let amount = i64::from_le_bytes(bytes);
121 amount.try_into()
122 }
123
124 pub fn zero() -> Amount<C>
126 where
127 C: Constraint,
128 {
129 0.try_into().expect("an amount of 0 is always valid")
130 }
131}
132
133impl<C> std::ops::Add<Amount<C>> for Amount<C>
134where
135 C: Constraint,
136{
137 type Output = Result<Amount<C>>;
138
139 fn add(self, rhs: Amount<C>) -> Self::Output {
140 let value = self
141 .0
142 .checked_add(rhs.0)
143 .expect("adding two constrained Amounts is always within an i64");
144 value.try_into()
145 }
146}
147
148impl<C> std::ops::Add<Amount<C>> for Result<Amount<C>>
149where
150 C: Constraint,
151{
152 type Output = Result<Amount<C>>;
153
154 fn add(self, rhs: Amount<C>) -> Self::Output {
155 self? + rhs
156 }
157}
158
159impl<C> std::ops::Add<Result<Amount<C>>> for Amount<C>
160where
161 C: Constraint,
162{
163 type Output = Result<Amount<C>>;
164
165 fn add(self, rhs: Result<Amount<C>>) -> Self::Output {
166 self + rhs?
167 }
168}
169
170impl<C> std::ops::AddAssign<Amount<C>> for Result<Amount<C>>
171where
172 Amount<C>: Copy,
173 C: Constraint,
174{
175 fn add_assign(&mut self, rhs: Amount<C>) {
176 if let Ok(lhs) = *self {
177 *self = lhs + rhs;
178 }
179 }
180}
181
182impl<C> std::ops::Sub<Amount<C>> for Amount<C>
183where
184 C: Constraint,
185{
186 type Output = Result<Amount<C>>;
187
188 fn sub(self, rhs: Amount<C>) -> Self::Output {
189 let value = self
190 .0
191 .checked_sub(rhs.0)
192 .expect("subtracting two constrained Amounts is always within an i64");
193 value.try_into()
194 }
195}
196
197impl<C> std::ops::Sub<Amount<C>> for Result<Amount<C>>
198where
199 C: Constraint,
200{
201 type Output = Result<Amount<C>>;
202
203 fn sub(self, rhs: Amount<C>) -> Self::Output {
204 self? - rhs
205 }
206}
207
208impl<C> std::ops::Sub<Result<Amount<C>>> for Amount<C>
209where
210 C: Constraint,
211{
212 type Output = Result<Amount<C>>;
213
214 fn sub(self, rhs: Result<Amount<C>>) -> Self::Output {
215 self - rhs?
216 }
217}
218
219impl<C> std::ops::SubAssign<Amount<C>> for Result<Amount<C>>
220where
221 Amount<C>: Copy,
222 C: Constraint,
223{
224 fn sub_assign(&mut self, rhs: Amount<C>) {
225 if let Ok(lhs) = *self {
226 *self = lhs - rhs;
227 }
228 }
229}
230
231impl<C> From<Amount<C>> for i64 {
232 fn from(amount: Amount<C>) -> Self {
233 amount.0
234 }
235}
236
237impl From<Amount<NonNegative>> for u64 {
238 fn from(amount: Amount<NonNegative>) -> Self {
239 amount.0.try_into().expect("non-negative i64 fits in u64")
240 }
241}
242
243impl<C> From<Amount<C>> for jubjub::Fr {
244 fn from(a: Amount<C>) -> jubjub::Fr {
245 if a.0 < 0 {
247 let abs_amount = i128::from(a.0)
248 .checked_abs()
249 .expect("absolute i64 fits in i128");
250 let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
251
252 jubjub::Fr::from(abs_amount).neg()
253 } else {
254 jubjub::Fr::from(u64::try_from(a.0).expect("non-negative i64 fits in u64"))
255 }
256 }
257}
258
259impl<C> From<Amount<C>> for halo2::pasta::pallas::Scalar {
260 fn from(a: Amount<C>) -> halo2::pasta::pallas::Scalar {
261 if a.0 < 0 {
263 let abs_amount = i128::from(a.0)
264 .checked_abs()
265 .expect("absolute i64 fits in i128");
266 let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
267
268 halo2::pasta::pallas::Scalar::from(abs_amount).neg()
269 } else {
270 halo2::pasta::pallas::Scalar::from(
271 u64::try_from(a.0).expect("non-negative i64 fits in u64"),
272 )
273 }
274 }
275}
276
277impl<C> TryFrom<i32> for Amount<C>
278where
279 C: Constraint,
280{
281 type Error = Error;
282
283 fn try_from(value: i32) -> Result<Self, Self::Error> {
284 C::validate(value.into()).map(|v| Self(v, PhantomData))
285 }
286}
287
288impl<C> TryFrom<i64> for Amount<C>
289where
290 C: Constraint,
291{
292 type Error = Error;
293
294 fn try_from(value: i64) -> Result<Self, Self::Error> {
295 C::validate(value).map(|v| Self(v, PhantomData))
296 }
297}
298
299impl<C> TryFrom<u64> for Amount<C>
300where
301 C: Constraint,
302{
303 type Error = Error;
304
305 fn try_from(value: u64) -> Result<Self, Self::Error> {
306 let value = value.try_into().map_err(|source| Error::Convert {
307 value: value.into(),
308 source,
309 })?;
310
311 C::validate(value).map(|v| Self(v, PhantomData))
312 }
313}
314
315impl<C> TryFrom<i128> for Amount<C>
319where
320 C: Constraint,
321{
322 type Error = Error;
323
324 fn try_from(value: i128) -> Result<Self, Self::Error> {
325 let value = value
326 .try_into()
327 .map_err(|source| Error::Convert { value, source })?;
328
329 C::validate(value).map(|v| Self(v, PhantomData))
330 }
331}
332
333impl<C> Hash for Amount<C> {
334 fn hash<H: Hasher>(&self, state: &mut H) {
336 self.0.hash(state);
337 }
338}
339
340impl<C1, C2> PartialEq<Amount<C2>> for Amount<C1> {
341 fn eq(&self, other: &Amount<C2>) -> bool {
342 self.0.eq(&other.0)
343 }
344}
345
346impl<C> PartialEq<i64> for Amount<C> {
347 fn eq(&self, other: &i64) -> bool {
348 self.0.eq(other)
349 }
350}
351
352impl<C> PartialEq<Amount<C>> for i64 {
353 fn eq(&self, other: &Amount<C>) -> bool {
354 self.eq(&other.0)
355 }
356}
357
358impl<C> Eq for Amount<C> {}
359
360impl<C1, C2> PartialOrd<Amount<C2>> for Amount<C1> {
361 fn partial_cmp(&self, other: &Amount<C2>) -> Option<Ordering> {
362 Some(self.0.cmp(&other.0))
363 }
364}
365
366impl<C> Ord for Amount<C> {
367 fn cmp(&self, other: &Amount<C>) -> Ordering {
368 self.0.cmp(&other.0)
369 }
370}
371
372impl<C> std::ops::Mul<u64> for Amount<C>
373where
374 C: Constraint,
375{
376 type Output = Result<Amount<C>>;
377
378 fn mul(self, rhs: u64) -> Self::Output {
379 let value = i128::from(self.0)
381 .checked_mul(i128::from(rhs))
382 .expect("multiplying i64 by u64 can't overflow i128");
383
384 value.try_into().map_err(|_| Error::MultiplicationOverflow {
385 amount: self.0,
386 multiplier: rhs,
387 overflowing_result: value,
388 })
389 }
390}
391
392impl<C> std::ops::Mul<Amount<C>> for u64
393where
394 C: Constraint,
395{
396 type Output = Result<Amount<C>>;
397
398 fn mul(self, rhs: Amount<C>) -> Self::Output {
399 rhs.mul(self)
400 }
401}
402
403impl<C> std::ops::Div<u64> for Amount<C>
404where
405 C: Constraint,
406{
407 type Output = Result<Amount<C>>;
408
409 fn div(self, rhs: u64) -> Self::Output {
410 let quotient = i128::from(self.0)
411 .checked_div(i128::from(rhs))
412 .ok_or(Error::DivideByZero { amount: self.0 })?;
413
414 Ok(quotient
415 .try_into()
416 .expect("division by a positive integer always stays within the constraint"))
417 }
418}
419
420impl<C> std::iter::Sum<Amount<C>> for Result<Amount<C>>
421where
422 C: Constraint,
423{
424 fn sum<I: Iterator<Item = Amount<C>>>(mut iter: I) -> Self {
425 let sum = iter.try_fold(Amount::zero(), |acc, amount| acc + amount);
426
427 match sum {
428 Ok(sum) => Ok(sum),
429 Err(Error::Constraint { value, .. }) => Err(Error::SumOverflow {
430 partial_sum: value,
431 remaining_items: iter.count(),
432 }),
433 Err(unexpected_error) => unreachable!("unexpected Add error: {:?}", unexpected_error),
434 }
435 }
436}
437
438impl<'amt, C> std::iter::Sum<&'amt Amount<C>> for Result<Amount<C>>
439where
440 C: Constraint + Copy + 'amt,
441{
442 fn sum<I: Iterator<Item = &'amt Amount<C>>>(iter: I) -> Self {
443 iter.copied().sum()
444 }
445}
446
447impl<C> std::ops::Neg for Amount<C>
450where
451 C: Constraint,
452{
453 type Output = Amount<NegativeAllowed>;
454 fn neg(self) -> Self::Output {
455 Amount::<NegativeAllowed>::try_from(-self.0)
456 .expect("a negation of any Amount into NegativeAllowed is always valid")
457 }
458}
459
460#[allow(missing_docs)]
461#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
462pub enum Error {
464 Constraint {
466 value: i64,
467 range: RangeInclusive<i64>,
468 },
469
470 Convert {
472 value: i128,
473 source: std::num::TryFromIntError,
474 },
475
476 MultiplicationOverflow {
478 amount: i64,
479 multiplier: u64,
480 overflowing_result: i128,
481 },
482
483 DivideByZero { amount: i64 },
485
486 SumOverflow {
488 partial_sum: i64,
489 remaining_items: usize,
490 },
491}
492
493impl fmt::Display for Error {
494 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
495 f.write_str(&match self {
496 Error::Constraint { value, range } => format!(
497 "input {value} is outside of valid range for zatoshi Amount, valid_range={range:?}"
498 ),
499 Error::Convert { value, .. } => {
500 format!("{value} could not be converted to an i64 Amount")
501 }
502 Error::MultiplicationOverflow {
503 amount,
504 multiplier,
505 overflowing_result,
506 } => format!(
507 "overflow when calculating {amount}i64 * {multiplier}u64 = {overflowing_result}i128"
508 ),
509 Error::DivideByZero { amount } => format!("cannot divide amount {amount} by zero"),
510 Error::SumOverflow {
511 partial_sum,
512 remaining_items,
513 } => format!(
514 "overflow when summing i64 amounts; \
515 partial sum: {partial_sum}, number of remaining items: {remaining_items}"
516 ),
517 })
518 }
519}
520
521impl Error {
522 pub fn invalid_value(&self) -> i128 {
527 use Error::*;
528
529 match self.clone() {
530 Constraint { value, .. } => value.into(),
531 Convert { value, .. } => value,
532 MultiplicationOverflow {
533 overflowing_result, ..
534 } => overflowing_result,
535 DivideByZero { amount } => amount.into(),
536 SumOverflow { partial_sum, .. } => partial_sum.into(),
537 }
538 }
539}
540
541#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
551pub struct NegativeAllowed;
552
553impl Constraint for NegativeAllowed {
554 fn valid_range() -> RangeInclusive<i64> {
555 -MAX_MONEY..=MAX_MONEY
556 }
557}
558
559#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Default)]
569#[cfg_attr(
570 any(test, feature = "proptest-impl"),
571 derive(proptest_derive::Arbitrary)
572)]
573pub struct NonNegative;
574
575impl Constraint for NonNegative {
576 fn valid_range() -> RangeInclusive<i64> {
577 0..=MAX_MONEY
578 }
579}
580
581#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
593pub struct NegativeOrZero;
594
595impl Constraint for NegativeOrZero {
596 fn valid_range() -> RangeInclusive<i64> {
597 -MAX_MONEY..=0
598 }
599}
600
601pub const COIN: i64 = 100_000_000;
603
604pub const MAX_MONEY: i64 = 21_000_000 * COIN;
606
607pub trait Constraint {
609 fn valid_range() -> RangeInclusive<i64>;
611
612 fn validate(value: i64) -> Result<i64, Error> {
614 let range = Self::valid_range();
615
616 if !range.contains(&value) {
617 Err(Error::Constraint { value, range })
618 } else {
619 Ok(value)
620 }
621 }
622}
623
624impl ZcashSerialize for Amount<NegativeAllowed> {
625 fn zcash_serialize<W: std::io::Write>(&self, mut writer: W) -> Result<(), std::io::Error> {
626 writer.write_i64::<LittleEndian>(self.0)
627 }
628}
629
630impl ZcashDeserialize for Amount<NegativeAllowed> {
631 fn zcash_deserialize<R: std::io::Read>(
632 mut reader: R,
633 ) -> Result<Self, crate::serialization::SerializationError> {
634 Ok(reader.read_i64::<LittleEndian>()?.try_into()?)
635 }
636}
637
638impl ZcashSerialize for Amount<NonNegative> {
639 #[allow(clippy::unwrap_in_result)]
640 fn zcash_serialize<W: std::io::Write>(&self, mut writer: W) -> Result<(), std::io::Error> {
641 let amount = self
642 .0
643 .try_into()
644 .expect("constraint guarantees value is positive");
645
646 writer.write_u64::<LittleEndian>(amount)
647 }
648}
649
650impl ZcashDeserialize for Amount<NonNegative> {
651 fn zcash_deserialize<R: std::io::Read>(
652 mut reader: R,
653 ) -> Result<Self, crate::serialization::SerializationError> {
654 Ok(reader.read_u64::<LittleEndian>()?.try_into()?)
655 }
656}
657
658#[derive(Clone, Copy, Default, Debug, PartialEq, Eq)]
660pub struct DeferredPoolBalanceChange(Amount);
661
662impl DeferredPoolBalanceChange {
663 pub fn new(amount: Amount) -> Self {
665 Self(amount)
666 }
667
668 pub fn zero() -> Self {
670 Self(Amount::zero())
671 }
672
673 pub fn value(self) -> Amount {
675 self.0
676 }
677}