@@ -779,67 +779,98 @@ where
779
779
fn calculate_amount_to_forward_per_htlc (
780
780
htlcs : & [ InterceptedHTLC ] , total_amt_to_forward_msat : u64 ,
781
781
) -> Vec < ( InterceptId , u64 ) > {
782
+ // TODO: we should eventually make sure the HTLCs are all above ChannelDetails::next_outbound_minimum_msat
782
783
let total_received_msat: u64 =
783
784
htlcs. iter ( ) . map ( |htlc| htlc. expected_outbound_amount_msat ) . sum ( ) ;
784
785
785
- let mut fee_remaining_msat = total_received_msat - total_amt_to_forward_msat;
786
- let total_fee_msat = fee_remaining_msat;
786
+ match total_received_msat. checked_sub ( total_amt_to_forward_msat) {
787
+ Some ( total_fee_msat) => {
788
+ let mut fee_remaining_msat = total_fee_msat;
787
789
788
- let mut per_htlc_forwards = vec ! [ ] ;
790
+ let mut per_htlc_forwards = vec ! [ ] ;
789
791
790
- for ( index, htlc) in htlcs. iter ( ) . enumerate ( ) {
791
- let proportional_fee_amt_msat =
792
- total_fee_msat * htlc. expected_outbound_amount_msat / total_received_msat;
792
+ for ( index, htlc) in htlcs. iter ( ) . enumerate ( ) {
793
+ let proportional_fee_amt_msat =
794
+ total_fee_msat * ( htlc. expected_outbound_amount_msat / total_received_msat) ;
793
795
794
- let mut actual_fee_amt_msat = core:: cmp:: min ( fee_remaining_msat, proportional_fee_amt_msat) ;
795
- fee_remaining_msat -= actual_fee_amt_msat;
796
+ let mut actual_fee_amt_msat =
797
+ core:: cmp:: min ( fee_remaining_msat, proportional_fee_amt_msat) ;
798
+ fee_remaining_msat -= actual_fee_amt_msat;
796
799
797
- if index == htlcs. len ( ) - 1 {
798
- actual_fee_amt_msat += fee_remaining_msat;
799
- }
800
+ if index == htlcs. len ( ) - 1 {
801
+ actual_fee_amt_msat += fee_remaining_msat;
802
+ }
800
803
801
- let amount_to_forward_msat = htlc. expected_outbound_amount_msat - actual_fee_amt_msat;
804
+ let amount_to_forward_msat =
805
+ htlc. expected_outbound_amount_msat . saturating_sub ( actual_fee_amt_msat) ;
802
806
803
- per_htlc_forwards. push ( ( htlc. intercept_id , amount_to_forward_msat) )
804
- }
807
+ per_htlc_forwards. push ( ( htlc. intercept_id , amount_to_forward_msat) )
808
+ }
805
809
806
- per_htlc_forwards
810
+ per_htlc_forwards
811
+ }
812
+ None => Vec :: new ( ) ,
813
+ }
807
814
}
808
815
809
816
#[ cfg( test) ]
810
817
mod tests {
811
818
812
819
use super :: * ;
820
+ use proptest:: prelude:: * ;
813
821
814
- #[ test]
815
- fn test_calculate_amount_to_forward ( ) {
816
- // TODO: Use proptest to generate random allocations
817
- let htlcs = vec ! [
818
- InterceptedHTLC {
819
- intercept_id: InterceptId ( [ 0 ; 32 ] ) ,
820
- expected_outbound_amount_msat: 1000 ,
821
- } ,
822
- InterceptedHTLC {
823
- intercept_id: InterceptId ( [ 1 ; 32 ] ) ,
824
- expected_outbound_amount_msat: 2000 ,
825
- } ,
826
- InterceptedHTLC {
827
- intercept_id: InterceptId ( [ 2 ; 32 ] ) ,
828
- expected_outbound_amount_msat: 3000 ,
829
- } ,
830
- ] ;
831
-
832
- let total_amt_to_forward_msat = 5000 ;
833
-
834
- let result = calculate_amount_to_forward_per_htlc ( & htlcs, total_amt_to_forward_msat) ;
822
+ const MAX_VALUE_MSAT : u64 = 21_000_000_0000_0000_000 ;
835
823
836
- assert_eq ! ( result[ 0 ] . 0 , htlcs[ 0 ] . intercept_id) ;
837
- assert_eq ! ( result[ 0 ] . 1 , 834 ) ;
824
+ fn arb_forward_amounts ( ) -> impl Strategy < Value = ( u64 , u64 , u64 , u64 ) > {
825
+ ( 1u64 ..MAX_VALUE_MSAT , 1u64 ..MAX_VALUE_MSAT , 1u64 ..MAX_VALUE_MSAT , 1u64 ..MAX_VALUE_MSAT )
826
+ . prop_map ( |( a, b, c, d) | {
827
+ ( a, b, c, core:: cmp:: min ( d, a. saturating_add ( b) . saturating_add ( c) ) )
828
+ } )
829
+ }
838
830
839
- assert_eq ! ( result[ 1 ] . 0 , htlcs[ 1 ] . intercept_id) ;
840
- assert_eq ! ( result[ 1 ] . 1 , 1667 ) ;
831
+ proptest ! {
832
+ #[ test]
833
+ fn test_calculate_amount_to_forward( ( o_0, o_1, o_2, total_amt_to_forward_msat) in arb_forward_amounts( ) ) {
834
+ let htlcs = vec![
835
+ InterceptedHTLC {
836
+ intercept_id: InterceptId ( [ 0 ; 32 ] ) ,
837
+ expected_outbound_amount_msat: o_0
838
+ } ,
839
+ InterceptedHTLC {
840
+ intercept_id: InterceptId ( [ 1 ; 32 ] ) ,
841
+ expected_outbound_amount_msat: o_1
842
+ } ,
843
+ InterceptedHTLC {
844
+ intercept_id: InterceptId ( [ 2 ; 32 ] ) ,
845
+ expected_outbound_amount_msat: o_2
846
+ } ,
847
+ ] ;
848
+
849
+ let result = calculate_amount_to_forward_per_htlc( & htlcs, total_amt_to_forward_msat) ;
850
+ let total_received_msat = o_0 + o_1 + o_2;
851
+
852
+ if total_received_msat < total_amt_to_forward_msat {
853
+ assert_eq!( result. len( ) , 0 ) ;
854
+ } else {
855
+ assert_ne!( result. len( ) , 0 ) ;
856
+ assert_eq!( result[ 0 ] . 0 , htlcs[ 0 ] . intercept_id) ;
857
+ assert_eq!( result[ 1 ] . 0 , htlcs[ 1 ] . intercept_id) ;
858
+ assert_eq!( result[ 2 ] . 0 , htlcs[ 2 ] . intercept_id) ;
859
+ assert!( result[ 0 ] . 1 <= o_0) ;
860
+ assert!( result[ 1 ] . 1 <= o_1) ;
861
+ assert!( result[ 2 ] . 1 <= o_2) ;
862
+
863
+ let result_sum = result. iter( ) . map( |( _, f) | f) . sum:: <u64 >( ) ;
864
+ assert!( result_sum >= total_amt_to_forward_msat) ;
865
+ let five_pct = result_sum as f32 * 0.1 ;
866
+ let fair_share_0 = ( ( o_0 as f32 / total_received_msat as f32 ) * result_sum as f32 ) . max( o_0 as f32 ) ;
867
+ assert!( result[ 0 ] . 1 as f32 <= fair_share_0 + five_pct) ;
868
+ let fair_share_1 = ( ( o_1 as f32 / total_received_msat as f32 ) * result_sum as f32 ) . max( o_1 as f32 ) ;
869
+ assert!( result[ 1 ] . 1 as f32 <= fair_share_1 + five_pct) ;
870
+ let fair_share_2 = ( ( o_2 as f32 / total_received_msat as f32 ) * result_sum as f32 ) . max( o_2 as f32 ) ;
871
+ assert!( result[ 2 ] . 1 as f32 <= fair_share_2 + five_pct) ;
872
+ }
841
873
842
- assert_eq ! ( result[ 2 ] . 0 , htlcs[ 2 ] . intercept_id) ;
843
- assert_eq ! ( result[ 2 ] . 1 , 2499 ) ;
874
+ }
844
875
}
845
876
}
0 commit comments