/* * Copyright (C) 2009 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.google.polo.pairing; import com.google.polo.encoding.HexadecimalEncoder; import com.google.polo.encoding.SecretEncoder; import com.google.polo.exception.BadSecretException; import com.google.polo.exception.NoConfigurationException; import com.google.polo.exception.PoloException; import com.google.polo.exception.ProtocolErrorException; import com.google.polo.pairing.PairingListener.LogLevel; import com.google.polo.pairing.message.ConfigurationMessage; import com.google.polo.pairing.message.EncodingOption; import com.google.polo.pairing.message.OptionsMessage; import com.google.polo.pairing.message.OptionsMessage.ProtocolRole; import com.google.polo.pairing.message.PoloMessage; import com.google.polo.pairing.message.PoloMessage.PoloMessageType; import com.google.polo.pairing.message.SecretAckMessage; import com.google.polo.pairing.message.SecretMessage; import com.google.polo.wire.PoloWireInterface; import java.io.IOException; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.security.cert.Certificate; import java.util.Arrays; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; /** * Implements the logic of and holds state for a single occurrence of the * pairing protocol. * <p> * This abstract class implements the logic common to both client and server * perspectives of the protocol. Notably, the 'pairing' phase of the * protocol has the same logic regardless of client/server status * ({link PairingSession#doPairingPhase()}). Other phases of the protocol are * specific to client/server status; see {@link ServerPairingSession} and * {@link ClientPairingSession}. * <p> * The protocol is initiated by called * {@link PairingSession#doPair(PairingListener)} * The listener implementation is responsible for showing the shared secret * to the user * ({@link PairingListener#onPerformOutputDeviceRole(PairingSession, byte[])}), * or in accepting the user input * ({@link PairingListener#onPerformInputDeviceRole(PairingSession)}), * depending on the role negotiated during initialization. * <p> * When operating in the input role, the session will block execution after * calling {@link PairingListener#onPerformInputDeviceRole(PairingSession)} to * wait for the secret. The listener, or some activity resulting from it, must * publish the input secret to the session via * {@link PairingSession#setSecret(byte[])}. */ public abstract class PairingSession { protected enum ProtocolState { STATE_UNINITIALIZED, STATE_INITIALIZING, STATE_CONFIGURING, STATE_PAIRING, STATE_SUCCESS, STATE_FAILURE, } /** * Enable extra verbose debug logging. */ private static final boolean DEBUG_VERBOSE = false; /** * Controls whether to verify the secret portion of the SecretAck message. * <p> * NOTE(mikey): One implementation does not send the secret back in * the SecretAck. This should be fixed, but in the meantime it is not * essential that we verify it, since *any* acknowledgment from the * sender is enough to indicate protocol success. */ private static final boolean VERIFY_SECRET_ACK = false; /** * Timeout, in milliseconds, for polling the secret queue for a response from * the listener. This timeout is relevant only to periodically check the * mAbort flag to terminate the protocol, which is set by calling teardown(). */ private static final int SECRET_POLL_TIMEOUT_MS = 500; /** * Performs the initialization phase of the protocol. * * @throws PoloException if a protocol error occurred * @throws IOException if an error occurred in input/output */ protected abstract void doInitializationPhase() throws PoloException, IOException; /** * Performs the configuration phase of the protocol. * * @throws PoloException if a protocol error occurred * @throws IOException if an error occurred in input/output */ protected abstract void doConfigurationPhase() throws PoloException, IOException; /** * Internal representation of challenge-response. */ protected PoloChallengeResponse mChallenge; /** * Implementation of the transport layer. */ private final PoloWireInterface mProtocol; /** * Context for the pairing session. */ protected final PairingContext mPairingContext; /** * Local endpoint's supported options. * <p> * If this session is acting as a server, this message will be sent to the * client in the Initialization phase. If acting as a client, this member is * used to store local options and compute the Configuration message (but * is never transmitted directly). */ protected OptionsMessage mLocalOptions; /** * Encoding scheme used for the session. */ protected SecretEncoder mEncoder; /** * Name of the service being paired. */ protected String mServiceName; /** * Name of the peer. */ protected String mPeerName; /** * Configuration message for current session. * <p> * This is computed by the client and sent to the server. */ protected ConfigurationMessage mSessionConfig; /** * Listener that will receive callbacks upon protocol events. */ protected PairingListener mListener; /** * Internal state of the pairing session. */ protected ProtocolState mState; /** * Threadsafe queue for receiving the messages sent by peer, user-given secret * from the listener, or exceptions caught by async threads. */ protected BlockingQueue<QueueMessage> mMessageQueue; /** * Flag set when the session should be aborted. */ protected boolean mAbort; /** * Reader thread. */ private final Thread mThread; /** * Constructor. * * @param protocol the wire interface to operate against * @param pairingContext a PairingContext for the session */ public PairingSession(PoloWireInterface protocol, PairingContext pairingContext) { mProtocol = protocol; mPairingContext = pairingContext; mState = ProtocolState.STATE_UNINITIALIZED; mMessageQueue = new LinkedBlockingQueue<QueueMessage>(); Certificate clientCert = mPairingContext.getClientCertificate(); Certificate serverCert = mPairingContext.getServerCertificate(); mChallenge = new PoloChallengeResponse(clientCert, serverCert, new PoloChallengeResponse.DebugLogger() { public void debug(String message) { logDebug(message); } public void verbose(String message) { if (DEBUG_VERBOSE) { logDebug(message); } } }); mLocalOptions = new OptionsMessage(); if (mPairingContext.isServer()) { mLocalOptions.setProtocolRolePreference(ProtocolRole.DISPLAY_DEVICE); } else { mLocalOptions.setProtocolRolePreference(ProtocolRole.INPUT_DEVICE); } mThread = new Thread(new Runnable() { public void run() { logDebug("Starting reader"); try { while (!mAbort) { try { PoloMessage message = mProtocol.getNextMessage(); logDebug("Received: " + message.getClass()); mMessageQueue.put(new QueueMessage(message)); } catch (PoloException exception) { logDebug("Exception while getting message: " + exception); mMessageQueue.put(new QueueMessage(exception)); break; } catch (IOException exception) { logDebug("Exception while getting message: " + exception); mMessageQueue.put(new QueueMessage(new PoloException(exception))); break; } } } catch (InterruptedException ie) { logDebug("Interrupted: " + ie); } finally { logDebug("Reader is done"); } } }); mThread.start(); } public void teardown() { try { // Send any error. mProtocol.sendErrorMessage(new Exception()); mPairingContext.getPeerInputStream().close(); mPairingContext.getPeerOutputStream().close(); } catch (IOException e) { // oh well. } // Unblock the blocking wait on the secret queue. mAbort = true; mThread.interrupt(); } protected void log(LogLevel level, String message) { if (mListener != null) { mListener.onLogMessage(level, message); } } /** * Logs a debug message to the active listener. */ public void logDebug(String message) { log(LogLevel.LOG_DEBUG, message); } /** * Logs an informational message to the active listener. */ public void logInfo(String message) { log(LogLevel.LOG_INFO, message); } /** * Logs an error message to the active listener. */ public void logError(String message) { log(LogLevel.LOG_ERROR, message); } /** * Adds an encoding to the supported input role encodings. This method can * only be called before the session has started. * <p> * If no input encodings have been added, then this endpoint cannot act as * the input device protocol role. * * @param encoding the {@link EncodingOption} to add */ public void addInputEncoding(EncodingOption encoding) { if (mState != ProtocolState.STATE_UNINITIALIZED) { throw new IllegalStateException("Cannot add encodings once session " + "has been started."); } // Legal values of GAMMALEN must be: // - an even number of bytes // - at least 2 bytes if ((encoding.getSymbolLength() < 2) || ((encoding.getSymbolLength() % 2) != 0)) { throw new IllegalArgumentException("Bad symbol length: " + encoding.getSymbolLength()); } mLocalOptions.addInputEncoding(encoding); } /** * Adds an encoding to the supported output role encodings. This method can * only be called before the session has started. * <p> * If no output encodings have been added, then this endpoint cannot act as * the output device protocol role. * * @param encoding the {@link EncodingOption} to add */ public void addOutputEncoding(EncodingOption encoding) { if (mState != ProtocolState.STATE_UNINITIALIZED) { throw new IllegalStateException("Cannot add encodings once session " + "has been started."); } mLocalOptions.addOutputEncoding(encoding); } /** * Changes the internal state. * * @param newState the new state */ private void setState(ProtocolState newState) { logInfo("New state: " + newState); mState = newState; } /** * Runs the pairing protocol. * <p> * Supported input and output encodings must be specified * first, using * {@link PairingSession#addInputEncoding(EncodingOption)} and * {@link PairingSession#addOutputEncoding(EncodingOption)}, * respectively. * * @param listener the {@link PairingListener} for the session * @return {@code true} if pairing was successful */ public boolean doPair(PairingListener listener) { mListener = listener; mListener.onSessionCreated(this); if (mPairingContext.isServer()) { logDebug("Protocol started (SERVER mode)"); } else { logDebug("Protocol started (CLIENT mode)"); } logDebug("Local options: " + mLocalOptions.toString()); Certificate clientCert = mPairingContext.getClientCertificate(); if (DEBUG_VERBOSE) { logDebug("Client certificate:"); logDebug(clientCert.toString()); } Certificate serverCert = mPairingContext.getServerCertificate(); if (DEBUG_VERBOSE) { logDebug("Server certificate:"); logDebug(serverCert.toString()); } boolean success = false; try { setState(ProtocolState.STATE_INITIALIZING); doInitializationPhase(); setState(ProtocolState.STATE_CONFIGURING); doConfigurationPhase(); setState(ProtocolState.STATE_PAIRING); doPairingPhase(); success = true; } catch (ProtocolErrorException e) { logDebug("Remote protocol failure: " + e); } catch (PoloException e) { try { logDebug("Local protocol failure, attempting to send error: " + e); mProtocol.sendErrorMessage(e); } catch (IOException e1) { logDebug("Error message send failed"); } } catch (IOException e) { logDebug("IOException: " + e); } if (success) { setState(ProtocolState.STATE_SUCCESS); } else { setState(ProtocolState.STATE_FAILURE); } mListener.onSessionEnded(this); return success; } /** * Returns {@code true} if the session is in a terminal state (success or * failure). */ public boolean hasCompleted() { switch (mState) { case STATE_SUCCESS: case STATE_FAILURE: return true; default: return false; } } public boolean hasSucceeded() { return mState == ProtocolState.STATE_SUCCESS; } public String getServiceName() { return mServiceName; } /** * Sets the secret, as received from a user. This method is only meaningful * when the endpoint is acting as the input device role. * * @param secret the secret, as a byte sequence * @return {@code true} if the secret was captured */ public boolean setSecret(byte[] secret) { if (!isInputDevice()) { throw new IllegalStateException("Secret can only be set for " + "input role session."); } else if (mState != ProtocolState.STATE_PAIRING) { throw new IllegalStateException("Secret can only be set while " + "in pairing state."); } return mMessageQueue.offer(new QueueMessage(secret)); } /** * Executes the pairing phase of the protocol. * * @throws PoloException if a protocol error occurred * @throws IOException if an error in the input/output occurred */ protected void doPairingPhase() throws PoloException, IOException { if (isInputDevice()) { new Thread(new Runnable() { public void run() { logDebug("Calling listener for user input..."); try { mListener.onPerformInputDeviceRole(PairingSession.this); } catch (PoloException exception) { logDebug("Sending exception: " + exception); mMessageQueue.offer(new QueueMessage(exception)); } finally { logDebug("Listener finished."); } } }).start(); logDebug("Waiting for secret from Listener or ..."); QueueMessage message = waitForMessage(); if (message == null || !message.hasSecret()) { throw new PoloException( "Illegal state - no secret available: " + message); } byte[] userGamma = message.mSecret; if (userGamma == null) { throw new PoloException("Invalid secret."); } boolean match = mChallenge.checkGamma(userGamma); if (match != true) { throw new BadSecretException("Secret failed local check."); } byte[] userNonce = mChallenge.extractNonce(userGamma); byte[] genAlpha = mChallenge.getAlpha(userNonce); logDebug("Sending Secret reply..."); SecretMessage secretMessage = new SecretMessage(genAlpha); mProtocol.sendMessage(secretMessage); logDebug("Waiting for SecretAck..."); SecretAckMessage secretAck = (SecretAckMessage) getNextMessage(PoloMessageType.SECRET_ACK); if (VERIFY_SECRET_ACK) { byte[] inbandAlpha = secretAck.getSecret(); if (!Arrays.equals(inbandAlpha, genAlpha)) { throw new BadSecretException("Inband secret did not match. " + "Expected [" + PoloUtil.bytesToHexString(genAlpha) + "], got [" + PoloUtil.bytesToHexString(inbandAlpha) + "]"); } } } else { int symbolLength = mSessionConfig.getEncoding().getSymbolLength(); int nonceLength = symbolLength / 2; int bytesNeeded = nonceLength / mEncoder.symbolsPerByte(); byte[] nonce = new byte[bytesNeeded]; SecureRandom random; try { random = SecureRandom.getInstance("SHA1PRNG"); } catch (NoSuchAlgorithmException e) { throw new PoloException(e); } random.nextBytes(nonce); // Display gamma logDebug("Calling listener to display output..."); byte[] gamma = mChallenge.getGamma(nonce); mListener.onPerformOutputDeviceRole(this, gamma); logDebug("Waiting for Secret..."); SecretMessage secretMessage = (SecretMessage) getNextMessage(PoloMessageType.SECRET); byte[] localAlpha = mChallenge.getAlpha(nonce); byte[] inbandAlpha = secretMessage.getSecret(); boolean matched = Arrays.equals(localAlpha, inbandAlpha); if (!matched) { throw new BadSecretException("Inband secret did not match. " + "Expected [" + PoloUtil.bytesToHexString(localAlpha) + "], got [" + PoloUtil.bytesToHexString(inbandAlpha) + "]"); } logDebug("Sending SecretAck..."); byte[] genAlpha = mChallenge.getAlpha(nonce); SecretAckMessage secretAck = new SecretAckMessage(inbandAlpha); mProtocol.sendMessage(secretAck); } } public SecretEncoder getEncoder() { return mEncoder; } /** * Sets the current session's configuration from a * {@link ConfigurationMessage}. * * @param message the session's config * @throws PoloException if the config was not valid for some reason */ protected void setConfiguration(ConfigurationMessage message) throws PoloException { if (message == null || message.getEncoding() == null) { throw new NoConfigurationException("No configuration is possible."); } if (message.getEncoding().getSymbolLength() % 2 != 0) { throw new PoloException("Symbol length must be even."); } if (message.getEncoding().getSymbolLength() < 2) { throw new PoloException("Symbol length must be >= 2 symbols."); } switch (message.getEncoding().getType()) { case ENCODING_HEXADECIMAL: mEncoder = new HexadecimalEncoder(); break; default: throw new PoloException("Unsupported encoding type."); } mSessionConfig = message; } /** * Returns the role of this endpoint in the current session. */ protected ProtocolRole getLocalRole() { assert (mSessionConfig != null); if (!mPairingContext.isServer()) { return mSessionConfig.getClientRole(); } else { return (mSessionConfig.getClientRole() == ProtocolRole.DISPLAY_DEVICE) ? ProtocolRole.INPUT_DEVICE : ProtocolRole.DISPLAY_DEVICE; } } /** * Returns {@code true} if this endpoint will act as the input device. */ protected boolean isInputDevice() { return (getLocalRole() == ProtocolRole.INPUT_DEVICE); } /** * Returns {@code true} if peer's name is set. */ public boolean hasPeerName() { return mPeerName != null; } /** * Returns peer's name if set, {@code null} otherwise. */ public String getPeerName() { return mPeerName; } protected PoloMessage getNextMessage(PoloMessageType type) throws PoloException { QueueMessage message = waitForMessage(); if (message != null && message.hasPoloMessage()) { if (!type.equals(message.mPoloMessage.getType())) { throw new PoloException( "Unexpected message type: " + message.mPoloMessage.getType()); } return message.mPoloMessage; } throw new PoloException("Invalid state - expected polo message"); } /** * Returns next queued message. The method blocks until the secret or the * polo message is available. * * @return the queued message, or null on error * @throws PoloException if exception was queued */ private QueueMessage waitForMessage() throws PoloException { while (!mAbort) { try { QueueMessage message = mMessageQueue.poll(SECRET_POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS); if (message != null) { if (message.hasPoloException()) { throw new PoloException(message.mPoloException); } return message; } } catch (InterruptedException e) { break; } } // Aborted or interrupted. return null; } /** * Sends message to the peer. * * @param message the message * @throws PoloException if a protocol error occurred * @throws IOException if an error in the input/output occurred */ protected void sendMessage(PoloMessage message) throws IOException, PoloException { mProtocol.sendMessage(message); } /** * Queued message, that can carry information about secret, next read message, * or exception caught by reader or input threads. */ private static final class QueueMessage { final PoloMessage mPoloMessage; final PoloException mPoloException; final byte[] mSecret; private QueueMessage( PoloMessage message, byte[] secret, PoloException exception) { int nonNullCount = 0; if (message != null) { ++nonNullCount; } mPoloMessage = message; if (exception != null) { assert(nonNullCount == 0); ++nonNullCount; } mPoloException = exception; if (secret != null) { assert(nonNullCount == 0); ++nonNullCount; } mSecret = secret; assert(nonNullCount == 1); } public QueueMessage(PoloMessage message) { this(message, null, null); } public QueueMessage(byte[] secret) { this(null, secret, null); } public QueueMessage(PoloException exception) { this(null, null, exception); } public boolean hasPoloMessage() { return mPoloMessage != null; } public boolean hasPoloException() { return mPoloException != null; } public boolean hasSecret() { return mSecret != null; } @Override public String toString() { StringBuilder builder = new StringBuilder("QueueMessage("); if (hasPoloMessage()) { builder.append("poloMessage = " + mPoloMessage); } if (hasPoloException()) { builder.append("poloException = " + mPoloException); } if (hasSecret()) { builder.append("secret = " + Arrays.toString(mSecret)); } return builder.append(")").toString(); } } }