The OpenD Programming Language

1 /++
2 $(H2 Interpolation Modifier)
3 
4 License: $(HTTP www.apache.org/licenses/LICENSE-2.0, Apache-2.0)
5 Copyright: 2022 Ilia, Symmetry Investments
6 Authors: Ilia Ki
7 
8 Macros:
9 SUBREF = $(REF_ALTTEXT $(TT $2), $2, mir, interpolate, $1)$(NBSP)
10 T2=$(TR $(TDNW $(LREF $1)) $(TD $+))
11 +/
12 module mir.interpolate.mod;
13 
14 import mir.math.common;
15 
16 /++
17 Applies function to the interpolated value.
18 
19 Params:
20     fun = two arguments `(x, derivativeOrder)` function
21 +/
22 template interpolationMap(alias fun)
23 {
24     ///
25     auto interpolationMap(T)(T interpolator)
26     {
27         import core.lifetime: move;
28         alias S = InterpolationMap!fun;
29         return S!T(interpolator.move);
30     }
31 }
32 
33 /// ditto
34 template InterpolationMap(alias fun)
35 {
36     ///
37     struct InterpolationMap(T)
38     {
39         static if (__traits(hasMember, T, "derivativeOrder"))
40             enum derivativeOrder = T.derivativeOrder;
41 
42         static if (__traits(hasMember, T, "dimensionCount"))
43             enum uint dimensionCount = T.dimensionCount;
44 
45         ///
46         T interpolator;
47 
48         ///
49         this(T interpolator)
50         {
51             import core.lifetime: move;
52             this.interpolator = interpolator.move;
53         }
54 
55         ///
56         template opCall(uint derivative = 0)
57             // if (derivative <= derivativeOrder)
58         {
59             /++
60             `(x)` operator.
61             Complexity:
62                 `O(log(grid.length))`
63             +/
64             auto opCall(X...)(const X xs) scope const @trusted
65                 // if (X.length == dimensionCount)
66             {
67                 auto g = interpolator.opCall!derivative(xs);
68 
69                 static if (derivative == 0)
70                 {
71                     typeof(g)[1] ret;
72                     fun(g, ret);
73                     return ret[0];
74                 }
75                 else
76                 {
77                     static if (X.length == 1)
78                         auto g0 = g[0];
79                     else
80                     static if (X.length == 2)
81                         auto g0 = g[0][0];
82                     else
83                     static if (X.length == 3)
84                         auto g0 = g[0][0][0];
85                     else
86                     static assert(0, "Not implemented");
87 
88                     typeof(g0)[derivative + 1] f;
89 
90                     fun(g0, f);
91 
92                     static if (X.length == 1)
93                     {
94                         typeof(g) r;
95                         r[0] = f[0];
96                         r[1] = f[1] * g[1];
97 
98                         static if (derivative >= 2)
99                         {
100                             r[2] = f[2] * (g[1] * g[1]) + f[1] * g[2];
101                         }
102                         static if (derivative >= 3)
103                         {
104                             r[3] = f[3] * (g[1] * g[1] * g[1]) + f[1] * g[3] + 3 * (f[2] * g[1] * g[2]);
105                         }
106                         static if (derivative >= 4)
107                         {
108                             static assert(0, "Not implemented");
109                         }
110 
111                         return r;
112                     } else static assert(0, "Not implemented");
113                 }
114             }
115         }
116     }
117 }
118 
119 ///
120 version (mir_test)
121 unittest
122 {
123     import mir.interpolate.spline;
124     import mir.math.common: log;
125     import mir.ndslice.allocation: rcslice;
126     import mir.test;
127 
128     alias g = (double x, uint d = 0) =>
129         d == 0 ? 3 * x ^^ 3 + 5 * x ^^ 2 + 0.23 * x + 2 :
130         d == 1 ? 9 * x ^^ 2 + 10 * x + 0.23 :
131         double.nan;
132     
133     alias f = (double x, ref scope y)
134     {
135         y[0] = log(x);
136         static if (y.length >= 2)
137             y[1] = 1 / x;
138         static if (y.length >= 3)
139             y[2] = -y[1] * y[1];
140         static if (y.length >= 4)
141             y[3] = -2 * y[1] * y[2];
142         static if (y.length >= 5)
143             static assert(0, "Not implemented");
144     };
145 
146     auto s = spline!double(
147         [0.1, 0.4, 0.5, 1.0].rcslice!(immutable double),
148         [g(0.1), g(0.4), g(0.5), g(1.0)].rcslice!(const double)
149     );
150 
151     auto m = s.interpolationMap!f;
152 
153     m(0.7).shouldApprox == log(g(0.7));
154     auto d = m.opCall!3(0.7);
155     d[0].shouldApprox == log(g(0.7));
156     d[1].shouldApprox == 1 / g(0.7) * g(0.7, 1);
157     d[2].shouldApprox == -0.252301;
158     d[3].shouldApprox == -4.03705;
159 }
160 
161 private alias implSqrt = (x, ref scope y)
162 {
163     import mir.math.common: sqrt;
164     y[0] = sqrt(x);
165     static if (y.length >= 2)
166         y[1] = 0.5f / y[0];
167     static if (y.length >= 3)
168         y[2] = -0.5f * y[1] / x;
169     static if (y.length >= 4)
170         y[3] = -1.5f * y[2] / x;
171     static if (y.length >= 5)
172         static assert(0, "Not implemented");
173 };
174 
175 /++
176 Applies square root function to the interpolated value.
177 +/
178 alias interpolationSqrt = interpolationMap!implSqrt;
179 /// ditto
180 alias InterpolationSqrt = InterpolationMap!implSqrt;
181 
182 ///
183 version (mir_test)
184 unittest
185 {
186     import mir.interpolate.spline;
187     import mir.math.common: sqrt;
188     import mir.ndslice.allocation: rcslice;
189     import mir.test;
190 
191     alias g = (double x, uint d = 0) =>
192         d == 0 ? 3 * x ^^ 3 + 5 * x ^^ 2 + 0.23 * x + 2 :
193         d == 1 ? 9 * x ^^ 2 + 10 * x + 0.23 :
194         double.nan;
195     
196     auto s = spline!double(
197         [0.1, 0.4, 0.5, 1.0].rcslice!(immutable double),
198         [g(0.1), g(0.4), g(0.5), g(1.0)].rcslice!(const double)
199     );
200 
201     auto m = s.interpolationSqrt;
202 
203     m(0.7).shouldApprox == sqrt(g(0.7));
204     auto d = m.opCall!3(0.7);
205     d[0].shouldApprox == sqrt(g(0.7));
206     d[1].shouldApprox == 0.5 / sqrt(g(0.7)) * g(0.7, 1);
207     d[2].shouldApprox == 2.2292836438189414;
208     d[3].shouldApprox == -3.11161;
209 }