1 module mir.random.flex.internal.types; 2 3 import std.traits: ReturnType, isFloatingPoint; 4 5 version(Flex_logging) 6 { 7 import std.experimental.logger; 8 } 9 10 /** 11 Major data unit of the Flex algorithm. 12 It is used to store 13 - (cached) values of the transformation (and its derivatives) 14 - area below the hat and squeeze function 15 - linked-list like reference to the right part of the interval (there will always 16 be exactly one interval with right = 0) 17 */ 18 struct Interval(S) 19 if (isFloatingPoint!S) 20 { 21 /// left position of the interval 22 S lx; 23 24 /// right position of the interval 25 S rx; 26 27 /// T_c family of the interval 28 S c; 29 30 /// transformed left value of lx 31 S ltx; 32 33 /// transformed value of the first derivate of the left lx value 34 S lt1x; 35 36 /// transformed value of the second derivate of the left lx value 37 S lt2x; 38 39 /// transformed right value of rx 40 S rtx; 41 42 /// transformed value of the first derivate of the right rx value 43 S rt1x; 44 45 /// transformed value of the second derivate of the right rx value 46 S rt2x; 47 48 /// hat function of the interval 49 LinearFun!S hat; 50 51 /// squeeze function of the interval 52 LinearFun!S squeeze; 53 54 /// calculated area of the integrated hat function 55 S hatArea; 56 57 /// calculated area of the integrated squeeze function 58 S squeezeArea; 59 60 // workaround against @@@BUG 16331@@@ 61 // sets NaN's to be equal on comparison 62 version(Flex_logging) 63 bool opEquals(const Interval s2) const 64 { 65 import mir.math : isNaN, isFloatingPoint; 66 import std.meta : AliasSeq; 67 string buildMixin() 68 { 69 enum symbols = AliasSeq!("lx", "rx", "c", "ltx", "lt1x", "lt2x", 70 "rtx", "rt1x", "rt2x", "hat", "squeeze", "hatArea", "squeezeArea"); 71 enum linSymbols = AliasSeq!("slope", "y", "a"); 72 string s = "return "; 73 foreach (i, attr; symbols) 74 { 75 if (i > 0) 76 s ~= " && "; 77 s ~= "("; 78 auto attrName = symbols[i].stringof; 79 alias T = typeof(mixin("typeof(this).init." ~ attr)); 80 81 if (isFloatingPoint!T) 82 { 83 // allow NaNs 84 s ~= "this." ~ attr ~ ".isNaN && s2." ~ attr ~ ".isNaN ||"; 85 } 86 else if (is(T == const LinearFun!S)) 87 { 88 // allow NaNs 89 s ~= "("; 90 foreach (j, linSymbol; linSymbols) 91 { 92 if (j > 0) 93 s ~= "||"; 94 s ~= attr ~ "." ~ linSymbol ~ ".isNaN"; 95 s ~= "&& s2." ~ attr ~ "." ~ linSymbol ~ ".isNaN"; 96 } 97 s ~= ") ||"; 98 } 99 s ~= attr ~ " == s2." ~ attr; 100 s ~= ")"; 101 } 102 s ~= ";"; 103 return s; 104 } 105 mixin(buildMixin()); 106 } 107 108 /// 109 version(Flex_logging_hex) string logHex() 110 { 111 import std.format : format; 112 return "Interval!%s(%a, %a, %a, %a, %a, %a, %a, %a, %a, %s, %s, %a, %a)" 113 .format(S.stringof, lx, rx, c, ltx, lt1x, lt2x, rtx, rt1x, rt2x, 114 hat.logHex, squeeze.logHex, hatArea, squeezeArea); 115 } 116 } 117 118 /** 119 Notations of different function types according to Botts et al. (2013). 120 It is based on this naming scheme: 121 122 - a: concAve 123 - b: convex 124 - Type 4 is the pure case without any inflection point 125 */ 126 enum FunType {undefined, T1a, T1b, T2a, T2b, T3a, T3b, T4a, T4b} 127 128 /** 129 Determine the function type of an interval. 130 Based on Theorem 1 of the Flex paper. 131 Params: 132 iv = interval 133 */ 134 FunType determineType(S)(in Interval!S iv) 135 in 136 { 137 assert(iv.lx < iv.rx, "invalid interval"); 138 } 139 out(type) 140 { 141 version(Flex_logging) 142 if (!type) 143 warningf("Interval has an undefined type: %s", iv); 144 } 145 do 146 { 147 with(FunType) 148 { 149 // In each unbounded interval f must be concave and strictly monotone 150 // Condition 4 in section 2.3 from Botts et al. (2013) 151 if (iv.lx == -S.infinity) 152 { 153 if (iv.rt2x < 0 && iv.rt1x > 0) 154 return T4a; 155 return undefined; 156 } 157 158 if (iv.rx == +S.infinity) 159 { 160 if (iv.lt2x < 0 && iv.lt1x < 0) 161 return T4a; 162 return undefined; 163 } 164 165 if (iv.c > 0 && iv.ltx == 0 || iv.c <= 0 && iv.ltx == -S.infinity) 166 { 167 if (iv.rt2x < 0 && iv.rt1x > 0) 168 return T4a; 169 if (iv.rt2x > 0 && iv.rt1x > 0) 170 return T4b; 171 return undefined; 172 } 173 174 if (iv.c > 0 && iv.rtx == 0 || iv.c <= 0 && iv.rtx == -S.infinity) 175 { 176 if (iv.lt2x < 0 && iv.lt1x < 0) 177 return T4a; 178 if (iv.lt2x > 0 && iv.lt1x < 0) 179 return T4b; 180 return undefined; 181 } 182 183 if (iv.c < 0) 184 { 185 if (iv.ltx == 0 && iv.rt2x > 0 || iv.rtx == 0 && iv.lt2x > 0) 186 return T4b; 187 } 188 189 // slope of the interval 190 auto R = (iv.rtx - iv.ltx) / (iv.rx- iv.lx); 191 192 if (iv.lt1x >= R && iv.rt1x >= R) 193 return T1a; 194 if (iv.lt1x <= R && iv.rt1x <= R) 195 return T1b; 196 197 if (iv.lt2x <= 0 && iv.rt2x <= 0) 198 return T4a; 199 if (iv.lt2x >= 0 && iv.rt2x >= 0) 200 return T4b; 201 202 if (iv.lt1x >= R && R >= iv.rt1x) 203 { 204 if (iv.lt2x < 0 && iv.rt2x > 0) 205 return T2a; 206 if (iv.lt2x > 0 && iv.rt2x < 0) 207 return T2b; 208 } 209 else if (iv.lt1x <= R && R <= iv.rt1x) 210 { 211 if (iv.lt2x < 0 && iv.rt2x > 0) 212 return T3a; 213 if (iv.lt2x > 0 && iv.rt2x < 0) 214 return T3b; 215 } 216 217 return undefined; 218 } 219 } 220 221 nothrow pure @safe version(mir_random_test) unittest 222 { 223 import std.meta : AliasSeq; 224 foreach (S; AliasSeq!(float, double, real)) with(FunType) 225 { 226 const f0 = (S x) => x ^^ 4; 227 const f1 = (S x) => 4 * x ^^ 3; 228 const f2 = (S x) => 12 * x * x; 229 enum c = 42; // c doesn't matter here 230 auto dt = (S l, S r) => determineType(Interval!S(l, r, c, f0(l), f1(l), f2(l), 231 f0(r), f1(r), f2(r))); 232 233 // entirely convex 234 assert(dt(-3.0, -1) == T4b); 235 assert(dt(-1.0, 1) == T4b); 236 assert(dt(1.0, 3) == T4b); 237 } 238 } 239 240 // test x^3 241 nothrow pure @safe version(mir_random_test) unittest 242 { 243 import std.meta : AliasSeq; 244 foreach (S; AliasSeq!(float, double, real)) with(FunType) 245 { 246 const f0 = (S x) => x ^^ 3; 247 const f1 = (S x) => 3 * x ^^ 2; 248 const f2 = (S x) => 6 * x; 249 enum c = 42; // c doesn't matter here 250 auto dt = (S l, S r) => determineType(Interval!S(l, r, c, f0(l), f1(l), f2(l), 251 f0(r), f1(r), f2(r))); 252 253 // concave 254 assert(dt(-S.infinity, S(-1.0)) == T4a); 255 assert(dt(S(-3.0), S(-1)) == T4a); 256 257 // inflection point at x = 0, concave before 258 assert(dt(S(-1.0), S(1)) == T1a); 259 // convex 260 assert(dt(S(1.0), S(3)) == T4b); 261 } 262 } 263 264 // test sin(x) 265 nothrow pure @safe version(mir_random_test) unittest 266 { 267 import mir.math: PI; 268 // due to numerical errors a small padding must be added 269 // see e.g. https://gist.github.com/wilzbach/3d27d06b55821aa9795deb15d4d47679 270 import mir.math.common : cos, sin; 271 272 import std.meta : AliasSeq; 273 foreach (S; AliasSeq!(float, double, real)) with(FunType) 274 { 275 import std.stdio; 276 const f0 = (S x) => sin(x); 277 const f1 = (S x) => cos(x); 278 const f2 = (S x) => -sin(x); 279 enum c = 42; // c doesn't matter here 280 auto dt = (S l, S r) => determineType(Interval!S(l, r, c, f0(l), f1(l), f2(l), 281 f0(r), f1(r), f2(r))); 282 // type 1a: concave 283 assert(dt(0.01, 2 * PI - 0.01) == T1a); 284 assert(dt(2 * PI + 0.01, 4 * PI - 0.01) == T1a); 285 assert(dt(2, 4) == T1a); 286 assert(dt(0.01, 5) == T1a); 287 assert(dt(1, 5) == T1a); 288 289 // type 1b: convex 290 assert(dt(-PI, PI) == T1b); 291 assert(dt(PI, 3 * PI) == T1b); 292 assert(dt(4, 8) == T1b); 293 294 // type 2a: concave 295 assert(dt(1, 4) == T2a); 296 297 // type 2b: convex 298 assert(dt(6, 8) == T2b); 299 300 // type 3a: concave 301 assert(dt(3, 4) == T3a); 302 assert(dt(2, 5.7) == T3a); 303 304 // type 3b: concave 305 assert(dt(-3, 0.1) == T3b); 306 307 // type 4a - pure concave intervals (special case of 2a) 308 assert(dt(0.01, PI - 0.01) == T4a); 309 assert(dt(0.01, 3) == T4a); 310 assert(dt(2 * PI + 0.01, 3 * PI - 0.01) == T4a); 311 312 // type 4b - pure convex intervals (special case of 3b) 313 assert(dt(-PI + 0.01, -0.01) == T4b); 314 assert(dt(PI + 0.01, 2 * PI - 0.01) == T4b); 315 assert(dt(4, 6) == T4b); 316 } 317 } 318 319 nothrow pure @safe version(mir_random_test) unittest 320 { 321 import std.meta : AliasSeq; 322 foreach (S; AliasSeq!(float, double, real)) with(FunType) 323 { 324 const f0 = (S x) => x * x; 325 const f1 = (S x) => 2 * x; 326 const f2 = (S x) => 2.0; 327 enum c = 42; // c doesn't matter here 328 auto dt = (S l, S r) => determineType(Interval!S(l, r, c, f0(l), f1(l), f2(l), 329 f0(r), f1(r), f2(r))); 330 // entirely convex 331 assert(dt(-1, 1) == T4b); 332 assert(dt(1, 3) == T4b); 333 } 334 } 335 336 337 /** 338 Representation of linear function of the form: 339 340 y = slope * (x - y) + a 341 342 This representation allows a bit higher precision than the 343 typical representation `y = slope * x + a`. 344 */ 345 struct LinearFun(S) 346 { 347 import std.format : FormatSpec; 348 349 /// direction and steepness (aka beta) 350 S slope; 351 352 /// boundary point where f obtains it's maximum 353 S y; 354 355 /// constant intercept 356 S a; 357 358 /** 359 Params: 360 slope = direction and steepness 361 y = boundary point, often f(x) 362 a = constant intercept 363 */ 364 this(S slope, S y, S a) 365 { 366 this.slope = slope; 367 this.y = y; 368 this.a = a; 369 } 370 371 /// textual representation of the function 372 void toString(W)(auto ref W w, const ref FormatSpec!char fmt) const 373 { 374 import std.range: put; 375 import mir.math.common: fabs; 376 import std.format: formatValue, singleSpec; 377 switch(fmt.spec) 378 { 379 case 'l': 380 import mir.math: approxEqual; 381 if (slope != slope) 382 w.put("#NaN#"); 383 else 384 { 385 auto spec2g = singleSpec("%.2g"); 386 if (!slope.approxEqual(0, 1e-5, 1e-5)) 387 { 388 w.formatValue(slope, spec2g); 389 w.put("x"); 390 if (!intercept.approxEqual(0, 1e-5, 1e-5)) 391 { 392 auto sgn = intercept > 0 ? " + " : " - "; 393 w.put(sgn); 394 w.formatValue(fabs(intercept), spec2g); 395 } 396 } 397 else 398 { 399 w.formatValue(intercept, spec2g); 400 } 401 } 402 break; 403 case 's': 404 default: 405 import std.traits : Unqual; 406 w.put(Unqual!(typeof(this)).stringof); 407 auto spec2g = singleSpec("%.6g"); 408 w.put("("); 409 w.formatValue(slope, spec2g); 410 w.put(", "); 411 w.formatValue(y, spec2g); 412 w.put(", "); 413 w.formatValue(a, spec2g); 414 w.put(")"); 415 break; 416 } 417 } 418 419 /// call the linear function with x 420 S opCall(in S x) const 421 { 422 S val = slope * (x - y); 423 val += a; 424 return val; 425 } 426 427 /// calculate inverse of x 428 S inverse(S x) const 429 { 430 return y + (x - a) / slope; 431 } 432 433 // calculate intercept (for debugging) 434 S intercept() @property const 435 { 436 return slope * -y + a; 437 } 438 439 /// 440 string logHex() 441 { 442 import std.format : format; 443 return "LinearFun!%s(%a, %a, %a)".format(S.stringof, slope, y, a); 444 } 445 } 446 447 /** 448 Constructs a linear function of the form `y = slope * (x - y) + a`. 449 450 Params: 451 slope = direction and steepness 452 y = boundary point, often f(x) 453 a = constant intercept 454 Returns: 455 A linear function constructed with the given parameters. 456 */ 457 LinearFun!S linearFun(S)(S slope, S y, S a) 458 { 459 return LinearFun!S(slope, y, a); 460 } 461 462 /// tangent of a point 463 @safe version(mir_random_test) unittest 464 { 465 import std.format : format; 466 auto f = (double x) => x * x + 1; 467 auto df = (double x) => 2 * x; 468 auto buildTan = (double x) => linearFun(df(x), x, f(x)); 469 470 auto t0 = buildTan(0); 471 assert("%l".format(t0)== "1"); 472 assert(t0(0) == 1); 473 assert(t0(42) == 1); 474 475 auto t1 = buildTan(1); 476 assert("%l".format(t1) == "2x"); 477 assert(t1(1) == 2); 478 assert(t1(2) == 4); 479 480 auto t2 = buildTan(2); 481 assert("%l".format(t2) == "4x - 3"); 482 assert(t2(1) == 1); 483 assert(t2(2) == 5); 484 } 485 486 /// secant of two points 487 @safe version(mir_random_test) unittest 488 { 489 import std.format : format; 490 auto f = (double x) => x * x + 1; 491 auto lx = 1, rx = 3; 492 // compute the slope between lx and rx 493 auto lf = linearFun((f(rx) - f(lx)) / (rx - lx), lx, f(lx)); 494 495 assert("%l".format(lf) == "4x - 2"); 496 assert(lf(1) == 2); // f(1) 497 assert(lf(3) == 10); // f(3) 498 } 499 500 /// construct an arbitrary linear function 501 @safe version(mir_random_test) unittest 502 { 503 import std.format : format; 504 505 // 2 * x + 1 506 auto t = linearFun!double(2, 0, 1); 507 assert("%l".format(t) == "2x + 1"); 508 assert(t(1) == 3); 509 assert(t(-2) == -3); 510 } 511 512 @nogc nothrow pure @safe version(mir_random_test) unittest 513 { 514 import std.meta : AliasSeq; 515 foreach (S; AliasSeq!(float, double, real)) 516 { 517 auto f1 = (S x) => 2 * x; 518 519 auto t1 = linearFun!S(f1(1), 1, 1); 520 assert(t1.slope == 2); 521 assert(t1.intercept == -1); 522 523 auto t2 = linearFun!S(f1(0), 0, 0); 524 assert(t2.slope == 0); 525 assert(t2.intercept == 0); 526 } 527 } 528 529 nothrow pure @safe version(mir_random_test) unittest 530 { 531 import mir.math : cos, PI, approxEqual; 532 import std.meta : AliasSeq; 533 foreach (S; AliasSeq!(float, double, real)) 534 { 535 auto f = (S x) => cos(x); 536 auto buildTan = (S x, S y) => linearFun(f(x), x, y); 537 auto t1 = buildTan(0, 0); 538 assert(t1.slope == 1); 539 assert(t1.intercept == 0); 540 541 auto t2 = buildTan(PI / 2, 1); 542 assert(t2.slope.approxEqual(0)); 543 assert(t2.intercept.approxEqual(1)); 544 } 545 } 546 547 // test default toString 548 @safe version(mir_random_test) unittest 549 { 550 import std.format : format; 551 auto t = linearFun!double(2, 0, 1); 552 assert("%s".format(t) == "LinearFun!double(2, 0, 1)"); 553 } 554 555 // test NaN behavior 556 @safe version(mir_random_test) unittest 557 { 558 import std.format : format; 559 auto t = linearFun!double(double.nan, 0, 1); 560 assert("%s".format(t) == "LinearFun!double(nan, 0, 1)"); 561 assert("%l".format(t) == "#NaN#"); 562 } 563 564 /** 565 Compares whether to linear functions are approximately equal. 566 567 Params: 568 x = first linear function to compare 569 y = second linear function to compare 570 maxRelDiff = maximum relative difference 571 maxAbsDiff = maximum absolute difference 572 573 Returns: 574 True if both linear functions are approximately equal. 575 */ 576 bool approxEqual(S)(LinearFun!S x, LinearFun!S y, S maxRelDiff = 1e-2, S maxAbsDiff = 1e-5) 577 { 578 import mir.math : approxEqual; 579 return x.slope.approxEqual(y.slope, maxRelDiff, maxAbsDiff) && 580 x.y.approxEqual(y.y, maxRelDiff, maxAbsDiff) && 581 x.a.approxEqual(y.a, maxRelDiff, maxAbsDiff); 582 } 583 584 /// 585 @nogc nothrow pure @safe version(mir_random_test) unittest 586 { 587 auto x = linearFun!double(2, 0, 1); 588 auto x2 = linearFun!double(2, 0, 1); 589 assert(x.approxEqual(x2)); 590 591 auto y = linearFun!double(2, 1e-9, 1); 592 assert(x.approxEqual(y)); 593 594 auto z = linearFun!double(2, 4, 1); 595 assert(!x.approxEqual(z)); 596 }