1 /++ 2 Helper functions for generating database stuff. 3 4 Note: this is heavily biased toward Postgres 5 +/ 6 module arsd.database_generation; 7 8 /* 9 10 FIXME: support partial indexes and maybe "using" 11 FIXME: support views 12 13 Let's put indexes in there too and make index functions be the preferred way of doing a query 14 by making them convenient af. 15 */ 16 17 private enum UDA; 18 19 @UDA struct PrimaryKey { 20 string sql; 21 } 22 23 @UDA struct Default { 24 string sql; 25 } 26 27 @UDA struct Unique { } 28 29 @UDA struct ForeignKey(alias toWhat, string behavior) { 30 alias ReferencedTable = __traits(parent, toWhat); 31 } 32 33 enum CASCADE = "ON UPDATE CASCADE ON DELETE CASCADE"; 34 enum NULLIFY = "ON UPDATE CASCADE ON DELETE SET NULL"; 35 enum RESTRICT = "ON UPDATE CASCADE ON DELETE RESTRICT"; 36 37 @UDA struct DBName { string name; } 38 39 struct Nullable(T) { 40 bool isNull = true; 41 T value; 42 43 void opAssign(typeof(null)) { 44 isNull = true; 45 } 46 47 void opAssign(T v) { 48 isNull = false; 49 value = v; 50 } 51 52 T toArsdJsvar() { return value; } 53 } 54 55 struct Timestamp { 56 string value; 57 string toArsdJsvar() { return value; } 58 59 // FIXME: timezone 60 static Timestamp fromStrings(string date, string time) { 61 if(time.length < 6) 62 time ~= ":00"; 63 import std.datetime; 64 return Timestamp(SysTime.fromISOExtString(date ~ "T" ~ time).toISOExtString()); 65 } 66 } 67 68 SysTime parseDbTimestamp(Timestamp s) { 69 return parseDbTimestamp(s.value); 70 } 71 72 SysTime parseDbTimestamp(string s) { 73 if(s.length == 0) return SysTime.init; 74 auto date = s[0 .. 10]; 75 auto time = s[11 .. 20]; 76 auto tz = s[20 .. $]; 77 return SysTime.fromISOExtString(date ~ "T" ~ time ~ tz); 78 } 79 80 struct Constraint(string sql) {} 81 82 struct Index(Fields...) {} 83 struct UniqueIndex(Fields...) {} 84 85 struct Serial { 86 int value; 87 int toArsdJsvar() { return value; } 88 int getValue() { return value; } 89 alias getValue this; 90 } 91 92 93 string generateCreateTableFor(alias O)() { 94 enum tableName = toTableName(O.stringof); 95 string sql = "CREATE TABLE " ~ tableName ~ " ("; 96 string postSql; 97 bool outputtedPostSql = false; 98 99 string afterTableSql; 100 101 void addAfterTableSql(string s) { 102 afterTableSql ~= s; 103 afterTableSql ~= "\n"; 104 } 105 106 void addPostSql(string s) { 107 if(outputtedPostSql) { 108 postSql ~= ","; 109 } 110 postSql ~= "\n"; 111 postSql ~= "\t" ~ s; 112 outputtedPostSql = true; 113 } 114 115 bool outputted = false; 116 static foreach(memberName; __traits(allMembers, O)) {{ 117 alias member = __traits(getMember, O, memberName); 118 static if(is(typeof(member) == Constraint!constraintSql, string constraintSql)) { 119 version(dbgenerate_sqlite) {} else { // FIXME: make it work here too, it is the specifics of the constraint strings 120 if(outputted) { 121 sql ~= ","; 122 } 123 sql ~= "\n"; 124 sql ~= "\tCONSTRAINT " ~ memberName; 125 sql ~= " "; 126 sql ~= constraintSql; 127 outputted = true; 128 } 129 } else static if(is(typeof(member) == Index!Fields, Fields...)) { 130 string fields = ""; 131 static foreach(field; Fields) { 132 if(fields.length) 133 fields ~= ", "; 134 fields ~= __traits(identifier, field); 135 } 136 addAfterTableSql("CREATE INDEX " ~ tableName ~ "_" ~ memberName ~ " ON " ~ tableName ~ "("~fields~")"); 137 } else static if(is(typeof(member) == UniqueIndex!Fields, Fields...)) { 138 string fields = ""; 139 static foreach(field; Fields) { 140 if(fields.length) 141 fields ~= ", "; 142 fields ~= __traits(identifier, field); 143 } 144 addAfterTableSql("CREATE UNIQUE INDEX " ~ tableName ~ "_" ~ memberName ~ " ON " ~ tableName ~ "("~fields~")"); 145 } else static if(is(typeof(member) T)) { 146 if(outputted) { 147 sql ~= ","; 148 } 149 sql ~= "\n"; 150 sql ~= "\t" ~ memberName; 151 152 static if(is(T == Nullable!P, P)) { 153 static if(is(P == int)) 154 sql ~= " INTEGER NULL"; 155 else static if(is(P == string)) 156 sql ~= " TEXT NULL"; 157 else static if(is(P == double)) 158 sql ~= " FLOAT NULL"; 159 else static if(is(P == Timestamp)) 160 sql ~= " TIMESTAMPTZ NULL"; 161 else static assert(0, P.stringof); 162 } else static if(is(T == int)) 163 sql ~= " INTEGER NOT NULL"; 164 else static if(is(T == Serial)) { 165 version(dbgenerate_sqlite) 166 sql ~= " INTEGER PRIMARY KEY AUTOINCREMENT"; 167 else 168 sql ~= " SERIAL"; // FIXME postgresism 169 } else static if(is(T == string)) 170 sql ~= " TEXT NOT NULL"; 171 else static if(is(T == double)) 172 sql ~= " FLOAT NOT NULL"; 173 else static if(is(T == bool)) 174 sql ~= " BOOLEAN NOT NULL"; 175 else static if(is(T == Timestamp)) { 176 version(dbgenerate_sqlite) 177 sql ~= " TEXT NOT NULL"; 178 else 179 sql ~= " TIMESTAMPTZ NOT NULL"; // FIXME: postgresism 180 } else static if(is(T == enum)) 181 sql ~= " INTEGER NOT NULL"; // potentially crap but meh 182 183 static foreach(attr; __traits(getAttributes, member)) { 184 static if(is(typeof(attr) == Default)) { 185 // FIXME: postgresism there, try current_timestamp in sqlite 186 version(dbgenerate_sqlite) { 187 import std..string; 188 sql ~= " DEFAULT " ~ std..string.replace(attr.sql, "now()", "current_timestamp"); 189 } else 190 sql ~= " DEFAULT " ~ attr.sql; 191 } else static if(is(attr == Unique)) { 192 sql ~= " UNIQUE"; 193 } else static if(is(attr == PrimaryKey)) { 194 version(dbgenerate_sqlite) { 195 static if(is(T == Serial)) {} // skip, it is done above 196 else 197 addPostSql("PRIMARY KEY(" ~ memberName ~ ")"); 198 } else 199 addPostSql("PRIMARY KEY(" ~ memberName ~ ")"); 200 } else static if(is(attr == ForeignKey!(to, sqlPolicy), alias to, string sqlPolicy)) { 201 string refTable = toTableName(__traits(parent, to).stringof); 202 string refField = to.stringof; 203 addPostSql("FOREIGN KEY(" ~ memberName ~ ") REFERENCES "~refTable~"("~refField~(sqlPolicy.length ? ") " : ")") ~ sqlPolicy); 204 } 205 } 206 207 outputted = true; 208 } 209 }} 210 211 if(postSql.length && outputted) 212 sql ~= ",\n"; 213 214 sql ~= postSql; 215 sql ~= "\n);\n"; 216 sql ~= afterTableSql; 217 218 return sql; 219 } 220 221 string toTableName(string t) { 222 return plural(50, beautify(t, '_', true)); 223 } 224 225 // copy/pasted from english.d 226 private string plural(int count, string word, string pluralWord = null) { 227 if(count == 1 || word.length == 0) 228 return word; // it isn't actually plural 229 230 if(pluralWord !is null) 231 return pluralWord; 232 233 switch(word[$ - 1]) { 234 case 's': 235 return word ~ "es"; 236 case 'f': 237 return word[0 .. $-1] ~ "ves"; 238 case 'y': 239 return word[0 .. $-1] ~ "ies"; 240 case 'a', 'e', 'i', 'o', 'u': 241 default: 242 return word ~ "s"; 243 } 244 } 245 246 // copy/pasted from cgi 247 private string beautify(string name, char space = ' ', bool allLowerCase = false) { 248 if(name == "id") 249 return allLowerCase ? name : "ID"; 250 251 char[160] buffer; 252 int bufferIndex = 0; 253 bool shouldCap = true; 254 bool shouldSpace; 255 bool lastWasCap; 256 foreach(idx, char ch; name) { 257 if(bufferIndex == buffer.length) return name; // out of space, just give up, not that important 258 259 if((ch >= 'A' && ch <= 'Z') || ch == '_') { 260 if(lastWasCap) { 261 // two caps in a row, don't change. Prolly acronym. 262 } else { 263 if(idx) 264 shouldSpace = true; // new word, add space 265 } 266 267 lastWasCap = true; 268 } else { 269 lastWasCap = false; 270 } 271 272 if(shouldSpace) { 273 buffer[bufferIndex++] = space; 274 if(bufferIndex == buffer.length) return name; // out of space, just give up, not that important 275 shouldSpace = false; 276 } 277 if(shouldCap) { 278 if(ch >= 'a' && ch <= 'z') 279 ch -= 32; 280 shouldCap = false; 281 } 282 if(allLowerCase && ch >= 'A' && ch <= 'Z') 283 ch += 32; 284 buffer[bufferIndex++] = ch; 285 } 286 return buffer[0 .. bufferIndex].idup; 287 } 288 289 import arsd.database; 290 /++ 291 292 +/ 293 void save(O)(ref O t, Database db) { 294 t.insert(db); 295 } 296 297 /++ 298 299 +/ 300 void insert(O)(ref O t, Database db) { 301 auto builder = new InsertBuilder; 302 builder.setTable(toTableName(O.stringof)); 303 304 static foreach(memberName; __traits(allMembers, O)) {{ 305 alias member = __traits(getMember, O, memberName); 306 static if(is(typeof(member) T)) { 307 308 static if(is(T == Nullable!P, P)) { 309 auto v = __traits(getMember, t, memberName); 310 if(v.isNull) 311 builder.addFieldWithSql(memberName, "NULL"); 312 else 313 builder.addVariable(memberName, v.value); 314 } else static if(is(T == int)) 315 builder.addVariable(memberName, __traits(getMember, t, memberName)); 316 else static if(is(T == Serial)) { 317 auto v = __traits(getMember, t, memberName).value; 318 if(v) { 319 builder.addVariable(memberName, v); 320 } else { 321 // skip and let it auto-fill 322 } 323 } else static if(is(T == string)) { 324 builder.addVariable(memberName, __traits(getMember, t, memberName)); 325 } else static if(is(T == double)) 326 builder.addVariable(memberName, __traits(getMember, t, memberName)); 327 else static if(is(T == bool)) 328 builder.addVariable(memberName, __traits(getMember, t, memberName)); 329 else static if(is(T == Timestamp)) { 330 auto v = __traits(getMember, t, memberName).value; 331 if(v.length) 332 builder.addVariable(memberName, v); 333 } else static if(is(T == enum)) 334 builder.addVariable(memberName, cast(int) __traits(getMember, t, memberName)); 335 } 336 }} 337 338 import std.conv; 339 version(dbgenerate_sqlite) { 340 builder.execute(db); 341 foreach(row; db.query("SELECT max(id) FROM " ~ toTableName(O.stringof))) 342 t.id.value = to!int(row[0]); 343 } else { 344 foreach(row; builder.execute(db, "RETURNING id")) // FIXME: postgres-ism 345 t.id.value = to!int(row[0]); 346 } 347 } 348 349 /// 350 class RecordNotFoundException : Exception { 351 this() { super("RecordNotFoundException"); } 352 } 353 354 /++ 355 Returns a given struct populated from the database. Assumes types known to this module. 356 357 MyItem item = db.find!(MyItem.id)(3); 358 359 If you just give a type, it assumes the relevant index is "id". 360 361 +/ 362 auto find(alias T)(Database db, int id) { 363 364 // FIXME: if T is an index, search by it. 365 // if it is unique, return an individual item. 366 // if not, return the array 367 368 foreach(record; db.query("SELECT * FROM " ~ toTableName(T.stringof) ~ " WHERE id = ?", id)) { 369 T t; 370 populateFromDbRow(t, record); 371 372 return t; 373 // if there is ever a second record, that's a wtf, but meh. 374 } 375 throw new RecordNotFoundException(); 376 } 377 378 private void populateFromDbRow(T)(ref T t, Row record) { 379 foreach(field, value; record) { 380 sw: switch(field) { 381 static foreach(memberName; __traits(allMembers, T)) { 382 case memberName: 383 static if(is(typeof(__traits(getMember, T, memberName)))) { 384 populateFromDbVal(__traits(getMember, t, memberName), value); 385 } 386 break sw; 387 } 388 default: 389 // intentionally blank 390 } 391 } 392 } 393 394 private void populateFromDbVal(V)(ref V val, string value) { 395 import std.conv; 396 static if(is(V == Constraint!constraintSql, string constraintSql)) { 397 398 } else static if(is(V == Nullable!P, P)) { 399 // FIXME 400 if(value.length && value != "null") { 401 val.isNull = false; 402 val.value = to!P(value); 403 } 404 } else static if(is(V == bool)) { 405 val = value == "t" || value == "1" || value == "true"; 406 } else static if(is(V == int) || is(V == string) || is(V == double)) { 407 val = to!V(value); 408 } else static if(is(V == enum)) { 409 val = cast(V) to!int(value); 410 } else static if(is(V == Timestamp)) { 411 val.value = value; 412 } else static if(is(V == Serial)) { 413 val.value = to!int(value); 414 } 415 } 416 417 /++ 418 Gets all the children of that type. Specifically, it looks in T for a ForeignKey referencing B and queries on that. 419 420 To do a join through a many-to-many relationship, you could get the children of the join table, then get the children of that... 421 Or better yet, use real sql. This is more intended to get info where there is one parent row and then many child 422 rows, not for a combined thing. 423 +/ 424 QueryBuilderHelper!(T[]) children(T, B)(B base) { 425 int countOfAssociations() { 426 int count = 0; 427 static foreach(memberName; __traits(allMembers, T)) 428 static foreach(attr; __traits(getAttributes, __traits(getMember, T, memberName))) {{ 429 static if(is(attr == ForeignKey!(K, policy), alias K, string policy)) { 430 static if(is(attr.ReferencedTable == B)) 431 count++; 432 } 433 }} 434 return count; 435 } 436 static assert(countOfAssociations() == 1, T.stringof ~ " does not have exactly one foreign key of type " ~ B.stringof); 437 string keyName() { 438 static foreach(memberName; __traits(allMembers, T)) 439 static foreach(attr; __traits(getAttributes, __traits(getMember, T, memberName))) {{ 440 static if(is(attr == ForeignKey!(K, policy), alias K, string policy)) { 441 static if(is(attr.ReferencedTable == B)) 442 return memberName; 443 } 444 }} 445 } 446 447 // return QueryBuilderHelper!(T[])(toTableName(T.stringof)).where!(mixin(keyName ~ " => base.id")); 448 449 // changing mixin cuz of regression in dmd 2.088 450 mixin("return QueryBuilderHelper!(T[])(toTableName(T.stringof)).where!("~keyName ~ " => base.id);"); 451 } 452 453 /++ 454 Finds the single row associated with a foreign key in `base`. 455 456 `T` is used to find the key, unless ambiguous, in which case you must pass `key`. 457 458 To do a join through a many-to-many relationship, go to [children] or use real sql. 459 +/ 460 T associated(B, T, string key = null)(B base, Database db) { 461 int countOfAssociations() { 462 int count = 0; 463 static foreach(memberName; __traits(allMembers, B)) 464 static foreach(attr; __traits(getAttributes, __traits(getMember, B, memberName))) { 465 static if(is(attr == ForeignKey!(K, policy), alias K, string policy)) { 466 static if(is(attr.ReferencedTable == T)) 467 static if(key is null || key == memberName) 468 count++; 469 } 470 } 471 return count; 472 } 473 474 static if(key is null) { 475 enum coa = countOfAssociations(); 476 static assert(coa != 0, B.stringof ~ " has no association of type " ~ T); 477 static assert(coa == 1, B.stringof ~ " has multiple associations of type " ~ T ~ "; please specify the key you want"); 478 static foreach(memberName; __traits(allMembers, B)) 479 static foreach(attr; __traits(getAttributes, __traits(getMember, B, memberName))) { 480 static if(is(attr == ForeignKey!(K, policy), alias K, string policy)) { 481 static if(is(attr.ReferencedTable == T)) 482 return db.find!T(__traits(getMember, base, memberName)); 483 } 484 } 485 } else { 486 static assert(countOfAssociations() == 1, B.stringof ~ " does not have a key named " ~ key ~ " of type " ~ T); 487 static foreach(attr; __traits(getAttributes, __traits(getMember, B, memberName))) { 488 static if(is(attr == ForeignKey!(K, policy), alias K, string policy)) { 489 static if(is(attr.ReferencedTable == T)) { 490 return db.find!T(__traits(getMember, base, key)); 491 } 492 } 493 } 494 assert(0); 495 } 496 } 497 498 499 /++ 500 It will return an aggregate row with a member of type of each table in the join. 501 502 Could do an anonymous object for other things in the sql... 503 +/ 504 auto join(TableA, TableB, ThroughTable = void)() {} 505 506 /++ 507 508 +/ 509 struct QueryBuilderHelper(T) { 510 static if(is(T == R[], R)) 511 alias TType = R; 512 else 513 alias TType = T; 514 515 SelectBuilder selectBuilder; 516 517 this(string tableName) { 518 selectBuilder = new SelectBuilder(); 519 selectBuilder.table = tableName; 520 selectBuilder.fields = ["*"]; 521 } 522 523 T execute(Database db) { 524 selectBuilder.db = db; 525 static if(is(T == R[], R)) { 526 527 } else { 528 selectBuilder.limit = 1; 529 } 530 531 T ret; 532 bool first = true; 533 foreach(row; db.query(selectBuilder.toString())) { 534 TType t; 535 populateFromDbRow(t, row); 536 537 static if(is(T == R[], R)) 538 ret ~= t; 539 else { 540 if(first) { 541 ret = t; 542 first = false; 543 } else { 544 assert(0); 545 } 546 } 547 } 548 return ret; 549 } 550 551 /// 552 typeof(this) orderBy(string criterion)() { 553 string name() { 554 int idx = 0; 555 while(idx < criterion.length && criterion[idx] != ' ') 556 idx++; 557 return criterion[0 .. idx]; 558 } 559 560 string direction() { 561 int idx = 0; 562 while(idx < criterion.length && criterion[idx] != ' ') 563 idx++; 564 import std..string; 565 return criterion[idx .. $].strip; 566 } 567 568 static assert(is(typeof(__traits(getMember, TType, name()))), TType.stringof ~ " has no field " ~ name()); 569 static assert(direction().length == 0 || direction() == "ASC" || direction() == "DESC", "sort direction must be empty, ASC, or DESC"); 570 571 selectBuilder.orderBys ~= criterion; 572 return this; 573 } 574 } 575 576 QueryBuilderHelper!(T[]) from(T)() { 577 return QueryBuilderHelper!(T[])(toTableName(T.stringof)); 578 } 579 580 /// ditto 581 template where(conditions...) { 582 Qbh where(Qbh)(Qbh this_, string[] sqlCondition...) { 583 assert(this_.selectBuilder !is null); 584 585 static string extractName(string s) { 586 if(s.length == 0) assert(0); 587 auto i = s.length - 1; 588 while(i) { 589 if(s[i] == ')') { 590 // got to close paren, now backward to non-identifier char to get name 591 auto end = i; 592 while(i) { 593 if(s[i] == ' ') 594 return s[i + 1 .. end]; 595 i--; 596 } 597 assert(0); 598 } 599 i--; 600 } 601 assert(0); 602 } 603 604 static foreach(idx, cond; conditions) {{ 605 // I hate this but __parameters doesn't work here for some reason 606 // see my old thread: https://forum.dlang.org/post/awjuoemsnmxbfgzhgkgx@forum.dlang.org 607 enum name = extractName(typeof(cond!int).stringof); 608 auto value = cond(null); 609 610 // FIXME: convert the value as necessary 611 static if(is(typeof(value) == Serial)) 612 auto dbvalue = value.value; 613 else static if(is(typeof(value) == enum)) 614 auto dbvalue = cast(int) value; 615 else 616 auto dbvalue = value; 617 618 import std.conv; 619 620 static assert(is(typeof(__traits(getMember, Qbh.TType, name))), Qbh.TType.stringof ~ " has no member " ~ name); 621 static if(is(typeof(__traits(getMember, Qbh.TType, name)) == int)) { 622 static if(is(typeof(value) : const(int)[])) { 623 string s; 624 foreach(v; value) { 625 if(s.length) s ~= ", "; 626 s ~= to!string(v); 627 } 628 this_.selectBuilder.wheres ~= name ~ " IN (" ~ s ~ ")"; 629 } else { 630 static assert(is(typeof(value) : const(int)) || is(typeof(value) == Serial), Qbh.TType.stringof ~ " is a integer key, but you passed an incompatible " ~ typeof(value).stringof); 631 632 auto placeholder = "?_internal" ~ to!string(idx); 633 this_.selectBuilder.wheres ~= name ~ " = " ~ placeholder; 634 this_.selectBuilder.setVariable(placeholder, dbvalue); 635 } 636 } else static if(is(typeof(__traits(getMember, Qbh.TType, name)) == Nullable!int)) { 637 static if(is(typeof(value) : const(int)[])) { 638 string s; 639 foreach(v; value) { 640 if(s.length) s ~= ", "; 641 s ~= to!string(v); 642 } 643 this_.selectBuilder.wheres ~= name ~ " IN (" ~ s ~ ")"; 644 } else { 645 static assert(is(typeof(value) : const(int)) || is(typeof(value) == Serial), Qbh.TType.stringof ~ " is a integer key, but you passed an incompatible " ~ typeof(value).stringof); 646 647 auto placeholder = "?_internal" ~ to!string(idx); 648 this_.selectBuilder.wheres ~= name ~ " = " ~ placeholder; 649 this_.selectBuilder.setVariable(placeholder, dbvalue); 650 } 651 } else static if(is(typeof(__traits(getMember, Qbh.TType, name)) == Serial)) { 652 static if(is(typeof(value) : const(int)[])) { 653 string s; 654 foreach(v; value) { 655 if(s.length) s ~= ", "; 656 s ~= to!string(v); 657 } 658 this_.selectBuilder.wheres ~= name ~ " IN (" ~ s ~ ")"; 659 } else { 660 static assert(is(typeof(value) : const(int)) || is(typeof(value) == Serial), Qbh.TType.stringof ~ " is a integer key, but you passed an incompatible " ~ typeof(value).stringof); 661 662 auto placeholder = "?_internal" ~ to!string(idx); 663 this_.selectBuilder.wheres ~= name ~ " = " ~ placeholder; 664 this_.selectBuilder.setVariable(placeholder, dbvalue); 665 } 666 667 668 } else { 669 static assert(is(typeof(__traits(getMember, Qbh.TType, name)) == typeof(value)), Qbh.TType.stringof ~ "." ~ name ~ " is not of type " ~ typeof(value).stringof); 670 671 auto placeholder = "?_internal" ~ to!string(idx); 672 this_.selectBuilder.wheres ~= name ~ " = " ~ placeholder; 673 this_.selectBuilder.setVariable(placeholder, dbvalue); 674 } 675 }} 676 677 this_.selectBuilder.wheres ~= sqlCondition; 678 return this_; 679 } 680 }