The OpenD Programming Language

1 module mir.random.flex.internal.types;
2 
3 import std.traits: ReturnType, isFloatingPoint;
4 
5 version(Flex_logging)
6 {
7     import std.experimental.logger;
8 }
9 
10 /**
11 Major data unit of the Flex algorithm.
12 It is used to store
13 - (cached) values of the transformation (and its derivatives)
14 - area below the hat and squeeze function
15 - linked-list like reference to the right part of the interval (there will always
16 be exactly one interval with right = 0)
17 */
18 struct Interval(S)
19     if (isFloatingPoint!S)
20 {
21     /// left position of the interval
22     S lx;
23 
24     /// right position of the interval
25     S rx;
26 
27     /// T_c family of the interval
28     S c;
29 
30     /// transformed left value of lx
31     S ltx;
32 
33     /// transformed value of the first derivate of the left lx value
34     S lt1x;
35 
36     /// transformed value of the second derivate of the left lx value
37     S lt2x;
38 
39     /// transformed right value of rx
40     S rtx;
41 
42     /// transformed value of the first derivate of the right rx value
43     S rt1x;
44 
45     /// transformed value of the second derivate of the right rx value
46     S rt2x;
47 
48     /// hat function of the interval
49     LinearFun!S hat;
50 
51     /// squeeze function of the interval
52     LinearFun!S squeeze;
53 
54     /// calculated area of the integrated hat function
55     S hatArea;
56 
57     /// calculated area of the integrated squeeze function
58     S squeezeArea;
59 
60     // workaround against @@@BUG 16331@@@
61     // sets NaN's to be equal on comparison
62     version(Flex_logging)
63     bool opEquals(const Interval s2) const
64     {
65         import mir.math : isNaN, isFloatingPoint;
66         import std.meta : AliasSeq;
67         string buildMixin()
68         {
69             enum symbols = AliasSeq!("lx", "rx", "c", "ltx", "lt1x", "lt2x",
70                                      "rtx", "rt1x", "rt2x", "hat", "squeeze", "hatArea", "squeezeArea");
71             enum linSymbols = AliasSeq!("slope", "y", "a");
72             string s = "return ";
73             foreach (i, attr; symbols)
74             {
75                 if (i > 0)
76                     s ~= " && ";
77                 s ~= "(";
78                 auto attrName = symbols[i].stringof;
79                 alias T = typeof(mixin("typeof(this).init." ~ attr));
80 
81                 if (isFloatingPoint!T)
82                 {
83                     // allow NaNs
84                     s ~= "this." ~ attr ~ ".isNaN && s2." ~ attr ~ ".isNaN ||";
85                 }
86                 else if (is(T == const LinearFun!S))
87                 {
88                     // allow NaNs
89                     s ~= "(";
90                     foreach (j, linSymbol; linSymbols)
91                     {
92                         if (j > 0)
93                             s ~= "||";
94                         s ~= attr ~ "." ~ linSymbol ~ ".isNaN";
95                         s ~= "&& s2." ~ attr ~ "." ~ linSymbol ~ ".isNaN";
96                     }
97                     s ~= ") ||";
98                 }
99                 s ~= attr ~ " == s2." ~ attr;
100                 s ~= ")";
101             }
102             s ~= ";";
103             return s;
104         }
105         mixin(buildMixin());
106     }
107 
108     ///
109     version(Flex_logging_hex) string logHex()
110     {
111         import std.format : format;
112         return "Interval!%s(%a, %a, %a, %a, %a, %a, %a, %a, %a, %s, %s, %a, %a)"
113                .format(S.stringof, lx, rx, c, ltx, lt1x, lt2x, rtx, rt1x, rt2x,
114                        hat.logHex, squeeze.logHex, hatArea, squeezeArea);
115     }
116 }
117 
118 /**
119 Notations of different function types according to Botts et al. (2013).
120 It is based on this naming scheme:
121 
122 - a: concAve
123 - b: convex
124 - Type 4 is the pure case without any inflection point
125 */
126 enum FunType {undefined, T1a, T1b, T2a, T2b, T3a, T3b, T4a, T4b}
127 
128 /**
129 Determine the function type of an interval.
130 Based on Theorem 1 of the Flex paper.
131 Params:
132     iv = interval
133 */
134 FunType determineType(S)(in Interval!S iv)
135 in
136 {
137     assert(iv.lx < iv.rx, "invalid interval");
138 }
139 out(type)
140 {
141     version(Flex_logging)
142     if (!type)
143         warningf("Interval has an undefined type: %s", iv);
144 }
145 do
146 {
147     with(FunType)
148     {
149         // In each unbounded interval f must be concave and strictly monotone
150         // Condition 4 in section 2.3 from Botts et al. (2013)
151         if (iv.lx == -S.infinity)
152         {
153             if (iv.rt2x < 0 && iv.rt1x > 0)
154                 return T4a;
155             return undefined;
156         }
157 
158         if (iv.rx == +S.infinity)
159         {
160             if (iv.lt2x < 0 && iv.lt1x < 0)
161                 return T4a;
162             return undefined;
163         }
164 
165         if (iv.c > 0  && iv.ltx == 0 || iv.c <= 0 && iv.ltx == -S.infinity)
166         {
167             if (iv.rt2x < 0 && iv.rt1x > 0)
168                 return T4a;
169             if (iv.rt2x > 0 && iv.rt1x > 0)
170                 return T4b;
171             return undefined;
172         }
173 
174         if (iv.c > 0  && iv.rtx == 0 || iv.c <= 0 && iv.rtx == -S.infinity)
175         {
176             if (iv.lt2x < 0 && iv.lt1x < 0)
177                 return T4a;
178             if (iv.lt2x > 0 && iv.lt1x < 0)
179                 return T4b;
180             return undefined;
181         }
182 
183         if (iv.c < 0)
184         {
185             if (iv.ltx == 0  && iv.rt2x > 0 || iv.rtx == 0 && iv.lt2x > 0)
186                 return T4b;
187         }
188 
189         // slope of the interval
190         auto R = (iv.rtx - iv.ltx) / (iv.rx- iv.lx);
191 
192         if (iv.lt1x >= R && iv.rt1x >= R)
193             return T1a;
194         if (iv.lt1x <= R && iv.rt1x <= R)
195             return T1b;
196 
197         if (iv.lt2x <= 0 && iv.rt2x <= 0)
198             return T4a;
199         if (iv.lt2x >= 0 && iv.rt2x >= 0)
200             return T4b;
201 
202         if (iv.lt1x >= R && R >= iv.rt1x)
203         {
204             if (iv.lt2x < 0 && iv.rt2x > 0)
205                 return T2a;
206             if (iv.lt2x > 0 && iv.rt2x < 0)
207                 return T2b;
208         }
209         else if (iv.lt1x <= R && R <= iv.rt1x)
210         {
211             if (iv.lt2x < 0 && iv.rt2x > 0)
212                 return T3a;
213             if (iv.lt2x > 0 && iv.rt2x < 0)
214                 return T3b;
215         }
216 
217         return undefined;
218     }
219 }
220 
221 nothrow pure @safe version(mir_random_test) unittest
222 {
223     import std.meta : AliasSeq;
224     foreach (S; AliasSeq!(float, double, real)) with(FunType)
225     {
226         const f0 = (S x) => x ^^ 4;
227         const f1 = (S x) => 4 * x ^^ 3;
228         const f2 = (S x) => 12 * x * x;
229         enum c = 42; // c doesn't matter here
230         auto dt = (S l, S r) => determineType(Interval!S(l, r, c, f0(l), f1(l), f2(l),
231                                                                   f0(r), f1(r), f2(r)));
232 
233         // entirely convex
234         assert(dt(-3.0, -1) == T4b);
235         assert(dt(-1.0, 1) == T4b);
236         assert(dt(1.0, 3) == T4b);
237     }
238 }
239 
240 // test x^3
241 nothrow pure @safe version(mir_random_test) unittest
242 {
243     import std.meta : AliasSeq;
244     foreach (S; AliasSeq!(float, double, real)) with(FunType)
245     {
246         const f0 = (S x) => x ^^ 3;
247         const f1 = (S x) => 3 * x ^^ 2;
248         const f2 = (S x) => 6 * x;
249         enum c = 42; // c doesn't matter here
250         auto dt = (S l, S r) => determineType(Interval!S(l, r, c, f0(l), f1(l), f2(l),
251                                                                   f0(r), f1(r), f2(r)));
252 
253         // concave
254         assert(dt(-S.infinity, S(-1.0)) == T4a);
255         assert(dt(S(-3.0), S(-1)) == T4a);
256 
257         // inflection point at x = 0, concave before
258         assert(dt(S(-1.0), S(1)) == T1a);
259         // convex
260         assert(dt(S(1.0), S(3)) == T4b);
261     }
262 }
263 
264 // test sin(x)
265 nothrow pure @safe version(mir_random_test) unittest
266 {
267     import mir.math: PI;
268     // due to numerical errors a small padding must be added
269     // see e.g. https://gist.github.com/wilzbach/3d27d06b55821aa9795deb15d4d47679
270     import mir.math.common : cos, sin;
271 
272     import std.meta : AliasSeq;
273     foreach (S; AliasSeq!(float, double, real)) with(FunType)
274     {
275         import std.stdio;
276         const f0 = (S x) => sin(x);
277         const f1 = (S x) => cos(x);
278         const f2 = (S x) => -sin(x);
279         enum c = 42; // c doesn't matter here
280         auto dt = (S l, S r) => determineType(Interval!S(l, r, c, f0(l), f1(l), f2(l),
281                                                                   f0(r), f1(r), f2(r)));
282         // type 1a: concave
283         assert(dt(0.01, 2 * PI - 0.01) == T1a);
284         assert(dt(2 * PI + 0.01, 4 * PI - 0.01) == T1a);
285         assert(dt(2, 4) == T1a);
286         assert(dt(0.01, 5) == T1a);
287         assert(dt(1, 5) == T1a);
288 
289         // type 1b: convex
290         assert(dt(-PI, PI) == T1b);
291         assert(dt(PI, 3 * PI) == T1b);
292         assert(dt(4, 8) == T1b);
293 
294         // type 2a: concave
295         assert(dt(1, 4) == T2a);
296 
297         // type 2b: convex
298         assert(dt(6, 8) == T2b);
299 
300         // type 3a: concave
301         assert(dt(3, 4) == T3a);
302         assert(dt(2, 5.7) == T3a);
303 
304         // type 3b: concave
305         assert(dt(-3, 0.1) == T3b);
306 
307         // type 4a - pure concave intervals (special case of 2a)
308         assert(dt(0.01, PI - 0.01) == T4a);
309         assert(dt(0.01, 3) == T4a);
310         assert(dt(2 * PI + 0.01, 3 * PI - 0.01) == T4a);
311 
312         // type 4b - pure convex intervals (special case of 3b)
313         assert(dt(-PI + 0.01, -0.01) == T4b);
314         assert(dt(PI + 0.01, 2 * PI - 0.01) == T4b);
315         assert(dt(4, 6) == T4b);
316     }
317 }
318 
319 nothrow pure @safe version(mir_random_test) unittest
320 {
321     import std.meta : AliasSeq;
322     foreach (S; AliasSeq!(float, double, real)) with(FunType)
323     {
324         const f0 = (S x) => x * x;
325         const f1 = (S x) => 2 * x;
326         const f2 = (S x) => 2.0;
327         enum c = 42; // c doesn't matter here
328         auto dt = (S l, S r) => determineType(Interval!S(l, r, c, f0(l), f1(l), f2(l),
329                                                                   f0(r), f1(r), f2(r)));
330         // entirely convex
331         assert(dt(-1, 1) == T4b);
332         assert(dt(1, 3) == T4b);
333     }
334 }
335 
336 
337 /**
338 Representation of linear function of the form:
339 
340     y = slope * (x - y) + a
341 
342 This representation allows a bit higher precision than the
343 typical representation `y = slope * x + a`.
344 */
345 struct LinearFun(S)
346 {
347     import std.format : FormatSpec;
348 
349     /// direction and steepness (aka beta)
350     S slope;
351 
352     /// boundary point where f obtains it's maximum
353     S y;
354 
355     /// constant intercept
356     S a;
357 
358     /**
359     Params:
360         slope = direction and steepness
361         y = boundary point, often f(x)
362         a = constant intercept
363     */
364     this(S slope, S y, S a)
365     {
366         this.slope = slope;
367         this.y = y;
368         this.a = a;
369     }
370 
371     /// textual representation of the function
372     void toString(W)(auto ref W w, const ref FormatSpec!char fmt) const
373     {
374         import std.range: put;
375         import mir.math.common: fabs;
376         import std.format: formatValue, singleSpec;
377         switch(fmt.spec)
378         {
379             case 'l':
380                 import mir.math: approxEqual;
381                 if (slope != slope)
382                     w.put("#NaN#");
383                 else
384                 {
385                     auto spec2g = singleSpec("%.2g");
386                     if (!slope.approxEqual(0, 1e-5, 1e-5))
387                     {
388                         w.formatValue(slope, spec2g);
389                         w.put("x");
390                         if (!intercept.approxEqual(0, 1e-5, 1e-5))
391                         {
392                             auto sgn = intercept > 0 ? " + " : " - ";
393                             w.put(sgn);
394                             w.formatValue(fabs(intercept), spec2g);
395                         }
396                     }
397                     else
398                     {
399                         w.formatValue(intercept, spec2g);
400                     }
401             }
402                 break;
403             case 's':
404             default:
405                 import std.traits : Unqual;
406                 w.put(Unqual!(typeof(this)).stringof);
407                 auto spec2g = singleSpec("%.6g");
408                 w.put("(");
409                 w.formatValue(slope, spec2g);
410                 w.put(", ");
411                 w.formatValue(y, spec2g);
412                 w.put(", ");
413                 w.formatValue(a, spec2g);
414                 w.put(")");
415                 break;
416         }
417     }
418 
419     /// call the linear function with x
420     S opCall(in S x) const
421     {
422         S val = slope * (x - y);
423         val += a;
424         return val;
425     }
426 
427     /// calculate inverse of x
428     S inverse(S x) const
429     {
430         return y + (x - a) / slope;
431     }
432 
433     // calculate intercept (for debugging)
434     S intercept() @property const
435     {
436         return slope * -y + a;
437     }
438 
439     ///
440     string logHex()
441     {
442         import std.format : format;
443         return "LinearFun!%s(%a, %a, %a)".format(S.stringof, slope, y, a);
444     }
445 }
446 
447 /**
448 Constructs a linear function of the form `y = slope * (x - y) + a`.
449 
450 Params:
451     slope = direction and steepness
452     y = boundary point, often f(x)
453     a = constant intercept
454 Returns:
455     A linear function constructed with the given parameters.
456 */
457 LinearFun!S linearFun(S)(S slope, S y, S a)
458 {
459     return LinearFun!S(slope, y, a);
460 }
461 
462 /// tangent of a point
463 @safe version(mir_random_test) unittest
464 {
465     import std.format : format;
466     auto f = (double x) => x * x + 1;
467     auto df = (double x) => 2 * x;
468     auto buildTan = (double x) => linearFun(df(x), x, f(x));
469 
470     auto t0 = buildTan(0);
471     assert("%l".format(t0)== "1");
472     assert(t0(0) == 1);
473     assert(t0(42) == 1);
474 
475     auto t1 = buildTan(1);
476     assert("%l".format(t1) == "2x");
477     assert(t1(1) == 2);
478     assert(t1(2) == 4);
479 
480     auto t2 = buildTan(2);
481     assert("%l".format(t2) == "4x - 3");
482     assert(t2(1) == 1);
483     assert(t2(2) == 5);
484 }
485 
486 /// secant of two points
487 @safe version(mir_random_test) unittest
488 {
489     import std.format : format;
490     auto f = (double x) => x * x + 1;
491     auto lx = 1, rx = 3;
492     // compute the slope between lx and rx
493     auto lf = linearFun((f(rx) - f(lx)) / (rx - lx), lx, f(lx));
494 
495     assert("%l".format(lf) == "4x - 2");
496     assert(lf(1) == 2); // f(1)
497     assert(lf(3) == 10); // f(3)
498 }
499 
500 /// construct an arbitrary linear function
501 @safe version(mir_random_test) unittest
502 {
503     import std.format : format;
504 
505     // 2 * x + 1
506     auto t = linearFun!double(2, 0, 1);
507     assert("%l".format(t) == "2x + 1");
508     assert(t(1) == 3);
509     assert(t(-2) == -3);
510 }
511 
512 @nogc nothrow pure @safe version(mir_random_test) unittest
513 {
514     import std.meta : AliasSeq;
515     foreach (S; AliasSeq!(float, double, real))
516     {
517         auto f1 = (S x) => 2 * x;
518 
519         auto t1 = linearFun!S(f1(1), 1, 1);
520         assert(t1.slope == 2);
521         assert(t1.intercept == -1);
522 
523         auto t2 = linearFun!S(f1(0), 0, 0);
524         assert(t2.slope == 0);
525         assert(t2.intercept == 0);
526     }
527 }
528 
529 nothrow pure @safe version(mir_random_test) unittest
530 {
531     import mir.math : cos, PI, approxEqual;
532     import std.meta : AliasSeq;
533     foreach (S; AliasSeq!(float, double, real))
534     {
535         auto f = (S x) => cos(x);
536         auto buildTan = (S x, S y) => linearFun(f(x), x, y);
537         auto t1 = buildTan(0, 0);
538         assert(t1.slope == 1);
539         assert(t1.intercept == 0);
540 
541         auto t2 = buildTan(PI / 2, 1);
542         assert(t2.slope.approxEqual(0));
543         assert(t2.intercept.approxEqual(1));
544     }
545 }
546 
547 // test default toString
548 @safe version(mir_random_test) unittest
549 {
550     import std.format : format;
551     auto t = linearFun!double(2, 0, 1);
552     assert("%s".format(t) == "LinearFun!double(2, 0, 1)");
553 }
554 
555 // test NaN behavior
556 @safe version(mir_random_test) unittest
557 {
558     import std.format : format;
559     auto t = linearFun!double(double.nan, 0, 1);
560     assert("%s".format(t) == "LinearFun!double(nan, 0, 1)");
561     assert("%l".format(t) == "#NaN#");
562 }
563 
564 /**
565 Compares whether to linear functions are approximately equal.
566 
567 Params:
568     x = first linear function to compare
569     y = second linear function to compare
570     maxRelDiff = maximum relative difference
571     maxAbsDiff = maximum absolute difference
572 
573 Returns:
574     True if both linear functions are approximately equal.
575 */
576 bool approxEqual(S)(LinearFun!S x, LinearFun!S y, S maxRelDiff = 1e-2, S maxAbsDiff = 1e-5)
577 {
578     import mir.math : approxEqual;
579     return x.slope.approxEqual(y.slope, maxRelDiff, maxAbsDiff) &&
580            x.y.approxEqual(y.y, maxRelDiff, maxAbsDiff) &&
581            x.a.approxEqual(y.a, maxRelDiff, maxAbsDiff);
582 }
583 
584 ///
585 @nogc nothrow pure @safe version(mir_random_test) unittest
586 {
587     auto x = linearFun!double(2, 0, 1);
588     auto x2 = linearFun!double(2, 0, 1);
589     assert(x.approxEqual(x2));
590 
591     auto y = linearFun!double(2, 1e-9, 1);
592     assert(x.approxEqual(y));
593 
594     auto z = linearFun!double(2, 4, 1);
595     assert(!x.approxEqual(z));
596 }