PostgreSQL Source Code git master
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Pages
oauth_server.py
Go to the documentation of this file.
1#! /usr/bin/env python3
2#
3# A mock OAuth authorization server, designed to be invoked from
4# OAuth/Server.pm. This listens on an ephemeral port number (printed to stdout
5# so that the Perl tests can contact it) and runs as a daemon until it is
6# signaled.
7#
8
9import base64
10import http.server
11import json
12import os
13import sys
14import time
15import urllib.parse
16from collections import defaultdict
17from typing import Dict
18
19
20class OAuthHandler(http.server.BaseHTTPRequestHandler):
21 """
22 Core implementation of the authorization server. The API is
23 inheritance-based, with entry points at do_GET() and do_POST(). See the
24 documentation for BaseHTTPRequestHandler.
25 """
26
27 JsonObject = Dict[str, object] # TypeAlias is not available until 3.10
28
29 def _check_issuer(self):
30 """
31 Switches the behavior of the provider depending on the issuer URI.
32 """
33 self._alt_issuer = (
34 self.path.startswith("/alternate/")
35 or self.path == "/.well-known/oauth-authorization-server/alternate"
36 )
37 self._parameterized = self.path.startswith("/param/")
38
39 # Strip off the magic path segment. (The more readable
40 # str.removeprefix()/removesuffix() aren't available until Py3.9.)
41 if self._alt_issuer:
42 # The /alternate issuer uses IETF-style .well-known URIs.
43 if self.path.startswith("/.well-known/"):
44 self.path = self.path[: -len("/alternate")]
45 else:
46 self.path = self.path[len("/alternate") :]
47 elif self._parameterized:
48 self.path = self.path[len("/param") :]
49
50 def _check_authn(self):
51 """
52 Checks the expected value of the Authorization header, if any.
53 """
54 secret = self._get_param("expected_secret", None)
55 if secret is None:
56 return
57
58 assert "Authorization" in self.headers
59 method, creds = self.headers["Authorization"].split()
60
61 if method != "Basic":
62 raise RuntimeError(f"client used {method} auth; expected Basic")
63
64 # TODO: Remove "~" from the safe list after Py3.6 support is removed.
65 # 3.7 does this by default.
66 username = urllib.parse.quote_plus(self.client_id, safe="~")
67 password = urllib.parse.quote_plus(secret, safe="~")
68 expected_creds = f"{username}:{password}"
69
70 if creds.encode() != base64.b64encode(expected_creds.encode()):
71 raise RuntimeError(
72 f"client sent '{creds}'; expected b64encode('{expected_creds}')"
73 )
74
75 def do_GET(self):
76 self._response_code = 200
77 self._check_issuer()
78
79 config_path = "/.well-known/openid-configuration"
80 if self._alt_issuer:
81 config_path = "/.well-known/oauth-authorization-server"
82
83 if self.path == config_path:
84 resp = self.config()
85 else:
86 self.send_error(404, "Not Found")
87 return
88
89 self._send_json(resp)
90
91 def _parse_params(self) -> Dict[str, str]:
92 """
93 Parses apart the form-urlencoded request body and returns the resulting
94 dict. For use by do_POST().
95 """
96 size = int(self.headers["Content-Length"])
97 form = self.rfile.read(size)
98
99 assert self.headers["Content-Type"] == "application/x-www-form-urlencoded"
100 return urllib.parse.parse_qs(
101 form.decode("utf-8"),
102 strict_parsing=True,
103 keep_blank_values=True,
104 encoding="utf-8",
105 errors="strict",
106 )
107
108 @property
109 def client_id(self) -> str:
110 """
111 Returns the client_id sent in the POST body or the Authorization header.
112 self._parse_params() must have been called first.
113 """
114 if "client_id" in self._params:
115 return self._params["client_id"][0]
116
117 if "Authorization" not in self.headers:
118 raise RuntimeError("client did not send any client_id")
119
120 _, creds = self.headers["Authorization"].split()
121
122 decoded = base64.b64decode(creds).decode("utf-8")
123 username, _ = decoded.split(":", 1)
124
125 return urllib.parse.unquote_plus(username)
126
127 def do_POST(self):
128 self._response_code = 200
129 self._check_issuer()
130
131 self._params = self._parse_params()
132 if self._parameterized:
133 # Pull encoded test parameters out of the peer's client_id field.
134 # This is expected to be Base64-encoded JSON.
135 js = base64.b64decode(self.client_id)
136 self._test_params = json.loads(js)
137
138 self._check_authn()
139
140 if self.path == "/authorize":
141 resp = self.authorization()
142 elif self.path == "/token":
143 resp = self.token()
144 else:
145 self.send_error(404)
146 return
147
148 self._send_json(resp)
149
150 def _should_modify(self) -> bool:
151 """
152 Returns True if the client has requested a modification to this stage of
153 the exchange.
154 """
155 if not hasattr(self, "_test_params"):
156 return False
157
158 stage = self._test_params.get("stage")
159
160 return (
161 stage == "all"
162 or (
163 stage == "discovery"
164 and self.path == "/.well-known/openid-configuration"
165 )
166 or (stage == "device" and self.path == "/authorize")
167 or (stage == "token" and self.path == "/token")
168 )
169
170 def _get_param(self, name, default):
171 """
172 If the client has requested a modification to this stage (see
173 _should_modify()), this method searches the provided test parameters for
174 a key of the given name, and returns it if found. Otherwise the provided
175 default is returned.
176 """
177 if self._should_modify() and name in self._test_params:
178 return self._test_params[name]
179
180 return default
181
182 @property
183 def _content_type(self) -> str:
184 """
185 Returns "application/json" unless the test has requested something
186 different.
187 """
188 return self._get_param("content_type", "application/json")
189
190 @property
191 def _interval(self) -> int:
192 """
193 Returns 0 unless the test has requested something different.
194 """
195 return self._get_param("interval", 0)
196
197 @property
198 def _retry_code(self) -> str:
199 """
200 Returns "authorization_pending" unless the test has requested something
201 different.
202 """
203 return self._get_param("retry_code", "authorization_pending")
204
205 @property
206 def _uri_spelling(self) -> str:
207 """
208 Returns "verification_uri" unless the test has requested something
209 different.
210 """
211 return self._get_param("uri_spelling", "verification_uri")
212
213 @property
215 """
216 If the huge_response test parameter is set to True, returns a dict
217 containing a gigantic string value, which can then be folded into a JSON
218 response.
219 """
220 if not self._get_param("huge_response", False):
221 return dict()
222
223 return {"_pad_": "x" * 1024 * 1024}
224
225 @property
226 def _access_token(self):
227 """
228 The actual Bearer token sent back to the client on success. Tests may
229 override this with the "token" test parameter.
230 """
231 token = self._get_param("token", None)
232 if token is not None:
233 return token
234
235 token = "9243959234"
236 if self._alt_issuer:
237 token += "-alt"
238
239 return token
240
241 def _send_json(self, js: JsonObject) -> None:
242 """
243 Sends the provided JSON dict as an application/json response.
244 self._response_code can be modified to send JSON error responses.
245 """
246 resp = json.dumps(js).encode("ascii")
247 self.log_message("sending JSON response: %s", resp)
248
249 self.send_response(self._response_code)
250 self.send_header("Content-Type", self._content_type)
251 self.send_header("Content-Length", str(len(resp)))
252 self.end_headers()
253
254 self.wfile.write(resp)
255
256 def config(self) -> JsonObject:
257 port = self.server.socket.getsockname()[1]
258
259 issuer = f"http://127.0.0.1:{port}"
260 if self._alt_issuer:
261 issuer += "/alternate"
262 elif self._parameterized:
263 issuer += "/param"
264
265 return {
266 "issuer": issuer,
267 "token_endpoint": issuer + "/token",
268 "device_authorization_endpoint": issuer + "/authorize",
269 "response_types_supported": ["token"],
270 "subject_types_supported": ["public"],
271 "id_token_signing_alg_values_supported": ["RS256"],
272 "grant_types_supported": [
273 "authorization_code",
274 "urn:ietf:params:oauth:grant-type:device_code",
275 ],
276 }
277
278 @property
279 def _token_state(self):
280 """
281 A cached _TokenState object for the connected client (as determined by
282 the request's client_id), or a new one if it doesn't already exist.
283
284 This relies on the existence of a defaultdict attached to the server;
285 see main() below.
286 """
287 return self.server.token_state[self.client_id]
288
290 """
291 Removes any cached _TokenState for the current client_id. Call this
292 after the token exchange ends to get rid of unnecessary state.
293 """
294 if self.client_id in self.server.token_state:
295 del self.server.token_state[self.client_id]
296
297 def authorization(self) -> JsonObject:
298 uri = "https://example.com/"
299 if self._alt_issuer:
300 uri = "https://example.org/"
301
302 resp = {
303 "device_code": "postgres",
304 "user_code": "postgresuser",
305 self._uri_spelling: uri,
306 "expires_in": 5,
307 **self._response_padding,
308 }
309
310 interval = self._interval
311 if interval is not None:
312 resp["interval"] = interval
313 self._token_state.min_delay = interval
314 else:
315 self._token_state.min_delay = 5 # default
316
317 # Check the scope.
318 if "scope" in self._params:
319 assert self._params["scope"][0], "empty scopes should be omitted"
320
321 return resp
322
323 def token(self) -> JsonObject:
324 err = self._get_param("error_code", None)
325 if err:
326 self._response_code = self._get_param("error_status", 400)
327
328 resp = {"error": err}
329
330 desc = self._get_param("error_desc", "")
331 if desc:
332 resp["error_description"] = desc
333
334 return resp
335
336 if self._should_modify() and "retries" in self._test_params:
337 retries = self._test_params["retries"]
338
339 # Check to make sure the token interval is being respected.
340 now = time.monotonic()
341 if self._token_state.last_try is not None:
342 delay = now - self._token_state.last_try
343 assert (
344 delay > self._token_state.min_delay
345 ), f"client waited only {delay} seconds between token requests (expected {self._token_state.min_delay})"
346
347 self._token_state.last_try = now
348
349 # If we haven't reached the required number of retries yet, return a
350 # "pending" response.
351 if self._token_state.retries < retries:
352 self._token_state.retries += 1
353
354 self._response_code = 400
355 return {"error": self._retry_code}
356
357 # Clean up any retry tracking state now that the exchange is ending.
359
360 return {
361 "access_token": self._access_token,
362 "token_type": "bearer",
363 **self._response_padding,
364 }
365
366
367def main():
368 """
369 Starts the authorization server on localhost. The ephemeral port in use will
370 be printed to stdout.
371 """
372
373 s = http.server.HTTPServer(("127.0.0.1", 0), OAuthHandler)
374
375 # Attach a "cache" dictionary to the server to allow the OAuthHandlers to
376 # track state across token requests. The use of defaultdict ensures that new
377 # entries will be created automatically.
378 class _TokenState:
379 retries = 0
380 min_delay = None
381 last_try = None
382
383 s.token_state = defaultdict(_TokenState)
384
385 # Give the parent the port number to contact (this is also the signal that
386 # we're ready to receive requests).
387 port = s.socket.getsockname()[1]
388 print(port)
389
390 # stdout is closed to allow the parent to just "read to the end".
391 stdout = sys.stdout.fileno()
392 sys.stdout.close()
393 os.close(stdout)
394
395 s.serve_forever() # we expect our parent to send a termination signal
396
397
398if __name__ == "__main__":
399 main()
void print(const void *obj)
Definition: print.c:36
Dict[str, str] _parse_params(self)
Definition: oauth_server.py:91
JsonObject authorization(self)
None _send_json(self, JsonObject js)
JsonObject token(self)
def _get_param(self, name, default)
JsonObject config(self)
const char * str
#define write(a, b, c)
Definition: win32.h:14
#define read(a, b, c)
Definition: win32.h:13
const void size_t len