1 // Written in the D programming language. 2 3 /** 4 This module is a port of a growing fragment of the $(D_PARAM numeric) 5 header in Alexander Stepanov's $(LINK2 https://en.wikipedia.org/wiki/Standard_Template_Library, 6 Standard Template Library), with a few additions. 7 8 Macros: 9 Copyright: Copyright Andrei Alexandrescu 2008 - 2009. 10 License: $(HTTP www.boost.org/LICENSE_1_0.txt, Boost License 1.0). 11 Authors: $(HTTP erdani.org, Andrei Alexandrescu), 12 Don Clugston, Robert Jacques, Ilya Yaroshenko 13 Source: $(PHOBOSSRC std/numeric.d) 14 */ 15 /* 16 Copyright Andrei Alexandrescu 2008 - 2009. 17 Distributed under the Boost Software License, Version 1.0. 18 (See accompanying file LICENSE_1_0.txt or copy at 19 http://www.boost.org/LICENSE_1_0.txt) 20 */ 21 module std.numeric; 22 23 import std.complex; 24 import std.math; 25 import std.range.primitives; 26 import std.traits; 27 import std.typecons; 28 29 /// Format flags for CustomFloat. 30 public enum CustomFloatFlags 31 { 32 /// Adds a sign bit to allow for signed numbers. 33 signed = 1, 34 35 /** 36 * Store values in normalized form by default. The actual precision of the 37 * significand is extended by 1 bit by assuming an implicit leading bit of 1 38 * instead of 0. i.e. `1.nnnn` instead of `0.nnnn`. 39 * True for all $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEE754) types 40 */ 41 storeNormalized = 2, 42 43 /** 44 * Stores the significand in $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers, 45 * IEEE754 denormalized) form when the exponent is 0. Required to express the value 0. 46 */ 47 allowDenorm = 4, 48 49 /** 50 * Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Positive_and_negative_infinity, 51 * IEEE754 _infinity) values. 52 */ 53 infinity = 8, 54 55 /// Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/NaN, IEEE754 Not a Number) values. 56 nan = 16, 57 58 /** 59 * If set, select an exponent bias such that max_exp = 1. 60 * i.e. so that the maximum value is >= 1.0 and < 2.0. 61 * Ignored if the exponent bias is manually specified. 62 */ 63 probability = 32, 64 65 /// If set, unsigned custom floats are assumed to be negative. 66 negativeUnsigned = 64, 67 68 /**If set, 0 is the only allowed $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers, 69 * IEEE754 denormalized) number. 70 * Requires allowDenorm and storeNormalized. 71 */ 72 allowDenormZeroOnly = 128 | allowDenorm | storeNormalized, 73 74 /// Include _all of the $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEEE754) options. 75 ieee = signed | storeNormalized | allowDenorm | infinity | nan , 76 77 /// Include none of the above options. 78 none = 0 79 } 80 81 private template CustomFloatParams(uint bits) 82 { 83 enum CustomFloatFlags flags = CustomFloatFlags.ieee 84 ^ ((bits == 80) ? CustomFloatFlags.storeNormalized : CustomFloatFlags.none); 85 static if (bits == 8) alias CustomFloatParams = CustomFloatParams!( 4, 3, flags); 86 static if (bits == 16) alias CustomFloatParams = CustomFloatParams!(10, 5, flags); 87 static if (bits == 32) alias CustomFloatParams = CustomFloatParams!(23, 8, flags); 88 static if (bits == 64) alias CustomFloatParams = CustomFloatParams!(52, 11, flags); 89 static if (bits == 80) alias CustomFloatParams = CustomFloatParams!(64, 15, flags); 90 } 91 92 private template CustomFloatParams(uint precision, uint exponentWidth, CustomFloatFlags flags) 93 { 94 import std.meta : AliasSeq; 95 alias CustomFloatParams = 96 AliasSeq!( 97 precision, 98 exponentWidth, 99 flags, 100 (1 << (exponentWidth - ((flags & flags.probability) == 0))) 101 - ((flags & (flags.nan | flags.infinity)) != 0) - ((flags & flags.probability) != 0) 102 ); // ((flags & CustomFloatFlags.probability) == 0) 103 } 104 105 /** 106 * Allows user code to define custom floating-point formats. These formats are 107 * for storage only; all operations on them are performed by first implicitly 108 * extracting them to `real` first. After the operation is completed the 109 * result can be stored in a custom floating-point value via assignment. 110 */ 111 template CustomFloat(uint bits) 112 if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 80) 113 { 114 alias CustomFloat = CustomFloat!(CustomFloatParams!(bits)); 115 } 116 117 /// ditto 118 template CustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags = CustomFloatFlags.ieee) 119 if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 && precision + exponentWidth > 0) 120 { 121 alias CustomFloat = CustomFloat!(CustomFloatParams!(precision, exponentWidth, flags)); 122 } 123 124 /// 125 @safe unittest 126 { 127 import std.math : sin, cos; 128 129 // Define a 16-bit floating point values 130 CustomFloat!16 x; // Using the number of bits 131 CustomFloat!(10, 5) y; // Using the precision and exponent width 132 CustomFloat!(10, 5,CustomFloatFlags.ieee) z; // Using the precision, exponent width and format flags 133 CustomFloat!(10, 5,CustomFloatFlags.ieee, 15) w; // Using the precision, exponent width, format flags and exponent offset bias 134 135 // Use the 16-bit floats mostly like normal numbers 136 w = x*y - 1; 137 138 // Functions calls require conversion 139 z = sin(+x) + cos(+y); // Use unary plus to concisely convert to a real 140 z = sin(x.get!float) + cos(y.get!float); // Or use get!T 141 z = sin(cast(float) x) + cos(cast(float) y); // Or use cast(T) to explicitly convert 142 143 // Define a 8-bit custom float for storing probabilities 144 alias Probability = CustomFloat!(4, 4, CustomFloatFlags.ieee^CustomFloatFlags.probability^CustomFloatFlags.signed ); 145 auto p = Probability(0.5); 146 } 147 148 // Facilitate converting numeric types to custom float 149 private union ToBinary(F) 150 if (is(typeof(CustomFloatParams!(F.sizeof*8))) || is(F == real)) 151 { 152 F set; 153 154 // If on Linux or Mac, where 80-bit reals are padded, ignore the 155 // padding. 156 import std.algorithm.comparison : min; 157 CustomFloat!(CustomFloatParams!(min(F.sizeof*8, 80))) get; 158 159 // Convert F to the correct binary type. 160 static typeof(get) opCall(F value) 161 { 162 ToBinary r; 163 r.set = value; 164 return r.get; 165 } 166 alias get this; 167 } 168 169 /// ditto 170 struct CustomFloat(uint precision, // fraction bits (23 for float) 171 uint exponentWidth, // exponent bits (8 for float) Exponent width 172 CustomFloatFlags flags, 173 uint bias) 174 if (isCorrectCustomFloat(precision, exponentWidth, flags)) 175 { 176 import std.bitmanip : bitfields; 177 import std.meta : staticIndexOf; 178 private: 179 // get the correct unsigned bitfield type to support > 32 bits 180 template uType(uint bits) 181 { 182 static if (bits <= size_t.sizeof*8) alias uType = size_t; 183 else alias uType = ulong ; 184 } 185 186 // get the correct signed bitfield type to support > 32 bits 187 template sType(uint bits) 188 { 189 static if (bits <= ptrdiff_t.sizeof*8-1) alias sType = ptrdiff_t; 190 else alias sType = long; 191 } 192 193 alias T_sig = uType!precision; 194 alias T_exp = uType!exponentWidth; 195 alias T_signed_exp = sType!exponentWidth; 196 197 alias Flags = CustomFloatFlags; 198 199 // Perform IEEE rounding with round to nearest detection 200 void roundedShift(T,U)(ref T sig, U shift) 201 { 202 if (shift >= T.sizeof*8) 203 { 204 // avoid illegal shift 205 sig = 0; 206 } 207 else if (sig << (T.sizeof*8 - shift) == cast(T) 1uL << (T.sizeof*8 - 1)) 208 { 209 // round to even 210 sig >>= shift; 211 sig += sig & 1; 212 } 213 else 214 { 215 sig >>= shift - 1; 216 sig += sig & 1; 217 // Perform standard rounding 218 sig >>= 1; 219 } 220 } 221 222 // Convert the current value to signed exponent, normalized form 223 void toNormalized(T,U)(ref T sig, ref U exp) 224 { 225 sig = significand; 226 auto shift = (T.sizeof*8) - precision; 227 exp = exponent; 228 static if (flags&(Flags.infinity|Flags.nan)) 229 { 230 // Handle inf or nan 231 if (exp == exponent_max) 232 { 233 exp = exp.max; 234 sig <<= shift; 235 static if (flags&Flags.storeNormalized) 236 { 237 // Save inf/nan in denormalized format 238 sig >>= 1; 239 sig += cast(T) 1uL << (T.sizeof*8 - 1); 240 } 241 return; 242 } 243 } 244 if ((~flags&Flags.storeNormalized) || 245 // Convert denormalized form to normalized form 246 ((flags&Flags.allowDenorm) && exp == 0)) 247 { 248 if (sig > 0) 249 { 250 import core.bitop : bsr; 251 auto shift2 = precision - bsr(sig); 252 exp -= shift2-1; 253 shift += shift2; 254 } 255 else // value = 0.0 256 { 257 exp = exp.min; 258 return; 259 } 260 } 261 sig <<= shift; 262 exp -= bias; 263 } 264 265 // Set the current value from signed exponent, normalized form 266 void fromNormalized(T,U)(ref T sig, ref U exp) 267 { 268 auto shift = (T.sizeof*8) - precision; 269 if (exp == exp.max) 270 { 271 // infinity or nan 272 exp = exponent_max; 273 static if (flags & Flags.storeNormalized) 274 sig <<= 1; 275 276 // convert back to normalized form 277 static if (~flags & Flags.infinity) 278 // No infinity support? 279 assert(sig != 0, "Infinity floating point value assigned to a " 280 ~ typeof(this).stringof ~ " (no infinity support)."); 281 282 static if (~flags & Flags.nan) // No NaN support? 283 assert(sig == 0, "NaN floating point value assigned to a " ~ 284 typeof(this).stringof ~ " (no nan support)."); 285 sig >>= shift; 286 return; 287 } 288 if (exp == exp.min) // 0.0 289 { 290 exp = 0; 291 sig = 0; 292 return; 293 } 294 295 exp += bias; 296 if (exp <= 0) 297 { 298 static if ((flags&Flags.allowDenorm) || 299 // Convert from normalized form to denormalized 300 (~flags&Flags.storeNormalized)) 301 { 302 shift += -exp; 303 roundedShift(sig,1); 304 sig += cast(T) 1uL << (T.sizeof*8 - 1); 305 // Add the leading 1 306 exp = 0; 307 } 308 else 309 assert((flags&Flags.storeNormalized) && exp == 0, 310 "Underflow occured assigning to a " ~ 311 typeof(this).stringof ~ " (no denormal support)."); 312 } 313 else 314 { 315 static if (~flags&Flags.storeNormalized) 316 { 317 // Convert from normalized form to denormalized 318 roundedShift(sig,1); 319 sig += cast(T) 1uL << (T.sizeof*8 - 1); 320 // Add the leading 1 321 } 322 } 323 324 if (shift > 0) 325 roundedShift(sig,shift); 326 if (sig > significand_max) 327 { 328 // handle significand overflow (should only be 1 bit) 329 static if (~flags&Flags.storeNormalized) 330 { 331 sig >>= 1; 332 } 333 else 334 sig &= significand_max; 335 exp++; 336 } 337 static if ((flags&Flags.allowDenormZeroOnly)==Flags.allowDenormZeroOnly) 338 { 339 // disallow non-zero denormals 340 if (exp == 0) 341 { 342 sig <<= 1; 343 if (sig > significand_max && (sig&significand_max) > 0) 344 // Check and round to even 345 exp++; 346 sig = 0; 347 } 348 } 349 350 if (exp >= exponent_max) 351 { 352 static if (flags&(Flags.infinity|Flags.nan)) 353 { 354 sig = 0; 355 exp = exponent_max; 356 static if (~flags&(Flags.infinity)) 357 assert(0, "Overflow occured assigning to a " ~ 358 typeof(this).stringof ~ " (no infinity support)."); 359 } 360 else 361 assert(exp == exponent_max, "Overflow occured assigning to a " 362 ~ typeof(this).stringof ~ " (no infinity support)."); 363 } 364 } 365 366 public: 367 static if (precision == 64) // CustomFloat!80 support hack 368 { 369 ulong significand; 370 enum ulong significand_max = ulong.max; 371 mixin(bitfields!( 372 T_exp , "exponent", exponentWidth, 373 bool , "sign" , flags & flags.signed )); 374 } 375 else 376 { 377 mixin(bitfields!( 378 T_sig, "significand", precision, 379 T_exp, "exponent" , exponentWidth, 380 bool , "sign" , flags & flags.signed )); 381 } 382 383 /// Returns: infinity value 384 static if (flags & Flags.infinity) 385 static @property CustomFloat infinity() 386 { 387 CustomFloat value; 388 static if (flags & Flags.signed) 389 value.sign = 0; 390 value.significand = 0; 391 value.exponent = exponent_max; 392 return value; 393 } 394 395 /// Returns: NaN value 396 static if (flags & Flags.nan) 397 static @property CustomFloat nan() 398 { 399 CustomFloat value; 400 static if (flags & Flags.signed) 401 value.sign = 0; 402 value.significand = cast(typeof(significand_max)) 1L << (precision-1); 403 value.exponent = exponent_max; 404 return value; 405 } 406 407 /// Returns: number of decimal digits of precision 408 static @property size_t dig() 409 { 410 auto shiftcnt = precision - ((flags&Flags.storeNormalized) == 0); 411 return shiftcnt == 64 ? 19 : cast(size_t) log10(1uL << shiftcnt); 412 } 413 414 /// Returns: smallest increment to the value 1 415 static @property CustomFloat epsilon() 416 { 417 CustomFloat one = CustomFloat(1); 418 CustomFloat onePlusEpsilon = one; 419 onePlusEpsilon.significand = onePlusEpsilon.significand | 1; // |= does not work here 420 421 return CustomFloat(onePlusEpsilon - one); 422 } 423 424 /// the number of bits in mantissa 425 enum mant_dig = precision + ((flags&Flags.storeNormalized) != 0); 426 427 /// Returns: maximum int value such that 10<sup>max_10_exp</sup> is representable 428 static @property int max_10_exp(){ return cast(int) log10( +max ); } 429 430 /// maximum int value such that 2<sup>max_exp-1</sup> is representable 431 enum max_exp = exponent_max - bias - ((flags & (Flags.infinity | Flags.nan)) != 0) + 1; 432 433 /// Returns: minimum int value such that 10<sup>min_10_exp</sup> is representable 434 static @property int min_10_exp(){ return cast(int) log10( +min_normal ); } 435 436 /// minimum int value such that 2<sup>min_exp-1</sup> is representable as a normalized value 437 enum min_exp = cast(T_signed_exp) -(cast(long) bias) + 1 + ((flags & Flags.allowDenorm) != 0); 438 439 /// Returns: largest representable value that's not infinity 440 static @property CustomFloat max() 441 { 442 CustomFloat value; 443 static if (flags & Flags.signed) 444 value.sign = 0; 445 value.exponent = exponent_max - ((flags&(flags.infinity|flags.nan)) != 0); 446 value.significand = significand_max; 447 return value; 448 } 449 450 /// Returns: smallest representable normalized value that's not 0 451 static @property CustomFloat min_normal() 452 { 453 CustomFloat value; 454 static if (flags & Flags.signed) 455 value.sign = 0; 456 value.exponent = (flags & Flags.allowDenorm) != 0; 457 static if (flags & Flags.storeNormalized) 458 value.significand = 0; 459 else 460 value.significand = cast(T_sig) 1uL << (precision - 1); 461 return value; 462 } 463 464 /// Returns: real part 465 @property CustomFloat re() { return this; } 466 467 /// Returns: imaginary part 468 static @property CustomFloat im() { return CustomFloat(0.0f); } 469 470 /// Initialize from any `real` compatible type. 471 this(F)(F input) if (__traits(compiles, cast(real) input )) 472 { 473 this = input; 474 } 475 476 /// Self assignment 477 void opAssign(F:CustomFloat)(F input) 478 { 479 static if (flags & Flags.signed) 480 sign = input.sign; 481 exponent = input.exponent; 482 significand = input.significand; 483 } 484 485 /// Assigns from any `real` compatible type. 486 void opAssign(F)(F input) 487 if (__traits(compiles, cast(real) input)) 488 { 489 import std.conv : text; 490 491 static if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0) 492 auto value = ToBinary!(Unqual!F)(input); 493 else 494 auto value = ToBinary!(real )(input); 495 496 // Assign the sign bit 497 static if (~flags & Flags.signed) 498 assert((!value.sign) ^ ((flags&flags.negativeUnsigned) > 0), 499 "Incorrectly signed floating point value assigned to a " ~ 500 typeof(this).stringof ~ " (no sign support)."); 501 else 502 sign = value.sign; 503 504 CommonType!(T_signed_exp ,value.T_signed_exp) exp = value.exponent; 505 CommonType!(T_sig, value.T_sig ) sig = value.significand; 506 507 value.toNormalized(sig,exp); 508 fromNormalized(sig,exp); 509 510 assert(exp <= exponent_max, text(typeof(this).stringof ~ 511 " exponent too large: " ,exp," > ",exponent_max, "\t",input,"\t",sig)); 512 assert(sig <= significand_max, text(typeof(this).stringof ~ 513 " significand too large: ",sig," > ",significand_max, 514 "\t",input,"\t",exp," ",exponent_max)); 515 exponent = cast(T_exp) exp; 516 significand = cast(T_sig) sig; 517 } 518 519 /// Fetches the stored value either as a `float`, `double` or `real`. 520 @property F get(F)() 521 if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0) 522 { 523 import std.conv : text; 524 525 ToBinary!F result; 526 527 static if (flags&Flags.signed) 528 result.sign = sign; 529 else 530 result.sign = (flags&flags.negativeUnsigned) > 0; 531 532 CommonType!(T_signed_exp ,result.get.T_signed_exp ) exp = exponent; // Assign the exponent and fraction 533 CommonType!(T_sig, result.get.T_sig ) sig = significand; 534 535 toNormalized(sig,exp); 536 result.fromNormalized(sig,exp); 537 assert(exp <= result.exponent_max, text("get exponent too large: " ,exp," > ",result.exponent_max) ); 538 assert(sig <= result.significand_max, text("get significand too large: ",sig," > ",result.significand_max) ); 539 result.exponent = cast(result.get.T_exp) exp; 540 result.significand = cast(result.get.T_sig) sig; 541 return result.set; 542 } 543 544 ///ditto 545 alias opCast = get; 546 547 /// Convert the CustomFloat to a real and perform the relevant operator on the result 548 real opUnary(string op)() 549 if (__traits(compiles, mixin(op~`(get!real)`)) || op=="++" || op=="--") 550 { 551 static if (op=="++" || op=="--") 552 { 553 auto result = get!real; 554 this = mixin(op~`result`); 555 return result; 556 } 557 else 558 return mixin(op~`get!real`); 559 } 560 561 /// ditto 562 // Define an opBinary `CustomFloat op CustomFloat` so that those below 563 // do not match equally, which is disallowed by the spec: 564 // https://dlang.org/spec/operatoroverloading.html#binary 565 real opBinary(string op,T)(T b) 566 if (__traits(compiles, mixin(`get!real`~op~`b.get!real`))) 567 { 568 return mixin(`get!real`~op~`b.get!real`); 569 } 570 571 /// ditto 572 real opBinary(string op,T)(T b) 573 if ( __traits(compiles, mixin(`get!real`~op~`b`)) && 574 !__traits(compiles, mixin(`get!real`~op~`b.get!real`))) 575 { 576 return mixin(`get!real`~op~`b`); 577 } 578 579 /// ditto 580 real opBinaryRight(string op,T)(T a) 581 if ( __traits(compiles, mixin(`a`~op~`get!real`)) && 582 !__traits(compiles, mixin(`get!real`~op~`b`)) && 583 !__traits(compiles, mixin(`get!real`~op~`b.get!real`))) 584 { 585 return mixin(`a`~op~`get!real`); 586 } 587 588 /// ditto 589 int opCmp(T)(auto ref T b) 590 if (__traits(compiles, cast(real) b)) 591 { 592 auto x = get!real; 593 auto y = cast(real) b; 594 return (x >= y)-(x <= y); 595 } 596 597 /// ditto 598 void opOpAssign(string op, T)(auto ref T b) 599 if (__traits(compiles, mixin(`get!real`~op~`cast(real) b`))) 600 { 601 return mixin(`this = this `~op~` cast(real) b`); 602 } 603 604 /// ditto 605 template toString() 606 { 607 import std.format : FormatSpec, formatValue; 608 // Needs to be a template because of https://issues.dlang.org/show_bug.cgi?id=13737. 609 void toString()(scope void delegate(const(char)[]) sink, scope const ref FormatSpec!char fmt) 610 { 611 sink.formatValue(get!real, fmt); 612 } 613 } 614 } 615 616 @safe unittest 617 { 618 import std.meta; 619 alias FPTypes = 620 AliasSeq!( 621 CustomFloat!(5, 10), 622 CustomFloat!(5, 11, CustomFloatFlags.ieee ^ CustomFloatFlags.signed), 623 CustomFloat!(1, 7, CustomFloatFlags.ieee ^ CustomFloatFlags.signed), 624 CustomFloat!(4, 3, CustomFloatFlags.ieee | CustomFloatFlags.probability ^ CustomFloatFlags.signed) 625 ); 626 627 foreach (F; FPTypes) 628 { 629 auto x = F(0.125); 630 assert(x.get!float == 0.125F); 631 assert(x.get!double == 0.125); 632 633 x -= 0.0625; 634 assert(x.get!float == 0.0625F); 635 assert(x.get!double == 0.0625); 636 637 x *= 2; 638 assert(x.get!float == 0.125F); 639 assert(x.get!double == 0.125); 640 641 x /= 4; 642 assert(x.get!float == 0.03125); 643 assert(x.get!double == 0.03125); 644 645 x = 0.5; 646 x ^^= 4; 647 assert(x.get!float == 1 / 16.0F); 648 assert(x.get!double == 1 / 16.0); 649 } 650 } 651 652 @system unittest 653 { 654 // @system due to to!string(CustomFloat) 655 import std.conv; 656 CustomFloat!(5, 10) y = CustomFloat!(5, 10)(0.125); 657 assert(y.to!string == "0.125"); 658 } 659 660 @safe unittest 661 { 662 alias cf = CustomFloat!(5, 2); 663 664 auto a = cf.infinity; 665 assert(a.sign == 0); 666 assert(a.exponent == 3); 667 assert(a.significand == 0); 668 669 auto b = cf.nan; 670 assert(b.exponent == 3); 671 assert(b.significand != 0); 672 673 assert(cf.dig == 1); 674 675 auto c = cf.epsilon; 676 assert(c.sign == 0); 677 assert(c.exponent == 0); 678 assert(c.significand == 1); 679 680 assert(cf.mant_dig == 6); 681 682 assert(cf.max_10_exp == 0); 683 assert(cf.max_exp == 2); 684 assert(cf.min_10_exp == 0); 685 assert(cf.min_exp == 1); 686 687 auto d = cf.max; 688 assert(d.sign == 0); 689 assert(d.exponent == 2); 690 assert(d.significand == 31); 691 692 auto e = cf.min_normal; 693 assert(e.sign == 0); 694 assert(e.exponent == 1); 695 assert(e.significand == 0); 696 697 assert(e.re == e); 698 assert(e.im == cf(0.0)); 699 } 700 701 // check whether CustomFloats identical to float/double behave like float/double 702 @safe unittest 703 { 704 import std.conv : to; 705 706 alias myFloat = CustomFloat!(23, 8); 707 708 static assert(myFloat.dig == float.dig); 709 static assert(myFloat.mant_dig == float.mant_dig); 710 assert(myFloat.max_10_exp == float.max_10_exp); 711 static assert(myFloat.max_exp == float.max_exp); 712 assert(myFloat.min_10_exp == float.min_10_exp); 713 static assert(myFloat.min_exp == float.min_exp); 714 assert(to!float(myFloat.epsilon) == float.epsilon); 715 assert(to!float(myFloat.max) == float.max); 716 assert(to!float(myFloat.min_normal) == float.min_normal); 717 718 alias myDouble = CustomFloat!(52, 11); 719 720 static assert(myDouble.dig == double.dig); 721 static assert(myDouble.mant_dig == double.mant_dig); 722 assert(myDouble.max_10_exp == double.max_10_exp); 723 static assert(myDouble.max_exp == double.max_exp); 724 assert(myDouble.min_10_exp == double.min_10_exp); 725 static assert(myDouble.min_exp == double.min_exp); 726 assert(to!double(myDouble.epsilon) == double.epsilon); 727 assert(to!double(myDouble.max) == double.max); 728 assert(to!double(myDouble.min_normal) == double.min_normal); 729 } 730 731 // testing .dig 732 @safe unittest 733 { 734 static assert(CustomFloat!(1, 6).dig == 0); 735 static assert(CustomFloat!(9, 6).dig == 2); 736 static assert(CustomFloat!(10, 5).dig == 3); 737 static assert(CustomFloat!(10, 6, CustomFloatFlags.none).dig == 2); 738 static assert(CustomFloat!(11, 5, CustomFloatFlags.none).dig == 3); 739 static assert(CustomFloat!(64, 7).dig == 19); 740 } 741 742 // testing .mant_dig 743 @safe unittest 744 { 745 static assert(CustomFloat!(10, 5).mant_dig == 11); 746 static assert(CustomFloat!(10, 6, CustomFloatFlags.none).mant_dig == 10); 747 } 748 749 // testing .max_exp 750 @safe unittest 751 { 752 static assert(CustomFloat!(1, 6).max_exp == 2^^5); 753 static assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_exp == 2^^5); 754 static assert(CustomFloat!(5, 10).max_exp == 2^^9); 755 static assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_exp == 2^^9); 756 static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_exp == 2^^5); 757 static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_exp == 2^^9); 758 } 759 760 // testing .min_exp 761 @safe unittest 762 { 763 static assert(CustomFloat!(1, 6).min_exp == -2^^5+3); 764 static assert(CustomFloat!(5, 10).min_exp == -2^^9+3); 765 static assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_exp == -2^^5+1); 766 static assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_exp == -2^^9+1); 767 static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_exp == -2^^5+2); 768 static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_exp == -2^^9+2); 769 static assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_exp == -2^^5+2); 770 static assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_exp == -2^^9+2); 771 } 772 773 // testing .max_10_exp 774 @safe unittest 775 { 776 assert(CustomFloat!(1, 6).max_10_exp == 9); 777 assert(CustomFloat!(5, 10).max_10_exp == 154); 778 assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_10_exp == 9); 779 assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_10_exp == 154); 780 assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_10_exp == 9); 781 assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_10_exp == 154); 782 } 783 784 // testing .min_10_exp 785 @safe unittest 786 { 787 assert(CustomFloat!(1, 6).min_10_exp == -9); 788 assert(CustomFloat!(5, 10).min_10_exp == -153); 789 assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_10_exp == -9); 790 assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_10_exp == -154); 791 assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_10_exp == -9); 792 assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_10_exp == -153); 793 assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_10_exp == -9); 794 assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_10_exp == -153); 795 } 796 797 // testing .epsilon 798 @safe unittest 799 { 800 assert(CustomFloat!(1,6).epsilon.sign == 0); 801 assert(CustomFloat!(1,6).epsilon.exponent == 30); 802 assert(CustomFloat!(1,6).epsilon.significand == 0); 803 assert(CustomFloat!(2,5).epsilon.sign == 0); 804 assert(CustomFloat!(2,5).epsilon.exponent == 13); 805 assert(CustomFloat!(2,5).epsilon.significand == 0); 806 assert(CustomFloat!(3,4).epsilon.sign == 0); 807 assert(CustomFloat!(3,4).epsilon.exponent == 4); 808 assert(CustomFloat!(3,4).epsilon.significand == 0); 809 // the following epsilons are only available, when denormalized numbers are allowed: 810 assert(CustomFloat!(4,3).epsilon.sign == 0); 811 assert(CustomFloat!(4,3).epsilon.exponent == 0); 812 assert(CustomFloat!(4,3).epsilon.significand == 4); 813 assert(CustomFloat!(5,2).epsilon.sign == 0); 814 assert(CustomFloat!(5,2).epsilon.exponent == 0); 815 assert(CustomFloat!(5,2).epsilon.significand == 1); 816 } 817 818 // testing .max 819 @safe unittest 820 { 821 static assert(CustomFloat!(5,2).max.sign == 0); 822 static assert(CustomFloat!(5,2).max.exponent == 2); 823 static assert(CustomFloat!(5,2).max.significand == 31); 824 static assert(CustomFloat!(4,3).max.sign == 0); 825 static assert(CustomFloat!(4,3).max.exponent == 6); 826 static assert(CustomFloat!(4,3).max.significand == 15); 827 static assert(CustomFloat!(3,4).max.sign == 0); 828 static assert(CustomFloat!(3,4).max.exponent == 14); 829 static assert(CustomFloat!(3,4).max.significand == 7); 830 static assert(CustomFloat!(2,5).max.sign == 0); 831 static assert(CustomFloat!(2,5).max.exponent == 30); 832 static assert(CustomFloat!(2,5).max.significand == 3); 833 static assert(CustomFloat!(1,6).max.sign == 0); 834 static assert(CustomFloat!(1,6).max.exponent == 62); 835 static assert(CustomFloat!(1,6).max.significand == 1); 836 static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.exponent == 31); 837 static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.significand == 7); 838 } 839 840 // testing .min_normal 841 @safe unittest 842 { 843 static assert(CustomFloat!(5,2).min_normal.sign == 0); 844 static assert(CustomFloat!(5,2).min_normal.exponent == 1); 845 static assert(CustomFloat!(5,2).min_normal.significand == 0); 846 static assert(CustomFloat!(4,3).min_normal.sign == 0); 847 static assert(CustomFloat!(4,3).min_normal.exponent == 1); 848 static assert(CustomFloat!(4,3).min_normal.significand == 0); 849 static assert(CustomFloat!(3,4).min_normal.sign == 0); 850 static assert(CustomFloat!(3,4).min_normal.exponent == 1); 851 static assert(CustomFloat!(3,4).min_normal.significand == 0); 852 static assert(CustomFloat!(2,5).min_normal.sign == 0); 853 static assert(CustomFloat!(2,5).min_normal.exponent == 1); 854 static assert(CustomFloat!(2,5).min_normal.significand == 0); 855 static assert(CustomFloat!(1,6).min_normal.sign == 0); 856 static assert(CustomFloat!(1,6).min_normal.exponent == 1); 857 static assert(CustomFloat!(1,6).min_normal.significand == 0); 858 static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.exponent == 0); 859 static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.significand == 4); 860 } 861 862 @safe unittest 863 { 864 import std.math : isNaN; 865 866 alias cf = CustomFloat!(5, 2); 867 868 auto f = cf.nan.get!float(); 869 assert(isNaN(f)); 870 871 cf a; 872 a = real.max; 873 assert(a == cf.infinity); 874 875 a = 0.015625; 876 assert(a.exponent == 0); 877 assert(a.significand == 0); 878 879 a = 0.984375; 880 assert(a.exponent == 1); 881 assert(a.significand == 0); 882 } 883 884 @system unittest 885 { 886 import std.exception : assertThrown; 887 import core.exception : AssertError; 888 889 alias cf = CustomFloat!(3, 5, CustomFloatFlags.none); 890 891 cf a; 892 assertThrown!AssertError(a = real.max); 893 } 894 895 @system unittest 896 { 897 import std.exception : assertThrown; 898 import core.exception : AssertError; 899 900 alias cf = CustomFloat!(3, 5, CustomFloatFlags.nan); 901 902 cf a; 903 assertThrown!AssertError(a = real.max); 904 } 905 906 @system unittest 907 { 908 import std.exception : assertThrown; 909 import core.exception : AssertError; 910 911 alias cf = CustomFloat!(24, 8, CustomFloatFlags.none); 912 913 cf a; 914 assertThrown!AssertError(a = float.infinity); 915 } 916 917 private bool isCorrectCustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags) @safe pure nothrow @nogc 918 { 919 // Restrictions from bitfield 920 // due to CustomFloat!80 support hack precision with 64 bits is handled specially 921 auto length = (flags & flags.signed) + exponentWidth + ((precision == 64) ? 0 : precision); 922 if (length != 8 && length != 16 && length != 32 && length != 64) return false; 923 924 // mantissa needs to fit into real mantissa 925 if (precision > real.mant_dig - 1 && precision != 64) return false; 926 927 // exponent needs to fit into real exponent 928 if (1L << exponentWidth - 1 > real.max_exp) return false; 929 930 // mantissa should have at least one bit 931 if (precision == 0) return false; 932 933 // exponent should have at least one bit, in some cases two 934 if (exponentWidth <= ((flags & (flags.allowDenorm | flags.infinity | flags.nan)) != 0)) return false; 935 936 return true; 937 } 938 939 @safe pure nothrow @nogc unittest 940 { 941 assert(isCorrectCustomFloat(3,4,CustomFloatFlags.ieee)); 942 assert(isCorrectCustomFloat(3,5,CustomFloatFlags.none)); 943 assert(!isCorrectCustomFloat(3,3,CustomFloatFlags.ieee)); 944 assert(isCorrectCustomFloat(64,7,CustomFloatFlags.ieee)); 945 assert(!isCorrectCustomFloat(64,4,CustomFloatFlags.ieee)); 946 assert(!isCorrectCustomFloat(508,3,CustomFloatFlags.ieee)); 947 assert(!isCorrectCustomFloat(3,100,CustomFloatFlags.ieee)); 948 assert(!isCorrectCustomFloat(0,7,CustomFloatFlags.ieee)); 949 assert(!isCorrectCustomFloat(6,1,CustomFloatFlags.ieee)); 950 assert(isCorrectCustomFloat(7,1,CustomFloatFlags.none)); 951 assert(!isCorrectCustomFloat(8,0,CustomFloatFlags.none)); 952 } 953 954 /** 955 Defines the fastest type to use when storing temporaries of a 956 calculation intended to ultimately yield a result of type `F` 957 (where `F` must be one of `float`, `double`, or $(D 958 real)). When doing a multi-step computation, you may want to store 959 intermediate results as `FPTemporary!F`. 960 961 The necessity of `FPTemporary` stems from the optimized 962 floating-point operations and registers present in virtually all 963 processors. When adding numbers in the example above, the addition may 964 in fact be done in `real` precision internally. In that case, 965 storing the intermediate `result` in $(D double format) is not only 966 less precise, it is also (surprisingly) slower, because a conversion 967 from `real` to `double` is performed every pass through the 968 loop. This being a lose-lose situation, `FPTemporary!F` has been 969 defined as the $(I fastest) type to use for calculations at precision 970 `F`. There is no need to define a type for the $(I most accurate) 971 calculations, as that is always `real`. 972 973 Finally, there is no guarantee that using `FPTemporary!F` will 974 always be fastest, as the speed of floating-point calculations depends 975 on very many factors. 976 */ 977 template FPTemporary(F) 978 if (isFloatingPoint!F) 979 { 980 version (X86) 981 alias FPTemporary = real; 982 else 983 alias FPTemporary = Unqual!F; 984 } 985 986 /// 987 @safe unittest 988 { 989 import std.math : approxEqual; 990 991 // Average numbers in an array 992 double avg(in double[] a) 993 { 994 if (a.length == 0) return 0; 995 FPTemporary!double result = 0; 996 foreach (e; a) result += e; 997 return result / a.length; 998 } 999 1000 auto a = [1.0, 2.0, 3.0]; 1001 assert(approxEqual(avg(a), 2)); 1002 } 1003 1004 /** 1005 Implements the $(HTTP tinyurl.com/2zb9yr, secant method) for finding a 1006 root of the function `fun` starting from points $(D [xn_1, x_n]) 1007 (ideally close to the root). `Num` may be `float`, `double`, 1008 or `real`. 1009 */ 1010 template secantMethod(alias fun) 1011 { 1012 import std.functional : unaryFun; 1013 Num secantMethod(Num)(Num xn_1, Num xn) 1014 { 1015 auto fxn = unaryFun!(fun)(xn_1), d = xn_1 - xn; 1016 typeof(fxn) fxn_1; 1017 1018 xn = xn_1; 1019 while (!approxEqual(d, 0) && isFinite(d)) 1020 { 1021 xn_1 = xn; 1022 xn -= d; 1023 fxn_1 = fxn; 1024 fxn = unaryFun!(fun)(xn); 1025 d *= -fxn / (fxn - fxn_1); 1026 } 1027 return xn; 1028 } 1029 } 1030 1031 /// 1032 @safe unittest 1033 { 1034 import std.math : approxEqual, cos; 1035 1036 float f(float x) 1037 { 1038 return cos(x) - x*x*x; 1039 } 1040 auto x = secantMethod!(f)(0f, 1f); 1041 assert(approxEqual(x, 0.865474)); 1042 } 1043 1044 @system unittest 1045 { 1046 // @system because of __gshared stderr 1047 import std.stdio; 1048 scope(failure) stderr.writeln("Failure testing secantMethod"); 1049 float f(float x) 1050 { 1051 return cos(x) - x*x*x; 1052 } 1053 immutable x = secantMethod!(f)(0f, 1f); 1054 assert(approxEqual(x, 0.865474)); 1055 auto d = &f; 1056 immutable y = secantMethod!(d)(0f, 1f); 1057 assert(approxEqual(y, 0.865474)); 1058 } 1059 1060 1061 /** 1062 * Return true if a and b have opposite sign. 1063 */ 1064 private bool oppositeSigns(T1, T2)(T1 a, T2 b) 1065 { 1066 return signbit(a) != signbit(b); 1067 } 1068 1069 public: 1070 1071 /** Find a real root of a real function f(x) via bracketing. 1072 * 1073 * Given a function `f` and a range `[a .. b]` such that `f(a)` 1074 * and `f(b)` have opposite signs or at least one of them equals ±0, 1075 * returns the value of `x` in 1076 * the range which is closest to a root of `f(x)`. If `f(x)` 1077 * has more than one root in the range, one will be chosen 1078 * arbitrarily. If `f(x)` returns NaN, NaN will be returned; 1079 * otherwise, this algorithm is guaranteed to succeed. 1080 * 1081 * Uses an algorithm based on TOMS748, which uses inverse cubic 1082 * interpolation whenever possible, otherwise reverting to parabolic 1083 * or secant interpolation. Compared to TOMS748, this implementation 1084 * improves worst-case performance by a factor of more than 100, and 1085 * typical performance by a factor of 2. For 80-bit reals, most 1086 * problems require 8 to 15 calls to `f(x)` to achieve full machine 1087 * precision. The worst-case performance (pathological cases) is 1088 * approximately twice the number of bits. 1089 * 1090 * References: "On Enclosing Simple Roots of Nonlinear Equations", 1091 * G. Alefeld, F.A. Potra, Yixun Shi, Mathematics of Computation 61, 1092 * pp733-744 (1993). Fortran code available from $(HTTP 1093 * www.netlib.org,www.netlib.org) as algorithm TOMS478. 1094 * 1095 */ 1096 T findRoot(T, DF, DT)(scope DF f, in T a, in T b, 1097 scope DT tolerance) //= (T a, T b) => false) 1098 if ( 1099 isFloatingPoint!T && 1100 is(typeof(tolerance(T.init, T.init)) : bool) && 1101 is(typeof(f(T.init)) == R, R) && isFloatingPoint!R 1102 ) 1103 { 1104 immutable fa = f(a); 1105 if (fa == 0) 1106 return a; 1107 immutable fb = f(b); 1108 if (fb == 0) 1109 return b; 1110 immutable r = findRoot(f, a, b, fa, fb, tolerance); 1111 // Return the first value if it is smaller or NaN 1112 return !(fabs(r[2]) > fabs(r[3])) ? r[0] : r[1]; 1113 } 1114 1115 ///ditto 1116 T findRoot(T, DF)(scope DF f, in T a, in T b) 1117 { 1118 return findRoot(f, a, b, (T a, T b) => false); 1119 } 1120 1121 /** Find root of a real function f(x) by bracketing, allowing the 1122 * termination condition to be specified. 1123 * 1124 * Params: 1125 * 1126 * f = Function to be analyzed 1127 * 1128 * ax = Left bound of initial range of `f` known to contain the 1129 * root. 1130 * 1131 * bx = Right bound of initial range of `f` known to contain the 1132 * root. 1133 * 1134 * fax = Value of `f(ax)`. 1135 * 1136 * fbx = Value of `f(bx)`. `fax` and `fbx` should have opposite signs. 1137 * (`f(ax)` and `f(bx)` are commonly known in advance.) 1138 * 1139 * 1140 * tolerance = Defines an early termination condition. Receives the 1141 * current upper and lower bounds on the root. The 1142 * delegate must return `true` when these bounds are 1143 * acceptable. If this function always returns `false`, 1144 * full machine precision will be achieved. 1145 * 1146 * Returns: 1147 * 1148 * A tuple consisting of two ranges. The first two elements are the 1149 * range (in `x`) of the root, while the second pair of elements 1150 * are the corresponding function values at those points. If an exact 1151 * root was found, both of the first two elements will contain the 1152 * root, and the second pair of elements will be 0. 1153 */ 1154 Tuple!(T, T, R, R) findRoot(T, R, DF, DT)(scope DF f, in T ax, in T bx, in R fax, in R fbx, 1155 scope DT tolerance) // = (T a, T b) => false) 1156 if ( 1157 isFloatingPoint!T && 1158 is(typeof(tolerance(T.init, T.init)) : bool) && 1159 is(typeof(f(T.init)) == R) && isFloatingPoint!R 1160 ) 1161 in 1162 { 1163 assert(!ax.isNaN() && !bx.isNaN(), "Limits must not be NaN"); 1164 assert(signbit(fax) != signbit(fbx), "Parameters must bracket the root."); 1165 } 1166 do 1167 { 1168 // Author: Don Clugston. This code is (heavily) modified from TOMS748 1169 // (www.netlib.org). The changes to improve the worst-cast performance are 1170 // entirely original. 1171 1172 T a, b, d; // [a .. b] is our current bracket. d is the third best guess. 1173 R fa, fb, fd; // Values of f at a, b, d. 1174 bool done = false; // Has a root been found? 1175 1176 // Allow ax and bx to be provided in reverse order 1177 if (ax <= bx) 1178 { 1179 a = ax; fa = fax; 1180 b = bx; fb = fbx; 1181 } 1182 else 1183 { 1184 a = bx; fa = fbx; 1185 b = ax; fb = fax; 1186 } 1187 1188 // Test the function at point c; update brackets accordingly 1189 void bracket(T c) 1190 { 1191 R fc = f(c); 1192 if (fc == 0 || fc.isNaN()) // Exact solution, or NaN 1193 { 1194 a = c; 1195 fa = fc; 1196 d = c; 1197 fd = fc; 1198 done = true; 1199 return; 1200 } 1201 1202 // Determine new enclosing interval 1203 if (signbit(fa) != signbit(fc)) 1204 { 1205 d = b; 1206 fd = fb; 1207 b = c; 1208 fb = fc; 1209 } 1210 else 1211 { 1212 d = a; 1213 fd = fa; 1214 a = c; 1215 fa = fc; 1216 } 1217 } 1218 1219 /* Perform a secant interpolation. If the result would lie on a or b, or if 1220 a and b differ so wildly in magnitude that the result would be meaningless, 1221 perform a bisection instead. 1222 */ 1223 static T secant_interpolate(T a, T b, R fa, R fb) 1224 { 1225 if (( ((a - b) == a) && b != 0) || (a != 0 && ((b - a) == b))) 1226 { 1227 // Catastrophic cancellation 1228 if (a == 0) 1229 a = copysign(T(0), b); 1230 else if (b == 0) 1231 b = copysign(T(0), a); 1232 else if (signbit(a) != signbit(b)) 1233 return 0; 1234 T c = ieeeMean(a, b); 1235 return c; 1236 } 1237 // avoid overflow 1238 if (b - a > T.max) 1239 return b / 2 + a / 2; 1240 if (fb - fa > R.max) 1241 return a - (b - a) / 2; 1242 T c = a - (fa / (fb - fa)) * (b - a); 1243 if (c == a || c == b) 1244 return (a + b) / 2; 1245 return c; 1246 } 1247 1248 /* Uses 'numsteps' newton steps to approximate the zero in [a .. b] of the 1249 quadratic polynomial interpolating f(x) at a, b, and d. 1250 Returns: 1251 The approximate zero in [a .. b] of the quadratic polynomial. 1252 */ 1253 T newtonQuadratic(int numsteps) 1254 { 1255 // Find the coefficients of the quadratic polynomial. 1256 immutable T a0 = fa; 1257 immutable T a1 = (fb - fa)/(b - a); 1258 immutable T a2 = ((fd - fb)/(d - b) - a1)/(d - a); 1259 1260 // Determine the starting point of newton steps. 1261 T c = oppositeSigns(a2, fa) ? a : b; 1262 1263 // start the safeguarded newton steps. 1264 foreach (int i; 0 .. numsteps) 1265 { 1266 immutable T pc = a0 + (a1 + a2 * (c - b))*(c - a); 1267 immutable T pdc = a1 + a2*((2 * c) - (a + b)); 1268 if (pdc == 0) 1269 return a - a0 / a1; 1270 else 1271 c = c - pc / pdc; 1272 } 1273 return c; 1274 } 1275 1276 // On the first iteration we take a secant step: 1277 if (fa == 0 || fa.isNaN()) 1278 { 1279 done = true; 1280 b = a; 1281 fb = fa; 1282 } 1283 else if (fb == 0 || fb.isNaN()) 1284 { 1285 done = true; 1286 a = b; 1287 fa = fb; 1288 } 1289 else 1290 { 1291 bracket(secant_interpolate(a, b, fa, fb)); 1292 } 1293 1294 // Starting with the second iteration, higher-order interpolation can 1295 // be used. 1296 int itnum = 1; // Iteration number 1297 int baditer = 1; // Num bisections to take if an iteration is bad. 1298 T c, e; // e is our fourth best guess 1299 R fe; 1300 1301 whileloop: 1302 while (!done && (b != nextUp(a)) && !tolerance(a, b)) 1303 { 1304 T a0 = a, b0 = b; // record the brackets 1305 1306 // Do two higher-order (cubic or parabolic) interpolation steps. 1307 foreach (int QQ; 0 .. 2) 1308 { 1309 // Cubic inverse interpolation requires that 1310 // all four function values fa, fb, fd, and fe are distinct; 1311 // otherwise use quadratic interpolation. 1312 bool distinct = (fa != fb) && (fa != fd) && (fa != fe) 1313 && (fb != fd) && (fb != fe) && (fd != fe); 1314 // The first time, cubic interpolation is impossible. 1315 if (itnum<2) distinct = false; 1316 bool ok = distinct; 1317 if (distinct) 1318 { 1319 // Cubic inverse interpolation of f(x) at a, b, d, and e 1320 immutable q11 = (d - e) * fd / (fe - fd); 1321 immutable q21 = (b - d) * fb / (fd - fb); 1322 immutable q31 = (a - b) * fa / (fb - fa); 1323 immutable d21 = (b - d) * fd / (fd - fb); 1324 immutable d31 = (a - b) * fb / (fb - fa); 1325 1326 immutable q22 = (d21 - q11) * fb / (fe - fb); 1327 immutable q32 = (d31 - q21) * fa / (fd - fa); 1328 immutable d32 = (d31 - q21) * fd / (fd - fa); 1329 immutable q33 = (d32 - q22) * fa / (fe - fa); 1330 c = a + (q31 + q32 + q33); 1331 if (c.isNaN() || (c <= a) || (c >= b)) 1332 { 1333 // DAC: If the interpolation predicts a or b, it's 1334 // probable that it's the actual root. Only allow this if 1335 // we're already close to the root. 1336 if (c == a && a - b != a) 1337 { 1338 c = nextUp(a); 1339 } 1340 else if (c == b && a - b != -b) 1341 { 1342 c = nextDown(b); 1343 } 1344 else 1345 { 1346 ok = false; 1347 } 1348 } 1349 } 1350 if (!ok) 1351 { 1352 // DAC: Alefeld doesn't explain why the number of newton steps 1353 // should vary. 1354 c = newtonQuadratic(distinct ? 3 : 2); 1355 if (c.isNaN() || (c <= a) || (c >= b)) 1356 { 1357 // Failure, try a secant step: 1358 c = secant_interpolate(a, b, fa, fb); 1359 } 1360 } 1361 ++itnum; 1362 e = d; 1363 fe = fd; 1364 bracket(c); 1365 if (done || ( b == nextUp(a)) || tolerance(a, b)) 1366 break whileloop; 1367 if (itnum == 2) 1368 continue whileloop; 1369 } 1370 1371 // Now we take a double-length secant step: 1372 T u; 1373 R fu; 1374 if (fabs(fa) < fabs(fb)) 1375 { 1376 u = a; 1377 fu = fa; 1378 } 1379 else 1380 { 1381 u = b; 1382 fu = fb; 1383 } 1384 c = u - 2 * (fu / (fb - fa)) * (b - a); 1385 1386 // DAC: If the secant predicts a value equal to an endpoint, it's 1387 // probably false. 1388 if (c == a || c == b || c.isNaN() || fabs(c - u) > (b - a) / 2) 1389 { 1390 if ((a-b) == a || (b-a) == b) 1391 { 1392 if ((a>0 && b<0) || (a<0 && b>0)) 1393 c = 0; 1394 else 1395 { 1396 if (a == 0) 1397 c = ieeeMean(copysign(T(0), b), b); 1398 else if (b == 0) 1399 c = ieeeMean(copysign(T(0), a), a); 1400 else 1401 c = ieeeMean(a, b); 1402 } 1403 } 1404 else 1405 { 1406 c = a + (b - a) / 2; 1407 } 1408 } 1409 e = d; 1410 fe = fd; 1411 bracket(c); 1412 if (done || (b == nextUp(a)) || tolerance(a, b)) 1413 break; 1414 1415 // IMPROVE THE WORST-CASE PERFORMANCE 1416 // We must ensure that the bounds reduce by a factor of 2 1417 // in binary space! every iteration. If we haven't achieved this 1418 // yet, or if we don't yet know what the exponent is, 1419 // perform a binary chop. 1420 1421 if ((a == 0 || b == 0 || 1422 (fabs(a) >= T(0.5) * fabs(b) && fabs(b) >= T(0.5) * fabs(a))) 1423 && (b - a) < T(0.25) * (b0 - a0)) 1424 { 1425 baditer = 1; 1426 continue; 1427 } 1428 1429 // DAC: If this happens on consecutive iterations, we probably have a 1430 // pathological function. Perform a number of bisections equal to the 1431 // total number of consecutive bad iterations. 1432 1433 if ((b - a) < T(0.25) * (b0 - a0)) 1434 baditer = 1; 1435 foreach (int QQ; 0 .. baditer) 1436 { 1437 e = d; 1438 fe = fd; 1439 1440 T w; 1441 if ((a>0 && b<0) || (a<0 && b>0)) 1442 w = 0; 1443 else 1444 { 1445 T usea = a; 1446 T useb = b; 1447 if (a == 0) 1448 usea = copysign(T(0), b); 1449 else if (b == 0) 1450 useb = copysign(T(0), a); 1451 w = ieeeMean(usea, useb); 1452 } 1453 bracket(w); 1454 } 1455 ++baditer; 1456 } 1457 return Tuple!(T, T, R, R)(a, b, fa, fb); 1458 } 1459 1460 ///ditto 1461 Tuple!(T, T, R, R) findRoot(T, R, DF)(scope DF f, in T ax, in T bx, in R fax, in R fbx) 1462 { 1463 return findRoot(f, ax, bx, fax, fbx, (T a, T b) => false); 1464 } 1465 1466 ///ditto 1467 T findRoot(T, R)(scope R delegate(T) f, in T a, in T b, 1468 scope bool delegate(T lo, T hi) tolerance = (T a, T b) => false) 1469 { 1470 return findRoot!(T, R delegate(T), bool delegate(T lo, T hi))(f, a, b, tolerance); 1471 } 1472 1473 @safe nothrow unittest 1474 { 1475 int numProblems = 0; 1476 int numCalls; 1477 1478 void testFindRoot(real delegate(real) @nogc @safe nothrow pure f , real x1, real x2) @nogc @safe nothrow pure 1479 { 1480 //numCalls=0; 1481 //++numProblems; 1482 assert(!x1.isNaN() && !x2.isNaN()); 1483 assert(signbit(f(x1)) != signbit(f(x2))); 1484 auto result = findRoot(f, x1, x2, f(x1), f(x2), 1485 (real lo, real hi) { return false; }); 1486 1487 auto flo = f(result[0]); 1488 auto fhi = f(result[1]); 1489 if (flo != 0) 1490 { 1491 assert(oppositeSigns(flo, fhi)); 1492 } 1493 } 1494 1495 // Test functions 1496 real cubicfn(real x) @nogc @safe nothrow pure 1497 { 1498 //++numCalls; 1499 if (x>float.max) 1500 x = float.max; 1501 if (x<-float.max) 1502 x = -float.max; 1503 // This has a single real root at -59.286543284815 1504 return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2; 1505 } 1506 // Test a function with more than one root. 1507 real multisine(real x) { ++numCalls; return sin(x); } 1508 testFindRoot( &multisine, 6, 90); 1509 testFindRoot(&cubicfn, -100, 100); 1510 testFindRoot( &cubicfn, -double.max, real.max); 1511 1512 1513 /* Tests from the paper: 1514 * "On Enclosing Simple Roots of Nonlinear Equations", G. Alefeld, F.A. Potra, 1515 * Yixun Shi, Mathematics of Computation 61, pp733-744 (1993). 1516 */ 1517 // Parameters common to many alefeld tests. 1518 int n; 1519 real ale_a, ale_b; 1520 1521 int powercalls = 0; 1522 1523 real power(real x) 1524 { 1525 ++powercalls; 1526 ++numCalls; 1527 return pow(x, n) + double.min_normal; 1528 } 1529 int [] power_nvals = [3, 5, 7, 9, 19, 25]; 1530 // Alefeld paper states that pow(x,n) is a very poor case, where bisection 1531 // outperforms his method, and gives total numcalls = 1532 // 921 for bisection (2.4 calls per bit), 1830 for Alefeld (4.76/bit), 1533 // 2624 for brent (6.8/bit) 1534 // ... but that is for double, not real80. 1535 // This poor performance seems mainly due to catastrophic cancellation, 1536 // which is avoided here by the use of ieeeMean(). 1537 // I get: 231 (0.48/bit). 1538 // IE this is 10X faster in Alefeld's worst case 1539 numProblems=0; 1540 foreach (k; power_nvals) 1541 { 1542 n = k; 1543 testFindRoot(&power, -1, 10); 1544 } 1545 1546 int powerProblems = numProblems; 1547 1548 // Tests from Alefeld paper 1549 1550 int [9] alefeldSums; 1551 real alefeld0(real x) 1552 { 1553 ++alefeldSums[0]; 1554 ++numCalls; 1555 real q = sin(x) - x/2; 1556 for (int i=1; i<20; ++i) 1557 q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i)); 1558 return q; 1559 } 1560 real alefeld1(real x) 1561 { 1562 ++numCalls; 1563 ++alefeldSums[1]; 1564 return ale_a*x + exp(ale_b * x); 1565 } 1566 real alefeld2(real x) 1567 { 1568 ++numCalls; 1569 ++alefeldSums[2]; 1570 return pow(x, n) - ale_a; 1571 } 1572 real alefeld3(real x) 1573 { 1574 ++numCalls; 1575 ++alefeldSums[3]; 1576 return (1.0 +pow(1.0L-n, 2))*x - pow(1.0L-n*x, 2); 1577 } 1578 real alefeld4(real x) 1579 { 1580 ++numCalls; 1581 ++alefeldSums[4]; 1582 return x*x - pow(1-x, n); 1583 } 1584 real alefeld5(real x) 1585 { 1586 ++numCalls; 1587 ++alefeldSums[5]; 1588 return (1+pow(1.0L-n, 4))*x - pow(1.0L-n*x, 4); 1589 } 1590 real alefeld6(real x) 1591 { 1592 ++numCalls; 1593 ++alefeldSums[6]; 1594 return exp(-n*x)*(x-1.01L) + pow(x, n); 1595 } 1596 real alefeld7(real x) 1597 { 1598 ++numCalls; 1599 ++alefeldSums[7]; 1600 return (n*x-1)/((n-1)*x); 1601 } 1602 1603 numProblems=0; 1604 testFindRoot(&alefeld0, PI_2, PI); 1605 for (n=1; n <= 10; ++n) 1606 { 1607 testFindRoot(&alefeld0, n*n+1e-9L, (n+1)*(n+1)-1e-9L); 1608 } 1609 ale_a = -40; ale_b = -1; 1610 testFindRoot(&alefeld1, -9, 31); 1611 ale_a = -100; ale_b = -2; 1612 testFindRoot(&alefeld1, -9, 31); 1613 ale_a = -200; ale_b = -3; 1614 testFindRoot(&alefeld1, -9, 31); 1615 int [] nvals_3 = [1, 2, 5, 10, 15, 20]; 1616 int [] nvals_5 = [1, 2, 4, 5, 8, 15, 20]; 1617 int [] nvals_6 = [1, 5, 10, 15, 20]; 1618 int [] nvals_7 = [2, 5, 15, 20]; 1619 1620 for (int i=4; i<12; i+=2) 1621 { 1622 n = i; 1623 ale_a = 0.2; 1624 testFindRoot(&alefeld2, 0, 5); 1625 ale_a=1; 1626 testFindRoot(&alefeld2, 0.95, 4.05); 1627 testFindRoot(&alefeld2, 0, 1.5); 1628 } 1629 foreach (i; nvals_3) 1630 { 1631 n=i; 1632 testFindRoot(&alefeld3, 0, 1); 1633 } 1634 foreach (i; nvals_3) 1635 { 1636 n=i; 1637 testFindRoot(&alefeld4, 0, 1); 1638 } 1639 foreach (i; nvals_5) 1640 { 1641 n=i; 1642 testFindRoot(&alefeld5, 0, 1); 1643 } 1644 foreach (i; nvals_6) 1645 { 1646 n=i; 1647 testFindRoot(&alefeld6, 0, 1); 1648 } 1649 foreach (i; nvals_7) 1650 { 1651 n=i; 1652 testFindRoot(&alefeld7, 0.01L, 1); 1653 } 1654 real worstcase(real x) 1655 { 1656 ++numCalls; 1657 return x<0.3*real.max? -0.999e-3 : 1.0; 1658 } 1659 testFindRoot(&worstcase, -real.max, real.max); 1660 1661 // just check that the double + float cases compile 1662 findRoot((double x){ return 0.0; }, -double.max, double.max); 1663 findRoot((float x){ return 0.0f; }, -float.max, float.max); 1664 1665 /* 1666 int grandtotal=0; 1667 foreach (calls; alefeldSums) 1668 { 1669 grandtotal+=calls; 1670 } 1671 grandtotal-=2*numProblems; 1672 printf("\nALEFELD TOTAL = %d avg = %f (alefeld avg=19.3 for double)\n", 1673 grandtotal, (1.0*grandtotal)/numProblems); 1674 powercalls -= 2*powerProblems; 1675 printf("POWER TOTAL = %d avg = %f ", powercalls, 1676 (1.0*powercalls)/powerProblems); 1677 */ 1678 // https://issues.dlang.org/show_bug.cgi?id=14231 1679 auto xp = findRoot((float x) => x, 0f, 1f); 1680 auto xn = findRoot((float x) => x, -1f, -0f); 1681 } 1682 1683 //regression control 1684 @system unittest 1685 { 1686 // @system due to the case in the 2nd line 1687 static assert(__traits(compiles, findRoot((float x)=>cast(real) x, float.init, float.init))); 1688 static assert(__traits(compiles, findRoot!real((x)=>cast(double) x, real.init, real.init))); 1689 static assert(__traits(compiles, findRoot((real x)=>cast(double) x, real.init, real.init))); 1690 } 1691 1692 /++ 1693 Find a real minimum of a real function `f(x)` via bracketing. 1694 Given a function `f` and a range `(ax .. bx)`, 1695 returns the value of `x` in the range which is closest to a minimum of `f(x)`. 1696 `f` is never evaluted at the endpoints of `ax` and `bx`. 1697 If `f(x)` has more than one minimum in the range, one will be chosen arbitrarily. 1698 If `f(x)` returns NaN or -Infinity, `(x, f(x), NaN)` will be returned; 1699 otherwise, this algorithm is guaranteed to succeed. 1700 1701 Params: 1702 f = Function to be analyzed 1703 ax = Left bound of initial range of f known to contain the minimum. 1704 bx = Right bound of initial range of f known to contain the minimum. 1705 relTolerance = Relative tolerance. 1706 absTolerance = Absolute tolerance. 1707 1708 Preconditions: 1709 `ax` and `bx` shall be finite reals. $(BR) 1710 `relTolerance` shall be normal positive real. $(BR) 1711 `absTolerance` shall be normal positive real no less then `T.epsilon*2`. 1712 1713 Returns: 1714 A tuple consisting of `x`, `y = f(x)` and `error = 3 * (absTolerance * fabs(x) + relTolerance)`. 1715 1716 The method used is a combination of golden section search and 1717 successive parabolic interpolation. Convergence is never much slower 1718 than that for a Fibonacci search. 1719 1720 References: 1721 "Algorithms for Minimization without Derivatives", Richard Brent, Prentice-Hall, Inc. (1973) 1722 1723 See_Also: $(LREF findRoot), $(REF isNormal, std,math) 1724 +/ 1725 Tuple!(T, "x", Unqual!(ReturnType!DF), "y", T, "error") 1726 findLocalMin(T, DF)( 1727 scope DF f, 1728 in T ax, 1729 in T bx, 1730 in T relTolerance = sqrt(T.epsilon), 1731 in T absTolerance = sqrt(T.epsilon), 1732 ) 1733 if (isFloatingPoint!T 1734 && __traits(compiles, {T _ = DF.init(T.init);})) 1735 in 1736 { 1737 assert(isFinite(ax), "ax is not finite"); 1738 assert(isFinite(bx), "bx is not finite"); 1739 assert(isNormal(relTolerance), "relTolerance is not normal floating point number"); 1740 assert(isNormal(absTolerance), "absTolerance is not normal floating point number"); 1741 assert(relTolerance >= 0, "absTolerance is not positive"); 1742 assert(absTolerance >= T.epsilon*2, "absTolerance is not greater then `2*T.epsilon`"); 1743 } 1744 out (result) 1745 { 1746 assert(isFinite(result.x)); 1747 } 1748 do 1749 { 1750 alias R = Unqual!(CommonType!(ReturnType!DF, T)); 1751 // c is the squared inverse of the golden ratio 1752 // (3 - sqrt(5))/2 1753 // Value obtained from Wolfram Alpha. 1754 enum T c = 0x0.61c8864680b583ea0c633f9fa31237p+0L; 1755 enum T cm1 = 0x0.9e3779b97f4a7c15f39cc0605cedc8p+0L; 1756 R tolerance; 1757 T a = ax > bx ? bx : ax; 1758 T b = ax > bx ? ax : bx; 1759 // sequence of declarations suitable for SIMD instructions 1760 T v = a * cm1 + b * c; 1761 assert(isFinite(v)); 1762 R fv = f(v); 1763 if (isNaN(fv) || fv == -T.infinity) 1764 { 1765 return typeof(return)(v, fv, T.init); 1766 } 1767 T w = v; 1768 R fw = fv; 1769 T x = v; 1770 R fx = fv; 1771 size_t i; 1772 for (R d = 0, e = 0;;) 1773 { 1774 i++; 1775 T m = (a + b) / 2; 1776 // This fix is not part of the original algorithm 1777 if (!isFinite(m)) // fix infinity loop. Issue can be reproduced in R. 1778 { 1779 m = a / 2 + b / 2; 1780 if (!isFinite(m)) // fast-math compiler switch is enabled 1781 { 1782 //SIMD instructions can be used by compiler, do not reduce declarations 1783 int a_exp = void; 1784 int b_exp = void; 1785 immutable an = frexp(a, a_exp); 1786 immutable bn = frexp(b, b_exp); 1787 immutable am = ldexp(an, a_exp-1); 1788 immutable bm = ldexp(bn, b_exp-1); 1789 m = am + bm; 1790 if (!isFinite(m)) // wrong input: constraints are disabled in release mode 1791 { 1792 return typeof(return).init; 1793 } 1794 } 1795 } 1796 tolerance = absTolerance * fabs(x) + relTolerance; 1797 immutable t2 = tolerance * 2; 1798 // check stopping criterion 1799 if (!(fabs(x - m) > t2 - (b - a) / 2)) 1800 { 1801 break; 1802 } 1803 R p = 0; 1804 R q = 0; 1805 R r = 0; 1806 // fit parabola 1807 if (fabs(e) > tolerance) 1808 { 1809 immutable xw = x - w; 1810 immutable fxw = fx - fw; 1811 immutable xv = x - v; 1812 immutable fxv = fx - fv; 1813 immutable xwfxv = xw * fxv; 1814 immutable xvfxw = xv * fxw; 1815 p = xv * xvfxw - xw * xwfxv; 1816 q = (xvfxw - xwfxv) * 2; 1817 if (q > 0) 1818 p = -p; 1819 else 1820 q = -q; 1821 r = e; 1822 e = d; 1823 } 1824 T u; 1825 // a parabolic-interpolation step 1826 if (fabs(p) < fabs(q * r / 2) && p > q * (a - x) && p < q * (b - x)) 1827 { 1828 d = p / q; 1829 u = x + d; 1830 // f must not be evaluated too close to a or b 1831 if (u - a < t2 || b - u < t2) 1832 d = x < m ? tolerance : -tolerance; 1833 } 1834 // a golden-section step 1835 else 1836 { 1837 e = (x < m ? b : a) - x; 1838 d = c * e; 1839 } 1840 // f must not be evaluated too close to x 1841 u = x + (fabs(d) >= tolerance ? d : d > 0 ? tolerance : -tolerance); 1842 immutable fu = f(u); 1843 if (isNaN(fu) || fu == -T.infinity) 1844 { 1845 return typeof(return)(u, fu, T.init); 1846 } 1847 // update a, b, v, w, and x 1848 if (fu <= fx) 1849 { 1850 (u < x ? b : a) = x; 1851 v = w; fv = fw; 1852 w = x; fw = fx; 1853 x = u; fx = fu; 1854 } 1855 else 1856 { 1857 (u < x ? a : b) = u; 1858 if (fu <= fw || w == x) 1859 { 1860 v = w; fv = fw; 1861 w = u; fw = fu; 1862 } 1863 else if (fu <= fv || v == x || v == w) 1864 { // do not remove this braces 1865 v = u; fv = fu; 1866 } 1867 } 1868 } 1869 return typeof(return)(x, fx, tolerance * 3); 1870 } 1871 1872 /// 1873 @safe unittest 1874 { 1875 import std.math : approxEqual; 1876 1877 auto ret = findLocalMin((double x) => (x-4)^^2, -1e7, 1e7); 1878 assert(ret.x.approxEqual(4.0)); 1879 assert(ret.y.approxEqual(0.0)); 1880 } 1881 1882 @safe unittest 1883 { 1884 import std.meta : AliasSeq; 1885 static foreach (T; AliasSeq!(double, float, real)) 1886 { 1887 { 1888 auto ret = findLocalMin!T((T x) => (x-4)^^2, T.min_normal, 1e7); 1889 assert(ret.x.approxEqual(T(4))); 1890 assert(ret.y.approxEqual(T(0))); 1891 } 1892 { 1893 auto ret = findLocalMin!T((T x) => fabs(x-1), -T.max/4, T.max/4, T.min_normal, 2*T.epsilon); 1894 assert(approxEqual(ret.x, T(1))); 1895 assert(approxEqual(ret.y, T(0))); 1896 assert(ret.error <= 10 * T.epsilon); 1897 } 1898 { 1899 auto ret = findLocalMin!T((T x) => T.init, 0, 1, T.min_normal, 2*T.epsilon); 1900 assert(!ret.x.isNaN); 1901 assert(ret.y.isNaN); 1902 assert(ret.error.isNaN); 1903 } 1904 { 1905 auto ret = findLocalMin!T((T x) => log(x), 0, 1, T.min_normal, 2*T.epsilon); 1906 assert(ret.error < 3.00001 * ((2*T.epsilon)*fabs(ret.x)+ T.min_normal)); 1907 assert(ret.x >= 0 && ret.x <= ret.error); 1908 } 1909 { 1910 auto ret = findLocalMin!T((T x) => log(x), 0, T.max, T.min_normal, 2*T.epsilon); 1911 assert(ret.y < -18); 1912 assert(ret.error < 5e-08); 1913 assert(ret.x >= 0 && ret.x <= ret.error); 1914 } 1915 { 1916 auto ret = findLocalMin!T((T x) => -fabs(x), -1, 1, T.min_normal, 2*T.epsilon); 1917 assert(ret.x.fabs.approxEqual(T(1))); 1918 assert(ret.y.fabs.approxEqual(T(1))); 1919 assert(ret.error.approxEqual(T(0))); 1920 } 1921 } 1922 } 1923 1924 /** 1925 Computes $(LINK2 https://en.wikipedia.org/wiki/Euclidean_distance, 1926 Euclidean distance) between input ranges `a` and 1927 `b`. The two ranges must have the same length. The three-parameter 1928 version stops computation as soon as the distance is greater than or 1929 equal to `limit` (this is useful to save computation if a small 1930 distance is sought). 1931 */ 1932 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1933 euclideanDistance(Range1, Range2)(Range1 a, Range2 b) 1934 if (isInputRange!(Range1) && isInputRange!(Range2)) 1935 { 1936 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 1937 static if (haveLen) assert(a.length == b.length); 1938 Unqual!(typeof(return)) result = 0; 1939 for (; !a.empty; a.popFront(), b.popFront()) 1940 { 1941 immutable t = a.front - b.front; 1942 result += t * t; 1943 } 1944 static if (!haveLen) assert(b.empty); 1945 return sqrt(result); 1946 } 1947 1948 /// Ditto 1949 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1950 euclideanDistance(Range1, Range2, F)(Range1 a, Range2 b, F limit) 1951 if (isInputRange!(Range1) && isInputRange!(Range2)) 1952 { 1953 limit *= limit; 1954 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 1955 static if (haveLen) assert(a.length == b.length); 1956 Unqual!(typeof(return)) result = 0; 1957 for (; ; a.popFront(), b.popFront()) 1958 { 1959 if (a.empty) 1960 { 1961 static if (!haveLen) assert(b.empty); 1962 break; 1963 } 1964 immutable t = a.front - b.front; 1965 result += t * t; 1966 if (result >= limit) break; 1967 } 1968 return sqrt(result); 1969 } 1970 1971 @safe unittest 1972 { 1973 import std.meta : AliasSeq; 1974 static foreach (T; AliasSeq!(double, const double, immutable double)) 1975 {{ 1976 T[] a = [ 1.0, 2.0, ]; 1977 T[] b = [ 4.0, 6.0, ]; 1978 assert(euclideanDistance(a, b) == 5); 1979 assert(euclideanDistance(a, b, 6) == 5); 1980 assert(euclideanDistance(a, b, 5) == 5); 1981 assert(euclideanDistance(a, b, 4) == 5); 1982 assert(euclideanDistance(a, b, 2) == 3); 1983 }} 1984 } 1985 1986 /** 1987 Computes the $(LINK2 https://en.wikipedia.org/wiki/Dot_product, 1988 dot product) of input ranges `a` and $(D 1989 b). The two ranges must have the same length. If both ranges define 1990 length, the check is done once; otherwise, it is done at each 1991 iteration. 1992 */ 1993 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1994 dotProduct(Range1, Range2)(Range1 a, Range2 b) 1995 if (isInputRange!(Range1) && isInputRange!(Range2) && 1996 !(isArray!(Range1) && isArray!(Range2))) 1997 { 1998 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 1999 static if (haveLen) assert(a.length == b.length); 2000 Unqual!(typeof(return)) result = 0; 2001 for (; !a.empty; a.popFront(), b.popFront()) 2002 { 2003 result += a.front * b.front; 2004 } 2005 static if (!haveLen) assert(b.empty); 2006 return result; 2007 } 2008 2009 /// Ditto 2010 CommonType!(F1, F2) 2011 dotProduct(F1, F2)(in F1[] avector, in F2[] bvector) 2012 { 2013 immutable n = avector.length; 2014 assert(n == bvector.length); 2015 auto avec = avector.ptr, bvec = bvector.ptr; 2016 Unqual!(typeof(return)) sum0 = 0, sum1 = 0; 2017 2018 const all_endp = avec + n; 2019 const smallblock_endp = avec + (n & ~3); 2020 const bigblock_endp = avec + (n & ~15); 2021 2022 for (; avec != bigblock_endp; avec += 16, bvec += 16) 2023 { 2024 sum0 += avec[0] * bvec[0]; 2025 sum1 += avec[1] * bvec[1]; 2026 sum0 += avec[2] * bvec[2]; 2027 sum1 += avec[3] * bvec[3]; 2028 sum0 += avec[4] * bvec[4]; 2029 sum1 += avec[5] * bvec[5]; 2030 sum0 += avec[6] * bvec[6]; 2031 sum1 += avec[7] * bvec[7]; 2032 sum0 += avec[8] * bvec[8]; 2033 sum1 += avec[9] * bvec[9]; 2034 sum0 += avec[10] * bvec[10]; 2035 sum1 += avec[11] * bvec[11]; 2036 sum0 += avec[12] * bvec[12]; 2037 sum1 += avec[13] * bvec[13]; 2038 sum0 += avec[14] * bvec[14]; 2039 sum1 += avec[15] * bvec[15]; 2040 } 2041 2042 for (; avec != smallblock_endp; avec += 4, bvec += 4) 2043 { 2044 sum0 += avec[0] * bvec[0]; 2045 sum1 += avec[1] * bvec[1]; 2046 sum0 += avec[2] * bvec[2]; 2047 sum1 += avec[3] * bvec[3]; 2048 } 2049 2050 sum0 += sum1; 2051 2052 /* Do trailing portion in naive loop. */ 2053 while (avec != all_endp) 2054 { 2055 sum0 += *avec * *bvec; 2056 ++avec; 2057 ++bvec; 2058 } 2059 2060 return sum0; 2061 } 2062 2063 /// ditto 2064 F dotProduct(F, uint N)(const ref scope F[N] a, const ref scope F[N] b) 2065 if (N <= 16) 2066 { 2067 F sum0 = 0; 2068 F sum1 = 0; 2069 static foreach (i; 0 .. N / 2) 2070 { 2071 sum0 += a[i*2] * b[i*2]; 2072 sum1 += a[i*2+1] * b[i*2+1]; 2073 } 2074 static if (N % 2 == 1) 2075 { 2076 sum0 += a[N-1] * b[N-1]; 2077 } 2078 return sum0 + sum1; 2079 } 2080 2081 @system unittest 2082 { 2083 // @system due to dotProduct and assertCTFEable 2084 import std.exception : assertCTFEable; 2085 import std.meta : AliasSeq; 2086 static foreach (T; AliasSeq!(double, const double, immutable double)) 2087 {{ 2088 T[] a = [ 1.0, 2.0, ]; 2089 T[] b = [ 4.0, 6.0, ]; 2090 assert(dotProduct(a, b) == 16); 2091 assert(dotProduct([1, 3, -5], [4, -2, -1]) == 3); 2092 // Test with fixed-length arrays. 2093 T[2] c = [ 1.0, 2.0, ]; 2094 T[2] d = [ 4.0, 6.0, ]; 2095 assert(dotProduct(c, d) == 16); 2096 T[3] e = [1, 3, -5]; 2097 T[3] f = [4, -2, -1]; 2098 assert(dotProduct(e, f) == 3); 2099 }} 2100 2101 // Make sure the unrolled loop codepath gets tested. 2102 static const x = 2103 [1.0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]; 2104 static const y = 2105 [2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]; 2106 assertCTFEable!({ assert(dotProduct(x, y) == 4048); }); 2107 } 2108 2109 /** 2110 Computes the $(LINK2 https://en.wikipedia.org/wiki/Cosine_similarity, 2111 cosine similarity) of input ranges `a` and $(D 2112 b). The two ranges must have the same length. If both ranges define 2113 length, the check is done once; otherwise, it is done at each 2114 iteration. If either range has all-zero elements, return 0. 2115 */ 2116 CommonType!(ElementType!(Range1), ElementType!(Range2)) 2117 cosineSimilarity(Range1, Range2)(Range1 a, Range2 b) 2118 if (isInputRange!(Range1) && isInputRange!(Range2)) 2119 { 2120 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2121 static if (haveLen) assert(a.length == b.length); 2122 Unqual!(typeof(return)) norma = 0, normb = 0, dotprod = 0; 2123 for (; !a.empty; a.popFront(), b.popFront()) 2124 { 2125 immutable t1 = a.front, t2 = b.front; 2126 norma += t1 * t1; 2127 normb += t2 * t2; 2128 dotprod += t1 * t2; 2129 } 2130 static if (!haveLen) assert(b.empty); 2131 if (norma == 0 || normb == 0) return 0; 2132 return dotprod / sqrt(norma * normb); 2133 } 2134 2135 @safe unittest 2136 { 2137 import std.meta : AliasSeq; 2138 static foreach (T; AliasSeq!(double, const double, immutable double)) 2139 {{ 2140 T[] a = [ 1.0, 2.0, ]; 2141 T[] b = [ 4.0, 3.0, ]; 2142 assert(approxEqual( 2143 cosineSimilarity(a, b), 10.0 / sqrt(5.0 * 25), 2144 0.01)); 2145 }} 2146 } 2147 2148 /** 2149 Normalizes values in `range` by multiplying each element with a 2150 number chosen such that values sum up to `sum`. If elements in $(D 2151 range) sum to zero, assigns $(D sum / range.length) to 2152 all. Normalization makes sense only if all elements in `range` are 2153 positive. `normalize` assumes that is the case without checking it. 2154 2155 Returns: `true` if normalization completed normally, `false` if 2156 all elements in `range` were zero or if `range` is empty. 2157 */ 2158 bool normalize(R)(R range, ElementType!(R) sum = 1) 2159 if (isForwardRange!(R)) 2160 { 2161 ElementType!(R) s = 0; 2162 // Step 1: Compute sum and length of the range 2163 static if (hasLength!(R)) 2164 { 2165 const length = range.length; 2166 foreach (e; range) 2167 { 2168 s += e; 2169 } 2170 } 2171 else 2172 { 2173 uint length = 0; 2174 foreach (e; range) 2175 { 2176 s += e; 2177 ++length; 2178 } 2179 } 2180 // Step 2: perform normalization 2181 if (s == 0) 2182 { 2183 if (length) 2184 { 2185 immutable f = sum / range.length; 2186 foreach (ref e; range) e = f; 2187 } 2188 return false; 2189 } 2190 // The path most traveled 2191 assert(s >= 0); 2192 immutable f = sum / s; 2193 foreach (ref e; range) 2194 e *= f; 2195 return true; 2196 } 2197 2198 /// 2199 @safe unittest 2200 { 2201 double[] a = []; 2202 assert(!normalize(a)); 2203 a = [ 1.0, 3.0 ]; 2204 assert(normalize(a)); 2205 assert(a == [ 0.25, 0.75 ]); 2206 assert(normalize!(typeof(a))(a, 50)); // a = [12.5, 37.5] 2207 a = [ 0.0, 0.0 ]; 2208 assert(!normalize(a)); 2209 assert(a == [ 0.5, 0.5 ]); 2210 } 2211 2212 /** 2213 Compute the sum of binary logarithms of the input range `r`. 2214 The error of this method is much smaller than with a naive sum of log2. 2215 */ 2216 ElementType!Range sumOfLog2s(Range)(Range r) 2217 if (isInputRange!Range && isFloatingPoint!(ElementType!Range)) 2218 { 2219 long exp = 0; 2220 Unqual!(typeof(return)) x = 1; 2221 foreach (e; r) 2222 { 2223 if (e < 0) 2224 return typeof(return).nan; 2225 int lexp = void; 2226 x *= frexp(e, lexp); 2227 exp += lexp; 2228 if (x < 0.5) 2229 { 2230 x *= 2; 2231 exp--; 2232 } 2233 } 2234 return exp + log2(x); 2235 } 2236 2237 /// 2238 @safe unittest 2239 { 2240 import std.math : isNaN; 2241 2242 assert(sumOfLog2s(new double[0]) == 0); 2243 assert(sumOfLog2s([0.0L]) == -real.infinity); 2244 assert(sumOfLog2s([-0.0L]) == -real.infinity); 2245 assert(sumOfLog2s([2.0L]) == 1); 2246 assert(sumOfLog2s([-2.0L]).isNaN()); 2247 assert(sumOfLog2s([real.nan]).isNaN()); 2248 assert(sumOfLog2s([-real.nan]).isNaN()); 2249 assert(sumOfLog2s([real.infinity]) == real.infinity); 2250 assert(sumOfLog2s([-real.infinity]).isNaN()); 2251 assert(sumOfLog2s([ 0.25, 0.25, 0.25, 0.125 ]) == -9); 2252 } 2253 2254 /** 2255 Computes $(LINK2 https://en.wikipedia.org/wiki/Entropy_(information_theory), 2256 _entropy) of input range `r` in bits. This 2257 function assumes (without checking) that the values in `r` are all 2258 in $(D [0, 1]). For the entropy to be meaningful, often `r` should 2259 be normalized too (i.e., its values should sum to 1). The 2260 two-parameter version stops evaluating as soon as the intermediate 2261 result is greater than or equal to `max`. 2262 */ 2263 ElementType!Range entropy(Range)(Range r) 2264 if (isInputRange!Range) 2265 { 2266 Unqual!(typeof(return)) result = 0.0; 2267 for (;!r.empty; r.popFront) 2268 { 2269 if (!r.front) continue; 2270 result -= r.front * log2(r.front); 2271 } 2272 return result; 2273 } 2274 2275 /// Ditto 2276 ElementType!Range entropy(Range, F)(Range r, F max) 2277 if (isInputRange!Range && 2278 !is(CommonType!(ElementType!Range, F) == void)) 2279 { 2280 Unqual!(typeof(return)) result = 0.0; 2281 for (;!r.empty; r.popFront) 2282 { 2283 if (!r.front) continue; 2284 result -= r.front * log2(r.front); 2285 if (result >= max) break; 2286 } 2287 return result; 2288 } 2289 2290 @safe unittest 2291 { 2292 import std.meta : AliasSeq; 2293 static foreach (T; AliasSeq!(double, const double, immutable double)) 2294 {{ 2295 T[] p = [ 0.0, 0, 0, 1 ]; 2296 assert(entropy(p) == 0); 2297 p = [ 0.25, 0.25, 0.25, 0.25 ]; 2298 assert(entropy(p) == 2); 2299 assert(entropy(p, 1) == 1); 2300 }} 2301 } 2302 2303 /** 2304 Computes the $(LINK2 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence, 2305 Kullback-Leibler divergence) between input ranges 2306 `a` and `b`, which is the sum $(D ai * log(ai / bi)). The base 2307 of logarithm is 2. The ranges are assumed to contain elements in $(D 2308 [0, 1]). Usually the ranges are normalized probability distributions, 2309 but this is not required or checked by $(D 2310 kullbackLeiblerDivergence). If any element `bi` is zero and the 2311 corresponding element `ai` nonzero, returns infinity. (Otherwise, 2312 if $(D ai == 0 && bi == 0), the term $(D ai * log(ai / bi)) is 2313 considered zero.) If the inputs are normalized, the result is 2314 positive. 2315 */ 2316 CommonType!(ElementType!Range1, ElementType!Range2) 2317 kullbackLeiblerDivergence(Range1, Range2)(Range1 a, Range2 b) 2318 if (isInputRange!(Range1) && isInputRange!(Range2)) 2319 { 2320 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2321 static if (haveLen) assert(a.length == b.length); 2322 Unqual!(typeof(return)) result = 0; 2323 for (; !a.empty; a.popFront(), b.popFront()) 2324 { 2325 immutable t1 = a.front; 2326 if (t1 == 0) continue; 2327 immutable t2 = b.front; 2328 if (t2 == 0) return result.infinity; 2329 assert(t1 > 0 && t2 > 0); 2330 result += t1 * log2(t1 / t2); 2331 } 2332 static if (!haveLen) assert(b.empty); 2333 return result; 2334 } 2335 2336 /// 2337 @safe unittest 2338 { 2339 import std.math : approxEqual; 2340 2341 double[] p = [ 0.0, 0, 0, 1 ]; 2342 assert(kullbackLeiblerDivergence(p, p) == 0); 2343 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ]; 2344 assert(kullbackLeiblerDivergence(p1, p1) == 0); 2345 assert(kullbackLeiblerDivergence(p, p1) == 2); 2346 assert(kullbackLeiblerDivergence(p1, p) == double.infinity); 2347 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ]; 2348 assert(approxEqual(kullbackLeiblerDivergence(p1, p2), 0.0719281)); 2349 assert(approxEqual(kullbackLeiblerDivergence(p2, p1), 0.0780719)); 2350 } 2351 2352 /** 2353 Computes the $(LINK2 https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence, 2354 Jensen-Shannon divergence) between `a` and $(D 2355 b), which is the sum $(D (ai * log(2 * ai / (ai + bi)) + bi * log(2 * 2356 bi / (ai + bi))) / 2). The base of logarithm is 2. The ranges are 2357 assumed to contain elements in $(D [0, 1]). Usually the ranges are 2358 normalized probability distributions, but this is not required or 2359 checked by `jensenShannonDivergence`. If the inputs are normalized, 2360 the result is bounded within $(D [0, 1]). The three-parameter version 2361 stops evaluations as soon as the intermediate result is greater than 2362 or equal to `limit`. 2363 */ 2364 CommonType!(ElementType!Range1, ElementType!Range2) 2365 jensenShannonDivergence(Range1, Range2)(Range1 a, Range2 b) 2366 if (isInputRange!Range1 && isInputRange!Range2 && 2367 is(CommonType!(ElementType!Range1, ElementType!Range2))) 2368 { 2369 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2370 static if (haveLen) assert(a.length == b.length); 2371 Unqual!(typeof(return)) result = 0; 2372 for (; !a.empty; a.popFront(), b.popFront()) 2373 { 2374 immutable t1 = a.front; 2375 immutable t2 = b.front; 2376 immutable avg = (t1 + t2) / 2; 2377 if (t1 != 0) 2378 { 2379 result += t1 * log2(t1 / avg); 2380 } 2381 if (t2 != 0) 2382 { 2383 result += t2 * log2(t2 / avg); 2384 } 2385 } 2386 static if (!haveLen) assert(b.empty); 2387 return result / 2; 2388 } 2389 2390 /// Ditto 2391 CommonType!(ElementType!Range1, ElementType!Range2) 2392 jensenShannonDivergence(Range1, Range2, F)(Range1 a, Range2 b, F limit) 2393 if (isInputRange!Range1 && isInputRange!Range2 && 2394 is(typeof(CommonType!(ElementType!Range1, ElementType!Range2).init 2395 >= F.init) : bool)) 2396 { 2397 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2398 static if (haveLen) assert(a.length == b.length); 2399 Unqual!(typeof(return)) result = 0; 2400 limit *= 2; 2401 for (; !a.empty; a.popFront(), b.popFront()) 2402 { 2403 immutable t1 = a.front; 2404 immutable t2 = b.front; 2405 immutable avg = (t1 + t2) / 2; 2406 if (t1 != 0) 2407 { 2408 result += t1 * log2(t1 / avg); 2409 } 2410 if (t2 != 0) 2411 { 2412 result += t2 * log2(t2 / avg); 2413 } 2414 if (result >= limit) break; 2415 } 2416 static if (!haveLen) assert(b.empty); 2417 return result / 2; 2418 } 2419 2420 /// 2421 @safe unittest 2422 { 2423 import std.math : approxEqual; 2424 2425 double[] p = [ 0.0, 0, 0, 1 ]; 2426 assert(jensenShannonDivergence(p, p) == 0); 2427 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ]; 2428 assert(jensenShannonDivergence(p1, p1) == 0); 2429 assert(approxEqual(jensenShannonDivergence(p1, p), 0.548795)); 2430 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ]; 2431 assert(approxEqual(jensenShannonDivergence(p1, p2), 0.0186218)); 2432 assert(approxEqual(jensenShannonDivergence(p2, p1), 0.0186218)); 2433 assert(approxEqual(jensenShannonDivergence(p2, p1, 0.005), 0.00602366)); 2434 } 2435 2436 /** 2437 The so-called "all-lengths gap-weighted string kernel" computes a 2438 similarity measure between `s` and `t` based on all of their 2439 common subsequences of all lengths. Gapped subsequences are also 2440 included. 2441 2442 To understand what $(D gapWeightedSimilarity(s, t, lambda)) computes, 2443 consider first the case $(D lambda = 1) and the strings $(D s = 2444 ["Hello", "brave", "new", "world"]) and $(D t = ["Hello", "new", 2445 "world"]). In that case, `gapWeightedSimilarity` counts the 2446 following matches: 2447 2448 $(OL $(LI three matches of length 1, namely `"Hello"`, `"new"`, 2449 and `"world"`;) $(LI three matches of length 2, namely ($(D 2450 "Hello", "new")), ($(D "Hello", "world")), and ($(D "new", "world"));) 2451 $(LI one match of length 3, namely ($(D "Hello", "new", "world")).)) 2452 2453 The call $(D gapWeightedSimilarity(s, t, 1)) simply counts all of 2454 these matches and adds them up, returning 7. 2455 2456 ---- 2457 string[] s = ["Hello", "brave", "new", "world"]; 2458 string[] t = ["Hello", "new", "world"]; 2459 assert(gapWeightedSimilarity(s, t, 1) == 7); 2460 ---- 2461 2462 Note how the gaps in matching are simply ignored, for example ($(D 2463 "Hello", "new")) is deemed as good a match as ($(D "new", 2464 "world")). This may be too permissive for some applications. To 2465 eliminate gapped matches entirely, use $(D lambda = 0): 2466 2467 ---- 2468 string[] s = ["Hello", "brave", "new", "world"]; 2469 string[] t = ["Hello", "new", "world"]; 2470 assert(gapWeightedSimilarity(s, t, 0) == 4); 2471 ---- 2472 2473 The call above eliminated the gapped matches ($(D "Hello", "new")), 2474 ($(D "Hello", "world")), and ($(D "Hello", "new", "world")) from the 2475 tally. That leaves only 4 matches. 2476 2477 The most interesting case is when gapped matches still participate in 2478 the result, but not as strongly as ungapped matches. The result will 2479 be a smooth, fine-grained similarity measure between the input 2480 strings. This is where values of `lambda` between 0 and 1 enter 2481 into play: gapped matches are $(I exponentially penalized with the 2482 number of gaps) with base `lambda`. This means that an ungapped 2483 match adds 1 to the return value; a match with one gap in either 2484 string adds `lambda` to the return value; ...; a match with a total 2485 of `n` gaps in both strings adds $(D pow(lambda, n)) to the return 2486 value. In the example above, we have 4 matches without gaps, 2 matches 2487 with one gap, and 1 match with three gaps. The latter match is ($(D 2488 "Hello", "world")), which has two gaps in the first string and one gap 2489 in the second string, totaling to three gaps. Summing these up we get 2490 $(D 4 + 2 * lambda + pow(lambda, 3)). 2491 2492 ---- 2493 string[] s = ["Hello", "brave", "new", "world"]; 2494 string[] t = ["Hello", "new", "world"]; 2495 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 0.5 * 2 + 0.125); 2496 ---- 2497 2498 `gapWeightedSimilarity` is useful wherever a smooth similarity 2499 measure between sequences allowing for approximate matches is 2500 needed. The examples above are given with words, but any sequences 2501 with elements comparable for equality are allowed, e.g. characters or 2502 numbers. `gapWeightedSimilarity` uses a highly optimized dynamic 2503 programming implementation that needs $(D 16 * min(s.length, 2504 t.length)) extra bytes of memory and $(BIGOH s.length * t.length) time 2505 to complete. 2506 */ 2507 F gapWeightedSimilarity(alias comp = "a == b", R1, R2, F)(R1 s, R2 t, F lambda) 2508 if (isRandomAccessRange!(R1) && hasLength!(R1) && 2509 isRandomAccessRange!(R2) && hasLength!(R2)) 2510 { 2511 import core.exception : onOutOfMemoryError; 2512 import core.stdc.stdlib : malloc, free; 2513 import std.algorithm.mutation : swap; 2514 import std.functional : binaryFun; 2515 2516 if (s.length < t.length) return gapWeightedSimilarity(t, s, lambda); 2517 if (!t.length) return 0; 2518 2519 auto dpvi = cast(F*) malloc(F.sizeof * 2 * t.length); 2520 if (!dpvi) 2521 onOutOfMemoryError(); 2522 2523 auto dpvi1 = dpvi + t.length; 2524 scope(exit) free(dpvi < dpvi1 ? dpvi : dpvi1); 2525 dpvi[0 .. t.length] = 0; 2526 dpvi1[0] = 0; 2527 immutable lambda2 = lambda * lambda; 2528 2529 F result = 0; 2530 foreach (i; 0 .. s.length) 2531 { 2532 const si = s[i]; 2533 for (size_t j = 0;;) 2534 { 2535 F dpsij = void; 2536 if (binaryFun!(comp)(si, t[j])) 2537 { 2538 dpsij = 1 + dpvi[j]; 2539 result += dpsij; 2540 } 2541 else 2542 { 2543 dpsij = 0; 2544 } 2545 immutable j1 = j + 1; 2546 if (j1 == t.length) break; 2547 dpvi1[j1] = dpsij + lambda * (dpvi1[j] + dpvi[j1]) - 2548 lambda2 * dpvi[j]; 2549 j = j1; 2550 } 2551 swap(dpvi, dpvi1); 2552 } 2553 return result; 2554 } 2555 2556 @system unittest 2557 { 2558 string[] s = ["Hello", "brave", "new", "world"]; 2559 string[] t = ["Hello", "new", "world"]; 2560 assert(gapWeightedSimilarity(s, t, 1) == 7); 2561 assert(gapWeightedSimilarity(s, t, 0) == 4); 2562 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 2 * 0.5 + 0.125); 2563 } 2564 2565 /** 2566 The similarity per `gapWeightedSimilarity` has an issue in that it 2567 grows with the lengths of the two strings, even though the strings are 2568 not actually very similar. For example, the range $(D ["Hello", 2569 "world"]) is increasingly similar with the range $(D ["Hello", 2570 "world", "world", "world",...]) as more instances of `"world"` are 2571 appended. To prevent that, `gapWeightedSimilarityNormalized` 2572 computes a normalized version of the similarity that is computed as 2573 $(D gapWeightedSimilarity(s, t, lambda) / 2574 sqrt(gapWeightedSimilarity(s, t, lambda) * gapWeightedSimilarity(s, t, 2575 lambda))). The function `gapWeightedSimilarityNormalized` (a 2576 so-called normalized kernel) is bounded in $(D [0, 1]), reaches `0` 2577 only for ranges that don't match in any position, and `1` only for 2578 identical ranges. 2579 2580 The optional parameters `sSelfSim` and `tSelfSim` are meant for 2581 avoiding duplicate computation. Many applications may have already 2582 computed $(D gapWeightedSimilarity(s, s, lambda)) and/or $(D 2583 gapWeightedSimilarity(t, t, lambda)). In that case, they can be passed 2584 as `sSelfSim` and `tSelfSim`, respectively. 2585 */ 2586 Select!(isFloatingPoint!(F), F, double) 2587 gapWeightedSimilarityNormalized(alias comp = "a == b", R1, R2, F) 2588 (R1 s, R2 t, F lambda, F sSelfSim = F.init, F tSelfSim = F.init) 2589 if (isRandomAccessRange!(R1) && hasLength!(R1) && 2590 isRandomAccessRange!(R2) && hasLength!(R2)) 2591 { 2592 static bool uncomputed(F n) 2593 { 2594 static if (isFloatingPoint!(F)) 2595 return isNaN(n); 2596 else 2597 return n == n.init; 2598 } 2599 if (uncomputed(sSelfSim)) 2600 sSelfSim = gapWeightedSimilarity!(comp)(s, s, lambda); 2601 if (sSelfSim == 0) return 0; 2602 if (uncomputed(tSelfSim)) 2603 tSelfSim = gapWeightedSimilarity!(comp)(t, t, lambda); 2604 if (tSelfSim == 0) return 0; 2605 2606 return gapWeightedSimilarity!(comp)(s, t, lambda) / 2607 sqrt(cast(typeof(return)) sSelfSim * tSelfSim); 2608 } 2609 2610 /// 2611 @system unittest 2612 { 2613 import std.math : approxEqual, sqrt; 2614 2615 string[] s = ["Hello", "brave", "new", "world"]; 2616 string[] t = ["Hello", "new", "world"]; 2617 assert(gapWeightedSimilarity(s, s, 1) == 15); 2618 assert(gapWeightedSimilarity(t, t, 1) == 7); 2619 assert(gapWeightedSimilarity(s, t, 1) == 7); 2620 assert(approxEqual(gapWeightedSimilarityNormalized(s, t, 1), 2621 7.0 / sqrt(15.0 * 7), 0.01)); 2622 } 2623 2624 /** 2625 Similar to `gapWeightedSimilarity`, just works in an incremental 2626 manner by first revealing the matches of length 1, then gapped matches 2627 of length 2, and so on. The memory requirement is $(BIGOH s.length * 2628 t.length). The time complexity is $(BIGOH s.length * t.length) time 2629 for computing each step. Continuing on the previous example: 2630 2631 The implementation is based on the pseudocode in Fig. 4 of the paper 2632 $(HTTP jmlr.csail.mit.edu/papers/volume6/rousu05a/rousu05a.pdf, 2633 "Efficient Computation of Gapped Substring Kernels on Large Alphabets") 2634 by Rousu et al., with additional algorithmic and systems-level 2635 optimizations. 2636 */ 2637 struct GapWeightedSimilarityIncremental(Range, F = double) 2638 if (isRandomAccessRange!(Range) && hasLength!(Range)) 2639 { 2640 import core.stdc.stdlib : malloc, realloc, alloca, free; 2641 2642 private: 2643 Range s, t; 2644 F currentValue = 0; 2645 F* kl; 2646 size_t gram = void; 2647 F lambda = void, lambda2 = void; 2648 2649 public: 2650 /** 2651 Constructs an object given two ranges `s` and `t` and a penalty 2652 `lambda`. Constructor completes in $(BIGOH s.length * t.length) 2653 time and computes all matches of length 1. 2654 */ 2655 this(Range s, Range t, F lambda) 2656 { 2657 import core.exception : onOutOfMemoryError; 2658 2659 assert(lambda > 0); 2660 this.gram = 0; 2661 this.lambda = lambda; 2662 this.lambda2 = lambda * lambda; // for efficiency only 2663 2664 size_t iMin = size_t.max, jMin = size_t.max, 2665 iMax = 0, jMax = 0; 2666 /* initialize */ 2667 Tuple!(size_t, size_t) * k0; 2668 size_t k0len; 2669 scope(exit) free(k0); 2670 currentValue = 0; 2671 foreach (i, si; s) 2672 { 2673 foreach (j; 0 .. t.length) 2674 { 2675 if (si != t[j]) continue; 2676 k0 = cast(typeof(k0)) realloc(k0, ++k0len * (*k0).sizeof); 2677 with (k0[k0len - 1]) 2678 { 2679 field[0] = i; 2680 field[1] = j; 2681 } 2682 // Maintain the minimum and maximum i and j 2683 if (iMin > i) iMin = i; 2684 if (iMax < i) iMax = i; 2685 if (jMin > j) jMin = j; 2686 if (jMax < j) jMax = j; 2687 } 2688 } 2689 2690 if (iMin > iMax) return; 2691 assert(k0len); 2692 2693 currentValue = k0len; 2694 // Chop strings down to the useful sizes 2695 s = s[iMin .. iMax + 1]; 2696 t = t[jMin .. jMax + 1]; 2697 this.s = s; 2698 this.t = t; 2699 2700 kl = cast(F*) malloc(s.length * t.length * F.sizeof); 2701 if (!kl) 2702 onOutOfMemoryError(); 2703 2704 kl[0 .. s.length * t.length] = 0; 2705 foreach (pos; 0 .. k0len) 2706 { 2707 with (k0[pos]) 2708 { 2709 kl[(field[0] - iMin) * t.length + field[1] -jMin] = lambda2; 2710 } 2711 } 2712 } 2713 2714 /** 2715 Returns: `this`. 2716 */ 2717 ref GapWeightedSimilarityIncremental opSlice() 2718 { 2719 return this; 2720 } 2721 2722 /** 2723 Computes the match of the popFront length. Completes in $(BIGOH s.length * 2724 t.length) time. 2725 */ 2726 void popFront() 2727 { 2728 import std.algorithm.mutation : swap; 2729 2730 // This is a large source of optimization: if similarity at 2731 // the gram-1 level was 0, then we can safely assume 2732 // similarity at the gram level is 0 as well. 2733 if (empty) return; 2734 2735 // Now attempt to match gapped substrings of length `gram' 2736 ++gram; 2737 currentValue = 0; 2738 2739 auto Si = cast(F*) alloca(t.length * F.sizeof); 2740 Si[0 .. t.length] = 0; 2741 foreach (i; 0 .. s.length) 2742 { 2743 const si = s[i]; 2744 F Sij_1 = 0; 2745 F Si_1j_1 = 0; 2746 auto kli = kl + i * t.length; 2747 for (size_t j = 0;;) 2748 { 2749 const klij = kli[j]; 2750 const Si_1j = Si[j]; 2751 const tmp = klij + lambda * (Si_1j + Sij_1) - lambda2 * Si_1j_1; 2752 // now update kl and currentValue 2753 if (si == t[j]) 2754 currentValue += kli[j] = lambda2 * Si_1j_1; 2755 else 2756 kli[j] = 0; 2757 // commit to Si 2758 Si[j] = tmp; 2759 if (++j == t.length) break; 2760 // get ready for the popFront step; virtually increment j, 2761 // so essentially stuffj_1 <-- stuffj 2762 Si_1j_1 = Si_1j; 2763 Sij_1 = tmp; 2764 } 2765 } 2766 currentValue /= pow(lambda, 2 * (gram + 1)); 2767 2768 version (none) 2769 { 2770 Si_1[0 .. t.length] = 0; 2771 kl[0 .. min(t.length, maxPerimeter + 1)] = 0; 2772 foreach (i; 1 .. min(s.length, maxPerimeter + 1)) 2773 { 2774 auto kli = kl + i * t.length; 2775 assert(s.length > i); 2776 const si = s[i]; 2777 auto kl_1i_1 = kl_1 + (i - 1) * t.length; 2778 kli[0] = 0; 2779 F lastS = 0; 2780 foreach (j; 1 .. min(maxPerimeter - i + 1, t.length)) 2781 { 2782 immutable j_1 = j - 1; 2783 immutable tmp = kl_1i_1[j_1] 2784 + lambda * (Si_1[j] + lastS) 2785 - lambda2 * Si_1[j_1]; 2786 kl_1i_1[j_1] = float.nan; 2787 Si_1[j_1] = lastS; 2788 lastS = tmp; 2789 if (si == t[j]) 2790 { 2791 currentValue += kli[j] = lambda2 * lastS; 2792 } 2793 else 2794 { 2795 kli[j] = 0; 2796 } 2797 } 2798 Si_1[t.length - 1] = lastS; 2799 } 2800 currentValue /= pow(lambda, 2 * (gram + 1)); 2801 // get ready for the popFront computation 2802 swap(kl, kl_1); 2803 } 2804 } 2805 2806 /** 2807 Returns: The gapped similarity at the current match length (initially 2808 1, grows with each call to `popFront`). 2809 */ 2810 @property F front() { return currentValue; } 2811 2812 /** 2813 Returns: Whether there are more matches. 2814 */ 2815 @property bool empty() 2816 { 2817 if (currentValue) return false; 2818 if (kl) 2819 { 2820 free(kl); 2821 kl = null; 2822 } 2823 return true; 2824 } 2825 } 2826 2827 /** 2828 Ditto 2829 */ 2830 GapWeightedSimilarityIncremental!(R, F) gapWeightedSimilarityIncremental(R, F) 2831 (R r1, R r2, F penalty) 2832 { 2833 return typeof(return)(r1, r2, penalty); 2834 } 2835 2836 /// 2837 @system unittest 2838 { 2839 string[] s = ["Hello", "brave", "new", "world"]; 2840 string[] t = ["Hello", "new", "world"]; 2841 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0); 2842 assert(simIter.front == 3); // three 1-length matches 2843 simIter.popFront(); 2844 assert(simIter.front == 3); // three 2-length matches 2845 simIter.popFront(); 2846 assert(simIter.front == 1); // one 3-length match 2847 simIter.popFront(); 2848 assert(simIter.empty); // no more match 2849 } 2850 2851 @system unittest 2852 { 2853 import std.conv : text; 2854 string[] s = ["Hello", "brave", "new", "world"]; 2855 string[] t = ["Hello", "new", "world"]; 2856 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0); 2857 //foreach (e; simIter) writeln(e); 2858 assert(simIter.front == 3); // three 1-length matches 2859 simIter.popFront(); 2860 assert(simIter.front == 3, text(simIter.front)); // three 2-length matches 2861 simIter.popFront(); 2862 assert(simIter.front == 1); // one 3-length matches 2863 simIter.popFront(); 2864 assert(simIter.empty); // no more match 2865 2866 s = ["Hello"]; 2867 t = ["bye"]; 2868 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2869 assert(simIter.empty); 2870 2871 s = ["Hello"]; 2872 t = ["Hello"]; 2873 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2874 assert(simIter.front == 1); // one match 2875 simIter.popFront(); 2876 assert(simIter.empty); 2877 2878 s = ["Hello", "world"]; 2879 t = ["Hello"]; 2880 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2881 assert(simIter.front == 1); // one match 2882 simIter.popFront(); 2883 assert(simIter.empty); 2884 2885 s = ["Hello", "world"]; 2886 t = ["Hello", "yah", "world"]; 2887 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2888 assert(simIter.front == 2); // two 1-gram matches 2889 simIter.popFront(); 2890 assert(simIter.front == 0.5, text(simIter.front)); // one 2-gram match, 1 gap 2891 } 2892 2893 @system unittest 2894 { 2895 GapWeightedSimilarityIncremental!(string[]) sim = 2896 GapWeightedSimilarityIncremental!(string[])( 2897 ["nyuk", "I", "have", "no", "chocolate", "giba"], 2898 ["wyda", "I", "have", "I", "have", "have", "I", "have", "hehe"], 2899 0.5); 2900 double[] witness = [ 7.0, 4.03125, 0, 0 ]; 2901 foreach (e; sim) 2902 { 2903 //writeln(e); 2904 assert(e == witness.front); 2905 witness.popFront(); 2906 } 2907 witness = [ 3.0, 1.3125, 0.25 ]; 2908 sim = GapWeightedSimilarityIncremental!(string[])( 2909 ["I", "have", "no", "chocolate"], 2910 ["I", "have", "some", "chocolate"], 2911 0.5); 2912 foreach (e; sim) 2913 { 2914 //writeln(e); 2915 assert(e == witness.front); 2916 witness.popFront(); 2917 } 2918 assert(witness.empty); 2919 } 2920 2921 /** 2922 Computes the greatest common divisor of `a` and `b` by using 2923 an efficient algorithm such as $(HTTPS en.wikipedia.org/wiki/Euclidean_algorithm, Euclid's) 2924 or $(HTTPS en.wikipedia.org/wiki/Binary_GCD_algorithm, Stein's) algorithm. 2925 2926 Params: 2927 T = Any numerical type that supports the modulo operator `%`. If 2928 bit-shifting `<<` and `>>` are also supported, Stein's algorithm will 2929 be used; otherwise, Euclid's algorithm is used as _a fallback. 2930 Returns: 2931 The greatest common divisor of the given arguments. 2932 */ 2933 T gcd(T)(T a, T b) 2934 if (isIntegral!T) 2935 { 2936 static if (is(T == const) || is(T == immutable)) 2937 { 2938 return gcd!(Unqual!T)(a, b); 2939 } 2940 else version (DigitalMars) 2941 { 2942 static if (T.min < 0) 2943 { 2944 assert(a >= 0 && b >= 0); 2945 } 2946 while (b) 2947 { 2948 immutable t = b; 2949 b = a % b; 2950 a = t; 2951 } 2952 return a; 2953 } 2954 else 2955 { 2956 if (a == 0) 2957 return b; 2958 if (b == 0) 2959 return a; 2960 2961 import core.bitop : bsf; 2962 import std.algorithm.mutation : swap; 2963 2964 immutable uint shift = bsf(a | b); 2965 a >>= a.bsf; 2966 2967 do 2968 { 2969 b >>= b.bsf; 2970 if (a > b) 2971 swap(a, b); 2972 b -= a; 2973 } while (b); 2974 2975 return a << shift; 2976 } 2977 } 2978 2979 /// 2980 @safe unittest 2981 { 2982 assert(gcd(2 * 5 * 7 * 7, 5 * 7 * 11) == 5 * 7); 2983 const int a = 5 * 13 * 23 * 23, b = 13 * 59; 2984 assert(gcd(a, b) == 13); 2985 } 2986 2987 // This overload is for non-builtin numerical types like BigInt or 2988 // user-defined types. 2989 /// ditto 2990 T gcd(T)(T a, T b) 2991 if (!isIntegral!T && 2992 is(typeof(T.init % T.init)) && 2993 is(typeof(T.init == 0 || T.init > 0))) 2994 { 2995 import std.algorithm.mutation : swap; 2996 2997 enum canUseBinaryGcd = is(typeof(() { 2998 T t, u; 2999 t <<= 1; 3000 t >>= 1; 3001 t -= u; 3002 bool b = (t & 1) == 0; 3003 swap(t, u); 3004 })); 3005 3006 assert(a >= 0 && b >= 0); 3007 3008 // Special cases. 3009 if (a == 0) 3010 return b; 3011 if (b == 0) 3012 return a; 3013 3014 static if (canUseBinaryGcd) 3015 { 3016 uint shift = 0; 3017 while ((a & 1) == 0 && (b & 1) == 0) 3018 { 3019 a >>= 1; 3020 b >>= 1; 3021 shift++; 3022 } 3023 3024 if ((a & 1) == 0) swap(a, b); 3025 3026 do 3027 { 3028 assert((a & 1) != 0); 3029 while ((b & 1) == 0) 3030 b >>= 1; 3031 if (a > b) 3032 swap(a, b); 3033 b -= a; 3034 } while (b); 3035 3036 return a << shift; 3037 } 3038 else 3039 { 3040 // The only thing we have is %; fallback to Euclidean algorithm. 3041 while (b != 0) 3042 { 3043 auto t = b; 3044 b = a % b; 3045 a = t; 3046 } 3047 return a; 3048 } 3049 } 3050 3051 // https://issues.dlang.org/show_bug.cgi?id=7102 3052 @system pure unittest 3053 { 3054 import std.bigint : BigInt; 3055 assert(gcd(BigInt("71_000_000_000_000_000_000"), 3056 BigInt("31_000_000_000_000_000_000")) == 3057 BigInt("1_000_000_000_000_000_000")); 3058 3059 assert(gcd(BigInt(0), BigInt(1234567)) == BigInt(1234567)); 3060 assert(gcd(BigInt(1234567), BigInt(0)) == BigInt(1234567)); 3061 } 3062 3063 @safe pure nothrow unittest 3064 { 3065 // A numerical type that only supports % and - (to force gcd implementation 3066 // to use Euclidean algorithm). 3067 struct CrippledInt 3068 { 3069 int impl; 3070 CrippledInt opBinary(string op : "%")(CrippledInt i) 3071 { 3072 return CrippledInt(impl % i.impl); 3073 } 3074 int opEquals(CrippledInt i) { return impl == i.impl; } 3075 int opEquals(int i) { return impl == i; } 3076 int opCmp(int i) { return (impl < i) ? -1 : (impl > i) ? 1 : 0; } 3077 } 3078 assert(gcd(CrippledInt(2310), CrippledInt(1309)) == CrippledInt(77)); 3079 } 3080 3081 // https://issues.dlang.org/show_bug.cgi?id=19514 3082 @system pure unittest 3083 { 3084 import std.bigint : BigInt; 3085 assert(gcd(BigInt(2), BigInt(1)) == BigInt(1)); 3086 } 3087 3088 // This is to make tweaking the speed/size vs. accuracy tradeoff easy, 3089 // though floats seem accurate enough for all practical purposes, since 3090 // they pass the "approxEqual(inverseFft(fft(arr)), arr)" test even for 3091 // size 2 ^^ 22. 3092 private alias lookup_t = float; 3093 3094 /**A class for performing fast Fourier transforms of power of two sizes. 3095 * This class encapsulates a large amount of state that is reusable when 3096 * performing multiple FFTs of sizes smaller than or equal to that specified 3097 * in the constructor. This results in substantial speedups when performing 3098 * multiple FFTs with a known maximum size. However, 3099 * a free function API is provided for convenience if you need to perform a 3100 * one-off FFT. 3101 * 3102 * References: 3103 * $(HTTP en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm) 3104 */ 3105 final class Fft 3106 { 3107 import core.bitop : bsf; 3108 import std.algorithm.iteration : map; 3109 import std.array : uninitializedArray; 3110 3111 private: 3112 immutable lookup_t[][] negSinLookup; 3113 3114 void enforceSize(R)(R range) const 3115 { 3116 import std.conv : text; 3117 assert(range.length <= size, text( 3118 "FFT size mismatch. Expected ", size, ", got ", range.length)); 3119 } 3120 3121 void fftImpl(Ret, R)(Stride!R range, Ret buf) const 3122 in 3123 { 3124 assert(range.length >= 4); 3125 assert(isPowerOf2(range.length)); 3126 } 3127 do 3128 { 3129 auto recurseRange = range; 3130 recurseRange.doubleSteps(); 3131 3132 if (buf.length > 4) 3133 { 3134 fftImpl(recurseRange, buf[0..$ / 2]); 3135 recurseRange.popHalf(); 3136 fftImpl(recurseRange, buf[$ / 2..$]); 3137 } 3138 else 3139 { 3140 // Do this here instead of in another recursion to save on 3141 // recursion overhead. 3142 slowFourier2(recurseRange, buf[0..$ / 2]); 3143 recurseRange.popHalf(); 3144 slowFourier2(recurseRange, buf[$ / 2..$]); 3145 } 3146 3147 butterfly(buf); 3148 } 3149 3150 // This algorithm works by performing the even and odd parts of our FFT 3151 // using the "two for the price of one" method mentioned at 3152 // http://www.engineeringproductivitytools.com/stuff/T0001/PT10.HTM#Head521 3153 // by making the odd terms into the imaginary components of our new FFT, 3154 // and then using symmetry to recombine them. 3155 void fftImplPureReal(Ret, R)(R range, Ret buf) const 3156 in 3157 { 3158 assert(range.length >= 4); 3159 assert(isPowerOf2(range.length)); 3160 } 3161 do 3162 { 3163 alias E = ElementType!R; 3164 3165 // Converts odd indices of range to the imaginary components of 3166 // a range half the size. The even indices become the real components. 3167 static if (isArray!R && isFloatingPoint!E) 3168 { 3169 // Then the memory layout of complex numbers provides a dirt 3170 // cheap way to convert. This is a common case, so take advantage. 3171 auto oddsImag = cast(Complex!E[]) range; 3172 } 3173 else 3174 { 3175 // General case: Use a higher order range. We can assume 3176 // source.length is even because it has to be a power of 2. 3177 static struct OddToImaginary 3178 { 3179 R source; 3180 alias C = Complex!(CommonType!(E, typeof(buf[0].re))); 3181 3182 @property 3183 { 3184 C front() 3185 { 3186 return C(source[0], source[1]); 3187 } 3188 3189 C back() 3190 { 3191 immutable n = source.length; 3192 return C(source[n - 2], source[n - 1]); 3193 } 3194 3195 typeof(this) save() 3196 { 3197 return typeof(this)(source.save); 3198 } 3199 3200 bool empty() 3201 { 3202 return source.empty; 3203 } 3204 3205 size_t length() 3206 { 3207 return source.length / 2; 3208 } 3209 } 3210 3211 void popFront() 3212 { 3213 source.popFront(); 3214 source.popFront(); 3215 } 3216 3217 void popBack() 3218 { 3219 source.popBack(); 3220 source.popBack(); 3221 } 3222 3223 C opIndex(size_t index) 3224 { 3225 return C(source[index * 2], source[index * 2 + 1]); 3226 } 3227 3228 typeof(this) opSlice(size_t lower, size_t upper) 3229 { 3230 return typeof(this)(source[lower * 2 .. upper * 2]); 3231 } 3232 } 3233 3234 auto oddsImag = OddToImaginary(range); 3235 } 3236 3237 fft(oddsImag, buf[0..$ / 2]); 3238 auto evenFft = buf[0..$ / 2]; 3239 auto oddFft = buf[$ / 2..$]; 3240 immutable halfN = evenFft.length; 3241 oddFft[0].re = buf[0].im; 3242 oddFft[0].im = 0; 3243 evenFft[0].im = 0; 3244 // evenFft[0].re is already right b/c it's aliased with buf[0].re. 3245 3246 foreach (k; 1 .. halfN / 2 + 1) 3247 { 3248 immutable bufk = buf[k]; 3249 immutable bufnk = buf[buf.length / 2 - k]; 3250 evenFft[k].re = 0.5 * (bufk.re + bufnk.re); 3251 evenFft[halfN - k].re = evenFft[k].re; 3252 evenFft[k].im = 0.5 * (bufk.im - bufnk.im); 3253 evenFft[halfN - k].im = -evenFft[k].im; 3254 3255 oddFft[k].re = 0.5 * (bufk.im + bufnk.im); 3256 oddFft[halfN - k].re = oddFft[k].re; 3257 oddFft[k].im = 0.5 * (bufnk.re - bufk.re); 3258 oddFft[halfN - k].im = -oddFft[k].im; 3259 } 3260 3261 butterfly(buf); 3262 } 3263 3264 void butterfly(R)(R buf) const 3265 in 3266 { 3267 assert(isPowerOf2(buf.length)); 3268 } 3269 do 3270 { 3271 immutable n = buf.length; 3272 immutable localLookup = negSinLookup[bsf(n)]; 3273 assert(localLookup.length == n); 3274 3275 immutable cosMask = n - 1; 3276 immutable cosAdd = n / 4 * 3; 3277 3278 lookup_t negSinFromLookup(size_t index) pure nothrow 3279 { 3280 return localLookup[index]; 3281 } 3282 3283 lookup_t cosFromLookup(size_t index) pure nothrow 3284 { 3285 // cos is just -sin shifted by PI * 3 / 2. 3286 return localLookup[(index + cosAdd) & cosMask]; 3287 } 3288 3289 immutable halfLen = n / 2; 3290 3291 // This loop is unrolled and the two iterations are interleaved 3292 // relative to the textbook FFT to increase ILP. This gives roughly 5% 3293 // speedups on DMD. 3294 for (size_t k = 0; k < halfLen; k += 2) 3295 { 3296 immutable cosTwiddle1 = cosFromLookup(k); 3297 immutable sinTwiddle1 = negSinFromLookup(k); 3298 immutable cosTwiddle2 = cosFromLookup(k + 1); 3299 immutable sinTwiddle2 = negSinFromLookup(k + 1); 3300 3301 immutable realLower1 = buf[k].re; 3302 immutable imagLower1 = buf[k].im; 3303 immutable realLower2 = buf[k + 1].re; 3304 immutable imagLower2 = buf[k + 1].im; 3305 3306 immutable upperIndex1 = k + halfLen; 3307 immutable upperIndex2 = upperIndex1 + 1; 3308 immutable realUpper1 = buf[upperIndex1].re; 3309 immutable imagUpper1 = buf[upperIndex1].im; 3310 immutable realUpper2 = buf[upperIndex2].re; 3311 immutable imagUpper2 = buf[upperIndex2].im; 3312 3313 immutable realAdd1 = cosTwiddle1 * realUpper1 3314 - sinTwiddle1 * imagUpper1; 3315 immutable imagAdd1 = sinTwiddle1 * realUpper1 3316 + cosTwiddle1 * imagUpper1; 3317 immutable realAdd2 = cosTwiddle2 * realUpper2 3318 - sinTwiddle2 * imagUpper2; 3319 immutable imagAdd2 = sinTwiddle2 * realUpper2 3320 + cosTwiddle2 * imagUpper2; 3321 3322 buf[k].re += realAdd1; 3323 buf[k].im += imagAdd1; 3324 buf[k + 1].re += realAdd2; 3325 buf[k + 1].im += imagAdd2; 3326 3327 buf[upperIndex1].re = realLower1 - realAdd1; 3328 buf[upperIndex1].im = imagLower1 - imagAdd1; 3329 buf[upperIndex2].re = realLower2 - realAdd2; 3330 buf[upperIndex2].im = imagLower2 - imagAdd2; 3331 } 3332 } 3333 3334 // This constructor is used within this module for allocating the 3335 // buffer space elsewhere besides the GC heap. It's definitely **NOT** 3336 // part of the public API and definitely **IS** subject to change. 3337 // 3338 // Also, this is unsafe because the memSpace buffer will be cast 3339 // to immutable. 3340 // 3341 // Public b/c of https://issues.dlang.org/show_bug.cgi?id=4636. 3342 public this(lookup_t[] memSpace) 3343 { 3344 immutable size = memSpace.length / 2; 3345 3346 /* Create a lookup table of all negative sine values at a resolution of 3347 * size and all smaller power of two resolutions. This may seem 3348 * inefficient, but having all the lookups be next to each other in 3349 * memory at every level of iteration is a huge win performance-wise. 3350 */ 3351 if (size == 0) 3352 { 3353 return; 3354 } 3355 3356 assert(isPowerOf2(size), 3357 "Can only do FFTs on ranges with a size that is a power of two."); 3358 3359 auto table = new lookup_t[][bsf(size) + 1]; 3360 3361 table[$ - 1] = memSpace[$ - size..$]; 3362 memSpace = memSpace[0 .. size]; 3363 3364 auto lastRow = table[$ - 1]; 3365 lastRow[0] = 0; // -sin(0) == 0. 3366 foreach (ptrdiff_t i; 1 .. size) 3367 { 3368 // The hard coded cases are for improved accuracy and to prevent 3369 // annoying non-zeroness when stuff should be zero. 3370 3371 if (i == size / 4) 3372 lastRow[i] = -1; // -sin(pi / 2) == -1. 3373 else if (i == size / 2) 3374 lastRow[i] = 0; // -sin(pi) == 0. 3375 else if (i == size * 3 / 4) 3376 lastRow[i] = 1; // -sin(pi * 3 / 2) == 1 3377 else 3378 lastRow[i] = -sin(i * 2.0L * PI / size); 3379 } 3380 3381 // Fill in all the other rows with strided versions. 3382 foreach (i; 1 .. table.length - 1) 3383 { 3384 immutable strideLength = size / (2 ^^ i); 3385 auto strided = Stride!(lookup_t[])(lastRow, strideLength); 3386 table[i] = memSpace[$ - strided.length..$]; 3387 memSpace = memSpace[0..$ - strided.length]; 3388 3389 size_t copyIndex; 3390 foreach (elem; strided) 3391 { 3392 table[i][copyIndex++] = elem; 3393 } 3394 } 3395 3396 negSinLookup = cast(immutable) table; 3397 } 3398 3399 public: 3400 /**Create an `Fft` object for computing fast Fourier transforms of 3401 * power of two sizes of `size` or smaller. `size` must be a 3402 * power of two. 3403 */ 3404 this(size_t size) 3405 { 3406 // Allocate all twiddle factor buffers in one contiguous block so that, 3407 // when one is done being used, the next one is next in cache. 3408 auto memSpace = uninitializedArray!(lookup_t[])(2 * size); 3409 this(memSpace); 3410 } 3411 3412 @property size_t size() const 3413 { 3414 return (negSinLookup is null) ? 0 : negSinLookup[$ - 1].length; 3415 } 3416 3417 /**Compute the Fourier transform of range using the $(BIGOH N log N) 3418 * Cooley-Tukey Algorithm. `range` must be a random-access range with 3419 * slicing and a length equal to `size` as provided at the construction of 3420 * this object. The contents of range can be either numeric types, 3421 * which will be interpreted as pure real values, or complex types with 3422 * properties or members `.re` and `.im` that can be read. 3423 * 3424 * Note: Pure real FFTs are automatically detected and the relevant 3425 * optimizations are performed. 3426 * 3427 * Returns: An array of complex numbers representing the transformed data in 3428 * the frequency domain. 3429 * 3430 * Conventions: The exponent is negative and the factor is one, 3431 * i.e., output[j] := sum[ exp(-2 PI i j k / N) input[k] ]. 3432 */ 3433 Complex!F[] fft(F = double, R)(R range) const 3434 if (isFloatingPoint!F && isRandomAccessRange!R) 3435 { 3436 enforceSize(range); 3437 Complex!F[] ret; 3438 if (range.length == 0) 3439 { 3440 return ret; 3441 } 3442 3443 // Don't waste time initializing the memory for ret. 3444 ret = uninitializedArray!(Complex!F[])(range.length); 3445 3446 fft(range, ret); 3447 return ret; 3448 } 3449 3450 /**Same as the overload, but allows for the results to be stored in a user- 3451 * provided buffer. The buffer must be of the same length as range, must be 3452 * a random-access range, must have slicing, and must contain elements that are 3453 * complex-like. This means that they must have a .re and a .im member or 3454 * property that can be both read and written and are floating point numbers. 3455 */ 3456 void fft(Ret, R)(R range, Ret buf) const 3457 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret) 3458 { 3459 assert(buf.length == range.length); 3460 enforceSize(range); 3461 3462 if (range.length == 0) 3463 { 3464 return; 3465 } 3466 else if (range.length == 1) 3467 { 3468 buf[0] = range[0]; 3469 return; 3470 } 3471 else if (range.length == 2) 3472 { 3473 slowFourier2(range, buf); 3474 return; 3475 } 3476 else 3477 { 3478 alias E = ElementType!R; 3479 static if (is(E : real)) 3480 { 3481 return fftImplPureReal(range, buf); 3482 } 3483 else 3484 { 3485 static if (is(R : Stride!R)) 3486 return fftImpl(range, buf); 3487 else 3488 return fftImpl(Stride!R(range, 1), buf); 3489 } 3490 } 3491 } 3492 3493 /** 3494 * Computes the inverse Fourier transform of a range. The range must be a 3495 * random access range with slicing, have a length equal to the size 3496 * provided at construction of this object, and contain elements that are 3497 * either of type std.complex.Complex or have essentially 3498 * the same compile-time interface. 3499 * 3500 * Returns: The time-domain signal. 3501 * 3502 * Conventions: The exponent is positive and the factor is 1/N, i.e., 3503 * output[j] := (1 / N) sum[ exp(+2 PI i j k / N) input[k] ]. 3504 */ 3505 Complex!F[] inverseFft(F = double, R)(R range) const 3506 if (isRandomAccessRange!R && isComplexLike!(ElementType!R) && isFloatingPoint!F) 3507 { 3508 enforceSize(range); 3509 Complex!F[] ret; 3510 if (range.length == 0) 3511 { 3512 return ret; 3513 } 3514 3515 // Don't waste time initializing the memory for ret. 3516 ret = uninitializedArray!(Complex!F[])(range.length); 3517 3518 inverseFft(range, ret); 3519 return ret; 3520 } 3521 3522 /** 3523 * Inverse FFT that allows a user-supplied buffer to be provided. The buffer 3524 * must be a random access range with slicing, and its elements 3525 * must be some complex-like type. 3526 */ 3527 void inverseFft(Ret, R)(R range, Ret buf) const 3528 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret) 3529 { 3530 enforceSize(range); 3531 3532 auto swapped = map!swapRealImag(range); 3533 fft(swapped, buf); 3534 3535 immutable lenNeg1 = 1.0 / buf.length; 3536 foreach (ref elem; buf) 3537 { 3538 immutable temp = elem.re * lenNeg1; 3539 elem.re = elem.im * lenNeg1; 3540 elem.im = temp; 3541 } 3542 } 3543 } 3544 3545 // This mixin creates an Fft object in the scope it's mixed into such that all 3546 // memory owned by the object is deterministically destroyed at the end of that 3547 // scope. 3548 private enum string MakeLocalFft = q{ 3549 import core.stdc.stdlib; 3550 import core.exception : onOutOfMemoryError; 3551 3552 auto lookupBuf = (cast(lookup_t*) malloc(range.length * 2 * lookup_t.sizeof)) 3553 [0 .. 2 * range.length]; 3554 if (!lookupBuf.ptr) 3555 onOutOfMemoryError(); 3556 3557 scope(exit) free(cast(void*) lookupBuf.ptr); 3558 auto fftObj = scoped!Fft(lookupBuf); 3559 }; 3560 3561 /**Convenience functions that create an `Fft` object, run the FFT or inverse 3562 * FFT and return the result. Useful for one-off FFTs. 3563 * 3564 * Note: In addition to convenience, these functions are slightly more 3565 * efficient than manually creating an Fft object for a single use, 3566 * as the Fft object is deterministically destroyed before these 3567 * functions return. 3568 */ 3569 Complex!F[] fft(F = double, R)(R range) 3570 { 3571 mixin(MakeLocalFft); 3572 return fftObj.fft!(F, R)(range); 3573 } 3574 3575 /// ditto 3576 void fft(Ret, R)(R range, Ret buf) 3577 { 3578 mixin(MakeLocalFft); 3579 return fftObj.fft!(Ret, R)(range, buf); 3580 } 3581 3582 /// ditto 3583 Complex!F[] inverseFft(F = double, R)(R range) 3584 { 3585 mixin(MakeLocalFft); 3586 return fftObj.inverseFft!(F, R)(range); 3587 } 3588 3589 /// ditto 3590 void inverseFft(Ret, R)(R range, Ret buf) 3591 { 3592 mixin(MakeLocalFft); 3593 return fftObj.inverseFft!(Ret, R)(range, buf); 3594 } 3595 3596 @system unittest 3597 { 3598 import std.algorithm; 3599 import std.conv; 3600 import std.range; 3601 // Test values from R and Octave. 3602 auto arr = [1,2,3,4,5,6,7,8]; 3603 auto fft1 = fft(arr); 3604 assert(approxEqual(map!"a.re"(fft1), 3605 [36.0, -4, -4, -4, -4, -4, -4, -4])); 3606 assert(approxEqual(map!"a.im"(fft1), 3607 [0, 9.6568, 4, 1.6568, 0, -1.6568, -4, -9.6568])); 3608 3609 auto fft1Retro = fft(retro(arr)); 3610 assert(approxEqual(map!"a.re"(fft1Retro), 3611 [36.0, 4, 4, 4, 4, 4, 4, 4])); 3612 assert(approxEqual(map!"a.im"(fft1Retro), 3613 [0, -9.6568, -4, -1.6568, 0, 1.6568, 4, 9.6568])); 3614 3615 auto fft1Float = fft(to!(float[])(arr)); 3616 assert(approxEqual(map!"a.re"(fft1), map!"a.re"(fft1Float))); 3617 assert(approxEqual(map!"a.im"(fft1), map!"a.im"(fft1Float))); 3618 3619 alias C = Complex!float; 3620 auto arr2 = [C(1,2), C(3,4), C(5,6), C(7,8), C(9,10), 3621 C(11,12), C(13,14), C(15,16)]; 3622 auto fft2 = fft(arr2); 3623 assert(approxEqual(map!"a.re"(fft2), 3624 [64.0, -27.3137, -16, -11.3137, -8, -4.6862, 0, 11.3137])); 3625 assert(approxEqual(map!"a.im"(fft2), 3626 [72, 11.3137, 0, -4.686, -8, -11.3137, -16, -27.3137])); 3627 3628 auto inv1 = inverseFft(fft1); 3629 assert(approxEqual(map!"a.re"(inv1), arr)); 3630 assert(reduce!max(map!"a.im"(inv1)) < 1e-10); 3631 3632 auto inv2 = inverseFft(fft2); 3633 assert(approxEqual(map!"a.re"(inv2), map!"a.re"(arr2))); 3634 assert(approxEqual(map!"a.im"(inv2), map!"a.im"(arr2))); 3635 3636 // FFTs of size 0, 1 and 2 are handled as special cases. Test them here. 3637 ushort[] empty; 3638 assert(fft(empty) == null); 3639 assert(inverseFft(fft(empty)) == null); 3640 3641 real[] oneElem = [4.5L]; 3642 auto oneFft = fft(oneElem); 3643 assert(oneFft.length == 1); 3644 assert(oneFft[0].re == 4.5L); 3645 assert(oneFft[0].im == 0); 3646 3647 auto oneInv = inverseFft(oneFft); 3648 assert(oneInv.length == 1); 3649 assert(approxEqual(oneInv[0].re, 4.5)); 3650 assert(approxEqual(oneInv[0].im, 0)); 3651 3652 long[2] twoElems = [8, 4]; 3653 auto twoFft = fft(twoElems[]); 3654 assert(twoFft.length == 2); 3655 assert(approxEqual(twoFft[0].re, 12)); 3656 assert(approxEqual(twoFft[0].im, 0)); 3657 assert(approxEqual(twoFft[1].re, 4)); 3658 assert(approxEqual(twoFft[1].im, 0)); 3659 auto twoInv = inverseFft(twoFft); 3660 assert(approxEqual(twoInv[0].re, 8)); 3661 assert(approxEqual(twoInv[0].im, 0)); 3662 assert(approxEqual(twoInv[1].re, 4)); 3663 assert(approxEqual(twoInv[1].im, 0)); 3664 } 3665 3666 // Swaps the real and imaginary parts of a complex number. This is useful 3667 // for inverse FFTs. 3668 C swapRealImag(C)(C input) 3669 { 3670 return C(input.im, input.re); 3671 } 3672 3673 /** This function transforms `decimal` value into a value in the factorial number 3674 system stored in `fac`. 3675 3676 A factorial number is constructed as: 3677 $(D fac[0] * 0! + fac[1] * 1! + ... fac[20] * 20!) 3678 3679 Params: 3680 decimal = The decimal value to convert into the factorial number system. 3681 fac = The array to store the factorial number. The array is of size 21 as 3682 `ulong.max` requires 21 digits in the factorial number system. 3683 Returns: 3684 A variable storing the number of digits of the factorial number stored in 3685 `fac`. 3686 */ 3687 size_t decimalToFactorial(ulong decimal, ref ubyte[21] fac) 3688 @safe pure nothrow @nogc 3689 { 3690 import std.algorithm.mutation : reverse; 3691 size_t idx; 3692 3693 for (ulong i = 1; decimal != 0; ++i) 3694 { 3695 auto temp = decimal % i; 3696 decimal /= i; 3697 fac[idx++] = cast(ubyte)(temp); 3698 } 3699 3700 if (idx == 0) 3701 { 3702 fac[idx++] = cast(ubyte) 0; 3703 } 3704 3705 reverse(fac[0 .. idx]); 3706 3707 // first digit of the number in factorial will always be zero 3708 assert(fac[idx - 1] == 0); 3709 3710 return idx; 3711 } 3712 3713 /// 3714 @safe pure @nogc unittest 3715 { 3716 ubyte[21] fac; 3717 size_t idx = decimalToFactorial(2982, fac); 3718 3719 assert(fac[0] == 4); 3720 assert(fac[1] == 0); 3721 assert(fac[2] == 4); 3722 assert(fac[3] == 1); 3723 assert(fac[4] == 0); 3724 assert(fac[5] == 0); 3725 assert(fac[6] == 0); 3726 } 3727 3728 @safe pure unittest 3729 { 3730 ubyte[21] fac; 3731 size_t idx = decimalToFactorial(0UL, fac); 3732 assert(idx == 1); 3733 assert(fac[0] == 0); 3734 3735 fac[] = 0; 3736 idx = 0; 3737 idx = decimalToFactorial(ulong.max, fac); 3738 assert(idx == 21); 3739 auto t = [7, 11, 12, 4, 3, 15, 3, 5, 3, 5, 0, 8, 3, 5, 0, 0, 0, 2, 1, 1, 0]; 3740 foreach (i, it; fac[0 .. 21]) 3741 { 3742 assert(it == t[i]); 3743 } 3744 3745 fac[] = 0; 3746 idx = decimalToFactorial(2982, fac); 3747 3748 assert(idx == 7); 3749 t = [4, 0, 4, 1, 0, 0, 0]; 3750 foreach (i, it; fac[0 .. idx]) 3751 { 3752 assert(it == t[i]); 3753 } 3754 } 3755 3756 private: 3757 // The reasons I couldn't use std.algorithm were b/c its stride length isn't 3758 // modifiable on the fly and because range has grown some performance hacks 3759 // for powers of 2. 3760 struct Stride(R) 3761 { 3762 import core.bitop : bsf; 3763 Unqual!R range; 3764 size_t _nSteps; 3765 size_t _length; 3766 alias E = ElementType!(R); 3767 3768 this(R range, size_t nStepsIn) 3769 { 3770 this.range = range; 3771 _nSteps = nStepsIn; 3772 _length = (range.length + _nSteps - 1) / nSteps; 3773 } 3774 3775 size_t length() const @property 3776 { 3777 return _length; 3778 } 3779 3780 typeof(this) save() @property 3781 { 3782 auto ret = this; 3783 ret.range = ret.range.save; 3784 return ret; 3785 } 3786 3787 E opIndex(size_t index) 3788 { 3789 return range[index * _nSteps]; 3790 } 3791 3792 E front() @property 3793 { 3794 return range[0]; 3795 } 3796 3797 void popFront() 3798 { 3799 if (range.length >= _nSteps) 3800 { 3801 range = range[_nSteps .. range.length]; 3802 _length--; 3803 } 3804 else 3805 { 3806 range = range[0 .. 0]; 3807 _length = 0; 3808 } 3809 } 3810 3811 // Pops half the range's stride. 3812 void popHalf() 3813 { 3814 range = range[_nSteps / 2 .. range.length]; 3815 } 3816 3817 bool empty() const @property 3818 { 3819 return length == 0; 3820 } 3821 3822 size_t nSteps() const @property 3823 { 3824 return _nSteps; 3825 } 3826 3827 void doubleSteps() 3828 { 3829 _nSteps *= 2; 3830 _length /= 2; 3831 } 3832 3833 size_t nSteps(size_t newVal) @property 3834 { 3835 _nSteps = newVal; 3836 3837 // Using >> bsf(nSteps) is a few cycles faster than / nSteps. 3838 _length = (range.length + _nSteps - 1) >> bsf(nSteps); 3839 return newVal; 3840 } 3841 } 3842 3843 // Hard-coded base case for FFT of size 2. This is actually a TON faster than 3844 // using a generic slow DFT. This seems to be the best base case. (Size 1 3845 // can be coded inline as buf[0] = range[0]). 3846 void slowFourier2(Ret, R)(R range, Ret buf) 3847 { 3848 assert(range.length == 2); 3849 assert(buf.length == 2); 3850 buf[0] = range[0] + range[1]; 3851 buf[1] = range[0] - range[1]; 3852 } 3853 3854 // Hard-coded base case for FFT of size 4. Doesn't work as well as the size 3855 // 2 case. 3856 void slowFourier4(Ret, R)(R range, Ret buf) 3857 { 3858 alias C = ElementType!Ret; 3859 3860 assert(range.length == 4); 3861 assert(buf.length == 4); 3862 buf[0] = range[0] + range[1] + range[2] + range[3]; 3863 buf[1] = range[0] - range[1] * C(0, 1) - range[2] + range[3] * C(0, 1); 3864 buf[2] = range[0] - range[1] + range[2] - range[3]; 3865 buf[3] = range[0] + range[1] * C(0, 1) - range[2] - range[3] * C(0, 1); 3866 } 3867 3868 N roundDownToPowerOf2(N)(N num) 3869 if (isScalarType!N && !isFloatingPoint!N) 3870 { 3871 import core.bitop : bsr; 3872 return num & (cast(N) 1 << bsr(num)); 3873 } 3874 3875 @safe unittest 3876 { 3877 assert(roundDownToPowerOf2(7) == 4); 3878 assert(roundDownToPowerOf2(4) == 4); 3879 } 3880 3881 template isComplexLike(T) 3882 { 3883 enum bool isComplexLike = is(typeof(T.init.re)) && 3884 is(typeof(T.init.im)); 3885 } 3886 3887 @safe unittest 3888 { 3889 static assert(isComplexLike!(Complex!double)); 3890 static assert(!isComplexLike!(uint)); 3891 }