1 /++ 2 This is a submodule of $(MREF mir,ndslice). 3 4 Note: 5 The combination of 6 $(SUBREF topology, pairwise) with lambda `"a <= b"` (`"a < b"`) and $(SUBREF algorithm, all) can be used 7 to check if an ndslice is sorted (strictly monotonic). 8 $(SUBREF topology, iota) can be used to make an index. 9 $(SUBREF topology, map) and $(SUBREF topology, zip) can be used to create Schwartzian transform. 10 See also the examples in the module. 11 12 13 See_also: $(SUBREF topology, flattened) 14 15 `isSorted` and `isStrictlyMonotonic` 16 17 License: $(HTTP www.apache.org/licenses/LICENSE-2.0, Apache-2.0) 18 Copyright: 2020 Ilia Ki, Kaleidic Associates Advisory Limited, Symmetry Investments 19 Authors: Andrei Alexandrescu (Phobos), Ilia Ki (API, rework, Mir adoptation) 20 21 Macros: 22 SUBREF = $(REF_ALTTEXT $(TT $2), $2, mir, ndslice, $1)$(NBSP) 23 +/ 24 module mir.ndslice.sorting; 25 26 /// Check if ndslice is sorted, or strictly monotonic. 27 @safe pure version(mir_ndslice_test) unittest 28 { 29 import mir.algorithm.iteration: all; 30 import mir.ndslice.slice: sliced; 31 import mir.ndslice.sorting: sort; 32 import mir.ndslice.topology: pairwise; 33 34 auto arr = [1, 1, 2].sliced; 35 36 assert(arr.pairwise!"a <= b".all); 37 assert(!arr.pairwise!"a < b".all); 38 39 arr = [4, 3, 2, 1].sliced; 40 41 assert(!arr.pairwise!"a <= b".all); 42 assert(!arr.pairwise!"a < b".all); 43 44 sort(arr); 45 46 assert(arr.pairwise!"a <= b".all); 47 assert(arr.pairwise!"a < b".all); 48 } 49 50 /// Create index 51 version(mir_ndslice_test) unittest 52 { 53 import mir.algorithm.iteration: all; 54 import mir.ndslice.allocation: slice; 55 import mir.ndslice.slice: sliced; 56 import mir.ndslice.sorting: sort; 57 import mir.ndslice.topology: iota, pairwise; 58 59 auto arr = [4, 2, 3, 1].sliced; 60 61 auto index = arr.length.iota.slice; 62 index.sort!((a, b) => arr[a] < arr[b]); 63 64 assert(arr[index].pairwise!"a <= b".all); 65 } 66 67 /// Schwartzian transform 68 version(mir_ndslice_test) unittest 69 { 70 import mir.algorithm.iteration: all; 71 import mir.ndslice.allocation: slice; 72 import mir.ndslice.slice: sliced; 73 import mir.ndslice.sorting: sort; 74 import mir.ndslice.topology: zip, map, pairwise; 75 76 alias transform = (a) => (a - 3) ^^ 2; 77 78 auto arr = [4, 2, 3, 1].sliced; 79 80 arr.map!transform.slice.zip(arr).sort!((l, r) => l.a < r.a); 81 82 assert(arr.map!transform.pairwise!"a <= b".all); 83 } 84 85 import mir.ndslice.slice; 86 import mir.math.common: fmamath; 87 88 @fmamath: 89 90 @safe pure version(mir_ndslice_test) unittest 91 { 92 import mir.algorithm.iteration: all; 93 import mir.ndslice.topology: pairwise; 94 95 auto a = [1, 2, 3].sliced; 96 assert(a[0 .. 0].pairwise!"a <= b".all); 97 assert(a[0 .. 1].pairwise!"a <= b".all); 98 assert(a.pairwise!"a <= b".all); 99 auto b = [1, 3, 2].sliced; 100 assert(!b.pairwise!"a <= b".all); 101 102 // ignores duplicates 103 auto c = [1, 1, 2].sliced; 104 assert(c.pairwise!"a <= b".all); 105 } 106 107 @safe pure version(mir_ndslice_test) unittest 108 { 109 import mir.algorithm.iteration: all; 110 import mir.ndslice.topology: pairwise; 111 112 assert([1, 2, 3][0 .. 0].sliced.pairwise!"a < b".all); 113 assert([1, 2, 3][0 .. 1].sliced.pairwise!"a < b".all); 114 assert([1, 2, 3].sliced.pairwise!"a < b".all); 115 assert(![1, 3, 2].sliced.pairwise!"a < b".all); 116 assert(![1, 1, 2].sliced.pairwise!"a < b".all); 117 } 118 119 120 /++ 121 Sorts ndslice, array, or series. 122 123 See_also: $(SUBREF topology, flattened). 124 +/ 125 template sort(alias less = "a < b") 126 { 127 import mir.functional: naryFun; 128 import mir.series: Series; 129 static if (__traits(isSame, naryFun!less, less)) 130 { 131 @fmamath: 132 /++ 133 Sort n-dimensional slice. 134 +/ 135 Slice!(Iterator, N, kind) sort(Iterator, size_t N, SliceKind kind) 136 (Slice!(Iterator, N, kind) slice) 137 { 138 if (false) // break safety 139 { 140 import mir.utility : swapStars; 141 auto elem = typeof(*slice._iterator).init; 142 elem = elem; 143 auto l = less(elem, elem); 144 } 145 import mir.ndslice.topology: flattened; 146 if (slice.anyEmpty) 147 return slice; 148 .quickSortImpl!less(slice.flattened); 149 return slice; 150 } 151 152 /++ 153 Sort for arrays 154 +/ 155 T[] sort(T)(T[] ar) 156 { 157 return .sort!less(ar.sliced).field; 158 } 159 160 /++ 161 Sort for one-dimensional Series. 162 +/ 163 Series!(IndexIterator, Iterator, N, kind) 164 sort(IndexIterator, Iterator, size_t N, SliceKind kind) 165 (Series!(IndexIterator, Iterator, N, kind) series) 166 if (N == 1) 167 { 168 import mir.ndslice.sorting: sort; 169 import mir.ndslice.topology: zip; 170 with(series) 171 index.zip(data).sort!((a, b) => less(a.a, b.a)); 172 return series; 173 } 174 175 /++ 176 Sort for n-dimensional Series. 177 +/ 178 Series!(IndexIterator, Iterator, N, kind) 179 sort( 180 IndexIterator, 181 Iterator, 182 size_t N, 183 SliceKind kind, 184 SortIndexIterator, 185 DataIterator, 186 ) 187 ( 188 Series!(IndexIterator, Iterator, N, kind) series, 189 Slice!SortIndexIterator indexBuffer, 190 Slice!DataIterator dataBuffer, 191 ) 192 { 193 import mir.algorithm.iteration: each; 194 import mir.ndslice.sorting: sort; 195 import mir.ndslice.topology: iota, zip, ipack, evertPack; 196 197 assert(indexBuffer.length == series.length); 198 assert(dataBuffer.length == series.length); 199 indexBuffer[] = indexBuffer.length.iota!(typeof(indexBuffer.front)); 200 series.index.zip(indexBuffer).sort!((a, b) => less(a.a, b.a)); 201 series.data.ipack!1.evertPack.each!((sl){ 202 { 203 assert(sl.shape == dataBuffer.shape); 204 dataBuffer[] = sl[indexBuffer]; 205 sl[] = dataBuffer; 206 }}); 207 return series; 208 } 209 } 210 else 211 alias sort = .sort!(naryFun!less); 212 } 213 214 /// 215 @safe pure version(mir_ndslice_test) unittest 216 { 217 import mir.algorithm.iteration: all; 218 import mir.ndslice.slice; 219 import mir.ndslice.sorting: sort; 220 import mir.ndslice.topology: pairwise; 221 222 int[10] arr = [7,1,3,2,9,0,5,4,8,6]; 223 224 auto data = arr[].sliced(arr.length); 225 data.sort(); 226 assert(data.pairwise!"a <= b".all); 227 } 228 229 /// one-dimensional series 230 pure version(mir_ndslice_test) unittest 231 { 232 import mir.series; 233 234 auto index = [4, 2, 1, 3, 0].sliced; 235 auto data = [5.6, 3.4, 2.1, 7.8, 0.1].sliced; 236 auto series = index.series(data); 237 series.sort; 238 assert(series.index == [0, 1, 2, 3, 4]); 239 assert(series.data == [0.1, 2.1, 3.4, 7.8, 5.6]); 240 /// initial index and data are the same 241 assert(index.iterator is series.index.iterator); 242 assert(data.iterator is series.data.iterator); 243 244 foreach(obs; series) 245 { 246 static assert(is(typeof(obs) == Observation!(int, double))); 247 } 248 } 249 250 /// two-dimensional series 251 pure version(mir_ndslice_test) unittest 252 { 253 import mir.series; 254 import mir.ndslice.allocation: uninitSlice; 255 256 auto index = [4, 2, 3, 1].sliced; 257 auto data = 258 [2.1, 3.4, 259 5.6, 7.8, 260 3.9, 9.0, 261 4.0, 2.0].sliced(4, 2); 262 auto series = index.series(data); 263 264 series.sort( 265 uninitSlice!size_t(series.length), // index buffer 266 uninitSlice!double(series.length), // data buffer 267 ); 268 269 assert(series.index == [1, 2, 3, 4]); 270 assert(series.data == 271 [[4.0, 2.0], 272 [5.6, 7.8], 273 [3.9, 9.0], 274 [2.1, 3.4]]); 275 /// initial index and data are the same 276 assert(index.iterator is series.index.iterator); 277 assert(data.iterator is series.data.iterator); 278 } 279 280 void quickSortImpl(alias less, Iterator)(Slice!Iterator slice) @trusted 281 { 282 import mir.utility : swap, swapStars; 283 284 enum max_depth = 64; 285 enum naive_est = 1024 / slice.Element!0.sizeof; 286 enum size_t naive = 32 > naive_est ? 32 : naive_est; 287 //enum size_t naive = 1; 288 static assert(naive >= 1); 289 290 for(;;) 291 { 292 auto l = slice._iterator; 293 auto r = l; 294 r += slice.length; 295 296 if(slice.length <= 1) 297 return; 298 299 static if (naive > 1) 300 { 301 if (slice.length <= naive || __ctfe) 302 { 303 auto p = r; 304 --p; 305 while(p != l) 306 { 307 --p; 308 //static if (is(typeof(() nothrow 309 // { 310 // auto t = slice[0]; if (less(t, slice[0])) slice[0] = slice[0]; 311 // }))) 312 //{ 313 auto d = p; 314 import mir.functional: unref; 315 auto temp = unref(*d); 316 auto c = d; 317 ++c; 318 if (less(*c, temp)) 319 { 320 do 321 { 322 d[0] = *c; 323 ++d; 324 ++c; 325 } 326 while (c != r && less(*c, temp)); 327 d[0] = temp; 328 } 329 //} 330 //else 331 //{ 332 // auto d = p; 333 // auto c = d; 334 // ++c; 335 // while (less(*c, *d)) 336 // { 337 // swap(*d, *c); 338 // d = c; 339 // ++c; 340 // if (c == maxJ) break; 341 // } 342 //} 343 } 344 return; 345 } 346 } 347 348 // partition 349 auto lessI = l; 350 --r; 351 auto pivotIdx = l + slice.length / 2; 352 setPivot!less(slice.length, l, pivotIdx, r); 353 import mir.functional: unref; 354 auto pivot = unref(*pivotIdx); 355 --lessI; 356 auto greaterI = r; 357 swapStars(pivotIdx, greaterI); 358 359 outer: for (;;) 360 { 361 do ++lessI; 362 while (less(*lessI, pivot)); 363 assert(lessI <= greaterI, "sort: invalid comparison function."); 364 for (;;) 365 { 366 if (greaterI == lessI) 367 break outer; 368 --greaterI; 369 if (!less(pivot, *greaterI)) 370 break; 371 } 372 assert(lessI <= greaterI, "sort: invalid comparison function."); 373 if (lessI == greaterI) 374 break; 375 swapStars(lessI, greaterI); 376 } 377 378 swapStars(r, lessI); 379 380 ptrdiff_t len = lessI - l; 381 auto tail = slice[len + 1 .. $]; 382 slice = slice[0 .. len]; 383 if (tail.length > slice.length) 384 swap(slice, tail); 385 quickSortImpl!less(tail); 386 } 387 } 388 389 void setPivot(alias less, Iterator)(size_t length, ref Iterator l, ref Iterator mid, ref Iterator r) @trusted 390 { 391 if (length < 512) 392 { 393 if (length >= 32) 394 medianOf!less(l, mid, r); 395 return; 396 } 397 auto quarter = length >> 2; 398 auto b = mid - quarter; 399 auto e = mid + quarter; 400 medianOf!less(l, e, mid, b, r); 401 } 402 403 void medianOf(alias less, bool leanRight = false, Iterator) 404 (ref Iterator a, ref Iterator b) @trusted 405 { 406 import mir.utility : swapStars; 407 408 if (less(*b, *a)) { 409 swapStars(a, b); 410 } 411 assert(!less(*b, *a)); 412 } 413 414 void medianOf(alias less, bool leanRight = false, Iterator) 415 (ref Iterator a, ref Iterator b, ref Iterator c) @trusted 416 { 417 import mir.utility : swapStars; 418 419 if (less(*c, *a)) // c < a 420 { 421 if (less(*a, *b)) // c < a < b 422 { 423 swapStars(a, b); 424 swapStars(a, c); 425 } 426 else // c < a, b <= a 427 { 428 swapStars(a, c); 429 if (less(*b, *a)) swapStars(a, b); 430 } 431 } 432 else // a <= c 433 { 434 if (less(*b, *a)) // b < a <= c 435 { 436 swapStars(a, b); 437 } 438 else // a <= c, a <= b 439 { 440 if (less(*c, *b)) swapStars(b, c); 441 } 442 } 443 assert(!less(*b, *a)); 444 assert(!less(*c, *b)); 445 } 446 447 void medianOf(alias less, bool leanRight = false, Iterator) 448 (ref Iterator a, ref Iterator b, ref Iterator c, ref Iterator d) @trusted 449 { 450 import mir.utility: swapStars; 451 452 static if (!leanRight) 453 { 454 // Eliminate the rightmost from the competition 455 if (less(*d, *c)) swapStars(c, d); // c <= d 456 if (less(*d, *b)) swapStars(b, d); // b <= d 457 medianOf!less(a, b, c); 458 } 459 else 460 { 461 // Eliminate the leftmost from the competition 462 if (less(*b, *a)) swapStars(a, b); // a <= b 463 if (less(*c, *a)) swapStars(a, c); // a <= c 464 medianOf!less(b, c, d); 465 } 466 } 467 468 void medianOf(alias less, bool leanRight = false, Iterator) 469 (ref Iterator a, ref Iterator b, ref Iterator c, ref Iterator d, ref Iterator e) @trusted 470 { 471 import mir.utility: swapStars; // Credit: Teppo Niinimäki 472 473 version(unittest) scope(success) 474 { 475 assert(!less(*c, *a)); 476 assert(!less(*c, *b)); 477 assert(!less(*d, *c)); 478 assert(!less(*e, *c)); 479 } 480 481 if (less(*c, *a)) swapStars(a, c); 482 if (less(*d, *b)) swapStars(b, d); 483 if (less(*d, *c)) 484 { 485 swapStars(c, d); 486 swapStars(a, b); 487 } 488 if (less(*e, *b)) swapStars(b, e); 489 if (less(*e, *c)) 490 { 491 swapStars(c, e); 492 if (less(*c, *a)) swapStars(a, c); 493 } 494 else 495 { 496 if (less(*c, *b)) swapStars(b, c); 497 } 498 } 499 500 501 /++ 502 Returns: `true` if a sorted array contains the value. 503 504 Params: 505 test = strict ordering symmetric predicate 506 507 For non-symmetric predicates please use a structure with two `opCall`s or an alias of two global functions, 508 that correponds to `(array[i], value)` and `(value, array[i])` cases. 509 510 See_also: $(LREF transitionIndex). 511 +/ 512 template assumeSortedContains(alias test = "a < b") 513 { 514 import mir.functional: naryFun; 515 static if (__traits(isSame, naryFun!test, test)) 516 { 517 @fmamath: 518 /++ 519 Params: 520 slice = sorted one-dimensional slice or array. 521 v = value to test with. It is passed to second argument. 522 +/ 523 bool assumeSortedContains(Iterator, SliceKind kind, V) 524 (auto ref Slice!(Iterator, 1, kind) slice, auto ref scope const V v) 525 { 526 auto ti = transitionIndex!test(slice, v); 527 return ti < slice.length && !test(v, slice[ti]); 528 } 529 530 /// ditto 531 bool assumeSortedContains(T, V)(scope T[] ar, auto ref scope const V v) 532 { 533 return .assumeSortedContains!test(ar.sliced, v); 534 } 535 } 536 else 537 alias assumeSortedContains = .assumeSortedContains!(naryFun!test); 538 } 539 540 /++ 541 Returns: the smallest index of a sorted array such 542 that the index corresponds to the arrays element at the index according to the predicate 543 and `-1` if the array doesn't contain corresponding element. 544 545 Params: 546 test = strict ordering symmetric predicate. 547 548 For non-symmetric predicates please use a structure with two `opCall`s or an alias of two global functions, 549 that correponds to `(array[i], value)` and `(value, array[i])` cases. 550 551 See_also: $(LREF transitionIndex). 552 +/ 553 template assumeSortedEqualIndex(alias test = "a < b") 554 { 555 import mir.functional: naryFun; 556 static if (__traits(isSame, naryFun!test, test)) 557 { 558 @fmamath: 559 /++ 560 Params: 561 slice = sorted one-dimensional slice or array. 562 v = value to test with. It is passed to second argument. 563 +/ 564 sizediff_t assumeSortedEqualIndex(Iterator, SliceKind kind, V) 565 (auto ref Slice!(Iterator, 1, kind) slice, auto ref scope const V v) 566 { 567 auto ti = transitionIndex!test(slice, v); 568 return ti < slice.length && !test(v, slice[ti]) ? ti : -1; 569 } 570 571 /// ditto 572 sizediff_t assumeSortedEqualIndex(T, V)(scope T[] ar, auto ref scope const V v) 573 { 574 return .assumeSortedEqualIndex!test(ar.sliced, v); 575 } 576 } 577 else 578 alias assumeSortedEqualIndex = .assumeSortedEqualIndex!(naryFun!test); 579 } 580 581 /// 582 version(mir_ndslice_test) 583 @safe pure unittest 584 { 585 // sorted: a < b 586 auto a = [0, 1, 2, 3, 4, 6]; 587 588 assert(a.assumeSortedEqualIndex(2) == 2); 589 assert(a.assumeSortedEqualIndex(5) == -1); 590 591 // <= non strict predicates doesn't work 592 assert(a.assumeSortedEqualIndex!"a <= b"(2) == -1); 593 } 594 595 /++ 596 Computes transition index using binary search. 597 It is low-level API for lower and upper bounds of a sorted array. 598 599 Params: 600 test = ordering predicate for (`(array[i], value)`) pairs. 601 602 See_also: $(SUBREF topology, assumeSortedEqualIndex). 603 +/ 604 template transitionIndex(alias test = "a < b") 605 { 606 import mir.functional: naryFun; 607 static if (__traits(isSame, naryFun!test, test)) 608 { 609 @fmamath: 610 /++ 611 Params: 612 slice = sorted one-dimensional slice or array. 613 v = value to test with. It is passed to second argument. 614 +/ 615 size_t transitionIndex(Iterator, SliceKind kind, V) 616 (auto ref Slice!(Iterator, 1, kind) slice, auto ref scope const V v) 617 { 618 size_t first = 0, count = slice.length; 619 while (count > 0) 620 { 621 immutable step = count / 2, it = first + step; 622 if (test(slice[it], v)) 623 { 624 first = it + 1; 625 count -= step + 1; 626 } 627 else 628 { 629 count = step; 630 } 631 } 632 return first; 633 } 634 635 /// ditto 636 size_t transitionIndex(T, V)(scope T[] ar, auto ref scope const V v) 637 { 638 return .transitionIndex!test(ar.sliced, v); 639 } 640 641 } 642 else 643 alias transitionIndex = .transitionIndex!(naryFun!test); 644 } 645 646 /// 647 version(mir_ndslice_test) 648 @safe pure unittest 649 { 650 // sorted: a < b 651 auto a = [0, 1, 2, 3, 4, 6]; 652 653 auto i = a.transitionIndex(2); 654 assert(i == 2); 655 auto lowerBound = a[0 .. i]; 656 657 auto j = a.transitionIndex!"a <= b"(2); 658 assert(j == 3); 659 auto upperBound = a[j .. $]; 660 661 assert(a.transitionIndex(a[$ - 1]) == a.length - 1); 662 assert(a.transitionIndex!"a <= b"(a[$ - 1]) == a.length); 663 } 664 665 /++ 666 Computes an index for `r` based on the comparison `less`. The 667 index is a sorted array of indices into the original 668 range. 669 670 This technique is similar to sorting, but it is more flexible 671 because (1) it allows "sorting" of immutable collections, (2) allows 672 binary search even if the original collection does not offer random 673 access, (3) allows multiple indices, each on a different predicate, 674 and (4) may be faster when dealing with large objects. However, using 675 an index may also be slower under certain circumstances due to the 676 extra indirection, and is always larger than a sorting-based solution 677 because it needs space for the index in addition to the original 678 collection. The complexity is the same as `sort`'s. 679 680 Can be combined with $(SUBREF topology, indexed) to create a view that is sorted 681 based on the index. 682 683 Params: 684 less = The comparison to use. 685 r = The slice/array to index. 686 687 Returns: 688 Index slice/array. 689 690 See_Also: 691 $(HTTPS numpy.org/doc/stable/reference/generated/numpy.argsort.html, numpy.argsort) 692 +/ 693 Slice!(I*) makeIndex(I = size_t, alias less = "a < b", Iterator, SliceKind kind)(Slice!(Iterator, 1, kind) r) 694 { 695 import mir.functional: naryFun; 696 import mir.ndslice.allocation: slice; 697 import mir.ndslice.topology: iota; 698 return r 699 .length 700 .iota!I 701 .slice 702 .sort!((a, b) => naryFun!less(r[a], r[b])); 703 } 704 705 /// 706 I[] makeIndex(I = size_t, alias less = "a < b", T)(scope T[] r) 707 { 708 return .makeIndex!(I, less)(r.sliced).field; 709 } 710 711 /// 712 version(mir_ndslice_test) 713 @safe pure nothrow 714 unittest 715 { 716 import mir.algorithm.iteration: all; 717 import mir.ndslice.topology: indexed, pairwise; 718 719 immutable arr = [ 2, 3, 1, 5, 0 ]; 720 auto index = arr.makeIndex; 721 722 assert(arr.indexed(index).pairwise!"a < b".all); 723 } 724 725 /// Sort based on index created from a separate array 726 version(mir_ndslice_test) 727 @safe pure nothrow 728 unittest 729 { 730 import mir.algorithm.iteration: equal; 731 import mir.ndslice.topology: indexed; 732 733 immutable arr0 = [ 2, 3, 1, 5, 0 ]; 734 immutable arr1 = [ 1, 5, 4, 2, -1 ]; 735 736 auto index = makeIndex(arr0); 737 assert(index.equal([4, 2, 0, 1, 3])); 738 auto view = arr1.indexed(index); 739 assert(view.equal([-1, 4, 1, 5, 2])); 740 } 741 742 /++ 743 Partitions `slice` around `pivot` using comparison function `less`, algorithm 744 akin to $(LINK2 https://en.wikipedia.org/wiki/Quicksort#Hoare_partition_scheme, 745 Hoare partition). Specifically, permutes elements of `slice` and returns 746 an index `k < slice.length` such that: 747 748 $(UL 749 750 $(LI `slice[pivot]` is swapped to `slice[k]`) 751 752 753 $(LI All elements `e` in subrange `slice[0 .. k]` satisfy `!less(slice[k], e)` 754 (i.e. `slice[k]` is greater than or equal to each element to its left according 755 to predicate `less`)) 756 757 $(LI All elements `e` in subrange `slice[k .. $]` satisfy 758 `!less(e, slice[k])` (i.e. `slice[k]` is less than or equal to each element to 759 its right according to predicate `less`))) 760 761 If `slice` contains equivalent elements, multiple permutations of `slice` may 762 satisfy these constraints. In such cases, `pivotPartition` attempts to 763 distribute equivalent elements fairly to the left and right of `k` such that `k` 764 stays close to `slice.length / 2`. 765 766 Params: 767 less = The predicate used for comparison 768 769 Returns: 770 The new position of the pivot 771 772 See_Also: 773 $(HTTP jgrcs.info/index.php/jgrcs/article/view/142, Engineering of a Quicksort 774 Partitioning Algorithm), D. Abhyankar, Journal of Global Research in Computer 775 Science, February 2011. $(HTTPS youtube.com/watch?v=AxnotgLql0k, ACCU 2016 776 Keynote), Andrei Alexandrescu. 777 +/ 778 @trusted 779 template pivotPartition(alias less = "a < b") 780 { 781 import mir.functional: naryFun; 782 783 static if (__traits(isSame, naryFun!less, less)) 784 { 785 /++ 786 Params: 787 slice = slice being partitioned 788 pivot = The index of the pivot for partitioning, must be less than 789 `slice.length` or `0` if `slice.length` is `0` 790 +/ 791 size_t pivotPartition(Iterator, size_t N, SliceKind kind) 792 (Slice!(Iterator, N, kind) slice, 793 size_t pivot) 794 { 795 assert(pivot < slice.elementCount || slice.elementCount == 0 && pivot == 0, "pivotPartition: pivot must be less than the elementCount of the slice or the slice must be empty and pivot zero"); 796 797 if (slice.elementCount <= 1) return 0; 798 799 import mir.ndslice.topology: flattened; 800 801 auto flattenedSlice = slice.flattened; 802 auto frontI = flattenedSlice._iterator; 803 auto lastI = frontI + flattenedSlice.length - 1; 804 auto pivotI = frontI + pivot; 805 pivotPartitionImpl!less(frontI, lastI, pivotI); 806 return pivotI - frontI; 807 } 808 } else { 809 alias pivotPartition = .pivotPartition!(naryFun!less); 810 } 811 } 812 813 /// pivotPartition with 1-dimensional Slice 814 version(mir_ndslice_test) 815 @safe pure nothrow 816 unittest 817 { 818 import mir.ndslice.slice: sliced; 819 import mir.algorithm.iteration: all; 820 821 auto x = [5, 3, 2, 6, 4, 1, 3, 7].sliced; 822 size_t pivot = pivotPartition(x, x.length / 2); 823 824 assert(x[0 .. pivot].all!(a => a <= x[pivot])); 825 assert(x[pivot .. $].all!(a => a >= x[pivot])); 826 } 827 828 /// pivotPartition with 2-dimensional Slice 829 version(mir_ndslice_test) 830 @safe pure 831 unittest 832 { 833 import mir.ndslice.fuse: fuse; 834 import mir.ndslice.topology: flattened; 835 import mir.algorithm.iteration: all; 836 837 auto x = [ 838 [5, 3, 2, 6], 839 [4, 1, 3, 7] 840 ].fuse; 841 842 size_t pivot = pivotPartition(x, x.elementCount / 2); 843 844 auto xFlattened = x.flattened; 845 assert(xFlattened[0 .. pivot].all!(a => a <= xFlattened[pivot])); 846 assert(xFlattened[pivot .. $].all!(a => a >= xFlattened[pivot])); 847 } 848 849 version(mir_ndslice_test) 850 @safe 851 unittest 852 { 853 void test(alias less)() 854 { 855 import mir.ndslice.slice: sliced; 856 import mir.algorithm.iteration: all, equal; 857 858 Slice!(int*) x; 859 size_t pivot; 860 861 x = [-9, -4, -2, -2, 9].sliced; 862 pivot = pivotPartition!less(x, x.length / 2); 863 864 assert(x[0 .. pivot].all!(a => a <= x[pivot])); 865 assert(x[pivot .. $].all!(a => a >= x[pivot])); 866 867 x = [9, 2, 8, -5, 5, 4, -8, -4, 9].sliced; 868 pivot = pivotPartition!less(x, x.length / 2); 869 870 assert(x[0 .. pivot].all!(a => a <= x[pivot])); 871 assert(x[pivot .. $].all!(a => a >= x[pivot])); 872 873 x = [ 42 ].sliced; 874 pivot = pivotPartition!less(x, x.length / 2); 875 876 assert(pivot == 0); 877 assert(x.equal([ 42 ])); 878 879 x = [ 43, 42 ].sliced; 880 pivot = pivotPartition!less(x, 0); 881 assert(pivot == 1); 882 assert(x.equal([ 42, 43 ])); 883 884 x = [ 43, 42 ].sliced; 885 pivot = pivotPartition!less(x, 1); 886 887 assert(pivot == 0); 888 assert(x.equal([ 42, 43 ])); 889 890 x = [ 42, 42 ].sliced; 891 pivot = pivotPartition!less(x, 0); 892 893 assert(pivot == 0 || pivot == 1); 894 assert(x.equal([ 42, 42 ])); 895 896 pivot = pivotPartition!less(x, 1); 897 898 assert(pivot == 0 || pivot == 1); 899 assert(x.equal([ 42, 42 ])); 900 } 901 test!"a < b"; 902 static bool myLess(int a, int b) 903 { 904 static bool bogus; 905 if (bogus) throw new Exception(""); // just to make it no-nothrow 906 return a < b; 907 } 908 test!myLess; 909 } 910 911 @trusted 912 template pivotPartitionImpl(alias less) 913 { 914 void pivotPartitionImpl(Iterator) 915 (ref Iterator frontI, 916 ref Iterator lastI, 917 ref Iterator pivotI) 918 { 919 assert(pivotI <= lastI && pivotI >= frontI, "pivotPartition: pivot must be less than the length of slice or slice must be empty and pivot zero"); 920 921 if (frontI == lastI) return; 922 923 import mir.utility: swapStars; 924 925 // Pivot at the front 926 swapStars(pivotI, frontI); 927 928 // Fork implementation depending on nothrow copy, assignment, and 929 // comparison. If all of these are nothrow, use the specialized 930 // implementation discussed at 931 // https://youtube.com/watch?v=AxnotgLql0k. 932 static if (is(typeof( 933 () nothrow { auto x = frontI; x = frontI; return less(*x, *x); } 934 ))) 935 { 936 // Plant the pivot in the end as well as a sentinel 937 auto loI = frontI; 938 auto hiI = lastI; 939 auto save = *hiI; 940 *hiI = *frontI; // Vacancy is in r[$ - 1] now 941 942 // Start process 943 for (;;) 944 { 945 // Loop invariant 946 version(mir_ndslice_test) 947 { 948 // this used to import std.algorithm.all, but we want to 949 // save imports when unittests are enabled if possible. 950 size_t len = lastI - frontI + 1; 951 foreach (x; 0 .. (loI - frontI)) 952 assert(!less(*frontI, frontI[x]), "pivotPartition: *frontI must not be less than frontI[x]"); 953 foreach (x; (hiI - frontI + 1) .. len) 954 assert(!less(frontI[x], *frontI), "pivotPartition: frontI[x] must not be less than *frontI"); 955 } 956 do ++loI; while (less(*loI, *frontI)); 957 *(hiI) = *(loI); 958 // Vacancy is now in slice[lo] 959 do --hiI; while (less(*frontI, *hiI)); 960 if (loI >= hiI) break; 961 *(loI) = *(hiI); 962 // Vacancy is not in slice[hi] 963 } 964 // Fixup 965 assert(loI - hiI <= 2, "pivotPartition: Following compare not possible"); 966 assert(!less(*frontI, *hiI), "pivotPartition: *hiI must not be less than *frontI"); 967 if (loI - hiI == 2) 968 { 969 assert(!less(hiI[1], *frontI), "pivotPartition: *(hiI + 1) must not be less than *frontI"); 970 *(loI) = hiI[1]; 971 --loI; 972 } 973 *loI = save; 974 if (less(*frontI, save)) --loI; 975 assert(!less(*frontI, *loI), "pivotPartition: *frontI must not be less than *loI"); 976 } else { 977 auto loI = frontI; 978 ++loI; 979 auto hiI = lastI; 980 981 loop: for (;; loI++, hiI--) 982 { 983 for (;; ++loI) 984 { 985 if (loI > hiI) break loop; 986 if (!less(*loI, *frontI)) break; 987 } 988 // found the left bound: !less(*loI, *frontI) 989 assert(loI <= hiI, "pivotPartition: loI must be less or equal than hiI"); 990 for (;; --hiI) 991 { 992 if (loI >= hiI) break loop; 993 if (!less(*frontI, *hiI)) break; 994 } 995 // found the right bound: !less(*frontI, *hiI), swap & make progress 996 assert(!less(*loI, *hiI), "pivotPartition: *lowI must not be less than *hiI"); 997 swapStars(loI, hiI); 998 } 999 --loI; 1000 } 1001 1002 swapStars(loI, frontI); 1003 pivotI = loI; 1004 } 1005 } 1006 1007 version(mir_ndslice_test) 1008 @safe pure nothrow 1009 unittest { 1010 import mir.ndslice.sorting: partitionAt; 1011 import mir.ndslice.allocation: rcslice; 1012 auto x = rcslice!double(4); 1013 x[0] = 3; 1014 x[1] = 2; 1015 x[2] = 1; 1016 x[3] = 0; 1017 partitionAt!("a > b")(x, 2); 1018 } 1019 1020 1021 version(mir_ndslice_test) 1022 @trusted pure nothrow 1023 unittest 1024 { 1025 import mir.ndslice.slice: sliced; 1026 import mir.algorithm.iteration: all; 1027 1028 auto x = [5, 3, 2, 6, 4, 1, 3, 7].sliced; 1029 auto frontI = x._iterator; 1030 auto lastI = x._iterator + x.length - 1; 1031 auto pivotI = frontI + x.length / 2; 1032 alias less = (a, b) => (a < b); 1033 pivotPartitionImpl!less(frontI, lastI, pivotI); 1034 size_t pivot = pivotI - frontI; 1035 1036 assert(x[0 .. pivot].all!(a => a <= x[pivot])); 1037 assert(x[pivot .. $].all!(a => a >= x[pivot])); 1038 } 1039 1040 /++ 1041 Partitions `slice`, such that all elements `e1` from `slice[0]` to `slice[nth]` 1042 satisfy `!less(slice[nth], e1)`, and all elements `e2` from `slice[nth]` to 1043 `slice[slice.length]` satisfy `!less(e2, slice[nth])`. This effectively reorders 1044 `slice` such that `slice[nth]` refers to the element that would fall there if 1045 the range were fully sorted. Performs an expected $(BIGOH slice.length) 1046 evaluations of `less` and `swap`, with a worst case of $(BIGOH slice.length^^2). 1047 1048 This function implements the [Fast, Deterministic Selection](https://erdani.com/research/sea2017.pdf) 1049 algorithm that is implemented in the [`topN`](https://dlang.org/library/std/algorithm/sorting/top_n.html) 1050 function in D's standard library (as of version `2.092.0`). 1051 1052 Params: 1053 less = The predicate to sort by. 1054 1055 See_Also: 1056 $(LREF pivotPartition), https://erdani.com/research/sea2017.pdf 1057 1058 +/ 1059 template partitionAt(alias less = "a < b") 1060 { 1061 import mir.functional: naryFun; 1062 1063 static if (__traits(isSame, naryFun!less, less)) 1064 { 1065 /++ 1066 Params: 1067 slice = n-dimensional slice 1068 nth = The index of the element that should be in sorted position after the 1069 function is finished. 1070 +/ 1071 void partitionAt(Iterator, size_t N, SliceKind kind) 1072 (Slice!(Iterator, N, kind) slice, size_t nth) @trusted nothrow @nogc 1073 { 1074 import mir.qualifier: lightScope; 1075 import core.lifetime: move; 1076 import mir.ndslice.topology: flattened; 1077 1078 assert(slice.elementCount > 0, "partitionAt: slice must have elementCount greater than 0"); 1079 assert(nth >= 0, "partitionAt: nth must be greater than or equal to zero"); 1080 assert(nth < slice.elementCount, "partitionAt: nth must be less than the elementCount of the slice"); 1081 1082 bool useSampling = true; 1083 auto flattenedSlice = slice.move.flattened; 1084 auto frontI = flattenedSlice._iterator.lightScope; 1085 auto lastI = frontI + (flattenedSlice.length - 1); 1086 partitionAtImpl!less(frontI, lastI, nth, useSampling); 1087 } 1088 } 1089 else 1090 alias partitionAt = .partitionAt!(naryFun!less); 1091 } 1092 1093 /// Partition 1-dimensional slice at nth 1094 version(mir_ndslice_test) 1095 @safe pure nothrow 1096 unittest { 1097 import mir.ndslice.slice: sliced; 1098 1099 size_t nth = 2; 1100 auto x = [3, 1, 5, 2, 0].sliced; 1101 x.partitionAt(nth); 1102 assert(x[nth] == 2); 1103 } 1104 1105 /// Partition 2-dimensional slice 1106 version(mir_ndslice_test) 1107 @safe pure nothrow 1108 unittest { 1109 import mir.ndslice.slice: sliced; 1110 1111 size_t nth = 4; 1112 auto x = [3, 1, 5, 2, 0, 7].sliced(3, 2); 1113 x.partitionAt(nth); 1114 assert(x[2, 0] == 5); 1115 } 1116 1117 /// Can supply alternate ordering function 1118 version(mir_ndslice_test) 1119 @safe pure nothrow 1120 unittest { 1121 import mir.ndslice.slice: sliced; 1122 1123 size_t nth = 2; 1124 auto x = [3, 1, 5, 2, 0].sliced; 1125 x.partitionAt!("a > b")(nth); 1126 assert(x[nth] == 2); 1127 } 1128 1129 // Check issue #328 fixed 1130 version(mir_ndslice_test) 1131 @safe pure nothrow 1132 unittest { 1133 import mir.ndslice.slice: sliced; 1134 1135 auto slice = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17].sliced; 1136 partitionAt(slice, 8); 1137 partitionAt(slice, 9); 1138 } 1139 1140 version(unittest) { 1141 template checkPartitionAtAll(alias less = "a < b") 1142 { 1143 import mir.functional: naryFun; 1144 import mir.ndslice.slice: SliceKind, Slice; 1145 1146 static if (__traits(isSame, naryFun!less, less)) 1147 { 1148 @safe pure nothrow 1149 static bool checkPartitionAtAll 1150 (Iterator, SliceKind kind)( 1151 Slice!(Iterator, 1, kind) x) 1152 { 1153 auto x_sorted = x.dup; 1154 x_sorted.sort!less; 1155 1156 bool result = true; 1157 1158 foreach (nth; 0 .. x.length) 1159 { 1160 auto x_i = x.dup; 1161 x_i.partitionAt!less(nth); 1162 if (x_i[nth] != x_sorted[nth]) { 1163 result = false; 1164 break; 1165 } 1166 } 1167 return result; 1168 } 1169 } else { 1170 alias checkPartitionAtAll = .checkPartitionAtAll!(naryFun!less); 1171 } 1172 } 1173 } 1174 1175 version(mir_ndslice_test) 1176 @safe pure nothrow 1177 unittest { 1178 import mir.ndslice.slice: sliced; 1179 1180 assert(checkPartitionAtAll([2, 2].sliced)); 1181 1182 assert(checkPartitionAtAll([3, 1, 5, 2, 0].sliced)); 1183 assert(checkPartitionAtAll([3, 1, 5, 0, 2].sliced)); 1184 assert(checkPartitionAtAll([0, 0, 4, 3, 3].sliced)); 1185 assert(checkPartitionAtAll([5, 1, 5, 1, 5].sliced)); 1186 assert(checkPartitionAtAll([2, 2, 0, 0, 0].sliced)); 1187 1188 assert(checkPartitionAtAll([ 2, 12, 10, 8, 1, 20, 19, 1, 2, 7].sliced)); 1189 assert(checkPartitionAtAll([ 4, 18, 16, 0, 15, 6, 2, 17, 10, 16].sliced)); 1190 assert(checkPartitionAtAll([ 7, 5, 9, 4, 4, 2, 12, 20, 15, 15].sliced)); 1191 1192 assert(checkPartitionAtAll([17, 87, 58, 50, 34, 98, 25, 77, 88, 79].sliced)); 1193 1194 assert(checkPartitionAtAll([ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22].sliced)); 1195 assert(checkPartitionAtAll([21, 3, 11, 22, 24, 12, 14, 12, 15, 15, 1, 3, 12, 15, 25, 19, 9, 16, 16, 19].sliced)); 1196 assert(checkPartitionAtAll([22, 6, 18, 0, 1, 8, 13, 13, 16, 19, 23, 17, 4, 6, 12, 24, 15, 20, 11, 17].sliced)); 1197 assert(checkPartitionAtAll([19, 23, 14, 5, 12, 3, 13, 7, 25, 25, 24, 9, 21, 25, 12, 22, 15, 22, 7, 11].sliced)); 1198 assert(checkPartitionAtAll([ 0, 2, 7, 16, 2, 20, 1, 11, 17, 5, 22, 17, 25, 13, 14, 5, 22, 21, 24, 14].sliced)); 1199 } 1200 1201 private @trusted pure nothrow @nogc 1202 void partitionAtImpl(alias less, Iterator)( 1203 Iterator loI, 1204 Iterator hiI, 1205 size_t n, 1206 bool useSampling) 1207 { 1208 assert(loI <= hiI, "partitionAtImpl: frontI must be less than or equal to lastI"); 1209 1210 import mir.utility: swapStars; 1211 import mir.functional: reverseArgs; 1212 1213 Iterator pivotI; 1214 size_t len; 1215 1216 for (;;) { 1217 len = hiI - loI + 1; 1218 1219 if (len <= 1) { 1220 break; 1221 } 1222 1223 if (n == 0) { 1224 pivotI = loI; 1225 foreach (i; 1 .. len) { 1226 if (less(loI[i], *pivotI)) { 1227 pivotI = loI + i; 1228 } 1229 } 1230 swapStars(loI + n, pivotI); 1231 break; 1232 } 1233 1234 if (n + 1 == len) { 1235 pivotI = loI; 1236 foreach (i; 1 .. len) { 1237 if (reverseArgs!less(loI[i], *pivotI)) { 1238 pivotI = loI + i; 1239 } 1240 } 1241 swapStars(loI + n, pivotI); 1242 break; 1243 } 1244 1245 if (len <= 12) { 1246 pivotI = loI + len / 2; 1247 pivotPartitionImpl!less(loI, hiI, pivotI); 1248 } else if (n * 16 <= (len - 1) * 7) { 1249 pivotI = partitionAtPartitionOffMedian!(less, false)(loI, hiI, n, useSampling); 1250 // Quality check 1251 if (useSampling) 1252 { 1253 auto pivot = pivotI - loI; 1254 if (pivot < n) 1255 { 1256 if (pivot * 4 < len) 1257 { 1258 useSampling = false; 1259 } 1260 } 1261 else if ((len - pivot) * 8 < len * 3) 1262 { 1263 useSampling = false; 1264 } 1265 } 1266 } else if (n * 16 >= (len - 1) * 9) { 1267 pivotI = partitionAtPartitionOffMedian!(less, true)(loI, hiI, n, useSampling); 1268 // Quality check 1269 if (useSampling) 1270 { 1271 auto pivot = pivotI - loI; 1272 if (pivot < n) 1273 { 1274 if (pivot * 8 < len * 3) 1275 { 1276 useSampling = false; 1277 } 1278 } 1279 else if ((len - pivot) * 4 < len) 1280 { 1281 useSampling = false; 1282 } 1283 } 1284 } else { 1285 pivotI = partitionAtPartition!less(loI, hiI, n, useSampling); 1286 // Quality check 1287 if (useSampling) { 1288 auto pivot = pivotI - loI; 1289 if (pivot * 9 < len * 2 || pivot * 9 > len * 7) 1290 { 1291 // Failed - abort sampling going forward 1292 useSampling = false; 1293 } 1294 } 1295 } 1296 1297 if (n < (pivotI - loI)) { 1298 hiI = pivotI - 1; 1299 } else if (n > (pivotI - loI)) { 1300 n -= (pivotI - loI + 1); 1301 loI = pivotI; 1302 ++loI; 1303 } else { 1304 break; 1305 } 1306 } 1307 } 1308 1309 version(mir_ndslice_test) 1310 @trusted pure nothrow 1311 unittest { 1312 import mir.ndslice.slice: sliced; 1313 1314 size_t nth = 2; 1315 auto x = [3, 1, 5, 2, 0].sliced; 1316 auto frontI = x._iterator; 1317 auto lastI = frontI + x.elementCount - 1; 1318 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, 1, true); 1319 assert(x[nth] == 2); 1320 } 1321 1322 version(mir_ndslice_test) 1323 @trusted pure nothrow 1324 unittest { 1325 import mir.ndslice.slice: sliced; 1326 1327 size_t nth = 4; 1328 auto x = [3, 1, 5, 2, 0, 7].sliced(3, 2); 1329 auto frontI = x._iterator; 1330 auto lastI = frontI + x.elementCount - 1; 1331 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1332 assert(x[2, 0] == 5); 1333 } 1334 1335 version(mir_ndslice_test) 1336 @trusted pure nothrow 1337 unittest { 1338 import mir.ndslice.slice: sliced; 1339 1340 size_t nth = 1; 1341 auto x = [0, 0, 4, 3, 3].sliced; 1342 auto frontI = x._iterator; 1343 auto lastI = frontI + x.elementCount - 1; 1344 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1345 assert(x[nth] == 0); 1346 } 1347 1348 version(mir_ndslice_test) 1349 @trusted pure nothrow 1350 unittest { 1351 import mir.ndslice.slice: sliced; 1352 1353 size_t nth = 2; 1354 auto x = [0, 0, 4, 3, 3].sliced; 1355 auto frontI = x._iterator; 1356 auto lastI = frontI + x.elementCount - 1; 1357 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1358 assert(x[nth] == 3); 1359 } 1360 1361 version(mir_ndslice_test) 1362 @trusted pure nothrow 1363 unittest { 1364 import mir.ndslice.slice: sliced; 1365 1366 size_t nth = 3; 1367 auto x = [0, 0, 4, 3, 3].sliced; 1368 auto frontI = x._iterator; 1369 auto lastI = frontI + x.elementCount - 1; 1370 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1371 assert(x[nth] == 3); 1372 } 1373 1374 version(mir_ndslice_test) 1375 @trusted pure nothrow 1376 unittest { 1377 import mir.ndslice.slice: sliced; 1378 1379 size_t nth = 4; 1380 auto x = [ 2, 12, 10, 8, 1, 20, 19, 1, 2, 7].sliced; 1381 auto frontI = x._iterator; 1382 auto lastI = frontI + x.elementCount - 1; 1383 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1384 assert(x[nth] == 7); 1385 } 1386 1387 version(mir_ndslice_test) 1388 @trusted pure nothrow 1389 unittest { 1390 import mir.ndslice.slice: sliced; 1391 1392 size_t nth = 5; 1393 auto x = [ 2, 12, 10, 8, 1, 20, 19, 1, 2, 7].sliced; 1394 auto frontI = x._iterator; 1395 auto lastI = frontI + x.elementCount - 1; 1396 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1397 assert(x[nth] == 8); 1398 } 1399 1400 version(mir_ndslice_test) 1401 @trusted pure nothrow 1402 unittest { 1403 import mir.ndslice.slice: sliced; 1404 1405 size_t nth = 6; 1406 auto x = [ 2, 12, 10, 8, 1, 20, 19, 1, 2, 7].sliced; 1407 auto frontI = x._iterator; 1408 auto lastI = frontI + x.elementCount - 1; 1409 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1410 assert(x[nth] == 10); 1411 } 1412 1413 // Check all partitionAt 1414 version(mir_ndslice_test) 1415 @trusted pure nothrow 1416 unittest { 1417 import mir.ndslice.slice: sliced; 1418 import mir.ndslice.allocation: slice; 1419 1420 static immutable raw = [ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22]; 1421 1422 static void fill(T)(T x) { 1423 for (size_t i = 0; i < x.length; i++) { 1424 x[i] = raw[i]; 1425 } 1426 } 1427 auto x = slice!int(raw.length); 1428 fill(x); 1429 auto x_sort = x.dup; 1430 x_sort = x_sort.sort; 1431 size_t i = 0; 1432 while (i < raw.length) { 1433 auto frontI = x._iterator; 1434 auto lastI = frontI + x.length - 1; 1435 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, i, true); 1436 assert(x[i] == x_sort[i]); 1437 fill(x); 1438 i++; 1439 } 1440 } 1441 1442 private @trusted pure nothrow @nogc 1443 Iterator partitionAtPartition(alias less, Iterator)( 1444 ref Iterator frontI, 1445 ref Iterator lastI, 1446 size_t n, 1447 bool useSampling) 1448 { 1449 size_t len = lastI - frontI + 1; 1450 1451 assert(len >= 9 && n < len, "partitionAtPartition: length must be longer than 9 and n must be less than r.length"); 1452 1453 size_t ninth = len / 9; 1454 size_t pivot = ninth / 2; 1455 // Position subrange r[loI .. hiI] to have length equal to ninth and its upper 1456 // median r[loI .. hiI][$ / 2] in exactly the same place as the upper median 1457 // of the entire range r[$ / 2]. This is to improve behavior for searching 1458 // the median in already sorted ranges. 1459 auto loI = frontI; 1460 loI += len / 2 - pivot; 1461 auto hiI = loI; 1462 hiI += ninth; 1463 1464 // We have either one straggler on the left, one on the right, or none. 1465 assert(loI - frontI <= lastI - hiI + 1 || lastI - hiI <= loI - frontI + 1, "partitionAtPartition: straggler check failed for loI, len, hiI"); 1466 assert(loI - frontI >= ninth * 4, "partitionAtPartition: loI - frontI >= ninth * 4"); 1467 assert((lastI + 1) - hiI >= ninth * 4, "partitionAtPartition: (lastI + 1) - hiI >= ninth * 4"); 1468 1469 // Partition in groups of 3, and the mid tertile again in groups of 3 1470 if (!useSampling) { 1471 auto loI_ = loI; 1472 loI_ -= ninth; 1473 auto hiI_ = hiI; 1474 hiI_ += ninth; 1475 p3!(less, Iterator)(frontI, lastI, loI_, hiI_); 1476 } 1477 p3!(less, Iterator)(frontI, lastI, loI, hiI); 1478 1479 // Get the median of medians of medians 1480 // Map the full interval of n to the full interval of the ninth 1481 pivot = (n * (ninth - 1)) / (len - 1); 1482 if (hiI > loI) { 1483 auto hiI_minus = hiI; 1484 --hiI_minus; 1485 partitionAtImpl!less(loI, hiI_minus, pivot, useSampling); 1486 } 1487 1488 auto pivotI = loI; 1489 pivotI += pivot; 1490 1491 return expandPartition!less(frontI, lastI, loI, pivotI, hiI); 1492 } 1493 1494 version(mir_ndslice_test) 1495 @trusted pure nothrow 1496 unittest { 1497 import mir.ndslice.slice: sliced; 1498 auto x = [ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22].sliced; 1499 auto x_sort = x.dup; 1500 x_sort = x_sort.sort; 1501 auto frontI = x._iterator; 1502 auto lastI = frontI + x.length - 1; 1503 size_t n = x.length / 2; 1504 partitionAtPartition!((a, b) => (a < b))(frontI, lastI, n, true); 1505 assert(x[n - 1] == x_sort[n - 1]); 1506 } 1507 1508 private @trusted pure nothrow @nogc 1509 Iterator partitionAtPartitionOffMedian(alias less, bool leanRight, Iterator)( 1510 ref Iterator frontI, 1511 ref Iterator lastI, 1512 size_t n, 1513 bool useSampling) 1514 { 1515 size_t len = lastI - frontI + 1; 1516 1517 assert(len >= 12, "partitionAtPartitionOffMedian: len must be greater than 11"); 1518 assert(n < len, "partitionAtPartitionOffMedian: n must be less than len"); 1519 auto _4 = len / 4; 1520 auto leftLimitI = frontI; 1521 static if (leanRight) 1522 leftLimitI += 2 * _4; 1523 else 1524 leftLimitI += _4; 1525 // Partition in groups of 4, and the left quartile again in groups of 3 1526 if (!useSampling) 1527 { 1528 auto leftLimit_plus_4 = leftLimitI; 1529 leftLimit_plus_4 += _4; 1530 p4!(less, leanRight)(frontI, lastI, leftLimitI, leftLimit_plus_4); 1531 } 1532 auto _12 = _4 / 3; 1533 auto loI = leftLimitI; 1534 loI += _12; 1535 auto hiI = loI; 1536 hiI += _12; 1537 p3!less(frontI, lastI, loI, hiI); 1538 1539 // Get the median of medians of medians 1540 // Map the full interval of n to the full interval of the ninth 1541 auto pivot = (n * (_12 - 1)) / (len - 1); 1542 if (hiI > loI) { 1543 auto hiI_minus = hiI; 1544 --hiI_minus; 1545 partitionAtImpl!less(loI, hiI_minus, pivot, useSampling); 1546 } 1547 auto pivotI = loI; 1548 pivotI += pivot; 1549 return expandPartition!less(frontI, lastI, loI, pivotI, hiI); 1550 } 1551 1552 version(mir_ndslice_test) 1553 @trusted pure nothrow 1554 unittest { 1555 import mir.ndslice.slice: sliced; 1556 import mir.algorithm.iteration: equal; 1557 1558 auto x = [ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22].sliced; 1559 auto frontI = x._iterator; 1560 auto lastI = frontI + x.length - 1; 1561 partitionAtPartitionOffMedian!((a, b) => (a < b), false)(frontI, lastI, 5, true); 1562 assert(x.equal([6, 7, 8, 9, 5, 0, 2, 7, 9, 15, 10, 25, 11, 10, 13, 18, 17, 13, 25, 22])); 1563 } 1564 1565 version(mir_ndslice_test) 1566 @trusted pure nothrow 1567 unittest { 1568 import mir.ndslice.slice: sliced; 1569 import mir.algorithm.iteration: equal; 1570 1571 auto x = [ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22].sliced; 1572 auto frontI = x._iterator; 1573 auto lastI = frontI + x.length - 1; 1574 partitionAtPartitionOffMedian!((a, b) => (a < b), true)(frontI, lastI, 15, true); 1575 assert(x.equal([6, 7, 8, 7, 5, 2, 9, 0, 9, 15, 25, 10, 11, 10, 13, 18, 17, 13, 25, 22])); 1576 } 1577 1578 private @trusted 1579 void p3(alias less, Iterator)( 1580 Iterator frontI, 1581 Iterator lastI, 1582 Iterator loI, 1583 Iterator hiI) 1584 { 1585 assert(loI <= hiI && hiI <= lastI, "p3: loI must be less than or equal to hiI and hiI must be less than or equal to lastI"); 1586 immutable diffI = hiI - loI; 1587 Iterator lo_loI; 1588 Iterator hi_loI; 1589 for (; loI < hiI; ++loI) 1590 { 1591 lo_loI = loI; 1592 lo_loI -= diffI; 1593 hi_loI = loI; 1594 hi_loI += diffI; 1595 assert(lo_loI >= frontI, "p3: lo_loI must be greater than or equal to frontI"); 1596 assert(hi_loI <= lastI, "p3: hi_loI must be less than or equal to lastI"); 1597 medianOf!less(lo_loI, loI, hi_loI); 1598 } 1599 } 1600 1601 version(mir_ndslice_test) 1602 @trusted pure nothrow 1603 unittest { 1604 import mir.ndslice.slice: sliced; 1605 import mir.algorithm.iteration: equal; 1606 1607 auto x = [3, 4, 0, 5, 2, 1].sliced; 1608 auto frontI = x._iterator; 1609 auto lastI = frontI + x.length - 1; 1610 auto loI = frontI + 2; 1611 auto hiI = frontI + 4; 1612 p3!((a, b) => (a < b))(frontI, lastI, loI, hiI); 1613 assert(x.equal([0, 1, 2, 4, 3, 5])); 1614 } 1615 1616 private @trusted 1617 template p4(alias less, bool leanRight) 1618 { 1619 void p4(Iterator)( 1620 Iterator frontI, 1621 Iterator lastI, 1622 Iterator loI, 1623 Iterator hiI) 1624 { 1625 assert(loI <= hiI && hiI <= lastI, "p4: loI must be less than or equal to hiI and hiI must be less than or equal to lastI"); 1626 1627 immutable diffI = hiI - loI; 1628 immutable diffI2 = diffI * 2; 1629 1630 Iterator lo_loI; 1631 Iterator hi_loI; 1632 1633 static if (leanRight) 1634 Iterator lo2_loI; 1635 else 1636 Iterator hi2_loI; 1637 1638 for (; loI < hiI; ++loI) 1639 { 1640 lo_loI = loI - diffI; 1641 hi_loI = loI + diffI; 1642 1643 assert(lo_loI >= frontI, "p4: lo_loI must be greater than or equal to frontI"); 1644 assert(hi_loI <= lastI, "p4: hi_loI must be less than or equal to lastI"); 1645 1646 static if (leanRight) { 1647 lo2_loI = loI - diffI2; 1648 assert(lo2_loI >= frontI, "lo2_loI must be greater than or equal to frontI"); 1649 medianOf!(less, leanRight)(lo2_loI, lo_loI, loI, hi_loI); 1650 } else { 1651 hi2_loI = loI + diffI2; 1652 assert(hi2_loI <= lastI, "hi2_loI must be less than or equal to lastI"); 1653 medianOf!(less, leanRight)(lo_loI, loI, hi_loI, hi2_loI); 1654 } 1655 } 1656 } 1657 } 1658 1659 version(mir_ndslice_test) 1660 @trusted pure nothrow 1661 unittest { 1662 import mir.ndslice.slice: sliced; 1663 import mir.algorithm.iteration: equal; 1664 1665 auto x = [3, 4, 0, 7, 2, 6, 5, 1, 4].sliced; 1666 auto frontI = x._iterator; 1667 auto lastI = frontI + x.length - 1; 1668 auto loI = frontI + 3; 1669 auto hiI = frontI + 5; 1670 p4!((a, b) => (a < b), false)(frontI, lastI, loI, hiI); 1671 assert(x.equal([3, 1, 0, 4, 2, 6, 4, 7, 5])); 1672 } 1673 1674 version(mir_ndslice_test) 1675 @trusted pure nothrow 1676 unittest { 1677 import mir.ndslice.slice: sliced; 1678 import mir.algorithm.iteration: equal; 1679 1680 auto x = [3, 4, 0, 8, 2, 7, 5, 1, 4, 3].sliced; 1681 auto frontI = x._iterator; 1682 auto lastI = frontI + x.length - 1; 1683 auto loI = frontI + 4; 1684 auto hiI = frontI + 6; 1685 p4!((a, b) => (a < b), true)(frontI, lastI, loI, hiI); 1686 assert(x.equal([0, 4, 2, 1, 3, 7, 5, 8, 4, 3])); 1687 } 1688 1689 private @trusted 1690 template expandPartition(alias less) 1691 { 1692 Iterator expandPartition(Iterator)( 1693 ref Iterator frontI, 1694 ref Iterator lastI, 1695 ref Iterator loI, 1696 ref Iterator pivotI, 1697 ref Iterator hiI) 1698 { 1699 import mir.algorithm.iteration: all; 1700 1701 assert(frontI <= loI, "expandPartition: frontI must be less than or equal to loI"); 1702 assert(loI <= pivotI, "expandPartition: loI must be less than or equal pivotI"); 1703 assert(pivotI < hiI, "expandPartition: pivotI must be less than hiI"); 1704 assert(hiI <= lastI, "expandPartition: hiI must be less than or equal to lastI"); 1705 1706 foreach(x; loI .. (pivotI + 1)) 1707 assert(!less(*pivotI, *x), "expandPartition: loI .. (pivotI + 1) failed test"); 1708 foreach(x; (pivotI + 1) .. hiI) 1709 assert(!less(*x, *pivotI), "expandPartition: (pivotI + 1) .. hiI failed test"); 1710 1711 import mir.utility: swapStars; 1712 import mir.algorithm.iteration: all; 1713 // We work with closed intervals! 1714 --hiI; 1715 1716 auto leftI = frontI; 1717 auto rightI = lastI; 1718 loop: for (;; ++leftI, --rightI) 1719 { 1720 for (;; ++leftI) 1721 { 1722 if (leftI == loI) break loop; 1723 if (!less(*leftI, *pivotI)) break; 1724 } 1725 for (;; --rightI) 1726 { 1727 if (rightI == hiI) break loop; 1728 if (!less(*pivotI, *rightI)) break; 1729 } 1730 swapStars(leftI, rightI); 1731 } 1732 1733 foreach(x; loI .. (pivotI + 1)) 1734 assert(!less(*pivotI, *x), "expandPartition: loI .. (pivotI + 1) failed less than test"); 1735 foreach(x; (pivotI + 1) .. (hiI + 1)) 1736 assert(!less(*x, *pivotI), "expandPartition: (pivotI + 1) .. (hiI + 1) failed less than test"); 1737 foreach(x; frontI .. leftI) 1738 assert(!less(*pivotI, *x), "expandPartition: frontI .. leftI failed less than test"); 1739 foreach(x; (rightI + 1) .. (lastI + 1)) 1740 assert(!less(*x, *pivotI), "expandPartition: (rightI + 1) .. (lastI + 1) failed less than test"); 1741 1742 auto oldPivotI = pivotI; 1743 1744 if (leftI < loI) 1745 { 1746 // First loop: spend r[loI .. pivot] 1747 for (; loI < pivotI; ++leftI) 1748 { 1749 if (leftI == loI) goto done; 1750 if (!less(*oldPivotI, *leftI)) continue; 1751 --pivotI; 1752 assert(!less(*oldPivotI, *pivotI), "expandPartition: less check failed"); 1753 swapStars(leftI, pivotI); 1754 } 1755 // Second loop: make leftI and pivot meet 1756 for (;; ++leftI) 1757 { 1758 if (leftI == pivotI) goto done; 1759 if (!less(*oldPivotI, *leftI)) continue; 1760 for (;;) 1761 { 1762 if (leftI == pivotI) goto done; 1763 --pivotI; 1764 if (less(*pivotI, *oldPivotI)) 1765 { 1766 swapStars(leftI, pivotI); 1767 break; 1768 } 1769 } 1770 } 1771 } 1772 1773 // First loop: spend r[lo .. pivot] 1774 for (; hiI != pivotI; --rightI) 1775 { 1776 if (rightI == hiI) goto done; 1777 if (!less(*rightI, *oldPivotI)) continue; 1778 ++pivotI; 1779 assert(!less(*pivotI, *oldPivotI), "expandPartition: less check failed"); 1780 swapStars(rightI, pivotI); 1781 } 1782 // Second loop: make leftI and pivotI meet 1783 for (; rightI > pivotI; --rightI) 1784 { 1785 if (!less(*rightI, *oldPivotI)) continue; 1786 while (rightI > pivotI) 1787 { 1788 ++pivotI; 1789 if (less(*oldPivotI, *pivotI)) 1790 { 1791 swapStars(rightI, pivotI); 1792 break; 1793 } 1794 } 1795 } 1796 1797 done: 1798 swapStars(oldPivotI, pivotI); 1799 1800 1801 foreach(x; frontI .. (pivotI + 1)) 1802 assert(!less(*pivotI, *x), "expandPartition: frontI .. (pivotI + 1) failed test"); 1803 foreach(x; (pivotI + 1) .. (lastI + 1)) 1804 assert(!less(*x, *pivotI), "expandPartition: (pivotI + 1) .. (lastI + 1) failed test"); 1805 return pivotI; 1806 } 1807 } 1808 1809 version(mir_ndslice_test) 1810 @trusted pure nothrow 1811 unittest 1812 { 1813 import mir.ndslice.slice: sliced; 1814 1815 auto a = [ 10, 5, 3, 4, 8, 11, 13, 3, 9, 4, 10 ].sliced; 1816 auto frontI = a._iterator; 1817 auto lastI = frontI + a.length - 1; 1818 auto loI = frontI + 4; 1819 auto pivotI = frontI + 5; 1820 auto hiI = frontI + 6; 1821 assert(expandPartition!((a, b) => a < b)(frontI, lastI, loI, pivotI, hiI) == (frontI + 9)); 1822 } 1823 1824 version(mir_ndslice_test) 1825 unittest 1826 { 1827 import std.random; 1828 import mir.ndslice.sorting: sort; 1829 1830 static struct StructA 1831 { 1832 double val0; 1833 double val1; 1834 double val2; 1835 } 1836 1837 static struct StructB 1838 { 1839 ulong productId; 1840 StructA strA; 1841 } 1842 1843 auto createStructBArray(uint nbTrades) 1844 { 1845 auto rnd = Random(42); 1846 1847 auto p = StructA(0,0,0); 1848 1849 StructB[] ret; 1850 foreach(i;0..nbTrades) 1851 { 1852 ret ~= StructB(uniform(0, nbTrades, rnd), p); 1853 } 1854 1855 return ret; 1856 } 1857 1858 auto arrayB = createStructBArray(10000).sort!((a,b) => a.productId<b.productId); 1859 }