Skip to content
162 changes: 110 additions & 52 deletions packages/pg/lib/sasl.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,27 @@ function continueSession(session, password, serverData) {
if (session.message !== 'SASLInitialResponse') {
throw new Error('SASL: Last message was not SASLInitialResponse')
}
if (typeof password !== 'string') {
throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: client password must be a string')
}
if (typeof serverData !== 'string') {
throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: serverData must be a string')
}

const sv = extractVariablesFromFirstServerMessage(serverData)
const sv = parseServerFirstMessage(serverData)

if (!sv.nonce.startsWith(session.clientNonce)) {
throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: server nonce does not start with client nonce')
} else if (sv.nonce.length === session.clientNonce.length) {
throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: server nonce is too short')
}

var saltBytes = Buffer.from(sv.salt, 'base64')

var saltedPassword = Hi(password, saltBytes, sv.iteration)

var clientKey = createHMAC(saltedPassword, 'Client Key')
var storedKey = crypto.createHash('sha256').update(clientKey).digest()
var clientKey = hmacSha256(saltedPassword, 'Client Key')
var storedKey = sha256(clientKey)

var clientFirstMessageBare = 'n=*,r=' + session.clientNonce
var serverFirstMessage = 'r=' + sv.nonce + ',s=' + sv.salt + ',i=' + sv.iteration
Expand All @@ -41,12 +49,12 @@ function continueSession(session, password, serverData) {

var authMessage = clientFirstMessageBare + ',' + serverFirstMessage + ',' + clientFinalMessageWithoutProof

var clientSignature = createHMAC(storedKey, authMessage)
var clientSignature = hmacSha256(storedKey, authMessage)
var clientProofBytes = xorBuffers(clientKey, clientSignature)
var clientProof = clientProofBytes.toString('base64')

var serverKey = createHMAC(saltedPassword, 'Server Key')
var serverSignatureBytes = createHMAC(serverKey, authMessage)
var serverKey = hmacSha256(saltedPassword, 'Server Key')
var serverSignatureBytes = hmacSha256(serverKey, authMessage)

session.message = 'SASLResponse'
session.serverSignature = serverSignatureBytes.toString('base64')
Expand All @@ -57,54 +65,87 @@ function finalizeSession(session, serverData) {
if (session.message !== 'SASLResponse') {
throw new Error('SASL: Last message was not SASLResponse')
}
if (typeof serverData !== 'string') {
throw new Error('SASL: SCRAM-SERVER-FINAL-MESSAGE: serverData must be a string')
}

var serverSignature

String(serverData)
.split(',')
.forEach(function (part) {
switch (part[0]) {
case 'v':
serverSignature = part.substr(2)
break
}
})
const { serverSignature } = parseServerFinalMessage(serverData)

if (serverSignature !== session.serverSignature) {
throw new Error('SASL: SCRAM-SERVER-FINAL-MESSAGE: server signature does not match')
}
}

function extractVariablesFromFirstServerMessage(data) {
var nonce, salt, iteration

String(data)
.split(',')
.forEach(function (part) {
switch (part[0]) {
case 'r':
nonce = part.substr(2)
break
case 's':
salt = part.substr(2)
break
case 'i':
iteration = parseInt(part.substr(2), 10)
break
/**
* printable = %x21-2B / %x2D-7E
* ;; Printable ASCII except ",".
* ;; Note that any "printable" is also
* ;; a valid "value".
*/
function isPrintableChars(text) {
if (typeof text !== 'string') {
throw new TypeError('SASL: text must be a string')
}
return text
.split('')
.map((_, i) => text.charCodeAt(i))
.every((c) => (c >= 0x21 && c <= 0x2b) || (c >= 0x2d && c <= 0x7e))
}

/**
* base64-char = ALPHA / DIGIT / "/" / "+"
*
* base64-4 = 4base64-char
*
* base64-3 = 3base64-char "="
*
* base64-2 = 2base64-char "=="
*
* base64 = *base64-4 [base64-3 / base64-2]
*/
function isBase64(text) {
return /^(?:[a-zA-Z0-9+/]{4})*(?:[a-zA-Z0-9+/]{2}==|[a-zA-Z0-9+/]{3}=)?$/.test(text)
}

function parseAttributePairs(text) {
if (typeof text !== 'string') {
throw new TypeError('SASL: attribute pairs text must be a string')
}

return new Map(
text.split(',').map((attrValue) => {
if (!/^.=/.test(attrValue)) {
throw new Error('SASL: Invalid attribute pair entry')
}
const name = attrValue[0]
const value = attrValue.substring(2)
return [name, value]
})
)
}

function parseServerFirstMessage(data) {
const attrPairs = parseAttributePairs(data)

const nonce = attrPairs.get('r')
if (!nonce) {
throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: nonce missing')
} else if (!isPrintableChars(nonce)) {
throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: nonce must only contain printable characters')
}

const salt = attrPairs.get('s')
if (!salt) {
throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: salt missing')
} else if (!isBase64(salt)) {
throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: salt must be base64')
}

if (!iteration) {
const iterationText = attrPairs.get('i')
if (!iterationText) {
throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: iteration missing')
} else if (!/^[1-9][0-9]*$/.test(iterationText)) {
throw new Error('SASL: SCRAM-SERVER-FIRST-MESSAGE: invalid iteration count')
}
const iteration = parseInt(iterationText, 10)

return {
nonce,
Expand All @@ -113,31 +154,48 @@ function extractVariablesFromFirstServerMessage(data) {
}
}

function parseServerFinalMessage(serverData) {
const attrPairs = parseAttributePairs(serverData)
const serverSignature = attrPairs.get('v')
if (!serverSignature) {
throw new Error('SASL: SCRAM-SERVER-FINAL-MESSAGE: server signature is missing')
} else if (!isBase64(serverSignature)) {
throw new Error('SASL: SCRAM-SERVER-FINAL-MESSAGE: server signature must be base64')
}
return {
serverSignature,
}
}

function xorBuffers(a, b) {
if (!Buffer.isBuffer(a)) a = Buffer.from(a)
if (!Buffer.isBuffer(b)) b = Buffer.from(b)
var res = []
if (a.length > b.length) {
for (var i = 0; i < b.length; i++) {
res.push(a[i] ^ b[i])
}
} else {
for (var j = 0; j < a.length; j++) {
res.push(a[j] ^ b[j])
}
}
return Buffer.from(res)
if (!Buffer.isBuffer(a)) {
throw new TypeError('first argument must be a Buffer')
}
if (!Buffer.isBuffer(b)) {
throw new TypeError('second argument must be a Buffer')
}
if (a.length !== b.length) {
throw new Error('Buffer lengths must match')
}
if (a.length === 0) {
throw new Error('Buffers cannot be empty')
}
return Buffer.from(a.map((_, i) => a[i] ^ b[i]))
}

function sha256(text) {
return crypto.createHash('sha256').update(text).digest()
}

function createHMAC(key, msg) {
function hmacSha256(key, msg) {
return crypto.createHmac('sha256', key).update(msg).digest()
}

function Hi(password, saltBytes, iterations) {
var ui1 = createHMAC(password, Buffer.concat([saltBytes, Buffer.from([0, 0, 0, 1])]))
var ui1 = hmacSha256(password, Buffer.concat([saltBytes, Buffer.from([0, 0, 0, 1])]))
var ui = ui1
for (var i = 0; i < iterations - 1; i++) {
ui1 = createHMAC(password, ui1)
ui1 = hmacSha256(password, ui1)
ui = xorBuffers(ui, ui1)
}

Expand Down
10 changes: 0 additions & 10 deletions packages/pg/test/test-helper.js
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,6 @@ assert.success = function (callback) {
}
}

assert.throws = function (offender) {
try {
offender()
} catch (e) {
assert.ok(e instanceof Error, 'Expected ' + offender + ' to throw instances of Error')
return
}
assert.ok(false, 'Expected ' + offender + ' to throw exception')
}

Comment on lines -114 to -123
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good call. Monkey-patching assert was such a bad idea I had such a long time ago. The use of globals throughout some of the tests is also gross....need to fix up one day soonish. 🤦

assert.lengthIs = function (actual, expectedLength) {
assert.equal(actual.length, expectedLength)
}
Expand Down
41 changes: 31 additions & 10 deletions packages/pg/test/unit/client/sasl-scram-tests.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ test('sasl/scram', function () {
test('fails when last session message was not SASLInitialResponse', function () {
assert.throws(
function () {
sasl.continueSession({})
sasl.continueSession({}, '', '')
},
{
message: 'SASL: Last message was not SASLInitialResponse',
Expand All @@ -53,6 +53,7 @@ test('sasl/scram', function () {
{
message: 'SASLInitialResponse',
},
'bad-password',
's=1,i=1'
)
},
Expand All @@ -69,6 +70,7 @@ test('sasl/scram', function () {
{
message: 'SASLInitialResponse',
},
'bad-password',
'r=1,i=1'
)
},
Expand All @@ -85,7 +87,8 @@ test('sasl/scram', function () {
{
message: 'SASLInitialResponse',
},
'r=1,s=1'
'bad-password',
'r=1,s=abcd'
)
},
{
Expand All @@ -102,7 +105,8 @@ test('sasl/scram', function () {
message: 'SASLInitialResponse',
clientNonce: '2',
},
'r=1,s=1,i=1'
'bad-password',
'r=1,s=abcd,i=1'
)
},
{
Expand All @@ -117,12 +121,12 @@ test('sasl/scram', function () {
clientNonce: 'a',
}

sasl.continueSession(session, 'password', 'r=ab,s=x,i=1')
sasl.continueSession(session, 'password', 'r=ab,s=abcd,i=1')

assert.equal(session.message, 'SASLResponse')
assert.equal(session.serverSignature, 'TtywIrpWDJ0tCSXM2mjkyiaa8iGZsZG7HllQxr8fYAo=')
assert.equal(session.serverSignature, 'jwt97IHWFn7FEqHykPTxsoQrKGOMXJl/PJyJ1JXTBKc=')

assert.equal(session.response, 'c=biws,r=ab,p=KAEPBUTjjofB0IM5UWcZApK1dSzFE0o5vnbWjBbvFHA=')
assert.equal(session.response, 'c=biws,r=ab,p=mU8grLfTjDrJer9ITsdHk0igMRDejG10EJPFbIBL3D0=')
})
})

Expand All @@ -138,15 +142,32 @@ test('sasl/scram', function () {
)
})

test('fails when server signature is not valid base64', function () {
assert.throws(
function () {
sasl.finalizeSession(
{
message: 'SASLResponse',
serverSignature: 'abcd',
},
'v=x1' // Purposefully invalid base64
)
},
{
message: 'SASL: SCRAM-SERVER-FINAL-MESSAGE: server signature must be base64',
}
)
})

test('fails when server signature does not match', function () {
assert.throws(
function () {
sasl.finalizeSession(
{
message: 'SASLResponse',
serverSignature: '3',
serverSignature: 'abcd',
},
'v=4'
'v=xyzq'
)
},
{
Expand All @@ -159,9 +180,9 @@ test('sasl/scram', function () {
sasl.finalizeSession(
{
message: 'SASLResponse',
serverSignature: '5',
serverSignature: 'abcd',
},
'v=5'
'v=abcd'
)
})
})
Expand Down