PostgreSQL Source Code git master
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Pages
oauth_server.py
Go to the documentation of this file.
1#! /usr/bin/env python3
2#
3# A mock OAuth authorization server, designed to be invoked from
4# OAuth/Server.pm. This listens on an ephemeral port number (printed to stdout
5# so that the Perl tests can contact it) and runs as a daemon until it is
6# signaled.
7#
8
9import base64
10import http.server
11import json
12import os
13import sys
14import time
15import urllib.parse
16from collections import defaultdict
17
18
19class OAuthHandler(http.server.BaseHTTPRequestHandler):
20 """
21 Core implementation of the authorization server. The API is
22 inheritance-based, with entry points at do_GET() and do_POST(). See the
23 documentation for BaseHTTPRequestHandler.
24 """
25
26 JsonObject = dict[str, object] # TypeAlias is not available until 3.10
27
28 def _check_issuer(self):
29 """
30 Switches the behavior of the provider depending on the issuer URI.
31 """
32 self._alt_issuer = (
33 self.path.startswith("/alternate/")
34 or self.path == "/.well-known/oauth-authorization-server/alternate"
35 )
36 self._parameterized = self.path.startswith("/param/")
37
38 if self._alt_issuer:
39 # The /alternate issuer uses IETF-style .well-known URIs.
40 if self.path.startswith("/.well-known/"):
41 self.path = self.path.removesuffix("/alternate")
42 else:
43 self.path = self.path.removeprefix("/alternate")
44 elif self._parameterized:
45 self.path = self.path.removeprefix("/param")
46
47 def _check_authn(self):
48 """
49 Checks the expected value of the Authorization header, if any.
50 """
51 secret = self._get_param("expected_secret", None)
52 if secret is None:
53 return
54
55 assert "Authorization" in self.headers
56 method, creds = self.headers["Authorization"].split()
57
58 if method != "Basic":
59 raise RuntimeError(f"client used {method} auth; expected Basic")
60
61 username = urllib.parse.quote_plus(self.client_id)
62 password = urllib.parse.quote_plus(secret)
63 expected_creds = f"{username}:{password}"
64
65 if creds.encode() != base64.b64encode(expected_creds.encode()):
66 raise RuntimeError(
67 f"client sent '{creds}'; expected b64encode('{expected_creds}')"
68 )
69
70 def do_GET(self):
71 self._response_code = 200
72 self._check_issuer()
73
74 config_path = "/.well-known/openid-configuration"
75 if self._alt_issuer:
76 config_path = "/.well-known/oauth-authorization-server"
77
78 if self.path == config_path:
79 resp = self.config()
80 else:
81 self.send_error(404, "Not Found")
82 return
83
84 self._send_json(resp)
85
86 def _parse_params(self) -> dict[str, str]:
87 """
88 Parses apart the form-urlencoded request body and returns the resulting
89 dict. For use by do_POST().
90 """
91 size = int(self.headers["Content-Length"])
92 form = self.rfile.read(size)
93
94 assert self.headers["Content-Type"] == "application/x-www-form-urlencoded"
95 return urllib.parse.parse_qs(
96 form.decode("utf-8"),
97 strict_parsing=True,
98 keep_blank_values=True,
99 encoding="utf-8",
100 errors="strict",
101 )
102
103 @property
104 def client_id(self) -> str:
105 """
106 Returns the client_id sent in the POST body or the Authorization header.
107 self._parse_params() must have been called first.
108 """
109 if "client_id" in self._params:
110 return self._params["client_id"][0]
111
112 if "Authorization" not in self.headers:
113 raise RuntimeError("client did not send any client_id")
114
115 _, creds = self.headers["Authorization"].split()
116
117 decoded = base64.b64decode(creds).decode("utf-8")
118 username, _ = decoded.split(":", 1)
119
120 return urllib.parse.unquote_plus(username)
121
122 def do_POST(self):
123 self._response_code = 200
124 self._check_issuer()
125
126 self._params = self._parse_params()
127 if self._parameterized:
128 # Pull encoded test parameters out of the peer's client_id field.
129 # This is expected to be Base64-encoded JSON.
130 js = base64.b64decode(self.client_id)
131 self._test_params = json.loads(js)
132
133 self._check_authn()
134
135 if self.path == "/authorize":
136 resp = self.authorization()
137 elif self.path == "/token":
138 resp = self.token()
139 else:
140 self.send_error(404)
141 return
142
143 self._send_json(resp)
144
145 def _should_modify(self) -> bool:
146 """
147 Returns True if the client has requested a modification to this stage of
148 the exchange.
149 """
150 if not hasattr(self, "_test_params"):
151 return False
152
153 stage = self._test_params.get("stage")
154
155 return (
156 stage == "all"
157 or (
158 stage == "discovery"
159 and self.path == "/.well-known/openid-configuration"
160 )
161 or (stage == "device" and self.path == "/authorize")
162 or (stage == "token" and self.path == "/token")
163 )
164
165 def _get_param(self, name, default):
166 """
167 If the client has requested a modification to this stage (see
168 _should_modify()), this method searches the provided test parameters for
169 a key of the given name, and returns it if found. Otherwise the provided
170 default is returned.
171 """
172 if self._should_modify() and name in self._test_params:
173 return self._test_params[name]
174
175 return default
176
177 @property
178 def _content_type(self) -> str:
179 """
180 Returns "application/json" unless the test has requested something
181 different.
182 """
183 return self._get_param("content_type", "application/json")
184
185 @property
186 def _interval(self) -> int:
187 """
188 Returns 0 unless the test has requested something different.
189 """
190 return self._get_param("interval", 0)
191
192 @property
193 def _retry_code(self) -> str:
194 """
195 Returns "authorization_pending" unless the test has requested something
196 different.
197 """
198 return self._get_param("retry_code", "authorization_pending")
199
200 @property
201 def _uri_spelling(self) -> str:
202 """
203 Returns "verification_uri" unless the test has requested something
204 different.
205 """
206 return self._get_param("uri_spelling", "verification_uri")
207
208 @property
210 """
211 If the huge_response test parameter is set to True, returns a dict
212 containing a gigantic string value, which can then be folded into a JSON
213 response.
214 """
215 if not self._get_param("huge_response", False):
216 return dict()
217
218 return {"_pad_": "x" * 1024 * 1024}
219
220 @property
221 def _access_token(self):
222 """
223 The actual Bearer token sent back to the client on success. Tests may
224 override this with the "token" test parameter.
225 """
226 token = self._get_param("token", None)
227 if token is not None:
228 return token
229
230 token = "9243959234"
231 if self._alt_issuer:
232 token += "-alt"
233
234 return token
235
236 def _send_json(self, js: JsonObject) -> None:
237 """
238 Sends the provided JSON dict as an application/json response.
239 self._response_code can be modified to send JSON error responses.
240 """
241 resp = json.dumps(js).encode("ascii")
242 self.log_message("sending JSON response: %s", resp)
243
244 self.send_response(self._response_code)
245 self.send_header("Content-Type", self._content_type)
246 self.send_header("Content-Length", str(len(resp)))
247 self.end_headers()
248
249 self.wfile.write(resp)
250
251 def config(self) -> JsonObject:
252 port = self.server.socket.getsockname()[1]
253
254 issuer = f"http://localhost:{port}"
255 if self._alt_issuer:
256 issuer += "/alternate"
257 elif self._parameterized:
258 issuer += "/param"
259
260 return {
261 "issuer": issuer,
262 "token_endpoint": issuer + "/token",
263 "device_authorization_endpoint": issuer + "/authorize",
264 "response_types_supported": ["token"],
265 "subject_types_supported": ["public"],
266 "id_token_signing_alg_values_supported": ["RS256"],
267 "grant_types_supported": [
268 "authorization_code",
269 "urn:ietf:params:oauth:grant-type:device_code",
270 ],
271 }
272
273 @property
274 def _token_state(self):
275 """
276 A cached _TokenState object for the connected client (as determined by
277 the request's client_id), or a new one if it doesn't already exist.
278
279 This relies on the existence of a defaultdict attached to the server;
280 see main() below.
281 """
282 return self.server.token_state[self.client_id]
283
285 """
286 Removes any cached _TokenState for the current client_id. Call this
287 after the token exchange ends to get rid of unnecessary state.
288 """
289 if self.client_id in self.server.token_state:
290 del self.server.token_state[self.client_id]
291
292 def authorization(self) -> JsonObject:
293 uri = "https://example.com/"
294 if self._alt_issuer:
295 uri = "https://example.org/"
296
297 resp = {
298 "device_code": "postgres",
299 "user_code": "postgresuser",
300 self._uri_spelling: uri,
301 "expires_in": 5,
302 **self._response_padding,
303 }
304
305 interval = self._interval
306 if interval is not None:
307 resp["interval"] = interval
308 self._token_state.min_delay = interval
309 else:
310 self._token_state.min_delay = 5 # default
311
312 # Check the scope.
313 if "scope" in self._params:
314 assert self._params["scope"][0], "empty scopes should be omitted"
315
316 return resp
317
318 def token(self) -> JsonObject:
319 if err := self._get_param("error_code", None):
320 self._response_code = self._get_param("error_status", 400)
321
322 resp = {"error": err}
323 if desc := self._get_param("error_desc", ""):
324 resp["error_description"] = desc
325
326 return resp
327
328 if self._should_modify() and "retries" in self._test_params:
329 retries = self._test_params["retries"]
330
331 # Check to make sure the token interval is being respected.
332 now = time.monotonic()
333 if self._token_state.last_try is not None:
334 delay = now - self._token_state.last_try
335 assert (
336 delay > self._token_state.min_delay
337 ), f"client waited only {delay} seconds between token requests (expected {self._token_state.min_delay})"
338
339 self._token_state.last_try = now
340
341 # If we haven't reached the required number of retries yet, return a
342 # "pending" response.
343 if self._token_state.retries < retries:
344 self._token_state.retries += 1
345
346 self._response_code = 400
347 return {"error": self._retry_code}
348
349 # Clean up any retry tracking state now that the exchange is ending.
351
352 return {
353 "access_token": self._access_token,
354 "token_type": "bearer",
355 **self._response_padding,
356 }
357
358
359def main():
360 """
361 Starts the authorization server on localhost. The ephemeral port in use will
362 be printed to stdout.
363 """
364
365 s = http.server.HTTPServer(("127.0.0.1", 0), OAuthHandler)
366
367 # Attach a "cache" dictionary to the server to allow the OAuthHandlers to
368 # track state across token requests. The use of defaultdict ensures that new
369 # entries will be created automatically.
370 class _TokenState:
371 retries = 0
372 min_delay = None
373 last_try = None
374
375 s.token_state = defaultdict(_TokenState)
376
377 # Give the parent the port number to contact (this is also the signal that
378 # we're ready to receive requests).
379 port = s.socket.getsockname()[1]
380 print(port)
381
382 # stdout is closed to allow the parent to just "read to the end".
383 stdout = sys.stdout.fileno()
384 sys.stdout.close()
385 os.close(stdout)
386
387 s.serve_forever() # we expect our parent to send a termination signal
388
389
390if __name__ == "__main__":
391 main()
void print(const void *obj)
Definition: print.c:36
JsonObject authorization(self)
None _send_json(self, JsonObject js)
JsonObject token(self)
def _get_param(self, name, default)
JsonObject config(self)
dict[str, str] _parse_params(self)
Definition: oauth_server.py:86
const char * str
#define write(a, b, c)
Definition: win32.h:14
#define read(a, b, c)
Definition: win32.h:13
const void size_t len