112 const char *shadow_pass);
113 static int scram_exchange(
void *opaq,
const char *input,
int inputlen,
114 char **
output,
int *outputlen,
115 const char **logdetail);
181 char **salt,
uint8 *stored_key,
uint8 *server_key);
203 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
204 if (
port->ssl_in_use)
248 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
250 state->channel_binding_in_use =
true;
254 state->channel_binding_in_use =
false;
257 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
258 errmsg(
"client selected an invalid SASL authentication mechanism")));
279 (
errmsg(
"invalid SCRAM secret for user \"%s\"",
280 state->port->user_name)));
290 state->logdetail =
psprintf(
_(
"User \"%s\" does not have a valid SCRAM secret."),
291 state->port->user_name);
315 state->doomed =
true;
339 char **
output,
int *outputlen,
const char **logdetail)
367 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
368 errmsg(
"malformed SCRAM message"),
370 if (inputlen != strlen(input))
372 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
373 errmsg(
"malformed SCRAM message"),
374 errdetail(
"Message length does not match input length.")));
376 switch (
state->state)
405 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
406 errmsg(
"invalid SCRAM response"),
443 elog(
ERROR,
"invalid SCRAM exchange state");
448 *logdetail =
state->logdetail;
451 *outputlen = strlen(*
output);
468 const char *errstr = NULL;
477 password = (
const char *) prep_password;
482 (
errcode(ERRCODE_INTERNAL_ERROR),
483 errmsg(
"could not generate random salt")));
490 pfree(prep_password);
514 const char *errstr = NULL;
517 stored_key, server_key))
529 saltlen =
pg_b64_decode(encoded_salt, strlen(encoded_salt), salt,
545 salted_password, &errstr) < 0 ||
548 elog(
ERROR,
"could not compute server key: %s", errstr);
552 pfree(prep_password);
581 char *iterations_str;
585 char *decoded_salt_buf;
586 char *decoded_stored_buf;
587 char *decoded_server_buf;
595 if ((scheme_str = strtok(v,
"$")) == NULL)
597 if ((iterations_str = strtok(NULL,
":")) == NULL)
599 if ((salt_str = strtok(NULL,
"$")) == NULL)
601 if ((storedkey_str = strtok(NULL,
":")) == NULL)
603 if ((serverkey_str = strtok(NULL,
"")) == NULL)
607 if (strcmp(scheme_str,
"SCRAM-SHA-256") != 0)
611 *iterations = strtol(iterations_str, &p, 10);
612 if (*p || errno != 0)
620 decoded_salt_buf =
palloc(decoded_len);
622 decoded_salt_buf, decoded_len);
631 decoded_stored_buf =
palloc(decoded_len);
632 decoded_len =
pg_b64_decode(storedkey_str, strlen(storedkey_str),
633 decoded_stored_buf, decoded_len);
639 decoded_server_buf =
palloc(decoded_len);
640 decoded_len =
pg_b64_decode(serverkey_str, strlen(serverkey_str),
641 decoded_server_buf, decoded_len);
681 if (raw_salt == NULL)
686 encoded_salt = (
char *)
palloc(encoded_len + 1);
692 encoded_salt[encoded_len] =
'\0';
694 *salt = encoded_salt;
708 char *begin = *input;
713 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
714 errmsg(
"malformed SCRAM message"),
715 errdetail(
"Expected attribute \"%c\" but found \"%s\".",
721 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
722 errmsg(
"malformed SCRAM message"),
723 errdetail(
"Expected character \"=\" for attribute \"%c\".", attr)));
727 while (*end && *end !=
',')
755 if (*p < 0x21 || *p > 0x7E || *p == 0x2C )
774 if (
c >= 0x21 &&
c <= 0x7E)
792 static char buf[30 + 1];
795 for (
i = 0;
i <
sizeof(
buf) - 1;
i++)
802 if (
c >= 0x21 &&
c <= 0x7E)
820 char *begin = *input;
826 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
827 errmsg(
"malformed SCRAM message"),
828 errdetail(
"Attribute expected, but found end of string.")));
836 if (!((attr >=
'A' && attr <=
'Z') ||
837 (attr >=
'a' && attr <=
'z')))
839 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
840 errmsg(
"malformed SCRAM message"),
841 errdetail(
"Attribute expected, but found invalid character \"%s\".",
849 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
850 errmsg(
"malformed SCRAM message"),
851 errdetail(
"Expected character \"=\" for attribute \"%c\".", attr)));
855 while (*end && *end !=
',')
879 char *channel_binding_type;
946 state->cbind_flag = *p;
955 if (
state->channel_binding_in_use)
957 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
958 errmsg(
"malformed SCRAM message"),
959 errdetail(
"The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
964 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
965 errmsg(
"malformed SCRAM message"),
966 errdetail(
"Comma expected, but found character \"%s\".",
977 if (
state->channel_binding_in_use)
979 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
980 errmsg(
"malformed SCRAM message"),
981 errdetail(
"The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
983 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
984 if (
state->port->ssl_in_use)
986 (
errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
987 errmsg(
"SCRAM channel binding negotiation error"),
988 errdetail(
"The client supports SCRAM channel binding but thinks the server does not. "
989 "However, this server does support channel binding.")));
994 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
995 errmsg(
"malformed SCRAM message"),
996 errdetail(
"Comma expected, but found character \"%s\".",
1006 if (!
state->channel_binding_in_use)
1008 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1009 errmsg(
"malformed SCRAM message"),
1010 errdetail(
"The client selected SCRAM-SHA-256 without channel binding, but the SCRAM message includes channel binding data.")));
1018 if (strcmp(channel_binding_type,
"tls-server-end-point") != 0)
1020 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1021 errmsg(
"unsupported SCRAM channel-binding type \"%s\"",
1026 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1027 errmsg(
"malformed SCRAM message"),
1028 errdetail(
"Unexpected channel-binding flag \"%s\".",
1037 (
errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
1038 errmsg(
"client uses authorization identity, but it is not supported")));
1041 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1042 errmsg(
"malformed SCRAM message"),
1043 errdetail(
"Unexpected attribute \"%s\" in client-first-message.",
1058 (
errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
1059 errmsg(
"client requires an unsupported SCRAM extension")));
1072 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1073 errmsg(
"non-printable characters in SCRAM nonce")));
1092 int client_nonce_len = strlen(
state->client_nonce);
1093 int server_nonce_len = strlen(
state->server_nonce);
1094 int final_nonce_len = strlen(
state->client_final_nonce);
1096 if (final_nonce_len != client_nonce_len + server_nonce_len)
1098 if (memcmp(
state->client_final_nonce,
state->client_nonce, client_nonce_len) != 0)
1100 if (memcmp(
state->client_final_nonce + client_nonce_len,
state->server_nonce, server_nonce_len) != 0)
1119 const char *errstr = NULL;
1129 strlen(
state->client_first_message_bare)) < 0 ||
1133 strlen(
state->server_first_message)) < 0 ||
1136 (
uint8 *)
state->client_final_message_without_proof,
1137 strlen(
state->client_final_message_without_proof)) < 0 ||
1138 pg_hmac_final(ctx, ClientSignature,
sizeof(ClientSignature)) < 0)
1140 elog(
ERROR,
"could not calculate client signature: %s",
1148 ClientKey[
i] =
state->ClientProof[
i] ^ ClientSignature[
i];
1152 elog(
ERROR,
"could not hash stored key: %s", errstr);
1202 (
errcode(ERRCODE_INTERNAL_ERROR),
1203 errmsg(
"could not generate random nonce")));
1209 state->server_nonce, encoded_len);
1210 if (encoded_len < 0)
1212 (
errcode(ERRCODE_INTERNAL_ERROR),
1213 errmsg(
"could not encode random nonce")));
1214 state->server_nonce[encoded_len] =
'\0';
1216 state->server_first_message =
1232 char *channel_binding;
1238 int client_proof_len;
1276 if (
state->channel_binding_in_use)
1278 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
1279 const char *cbind_data = NULL;
1280 size_t cbind_data_len = 0;
1281 size_t cbind_header_len;
1283 size_t cbind_input_len;
1285 int b64_message_len;
1290 cbind_data = be_tls_get_certificate_hash(
state->port,
1294 if (cbind_data == NULL || cbind_data_len == 0)
1295 elog(
ERROR,
"could not get server certificate hash");
1297 cbind_header_len = strlen(
"p=tls-server-end-point,,");
1298 cbind_input_len = cbind_header_len + cbind_data_len;
1299 cbind_input =
palloc(cbind_input_len);
1300 snprintf(cbind_input, cbind_input_len,
"p=tls-server-end-point,,");
1301 memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
1305 b64_message =
palloc(b64_message_len + 1);
1306 b64_message_len =
pg_b64_encode(cbind_input, cbind_input_len,
1307 b64_message, b64_message_len);
1308 if (b64_message_len < 0)
1309 elog(
ERROR,
"could not encode channel binding data");
1310 b64_message[b64_message_len] =
'\0';
1316 if (strcmp(channel_binding, b64_message) != 0)
1318 (
errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
1319 errmsg(
"SCRAM channel binding check failed")));
1322 elog(
ERROR,
"channel binding not supported by this build");
1333 if (!(strcmp(channel_binding,
"biws") == 0 &&
state->cbind_flag ==
'n') &&
1334 !(strcmp(channel_binding,
"eSws") == 0 &&
state->cbind_flag ==
'y'))
1336 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1337 errmsg(
"unexpected SCRAM channel-binding attribute in client-final-message")));
1347 }
while (attr !=
'p');
1350 client_proof =
palloc(client_proof_len);
1354 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1355 errmsg(
"malformed SCRAM message"),
1356 errdetail(
"Malformed proof in client-final-message.")));
1358 pfree(client_proof);
1362 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1363 errmsg(
"malformed SCRAM message"),
1364 errdetail(
"Garbage found at the end of client-final-message.")));
1366 state->client_final_message_without_proof =
palloc(proof - begin + 1);
1367 memcpy(
state->client_final_message_without_proof, input, proof - begin);
1368 state->client_final_message_without_proof[proof - begin] =
'\0';
1378 char *server_signature_base64;
1386 strlen(
state->client_first_message_bare)) < 0 ||
1390 strlen(
state->server_first_message)) < 0 ||
1393 (
uint8 *)
state->client_final_message_without_proof,
1394 strlen(
state->client_final_message_without_proof)) < 0 ||
1395 pg_hmac_final(ctx, ServerSignature,
sizeof(ServerSignature)) < 0)
1397 elog(
ERROR,
"could not calculate server signature: %s",
1405 server_signature_base64 =
palloc(siglen + 1);
1410 elog(
ERROR,
"could not encode server signature");
1411 server_signature_base64[siglen] =
'\0';
1424 return psprintf(
"v=%s", server_signature_base64);
1447 "salt length greater than SHA256 digest length");
1460 return (
char *) sha_digest;
static void * scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
static char * build_server_first_message(scram_state *state)
const pg_be_sasl_mech pg_be_scram_mech
static void read_client_first_message(scram_state *state, const char *input)
bool parse_scram_secret(const char *secret, int *iterations, char **salt, uint8 *stored_key, uint8 *server_key)
static char * read_attr_value(char **input, char attr)
static char * read_any_attr(char **input, char *attr_p)
static bool verify_client_proof(scram_state *state)
static bool verify_final_nonce(scram_state *state)
static char * sanitize_str(const char *s)
static char * scram_mock_salt(const char *username)
static void mock_scram_secret(const char *username, int *iterations, char **salt, uint8 *stored_key, uint8 *server_key)
static int scram_exchange(void *opaq, const char *input, int inputlen, char **output, int *outputlen, const char **logdetail)
static bool is_scram_printable(char *p)
static char * sanitize_char(char c)
char * pg_be_scram_build_secret(const char *password)
bool scram_verify_plain_password(const char *username, const char *password, const char *secret)
static void read_client_final_message(scram_state *state, const char *input)
static char * build_server_final_message(scram_state *state)
static void scram_get_mechanisms(Port *port, StringInfo buf)
int pg_b64_decode(const char *src, int len, char *dst, int dstlen)
int pg_b64_enc_len(int srclen)
int pg_b64_encode(const char *src, int len, char *dst, int dstlen)
int pg_b64_dec_len(int srclen)
#define StaticAssertStmt(condition, errmessage)
PasswordType get_password_type(const char *shadow_pass)
@ PASSWORD_TYPE_SCRAM_SHA_256
int pg_cryptohash_update(pg_cryptohash_ctx *ctx, const uint8 *data, size_t len)
int pg_cryptohash_init(pg_cryptohash_ctx *ctx)
void pg_cryptohash_free(pg_cryptohash_ctx *ctx)
pg_cryptohash_ctx * pg_cryptohash_create(pg_cryptohash_type type)
int pg_cryptohash_final(pg_cryptohash_ctx *ctx, uint8 *dest, size_t len)
elog(ERROR, "%s: %s", p2, msg)
int errdetail(const char *fmt,...)
int errcode(int sqlerrcode)
int errmsg(const char *fmt,...)
#define ereport(elevel,...)
pg_hmac_ctx * pg_hmac_create(pg_cryptohash_type type)
const char * pg_hmac_error(pg_hmac_ctx *ctx)
void pg_hmac_free(pg_hmac_ctx *ctx)
int pg_hmac_update(pg_hmac_ctx *ctx, const uint8 *data, size_t len)
int pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
int pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
if(TABLE==NULL||TABLE_index==NULL)
Assert(fmt[strlen(fmt) - 1] !='\n')
char * pstrdup(const char *in)
void pfree(void *pointer)
void * palloc0(Size size)
#define MOCK_AUTH_NONCE_LEN
static void output(uint64 loop_count)
bool pg_strong_random(void *buf, size_t len)
char * psprintf(const char *fmt,...)
#define PG_SASL_EXCHANGE_FAILURE
#define PG_SASL_EXCHANGE_CONTINUE
#define PG_SASL_EXCHANGE_SUCCESS
pg_saslprep_rc pg_saslprep(const char *input, char **output)
int scram_SaltedPassword(const char *password, const char *salt, int saltlen, int iterations, uint8 *result, const char **errstr)
int scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
int scram_ServerKey(const uint8 *salted_password, uint8 *result, const char **errstr)
char * scram_build_secret(const char *salt, int saltlen, int iterations, const char *password, const char **errstr)
#define SCRAM_DEFAULT_ITERATIONS
#define SCRAM_SHA_256_PLUS_NAME
#define SCRAM_SHA_256_NAME
#define SCRAM_RAW_NONCE_LEN
#define SCRAM_DEFAULT_SALT_LEN
#define PG_SHA256_DIGEST_LENGTH
void appendStringInfoString(StringInfo str, const char *s)
void appendStringInfoChar(StringInfo str, char ch)
char * client_final_nonce
char * client_first_message_bare
char * client_final_message_without_proof
char * server_first_message
bool channel_binding_in_use
char * GetMockAuthenticationNonce(void)