PostgreSQL Source Code git master
Loading...
Searching...
No Matches
oauth_server.py
Go to the documentation of this file.
1#! /usr/bin/env python3
2#
3# A mock OAuth authorization server, designed to be invoked from
4# OAuth/Server.pm. This listens on an ephemeral port number (printed to stdout
5# so that the Perl tests can contact it) and runs as a daemon until it is
6# signaled.
7#
8
9import base64
10import functools
11import http.server
12import json
13import os
14import ssl
15import sys
16import time
17import urllib.parse
18from collections import defaultdict
19from typing import Dict
20
21ssl_dir = os.getenv("cert_dir")
22ssl_cert = ssl_dir + "/server-localhost-alt-names.crt"
23ssl_key = ssl_dir + "/server-localhost-alt-names.key"
24
25
27 """
28 Core implementation of the authorization server. The API is
29 inheritance-based, with entry points at do_GET() and do_POST(). See the
30 documentation for BaseHTTPRequestHandler.
31 """
32
33 JsonObject = Dict[str, object] # TypeAlias is not available until 3.10
34
35 def _check_issuer(self):
36 """
37 Switches the behavior of the provider depending on the issuer URI.
38 """
39 self._alt_issuer = (
40 self.path.startswith("/alternate/")
41 or self.path == "/.well-known/oauth-authorization-server/alternate"
42 )
43 self._parameterized = self.path.startswith("/param/")
44
45 # Strip off the magic path segment. (The more readable
46 # str.removeprefix()/removesuffix() aren't available until Py3.9.)
47 if self._alt_issuer:
48 # The /alternate issuer uses IETF-style .well-known URIs.
49 if self.path.startswith("/.well-known/"):
50 self.path = self.path[: -len("/alternate")]
51 else:
52 self.path = self.path[len("/alternate") :]
53 elif self._parameterized:
54 self.path = self.path[len("/param") :]
55
56 def _check_authn(self):
57 """
58 Checks the expected value of the Authorization header, if any.
59 """
60 secret = self._get_param("expected_secret", None)
61 if secret is None:
62 return
63
64 assert "Authorization" in self.headers
65 method, creds = self.headers["Authorization"].split()
66
67 if method != "Basic":
68 raise RuntimeError(f"client used {method} auth; expected Basic")
69
70 # TODO: Remove "~" from the safe list after Py3.6 support is removed.
71 # 3.7 does this by default.
72 username = urllib.parse.quote_plus(self.client_id, safe="~")
73 password = urllib.parse.quote_plus(secret, safe="~")
74 expected_creds = f"{username}:{password}"
75
77 raise RuntimeError(
78 f"client sent '{creds}'; expected b64encode('{expected_creds}')"
79 )
80
81 def do_GET(self):
82 self._response_code = 200
83 self._check_issuer()
84
85 config_path = "/.well-known/openid-configuration"
86 if self._alt_issuer:
87 config_path = "/.well-known/oauth-authorization-server"
88
89 if self.path == config_path:
90 resp = self.config()
91 else:
92 self.send_error(404, "Not Found")
93 return
94
95 self._send_json(resp)
96
97 def _parse_params(self) -> Dict[str, str]:
98 """
99 Parses apart the form-urlencoded request body and returns the resulting
100 dict. For use by do_POST().
101 """
102 size = int(self.headers["Content-Length"])
103 form = self.rfile.read(size)
104
105 assert self.headers["Content-Type"] == "application/x-www-form-urlencoded"
107 form.decode("utf-8"),
108 strict_parsing=True,
109 keep_blank_values=True,
110 encoding="utf-8",
111 errors="strict",
112 )
113
114 @property
115 def client_id(self) -> str:
116 """
117 Returns the client_id sent in the POST body or the Authorization header.
118 self._parse_params() must have been called first.
119 """
120 if "client_id" in self._params:
121 return self._params["client_id"][0]
122
123 if "Authorization" not in self.headers:
124 raise RuntimeError("client did not send any client_id")
125
126 _, creds = self.headers["Authorization"].split()
127
128 decoded = base64.b64decode(creds).decode("utf-8")
129 username, _ = decoded.split(":", 1)
130
131 return urllib.parse.unquote_plus(username)
132
133 def do_POST(self):
134 self._response_code = 200
135 self._check_issuer()
136
137 self._params = self._parse_params()
138 if self._parameterized:
139 # Pull encoded test parameters out of the peer's client_id field.
140 # This is expected to be Base64-encoded JSON.
141 js = base64.b64decode(self.client_id)
143
144 self._check_authn()
145
146 if self.path == "/authorize":
147 resp = self.authorization()
148 elif self.path == "/token":
149 resp = self.token()
150 else:
151 self.send_error(404)
152 return
153
154 self._send_json(resp)
155
156 def _should_modify(self) -> bool:
157 """
158 Returns True if the client has requested a modification to this stage of
159 the exchange.
160 """
161 if not hasattr(self, "_test_params"):
162 return False
163
164 stage = self._test_params.get("stage")
165
166 return (
167 stage == "all"
168 or (
169 stage == "discovery"
170 and self.path == "/.well-known/openid-configuration"
171 )
172 or (stage == "device" and self.path == "/authorize")
173 or (stage == "token" and self.path == "/token")
174 )
175
176 def _get_param(self, name, default):
177 """
178 If the client has requested a modification to this stage (see
179 _should_modify()), this method searches the provided test parameters for
180 a key of the given name, and returns it if found. Otherwise the provided
181 default is returned.
182 """
183 if self._should_modify() and name in self._test_params:
184 return self._test_params[name]
185
186 return default
187
188 @property
189 def _content_type(self) -> str:
190 """
191 Returns "application/json" unless the test has requested something
192 different.
193 """
194 return self._get_param("content_type", "application/json")
195
196 @property
197 def _interval(self) -> int:
198 """
199 Returns 0 unless the test has requested something different.
200 """
201 return self._get_param("interval", 0)
202
203 @property
204 def _retry_code(self) -> str:
205 """
206 Returns "authorization_pending" unless the test has requested something
207 different.
208 """
209 return self._get_param("retry_code", "authorization_pending")
210
211 @property
212 def _uri_spelling(self) -> str:
213 """
214 Returns "verification_uri" unless the test has requested something
215 different.
216 """
217 return self._get_param("uri_spelling", "verification_uri")
218
219 @property
221 """
222 Returns a dict with any additional entries that should be folded into a
223 JSON response, as determined by test parameters provided by the client:
224
225 - huge_response: if set to True, the dict will contain a gigantic string
226 value
227
228 - nested_array: if set to nonzero, the dict will contain a deeply nested
229 array so that the top-level object has the given depth
230
231 - nested_object: if set to nonzero, the dict will contain a deeply
232 nested JSON object so that the top-level object has the given depth
233 """
234 ret = dict()
235
236 if self._get_param("huge_response", False):
237 ret["_pad_"] = "x" * 1024 * 1024
238
239 depth = self._get_param("nested_array", 0)
240 if depth:
241 ret["_arr_"] = functools.reduce(lambda x, _: [x], range(depth))
242
243 depth = self._get_param("nested_object", 0)
244 if depth:
245 ret["_obj_"] = functools.reduce(lambda x, _: {"": x}, range(depth))
246
247 return ret
248
249 @property
250 def _access_token(self):
251 """
252 The actual Bearer token sent back to the client on success. Tests may
253 override this with the "token" test parameter.
254 """
255 token = self._get_param("token", None)
256 if token is not None:
257 return token
258
259 token = "9243959234"
260 if self._alt_issuer:
261 token += "-alt"
262
263 return token
264
265 def _log_response(self, js: JsonObject) -> None:
266 """
267 Trims the response JSON, if necessary, and logs it for later debugging.
268 """
269 # At the moment the biggest problem for tests is the _pad_ member, which
270 # is a megabyte in size, so truncate that to something more reasonable.
271 if "_pad_" in js:
272 pad = js["_pad_"]
273
274 # Don't modify the original dict.
275 js = dict(js)
276 js["_pad_"] = pad[:64] + f"[...truncated from {len(pad)} bytes]"
277
278 resp = json.dumps(js).encode("ascii")
279 self.log_message("sending JSON response: %s", resp)
280
281 # If you've tripped this assertion, please truncate the new addition as
282 # above, or else come up with a new strategy.
283 assert len(resp) < 1024, "_log_response must be adjusted for new JSON"
284
285 def _send_json(self, js: JsonObject) -> None:
286 """
287 Sends the provided JSON dict as an application/json response.
288 self._response_code can be modified to send JSON error responses.
289 """
290 resp = json.dumps(js).encode("ascii")
291 self._log_response(js)
292
293 self.send_response(self._response_code)
294 self.send_header("Content-Type", self._content_type_content_type)
295 self.send_header("Content-Length", str(len(resp)))
296 self.end_headers()
297
298 self.wfile.write(resp)
299
300 def config(self) -> JsonObject:
301 port = self.server.socket.getsockname()[1]
302
303 # XXX This IPv4-only Issuer can't be changed to "localhost" unless our
304 # server also listens on the corresponding IPv6 port when available.
305 # Otherwise, other processes with ephemeral sockets could accidentally
306 # interfere with our Curl client, causing intermittent failures.
307 issuer = f"https://127.0.0.1:{port}"
308 if self._alt_issuer:
309 issuer += "/alternate"
310 elif self._parameterized:
311 issuer += "/param"
312
313 return {
314 "issuer": issuer,
315 "token_endpoint": issuer + "/token",
316 "device_authorization_endpoint": issuer + "/authorize",
317 "response_types_supported": ["token"],
318 "subject_types_supported": ["public"],
319 "id_token_signing_alg_values_supported": ["RS256"],
320 "grant_types_supported": [
321 "authorization_code",
322 "urn:ietf:params:oauth:grant-type:device_code",
323 ],
324 }
325
326 @property
327 def _token_state(self):
328 """
329 A cached _TokenState object for the connected client (as determined by
330 the request's client_id), or a new one if it doesn't already exist.
331
332 This relies on the existence of a defaultdict attached to the server;
333 see main() below.
334 """
335 return self.server.token_state[self.client_id]
336
338 """
339 Removes any cached _TokenState for the current client_id. Call this
340 after the token exchange ends to get rid of unnecessary state.
341 """
342 if self.client_id in self.server.token_state:
343 del self.server.token_state[self.client_id]
344
345 def authorization(self) -> JsonObject:
346 uri = "https://example.com/"
347 if self._alt_issuer:
348 uri = "https://example.org/"
349
350 resp = {
351 "device_code": "postgres",
352 "user_code": "postgresuser",
353 self._uri_spelling: uri,
354 "expires_in": 5,
356 }
357
358 interval = self._interval
359 if interval is not None:
360 resp["interval"] = interval
361 self._token_state.min_delay = interval
362 else:
363 self._token_state.min_delay = 5 # default
364
365 # Check the scope.
366 if "scope" in self._params:
367 assert self._params["scope"][0], "empty scopes should be omitted"
368
369 return resp
370
371 def token(self) -> JsonObject:
372 err = self._get_param("error_code", None)
373 if err:
374 self._response_code = self._get_param("error_status", 400)
375
376 resp = {"error": err}
377
378 desc = self._get_param("error_desc", "")
379 if desc:
380 resp["error_description"] = desc
381
382 return resp
383
384 if self._should_modify() and "retries" in self._test_params:
385 retries = self._test_params["retries"]
386
387 # Check to make sure the token interval is being respected.
388 now = time.monotonic()
389 if self._token_state.last_try is not None:
390 delay = now - self._token_state.last_try
391 assert (
392 delay > self._token_state.min_delay
393 ), f"client waited only {delay} seconds between token requests (expected {self._token_state.min_delay})"
394
395 self._token_state.last_try = now
396
397 # If we haven't reached the required number of retries yet, return a
398 # "pending" response.
399 if self._token_state.retries < retries:
400 self._token_state.retries += 1
401
402 self._response_code = 400
403 return {"error": self._retry_code}
404
405 # Clean up any retry tracking state now that the exchange is ending.
407
408 return {
409 "access_token": self._access_token_access_token,
410 "token_type": "bearer",
412 }
413
414
415def main():
416 """
417 Starts the authorization server on localhost. The ephemeral port in use will
418 be printed to stdout.
419 """
420 # XXX Listen exclusively on IPv4. Listening on a dual-stack socket would be
421 # more true-to-life, but every OS/Python combination in the buildfarm and CI
422 # would need to provide the functionality first.
423 s = http.server.HTTPServer(("127.0.0.1", 0), OAuthHandler)
424
425 # Speak HTTPS.
426 # TODO: switch to HTTPSServer with Python 3.14
428 ssl_context.load_cert_chain(ssl_cert, ssl_key)
429
430 s.socket = ssl_context.wrap_socket(s.socket, server_side=True)
431
432 # Attach a "cache" dictionary to the server to allow the OAuthHandlers to
433 # track state across token requests. The use of defaultdict ensures that new
434 # entries will be created automatically.
435 class _TokenState:
436 retries = 0
437 min_delay = None
438 last_try = None
439
440 s.token_state = defaultdict(_TokenState)
441
442 # Give the parent the port number to contact (this is also the signal that
443 # we're ready to receive requests).
444 port = s.socket.getsockname()[1]
445 print(port)
446
447 # stdout is closed to allow the parent to just "read to the end".
448 stdout = sys.stdout.fileno()
450 os.close(stdout)
451
452 s.serve_forever() # we expect our parent to send a termination signal
453
454
455if __name__ == "__main__":
456 main()
void print(const void *obj)
Definition print.c:36
Dict[str, str] _parse_params(self)
JsonObject authorization(self)
None _send_json(self, JsonObject js)
None _log_response(self, JsonObject js)
_get_param(self, name, default)
JsonObject config(self)
const char * str
#define token
#define write(a, b, c)
Definition win32.h:14
#define read(a, b, c)
Definition win32.h:13
const void size_t len
static int fb(int x)
static struct cvec * range(struct vars *v, chr a, chr b, int cases)