// Copyright (c) 2009 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// See "SSPI Sample Application" at
// http://msdn.microsoft.com/en-us/library/aa918273.aspx
#include "net/http/http_auth_sspi_win.h"
#include "base/base64.h"
#include "base/logging.h"
#include "base/string_util.h"
#include "net/base/net_errors.h"
#include "net/base/net_util.h"
#include "net/http/http_auth.h"
namespace net {
HttpAuthSSPI::HttpAuthSSPI(const std::string& scheme,
SEC_WCHAR* security_package)
: scheme_(scheme),
security_package_(security_package),
max_token_length_(0) {
SecInvalidateHandle(&cred_);
SecInvalidateHandle(&ctxt_);
}
HttpAuthSSPI::~HttpAuthSSPI() {
ResetSecurityContext();
if (SecIsValidHandle(&cred_)) {
FreeCredentialsHandle(&cred_);
SecInvalidateHandle(&cred_);
}
}
bool HttpAuthSSPI::NeedsIdentity() const {
return decoded_server_auth_token_.empty();
}
bool HttpAuthSSPI::IsFinalRound() const {
return !decoded_server_auth_token_.empty();
}
void HttpAuthSSPI::ResetSecurityContext() {
if (SecIsValidHandle(&ctxt_)) {
DeleteSecurityContext(&ctxt_);
SecInvalidateHandle(&ctxt_);
}
}
bool HttpAuthSSPI::ParseChallenge(std::string::const_iterator challenge_begin,
std::string::const_iterator challenge_end) {
// Verify the challenge's auth-scheme.
HttpAuth::ChallengeTokenizer challenge_tok(challenge_begin, challenge_end);
if (!challenge_tok.valid() ||
!LowerCaseEqualsASCII(challenge_tok.scheme(),
StringToLowerASCII(scheme_).c_str()))
return false;
// Extract the auth-data. We can't use challenge_tok.GetNext() because
// auth-data is base64-encoded and may contain '=' padding at the end,
// which would be mistaken for a name=value pair.
challenge_begin += scheme_.length(); // Skip over scheme name.
HttpUtil::TrimLWS(&challenge_begin, &challenge_end);
std::string encoded_auth_token(challenge_begin, challenge_end);
int encoded_length = encoded_auth_token.length();
// Strip off any padding.
// (See https://bugzilla.mozilla.org/show_bug.cgi?id=230351.)
//
// Our base64 decoder requires that the length be a multiple of 4.
while (encoded_length > 0 && encoded_length % 4 != 0 &&
encoded_auth_token[encoded_length - 1] == '=')
encoded_length--;
encoded_auth_token.erase(encoded_length);
std::string decoded_auth_token;
bool rv = base::Base64Decode(encoded_auth_token, &decoded_auth_token);
if (rv) {
decoded_server_auth_token_ = decoded_auth_token;
}
return rv;
}
int HttpAuthSSPI::GenerateCredentials(const std::wstring& username,
const std::wstring& password,
const GURL& origin,
const HttpRequestInfo* request,
const ProxyInfo* proxy,
std::string* out_credentials) {
// |username| may be in the form "DOMAIN\user". Parse it into the two
// components.
std::wstring domain;
std::wstring user;
SplitDomainAndUser(username, &domain, &user);
// Initial challenge.
if (!IsFinalRound()) {
int rv = OnFirstRound(domain, user, password);
if (rv != OK)
return rv;
}
void* out_buf;
int out_buf_len;
int rv = GetNextSecurityToken(
origin,
static_cast<void *>(const_cast<char *>(
decoded_server_auth_token_.c_str())),
decoded_server_auth_token_.length(),
&out_buf,
&out_buf_len);
if (rv != OK)
return rv;
// Base64 encode data in output buffer and prepend the scheme.
std::string encode_input(static_cast<char*>(out_buf), out_buf_len);
std::string encode_output;
bool ok = base::Base64Encode(encode_input, &encode_output);
// OK, we are done with |out_buf|
free(out_buf);
if (!ok)
return rv;
*out_credentials = scheme_ + " " + encode_output;
return OK;
}
int HttpAuthSSPI::OnFirstRound(const std::wstring& domain,
const std::wstring& user,
const std::wstring& password) {
int rv = DetermineMaxTokenLength(security_package_, &max_token_length_);
if (rv != OK) {
return rv;
}
rv = AcquireCredentials(security_package_, domain, user, password, &cred_);
return rv;
}
int HttpAuthSSPI::GetNextSecurityToken(
const GURL& origin,
const void * in_token,
int in_token_len,
void** out_token,
int* out_token_len) {
SECURITY_STATUS status;
TimeStamp expiry;
DWORD ctxt_attr;
CtxtHandle* ctxt_ptr;
SecBufferDesc in_buffer_desc, out_buffer_desc;
SecBufferDesc* in_buffer_desc_ptr;
SecBuffer in_buffer, out_buffer;
if (in_token_len > 0) {
// Prepare input buffer.
in_buffer_desc.ulVersion = SECBUFFER_VERSION;
in_buffer_desc.cBuffers = 1;
in_buffer_desc.pBuffers = &in_buffer;
in_buffer.BufferType = SECBUFFER_TOKEN;
in_buffer.cbBuffer = in_token_len;
in_buffer.pvBuffer = const_cast<void*>(in_token);
ctxt_ptr = &ctxt_;
in_buffer_desc_ptr = &in_buffer_desc;
} else {
// If there is no input token, then we are starting a new authentication
// sequence. If we have already initialized our security context, then
// we're incorrectly reusing the auth handler for a new sequence.
if (SecIsValidHandle(&ctxt_)) {
LOG(ERROR) << "Cannot restart authentication sequence";
return ERR_UNEXPECTED;
}
ctxt_ptr = NULL;
in_buffer_desc_ptr = NULL;
}
// Prepare output buffer.
out_buffer_desc.ulVersion = SECBUFFER_VERSION;
out_buffer_desc.cBuffers = 1;
out_buffer_desc.pBuffers = &out_buffer;
out_buffer.BufferType = SECBUFFER_TOKEN;
out_buffer.cbBuffer = max_token_length_;
out_buffer.pvBuffer = malloc(out_buffer.cbBuffer);
if (!out_buffer.pvBuffer)
return ERR_OUT_OF_MEMORY;
// The service principal name of the destination server. See
// http://msdn.microsoft.com/en-us/library/ms677949%28VS.85%29.aspx
std::wstring target(L"HTTP/");
target.append(ASCIIToWide(GetHostAndPort(origin)));
wchar_t* target_name = const_cast<wchar_t*>(target.c_str());
// This returns a token that is passed to the remote server.
status = InitializeSecurityContext(&cred_, // phCredential
ctxt_ptr, // phContext
target_name, // pszTargetName
0, // fContextReq
0, // Reserved1 (must be 0)
SECURITY_NATIVE_DREP, // TargetDataRep
in_buffer_desc_ptr, // pInput
0, // Reserved2 (must be 0)
&ctxt_, // phNewContext
&out_buffer_desc, // pOutput
&ctxt_attr, // pfContextAttr
&expiry); // ptsExpiry
// On success, the function returns SEC_I_CONTINUE_NEEDED on the first call
// and SEC_E_OK on the second call. On failure, the function returns an
// error code.
if (status != SEC_I_CONTINUE_NEEDED && status != SEC_E_OK) {
LOG(ERROR) << "InitializeSecurityContext failed: " << status;
ResetSecurityContext();
free(out_buffer.pvBuffer);
return ERR_UNEXPECTED; // TODO(wtc): map error code.
}
if (!out_buffer.cbBuffer) {
free(out_buffer.pvBuffer);
out_buffer.pvBuffer = NULL;
}
*out_token = out_buffer.pvBuffer;
*out_token_len = out_buffer.cbBuffer;
return OK;
}
void SplitDomainAndUser(const std::wstring& combined,
std::wstring* domain,
std::wstring* user) {
size_t backslash_idx = combined.find(L'\\');
if (backslash_idx == std::wstring::npos) {
domain->clear();
*user = combined;
} else {
*domain = combined.substr(0, backslash_idx);
*user = combined.substr(backslash_idx + 1);
}
}
int DetermineMaxTokenLength(const std::wstring& package,
ULONG* max_token_length) {
PSecPkgInfo pkg_info;
SECURITY_STATUS status = QuerySecurityPackageInfo(
const_cast<wchar_t *>(package.c_str()), &pkg_info);
if (status != SEC_E_OK) {
LOG(ERROR) << "Security package " << package << " not found";
return ERR_UNEXPECTED;
}
*max_token_length = pkg_info->cbMaxToken;
FreeContextBuffer(pkg_info);
return OK;
}
int AcquireCredentials(const SEC_WCHAR* package,
const std::wstring& domain,
const std::wstring& user,
const std::wstring& password,
CredHandle* cred) {
SEC_WINNT_AUTH_IDENTITY identity;
identity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
identity.User =
reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(user.c_str()));
identity.UserLength = user.size();
identity.Domain =
reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(domain.c_str()));
identity.DomainLength = domain.size();
identity.Password =
reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(password.c_str()));
identity.PasswordLength = password.size();
TimeStamp expiry;
// Pass the username/password to get the credentials handle.
// Note: If the 5th argument is NULL, it uses the default cached credentials
// for the logged in user, which can be used for single sign-on.
SECURITY_STATUS status = AcquireCredentialsHandle(
NULL, // pszPrincipal
const_cast<SEC_WCHAR*>(package), // pszPackage
SECPKG_CRED_OUTBOUND, // fCredentialUse
NULL, // pvLogonID
&identity, // pAuthData
NULL, // pGetKeyFn (not used)
NULL, // pvGetKeyArgument (not used)
cred, // phCredential
&expiry); // ptsExpiry
if (status != SEC_E_OK)
return ERR_UNEXPECTED;
return OK;
}
} // namespace net