1 module dls.server;
2 
3 import dls.protocol.handlers;
4 import dls.protocol.jsonrpc;
5 
6 shared static this()
7 {
8     import std.algorithm : map;
9     import std.array : join, split;
10     import std.meta : AliasSeq;
11     import std.traits : hasUDA, select;
12     import std.typecons : tuple;
13     import std..string : capitalize;
14 
15     foreach (modName; AliasSeq!("general", "client", "text_document", "window", "workspace"))
16     {
17         mixin("import dls.protocol.messages" ~ (modName.length ? "." ~ modName : "") ~ ";");
18         mixin("alias mod = dls.protocol.messages" ~ (modName.length ? "." ~ modName : "") ~ ";");
19 
20         foreach (thing; __traits(allMembers, mod))
21         {
22             mixin("alias t = " ~ thing ~ ";");
23 
24             static if (isHandler!t)
25             {
26                 enum attrs = tuple(__traits(getAttributes, t));
27                 enum attrsWithDefaults = tuple(modName[0] ~ modName.split('_')
28                             .map!capitalize().join()[1 .. $], thing, attrs.expand);
29                 enum parts = tuple(attrsWithDefaults[attrs.length > 0 ? 2 : 0],
30                             attrsWithDefaults[attrs.length > 1 ? 3 : 1]);
31                 enum method = select!(parts[0].length != 0)(parts[0] ~ "/", "") ~ parts[1];
32 
33                 pushHandler(method, &t);
34             }
35         }
36     }
37 }
38 
39 abstract class Server
40 {
41 
42     import logger = std.experimental.logger;
43     import dls.protocol.interfaces : InitializeParams;
44     import std.algorithm : find, findSplit;
45     import std.json : JSONValue;
46     import std.typecons : Nullable;
47     import std..string : strip, stripRight;
48 
49     private static bool _initialized = false;
50     private static bool _shutdown = false;
51     private static bool _exit = false;
52     private static InitializeParams _initState;
53 
54     @property static auto initState()
55     {
56         return _initState;
57     }
58 
59     @property static void initState(InitializeParams params)
60     {
61         _initState = params;
62 
63         debug
64         {
65             logger.globalLogLevel = logger.LogLevel.all;
66         }
67         else
68         {
69             //dfmt off
70             immutable map = [
71                 InitializeParams.Trace.off : logger.LogLevel.off,
72                 InitializeParams.Trace.messages : logger.LogLevel.info,
73                 InitializeParams.Trace.verbose : logger.LogLevel.all
74             ];
75             //dfmt on
76             logger.globalLogLevel = params.trace.isNull ? logger.LogLevel.off : map[params.trace];
77         }
78     }
79 
80     @property static void opDispatch(string name, T)(T arg)
81     {
82         mixin("_" ~ name ~ " = arg;");
83     }
84 
85     static void loop()
86     {
87         import std.conv : to;
88         import std.stdio : stdin;
89 
90         while (!stdin.eof && !_exit)
91         {
92             string[][] headers;
93             string line;
94 
95             do
96             {
97                 line = stdin.readln().stripRight();
98                 auto parts = line.findSplit(":");
99 
100                 if (parts[1].length)
101                 {
102                     headers ~= [parts[0], parts[2]];
103                 }
104             }
105             while (line.length);
106 
107             if (headers.length == 0)
108             {
109                 continue;
110             }
111 
112             auto contentLengthResult = headers.find!((parts,
113                     name) => parts.length && parts[0] == name)("Content-Length");
114 
115             if (contentLengthResult.length == 0)
116             {
117                 logger.error("No valid Content-Length section in header");
118                 continue;
119             }
120 
121             immutable contentLength = contentLengthResult[0][1].strip().to!size_t;
122             immutable content = stdin.rawRead(new char[contentLength]).idup;
123             // TODO: support UTF-16/32 according to Content-Type when it's supported
124 
125             handleJSON(content);
126         }
127     }
128 
129     private static void handleJSON(T)(immutable(T[]) content)
130     {
131         import dls.util.json : convertFromJSON;
132         import std.algorithm : canFind;
133         import std.json : JSONException, parseJSON;
134 
135         RequestMessage request;
136 
137         try
138         {
139             immutable json = parseJSON(content);
140 
141             if ("method" in json)
142             {
143                 if ("id" in json)
144                 {
145                     request = convertFromJSON!RequestMessage(json);
146 
147                     if (!_shutdown && (_initialized || ["initialize"].canFind(request.method)))
148                     {
149                         send(request.id, handler!RequestHandler(request.method)(request.params));
150                     }
151                     else
152                     {
153                         sendError(ErrorCodes.serverNotInitialized, request);
154                     }
155                 }
156                 else
157                 {
158                     auto notification = convertFromJSON!NotificationMessage(json);
159 
160                     if (_initialized)
161                     {
162                         handler!NotificationHandler(notification.method)(notification.params);
163                     }
164                 }
165             }
166             else
167             {
168                 auto response = convertFromJSON!ResponseMessage(json);
169 
170                 if (response.error.isNull)
171                 {
172                     handler!ResponseHandler(response.id.str)(response.id.str, response.result);
173                 }
174                 else
175                 {
176                     logger.error(response.error.message);
177                 }
178             }
179         }
180         catch (JSONException e)
181         {
182             sendError(ErrorCodes.parseError, request);
183         }
184         catch (HandlerNotFoundException e)
185         {
186             sendError(ErrorCodes.methodNotFound, request);
187         }
188         catch (MessageException e)
189         {
190             send(request.id, Nullable!JSONValue(), ResponseError.fromException(e));
191         }
192     }
193 
194     private static void sendError(ErrorCodes error, RequestMessage request)
195     {
196         if (request !is null)
197         {
198             send(request.id, Nullable!JSONValue(), ResponseError.fromErrorCode(error));
199         }
200     }
201 
202     /++ Sends a request or a notification message. +/
203     static auto send(string method, Nullable!JSONValue params)
204     {
205         import dls.protocol.handlers : hasRegisteredHandler, pushHandler;
206         import std.uuid : randomUUID;
207 
208         if (hasRegisteredHandler(method))
209         {
210             auto id = "dls-" ~ randomUUID().toString();
211             pushHandler(id, method);
212             send!RequestMessage(JSONValue(id), method, params, Nullable!ResponseError());
213             return id;
214         }
215 
216         send!NotificationMessage(JSONValue(), method, params, Nullable!ResponseError());
217         return null;
218     }
219 
220     static auto send(T)(string method, T params) if (!is(T : Nullable!JSONValue))
221     {
222         import dls.util.json : convertToJSON;
223         import std.typecons : nullable;
224 
225         return send(method, convertToJSON(params).nullable);
226     }
227 
228     /++ Sends a response message. +/
229     private static void send(JSONValue id, Nullable!JSONValue result,
230             Nullable!ResponseError error = Nullable!ResponseError())
231     {
232         send!ResponseMessage(id, null, result, error);
233     }
234 
235     private static void send(T : Message)(JSONValue id, string method,
236             Nullable!JSONValue payload, Nullable!ResponseError error)
237     {
238         import dls.protocol.jsonrpc : send;
239         import std.meta : AliasSeq;
240         import std.traits : select;
241 
242         auto message = new T();
243 
244         __traits(getMember, message, select!(__traits(hasMember, T,
245                 "params"))("params", "result")) = payload;
246 
247         foreach (member; AliasSeq!("id", "method", "error"))
248         {
249             static if (__traits(hasMember, T, member))
250             {
251                 mixin("message." ~ member ~ " = " ~ member ~ ";");
252             }
253         }
254 
255         send(message);
256     }
257 }