PostgreSQL Source Code git master
Loading...
Searching...
No Matches
oauth_server.py
Go to the documentation of this file.
1#! /usr/bin/env python3
2#
3# A mock OAuth authorization server, designed to be invoked from
4# OAuth/Server.pm. This listens on an ephemeral port number (printed to stdout
5# so that the Perl tests can contact it) and runs as a daemon until it is
6# signaled.
7#
8
9import base64
10import functools
11import http.server
12import json
13import os
14import sys
15import time
16import urllib.parse
17from collections import defaultdict
18from typing import Dict
19
20
22 """
23 Core implementation of the authorization server. The API is
24 inheritance-based, with entry points at do_GET() and do_POST(). See the
25 documentation for BaseHTTPRequestHandler.
26 """
27
28 JsonObject = Dict[str, object] # TypeAlias is not available until 3.10
29
30 def _check_issuer(self):
31 """
32 Switches the behavior of the provider depending on the issuer URI.
33 """
34 self._alt_issuer = (
35 self.path.startswith("/alternate/")
36 or self.path == "/.well-known/oauth-authorization-server/alternate"
37 )
38 self._parameterized = self.path.startswith("/param/")
39
40 # Strip off the magic path segment. (The more readable
41 # str.removeprefix()/removesuffix() aren't available until Py3.9.)
42 if self._alt_issuer:
43 # The /alternate issuer uses IETF-style .well-known URIs.
44 if self.path.startswith("/.well-known/"):
45 self.path = self.path[: -len("/alternate")]
46 else:
47 self.path = self.path[len("/alternate") :]
48 elif self._parameterized:
49 self.path = self.path[len("/param") :]
50
51 def _check_authn(self):
52 """
53 Checks the expected value of the Authorization header, if any.
54 """
55 secret = self._get_param("expected_secret", None)
56 if secret is None:
57 return
58
59 assert "Authorization" in self.headers
60 method, creds = self.headers["Authorization"].split()
61
62 if method != "Basic":
63 raise RuntimeError(f"client used {method} auth; expected Basic")
64
65 # TODO: Remove "~" from the safe list after Py3.6 support is removed.
66 # 3.7 does this by default.
67 username = urllib.parse.quote_plus(self.client_id, safe="~")
68 password = urllib.parse.quote_plus(secret, safe="~")
69 expected_creds = f"{username}:{password}"
70
72 raise RuntimeError(
73 f"client sent '{creds}'; expected b64encode('{expected_creds}')"
74 )
75
76 def do_GET(self):
77 self._response_code = 200
78 self._check_issuer()
79
80 config_path = "/.well-known/openid-configuration"
81 if self._alt_issuer:
82 config_path = "/.well-known/oauth-authorization-server"
83
84 if self.path == config_path:
85 resp = self.config()
86 else:
87 self.send_error(404, "Not Found")
88 return
89
90 self._send_json(resp)
91
92 def _parse_params(self) -> Dict[str, str]:
93 """
94 Parses apart the form-urlencoded request body and returns the resulting
95 dict. For use by do_POST().
96 """
97 size = int(self.headers["Content-Length"])
98 form = self.rfile.read(size)
99
100 assert self.headers["Content-Type"] == "application/x-www-form-urlencoded"
102 form.decode("utf-8"),
103 strict_parsing=True,
104 keep_blank_values=True,
105 encoding="utf-8",
106 errors="strict",
107 )
108
109 @property
110 def client_id(self) -> str:
111 """
112 Returns the client_id sent in the POST body or the Authorization header.
113 self._parse_params() must have been called first.
114 """
115 if "client_id" in self._params:
116 return self._params["client_id"][0]
117
118 if "Authorization" not in self.headers:
119 raise RuntimeError("client did not send any client_id")
120
121 _, creds = self.headers["Authorization"].split()
122
123 decoded = base64.b64decode(creds).decode("utf-8")
124 username, _ = decoded.split(":", 1)
125
126 return urllib.parse.unquote_plus(username)
127
128 def do_POST(self):
129 self._response_code = 200
130 self._check_issuer()
131
132 self._params = self._parse_params()
133 if self._parameterized:
134 # Pull encoded test parameters out of the peer's client_id field.
135 # This is expected to be Base64-encoded JSON.
136 js = base64.b64decode(self.client_id)
138
139 self._check_authn()
140
141 if self.path == "/authorize":
142 resp = self.authorization()
143 elif self.path == "/token":
144 resp = self.token()
145 else:
146 self.send_error(404)
147 return
148
149 self._send_json(resp)
150
151 def _should_modify(self) -> bool:
152 """
153 Returns True if the client has requested a modification to this stage of
154 the exchange.
155 """
156 if not hasattr(self, "_test_params"):
157 return False
158
159 stage = self._test_params.get("stage")
160
161 return (
162 stage == "all"
163 or (
164 stage == "discovery"
165 and self.path == "/.well-known/openid-configuration"
166 )
167 or (stage == "device" and self.path == "/authorize")
168 or (stage == "token" and self.path == "/token")
169 )
170
171 def _get_param(self, name, default):
172 """
173 If the client has requested a modification to this stage (see
174 _should_modify()), this method searches the provided test parameters for
175 a key of the given name, and returns it if found. Otherwise the provided
176 default is returned.
177 """
178 if self._should_modify() and name in self._test_params:
179 return self._test_params[name]
180
181 return default
182
183 @property
184 def _content_type(self) -> str:
185 """
186 Returns "application/json" unless the test has requested something
187 different.
188 """
189 return self._get_param("content_type", "application/json")
190
191 @property
192 def _interval(self) -> int:
193 """
194 Returns 0 unless the test has requested something different.
195 """
196 return self._get_param("interval", 0)
197
198 @property
199 def _retry_code(self) -> str:
200 """
201 Returns "authorization_pending" unless the test has requested something
202 different.
203 """
204 return self._get_param("retry_code", "authorization_pending")
205
206 @property
207 def _uri_spelling(self) -> str:
208 """
209 Returns "verification_uri" unless the test has requested something
210 different.
211 """
212 return self._get_param("uri_spelling", "verification_uri")
213
214 @property
216 """
217 Returns a dict with any additional entries that should be folded into a
218 JSON response, as determined by test parameters provided by the client:
219
220 - huge_response: if set to True, the dict will contain a gigantic string
221 value
222
223 - nested_array: if set to nonzero, the dict will contain a deeply nested
224 array so that the top-level object has the given depth
225
226 - nested_object: if set to nonzero, the dict will contain a deeply
227 nested JSON object so that the top-level object has the given depth
228 """
229 ret = dict()
230
231 if self._get_param("huge_response", False):
232 ret["_pad_"] = "x" * 1024 * 1024
233
234 depth = self._get_param("nested_array", 0)
235 if depth:
236 ret["_arr_"] = functools.reduce(lambda x, _: [x], range(depth))
237
238 depth = self._get_param("nested_object", 0)
239 if depth:
240 ret["_obj_"] = functools.reduce(lambda x, _: {"": x}, range(depth))
241
242 return ret
243
244 @property
245 def _access_token(self):
246 """
247 The actual Bearer token sent back to the client on success. Tests may
248 override this with the "token" test parameter.
249 """
250 token = self._get_param("token", None)
251 if token is not None:
252 return token
253
254 token = "9243959234"
255 if self._alt_issuer:
256 token += "-alt"
257
258 return token
259
260 def _log_response(self, js: JsonObject) -> None:
261 """
262 Trims the response JSON, if necessary, and logs it for later debugging.
263 """
264 # At the moment the biggest problem for tests is the _pad_ member, which
265 # is a megabyte in size, so truncate that to something more reasonable.
266 if "_pad_" in js:
267 pad = js["_pad_"]
268
269 # Don't modify the original dict.
270 js = dict(js)
271 js["_pad_"] = pad[:64] + f"[...truncated from {len(pad)} bytes]"
272
273 resp = json.dumps(js).encode("ascii")
274 self.log_message("sending JSON response: %s", resp)
275
276 # If you've tripped this assertion, please truncate the new addition as
277 # above, or else come up with a new strategy.
278 assert len(resp) < 1024, "_log_response must be adjusted for new JSON"
279
280 def _send_json(self, js: JsonObject) -> None:
281 """
282 Sends the provided JSON dict as an application/json response.
283 self._response_code can be modified to send JSON error responses.
284 """
285 resp = json.dumps(js).encode("ascii")
286 self._log_response(js)
287
288 self.send_response(self._response_code)
289 self.send_header("Content-Type", self._content_type_content_type)
290 self.send_header("Content-Length", str(len(resp)))
291 self.end_headers()
292
293 self.wfile.write(resp)
294
295 def config(self) -> JsonObject:
296 port = self.server.socket.getsockname()[1]
297
298 issuer = f"http://127.0.0.1:{port}"
299 if self._alt_issuer:
300 issuer += "/alternate"
301 elif self._parameterized:
302 issuer += "/param"
303
304 return {
305 "issuer": issuer,
306 "token_endpoint": issuer + "/token",
307 "device_authorization_endpoint": issuer + "/authorize",
308 "response_types_supported": ["token"],
309 "subject_types_supported": ["public"],
310 "id_token_signing_alg_values_supported": ["RS256"],
311 "grant_types_supported": [
312 "authorization_code",
313 "urn:ietf:params:oauth:grant-type:device_code",
314 ],
315 }
316
317 @property
318 def _token_state(self):
319 """
320 A cached _TokenState object for the connected client (as determined by
321 the request's client_id), or a new one if it doesn't already exist.
322
323 This relies on the existence of a defaultdict attached to the server;
324 see main() below.
325 """
326 return self.server.token_state[self.client_id]
327
329 """
330 Removes any cached _TokenState for the current client_id. Call this
331 after the token exchange ends to get rid of unnecessary state.
332 """
333 if self.client_id in self.server.token_state:
334 del self.server.token_state[self.client_id]
335
336 def authorization(self) -> JsonObject:
337 uri = "https://example.com/"
338 if self._alt_issuer:
339 uri = "https://example.org/"
340
341 resp = {
342 "device_code": "postgres",
343 "user_code": "postgresuser",
344 self._uri_spelling: uri,
345 "expires_in": 5,
347 }
348
349 interval = self._interval
350 if interval is not None:
351 resp["interval"] = interval
352 self._token_state.min_delay = interval
353 else:
354 self._token_state.min_delay = 5 # default
355
356 # Check the scope.
357 if "scope" in self._params:
358 assert self._params["scope"][0], "empty scopes should be omitted"
359
360 return resp
361
362 def token(self) -> JsonObject:
363 err = self._get_param("error_code", None)
364 if err:
365 self._response_code = self._get_param("error_status", 400)
366
367 resp = {"error": err}
368
369 desc = self._get_param("error_desc", "")
370 if desc:
371 resp["error_description"] = desc
372
373 return resp
374
375 if self._should_modify() and "retries" in self._test_params:
376 retries = self._test_params["retries"]
377
378 # Check to make sure the token interval is being respected.
379 now = time.monotonic()
380 if self._token_state.last_try is not None:
381 delay = now - self._token_state.last_try
382 assert (
383 delay > self._token_state.min_delay
384 ), f"client waited only {delay} seconds between token requests (expected {self._token_state.min_delay})"
385
386 self._token_state.last_try = now
387
388 # If we haven't reached the required number of retries yet, return a
389 # "pending" response.
390 if self._token_state.retries < retries:
391 self._token_state.retries += 1
392
393 self._response_code = 400
394 return {"error": self._retry_code}
395
396 # Clean up any retry tracking state now that the exchange is ending.
398
399 return {
400 "access_token": self._access_token_access_token,
401 "token_type": "bearer",
403 }
404
405
406def main():
407 """
408 Starts the authorization server on localhost. The ephemeral port in use will
409 be printed to stdout.
410 """
411
412 s = http.server.HTTPServer(("127.0.0.1", 0), OAuthHandler)
413
414 # Attach a "cache" dictionary to the server to allow the OAuthHandlers to
415 # track state across token requests. The use of defaultdict ensures that new
416 # entries will be created automatically.
417 class _TokenState:
418 retries = 0
419 min_delay = None
420 last_try = None
421
422 s.token_state = defaultdict(_TokenState)
423
424 # Give the parent the port number to contact (this is also the signal that
425 # we're ready to receive requests).
426 port = s.socket.getsockname()[1]
427 print(port)
428
429 # stdout is closed to allow the parent to just "read to the end".
430 stdout = sys.stdout.fileno()
432 os.close(stdout)
433
434 s.serve_forever() # we expect our parent to send a termination signal
435
436
437if __name__ == "__main__":
438 main()
void print(const void *obj)
Definition print.c:36
Dict[str, str] _parse_params(self)
JsonObject authorization(self)
None _send_json(self, JsonObject js)
None _log_response(self, JsonObject js)
_get_param(self, name, default)
JsonObject config(self)
const char * str
#define token
#define write(a, b, c)
Definition win32.h:14
#define read(a, b, c)
Definition win32.h:13
const void size_t len
static int fb(int x)
static struct cvec * range(struct vars *v, chr a, chr b, int cases)