1 /++ 2 $(SCRIPT inhibitQuickIndex = 1;) 3 4 $(BOOKTABLE $(H2 Multidimensional Random Variables), 5 6 $(TR $(TH Generator name) $(TH Description)) 7 $(RVAR Sphere, Uniform distribution on a unit-sphere) 8 $(RVAR Simplex, Uniform distribution on a standard-simplex) 9 $(RVAR Dirichlet, $(WIKI_D Dirichlet)) 10 $(RVAR Multinomial, $(WIKI_D Multinomial)) 11 $(RVAR MultivariateNormal, $(WIKI_D Multivariate_normal)) 12 ) 13 14 Authors: Simon Bürger, Ilya Yaroshenko 15 Copyright: Mir Community 2017-. 16 License: $(HTTP www.boost.org/LICENSE_1_0.txt, Boost License 1.0). 17 18 Macros: 19 WIKI_D = $(HTTP en.wikipedia.org/wiki/$1_distribution, $1 random variable) 20 WIKI_D2 = $(HTTP en.wikipedia.org/wiki/$1_distribution, $2 random variable) 21 T2=$(TR $(TDNW $(LREF $1)) $(TD $+)) 22 RVAR = $(TR $(TDNW $(LREF $1Variable)) $(TD $+)) 23 +/ 24 module mir.random.ndvariable; 25 26 import mir.random; 27 import std.traits; 28 import mir.math.common; 29 30 /++ 31 Test if T is an n-dimensional random variable. 32 +/ 33 template isNdRandomVariable(T) 34 { 35 static if (is(typeof(T.isNdRandomVariable) : bool)) 36 { 37 static if (T.isNdRandomVariable) 38 { 39 alias E = T.Element; 40 enum isNdRandomVariable = 41 is(typeof(((T rv, Random* gen) => rv(*gen, E[].init))(T.init, null)) == void) 42 && 43 is(typeof(((T rv, Random* gen) => rv.opCall!Random(*gen, E[].init))(T.init, null)) == void); 44 } 45 else 46 { 47 enum isNdRandomVariable = false; 48 } 49 } 50 else 51 { 52 enum isNdRandomVariable = false; 53 } 54 } 55 56 /// 57 unittest 58 { 59 static assert(isNdRandomVariable!(SphereVariable!double)); 60 } 61 62 /++ 63 Uniform distribution on a sphere. 64 Returns: `X ~ 1` with `X[0]^^2 + .. + X[$-1]^^2 = 1` 65 +/ 66 struct SphereVariable(T) 67 if (isFloatingPoint!T) 68 { 69 /// 70 enum isNdRandomVariable = true; 71 /// 72 alias Element = T; 73 74 75 /// 76 pragma(inline, false) 77 void opCall(G)(scope ref G gen, scope T[] result) 78 if (isSaturatedRandomEngine!G) 79 { 80 import mir.random.variable : NormalVariable; 81 82 assert(result.length); 83 T summator = 0; 84 auto norm = NormalVariable!T(0, 1); 85 foreach (ref e; result) 86 { 87 auto x = e = norm(gen); 88 summator += x * x; 89 } 90 result[] /= summator.sqrt; 91 } 92 /// ditto 93 void opCall(G)(scope G* gen, scope T[] result) 94 if (isSaturatedRandomEngine!G) 95 { 96 pragma(inline, true); 97 opCall(*gen, result); 98 } 99 } 100 101 /// ditto 102 SphereVariable!T sphereVar(T = double)() 103 if (isFloatingPoint!T) 104 { 105 return typeof(return).init; 106 } 107 108 /// ditto 109 alias sphereVariable = sphereVar; 110 111 /// Generate random points on a circle 112 @nogc nothrow @safe version(mir_random_test) unittest 113 { 114 import mir.random.engine; 115 import mir.math.common: fabs; 116 117 double[2] x; 118 sphereVar()(rne, x); 119 assert(fabs(x[0] * x[0] + x[1] * x[1] - 1) < 1e-10); 120 } 121 122 @nogc nothrow @safe version(mir_random_test) unittest 123 { 124 import mir.math.common: fabs; 125 126 Random* gen = threadLocalPtr!Random; 127 double[2] x; 128 sphereVar()(gen, x); 129 assert(fabs(x[0] * x[0] + x[1] * x[1] - 1) < 1e-10); 130 } 131 132 /++ 133 Uniform distribution on a simplex. 134 Returns: `X ~ 1` with `X[i] >= 0` and `X[0] + .. + X[$-1] = 1` 135 +/ 136 struct SimplexVariable(T) 137 if (isFloatingPoint!T) 138 { 139 static assert(is(typeof({ import mir.ndslice.slice; })), "mir.ndslice package is required for 'SimplexVariable', it can be found in 'mir-algorithm'"); 140 141 /// 142 enum isNdRandomVariable = true; 143 /// 144 alias Element = T; 145 146 /// 147 pragma(inline, false) 148 void opCall(G)(scope ref G gen, scope T[] result) 149 if (isSaturatedRandomEngine!G) 150 { 151 import mir.ndslice.sorting : sort; 152 import mir.ndslice.topology: diff, retro; 153 154 assert(result.length); 155 foreach (ref e; result[0 .. $ - 1]) 156 e = gen.rand!T.fabs; 157 result[$-1] = T(1); 158 sort(result[0 .. $ - 1]); 159 result[1 .. $].retro[] = result.diff.retro; 160 } 161 /// ditto 162 void opCall(G)(scope G* gen, scope T[] result) 163 if (isSaturatedRandomEngine!G) 164 { 165 pragma(inline, true); 166 opCall(*gen, result); 167 } 168 } 169 170 /// ditto 171 SimplexVariable!T simplexVar(T = double)() 172 if (isFloatingPoint!T) 173 { 174 return typeof(return).init; 175 } 176 177 /// ditto 178 alias simplexVariable = simplexVar; 179 180 /// 181 @nogc nothrow @safe version(mir_random_test) unittest 182 { 183 import mir.math.common: fabs; 184 // mir.ndslice package is required for 'SimplexVariable', it can be found in 'mir-algorithm' 185 static if (is(typeof({ import mir.ndslice.slice; }))) 186 { 187 import mir.random.engine; 188 auto rv = simplexVar; 189 double[3] x; 190 rv(rne, x); 191 assert(x[0] >= 0 && x[1] >= 0 && x[2] >= 0); 192 assert(fabs(x[0] + x[1] + x[2] - 1) < 1e-10); 193 } 194 } 195 196 /// 197 @nogc nothrow @safe version(mir_random_test) unittest 198 { 199 import mir.random.engine; 200 import mir.math.common: fabs; 201 202 // mir.ndslice package is required for 'SimplexVariable', it can be found in 'mir-algorithm' 203 static if (is(typeof({ import mir.ndslice.slice; }))) 204 { 205 import mir.ndslice.slice; 206 207 Random* gen = threadLocalPtr!Random; 208 SimplexVariable!double rv; 209 double[3] x; 210 rv(gen, x); 211 assert(x[0] >= 0 && x[1] >= 0 && x[2] >= 0); 212 assert(fabs(x[0] + x[1] + x[2] - 1) < 1e-10); 213 } 214 } 215 216 /++ 217 Dirichlet distribution. 218 +/ 219 struct DirichletVariable(T) 220 if (isFloatingPoint!T) 221 { 222 import mir.random.variable : GammaVariable; 223 224 /// 225 enum isNdRandomVariable = true; 226 /// 227 alias Element = T; 228 229 /// 230 const(T)[] alpha; 231 232 /++ 233 Params: 234 alpha = concentration parameters 235 Constraints: `alpha[i] > 0` 236 +/ 237 238 /// ditto 239 this()(const(T)[] alpha) 240 { 241 this.alpha = alpha; 242 } 243 244 /// 245 pragma(inline, false) 246 void opCall(G)(scope ref G gen, scope T[] result) 247 if (isSaturatedRandomEngine!G) 248 { 249 assert(result.length == alpha.length); 250 T summator = 0; 251 foreach (size_t i; 0 .. result.length) 252 summator += result[i] = GammaVariable!T(alpha[i], 1)(gen); 253 result[] /= summator; 254 } 255 /// ditto 256 void opCall(G)(scope G* gen, scope T[] result) 257 if (isSaturatedRandomEngine!G) 258 { 259 pragma(inline, true); 260 opCall(*gen, result); 261 } 262 } 263 264 /// ditto 265 DirichletVariable!T dirichletVar(T)(in T[] alpha) 266 if (isFloatingPoint!T) 267 { 268 return typeof(return)(alpha); 269 } 270 271 /// ditto 272 alias dirichletVariable = dirichletVar; 273 274 /// 275 nothrow @safe version(mir_random_test) unittest 276 { 277 import mir.random.engine; 278 import mir.math.common: fabs; 279 280 auto rv = dirichletVar([1.0, 5.7, 0.3]); 281 double[3] x; 282 rv(rne, x); 283 assert(x[0] >= 0 && x[1] >= 0 && x[2] >= 0); 284 assert(fabs(x[0] + x[1] + x[2] - 1) < 1e-10); 285 } 286 287 /// 288 nothrow @safe version(mir_random_test) unittest 289 { 290 import mir.random.engine; 291 import mir.math.common: fabs; 292 293 Random* gen = threadLocalPtr!Random; 294 auto rv = DirichletVariable!double([1.0, 5.7, 0.3]); 295 double[3] x; 296 rv(gen, x); 297 assert(x[0] >= 0 && x[1] >= 0 && x[2] >= 0); 298 assert(fabs(x[0] + x[1] + x[2] - 1) < 1e-10); 299 } 300 301 /++ 302 Multinomial distribution. 303 +/ 304 struct MultinomialVariable(T) 305 if (isFloatingPoint!T) 306 { 307 import mir.random.variable : binomialVar; 308 309 /// 310 enum isNdRandomVariable = true; 311 /// 312 alias Element = uint; 313 314 /// 315 const(T)[] probs; 316 size_t N; 317 318 319 /++ 320 Params: 321 probs = probabilities of the multinomial distribution 322 N = Number of rolls 323 Constraints: `sum(probs[i]) <= 1` 324 +/ 325 this()(size_t N, const(T)[] probs) 326 { 327 this.N = N; 328 this.probs = probs; 329 /// Makes sure probabilities add up to one, by calculating a normalization factor 330 version(assert) 331 { 332 T norm = 0; 333 foreach(k, p; this.probs) 334 { 335 norm += p; 336 } 337 assert(fabs(norm - 1) <= T.epsilon * probs.length * 2); 338 } 339 340 } 341 342 /// 343 pragma(inline, false) 344 void opCall(G)(scope ref G gen, scope uint[] result) 345 if (isSaturatedRandomEngine!G) 346 { 347 T sum_p = 0.0; 348 size_t sum_n = 0; 349 350 foreach(k, p; this.probs) 351 { 352 if (p > 0.0) 353 { 354 auto rv = binomialVar!T(this.N - sum_n, p / (1 - sum_p)); 355 result[k] = cast(uint)rv(gen); 356 357 } 358 else 359 { 360 result[k] = 0; 361 } 362 363 sum_p += p; 364 sum_n += result[k]; 365 } 366 367 368 } 369 /// ditto 370 void opCall(G)(scope G* gen, scope uint[] result) 371 if (isSaturatedRandomEngine!G) 372 { 373 pragma(inline, true); 374 opCall(*gen, result); 375 } 376 } 377 378 /// ditto 379 MultinomialVariable!(T) multinomialVar(T)(size_t N, return const T[] probs) 380 if (isFloatingPoint!T) 381 { 382 return typeof(return)(N, probs); 383 } 384 385 /// ditto 386 alias multinomialVariable = multinomialVar; 387 388 /// Tests if sample returned is of correct size. 389 nothrow @safe version(mir_random_test) unittest 390 { 391 import mir.random.engine; 392 size_t s = 10000; 393 double[6] p =[1/6., 1/6., 1/6., 1/6., 1/6., 1/6.]; // probs must add up to one 394 auto rv = multinomialVar(s, p); 395 uint[6] x; 396 rv(rne, x[]); 397 assert(x[0]+x[1]+x[2]+x[3]+x[4]+x[5] == s); 398 } 399 400 nothrow @safe version(mir_random_test) unittest 401 { 402 import mir.random.engine; 403 404 Random* gen = threadLocalPtr!Random; 405 size_t s = 1000; 406 double[3] p = [0.1, 0.5, 0.4]; 407 auto rv = MultinomialVariable!(double)(s, p); 408 uint[3] x; 409 rv(gen, x[]); 410 assert(x[0]+x[1]+x[2] == s); 411 } 412 413 /++ 414 Multivariate normal distribution. 415 Beta version (has not properly tested). 416 +/ 417 struct MultivariateNormalVariable(T) 418 if(isFloatingPoint!T) 419 { 420 static assert(is(typeof({ import mir.ndslice.slice; })), "mir.ndslice package is required for 'MultivariateNormalVariable', it can be found in 'mir-algorithm'"); 421 422 423 /++ 424 Compute Cholesky decomposition in place. Only accesses lower/left half of 425 the matrix. Returns false if the matrix is not positive definite. 426 +/ 427 static bool cholesky()(Slice!(T*, 2) m) 428 { 429 import mir.algorithm.iteration: reduce; 430 assert(m.length!0 == m.length!1); 431 432 /* this is a straight-forward implementation of the Cholesky-Crout algorithm 433 from https://en.wikipedia.org/wiki/Cholesky_decomposition#Computation */ 434 foreach(size_t i; 0 .. m.length) 435 { 436 auto r = m[i]; 437 foreach(size_t j; 0 .. i) 438 r[j] = (r[j] - reduce!"a + b * c"(typeof(r[j])(0), r[0 .. j], m[j, 0 .. j])) / m[j, j]; 439 r[i] -= reduce!"a + b * b"(typeof(r[i])(0), r[0 .. i]); 440 if (!(r[i] > 0)) // this catches nan's as well 441 return false; 442 r[i] = sqrt(r[i]); 443 } 444 return true; 445 } 446 447 /// 448 enum isNdRandomVariable = true; 449 /// 450 alias Element = T; 451 452 private size_t n; 453 private const(T)* sigma; // cholesky decomposition of covariance matrix 454 private const(T)* mu; // mean vector (can be empty) 455 456 /++ 457 Constructor computes the Cholesky decomposition of `sigma` in place without 458 memory allocation. Furthermore it is assumed to be a symmetric matrix, but 459 only the lower/left half is actually accessed. 460 461 Params: 462 mu = mean vector (assumed zero if not supplied) 463 sigma = covariance matrix 464 chol = optional flag indicating that sigma is already Cholesky decomposed 465 466 Constraints: sigma has to be positive-definite 467 +/ 468 this()(Slice!(const(T)*) mu, Slice!(T*, 2) sigma, bool chol = false) 469 { 470 //Check the dimenstions even in release mode to _guarantee_ 471 //that unless memory corruption has already occurred sigma 472 //and mu have the correct dimensions and it is correct in opCall 473 //to "@trust" slicing sigma to [n x n] and mu to [n]. 474 if ((mu.length != sigma.length!0) | (mu.length != sigma.length!1)) 475 assert(false); 476 477 if(!chol && !cholesky(sigma)) 478 assert(false, "covariance matrix not positive definite"); 479 480 this.n = sigma.length; 481 this.mu = mu.iterator; 482 this.sigma = sigma.iterator; 483 } 484 485 /++ ditto +/ 486 this()(Slice!(T*, 2) sigma, bool chol = false) 487 { 488 //Check the dimenstions even in release mode to _guarantee_ 489 //that unless memory corruption has already occurred sigma 490 //and mu have the correct dimensions and it is correct in opCall 491 //to "@trust" slicing sigma as (n,n) and slicing mu as (n). 492 if (sigma.length!0 != sigma.length!1) 493 assert(false); 494 495 if(!chol && !cholesky(sigma)) 496 assert(false, "covariance matrix not positive definite"); 497 498 this.n = sigma.length; 499 this.mu = null; 500 this.sigma = sigma.iterator; 501 } 502 503 /// 504 pragma(inline, false) 505 void opCall(G)(scope ref G gen, scope T[] result) 506 if (isSaturatedRandomEngine!G) 507 { 508 import mir.algorithm.iteration: reduce; 509 import mir.ndslice.slice: sliced; 510 assert(result.length == n); 511 import mir.random.variable : NormalVariable; 512 auto norm = NormalVariable!T(0, 1); 513 514 auto s = (() @trusted => sigma.sliced(n, n))();//sigma is n x n matrix. 515 foreach(ref e; result) 516 e = norm(gen); 517 foreach_reverse(size_t i; 0 .. n) 518 result[i] = reduce!"a + b * c"(T(0), s[i, 0 .. i + 1], result[0 .. i + 1]); 519 if (mu) 520 result.sliced[] +=(() @trusted => mu.sliced(n))();//mu is n vector. 521 } 522 /// ditto 523 void opCall(G)(scope G* gen, scope T[] result) 524 if (isSaturatedRandomEngine!G) 525 { 526 pragma(inline, true); 527 opCall(*gen, result); 528 } 529 } 530 531 static if (is(typeof({import mir.ndslice.slice;}))) 532 { 533 import mir.ndslice.slice: Slice; 534 535 /// ditto 536 MultivariateNormalVariable!T multivariateNormalVar(T)(Slice!(const(T)*) mu, Slice!(T*, 2) sigma, bool chol = false) 537 { 538 return typeof(return)(mu, sigma, chol); 539 } 540 541 /// ditto 542 MultivariateNormalVariable!T multivariateNormalVar(T)(Slice!(T*, 2) sigma, bool chol = false) 543 { 544 return typeof(return)(sigma, chol); 545 } 546 } 547 else 548 { 549 auto multivariateNormalVar(S)(S sigma, bool chol = false) 550 { 551 static assert(0, "mir.ndslice package is required for 'MultivariateNormalVariable', it can be found in 'mir-algorithm'"); 552 } 553 554 auto multivariateNormalVar(M, S)(M mu, S sigma, bool chol = false) 555 { 556 static assert(0, "mir.ndslice package is required for 'MultivariateNormalVariable', it can be found in 'mir-algorithm'"); 557 } 558 } 559 560 561 /// ditto 562 alias multivariateNormalVariable = multivariateNormalVar; 563 564 /// 565 nothrow @safe version(mir_random_test) unittest 566 { 567 // mir.ndslice package is required for 'multivariateNormalVar', it can be found in 'mir-algorithm' 568 static if (is(typeof({ import mir.ndslice.slice; }))) 569 { 570 import mir.random.engine; 571 import mir.ndslice.slice: sliced; 572 auto mu = [10.0, 0.0].sliced; 573 auto sigma = [2.0, -1.5, -1.5, 2.0].sliced(2,2); 574 auto rv = multivariateNormalVar(mu, sigma); 575 double[2] x; 576 rv(rne, x[]); 577 } 578 } 579 580 /// 581 nothrow @safe version(mir_random_test) unittest 582 { 583 // mir.ndslice package is required for 'multivariateNormalVar', it can be found in 'mir-algorithm' 584 static if (is(typeof({ import mir.ndslice.slice; }))) 585 { 586 import mir.ndslice.slice: sliced; 587 import mir.random.engine; 588 589 Random* gen = threadLocalPtr!Random; 590 auto mu = [10.0, 0.0].sliced; 591 auto sigma = [2.0, -1.5, -1.5, 2.0].sliced(2,2); 592 auto rv = multivariateNormalVar(mu, sigma); 593 double[2] x; 594 rv(gen, x[]); 595 } 596 }