1 /++
2 	mvd stands for Multiple Virtual Dispatch. It lets you
3 	write functions that take any number of arguments of
4 	objects and match based on the dynamic type of each
5 	of them.
6 
7 	---
8 	void foo(Object a, Object b) {} // 1
9 	void foo(MyClass b, Object b) {} // 2
10 	void foo(DerivedClass a, MyClass b) {} // 3
11 
12 	Object a = new MyClass();
13 	Object b = new Object();
14 
15 	mvd!foo(a, b); // will call overload #2
16 	---
17 
18 	The return values must be compatible; [mvd] will return
19 	the least specialized static type of the return values
20 	(most likely the shared base class type of all return types,
21 	or `void` if there isn't one).
22 
23 	All non-class/interface types should be compatible among overloads.
24 	Otherwise you are liable to get compile errors. (Or it might work,
25 	that's up to the compiler's discretion.)
26 +/
27 module arsd.mvd;
28 
29 import std.traits;
30 
31 /// This exists just to make the documentation of [mvd] nicer looking.
32 alias CommonReturnOfOverloads(alias fn) = CommonType!(staticMap!(ReturnType, __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn))));
33 
34 /// See details on the [arsd.mvd] page.
35 CommonReturnOfOverloads!fn mvd(alias fn, T...)(T args) {
36 	typeof(return) delegate() bestMatch;
37 	int bestScore;
38 
39 	string argsStr() {
40 		string s;
41 		foreach(arg; args) {
42 			if(s.length)
43 				s ~= ", ";
44 			static if (is(typeof(arg) == class)) {
45 				if (arg is null) {
46 					s ~= "null " ~ typeof(arg).stringof;
47 				} else {
48 					s ~= typeid(arg).name;
49 				}
50 			} else {
51 				s ~= typeof(arg).stringof;
52 			}
53 		}
54 		return s;
55 	}
56 
57 	ov: foreach(overload; __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn))) {
58 		Parameters!overload pargs;
59 		int score = 0;
60 		foreach(idx, parg; pargs) {
61 			alias t = typeof(parg);
62 			static if(is(t == interface) || is(t == class)) {
63 				pargs[idx] = cast(typeof(parg)) args[idx];
64 				if(args[idx] !is null && pargs[idx] is null)
65 					continue ov; // failed cast, forget it
66 				else
67 					score += BaseClassesTuple!t.length + 1;
68 			} else
69 				pargs[idx] = args[idx];
70 		}
71 		if(score == bestScore)
72 			throw new Exception("ambiguous overload selection with args (" ~ argsStr ~ ")");
73 		if(score > bestScore) {
74 			bestMatch = () {
75 				static if(is(typeof(return) == void))
76 					overload(pargs);
77 				else
78 					return overload(pargs);
79 			};
80 			bestScore = score;
81 		}
82 	}
83 
84 	if(bestMatch is null)
85 		throw new Exception("no match existed with args (" ~ argsStr ~ ")");
86 
87 	return bestMatch();
88 }
89 
90 ///
91 unittest {
92 
93 	class MyClass {}
94 	class DerivedClass : MyClass {}
95 	class OtherClass {}
96 
97 	static struct Wrapper {
98 		static: // this is just a namespace cuz D doesn't allow overloading inside unittest
99 		int foo(Object a, Object b) { return 1; }
100 		int foo(MyClass a, Object b) { return 2; }
101 		int foo(DerivedClass a, MyClass b) { return 3; }
102 
103 		int bar(MyClass a) { return 4; }
104 	}
105 
106 	with(Wrapper) {
107 		assert(mvd!foo(new Object, new Object) == 1);
108 		assert(mvd!foo(new MyClass, new DerivedClass) == 2);
109 		assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
110 		assert(mvd!foo(new OtherClass, new OtherClass) == 1);
111 		assert(mvd!foo(new OtherClass, new MyClass) == 1);
112 		assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
113 		assert(mvd!foo(new OtherClass, new MyClass) == 1);
114 
115 		//mvd!bar(new OtherClass);
116 	}
117 }
Suggestion Box / Bug Report