1 /// I never finished this. The idea is to use CT reflection to make calling another process feel as simple as calling in-process objects. Will come eventually but no promises.
2 module arsd.rpc;
3
4 /*
5 FIXME:
6 1) integrate with arsd.eventloop
7 2) make it easy to use with other processes; pipe to a process and talk to it that way. perhaps with shared memory too?
8 3) extend the serialization capabilities
9
10
11 @Throws!(List, Of, Exceptions)
12 classes are also RPC proxied
13 stdin/out/err also redirected
14 */
15
16 ///+ //example usage
17 interface ExampleNetworkFunctions {
18 string sayHello(string name);
19 int add(int a, int b);
20 S2 structTest(S1);
21 void die();
22 }
23
24 // the server must implement the interface
25 class ExampleServer : ExampleNetworkFunctions {
26 override string sayHello(string name) {
27 return "Hello, " ~ name;
28 }
29
30 override int add(int a, int b) {
31 return a+b;
32 }
33
34 override S2 structTest(S1 a) {
35 return S2(a.name, a.number);
36 }
37
38 override void die() {
39 throw new Exception("death requested");
40 }
41
42 mixin NetworkServer!ExampleNetworkFunctions;
43 }
44
45 struct S1 {
46 int number;
47 string name;
48 }
49
50 struct S2 {
51 string name;
52 int number;
53 }
54
55 import std.stdio;
56 void main(string[] args) {
57 if(args.length > 1) {
58 auto client = makeNetworkClient!ExampleNetworkFunctions("localhost", 5005);
59 // these work like the interface above, but instead of returning the value,
60 // they take callbacks for success (where the arg is the retval)
61 // and failure (the arg is the exception)
62 client.sayHello("whoa", (a) { writeln(a); }, null);
63 client.add(1,2, (a) { writeln(a); }, null);
64 client.add(10,20, (a) { writeln(a); }, null);
65 client.structTest(S1(20, "cool!"), (a) { writeln(a.name, " -- ", a.number); }, null);
66 client.die(delegate () { writeln("shouldn't happen"); }, delegate(a) { writeln(a); });
67 client.eventLoop();
68
69 /*
70 auto client = makeNetworkClient!(ExampleNetworkFunctions, false)("localhost", 5005);
71 writeln(client.sayHello("whoa"));
72 writeln(client.add(1, 2));
73 client.die();
74 writeln(client.add(1, 2));
75 */
76 } else {
77 auto server = new ExampleServer(5005);
78 server.eventLoop();
79 }
80 }
81 //+/
82
83 mixin template NetworkServer(Interface) {
84 import std.socket;
85 private Socket socket;
86 public this(ushort port) {
87 socket = new TcpSocket();
88 socket.setOption(SocketOptionLevel.SOCKET, SocketOption.REUSEADDR, true);
89 socket.bind(new InternetAddress(port));
90 socket.listen(16);
91 }
92
93 final public void eventLoop() {
94 auto check = new SocketSet();
95 Socket[] connections;
96 connections.reserve(16);
97 ubyte[4096] buffer;
98
99 while(true) {
100 check.reset();
101 check.add(socket);
102 foreach(connection; connections) {
103 check.add(connection);
104 }
105
106 if(Socket.select(check, null, null)) {
107 if(check.isSet(socket)) {
108 connections ~= socket.accept();
109 }
110
111 foreach(connection; connections) {
112 if(check.isSet(connection)) {
113 auto gotNum = connection.receive(buffer);
114 if(gotNum == 0) {
115 // connection is closed, we could remove it from the list
116 } else {
117 auto got = buffer[0 .. gotNum];
118 another:
119 int length, functionNumber, sequenceNumber;
120 got = deserializeInto(got, length);
121 got = deserializeInto(got, functionNumber);
122 got = deserializeInto(got, sequenceNumber);
123
124 //writeln("got ", sequenceNumber, " calling ", functionNumber);
125
126 auto remaining = got[length .. $];
127 got = got[0 .. length];
128 import std.conv;
129 assert(length == got.length, to!string(length) ~ " != " ~ to!string(got.length)); // FIXME: what if it doesn't all come at once?
130 callByNumber(functionNumber, sequenceNumber, got, connection);
131
132 if(remaining.length) {
133 got = remaining;
134 goto another;
135 }
136 }
137 }
138 }
139 }
140 }
141 }
142
143 final private void callByNumber(int functionNumber, int sequenceNumber, const(ubyte)[] buffer, Socket connection) {
144 ubyte[4096] sendBuffer;
145 int length = 12;
146 // length, sequence, success
147 serialize(sendBuffer[4 .. 8], sequenceNumber);
148 string callCode() {
149 import std.conv;
150 import std.traits;
151 string code;
152 foreach(memIdx, member; __traits(allMembers, Interface)) {
153 code ~= "\t\tcase " ~ to!string(memIdx + 1) ~ ":\n";
154 alias mem = PassThrough!(__traits(getMember, Interface, member));
155 // we need to deserialize the arguments, call the function, and send back the response (if there is one)
156 string argsString;
157 foreach(i, arg; ParameterTypeTuple!mem) {
158 if(i)
159 argsString ~= ", ";
160 auto istr = to!string(i);
161 code ~= "\t\t\t" ~ arg.stringof ~ " arg" ~ istr ~ ";\n";
162 code ~= "\t\t\tbuffer = deserializeInto(buffer, arg" ~ istr ~ ");\n";
163
164 argsString ~= "arg" ~ istr;
165 }
166
167 // the call
168 static if(is(ReturnType!mem == void)) {
169 code ~= "\n\t\t\t" ~ member ~ "(" ~ argsString ~ ");\n";
170 } else {
171 // call and return answer
172 code ~= "\n\t\t\tauto ret = " ~ member ~ "(" ~ argsString ~ ");\n";
173
174 code ~= "\t\t\tserialize(sendBuffer[8 .. 12], cast(int) 1);\n"; // yes success
175 code ~= "\t\t\tauto serialized = serialize(sendBuffer[12 .. $], ret);\n";
176 code ~= "\t\t\tserialize(sendBuffer[0 .. 4], cast(int) serialized.length);\n";
177 code ~= "\t\t\tlength += serialized.length;\n";
178 }
179 code ~= "\t\tbreak;\n";
180 }
181 return code;
182 }
183
184 try {
185 switch(functionNumber) {
186 default: assert(0, "unknown function");
187 //pragma(msg, callCode());
188 mixin(callCode());
189 }
190 } catch(Throwable t) {
191 //writeln("thrown: ", t);
192 serialize(sendBuffer[8 .. 12], cast(int) 0); // no success
193
194 auto place = sendBuffer[12 .. $];
195 int l;
196 auto s = serialize(place, t.msg);
197 place = place[s.length .. $];
198 l += s.length;
199 s = serialize(place, t.file);
200 place = place[s.length .. $];
201 l += s.length;
202 s = serialize(place, t.line);
203 place = place[s.length .. $];
204 l += s.length;
205
206 serialize(sendBuffer[0 .. 4], l);
207 length += l;
208 }
209
210 if(length != 12) // if there is a response...
211 connection.send(sendBuffer[0 .. length]);
212 }
213 }
214
215 template PassThrough(alias a) {
216 alias PassThrough = a;
217 }
218
219 // general FIXME: what if we run out of buffer space?
220
221 // returns the part of the buffer that was actually used
222 final public ubyte[] serialize(T)(ubyte[] buffer, in T s) {
223 auto original = buffer;
224 size_t totalLength = 0;
225 import std.traits;
226 static if(isArray!T) {
227 /* length */ {
228 auto used = serialize(buffer, cast(int) s.length);
229 totalLength += used.length;
230 buffer = buffer[used.length .. $];
231 }
232 foreach(i; s) {
233 auto used = serialize(buffer, i);
234 totalLength += used.length;
235 buffer = buffer[used.length .. $];
236 }
237 } else static if(isPointer!T) {
238 static assert(0, "no pointers allowed");
239 } else static if(!hasIndirections!T) {
240 // covers int, float, char, etc. most the builtins
241 import std.string;
242 assert(buffer.length >= T.sizeof, format("%s won't fit in %s buffer", T.stringof, buffer.length));
243 buffer[0 .. T.sizeof] = (cast(ubyte*)&s)[0 .. T.sizeof];
244 totalLength += T.sizeof;
245 buffer = buffer[T.sizeof .. $];
246 } else {
247 // structs, classes, etc.
248 foreach(i, t; s.tupleof) {
249 auto used = serialize(buffer, t);
250 totalLength += used.length;
251 buffer = buffer[used.length .. $];
252 }
253 }
254
255 return original[0 .. totalLength];
256 }
257
258 // returns the remaining part of the buffer
259 final public inout(ubyte)[] deserializeInto(T)(inout(ubyte)[] buffer, ref T s) {
260 import std.traits;
261
262 static if(isArray!T) {
263 size_t length;
264 buffer = deserializeInto(buffer, length);
265 s.length = length;
266 foreach(i; 0 .. length)
267 buffer = deserializeInto(buffer, s[i]);
268 } else static if(isPointer!T) {
269 static assert(0, "no pointers allowed");
270 } else static if(!hasIndirections!T) {
271 // covers int, float, char, etc. most the builtins
272 (cast(ubyte*)(&s))[0 .. T.sizeof] = buffer[0 .. T.sizeof];
273 buffer = buffer[T.sizeof .. $];
274 } else {
275 // structs, classes, etc.
276 foreach(i, t; s.tupleof) {
277 buffer = deserializeInto(buffer, s.tupleof[i]);
278 }
279 }
280
281 return buffer;
282 }
283
284 mixin template NetworkClient(Interface, bool useAsync = true) {
285 private static string createClass() {
286 // this doesn't actually inherit from the interface because
287 // the return value needs to be handled async
288 string code;// = `final class Class /*: ` ~ Interface.stringof ~ `*/ {`;
289 code ~= "\n\timport std.socket;";
290 code ~= "\n\tprivate Socket socket;";
291 if(useAsync) {
292 code ~= "\n\tprivate void delegate(const(ubyte)[] buffer)[uint] onSuccesses;";
293 code ~= "\n\tprivate void delegate(const(ubyte)[] buffer)[uint] onErrors;";
294 }
295 code ~= "\n\tprivate uint lastSequenceNumber;";
296 code ~= q{
297 private this(string host, ushort port) {
298 this.socket = new TcpSocket();
299 this.socket.connect(new InternetAddress(host, port));
300 }
301 };
302
303 if(useAsync)
304 code ~= q{
305 final public void eventLoop() {
306 ubyte[4096] buffer;
307 bool open = true;
308
309 do {
310 auto gotNum = socket.receive(buffer);
311 if(gotNum == 0) {
312 open = false;
313 break;
314 }
315 while(gotNum < 9) {
316 auto g2 = socket.receive(buffer[gotNum .. $]);
317 if(g2 == 0) {
318 open = false;
319 break;
320 }
321 gotNum += g2;
322 }
323
324 auto got = buffer[0 .. gotNum];
325 another:
326 uint length, seq;
327 uint success;
328 got = deserializeInto(got, length);
329 got = deserializeInto(got, seq);
330 got = deserializeInto(got, success);
331 auto more = got[length .. $];
332
333 if(got.length >= length) {
334 if(success) {
335 auto s = (seq in onSuccesses);
336 if(s !is null && *s !is null)
337 (*s)(got);
338 } else {
339 auto s = (seq in onErrors);
340 if(s !is null && *s !is null)
341 (*s)(got);
342 }
343 }
344
345 if(more.length) {
346 got = more;
347 goto another;
348 }
349 } while(open);
350 }
351 };
352 code ~= "\n\tpublic:\n";
353
354 foreach(memIdx, member; __traits(allMembers, Interface)) {
355 import std.traits;
356 alias mem = PassThrough!(__traits(getMember, Interface, member));
357 string type;
358 if(useAsync)
359 type = "void";
360 else {
361 static if(is(ReturnType!mem == void))
362 type = "void";
363 else
364 type = (ReturnType!mem).stringof;
365 }
366 code ~= "\t\tfinal "~type~" " ~ member ~ "(";
367 bool hadArgument = false;
368 import std.conv;
369 // arguments
370 foreach(i, arg; ParameterTypeTuple!mem) {
371 if(hadArgument)
372 code ~= ", ";
373 // FIXME: this is one place the arg can get unknown if we don't have all the imports
374 code ~= arg.stringof ~ " arg" ~ to!string(i);
375 hadArgument = true;
376 }
377
378 if(useAsync) {
379 if(hadArgument)
380 code ~= ", ";
381
382 static if(is(ReturnType!mem == void))
383 code ~= "void delegate() onSuccess";
384 else
385 code ~= "void delegate("~(ReturnType!mem).stringof~") onSuccess";
386 code ~= ", ";
387 code ~= "void delegate(Throwable) onError";
388 }
389 code ~= ") {\n";
390 code ~= "auto seq = ++lastSequenceNumber;";
391 if(useAsync)
392 code ~= q{
393 #line 252
394 onSuccesses[seq] = (const(ubyte)[] buffer) {
395 onSuccesses.remove(seq);
396 onErrors.remove(seq);
397
398 import std.traits;
399
400 static if(is(ParameterTypeTuple!(typeof(onSuccess)) == void)) {
401 if(onSuccess !is null)
402 onSuccess();
403 } else {
404 ParameterTypeTuple!(typeof(onSuccess)) args;
405 foreach(i, arg; args)
406 buffer = deserializeInto(buffer, args[i]);
407 if(onSuccess !is null)
408 onSuccess(args);
409 }
410 };
411 onErrors[seq] = (const(ubyte)[] buffer) {
412 onSuccesses.remove(seq);
413 onErrors.remove(seq);
414 auto t = new Throwable("");
415 buffer = deserializeInto(buffer, t.msg);
416 buffer = deserializeInto(buffer, t.file);
417 buffer = deserializeInto(buffer, t.line);
418
419 if(onError !is null)
420 onError(t);
421 };
422 };
423
424 code ~= q{
425 #line 283
426 ubyte[4096] bufferBase;
427 auto buffer = bufferBase[12 .. $]; // leaving room for size, func number, and seq number
428 ubyte[] serialized;
429 int used;
430 };
431 // preparing the request
432 foreach(i, arg; ParameterTypeTuple!mem) {
433 code ~= "\t\t\tserialized = serialize(buffer, arg" ~ to!string(i) ~ ");\n";
434 code ~= "\t\t\tused += serialized.length;\n";
435 code ~= "\t\t\tbuffer = buffer[serialized.length .. $];\n";
436 }
437
438 code ~= "\t\t\tserialize(bufferBase[0 .. 4], used);\n";
439 code ~= "\t\t\tserialize(bufferBase[4 .. 8], " ~ to!string(memIdx + 1) ~ ");\n";
440 code ~= "\t\t\tserialize(bufferBase[8 .. 12], seq);\n";
441
442 // FIXME: what if it doesn't all send at once?
443 code ~= "\t\t\tsocket.send(bufferBase[0 .. 12 + used]);\n";
444 //code ~= `writeln("sending ", bufferBase[0 .. 12 + used]);`;
445
446 if(!useAsync)
447 code ~= q{
448 ubyte[4096] dbuffer;
449 bool open = true;
450 static if(is(typeof(return) == void)) {
451
452 } else
453 typeof(return) returned;
454
455 auto gotNum = socket.receive(dbuffer);
456 if(gotNum == 0) {
457 open = false;
458 throw new Exception("connection closed");
459 }
460 while(gotNum < 9) {
461 auto g2 = socket.receive(dbuffer[gotNum .. $]);
462 if(g2 == 0) {
463 open = false;
464 break;
465 }
466 gotNum += g2;
467 }
468
469 auto got = dbuffer[0 .. gotNum];
470 another:
471 uint length;
472 uint success;
473 got = deserializeInto(got, length);
474 got = deserializeInto(got, seq);
475 got = deserializeInto(got, success);
476 auto more = got[length .. $];
477
478 if(got.length >= length) {
479 if(success) {
480 /*
481 auto s = (seq in onSuccesses);
482 if(s !is null && *s !is null)
483 (*s)(got);
484 */
485 static if(is(typeof(return) == void)) {
486 } else {
487 got = deserializeInto(got, returned);
488 }
489 } else {
490 /*
491 auto s = (seq in onErrors);
492 if(s !is null && *s !is null)
493 (*s)(got);
494 */
495 auto t = new Throwable("");
496 got = deserializeInto(got, t.msg);
497 got = deserializeInto(got, t.file);
498 got = deserializeInto(got, t.line);
499 throw t;
500 }
501 }
502
503 if(more.length) {
504 got = more;
505 goto another;
506 }
507 static if(is(typeof(return) == void)) {
508
509 } else
510 return returned;
511 };
512
513 code ~= "}\n";
514 code ~= "\n";
515 }
516 //code ~= `}`;
517 return code;
518 }
519
520 //pragma(msg, createClass()); // for debugging help
521 mixin(createClass());
522 }
523
524 auto makeNetworkClient(Interface, bool useAsync = true)(string host, ushort port) {
525 class Thing {
526 mixin NetworkClient!(Interface, useAsync);
527 }
528
529 return new Thing(host, port);
530 }
531
532 // the protocol is:
533 /*
534
535 client connects
536 ulong interface hash
537
538 handshake complete
539
540 messages:
541
542 uint messageLength
543 uint sequence number
544 ushort function number, 0 is reserved for interface check
545 serialized arguments....
546
547
548
549 server responds with answers:
550
551 uint messageLength
552 uint re: sequence number
553 ubyte, 1 == success, 0 == error
554 serialized return value
555
556 */