112 const char *shadow_pass);
114 char **
output,
int *outputlen,
115 const char **logdetail);
185 int *iterations,
int *key_length,
char **salt,
215 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
216 if (
port->ssl_in_use)
260 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
262 state->channel_binding_in_use =
true;
266 state->channel_binding_in_use =
false;
269 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
270 errmsg(
"client selected an invalid SASL authentication mechanism")));
294 (
errmsg(
"invalid SCRAM secret for user \"%s\"",
295 state->port->user_name)));
305 state->logdetail =
psprintf(
_(
"User \"%s\" does not have a valid SCRAM secret."),
306 state->port->user_name);
332 state->doomed =
true;
356 char **
output,
int *outputlen,
const char **logdetail)
384 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
385 errmsg(
"malformed SCRAM message"),
387 if (inputlen != strlen(
input))
389 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
390 errmsg(
"malformed SCRAM message"),
391 errdetail(
"Message length does not match input length.")));
393 switch (
state->state)
422 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
423 errmsg(
"invalid SCRAM response"),
460 elog(
ERROR,
"invalid SCRAM exchange state");
465 *logdetail =
state->logdetail;
468 *outputlen = strlen(*
output);
485 const char *errstr = NULL;
494 password = (
const char *) prep_password;
499 (
errcode(ERRCODE_INTERNAL_ERROR),
500 errmsg(
"could not generate random salt")));
508 pfree(prep_password);
534 const char *errstr = NULL;
537 &encoded_salt, stored_key, server_key))
549 saltlen =
pg_b64_decode(encoded_salt, strlen(encoded_salt), salt,
565 salt, saltlen, iterations,
566 salted_password, &errstr) < 0 ||
568 computed_key, &errstr) < 0)
570 elog(
ERROR,
"could not compute server key: %s", errstr);
574 pfree(prep_password);
580 return memcmp(computed_key, server_key, key_length) == 0;
598 char **salt,
uint8 *stored_key,
uint8 *server_key)
604 char *iterations_str;
608 char *decoded_salt_buf;
609 char *decoded_stored_buf;
610 char *decoded_server_buf;
618 if ((scheme_str = strtok(v,
"$")) == NULL)
620 if ((iterations_str = strtok(NULL,
":")) == NULL)
622 if ((salt_str = strtok(NULL,
"$")) == NULL)
624 if ((storedkey_str = strtok(NULL,
":")) == NULL)
626 if ((serverkey_str = strtok(NULL,
"")) == NULL)
630 if (strcmp(scheme_str,
"SCRAM-SHA-256") != 0)
636 *iterations = strtol(iterations_str, &p, 10);
637 if (*p || errno != 0)
645 decoded_salt_buf =
palloc(decoded_len);
647 decoded_salt_buf, decoded_len);
656 decoded_stored_buf =
palloc(decoded_len);
657 decoded_len =
pg_b64_decode(storedkey_str, strlen(storedkey_str),
658 decoded_stored_buf, decoded_len);
659 if (decoded_len != *key_length)
661 memcpy(stored_key, decoded_stored_buf, *key_length);
664 decoded_server_buf =
palloc(decoded_len);
665 decoded_len =
pg_b64_decode(serverkey_str, strlen(serverkey_str),
666 decoded_server_buf, decoded_len);
667 if (decoded_len != *key_length)
669 memcpy(server_key, decoded_server_buf, *key_length);
691 int *iterations,
int *key_length,
char **salt,
711 if (raw_salt == NULL)
716 encoded_salt = (
char *)
palloc(encoded_len + 1);
722 encoded_salt[encoded_len] =
'\0';
724 *salt = encoded_salt;
738 char *begin = *
input;
743 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
744 errmsg(
"malformed SCRAM message"),
745 errdetail(
"Expected attribute \"%c\" but found \"%s\".",
751 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
752 errmsg(
"malformed SCRAM message"),
753 errdetail(
"Expected character \"=\" for attribute \"%c\".", attr)));
757 while (*end && *end !=
',')
785 if (*p < 0x21 || *p > 0x7E || *p == 0x2C )
804 if (
c >= 0x21 &&
c <= 0x7E)
822 static char buf[30 + 1];
825 for (
i = 0;
i <
sizeof(
buf) - 1;
i++)
832 if (
c >= 0x21 &&
c <= 0x7E)
850 char *begin = *
input;
856 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
857 errmsg(
"malformed SCRAM message"),
858 errdetail(
"Attribute expected, but found end of string.")));
866 if (!((attr >=
'A' && attr <=
'Z') ||
867 (attr >=
'a' && attr <=
'z')))
869 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
870 errmsg(
"malformed SCRAM message"),
871 errdetail(
"Attribute expected, but found invalid character \"%s\".",
879 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
880 errmsg(
"malformed SCRAM message"),
881 errdetail(
"Expected character \"=\" for attribute \"%c\".", attr)));
885 while (*end && *end !=
',')
909 char *channel_binding_type;
976 state->cbind_flag = *p;
985 if (
state->channel_binding_in_use)
987 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
988 errmsg(
"malformed SCRAM message"),
989 errdetail(
"The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
994 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
995 errmsg(
"malformed SCRAM message"),
996 errdetail(
"Comma expected, but found character \"%s\".",
1007 if (
state->channel_binding_in_use)
1009 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1010 errmsg(
"malformed SCRAM message"),
1011 errdetail(
"The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
1013 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
1014 if (
state->port->ssl_in_use)
1016 (
errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
1017 errmsg(
"SCRAM channel binding negotiation error"),
1018 errdetail(
"The client supports SCRAM channel binding but thinks the server does not. "
1019 "However, this server does support channel binding.")));
1024 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1025 errmsg(
"malformed SCRAM message"),
1026 errdetail(
"Comma expected, but found character \"%s\".",
1036 if (!
state->channel_binding_in_use)
1038 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1039 errmsg(
"malformed SCRAM message"),
1040 errdetail(
"The client selected SCRAM-SHA-256 without channel binding, but the SCRAM message includes channel binding data.")));
1048 if (strcmp(channel_binding_type,
"tls-server-end-point") != 0)
1050 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1051 errmsg(
"unsupported SCRAM channel-binding type \"%s\"",
1056 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1057 errmsg(
"malformed SCRAM message"),
1058 errdetail(
"Unexpected channel-binding flag \"%s\".",
1067 (
errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
1068 errmsg(
"client uses authorization identity, but it is not supported")));
1071 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1072 errmsg(
"malformed SCRAM message"),
1073 errdetail(
"Unexpected attribute \"%s\" in client-first-message.",
1088 (
errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
1089 errmsg(
"client requires an unsupported SCRAM extension")));
1102 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1103 errmsg(
"non-printable characters in SCRAM nonce")));
1122 int client_nonce_len = strlen(
state->client_nonce);
1123 int server_nonce_len = strlen(
state->server_nonce);
1124 int final_nonce_len = strlen(
state->client_final_nonce);
1126 if (final_nonce_len != client_nonce_len + server_nonce_len)
1128 if (memcmp(
state->client_final_nonce,
state->client_nonce, client_nonce_len) != 0)
1130 if (memcmp(
state->client_final_nonce + client_nonce_len,
state->server_nonce, server_nonce_len) != 0)
1149 const char *errstr = NULL;
1159 strlen(
state->client_first_message_bare)) < 0 ||
1163 strlen(
state->server_first_message)) < 0 ||
1166 (
uint8 *)
state->client_final_message_without_proof,
1167 strlen(
state->client_final_message_without_proof)) < 0 ||
1170 elog(
ERROR,
"could not calculate client signature: %s",
1177 for (
i = 0;
i <
state->key_length;
i++)
1178 ClientKey[
i] =
state->ClientProof[
i] ^ ClientSignature[
i];
1182 client_StoredKey, &errstr) < 0)
1183 elog(
ERROR,
"could not hash stored key: %s", errstr);
1185 if (memcmp(client_StoredKey,
state->StoredKey,
state->key_length) != 0)
1233 (
errcode(ERRCODE_INTERNAL_ERROR),
1234 errmsg(
"could not generate random nonce")));
1240 state->server_nonce, encoded_len);
1241 if (encoded_len < 0)
1243 (
errcode(ERRCODE_INTERNAL_ERROR),
1244 errmsg(
"could not encode random nonce")));
1245 state->server_nonce[encoded_len] =
'\0';
1247 state->server_first_message =
1263 char *channel_binding;
1269 int client_proof_len;
1307 if (
state->channel_binding_in_use)
1309 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
1310 const char *cbind_data = NULL;
1311 size_t cbind_data_len = 0;
1312 size_t cbind_header_len;
1314 size_t cbind_input_len;
1316 int b64_message_len;
1321 cbind_data = be_tls_get_certificate_hash(
state->port,
1325 if (cbind_data == NULL || cbind_data_len == 0)
1326 elog(
ERROR,
"could not get server certificate hash");
1328 cbind_header_len = strlen(
"p=tls-server-end-point,,");
1329 cbind_input_len = cbind_header_len + cbind_data_len;
1330 cbind_input =
palloc(cbind_input_len);
1331 snprintf(cbind_input, cbind_input_len,
"p=tls-server-end-point,,");
1332 memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
1336 b64_message =
palloc(b64_message_len + 1);
1337 b64_message_len =
pg_b64_encode(cbind_input, cbind_input_len,
1338 b64_message, b64_message_len);
1339 if (b64_message_len < 0)
1340 elog(
ERROR,
"could not encode channel binding data");
1341 b64_message[b64_message_len] =
'\0';
1347 if (strcmp(channel_binding, b64_message) != 0)
1349 (
errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
1350 errmsg(
"SCRAM channel binding check failed")));
1353 elog(
ERROR,
"channel binding not supported by this build");
1364 if (!(strcmp(channel_binding,
"biws") == 0 &&
state->cbind_flag ==
'n') &&
1365 !(strcmp(channel_binding,
"eSws") == 0 &&
state->cbind_flag ==
'y'))
1367 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1368 errmsg(
"unexpected SCRAM channel-binding attribute in client-final-message")));
1378 }
while (attr !=
'p');
1381 client_proof =
palloc(client_proof_len);
1383 client_proof_len) !=
state->key_length)
1385 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1386 errmsg(
"malformed SCRAM message"),
1387 errdetail(
"Malformed proof in client-final-message.")));
1388 memcpy(
state->ClientProof, client_proof,
state->key_length);
1389 pfree(client_proof);
1393 (
errcode(ERRCODE_PROTOCOL_VIOLATION),
1394 errmsg(
"malformed SCRAM message"),
1395 errdetail(
"Garbage found at the end of client-final-message.")));
1397 state->client_final_message_without_proof =
palloc(proof - begin + 1);
1398 memcpy(
state->client_final_message_without_proof,
input, proof - begin);
1399 state->client_final_message_without_proof[proof - begin] =
'\0';
1409 char *server_signature_base64;
1417 strlen(
state->client_first_message_bare)) < 0 ||
1421 strlen(
state->server_first_message)) < 0 ||
1424 (
uint8 *)
state->client_final_message_without_proof,
1425 strlen(
state->client_final_message_without_proof)) < 0 ||
1428 elog(
ERROR,
"could not calculate server signature: %s",
1436 server_signature_base64 =
palloc(siglen + 1);
1438 state->key_length, server_signature_base64,
1441 elog(
ERROR,
"could not encode server signature");
1442 server_signature_base64[siglen] =
'\0';
1455 return psprintf(
"v=%s", server_signature_base64);
1479 "salt length greater than SHA256 digest length");
1498 return (
char *) sha_digest;
static void * scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
static char * build_server_first_message(scram_state *state)
const pg_be_sasl_mech pg_be_scram_mech
static void read_client_first_message(scram_state *state, const char *input)
static char * read_attr_value(char **input, char attr)
bool parse_scram_secret(const char *secret, int *iterations, pg_cryptohash_type *hash_type, int *key_length, char **salt, uint8 *stored_key, uint8 *server_key)
static char * read_any_attr(char **input, char *attr_p)
static bool verify_client_proof(scram_state *state)
static bool verify_final_nonce(scram_state *state)
static char * sanitize_str(const char *s)
static char * scram_mock_salt(const char *username, pg_cryptohash_type hash_type, int key_length)
static int scram_exchange(void *opaq, const char *input, int inputlen, char **output, int *outputlen, const char **logdetail)
static bool is_scram_printable(char *p)
static char * sanitize_char(char c)
char * pg_be_scram_build_secret(const char *password)
bool scram_verify_plain_password(const char *username, const char *password, const char *secret)
static void read_client_final_message(scram_state *state, const char *input)
static void mock_scram_secret(const char *username, pg_cryptohash_type *hash_type, int *iterations, int *key_length, char **salt, uint8 *stored_key, uint8 *server_key)
static char * build_server_final_message(scram_state *state)
static void scram_get_mechanisms(Port *port, StringInfo buf)
int scram_sha_256_iterations
int pg_b64_decode(const char *src, int len, char *dst, int dstlen)
int pg_b64_enc_len(int srclen)
int pg_b64_encode(const char *src, int len, char *dst, int dstlen)
int pg_b64_dec_len(int srclen)
#define StaticAssertDecl(condition, errmessage)
PasswordType get_password_type(const char *shadow_pass)
@ PASSWORD_TYPE_SCRAM_SHA_256
int pg_cryptohash_update(pg_cryptohash_ctx *ctx, const uint8 *data, size_t len)
int pg_cryptohash_init(pg_cryptohash_ctx *ctx)
void pg_cryptohash_free(pg_cryptohash_ctx *ctx)
pg_cryptohash_ctx * pg_cryptohash_create(pg_cryptohash_type type)
int pg_cryptohash_final(pg_cryptohash_ctx *ctx, uint8 *dest, size_t len)
elog(ERROR, "%s: %s", p2, msg)
int errdetail(const char *fmt,...)
int errcode(int sqlerrcode)
int errmsg(const char *fmt,...)
#define ereport(elevel,...)
pg_hmac_ctx * pg_hmac_create(pg_cryptohash_type type)
const char * pg_hmac_error(pg_hmac_ctx *ctx)
void pg_hmac_free(pg_hmac_ctx *ctx)
int pg_hmac_update(pg_hmac_ctx *ctx, const uint8 *data, size_t len)
int pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
int pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
if(TABLE==NULL||TABLE_index==NULL)
Assert(fmt[strlen(fmt) - 1] !='\n')
char * pstrdup(const char *in)
void pfree(void *pointer)
void * palloc0(Size size)
#define MOCK_AUTH_NONCE_LEN
bool pg_strong_random(void *buf, size_t len)
char * psprintf(const char *fmt,...)
#define PG_SASL_EXCHANGE_FAILURE
#define PG_SASL_EXCHANGE_CONTINUE
#define PG_SASL_EXCHANGE_SUCCESS
pg_saslprep_rc pg_saslprep(const char *input, char **output)
int scram_ServerKey(const uint8 *salted_password, pg_cryptohash_type hash_type, int key_length, uint8 *result, const char **errstr)
int scram_SaltedPassword(const char *password, pg_cryptohash_type hash_type, int key_length, const char *salt, int saltlen, int iterations, uint8 *result, const char **errstr)
char * scram_build_secret(pg_cryptohash_type hash_type, int key_length, const char *salt, int saltlen, int iterations, const char *password, const char **errstr)
int scram_H(const uint8 *input, pg_cryptohash_type hash_type, int key_length, uint8 *result, const char **errstr)
#define SCRAM_SHA_256_PLUS_NAME
#define SCRAM_SHA_256_NAME
#define SCRAM_RAW_NONCE_LEN
#define SCRAM_DEFAULT_SALT_LEN
#define SCRAM_MAX_KEY_LEN
#define SCRAM_SHA_256_KEY_LEN
#define SCRAM_SHA_256_DEFAULT_ITERATIONS
#define PG_SHA256_DIGEST_LENGTH
void appendStringInfoString(StringInfo str, const char *s)
void appendStringInfoChar(StringInfo str, char ch)
char * client_final_nonce
char * client_first_message_bare
char * client_final_message_without_proof
char * server_first_message
bool channel_binding_in_use
pg_cryptohash_type hash_type
char * GetMockAuthenticationNonce(void)