1 module dls.server;
2 
3 import dls.protocol.handlers;
4 import dls.protocol.jsonrpc;
5 
6 shared static this()
7 {
8     import std.algorithm : map;
9     import std.array : join, split;
10     import std.meta : AliasSeq;
11     import std.traits : hasUDA, select;
12     import std.typecons : tuple;
13     import std.string : capitalize;
14 
15     foreach (modName; AliasSeq!("general", "client", "text_document", "window", "workspace"))
16     {
17         mixin("import dls.protocol.messages" ~ (modName.length ? "." ~ modName : "") ~ ";");
18         mixin("alias mod = dls.protocol.messages" ~ (modName.length ? "." ~ modName : "") ~ ";");
19 
20         foreach (thing; __traits(allMembers, mod))
21         {
22             mixin("alias t = " ~ thing ~ ";");
23 
24             static if (isHandler!t)
25             {
26                 enum attrs = tuple(__traits(getAttributes, t));
27                 enum attrsWithDefaults = tuple(modName[0] ~ modName.split('_')
28                             .map!capitalize().join()[1 .. $], thing, attrs.expand);
29                 enum parts = tuple(attrsWithDefaults[attrs.length > 0 ? 2 : 0],
30                             attrsWithDefaults[attrs.length > 1 ? 3 : 1]);
31                 enum method = select!(parts[0].length != 0)(parts[0] ~ "/", "") ~ parts[1];
32 
33                 pushHandler(method, &t);
34             }
35         }
36     }
37 }
38 
39 abstract class Server
40 {
41 
42     import logger = std.experimental.logger;
43     import dls.protocol.interfaces : InitializeParams;
44     import std.algorithm : find, findSplit;
45     import std.json : JSONValue;
46     import std.typecons : Nullable;
47     import std.string : strip, stripRight;
48 
49     private static bool _initialized = false;
50     private static bool _shutdown = false;
51     private static bool _exit = false;
52     private static InitializeParams _initState;
53 
54     @property static auto initState()
55     {
56         return _initState;
57     }
58 
59     @property static void initState(InitializeParams params)
60     {
61         _initState = params;
62 
63         debug
64         {
65             logger.globalLogLevel = logger.LogLevel.all;
66         }
67         else
68         {
69             //dfmt off
70             immutable map = [
71                 InitializeParams.Trace.off : logger.LogLevel.off,
72                 InitializeParams.Trace.messages : logger.LogLevel.info,
73                 InitializeParams.Trace.verbose : logger.LogLevel.all
74             ];
75             //dfmt on
76             logger.globalLogLevel = params.trace.isNull ? logger.LogLevel.off : map[params.trace];
77         }
78     }
79 
80     @property static void opDispatch(string name, T)(T arg)
81     {
82         mixin("_" ~ name ~ " = arg;");
83     }
84 
85     static void loop()
86     {
87         import std.conv : to;
88         import std.stdio : stdin;
89 
90         while (!stdin.eof && !_exit)
91         {
92             string[][] headers;
93             string line;
94 
95             do
96             {
97                 line = stdin.readln().stripRight();
98                 auto parts = line.findSplit(":");
99 
100                 if (parts[1].length)
101                 {
102                     headers ~= [parts[0], parts[2]];
103                 }
104             }
105             while (line.length);
106 
107             if (headers.length == 0)
108             {
109                 continue;
110             }
111 
112             auto contentLengthResult = headers.find!((parts,
113                     name) => parts.length && parts[0] == name)("Content-Length");
114 
115             if (contentLengthResult.length == 0)
116             {
117                 logger.error("No valid Content-Length section in header");
118                 continue;
119             }
120 
121             immutable contentLength = contentLengthResult[0][1].strip().to!size_t;
122             immutable content = stdin.rawRead(new char[contentLength]).idup;
123             // TODO: support UTF-16/32 according to Content-Type when it's supported
124 
125             handleJSON(content);
126         }
127     }
128 
129     private static void handleJSON(T)(immutable(T[]) content)
130     {
131         import dls.util.json : convertFromJSON;
132         import std.algorithm : canFind;
133         import std.json : JSONException, parseJSON;
134         import std.typecons : nullable;
135 
136         RequestMessage request;
137 
138         try
139         {
140             immutable json = parseJSON(content);
141 
142             if ("method" in json)
143             {
144                 if ("id" in json)
145                 {
146                     request = convertFromJSON!RequestMessage(json);
147 
148                     if (!_shutdown && (_initialized || ["initialize",
149                             "exit"].canFind(request.method)))
150                     {
151                         send(request.id, handler!RequestHandler(request.method)(request.params));
152                     }
153                     else
154                     {
155                         sendError(ErrorCodes.serverNotInitialized, request);
156                     }
157                 }
158                 else
159                 {
160                     auto notification = convertFromJSON!NotificationMessage(json);
161 
162                     if (_initialized)
163                     {
164                         handler!NotificationHandler(notification.method)(notification.params);
165                     }
166                 }
167             }
168             else
169             {
170                 auto response = convertFromJSON!ResponseMessage(json);
171 
172                 if (response.error.isNull)
173                 {
174                     handler!ResponseHandler(response.id.str)(response.id.str, response.result);
175                 }
176                 else
177                 {
178                     logger.error(response.error.message);
179                 }
180             }
181         }
182         catch (JSONException e)
183         {
184             sendError(ErrorCodes.parseError, request);
185         }
186         catch (HandlerNotFoundException e)
187         {
188             sendError(ErrorCodes.methodNotFound, request);
189         }
190         catch (MessageException e)
191         {
192             send(request.id, Nullable!JSONValue(), ResponseError.fromException(e));
193         }
194     }
195 
196     private static void sendError(ErrorCodes error, RequestMessage request)
197     {
198         if (request !is null)
199         {
200             send(request.id, Nullable!JSONValue(), ResponseError.fromErrorCode(error));
201         }
202     }
203 
204     /++ Sends a request or a notification message. +/
205     static auto send(string method, Nullable!JSONValue params)
206     {
207         import dls.protocol.handlers : hasRegisteredHandler, pushHandler;
208         import std.uuid : randomUUID;
209 
210         if (hasRegisteredHandler(method))
211         {
212             auto id = "dls-" ~ randomUUID().toString();
213             pushHandler(id, method);
214             send!RequestMessage(JSONValue(id), method, params, Nullable!ResponseError());
215             return id;
216         }
217 
218         send!NotificationMessage(JSONValue(), method, params, Nullable!ResponseError());
219         return null;
220     }
221 
222     static auto send(T)(string method, T params) if (!is(T : Nullable!JSONValue))
223     {
224         import dls.util.json : convertToJSON;
225 
226         return send(method, convertToJSON(params).nullable);
227     }
228 
229     /++ Sends a response message. +/
230     private static void send(JSONValue id, Nullable!JSONValue result,
231             Nullable!ResponseError error = Nullable!ResponseError())
232     {
233         send!ResponseMessage(id, null, result, error);
234     }
235 
236     private static void send(T : Message)(JSONValue id, string method,
237             Nullable!JSONValue payload, Nullable!ResponseError error)
238     {
239         import dls.protocol.jsonrpc : send;
240         import std.meta : AliasSeq;
241         import std.traits : select;
242 
243         auto message = new T();
244 
245         __traits(getMember, message, select!(__traits(hasMember, T,
246                 "params"))("params", "result")) = payload;
247 
248         foreach (member; AliasSeq!("id", "method", "error"))
249         {
250             static if (__traits(hasMember, T, member))
251             {
252                 mixin("message." ~ member ~ " = " ~ member ~ ";");
253             }
254         }
255 
256         send(message);
257     }
258 }