1 /++ 2 This module contains summation algorithms. 3 4 License: $(HTTP www.apache.org/licenses/LICENSE-2.0, Apache-2.0) 5 6 Authors: Ilia Ki 7 8 Copyright: 2020 Ilia Ki, Kaleidic Associates Advisory Limited, Symmetry Investments 9 +/ 10 module mir.math.sum; 11 12 /// 13 version(mir_test) 14 unittest 15 { 16 import mir.ndslice.slice: sliced; 17 import mir.ndslice.topology: map; 18 auto ar = [1, 1e100, 1, -1e100].sliced.map!"a * 10_000"; 19 const r = 20_000; 20 assert(r == ar.sum!"kbn"); 21 assert(r == ar.sum!"kb2"); 22 assert(r == ar.sum!"precise"); 23 assert(r == ar.sum!"decimal"); 24 } 25 26 /// Decimal precise summation 27 version(mir_test) 28 unittest 29 { 30 auto ar = [777.7, -777]; 31 assert(ar.sum!"decimal" == 0.7); 32 assert(sum!"decimal"(777.7, -777) == 0.7); 33 34 // The exact binary reuslt is 0.7000000000000455 35 assert(ar[0] + ar[1] == 0.7000000000000455); 36 assert(ar.sum!"fast" == 0.7000000000000455); 37 assert(ar.sum!"kahan" == 0.7000000000000455); 38 assert(ar.sum!"kbn" == 0.7000000000000455); 39 assert(ar.sum!"kb2" == 0.7000000000000455); 40 assert(ar.sum!"precise" == 0.7000000000000455); 41 42 assert([1e-20, 1].sum!"decimal" == 1); 43 } 44 45 /// 46 version(mir_test) 47 unittest 48 { 49 import mir.ndslice.slice: sliced, slicedField; 50 import mir.ndslice.topology: map, iota, retro; 51 import mir.ndslice.concatenation: concatenation; 52 import mir.math.common; 53 auto ar = 1000 54 .iota 55 .map!(n => 1.7L.pow(n+1) - 1.7L.pow(n)) 56 ; 57 real d = 1.7L.pow(1000); 58 assert(sum!"precise"(concatenation(ar, [-d].sliced).slicedField) == -1); 59 assert(sum!"precise"(ar.retro, -d) == -1); 60 } 61 62 /++ 63 `Naive`, `Pairwise` and `Kahan` algorithms can be used for user defined types. 64 +/ 65 version(mir_test) 66 unittest 67 { 68 import mir.internal.utility: isFloatingPoint; 69 static struct Quaternion(F) 70 if (isFloatingPoint!F) 71 { 72 F[4] rijk; 73 74 /// + and - operator overloading 75 Quaternion opBinary(string op)(auto ref const Quaternion rhs) const 76 if (op == "+" || op == "-") 77 { 78 Quaternion ret ; 79 foreach (i, ref e; ret.rijk) 80 mixin("e = rijk[i] "~op~" rhs.rijk[i];"); 81 return ret; 82 } 83 84 /// += and -= operator overloading 85 Quaternion opOpAssign(string op)(auto ref const Quaternion rhs) 86 if (op == "+" || op == "-") 87 { 88 foreach (i, ref e; rijk) 89 mixin("e "~op~"= rhs.rijk[i];"); 90 return this; 91 } 92 93 ///constructor with single FP argument 94 this(F f) 95 { 96 rijk[] = f; 97 } 98 99 ///assigment with single FP argument 100 void opAssign(F f) 101 { 102 rijk[] = f; 103 } 104 } 105 106 Quaternion!double q, p, r; 107 q.rijk = [0, 1, 2, 4]; 108 p.rijk = [3, 4, 5, 9]; 109 r.rijk = [3, 5, 7, 13]; 110 111 assert(r == [p, q].sum!"naive"); 112 assert(r == [p, q].sum!"pairwise"); 113 assert(r == [p, q].sum!"kahan"); 114 } 115 116 /++ 117 All summation algorithms available for complex numbers. 118 +/ 119 version(mir_test) 120 unittest 121 { 122 import mir.complex: Complex; 123 124 auto ar = [Complex!double(1.0, 2), Complex!double(2.0, 3), Complex!double(3.0, 4), Complex!double(4.0, 5)]; 125 Complex!double r = Complex!double(10.0, 14); 126 assert(r == ar.sum!"fast"); 127 assert(r == ar.sum!"naive"); 128 assert(r == ar.sum!"pairwise"); 129 assert(r == ar.sum!"kahan"); 130 version(LDC) // DMD Internal error: backend/cgxmm.c 628 131 { 132 assert(r == ar.sum!"kbn"); 133 assert(r == ar.sum!"kb2"); 134 } 135 assert(r == ar.sum!"precise"); 136 assert(r == ar.sum!"decimal"); 137 } 138 139 /// 140 version(mir_test) 141 @safe pure nothrow unittest 142 { 143 import mir.ndslice.topology: repeat, iota; 144 145 //simple integral summation 146 assert(sum([ 1, 2, 3, 4]) == 10); 147 148 //with initial value 149 assert(sum([ 1, 2, 3, 4], 5) == 15); 150 151 //with integral promotion 152 assert(sum([false, true, true, false, true]) == 3); 153 assert(sum(ubyte.max.repeat(100)) == 25_500); 154 155 //The result may overflow 156 assert(uint.max.repeat(3).sum == 4_294_967_293U ); 157 //But a seed can be used to change the summation primitive 158 assert(uint.max.repeat(3).sum(ulong.init) == 12_884_901_885UL); 159 160 //Floating point summation 161 assert(sum([1.0, 2.0, 3.0, 4.0]) == 10); 162 163 //Type overriding 164 static assert(is(typeof(sum!double([1F, 2F, 3F, 4F])) == double)); 165 static assert(is(typeof(sum!double([1F, 2F, 3F, 4F], 5F)) == double)); 166 assert(sum([1F, 2, 3, 4]) == 10); 167 assert(sum([1F, 2, 3, 4], 5F) == 15); 168 169 //Force pair-wise floating point summation on large integers 170 import mir.math : approxEqual; 171 assert(iota!long([4096], uint.max / 2).sum(0.0) 172 .approxEqual((uint.max / 2) * 4096.0 + 4096.0 * 4096.0 / 2)); 173 } 174 175 /// Precise summation 176 version(mir_test) 177 nothrow @nogc unittest 178 { 179 import mir.ndslice.topology: iota, map; 180 import core.stdc.tgmath: pow; 181 assert(iota(1000).map!(n => 1.7L.pow(real(n)+1) - 1.7L.pow(real(n))) 182 .sum!"precise" == -1 + 1.7L.pow(1000.0L)); 183 } 184 185 /// Precise summation with output range 186 version(mir_test) 187 nothrow @nogc unittest 188 { 189 import mir.ndslice.topology: iota, map; 190 import mir.math.common; 191 auto r = iota(1000).map!(n => 1.7L.pow(n+1) - 1.7L.pow(n)); 192 Summator!(real, Summation.precise) s = 0.0; 193 s.put(r); 194 s -= 1.7L.pow(1000); 195 assert(s.sum == -1); 196 } 197 198 /// Precise summation with output range 199 version(mir_test) 200 nothrow @nogc unittest 201 { 202 import mir.math.common; 203 float M = 2.0f ^^ (float.max_exp-1); 204 double N = 2.0 ^^ (float.max_exp-1); 205 auto s = Summator!(float, Summation.precise)(0); 206 s += M; 207 s += M; 208 assert(float.infinity == s.sum); //infinity 209 auto e = cast(Summator!(double, Summation.precise)) s; 210 assert(e.sum < double.infinity); 211 assert(N+N == e.sum()); //finite number 212 } 213 214 /// Moving mean 215 version(mir_test) 216 @safe pure nothrow @nogc 217 unittest 218 { 219 import mir.internal.utility: isFloatingPoint; 220 import mir.math.sum; 221 import mir.ndslice.topology: linspace; 222 import mir.rc.array: rcarray; 223 224 struct MovingAverage(T) 225 if (isFloatingPoint!T) 226 { 227 import mir.math.stat: MeanAccumulator; 228 229 MeanAccumulator!(T, Summation.precise) meanAccumulator; 230 double[] circularBuffer; 231 size_t frontIndex; 232 233 @disable this(this); 234 235 auto avg() @property const 236 { 237 return meanAccumulator.mean; 238 } 239 240 this(double[] buffer) 241 { 242 assert(buffer.length); 243 circularBuffer = buffer; 244 meanAccumulator.put(buffer); 245 } 246 247 ///operation without rounding 248 void put(T x) 249 { 250 import mir.utility: swap; 251 meanAccumulator.summator += x; 252 swap(circularBuffer[frontIndex++], x); 253 frontIndex = frontIndex == circularBuffer.length ? 0 : frontIndex; 254 meanAccumulator.summator -= x; 255 } 256 } 257 258 /// ma always keeps precise average of last 1000 elements 259 auto x = linspace!double([1000], [0.0, 999]).rcarray; 260 auto ma = MovingAverage!double(x[]); 261 assert(ma.avg == (1000 * 999 / 2) / 1000.0); 262 /// move by 10 elements 263 foreach(e; linspace!double([10], [1000.0, 1009.0])) 264 ma.put(e); 265 assert(ma.avg == (1010 * 1009 / 2 - 10 * 9 / 2) / 1000.0); 266 } 267 268 /// Arbitrary sum 269 version(mir_test) 270 @safe pure nothrow 271 unittest 272 { 273 import mir.complex; 274 alias C = Complex!double; 275 assert(sum(1, 2, 3, 4) == 10); 276 assert(sum!float(1, 2, 3, 4) == 10f); 277 assert(sum(1f, 2, 3, 4) == 10f); 278 assert(sum(C(1.0, 2), C(2, 3), C(3, 4), C(4, 5)) == C(10, 14)); 279 } 280 281 version(X86) 282 version = X86_Any; 283 version(X86_64) 284 version = X86_Any; 285 286 /++ 287 SIMD Vectors 288 Bugs: ICE 1662 (dmd only) 289 +/ 290 version(LDC) 291 version(X86_Any) 292 version(mir_test) 293 unittest 294 { 295 import core.simd; 296 import std.meta : AliasSeq; 297 double2 a = 1, b = 2, c = 3, d = 6; 298 with(Summation) 299 { 300 foreach (algo; AliasSeq!(naive, fast, pairwise, kahan)) 301 { 302 assert([a, b, c].sum!algo.array == d.array); 303 assert([a, b].sum!algo(c).array == d.array); 304 } 305 } 306 } 307 308 import std.traits; 309 private alias AliasSeq(T...) = T; 310 import mir.internal.utility: Iota, isComplex; 311 import mir.math.common: fabs; 312 313 private alias isNaN = x => x != x; 314 private alias isFinite = x => x.fabs < x.infinity; 315 private alias isInfinity = x => x.fabs == x.infinity; 316 317 318 private template chainSeq(size_t n) 319 { 320 static if (n) 321 alias chainSeq = AliasSeq!(n, chainSeq!(n / 2)); 322 else 323 alias chainSeq = AliasSeq!(); 324 } 325 326 /++ 327 Summation algorithms. 328 +/ 329 enum Summation 330 { 331 /++ 332 Performs `pairwise` summation for floating point based types and `fast` summation for integral based types. 333 +/ 334 appropriate, 335 336 /++ 337 $(WEB en.wikipedia.org/wiki/Pairwise_summation, Pairwise summation) algorithm. 338 +/ 339 pairwise, 340 341 /++ 342 Precise summation algorithm. 343 The value of the sum is rounded to the nearest representable 344 floating-point number using the $(LUCKY round-half-to-even rule). 345 The result can differ from the exact value on 32bit `x86`, `nextDown(proir) <= result && result <= nextUp(proir)`. 346 The current implementation re-establish special value semantics across iterations (i.e. handling ±inf). 347 348 References: $(LINK2 http://www.cs.cmu.edu/afs/cs/project/quake/public/papers/robust-arithmetic.ps, 349 "Adaptive Precision Floating-Point Arithmetic and Fast Robust Geometric Predicates", Jonathan Richard Shewchuk), 350 $(LINK2 http://bugs.python.org/file10357/msum4.py, Mark Dickinson's post at bugs.python.org). 351 +/ 352 353 /+ 354 Precise summation function as msum() by Raymond Hettinger in 355 <http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/393090>, 356 enhanced with the exact partials sum and roundoff from Mark 357 Dickinson's post at <http://bugs.python.org/file10357/msum4.py>. 358 See those links for more details, proofs and other references. 359 IEEE 754R floating point semantics are assumed. 360 +/ 361 precise, 362 363 /++ 364 Precise decimal summation algorithm. 365 366 The elements of the sum are converted to a shortest decimal representation that being converted back would result the same floating-point number. 367 The resulting decimal elements are summed without rounding. 368 The decimal sum is converted back to a binary floating point representation using round-half-to-even rule. 369 370 See_also: The $(HTTPS github.com/ulfjack/ryu, Ryu algorithm) 371 +/ 372 decimal, 373 374 /++ 375 $(WEB en.wikipedia.org/wiki/Kahan_summation, Kahan summation) algorithm. 376 +/ 377 /+ 378 --------------------- 379 s := x[1] 380 c := 0 381 FOR k := 2 TO n DO 382 y := x[k] - c 383 t := s + y 384 c := (t - s) - y 385 s := t 386 END DO 387 --------------------- 388 +/ 389 kahan, 390 391 /++ 392 $(LUCKY Kahan-Babuška-Neumaier summation algorithm). `KBN` gives more accurate results then `Kahan`. 393 +/ 394 /+ 395 --------------------- 396 s := x[1] 397 c := 0 398 FOR i := 2 TO n DO 399 t := s + x[i] 400 IF ABS(s) >= ABS(x[i]) THEN 401 c := c + ((s-t)+x[i]) 402 ELSE 403 c := c + ((x[i]-t)+s) 404 END IF 405 s := t 406 END DO 407 s := s + c 408 --------------------- 409 +/ 410 kbn, 411 412 /++ 413 $(LUCKY Generalized Kahan-Babuška summation algorithm), order 2. `KB2` gives more accurate results then `Kahan` and `KBN`. 414 +/ 415 /+ 416 --------------------- 417 s := 0 ; cs := 0 ; ccs := 0 418 FOR j := 1 TO n DO 419 t := s + x[i] 420 IF ABS(s) >= ABS(x[i]) THEN 421 c := (s-t) + x[i] 422 ELSE 423 c := (x[i]-t) + s 424 END IF 425 s := t 426 t := cs + c 427 IF ABS(cs) >= ABS(c) THEN 428 cc := (cs-t) + c 429 ELSE 430 cc := (c-t) + cs 431 END IF 432 cs := t 433 ccs := ccs + cc 434 END FOR 435 RETURN s+cs+ccs 436 --------------------- 437 +/ 438 kb2, 439 440 /++ 441 Naive algorithm (one by one). 442 +/ 443 naive, 444 445 /++ 446 SIMD optimized summation algorithm. 447 +/ 448 fast, 449 } 450 451 /++ 452 Output range for summation. 453 +/ 454 struct Summator(T, Summation summation) 455 if (isMutable!T) 456 { 457 import mir.internal.utility: isComplex; 458 static if (is(T == class) || is(T == interface) || hasElaborateAssign!T && !isComplex!T) 459 static assert (summation == Summation.naive, 460 "Classes, interfaces, and structures with " 461 ~ "elaborate constructor support only naive summation."); 462 463 static if (summation == Summation.fast) 464 { 465 version (LDC) 466 { 467 import ldc.attributes: fastmath; 468 alias attr = fastmath; 469 } 470 else 471 { 472 alias attr = AliasSeq!(); 473 } 474 } 475 else 476 { 477 alias attr = AliasSeq!(); 478 } 479 480 @attr: 481 482 static if (summation == Summation.pairwise) { 483 private enum bool fastPairwise = 484 is(F == float) || 485 is(F == double) || 486 (isComplex!F && F.sizeof <= 16) || 487 is(F : __vector(W[N]), W, size_t N); 488 //false; 489 } 490 491 alias F = T; 492 493 static if (summation == Summation.precise) 494 { 495 import std.internal.scopebuffer; 496 import mir.appender; 497 import mir.math.ieee: signbit; 498 private: 499 enum F M = (cast(F)(2)) ^^ (T.max_exp - 1); 500 auto partials = scopedBuffer!(F, 8 * T.sizeof); 501 //sum for NaN and infinity. 502 F s = summationInitValue!F; 503 //Overflow Degree. Count of 2^^F.max_exp minus count of -(2^^F.max_exp) 504 sizediff_t o; 505 506 507 /++ 508 Compute the sum of a list of nonoverlapping floats. 509 On input, partials is a list of nonzero, nonspecial, 510 nonoverlapping floats, strictly increasing in magnitude, but 511 possibly not all having the same sign. 512 On output, the sum of partials gives the error in the returned 513 result, which is correctly rounded (using the round-half-to-even 514 rule). 515 Two floating point values x and y are non-overlapping if the least significant nonzero 516 bit of x is more significant than the most significant nonzero bit of y, or vice-versa. 517 +/ 518 static F partialsReduce(F s, in F[] partials) 519 in 520 { 521 debug(numeric) assert(!partials.length || .isFinite(s)); 522 } 523 do 524 { 525 bool _break; 526 foreach_reverse (i, y; partials) 527 { 528 s = partialsReducePred(s, y, i ? partials[i-1] : 0, _break); 529 if (_break) 530 break; 531 debug(numeric) assert(.isFinite(s)); 532 } 533 return s; 534 } 535 536 static F partialsReducePred(F s, F y, F z, out bool _break) 537 out(result) 538 { 539 debug(numeric) assert(.isFinite(result)); 540 } 541 do 542 { 543 F x = s; 544 s = x + y; 545 F d = s - x; 546 F l = y - d; 547 debug(numeric) 548 { 549 assert(.isFinite(x)); 550 assert(.isFinite(y)); 551 assert(.isFinite(s)); 552 assert(fabs(y) < fabs(x)); 553 } 554 if (l) 555 { 556 //Make half-even rounding work across multiple partials. 557 //Needed so that sum([1e-16, 1, 1e16]) will round-up the last 558 //digit to two instead of down to zero (the 1e-16 makes the 1 559 //slightly closer to two). Can guarantee commutativity. 560 if (z && !signbit(l * z)) 561 { 562 l *= 2; 563 x = s + l; 564 F t = x - s; 565 if (l == t) 566 s = x; 567 } 568 _break = true; 569 } 570 return s; 571 } 572 573 //Returns corresponding infinity if is overflow and 0 otherwise. 574 F overflow()() const 575 { 576 if (o == 0) 577 return 0; 578 if (partials.length && (o == -1 || o == 1) && signbit(o * partials.data[$-1])) 579 { 580 // problem case: decide whether result is representable 581 F x = o * M; 582 F y = partials.data[$-1] / 2; 583 F h = x + y; 584 F d = h - x; 585 F l = (y - d) * 2; 586 y = h * 2; 587 d = h + l; 588 F t = d - h; 589 version(X86) 590 { 591 if (!.isInfinity(cast(T)y) || !.isInfinity(sum())) 592 return 0; 593 } 594 else 595 { 596 if (!.isInfinity(cast(T)y) || 597 ((partials.length > 1 && !signbit(l * partials.data[$-2])) && t == l)) 598 return 0; 599 } 600 } 601 return F.infinity * o; 602 } 603 } 604 else 605 static if (summation == Summation.kb2) 606 { 607 F s = summationInitValue!F; 608 F cs = summationInitValue!F; 609 F ccs = summationInitValue!F; 610 } 611 else 612 static if (summation == Summation.kbn) 613 { 614 F s = summationInitValue!F; 615 F c = summationInitValue!F; 616 } 617 else 618 static if (summation == Summation.kahan) 619 { 620 F s = summationInitValue!F; 621 F c = summationInitValue!F; 622 F y = summationInitValue!F; // do not declare in the loop/put (algo can be used for matrixes and etc) 623 F t = summationInitValue!F; // ditto 624 } 625 else 626 static if (summation == Summation.pairwise) 627 { 628 package size_t counter; 629 size_t index; 630 static if (fastPairwise) 631 { 632 enum registersCount= 16; 633 F[size_t.sizeof * 8] partials; 634 } 635 else 636 { 637 F[size_t.sizeof * 8] partials; 638 } 639 } 640 else 641 static if (summation == Summation.naive) 642 { 643 F s = summationInitValue!F; 644 } 645 else 646 static if (summation == Summation.fast) 647 { 648 F s = summationInitValue!F; 649 } 650 else 651 static if (summation == Summation.decimal) 652 { 653 import mir.bignum.decimal; 654 Decimal!128 s; 655 T ss = 0; 656 } 657 else 658 static assert(0, "Unsupported summation type for std.numeric.Summator."); 659 660 661 public: 662 663 /// 664 this()(T n) 665 { 666 static if (summation == Summation.precise) 667 { 668 s = 0.0; 669 o = 0; 670 if (n) put(n); 671 } 672 else 673 static if (summation == Summation.kb2) 674 { 675 s = n; 676 static if (isComplex!T) 677 { 678 cs = Complex!float(0, 0); 679 ccs = Complex!float(0, 0); 680 } 681 else 682 { 683 cs = 0.0; 684 ccs = 0.0; 685 } 686 } 687 else 688 static if (summation == Summation.kbn) 689 { 690 s = n; 691 static if (isComplex!T) 692 c = Complex!float(0, 0); 693 else 694 c = 0.0; 695 } 696 else 697 static if (summation == Summation.kahan) 698 { 699 s = n; 700 static if (isComplex!T) 701 c = Complex!float(0, 0); 702 else 703 c = 0.0; 704 } 705 else 706 static if (summation == Summation.pairwise) 707 { 708 counter = index = 1; 709 partials[0] = n; 710 } 711 else 712 static if (summation == Summation.naive) 713 { 714 s = n; 715 } 716 else 717 static if (summation == Summation.fast) 718 { 719 s = n; 720 } 721 else 722 static if (summation == Summation.decimal) 723 { 724 ss = 0; 725 if (!(-n.infinity < n && n < n.infinity)) 726 { 727 ss = n; 728 n = 0; 729 } 730 s = n; 731 } 732 else 733 static assert(0); 734 } 735 736 ///Adds `n` to the internal partial sums. 737 void put(N)(N n) 738 if (__traits(compiles, {T a = n; a = n; a += n;})) 739 { 740 static if (isCompesatorAlgorithm!summation) 741 F x = n; 742 static if (summation == Summation.precise) 743 { 744 if (.isFinite(x)) 745 { 746 size_t i; 747 auto partials_data = partials.data; 748 foreach (y; partials_data[]) 749 { 750 F h = x + y; 751 if (.isInfinity(cast(T)h)) 752 { 753 if (fabs(x) < fabs(y)) 754 { 755 F t = x; x = y; y = t; 756 } 757 //h == -F.infinity 758 if (signbit(h)) 759 { 760 x += M; 761 x += M; 762 o--; 763 } 764 //h == +F.infinity 765 else 766 { 767 x -= M; 768 x -= M; 769 o++; 770 } 771 debug(numeric) assert(x.isFinite); 772 h = x + y; 773 } 774 debug(numeric) assert(h.isFinite); 775 F l; 776 if (fabs(x) < fabs(y)) 777 { 778 F t = h - y; 779 l = x - t; 780 } 781 else 782 { 783 F t = h - x; 784 l = y - t; 785 } 786 debug(numeric) assert(l.isFinite); 787 if (l) 788 { 789 partials_data[i++] = l; 790 } 791 x = h; 792 } 793 partials.shrinkTo(i); 794 if (x) 795 { 796 partials.put(x); 797 } 798 } 799 else 800 { 801 s += x; 802 } 803 } 804 else 805 static if (summation == Summation.kb2) 806 { 807 static if (isFloatingPoint!F) 808 { 809 F t = s + x; 810 F c = 0; 811 if (fabs(s) >= fabs(x)) 812 { 813 F d = s - t; 814 c = d + x; 815 } 816 else 817 { 818 F d = x - t; 819 c = d + s; 820 } 821 s = t; 822 t = cs + c; 823 if (fabs(cs) >= fabs(c)) 824 { 825 F d = cs - t; 826 d += c; 827 ccs += d; 828 } 829 else 830 { 831 F d = c - t; 832 d += cs; 833 ccs += d; 834 } 835 cs = t; 836 } 837 else 838 { 839 F t = s + x; 840 if (fabs(s.re) < fabs(x.re)) 841 { 842 auto s_re = s.re; 843 auto x_re = x.re; 844 s = F(x_re, s.im); 845 x = F(s_re, x.im); 846 } 847 if (fabs(s.im) < fabs(x.im)) 848 { 849 auto s_im = s.im; 850 auto x_im = x.im; 851 s = F(s.re, x_im); 852 x = F(x.re, s_im); 853 } 854 F c = (s-t)+x; 855 s = t; 856 if (fabs(cs.re) < fabs(c.re)) 857 { 858 auto c_re = c.re; 859 auto cs_re = cs.re; 860 c = F(cs_re, c.im); 861 cs = F(c_re, cs.im); 862 } 863 if (fabs(cs.im) < fabs(c.im)) 864 { 865 auto c_im = c.im; 866 auto cs_im = cs.im; 867 c = F(c.re, cs_im); 868 cs = F(cs.re, c_im); 869 } 870 F d = cs - t; 871 d += c; 872 ccs += d; 873 cs = t; 874 } 875 } 876 else 877 static if (summation == Summation.kbn) 878 { 879 static if (isFloatingPoint!F) 880 { 881 F t = s + x; 882 if (fabs(s) >= fabs(x)) 883 { 884 F d = s - t; 885 d += x; 886 c += d; 887 } 888 else 889 { 890 F d = x - t; 891 d += s; 892 c += d; 893 } 894 s = t; 895 } 896 else 897 { 898 F t = s + x; 899 if (fabs(s.re) < fabs(x.re)) 900 { 901 auto s_re = s.re; 902 auto x_re = x.re; 903 s = F(x_re, s.im); 904 x = F(s_re, x.im); 905 } 906 if (fabs(s.im) < fabs(x.im)) 907 { 908 auto s_im = s.im; 909 auto x_im = x.im; 910 s = F(s.re, x_im); 911 x = F(x.re, s_im); 912 } 913 F d = s - t; 914 d += x; 915 c += d; 916 s = t; 917 } 918 } 919 else 920 static if (summation == Summation.kahan) 921 { 922 y = x - c; 923 t = s + y; 924 c = t - s; 925 c -= y; 926 s = t; 927 } 928 else 929 static if (summation == Summation.pairwise) 930 { 931 import mir.bitop: cttz; 932 ++counter; 933 partials[index] = n; 934 foreach (_; 0 .. cttz(counter)) 935 { 936 immutable newIndex = index - 1; 937 partials[newIndex] += partials[index]; 938 index = newIndex; 939 } 940 ++index; 941 } 942 else 943 static if (summation == Summation.naive) 944 { 945 s += n; 946 } 947 else 948 static if (summation == Summation.fast) 949 { 950 s += n; 951 } 952 else 953 static if (summation == summation.decimal) 954 { 955 import mir.bignum.internal.ryu.generic_128: genericBinaryToDecimal; 956 if (-n.infinity < n && n < n.infinity) 957 { 958 auto decimal = genericBinaryToDecimal(n); 959 s += decimal; 960 } 961 else 962 { 963 ss += n; 964 } 965 } 966 else 967 static assert(0); 968 } 969 970 ///ditto 971 void put(Range)(Range r) 972 if (isIterable!Range && !is(Range : __vector(V[N]), V, size_t N)) 973 { 974 static if (summation == Summation.pairwise && fastPairwise && isDynamicArray!Range) 975 { 976 F[registersCount] v; 977 foreach (i, n; chainSeq!registersCount) 978 { 979 if (r.length >= n * 2) do 980 { 981 foreach (j; Iota!n) 982 v[j] = cast(F) r[j]; 983 foreach (j; Iota!n) 984 v[j] += cast(F) r[n + j]; 985 foreach (m; chainSeq!(n / 2)) 986 foreach (j; Iota!m) 987 v[j] += v[m + j]; 988 put(v[0]); 989 r = r[n * 2 .. $]; 990 } 991 while (!i && r.length >= n * 2); 992 } 993 if (r.length) 994 { 995 put(cast(F) r[0]); 996 r = r[1 .. $]; 997 } 998 assert(r.length == 0); 999 } 1000 else 1001 static if (summation == Summation.fast) 1002 { 1003 static if (isComplex!F) 1004 F s0 = F(0, 0f); 1005 else 1006 F s0 = 0; 1007 foreach (ref elem; r) 1008 s0 += elem; 1009 s += s0; 1010 } 1011 else 1012 { 1013 foreach (ref elem; r) 1014 put(elem); 1015 } 1016 } 1017 1018 import mir.ndslice.slice; 1019 1020 /// ditto 1021 void put(Range: Slice!(Iterator, N, kind), Iterator, size_t N, SliceKind kind)(Range r) 1022 { 1023 static if (N > 1 && kind == Contiguous) 1024 { 1025 import mir.ndslice.topology: flattened; 1026 this.put(r.flattened); 1027 } 1028 else 1029 static if (isPointer!Iterator && kind == Contiguous) 1030 { 1031 this.put(r.field); 1032 } 1033 else 1034 static if (summation == Summation.fast && N == 1) 1035 { 1036 static if (isComplex!F) 1037 F s0 = F(0, 0f); 1038 else 1039 F s0 = 0; 1040 import mir.algorithm.iteration: reduce; 1041 s0 = s0.reduce!"a + b"(r); 1042 s += s0; 1043 } 1044 else 1045 { 1046 foreach(elem; r) 1047 this.put(elem); 1048 } 1049 } 1050 1051 /+ 1052 Adds `x` to the internal partial sums. 1053 This operation doesn't re-establish special 1054 value semantics across iterations (i.e. handling ±inf). 1055 Preconditions: `isFinite(x)`. 1056 +/ 1057 version(none) 1058 static if (summation == Summation.precise) 1059 package void unsafePut()(F x) 1060 in { 1061 assert(.isFinite(x)); 1062 } 1063 do { 1064 size_t i; 1065 foreach (y; partials.data[]) 1066 { 1067 F h = x + y; 1068 debug(numeric) assert(.isFinite(h)); 1069 F l; 1070 if (fabs(x) < fabs(y)) 1071 { 1072 F t = h - y; 1073 l = x - t; 1074 } 1075 else 1076 { 1077 F t = h - x; 1078 l = y - t; 1079 } 1080 debug(numeric) assert(.isFinite(l)); 1081 if (l) 1082 { 1083 partials.data[i++] = l; 1084 } 1085 x = h; 1086 } 1087 partials.length = i; 1088 if (x) 1089 { 1090 partials.put(x); 1091 } 1092 } 1093 1094 ///Returns the value of the sum. 1095 T sum()() scope const 1096 { 1097 /++ 1098 Returns the value of the sum, rounded to the nearest representable 1099 floating-point number using the round-half-to-even rule. 1100 The result can differ from the exact value on `X86`, `nextDown`proir) <= result && result <= nextUp(proir)). 1101 +/ 1102 static if (summation == Summation.precise) 1103 { 1104 debug(mir_sum) 1105 { 1106 foreach (y; partials.data[]) 1107 { 1108 assert(y); 1109 assert(y.isFinite); 1110 } 1111 //TODO: Add Non-Overlapping check to std.math 1112 import mir.ndslice.slice: sliced; 1113 import mir.ndslice.sorting: isSorted; 1114 import mir.ndslice.topology: map; 1115 assert(partials.data[].sliced.map!fabs.isSorted); 1116 } 1117 1118 if (s) 1119 return s; 1120 auto parts = partials.data[]; 1121 F y = 0.0; 1122 //pick last 1123 if (parts.length) 1124 { 1125 y = parts[$-1]; 1126 parts = parts[0..$-1]; 1127 } 1128 if (o) 1129 { 1130 immutable F of = o; 1131 if (y && (o == -1 || o == 1) && signbit(of * y)) 1132 { 1133 // problem case: decide whether result is representable 1134 y /= 2; 1135 F x = of * M; 1136 immutable F h = x + y; 1137 F t = h - x; 1138 F l = (y - t) * 2; 1139 y = h * 2; 1140 if (.isInfinity(cast(T)y)) 1141 { 1142 // overflow, except in edge case... 1143 x = h + l; 1144 t = x - h; 1145 y = parts.length && t == l && !signbit(l*parts[$-1]) ? 1146 x * 2 : 1147 F.infinity * of; 1148 parts = null; 1149 } 1150 else if (l) 1151 { 1152 bool _break; 1153 y = partialsReducePred(y, l, parts.length ? parts[$-1] : 0, _break); 1154 if (_break) 1155 parts = null; 1156 } 1157 } 1158 else 1159 { 1160 y = F.infinity * of; 1161 parts = null; 1162 } 1163 } 1164 return partialsReduce(y, parts); 1165 } 1166 else 1167 static if (summation == Summation.kb2) 1168 { 1169 return s + (cs + ccs); 1170 } 1171 else 1172 static if (summation == Summation.kbn) 1173 { 1174 return s + c; 1175 } 1176 else 1177 static if (summation == Summation.kahan) 1178 { 1179 return s; 1180 } 1181 else 1182 static if (summation == Summation.pairwise) 1183 { 1184 F s = summationInitValue!T; 1185 assert((counter == 0) == (index == 0)); 1186 foreach_reverse (ref e; partials[0 .. index]) 1187 { 1188 static if (is(F : __vector(W[N]), W, size_t N)) 1189 s += cast(Unqual!F) e; //DMD bug workaround 1190 else 1191 s += e; 1192 } 1193 return s; 1194 } 1195 else 1196 static if (summation == Summation.naive) 1197 { 1198 return s; 1199 } 1200 else 1201 static if (summation == Summation.fast) 1202 { 1203 return s; 1204 } 1205 else 1206 static if (summation == Summation.decimal) 1207 { 1208 return cast(T) s + ss; 1209 } 1210 else 1211 static assert(0); 1212 } 1213 1214 version(none) 1215 static if (summation == Summation.precise) 1216 F partialsSum()() const 1217 { 1218 debug(numeric) partialsDebug; 1219 auto parts = partials.data[]; 1220 F y = 0.0; 1221 //pick last 1222 if (parts.length) 1223 { 1224 y = parts[$-1]; 1225 parts = parts[0..$-1]; 1226 } 1227 return partialsReduce(y, parts); 1228 } 1229 1230 ///Returns `Summator` with extended internal partial sums. 1231 C opCast(C : Summator!(P, _summation), P, Summation _summation)() const 1232 if ( 1233 _summation == summation && 1234 isMutable!C && 1235 P.max_exp >= T.max_exp && 1236 P.mant_dig >= T.mant_dig 1237 ) 1238 { 1239 static if (is(P == T)) 1240 return this; 1241 else 1242 static if (summation == Summation.precise) 1243 { 1244 auto ret = typeof(return).init; 1245 ret.s = s; 1246 ret.o = o; 1247 foreach (p; partials.data[]) 1248 { 1249 ret.partials.put(p); 1250 } 1251 enum exp_diff = P.max_exp / T.max_exp; 1252 static if (exp_diff) 1253 { 1254 if (ret.o) 1255 { 1256 immutable f = ret.o / exp_diff; 1257 immutable t = cast(int)(ret.o % exp_diff); 1258 ret.o = f; 1259 ret.put((P(2) ^^ T.max_exp) * t); 1260 } 1261 } 1262 return ret; 1263 } 1264 else 1265 static if (summation == Summation.kb2) 1266 { 1267 auto ret = typeof(return).init; 1268 ret.s = s; 1269 ret.cs = cs; 1270 ret.ccs = ccs; 1271 return ret; 1272 } 1273 else 1274 static if (summation == Summation.kbn) 1275 { 1276 auto ret = typeof(return).init; 1277 ret.s = s; 1278 ret.c = c; 1279 return ret; 1280 } 1281 else 1282 static if (summation == Summation.kahan) 1283 { 1284 auto ret = typeof(return).init; 1285 ret.s = s; 1286 ret.c = c; 1287 return ret; 1288 } 1289 else 1290 static if (summation == Summation.pairwise) 1291 { 1292 auto ret = typeof(return).init; 1293 ret.counter = counter; 1294 ret.index = index; 1295 foreach (i; 0 .. index) 1296 ret.partials[i] = partials[i]; 1297 return ret; 1298 } 1299 else 1300 static if (summation == Summation.naive) 1301 { 1302 auto ret = typeof(return).init; 1303 ret.s = s; 1304 return ret; 1305 } 1306 else 1307 static if (summation == Summation.fast) 1308 { 1309 auto ret = typeof(return).init; 1310 ret.s = s; 1311 return ret; 1312 } 1313 else 1314 static assert(0); 1315 } 1316 1317 /++ 1318 `cast(C)` operator overloading. Returns `cast(C)sum()`. 1319 See also: `cast` 1320 +/ 1321 C opCast(C)() const if (is(Unqual!C == T)) 1322 { 1323 return cast(C)sum(); 1324 } 1325 1326 ///Operator overloading. 1327 // opAssign should initialize partials. 1328 void opAssign(T rhs) 1329 { 1330 static if (summation == Summation.precise) 1331 { 1332 partials.reset; 1333 s = 0.0; 1334 o = 0; 1335 if (rhs) put(rhs); 1336 } 1337 else 1338 static if (summation == Summation.kb2) 1339 { 1340 s = rhs; 1341 static if (isComplex!T) 1342 { 1343 cs = T(0, 0f); 1344 ccs = T(0.0, 0f); 1345 } 1346 else 1347 { 1348 cs = 0.0; 1349 ccs = 0.0; 1350 } 1351 } 1352 else 1353 static if (summation == Summation.kbn) 1354 { 1355 s = rhs; 1356 static if (isComplex!T) 1357 c = T(0, 0f); 1358 else 1359 c = 0.0; 1360 } 1361 else 1362 static if (summation == Summation.kahan) 1363 { 1364 s = rhs; 1365 static if (isComplex!T) 1366 c = T(0, 0f); 1367 else 1368 c = 0.0; 1369 } 1370 else 1371 static if (summation == Summation.pairwise) 1372 { 1373 counter = 1; 1374 index = 1; 1375 partials[0] = rhs; 1376 } 1377 else 1378 static if (summation == Summation.naive) 1379 { 1380 s = rhs; 1381 } 1382 else 1383 static if (summation == Summation.fast) 1384 { 1385 s = rhs; 1386 } 1387 else 1388 static if (summation == summation.decimal) 1389 { 1390 __ctor(rhs); 1391 } 1392 else 1393 static assert(0); 1394 } 1395 1396 ///ditto 1397 void opOpAssign(string op : "+")(T rhs) 1398 { 1399 put(rhs); 1400 } 1401 1402 ///ditto 1403 void opOpAssign(string op : "+")(ref const Summator rhs) 1404 { 1405 static if (summation == Summation.precise) 1406 { 1407 s += rhs.s; 1408 o += rhs.o; 1409 foreach (f; rhs.partials.data[]) 1410 put(f); 1411 } 1412 else 1413 static if (summation == Summation.kb2) 1414 { 1415 put(rhs.ccs); 1416 put(rhs.cs); 1417 put(rhs.s); 1418 } 1419 else 1420 static if (summation == Summation.kbn) 1421 { 1422 put(rhs.c); 1423 put(rhs.s); 1424 } 1425 else 1426 static if (summation == Summation.kahan) 1427 { 1428 put(rhs.s); 1429 } 1430 else 1431 static if (summation == Summation.pairwise) 1432 { 1433 foreach_reverse (e; rhs.partials[0 .. rhs.index]) 1434 put(e); 1435 counter -= rhs.index; 1436 counter += rhs.counter; 1437 } 1438 else 1439 static if (summation == Summation.naive) 1440 { 1441 put(rhs.s); 1442 } 1443 else 1444 static if (summation == Summation.fast) 1445 { 1446 put(rhs.s); 1447 } 1448 else 1449 static assert(0); 1450 } 1451 1452 ///ditto 1453 void opOpAssign(string op : "-")(T rhs) 1454 { 1455 static if (summation == Summation.precise) 1456 { 1457 put(-rhs); 1458 } 1459 else 1460 static if (summation == Summation.kb2) 1461 { 1462 put(-rhs); 1463 } 1464 else 1465 static if (summation == Summation.kbn) 1466 { 1467 put(-rhs); 1468 } 1469 else 1470 static if (summation == Summation.kahan) 1471 { 1472 y = 0.0; 1473 y -= rhs; 1474 y -= c; 1475 t = s + y; 1476 c = t - s; 1477 c -= y; 1478 s = t; 1479 } 1480 else 1481 static if (summation == Summation.pairwise) 1482 { 1483 put(-rhs); 1484 } 1485 else 1486 static if (summation == Summation.naive) 1487 { 1488 s -= rhs; 1489 } 1490 else 1491 static if (summation == Summation.fast) 1492 { 1493 s -= rhs; 1494 } 1495 else 1496 static assert(0); 1497 } 1498 1499 ///ditto 1500 void opOpAssign(string op : "-")(ref const Summator rhs) 1501 { 1502 static if (summation == Summation.precise) 1503 { 1504 s -= rhs.s; 1505 o -= rhs.o; 1506 foreach (f; rhs.partials.data[]) 1507 put(-f); 1508 } 1509 else 1510 static if (summation == Summation.kb2) 1511 { 1512 put(-rhs.ccs); 1513 put(-rhs.cs); 1514 put(-rhs.s); 1515 } 1516 else 1517 static if (summation == Summation.kbn) 1518 { 1519 put(-rhs.c); 1520 put(-rhs.s); 1521 } 1522 else 1523 static if (summation == Summation.kahan) 1524 { 1525 this -= rhs.s; 1526 } 1527 else 1528 static if (summation == Summation.pairwise) 1529 { 1530 foreach_reverse (e; rhs.partials[0 .. rhs.index]) 1531 put(-e); 1532 counter -= rhs.index; 1533 counter += rhs.counter; 1534 } 1535 else 1536 static if (summation == Summation.naive) 1537 { 1538 s -= rhs.s; 1539 } 1540 else 1541 static if (summation == Summation.fast) 1542 { 1543 s -= rhs.s; 1544 } 1545 else 1546 static assert(0); 1547 } 1548 1549 /// 1550 1551 version(mir_test) 1552 @nogc nothrow unittest 1553 { 1554 import mir.math.common; 1555 import mir.ndslice.topology: iota, map; 1556 auto r1 = iota(500).map!(a => 1.7L.pow(a+1) - 1.7L.pow(a)); 1557 auto r2 = iota([500], 500).map!(a => 1.7L.pow(a+1) - 1.7L.pow(a)); 1558 Summator!(real, Summation.precise) s1 = 0, s2 = 0.0; 1559 foreach (e; r1) s1 += e; 1560 foreach (e; r2) s2 -= e; 1561 s1 -= s2; 1562 s1 -= 1.7L.pow(1000); 1563 assert(s1.sum == -1); 1564 } 1565 1566 1567 version(mir_test) 1568 @nogc nothrow unittest 1569 { 1570 with(Summation) 1571 foreach (summation; AliasSeq!(kahan, kbn, kb2, precise, pairwise)) 1572 foreach (T; AliasSeq!(float, double, real)) 1573 { 1574 Summator!(T, summation) sum = 1; 1575 sum += 3; 1576 assert(sum.sum == 4); 1577 sum -= 10; 1578 assert(sum.sum == -6); 1579 Summator!(T, summation) sum2 = 3; 1580 sum -= sum2; 1581 assert(sum.sum == -9); 1582 sum2 = 100; 1583 sum += 100; 1584 assert(sum.sum == 91); 1585 auto sum3 = cast(Summator!(real, summation))sum; 1586 assert(sum3.sum == 91); 1587 sum = sum2; 1588 } 1589 } 1590 1591 1592 version(mir_test) 1593 @nogc nothrow unittest 1594 { 1595 import mir.math.common: approxEqual; 1596 with(Summation) 1597 foreach (summation; AliasSeq!(naive, fast)) 1598 foreach (T; AliasSeq!(float, double, real)) 1599 { 1600 Summator!(T, summation) sum = 1; 1601 sum += 3.5; 1602 assert(sum.sum.approxEqual(4.5)); 1603 sum = 2; 1604 assert(sum.sum == 2); 1605 sum -= 4; 1606 assert(sum.sum.approxEqual(-2)); 1607 } 1608 } 1609 1610 static if (summation == Summation.precise) 1611 { 1612 ///Returns `true` if current sum is a NaN. 1613 bool isNaN()() const 1614 { 1615 return .isNaN(s); 1616 } 1617 1618 ///Returns `true` if current sum is finite (not infinite or NaN). 1619 bool isFinite()() const 1620 { 1621 if (s) 1622 return false; 1623 return !overflow; 1624 } 1625 1626 ///Returns `true` if current sum is ±∞. 1627 bool isInfinity()() const 1628 { 1629 return .isInfinity(s) || overflow(); 1630 } 1631 } 1632 else static if (isFloatingPoint!F) 1633 { 1634 ///Returns `true` if current sum is a NaN. 1635 bool isNaN()() const 1636 { 1637 return .isNaN(sum()); 1638 } 1639 1640 ///Returns `true` if current sum is finite (not infinite or NaN). 1641 bool isFinite()() const 1642 { 1643 return .isFinite(sum()); 1644 } 1645 1646 ///Returns `true` if current sum is ±∞. 1647 bool isInfinity()() const 1648 { 1649 return .isInfinity(sum()); 1650 } 1651 } 1652 else 1653 { 1654 //User defined types 1655 } 1656 } 1657 1658 version(mir_test) 1659 unittest 1660 { 1661 import mir.functional: Tuple, tuple; 1662 import mir.ndslice.topology: map, iota, retro; 1663 import mir.array.allocation: array; 1664 import std.math: isInfinity, isFinite, isNaN; 1665 1666 Summator!(double, Summation.precise) summator = 0.0; 1667 1668 enum double M = (cast(double)2) ^^ (double.max_exp - 1); 1669 Tuple!(double[], double)[] tests = [ 1670 tuple(new double[0], 0.0), 1671 tuple([0.0], 0.0), 1672 tuple([1e100, 1.0, -1e100, 1e-100, 1e50, -1, -1e50], 1e-100), 1673 tuple([1e308, 1e308, -1e308], 1e308), 1674 tuple([-1e308, 1e308, 1e308], 1e308), 1675 tuple([1e308, -1e308, 1e308], 1e308), 1676 tuple([M, M, -2.0^^1000], 1.7976930277114552e+308), 1677 tuple([M, M, M, M, -M, -M, -M], 8.9884656743115795e+307), 1678 tuple([2.0^^53, -0.5, -2.0^^-54], 2.0^^53-1.0), 1679 tuple([2.0^^53, 1.0, 2.0^^-100], 2.0^^53+2.0), 1680 tuple([2.0^^53+10.0, 1.0, 2.0^^-100], 2.0^^53+12.0), 1681 tuple([2.0^^53-4.0, 0.5, 2.0^^-54], 2.0^^53-3.0), 1682 tuple([M-2.0^^970, -1, M], 1.7976931348623157e+308), 1683 tuple([double.max, double.max*2.^^-54], double.max), 1684 tuple([double.max, double.max*2.^^-53], double.infinity), 1685 tuple(iota([1000], 1).map!(a => 1.0/a).array , 7.4854708605503451), 1686 tuple(iota([1000], 1).map!(a => (-1.0)^^a/a).array, -0.69264743055982025), //0.693147180559945309417232121458176568075500134360255254120680... 1687 tuple(iota([1000], 1).map!(a => 1.0/a).retro.array , 7.4854708605503451), 1688 tuple(iota([1000], 1).map!(a => (-1.0)^^a/a).retro.array, -0.69264743055982025), 1689 tuple([double.infinity, -double.infinity, double.nan], double.nan), 1690 tuple([double.nan, double.infinity, -double.infinity], double.nan), 1691 tuple([double.infinity, double.nan, double.infinity], double.nan), 1692 tuple([double.infinity, double.infinity], double.infinity), 1693 tuple([double.infinity, -double.infinity], double.nan), 1694 tuple([-double.infinity, 1e308, 1e308, -double.infinity], -double.infinity), 1695 tuple([M-2.0^^970, 0.0, M], double.infinity), 1696 tuple([M-2.0^^970, 1.0, M], double.infinity), 1697 tuple([M, M], double.infinity), 1698 tuple([M, M, -1], double.infinity), 1699 tuple([M, M, M, M, -M, -M], double.infinity), 1700 tuple([M, M, M, M, -M, M], double.infinity), 1701 tuple([-M, -M, -M, -M], -double.infinity), 1702 tuple([M, M, -2.^^971], double.max), 1703 tuple([M, M, -2.^^970], double.infinity), 1704 tuple([-2.^^970, M, M, -0X0.0000000000001P-0 * 2.^^-1022], double.max), 1705 tuple([M, M, -2.^^970, 0X0.0000000000001P-0 * 2.^^-1022], double.infinity), 1706 tuple([-M, 2.^^971, -M], -double.max), 1707 tuple([-M, -M, 2.^^970], -double.infinity), 1708 tuple([-M, -M, 2.^^970, 0X0.0000000000001P-0 * 2.^^-1022], -double.max), 1709 tuple([-0X0.0000000000001P-0 * 2.^^-1022, -M, -M, 2.^^970], -double.infinity), 1710 tuple([2.^^930, -2.^^980, M, M, M, -M], 1.7976931348622137e+308), 1711 tuple([M, M, -1e307], 1.6976931348623159e+308), 1712 tuple([1e16, 1., 1e-16], 10_000_000_000_000_002.0), 1713 ]; 1714 foreach (i, test; tests) 1715 { 1716 summator = 0.0; 1717 foreach (t; test.a) summator.put(t); 1718 auto r = test.b; 1719 auto s = summator.sum; 1720 assert(summator.isNaN() == r.isNaN()); 1721 assert(summator.isFinite() == r.isFinite()); 1722 assert(summator.isInfinity() == r.isInfinity()); 1723 assert(s == r || s.isNaN && r.isNaN); 1724 } 1725 } 1726 1727 /++ 1728 Sums elements of `r`, which must be a finite 1729 iterable. 1730 1731 A seed may be passed to `sum`. Not only will this seed be used as an initial 1732 value, but its type will be used if it is not specified. 1733 1734 Note that these specialized summing algorithms execute more primitive operations 1735 than vanilla summation. Therefore, if in certain cases maximum speed is required 1736 at expense of precision, one can use $(LREF Summation.fast). 1737 1738 Returns: 1739 The sum of all the elements in the range r. 1740 +/ 1741 template sum(F, Summation summation = Summation.appropriate) 1742 if (isMutable!F) 1743 { 1744 /// 1745 template sum(Range) 1746 if (isIterable!Range && isMutable!Range) 1747 { 1748 import core.lifetime: move; 1749 1750 /// 1751 F sum(Range r) 1752 { 1753 static if (isComplex!F && (summation == Summation.precise || summation == Summation.decimal)) 1754 { 1755 return sum(r, summationInitValue!F); 1756 } 1757 else 1758 { 1759 static if (summation == Summation.decimal) 1760 { 1761 Summator!(F, summation) sum = void; 1762 sum = 0; 1763 } 1764 else 1765 { 1766 Summator!(F, ResolveSummationType!(summation, Range, sumType!Range)) sum; 1767 } 1768 sum.put(r.move); 1769 return sum.sum; 1770 } 1771 } 1772 1773 /// 1774 F sum(Range r, F seed) 1775 { 1776 static if (isComplex!F && (summation == Summation.precise || summation == Summation.decimal)) 1777 { 1778 alias T = typeof(F.init.re); 1779 static if (summation == Summation.decimal) 1780 { 1781 Summator!(T, summation) sumRe = void; 1782 sumRe = seed.re; 1783 1784 Summator!(T, summation) sumIm = void; 1785 sumIm = seed.im; 1786 } 1787 else 1788 { 1789 auto sumRe = Summator!(T, Summation.precise)(seed.re); 1790 auto sumIm = Summator!(T, Summation.precise)(seed.im); 1791 } 1792 import mir.ndslice.slice: isSlice; 1793 static if (isSlice!Range) 1794 { 1795 import mir.algorithm.iteration: each; 1796 r.each!((auto ref elem) 1797 { 1798 sumRe.put(elem.re); 1799 sumIm.put(elem.im); 1800 }); 1801 } 1802 else 1803 { 1804 foreach (ref elem; r) 1805 { 1806 sumRe.put(elem.re); 1807 sumIm.put(elem.im); 1808 } 1809 } 1810 return F(sumRe.sum, sumIm.sum); 1811 } 1812 else 1813 { 1814 static if (summation == Summation.decimal) 1815 { 1816 Summator!(F, summation) sum = void; 1817 sum = seed; 1818 } 1819 else 1820 { 1821 auto sum = Summator!(F, ResolveSummationType!(summation, Range, F))(seed); 1822 } 1823 sum.put(r.move); 1824 return sum.sum; 1825 } 1826 } 1827 } 1828 1829 /// 1830 template sum(Range) 1831 if (isIterable!Range && !isMutable!Range) 1832 { 1833 /// 1834 F sum(Range r) 1835 { 1836 return .sum!(F, summation)(r.lightConst); 1837 } 1838 1839 /// 1840 F sum(Range r, F seed) 1841 { 1842 return .sum!(F, summation)(r.lightConst, seed); 1843 } 1844 } 1845 1846 /// 1847 F sum(scope const F[] r...) 1848 { 1849 static if (isComplex!F && (summation == Summation.precise || summation == Summation.decimal)) 1850 { 1851 return sum(r, summationInitValue!F); 1852 } 1853 else 1854 { 1855 Summator!(F, ResolveSummationType!(summation, const(F)[], F)) sum; 1856 sum.put(r); 1857 return sum.sum; 1858 } 1859 } 1860 } 1861 1862 ///ditto 1863 template sum(Summation summation = Summation.appropriate) 1864 { 1865 /// 1866 sumType!Range sum(Range)(Range r) 1867 if (isIterable!Range && isMutable!Range) 1868 { 1869 import core.lifetime: move; 1870 alias F = typeof(return); 1871 alias s = .sum!(F, ResolveSummationType!(summation, Range, F)); 1872 return s(r.move); 1873 } 1874 1875 /// 1876 F sum(Range, F)(Range r, F seed) 1877 if (isIterable!Range && isMutable!Range) 1878 { 1879 import core.lifetime: move; 1880 alias s = .sum!(F, ResolveSummationType!(summation, Range, F)); 1881 return s(r.move, seed); 1882 } 1883 1884 /// 1885 sumType!Range sum(Range)(Range r) 1886 if (isIterable!Range && !isMutable!Range) 1887 { 1888 return .sum!(typeof(return), summation)(r.lightConst); 1889 } 1890 1891 /// 1892 F sum(Range, F)(Range r, F seed) 1893 if (isIterable!Range && !isMutable!Range) 1894 { 1895 return .sum!(F, summation)(r.lightConst, seed); 1896 } 1897 1898 /// 1899 sumType!T sum(T)(scope const T[] ar...) 1900 { 1901 alias F = typeof(return); 1902 return .sum!(F, ResolveSummationType!(summation, F[], F))(ar); 1903 } 1904 } 1905 1906 ///ditto 1907 template sum(F, string summation) 1908 if (isMutable!F) 1909 { 1910 mixin("alias sum = .sum!(F, Summation." ~ summation ~ ");"); 1911 } 1912 1913 ///ditto 1914 template sum(string summation) 1915 { 1916 mixin("alias sum = .sum!(Summation." ~ summation ~ ");"); 1917 } 1918 1919 private static immutable jaggedMsg = "sum: each slice should have the same length"; 1920 version(D_Exceptions) 1921 static immutable jaggedException = new Exception(jaggedMsg); 1922 1923 /++ 1924 Sum slices with a naive algorithm. 1925 +/ 1926 template sumSlices() 1927 { 1928 import mir.primitives: DeepElementType; 1929 import mir.ndslice.slice: Slice, SliceKind, isSlice; 1930 /// 1931 auto sumSlices(Iterator, SliceKind kind)(Slice!(Iterator, 1, kind) sliceOfSlices) 1932 if (isSlice!(DeepElementType!(Slice!(Iterator, 1, kind)))) 1933 { 1934 import mir.ndslice.topology: as; 1935 import mir.ndslice.allocation: slice; 1936 alias T = Unqual!(DeepElementType!(DeepElementType!(Slice!(Iterator, 1, kind)))); 1937 import mir.ndslice: slice; 1938 if (sliceOfSlices.length == 0) 1939 return typeof(slice(as!T(sliceOfSlices.front))).init; 1940 auto ret = slice(as!T(sliceOfSlices.front)); 1941 sliceOfSlices.popFront; 1942 foreach (sl; sliceOfSlices) 1943 { 1944 if (sl.length != ret.length) 1945 { 1946 version (D_Exceptions) 1947 { import mir.exception : toMutable; throw jaggedException.toMutable; } 1948 else 1949 assert(0); 1950 } 1951 ret[] += sl[]; 1952 } 1953 return ret; 1954 } 1955 } 1956 1957 /// 1958 version(mir_test) 1959 unittest 1960 { 1961 import mir.ndslice.topology: map, byDim; 1962 import mir.ndslice.slice: sliced; 1963 1964 auto ar = [[1, 2, 3], [10, 20, 30]]; 1965 assert(ar.map!sliced.sumSlices == [11, 22, 33]); 1966 1967 import mir.ndslice.fuse: fuse; 1968 auto a = [[[1.2], [2.1]], [[4.1], [5.2]]].fuse; 1969 auto s = a.byDim!0.sumSlices; 1970 assert(s == [[5.3], [7.300000000000001]]); 1971 } 1972 1973 version(mir_test) 1974 @safe pure nothrow unittest 1975 { 1976 static assert(is(typeof(sum([cast( byte)1])) == int)); 1977 static assert(is(typeof(sum([cast(ubyte)1])) == int)); 1978 static assert(is(typeof(sum([ 1, 2, 3, 4])) == int)); 1979 static assert(is(typeof(sum([ 1U, 2U, 3U, 4U])) == uint)); 1980 static assert(is(typeof(sum([ 1L, 2L, 3L, 4L])) == long)); 1981 static assert(is(typeof(sum([1UL, 2UL, 3UL, 4UL])) == ulong)); 1982 1983 int[] empty; 1984 assert(sum(empty) == 0); 1985 assert(sum([42]) == 42); 1986 assert(sum([42, 43]) == 42 + 43); 1987 assert(sum([42, 43, 44]) == 42 + 43 + 44); 1988 assert(sum([42, 43, 44, 45]) == 42 + 43 + 44 + 45); 1989 } 1990 1991 1992 version(mir_test) 1993 @safe pure nothrow unittest 1994 { 1995 static assert(is(typeof(sum([1.0, 2.0, 3.0, 4.0])) == double)); 1996 static assert(is(typeof(sum!double([ 1F, 2F, 3F, 4F])) == double)); 1997 const(float[]) a = [1F, 2F, 3F, 4F]; 1998 static assert(is(typeof(sum!double(a)) == double)); 1999 const(float)[] b = [1F, 2F, 3F, 4F]; 2000 static assert(is(typeof(sum!double(a)) == double)); 2001 2002 double[] empty; 2003 assert(sum(empty) == 0); 2004 assert(sum([42.]) == 42); 2005 assert(sum([42., 43.]) == 42 + 43); 2006 assert(sum([42., 43., 44.]) == 42 + 43 + 44); 2007 assert(sum([42., 43., 44., 45.5]) == 42 + 43 + 44 + 45.5); 2008 } 2009 2010 version(mir_test) 2011 @safe pure nothrow unittest 2012 { 2013 import mir.ndslice.topology: iota; 2014 assert(iota(2, 3).sum == 15); 2015 } 2016 2017 version(mir_test) 2018 @safe pure nothrow unittest 2019 { 2020 import std.container; 2021 static assert(is(typeof(sum!double(SList!float()[])) == double)); 2022 static assert(is(typeof(sum(SList!double()[])) == double)); 2023 static assert(is(typeof(sum(SList!real()[])) == real)); 2024 2025 assert(sum(SList!double()[]) == 0); 2026 assert(sum(SList!double(1)[]) == 1); 2027 assert(sum(SList!double(1, 2)[]) == 1 + 2); 2028 assert(sum(SList!double(1, 2, 3)[]) == 1 + 2 + 3); 2029 assert(sum(SList!double(1, 2, 3, 4)[]) == 10); 2030 } 2031 2032 2033 version(mir_test) 2034 pure nothrow unittest // 12434 2035 { 2036 import mir.ndslice.slice: sliced; 2037 import mir.ndslice.topology: map; 2038 immutable a = [10, 20]; 2039 auto s = a.sliced; 2040 auto s1 = sum(a); // Error 2041 auto s2 = s.map!(x => x).sum; // Error 2042 } 2043 2044 version(mir_test) 2045 unittest 2046 { 2047 import std.bigint; 2048 import mir.ndslice.topology: repeat; 2049 2050 auto a = BigInt("1_000_000_000_000_000_000").repeat(10); 2051 auto b = (ulong.max/2).repeat(10); 2052 auto sa = a.sum(); 2053 auto sb = b.sum(BigInt(0)); //reduce ulongs into bigint 2054 assert(sa == BigInt("10_000_000_000_000_000_000")); 2055 assert(sb == (BigInt(ulong.max/2) * 10)); 2056 } 2057 2058 version(mir_test) 2059 unittest 2060 { 2061 with(Summation) 2062 foreach (F; AliasSeq!(float, double, real)) 2063 { 2064 F[] ar = [1, 2, 3, 4]; 2065 F r = 10; 2066 assert(r == ar.sum!fast()); 2067 assert(r == ar.sum!pairwise()); 2068 assert(r == ar.sum!kahan()); 2069 assert(r == ar.sum!kbn()); 2070 assert(r == ar.sum!kb2()); 2071 } 2072 } 2073 2074 version(mir_test) 2075 unittest 2076 { 2077 assert(sum(1) == 1); 2078 assert(sum(1, 2, 3) == 6); 2079 assert(sum(1.0, 2.0, 3.0) == 6); 2080 } 2081 2082 version(mir_test) 2083 unittest 2084 { 2085 assert(sum!float(1) == 1f); 2086 assert(sum!float(1, 2, 3) == 6f); 2087 assert(sum!float(1.0, 2.0, 3.0) == 6f); 2088 } 2089 2090 version(mir_test) 2091 unittest 2092 { 2093 import mir.complex: Complex; 2094 2095 assert(sum(Complex!float(1.0, 1.0), Complex!float(2.0, 2.0), Complex!float(3.0, 3.0)) == Complex!float(6.0, 6.0)); 2096 assert(sum!(Complex!float)(Complex!float(1.0, 1.0), Complex!float(2.0, 2.0), Complex!float(3.0, 3.0)) == Complex!float(6.0, 6.0)); 2097 } 2098 2099 version(LDC) 2100 version(X86_Any) 2101 version(mir_test) 2102 unittest 2103 { 2104 import core.simd; 2105 static if (__traits(compiles, double2.init + double2.init)) 2106 { 2107 2108 alias S = Summation; 2109 alias sums = AliasSeq!(S.kahan, S.pairwise, S.naive, S.fast); 2110 2111 double2[] ar = [double2([1.0, 2]), double2([2, 3]), double2([3, 4]), double2([4, 6])]; 2112 double2 c = double2([10, 15]); 2113 2114 foreach (sumType; sums) 2115 { 2116 double2 s = ar.sum!(sumType); 2117 assert(s.array == c.array); 2118 } 2119 } 2120 } 2121 2122 version(LDC) 2123 version(X86_Any) 2124 version(mir_test) 2125 unittest 2126 { 2127 import core.simd; 2128 import mir.ndslice.topology: iota, as; 2129 2130 alias S = Summation; 2131 alias sums = AliasSeq!(S.kahan, S.pairwise, S.naive, S.fast, S.precise, 2132 S.kbn, S.kb2); 2133 2134 int[2] ns = [9, 101]; 2135 2136 foreach (n; ns) 2137 { 2138 foreach (sumType; sums) 2139 { 2140 auto ar = iota(n).as!double; 2141 double c = n * (n - 1) / 2; // gauss for n=100 2142 double s = ar.sum!(sumType); 2143 assert(s == c); 2144 } 2145 } 2146 } 2147 2148 // Confirm sum works for Slice!(const(double)*, 1)) 2149 version(mir_test) 2150 @safe pure nothrow 2151 unittest 2152 { 2153 import mir.ndslice.slice: sliced; 2154 double[] x = [1.0, 2, 3]; 2155 auto y = x.sliced; 2156 auto z = y.toConst; 2157 assert(z.sum == 6); 2158 assert(z.sum(0.0) == 6); 2159 assert(z.sum!double == 6); 2160 assert(z.sum!double(0.0) == 6); 2161 } 2162 2163 // Confirm sum works for const(Slice!(double*, 1)) 2164 version(mir_test) 2165 @safe pure nothrow 2166 unittest 2167 { 2168 import mir.ndslice.slice: sliced; 2169 double[] x = [1.0, 2, 3]; 2170 auto y = x.sliced; 2171 const z = y; 2172 assert(z.sum == 6); 2173 assert(z.sum(0.0) == 6); 2174 assert(z.sum!double == 6); 2175 assert(z.sum!double(0.0) == 6); 2176 } 2177 2178 // Confirm sum works for const(Slice!(const(double)*, 1)) 2179 version(mir_test) 2180 @safe pure nothrow 2181 unittest 2182 { 2183 import mir.ndslice.slice: sliced; 2184 double[] x = [1.0, 2, 3]; 2185 auto y = x.sliced; 2186 const z = y.toConst; 2187 assert(z.sum == 6); 2188 assert(z.sum(0.0) == 6); 2189 assert(z.sum!double == 6); 2190 assert(z.sum!double(0.0) == 6); 2191 } 2192 2193 package(mir) 2194 template ResolveSummationType(Summation summation, Range, F) 2195 { 2196 static if (summation == Summation.appropriate) 2197 { 2198 static if (isSummable!(Range, F)) 2199 enum ResolveSummationType = Summation.pairwise; 2200 else 2201 static if (is(F == class) || is(F == struct) || is(F == interface)) 2202 enum ResolveSummationType = Summation.naive; 2203 else 2204 enum ResolveSummationType = Summation.fast; 2205 } 2206 else 2207 { 2208 enum ResolveSummationType = summation; 2209 } 2210 } 2211 2212 private T summationInitValue(T)() 2213 { 2214 static if (__traits(compiles, {T a = 0.0;})) 2215 { 2216 T a = 0.0; 2217 return a; 2218 } 2219 else 2220 static if (__traits(compiles, {T a = 0;})) 2221 { 2222 T a = 0; 2223 return a; 2224 } 2225 else 2226 static if (__traits(compiles, {T a = 0 + 0fi;})) 2227 { 2228 T a = 0 + 0fi; 2229 return a; 2230 } 2231 else 2232 { 2233 return T.init; 2234 } 2235 } 2236 2237 package(mir) 2238 template elementType(T) 2239 { 2240 import mir.ndslice.slice: isSlice, DeepElementType; 2241 import std.traits: Unqual, ForeachType; 2242 2243 static if (isIterable!T) { 2244 static if (isSlice!T) 2245 alias elementType = Unqual!(DeepElementType!(T.This)); 2246 else 2247 alias elementType = Unqual!(ForeachType!T); 2248 } else { 2249 alias elementType = Unqual!T; 2250 } 2251 } 2252 2253 package(mir) 2254 template sumType(Range) 2255 { 2256 alias T = elementType!Range; 2257 2258 static if (__traits(compiles, { 2259 auto a = T.init + T.init; 2260 a += T.init; 2261 })) 2262 alias sumType = typeof(T.init + T.init); 2263 else 2264 static assert(0, "sumType: Can't sum elements of type " ~ T.stringof); 2265 } 2266 2267 /++ 2268 +/ 2269 template fillCollapseSums(Summation summation, alias combineParts, combineElements...) 2270 { 2271 import mir.ndslice.slice: Slice, SliceKind; 2272 /++ 2273 +/ 2274 auto ref fillCollapseSums(Iterator, SliceKind kind)(Slice!(Iterator, 1, kind) data) @property 2275 { 2276 import mir.algorithm.iteration; 2277 import mir.functional: naryFun; 2278 import mir.ndslice.topology: iota, triplets; 2279 foreach (triplet; data.length.iota.triplets) with(triplet) 2280 { 2281 auto ref ce(size_t i)() 2282 { 2283 static if (summation == Summation.fast) 2284 { 2285 return 2286 sum!summation(naryFun!(combineElements[i])(center, left )) + 2287 sum!summation(naryFun!(combineElements[i])(center, right)); 2288 } 2289 else 2290 { 2291 Summator!summation summator = 0; 2292 summator.put(naryFun!(combineElements[i])(center, left)); 2293 summator.put(naryFun!(combineElements[i])(center, right)); 2294 return summator.sum; 2295 } 2296 } 2297 alias sums = staticMap!(ce, Iota!(combineElements.length)); 2298 data[center] = naryFun!combineParts(center, sums); 2299 } 2300 } 2301 } 2302 2303 package: 2304 2305 template isSummable(F) 2306 { 2307 enum bool isSummable = 2308 __traits(compiles, 2309 { 2310 F a = 0.1, b, c; 2311 b = 2.3; 2312 c = a + b; 2313 c = a - b; 2314 a += b; 2315 a -= b; 2316 }); 2317 } 2318 2319 template isSummable(Range, F) 2320 { 2321 enum bool isSummable = 2322 isIterable!Range && 2323 isImplicitlyConvertible!(sumType!Range, F) && 2324 isSummable!F; 2325 } 2326 2327 version(mir_test) 2328 unittest 2329 { 2330 import mir.ndslice.topology: iota; 2331 static assert(isSummable!(typeof(iota([size_t.init])), double)); 2332 } 2333 2334 private enum bool isCompesatorAlgorithm(Summation summation) = 2335 summation == Summation.precise 2336 || summation == Summation.kb2 2337 || summation == Summation.kbn 2338 || summation == Summation.kahan; 2339 2340 2341 version(mir_test) 2342 unittest 2343 { 2344 import mir.ndslice; 2345 2346 auto p = iota([2, 3, 4, 5]); 2347 auto a = p.as!double; 2348 auto b = a.flattened; 2349 auto c = a.slice; 2350 auto d = c.flattened; 2351 auto s = p.flattened.sum; 2352 2353 assert(a.sum == s); 2354 assert(b.sum == s); 2355 assert(c.sum == s); 2356 assert(d.sum == s); 2357 2358 assert(a.canonical.sum == s); 2359 assert(b.canonical.sum == s); 2360 assert(c.canonical.sum == s); 2361 assert(d.canonical.sum == s); 2362 2363 assert(a.universal.transposed!3.sum == s); 2364 assert(b.universal.sum == s); 2365 assert(c.universal.transposed!3.sum == s); 2366 assert(d.universal.sum == s); 2367 2368 assert(a.sum!"fast" == s); 2369 assert(b.sum!"fast" == s); 2370 assert(c.sum!(float, "fast") == s); 2371 assert(d.sum!"fast" == s); 2372 2373 assert(a.canonical.sum!"fast" == s); 2374 assert(b.canonical.sum!"fast" == s); 2375 assert(c.canonical.sum!"fast" == s); 2376 assert(d.canonical.sum!"fast" == s); 2377 2378 assert(a.universal.transposed!3.sum!"fast" == s); 2379 assert(b.universal.sum!"fast" == s); 2380 assert(c.universal.transposed!3.sum!"fast" == s); 2381 assert(d.universal.sum!"fast" == s); 2382 2383 }