1 /++ 2 $(H2 Interpolation Modifier) 3 4 License: $(HTTP www.apache.org/licenses/LICENSE-2.0, Apache-2.0) 5 Copyright: 2022 Ilia, Symmetry Investments 6 Authors: Ilia Ki 7 8 Macros: 9 SUBREF = $(REF_ALTTEXT $(TT $2), $2, mir, interpolate, $1)$(NBSP) 10 T2=$(TR $(TDNW $(LREF $1)) $(TD $+)) 11 +/ 12 module mir.interpolate.mod; 13 14 import mir.math.common; 15 16 /++ 17 Applies function to the interpolated value. 18 19 Params: 20 fun = two arguments `(x, derivativeOrder)` function 21 +/ 22 template interpolationMap(alias fun) 23 { 24 /// 25 auto interpolationMap(T)(T interpolator) 26 { 27 import core.lifetime: move; 28 alias S = InterpolationMap!fun; 29 return S!T(interpolator.move); 30 } 31 } 32 33 /// ditto 34 template InterpolationMap(alias fun) 35 { 36 /// 37 struct InterpolationMap(T) 38 { 39 static if (__traits(hasMember, T, "derivativeOrder")) 40 enum derivativeOrder = T.derivativeOrder; 41 42 static if (__traits(hasMember, T, "dimensionCount")) 43 enum uint dimensionCount = T.dimensionCount; 44 45 /// 46 T interpolator; 47 48 /// 49 this(T interpolator) 50 { 51 import core.lifetime: move; 52 this.interpolator = interpolator.move; 53 } 54 55 /// 56 template opCall(uint derivative = 0) 57 // if (derivative <= derivativeOrder) 58 { 59 /++ 60 `(x)` operator. 61 Complexity: 62 `O(log(grid.length))` 63 +/ 64 auto opCall(X...)(const X xs) scope const @trusted 65 // if (X.length == dimensionCount) 66 { 67 auto g = interpolator.opCall!derivative(xs); 68 69 static if (derivative == 0) 70 { 71 typeof(g)[1] ret; 72 fun(g, ret); 73 return ret[0]; 74 } 75 else 76 { 77 static if (X.length == 1) 78 auto g0 = g[0]; 79 else 80 static if (X.length == 2) 81 auto g0 = g[0][0]; 82 else 83 static if (X.length == 3) 84 auto g0 = g[0][0][0]; 85 else 86 static assert(0, "Not implemented"); 87 88 typeof(g0)[derivative + 1] f; 89 90 fun(g0, f); 91 92 static if (X.length == 1) 93 { 94 typeof(g) r; 95 r[0] = f[0]; 96 r[1] = f[1] * g[1]; 97 98 static if (derivative >= 2) 99 { 100 r[2] = f[2] * (g[1] * g[1]) + f[1] * g[2]; 101 } 102 static if (derivative >= 3) 103 { 104 r[3] = f[3] * (g[1] * g[1] * g[1]) + f[1] * g[3] + 3 * (f[2] * g[1] * g[2]); 105 } 106 static if (derivative >= 4) 107 { 108 static assert(0, "Not implemented"); 109 } 110 111 return r; 112 } else static assert(0, "Not implemented"); 113 } 114 } 115 } 116 } 117 } 118 119 /// 120 version (mir_test) 121 unittest 122 { 123 import mir.interpolate.spline; 124 import mir.math.common: log; 125 import mir.ndslice.allocation: rcslice; 126 import mir.test; 127 128 alias g = (double x, uint d = 0) => 129 d == 0 ? 3 * x ^^ 3 + 5 * x ^^ 2 + 0.23 * x + 2 : 130 d == 1 ? 9 * x ^^ 2 + 10 * x + 0.23 : 131 double.nan; 132 133 alias f = (double x, ref scope y) 134 { 135 y[0] = log(x); 136 static if (y.length >= 2) 137 y[1] = 1 / x; 138 static if (y.length >= 3) 139 y[2] = -y[1] * y[1]; 140 static if (y.length >= 4) 141 y[3] = -2 * y[1] * y[2]; 142 static if (y.length >= 5) 143 static assert(0, "Not implemented"); 144 }; 145 146 auto s = spline!double( 147 [0.1, 0.4, 0.5, 1.0].rcslice!(immutable double), 148 [g(0.1), g(0.4), g(0.5), g(1.0)].rcslice!(const double) 149 ); 150 151 auto m = s.interpolationMap!f; 152 153 m(0.7).shouldApprox == log(g(0.7)); 154 auto d = m.opCall!3(0.7); 155 d[0].shouldApprox == log(g(0.7)); 156 d[1].shouldApprox == 1 / g(0.7) * g(0.7, 1); 157 d[2].shouldApprox == -0.252301; 158 d[3].shouldApprox == -4.03705; 159 } 160 161 private alias implSqrt = (x, ref scope y) 162 { 163 import mir.math.common: sqrt; 164 y[0] = sqrt(x); 165 static if (y.length >= 2) 166 y[1] = 0.5f / y[0]; 167 static if (y.length >= 3) 168 y[2] = -0.5f * y[1] / x; 169 static if (y.length >= 4) 170 y[3] = -1.5f * y[2] / x; 171 static if (y.length >= 5) 172 static assert(0, "Not implemented"); 173 }; 174 175 /++ 176 Applies square root function to the interpolated value. 177 +/ 178 alias interpolationSqrt = interpolationMap!implSqrt; 179 /// ditto 180 alias InterpolationSqrt = InterpolationMap!implSqrt; 181 182 /// 183 version (mir_test) 184 unittest 185 { 186 import mir.interpolate.spline; 187 import mir.math.common: sqrt; 188 import mir.ndslice.allocation: rcslice; 189 import mir.test; 190 191 alias g = (double x, uint d = 0) => 192 d == 0 ? 3 * x ^^ 3 + 5 * x ^^ 2 + 0.23 * x + 2 : 193 d == 1 ? 9 * x ^^ 2 + 10 * x + 0.23 : 194 double.nan; 195 196 auto s = spline!double( 197 [0.1, 0.4, 0.5, 1.0].rcslice!(immutable double), 198 [g(0.1), g(0.4), g(0.5), g(1.0)].rcslice!(const double) 199 ); 200 201 auto m = s.interpolationSqrt; 202 203 m(0.7).shouldApprox == sqrt(g(0.7)); 204 auto d = m.opCall!3(0.7); 205 d[0].shouldApprox == sqrt(g(0.7)); 206 d[1].shouldApprox == 0.5 / sqrt(g(0.7)) * g(0.7, 1); 207 d[2].shouldApprox == 2.2292836438189414; 208 d[3].shouldApprox == -3.11161; 209 }