1 module dls.server; 2 3 import dls.protocol.handlers; 4 import dls.protocol.jsonrpc; 5 6 shared static this() 7 { 8 import std.algorithm : map; 9 import std.array : join, split; 10 import std.meta : AliasSeq; 11 import std.traits : hasUDA, select; 12 import std.typecons : tuple; 13 import std..string : capitalize; 14 15 foreach (modName; AliasSeq!("general", "client", "text_document", "window", "workspace")) 16 { 17 mixin("import dls.protocol.messages" ~ (modName.length ? "." ~ modName : "") ~ ";"); 18 mixin("alias mod = dls.protocol.messages" ~ (modName.length ? "." ~ modName : "") ~ ";"); 19 20 foreach (thing; __traits(allMembers, mod)) 21 { 22 mixin("alias t = " ~ thing ~ ";"); 23 24 static if (isHandler!t) 25 { 26 enum attrs = tuple(__traits(getAttributes, t)); 27 enum attrsWithDefaults = tuple(modName[0] ~ modName.split('_') 28 .map!capitalize().join()[1 .. $], thing, attrs.expand); 29 enum parts = tuple(attrsWithDefaults[attrs.length > 0 ? 2 : 0], 30 attrsWithDefaults[attrs.length > 1 ? 3 : 1]); 31 enum method = select!(parts[0].length != 0)(parts[0] ~ "/", "") ~ parts[1]; 32 33 pushHandler(method, &t); 34 } 35 } 36 } 37 } 38 39 abstract class Server 40 { 41 42 import logger = std.experimental.logger; 43 import dls.protocol.interfaces : InitializeParams; 44 import std.algorithm : find, findSplit; 45 import std.json : JSONValue; 46 import std.typecons : Nullable; 47 import std..string : strip, stripRight; 48 49 private static bool _initialized = false; 50 private static bool _shutdown = false; 51 private static bool _exit = false; 52 private static InitializeParams _initState; 53 54 @property static auto initState() 55 { 56 return _initState; 57 } 58 59 @property static void initState(InitializeParams params) 60 { 61 _initState = params; 62 63 debug 64 { 65 logger.globalLogLevel = logger.LogLevel.all; 66 } 67 else 68 { 69 //dfmt off 70 immutable map = [ 71 InitializeParams.Trace.off : logger.LogLevel.off, 72 InitializeParams.Trace.messages : logger.LogLevel.info, 73 InitializeParams.Trace.verbose : logger.LogLevel.all 74 ]; 75 //dfmt on 76 logger.globalLogLevel = params.trace.isNull ? logger.LogLevel.off : map[params.trace]; 77 } 78 } 79 80 @property static void opDispatch(string name, T)(T arg) 81 { 82 mixin("_" ~ name ~ " = arg;"); 83 } 84 85 static void loop() 86 { 87 import std.conv : to; 88 import std.stdio : stdin; 89 90 while (!stdin.eof && !_exit) 91 { 92 string[][] headers; 93 string line; 94 95 do 96 { 97 line = stdin.readln().stripRight(); 98 auto parts = line.findSplit(":"); 99 100 if (parts[1].length) 101 { 102 headers ~= [parts[0], parts[2]]; 103 } 104 } 105 while (line.length); 106 107 if (headers.length == 0) 108 { 109 continue; 110 } 111 112 auto contentLengthResult = headers.find!((parts, 113 name) => parts.length && parts[0] == name)("Content-Length"); 114 115 if (contentLengthResult.length == 0) 116 { 117 logger.error("No valid Content-Length section in header"); 118 continue; 119 } 120 121 immutable contentLength = contentLengthResult[0][1].strip().to!size_t; 122 immutable content = stdin.rawRead(new char[contentLength]).idup; 123 // TODO: support UTF-16/32 according to Content-Type when it's supported 124 125 handleJSON(content); 126 } 127 } 128 129 private static void handleJSON(T)(immutable(T[]) content) 130 { 131 import dls.util.json : convertFromJSON; 132 import std.algorithm : canFind; 133 import std.json : JSONException, parseJSON; 134 135 RequestMessage request; 136 137 try 138 { 139 immutable json = parseJSON(content); 140 141 if ("method" in json) 142 { 143 if ("id" in json) 144 { 145 request = convertFromJSON!RequestMessage(json); 146 147 if (!_shutdown && (_initialized || ["initialize"].canFind(request.method))) 148 { 149 send(request.id, handler!RequestHandler(request.method)(request.params)); 150 } 151 else 152 { 153 sendError(ErrorCodes.serverNotInitialized, request); 154 } 155 } 156 else 157 { 158 auto notification = convertFromJSON!NotificationMessage(json); 159 160 if (_initialized) 161 { 162 handler!NotificationHandler(notification.method)(notification.params); 163 } 164 } 165 } 166 else 167 { 168 auto response = convertFromJSON!ResponseMessage(json); 169 170 if (response.error.isNull) 171 { 172 handler!ResponseHandler(response.id.str)(response.id.str, response.result); 173 } 174 else 175 { 176 logger.error(response.error.message); 177 } 178 } 179 } 180 catch (JSONException e) 181 { 182 sendError(ErrorCodes.parseError, request); 183 } 184 catch (HandlerNotFoundException e) 185 { 186 sendError(ErrorCodes.methodNotFound, request); 187 } 188 catch (MessageException e) 189 { 190 send(request.id, Nullable!JSONValue(), ResponseError.fromException(e)); 191 } 192 } 193 194 private static void sendError(ErrorCodes error, RequestMessage request) 195 { 196 if (request !is null) 197 { 198 send(request.id, Nullable!JSONValue(), ResponseError.fromErrorCode(error)); 199 } 200 } 201 202 /++ Sends a request or a notification message. +/ 203 static auto send(string method, Nullable!JSONValue params) 204 { 205 import dls.protocol.handlers : hasRegisteredHandler, pushHandler; 206 import std.uuid : randomUUID; 207 208 if (hasRegisteredHandler(method)) 209 { 210 auto id = "dls-" ~ randomUUID().toString(); 211 pushHandler(id, method); 212 send!RequestMessage(JSONValue(id), method, params, Nullable!ResponseError()); 213 return id; 214 } 215 216 send!NotificationMessage(JSONValue(), method, params, Nullable!ResponseError()); 217 return null; 218 } 219 220 static auto send(T)(string method, T params) if (!is(T : Nullable!JSONValue)) 221 { 222 import dls.util.json : convertToJSON; 223 import std.typecons : nullable; 224 225 return send(method, convertToJSON(params).nullable); 226 } 227 228 /++ Sends a response message. +/ 229 private static void send(JSONValue id, Nullable!JSONValue result, 230 Nullable!ResponseError error = Nullable!ResponseError()) 231 { 232 send!ResponseMessage(id, null, result, error); 233 } 234 235 private static void send(T : Message)(JSONValue id, string method, 236 Nullable!JSONValue payload, Nullable!ResponseError error) 237 { 238 import dls.protocol.jsonrpc : send; 239 import std.meta : AliasSeq; 240 import std.traits : select; 241 242 auto message = new T(); 243 244 __traits(getMember, message, select!(__traits(hasMember, T, 245 "params"))("params", "result")) = payload; 246 247 foreach (member; AliasSeq!("id", "method", "error")) 248 { 249 static if (__traits(hasMember, T, member)) 250 { 251 mixin("message." ~ member ~ " = " ~ member ~ ";"); 252 } 253 } 254 255 send(message); 256 } 257 }