1 /*
   2  * Copyright (c) 2015, 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 package sun.security.ssl;
  26 
  27 import java.io.IOException;
  28 import java.nio.ByteBuffer;
  29 import java.security.*;
  30 import java.text.MessageFormat;
  31 import java.util.List;
  32 import java.util.ArrayList;
  33 import java.util.Locale;
  34 import java.util.Arrays;
  35 import java.util.Collection;
  36 import javax.crypto.Mac;
  37 import javax.crypto.SecretKey;
  38 import javax.net.ssl.SSLPeerUnverifiedException;
  39 import static sun.security.ssl.ClientAuthType.CLIENT_AUTH_REQUIRED;
  40 import sun.security.ssl.ClientHello.ClientHelloMessage;
  41 import sun.security.ssl.SSLExtension.ExtensionConsumer;
  42 import sun.security.ssl.SSLExtension.SSLExtensionSpec;
  43 import sun.security.ssl.SSLHandshake.HandshakeMessage;
  44 import sun.security.ssl.SessionTicketExtension.SessionTicketSpec;
  45 import sun.security.util.HexDumpEncoder;
  46 
  47 import static sun.security.ssl.SSLExtension.*;
  48 
  49 /**
  50  * Pack of the "pre_shared_key" extension.
  51  */
  52 final class PreSharedKeyExtension {
  53     static final HandshakeProducer chNetworkProducer =
  54             new CHPreSharedKeyProducer();
  55     static final ExtensionConsumer chOnLoadConsumer =
  56             new CHPreSharedKeyConsumer();
  57     static final HandshakeAbsence chOnLoadAbsence =
  58             new CHPreSharedKeyAbsence();
  59     static final HandshakeConsumer chOnTradeConsumer =
  60             new CHPreSharedKeyUpdate();
  61     static final SSLStringizer chStringizer =
  62             new CHPreSharedKeyStringizer();
  63 
  64     static final HandshakeProducer shNetworkProducer =
  65             new SHPreSharedKeyProducer();
  66     static final ExtensionConsumer shOnLoadConsumer =
  67             new SHPreSharedKeyConsumer();
  68     static final HandshakeAbsence shOnLoadAbsence =
  69             new SHPreSharedKeyAbsence();
  70     static final SSLStringizer shStringizer =
  71             new SHPreSharedKeyStringizer();
  72 
  73     private static final class PskIdentity {
  74         final byte[] identity;
  75         final int obfuscatedAge;
  76 
  77         PskIdentity(byte[] identity, int obfuscatedAge) {
  78             this.identity = identity;
  79             this.obfuscatedAge = obfuscatedAge;
  80         }
  81 
  82         int getEncodedLength() {
  83             return 2 + identity.length + 4;
  84         }
  85 
  86         void writeEncoded(ByteBuffer m) throws IOException {
  87             Record.putBytes16(m, identity);
  88             Record.putInt32(m, obfuscatedAge);
  89         }
  90 
  91         @Override
  92         public String toString() {
  93             return "{" + Utilities.toHexString(identity) + ", " +
  94                 obfuscatedAge + "}";
  95         }
  96     }
  97 
  98     private static final
  99             class CHPreSharedKeySpec implements SSLExtensionSpec {
 100         final List<PskIdentity> identities;
 101         final List<byte[]> binders;
 102 
 103         CHPreSharedKeySpec(List<PskIdentity> identities, List<byte[]> binders) {
 104             this.identities = identities;
 105             this.binders = binders;
 106         }
 107 
 108         CHPreSharedKeySpec(HandshakeContext context,
 109                 ByteBuffer m) throws IOException {
 110             // struct {
 111             //     PskIdentity identities<7..2^16-1>;
 112             //     PskBinderEntry binders<33..2^16-1>;
 113             // } OfferedPsks;
 114             if (m.remaining() < 44) {
 115                 throw context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
 116                     "Invalid pre_shared_key extension: " +
 117                     "insufficient data (length=" + m.remaining() + ")");
 118             }
 119 
 120             int idEncodedLength = Record.getInt16(m);
 121             if (idEncodedLength < 7) {
 122                 throw context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
 123                     "Invalid pre_shared_key extension: " +
 124                     "insufficient identities (length=" + idEncodedLength + ")");
 125             }
 126 
 127             identities = new ArrayList<>();
 128             int idReadLength = 0;
 129             while (idReadLength < idEncodedLength) {
 130                 byte[] id = Record.getBytes16(m);
 131                 if (id.length < 1) {
 132                     throw context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
 133                         "Invalid pre_shared_key extension: " +
 134                         "insufficient identity (length=" + id.length + ")");
 135                 }
 136                 int obfuscatedTicketAge = Record.getInt32(m);
 137 
 138                 PskIdentity pskId = new PskIdentity(id, obfuscatedTicketAge);
 139                 identities.add(pskId);
 140                 idReadLength += pskId.getEncodedLength();
 141             }
 142 
 143             if (m.remaining() < 35) {
 144                 throw context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
 145                         "Invalid pre_shared_key extension: " +
 146                         "insufficient binders data (length=" +
 147                         m.remaining() + ")");
 148             }
 149 
 150             int bindersEncodedLen = Record.getInt16(m);
 151             if (bindersEncodedLen < 33) {
 152                 throw context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
 153                         "Invalid pre_shared_key extension: " +
 154                         "insufficient binders (length=" +
 155                         bindersEncodedLen + ")");
 156             }
 157 
 158             binders = new ArrayList<>();
 159             int bindersReadLength = 0;
 160             while (bindersReadLength < bindersEncodedLen) {
 161                 byte[] binder = Record.getBytes8(m);
 162                 if (binder.length < 32) {
 163                     throw context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
 164                             "Invalid pre_shared_key extension: " +
 165                             "insufficient binder entry (length=" +
 166                             binder.length + ")");
 167                 }
 168                 binders.add(binder);
 169                 bindersReadLength += 1 + binder.length;
 170             }
 171         }
 172 
 173         int getIdsEncodedLength() {
 174             int idEncodedLength = 0;
 175             for (PskIdentity curId : identities) {
 176                 idEncodedLength += curId.getEncodedLength();
 177             }
 178 
 179             return idEncodedLength;
 180         }
 181 
 182         int getBindersEncodedLength() {
 183             int binderEncodedLength = 0;
 184             for (byte[] curBinder : binders) {
 185                 binderEncodedLength += 1 + curBinder.length;
 186             }
 187 
 188             return binderEncodedLength;
 189         }
 190 
 191         byte[] getEncoded() throws IOException {
 192             int idsEncodedLength = getIdsEncodedLength();
 193             int bindersEncodedLength = getBindersEncodedLength();
 194             int encodedLength = 4 + idsEncodedLength + bindersEncodedLength;
 195             byte[] buffer = new byte[encodedLength];
 196             ByteBuffer m = ByteBuffer.wrap(buffer);
 197             Record.putInt16(m, idsEncodedLength);
 198             for (PskIdentity curId : identities) {
 199                 curId.writeEncoded(m);
 200             }
 201             Record.putInt16(m, bindersEncodedLength);
 202             for (byte[] curBinder : binders) {
 203                 Record.putBytes8(m, curBinder);
 204             }
 205 
 206             return buffer;
 207         }
 208 
 209         @Override
 210         public String toString() {
 211             MessageFormat messageFormat = new MessageFormat(
 212                 "\"PreSharedKey\": '{'\n" +
 213                 "  \"identities\": '{'\n" +
 214                 "{0}\n" +
 215                 "  '}'" +
 216                 "  \"binders\": \"{1}\",\n" +
 217                 "'}'",
 218                 Locale.ENGLISH);
 219 
 220             Object[] messageFields = {
 221                 Utilities.indent(identitiesString()),
 222                 Utilities.indent(bindersString())
 223             };
 224 
 225             return messageFormat.format(messageFields);
 226         }
 227 
 228         String identitiesString() {
 229             HexDumpEncoder hexEncoder = new HexDumpEncoder();
 230 
 231             StringBuilder result = new StringBuilder();
 232             for (PskIdentity curId : identities) {
 233                 result.append("  {\n"+ Utilities.indent(
 234                         hexEncoder.encode(curId.identity), "    ") +
 235                         "\n  }\n");
 236             }
 237 
 238             return result.toString();
 239         }
 240 
 241         String bindersString() {
 242             StringBuilder result = new StringBuilder();
 243             for (byte[] curBinder : binders) {
 244                 result.append("{" + Utilities.toHexString(curBinder) + "}\n");
 245             }
 246 
 247             return result.toString();
 248         }
 249     }
 250 
 251     private static final
 252             class CHPreSharedKeyStringizer implements SSLStringizer {
 253         @Override
 254         public String toString(ByteBuffer buffer) {
 255             try {
 256                 // As the HandshakeContext parameter of CHPreSharedKeySpec
 257                 // constructor is used for fatal alert only, we can use
 258                 // null HandshakeContext here as we don't care about exception.
 259                 //
 260                 // Please take care of this code if the CHPreSharedKeySpec
 261                 // constructor is updated in the future.
 262                 return (new CHPreSharedKeySpec(null, buffer)).toString();
 263             } catch (Exception ex) {
 264                 // For debug logging only, so please swallow exceptions.
 265                 return ex.getMessage();
 266             }
 267         }
 268     }
 269 
 270     private static final
 271             class SHPreSharedKeySpec implements SSLExtensionSpec {
 272         final int selectedIdentity;
 273 
 274         SHPreSharedKeySpec(int selectedIdentity) {
 275             this.selectedIdentity = selectedIdentity;
 276         }
 277 
 278         SHPreSharedKeySpec(HandshakeContext context,
 279                 ByteBuffer m) throws IOException {
 280             if (m.remaining() < 2) {
 281                 throw context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
 282                         "Invalid pre_shared_key extension: " +
 283                         "insufficient selected_identity (length=" +
 284                         m.remaining() + ")");
 285             }
 286             this.selectedIdentity = Record.getInt16(m);
 287         }
 288 
 289         byte[] getEncoded() {
 290             return new byte[] {
 291                 (byte)((selectedIdentity >> 8) & 0xFF),
 292                 (byte)(selectedIdentity & 0xFF)
 293             };
 294         }
 295 
 296         @Override
 297         public String toString() {
 298             MessageFormat messageFormat = new MessageFormat(
 299                 "\"PreSharedKey\": '{'\n" +
 300                 "  \"selected_identity\"      : \"{0}\",\n" +
 301                 "'}'",
 302                 Locale.ENGLISH);
 303 
 304             Object[] messageFields = {
 305                 Utilities.byte16HexString(selectedIdentity)
 306             };
 307 
 308             return messageFormat.format(messageFields);
 309         }
 310     }
 311 
 312     private static final
 313             class SHPreSharedKeyStringizer implements SSLStringizer {
 314         @Override
 315         public String toString(ByteBuffer buffer) {
 316             try {
 317                 // As the HandshakeContext parameter of SHPreSharedKeySpec
 318                 // constructor is used for fatal alert only, we can use
 319                 // null HandshakeContext here as we don't care about exception.
 320                 //
 321                 // Please take care of this code if the SHPreSharedKeySpec
 322                 // constructor is updated in the future.
 323                 return (new SHPreSharedKeySpec(null, buffer)).toString();
 324             } catch (Exception ex) {
 325                 // For debug logging only, so please swallow exceptions.
 326                 return ex.getMessage();
 327             }
 328         }
 329     }
 330 
 331     private static final
 332             class CHPreSharedKeyConsumer implements ExtensionConsumer {
 333         // Prevent instantiation of this class.
 334         private CHPreSharedKeyConsumer() {
 335             // blank
 336         }
 337 
 338         @Override
 339         public void consume(ConnectionContext context,
 340                             HandshakeMessage message,
 341                             ByteBuffer buffer) throws IOException {
 342             ClientHelloMessage clientHello = (ClientHelloMessage) message;
 343             ServerHandshakeContext shc = (ServerHandshakeContext)context;
 344             // Is it a supported and enabled extension?
 345             if (!shc.sslConfig.isAvailable(SSLExtension.CH_PRE_SHARED_KEY)) {
 346                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 347                     SSLLogger.fine(
 348                             "Ignore unavailable pre_shared_key extension");
 349                 }
 350                 return;     // ignore the extension
 351             }
 352 
 353             // Parse the extension.
 354             CHPreSharedKeySpec pskSpec = null;
 355             try {
 356                 pskSpec = new CHPreSharedKeySpec(shc, buffer);
 357             } catch (IOException ioe) {
 358                 throw shc.conContext.fatal(Alert.UNEXPECTED_MESSAGE, ioe);
 359             }
 360 
 361             // The "psk_key_exchange_modes" extension should have been loaded.
 362             if (!shc.handshakeExtensions.containsKey(
 363                     SSLExtension.PSK_KEY_EXCHANGE_MODES)) {
 364                 throw shc.conContext.fatal(Alert.ILLEGAL_PARAMETER,
 365                         "Client sent PSK but not PSK modes, or the PSK " +
 366                         "extension is not the last extension");
 367             }
 368 
 369             // error if id and binder lists are not the same length
 370             if (pskSpec.identities.size() != pskSpec.binders.size()) {
 371                 throw shc.conContext.fatal(Alert.ILLEGAL_PARAMETER,
 372                         "PSK extension has incorrect number of binders");
 373             }
 374 
 375             if (shc.isResumption) {     // resumingSession may not be set
 376                 SSLSessionContextImpl sessionCache = (SSLSessionContextImpl)
 377                         shc.sslContext.engineGetServerSessionContext();
 378                 int idIndex = 0;
 379                 SSLSessionImpl s = null;
 380 
 381                 for (PskIdentity requestedId : pskSpec.identities) {
 382                     // If we are keeping state, see if the identity is in the cache
 383                     if (requestedId.identity.length == SessionId.MAX_LENGTH) {
 384                         s = sessionCache.get(requestedId.identity);
 385                     }
 386                     // See if the identity is a stateless ticket
 387                     if (s == null &&
 388                             requestedId.identity.length > SessionId.MAX_LENGTH &&
 389                             sessionCache.statelessEnabled()) {
 390                         ByteBuffer b =
 391                                 new SessionTicketSpec(requestedId.identity).
 392                                         decrypt(shc);
 393                         if (b != null) {
 394                             try {
 395                                 s = new SSLSessionImpl(shc, b);
 396                             } catch (IOException | RuntimeException e) {
 397                                 s = null;
 398                             }
 399                         }
 400                         if (b == null || s == null) {
 401                             if (SSLLogger.isOn &&
 402                                     SSLLogger.isOn("ssl,handshake")) {
 403                                 SSLLogger.fine(
 404                                         "Stateless session ticket invalid");
 405                             }
 406                         }
 407                     }
 408 
 409                     if (s != null && canRejoin(clientHello, shc, s)) {
 410                         if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 411                             SSLLogger.fine("Resuming session: ", s);
 412                         }
 413 
 414                         // binder will be checked later
 415                         shc.resumingSession = s;
 416                         shc.handshakeExtensions.put(SH_PRE_SHARED_KEY,
 417                             new SHPreSharedKeySpec(idIndex));   // for the index
 418                         break;
 419                     }
 420 
 421                     ++idIndex;
 422                 }
 423 
 424                 if (idIndex == pskSpec.identities.size()) {
 425                     // no resumable session
 426                     shc.isResumption = false;
 427                     shc.resumingSession = null;
 428                 }
 429             }
 430             // update the context
 431             shc.handshakeExtensions.put(
 432                 SSLExtension.CH_PRE_SHARED_KEY, pskSpec);
 433         }
 434     }
 435 
 436     private static boolean canRejoin(ClientHelloMessage clientHello,
 437         ServerHandshakeContext shc, SSLSessionImpl s) {
 438 
 439         boolean result = s.isRejoinable() && (s.getPreSharedKey() != null);
 440 
 441         // Check protocol version
 442         if (result && s.getProtocolVersion() != shc.negotiatedProtocol) {
 443             if (SSLLogger.isOn &&
 444                 SSLLogger.isOn("ssl,handshake,verbose")) {
 445 
 446                 SSLLogger.finest("Can't resume, incorrect protocol version");
 447             }
 448             result = false;
 449         }
 450 
 451         // Make sure that the server handshake context's localSupportedSignAlgs
 452         // field is populated.  This is particularly important when
 453         // client authentication was used in an initial session and it is
 454         // now being resumed.
 455         if (shc.localSupportedSignAlgs == null) {
 456             shc.localSupportedSignAlgs =
 457                     SignatureScheme.getSupportedAlgorithms(
 458                             shc.algorithmConstraints, shc.activeProtocols);
 459         }
 460 
 461         // Validate the required client authentication.
 462         if (result &&
 463             (shc.sslConfig.clientAuthType == CLIENT_AUTH_REQUIRED)) {
 464             try {
 465                 s.getPeerPrincipal();
 466             } catch (SSLPeerUnverifiedException e) {
 467                 if (SSLLogger.isOn &&
 468                         SSLLogger.isOn("ssl,handshake,verbose")) {
 469                     SSLLogger.finest(
 470                         "Can't resume, " +
 471                         "client authentication is required");
 472                 }
 473                 result = false;
 474             }
 475 
 476             // Make sure the list of supported signature algorithms matches
 477             Collection<SignatureScheme> sessionSigAlgs =
 478                 s.getLocalSupportedSignatureSchemes();
 479             if (result &&
 480                 !shc.localSupportedSignAlgs.containsAll(sessionSigAlgs)) {
 481 
 482                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 483                     SSLLogger.fine("Can't resume. Session uses different " +
 484                         "signature algorithms");
 485                 }
 486                 result = false;
 487             }
 488         }
 489 
 490         // ensure that the endpoint identification algorithm matches the
 491         // one in the session
 492         String identityAlg = shc.sslConfig.identificationProtocol;
 493         if (result && identityAlg != null) {
 494             String sessionIdentityAlg = s.getIdentificationProtocol();
 495             if (!identityAlg.equalsIgnoreCase(sessionIdentityAlg)) {
 496                 if (SSLLogger.isOn &&
 497                     SSLLogger.isOn("ssl,handshake,verbose")) {
 498 
 499                     SSLLogger.finest("Can't resume, endpoint id" +
 500                         " algorithm does not match, requested: " +
 501                         identityAlg + ", cached: " + sessionIdentityAlg);
 502                 }
 503                 result = false;
 504             }
 505         }
 506 
 507         // Ensure cipher suite can be negotiated
 508         if (result && (!shc.isNegotiable(s.getSuite()) ||
 509             !clientHello.cipherSuites.contains(s.getSuite()))) {
 510             if (SSLLogger.isOn &&
 511                     SSLLogger.isOn("ssl,handshake,verbose")) {
 512                 SSLLogger.finest(
 513                     "Can't resume, unavailable session cipher suite");
 514             }
 515             result = false;
 516         }
 517 
 518         return result;
 519     }
 520 
 521     private static final
 522             class CHPreSharedKeyUpdate implements HandshakeConsumer {
 523         // Prevent instantiation of this class.
 524         private CHPreSharedKeyUpdate() {
 525             // blank
 526         }
 527 
 528         @Override
 529         public void consume(ConnectionContext context,
 530                 HandshakeMessage message) throws IOException {
 531             ServerHandshakeContext shc = (ServerHandshakeContext)context;
 532             if (!shc.isResumption || shc.resumingSession == null) {
 533                 // not resuming---nothing to do
 534                 return;
 535             }
 536 
 537             CHPreSharedKeySpec chPsk = (CHPreSharedKeySpec)
 538                     shc.handshakeExtensions.get(SSLExtension.CH_PRE_SHARED_KEY);
 539             SHPreSharedKeySpec shPsk = (SHPreSharedKeySpec)
 540                     shc.handshakeExtensions.get(SSLExtension.SH_PRE_SHARED_KEY);
 541             if (chPsk == null || shPsk == null) {
 542                 throw shc.conContext.fatal(Alert.INTERNAL_ERROR,
 543                         "Required extensions are unavailable");
 544             }
 545 
 546             byte[] binder = chPsk.binders.get(shPsk.selectedIdentity);
 547 
 548             // set up PSK binder hash
 549             HandshakeHash pskBinderHash = shc.handshakeHash.copy();
 550             byte[] lastMessage = pskBinderHash.removeLastReceived();
 551             ByteBuffer messageBuf = ByteBuffer.wrap(lastMessage);
 552             // skip the type and length
 553             messageBuf.position(4);
 554             // read to find the beginning of the binders
 555             ClientHelloMessage.readPartial(shc.conContext, messageBuf);
 556             int length = messageBuf.position();
 557             messageBuf.position(0);
 558             pskBinderHash.receive(messageBuf, length);
 559 
 560             checkBinder(shc, shc.resumingSession, pskBinderHash, binder);
 561         }
 562     }
 563 
 564     private static void checkBinder(ServerHandshakeContext shc,
 565             SSLSessionImpl session,
 566             HandshakeHash pskBinderHash, byte[] binder) throws IOException {
 567         SecretKey psk = session.getPreSharedKey();
 568         if (psk == null) {
 569             throw shc.conContext.fatal(Alert.INTERNAL_ERROR,
 570                     "Session has no PSK");
 571         }
 572 
 573         SecretKey binderKey = deriveBinderKey(shc, psk, session);
 574         byte[] computedBinder =
 575                 computeBinder(shc, binderKey, session, pskBinderHash);
 576         if (!Arrays.equals(binder, computedBinder)) {
 577             throw shc.conContext.fatal(Alert.ILLEGAL_PARAMETER,
 578                     "Incorect PSK binder value");
 579         }
 580     }
 581 
 582     // Class that produces partial messages used to compute binder hash
 583     static final class PartialClientHelloMessage extends HandshakeMessage {
 584 
 585         private final ClientHello.ClientHelloMessage msg;
 586         private final CHPreSharedKeySpec psk;
 587 
 588         PartialClientHelloMessage(HandshakeContext ctx,
 589                                   ClientHello.ClientHelloMessage msg,
 590                                   CHPreSharedKeySpec psk) {
 591             super(ctx);
 592 
 593             this.msg = msg;
 594             this.psk = psk;
 595         }
 596 
 597         @Override
 598         SSLHandshake handshakeType() {
 599             return msg.handshakeType();
 600         }
 601 
 602         private int pskTotalLength() {
 603             return psk.getIdsEncodedLength() +
 604                 psk.getBindersEncodedLength() + 8;
 605         }
 606 
 607         @Override
 608         int messageLength() {
 609 
 610             if (msg.extensions.get(SSLExtension.CH_PRE_SHARED_KEY) != null) {
 611                 return msg.messageLength();
 612             } else {
 613                 return msg.messageLength() + pskTotalLength();
 614             }
 615         }
 616 
 617         @Override
 618         void send(HandshakeOutStream hos) throws IOException {
 619             msg.sendCore(hos);
 620 
 621             // complete extensions
 622             int extsLen = msg.extensions.length();
 623             if (msg.extensions.get(SSLExtension.CH_PRE_SHARED_KEY) == null) {
 624                 extsLen += pskTotalLength();
 625             }
 626             hos.putInt16(extsLen - 2);
 627             // write the complete extensions
 628             for (SSLExtension ext : SSLExtension.values()) {
 629                 byte[] extData = msg.extensions.get(ext);
 630                 if (extData == null) {
 631                     continue;
 632                 }
 633                 // the PSK could be there from an earlier round
 634                 if (ext == SSLExtension.CH_PRE_SHARED_KEY) {
 635                     continue;
 636                 }
 637                 int extID = ext.id;
 638                 hos.putInt16(extID);
 639                 hos.putBytes16(extData);
 640             }
 641 
 642             // partial PSK extension
 643             int extID = SSLExtension.CH_PRE_SHARED_KEY.id;
 644             hos.putInt16(extID);
 645             byte[] encodedPsk = psk.getEncoded();
 646             hos.putInt16(encodedPsk.length);
 647             hos.write(encodedPsk, 0, psk.getIdsEncodedLength() + 2);
 648         }
 649     }
 650 
 651     private static final
 652             class CHPreSharedKeyProducer implements HandshakeProducer {
 653         // Prevent instantiation of this class.
 654         private CHPreSharedKeyProducer() {
 655             // blank
 656         }
 657 
 658         @Override
 659         public byte[] produce(ConnectionContext context,
 660                 HandshakeMessage message) throws IOException {
 661 
 662             // The producing happens in client side only.
 663             ClientHandshakeContext chc = (ClientHandshakeContext)context;
 664             if (!chc.isResumption || chc.resumingSession == null) {
 665                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 666                     SSLLogger.fine("No session to resume.");
 667                 }
 668                 return null;
 669             }
 670 
 671             // Make sure the list of supported signature algorithms matches
 672             Collection<SignatureScheme> sessionSigAlgs =
 673                 chc.resumingSession.getLocalSupportedSignatureSchemes();
 674             if (!chc.localSupportedSignAlgs.containsAll(sessionSigAlgs)) {
 675                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 676                     SSLLogger.fine("Existing session uses different " +
 677                         "signature algorithms");
 678                 }
 679                 return null;
 680             }
 681 
 682             // The session must have a pre-shared key
 683             SecretKey psk = chc.resumingSession.getPreSharedKey();
 684             if (psk == null) {
 685                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 686                     SSLLogger.fine("Existing session has no PSK.");
 687                 }
 688                 return null;
 689             }
 690 
 691             // The PSK ID can only be used in one connections, but this method
 692             // may be called twice in a connection if the server sends HRR.
 693             // ID is saved in the context so it can be used in the second call.
 694             if (chc.pskIdentity == null) {
 695                 chc.pskIdentity = chc.resumingSession.consumePskIdentity();
 696             }
 697 
 698             if (chc.pskIdentity == null) {
 699                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 700                     SSLLogger.fine(
 701                         "PSK has no identity, or identity was already used");
 702                 }
 703                 return null;
 704             }
 705 
 706             //The session cannot be used again. Remove it from the cache.
 707             SSLSessionContextImpl sessionCache = (SSLSessionContextImpl)
 708                 chc.sslContext.engineGetClientSessionContext();
 709             sessionCache.remove(chc.resumingSession.getSessionId());
 710 
 711             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 712                 SSLLogger.fine(
 713                     "Found resumable session. Preparing PSK message.");
 714             }
 715 
 716             List<PskIdentity> identities = new ArrayList<>();
 717             int ageMillis = (int)(System.currentTimeMillis() -
 718                     chc.resumingSession.getTicketCreationTime());
 719             int obfuscatedAge =
 720                     ageMillis + chc.resumingSession.getTicketAgeAdd();
 721             identities.add(new PskIdentity(chc.pskIdentity, obfuscatedAge));
 722 
 723             SecretKey binderKey =
 724                     deriveBinderKey(chc, psk, chc.resumingSession);
 725             ClientHelloMessage clientHello = (ClientHelloMessage)message;
 726             CHPreSharedKeySpec pskPrototype = createPskPrototype(
 727                 chc.resumingSession.getSuite().hashAlg.hashLength, identities);
 728             HandshakeHash pskBinderHash = chc.handshakeHash.copy();
 729 
 730             byte[] binder = computeBinder(chc, binderKey, pskBinderHash,
 731                     chc.resumingSession, chc, clientHello, pskPrototype);
 732 
 733             List<byte[]> binders = new ArrayList<>();
 734             binders.add(binder);
 735 
 736             CHPreSharedKeySpec pskMessage =
 737                     new CHPreSharedKeySpec(identities, binders);
 738             chc.handshakeExtensions.put(CH_PRE_SHARED_KEY, pskMessage);
 739             return pskMessage.getEncoded();
 740         }
 741 
 742         private CHPreSharedKeySpec createPskPrototype(
 743                 int hashLength, List<PskIdentity> identities) {
 744             List<byte[]> binders = new ArrayList<>();
 745             byte[] binderProto = new byte[hashLength];
 746             int i = identities.size();
 747             while (i-- > 0) {
 748                 binders.add(binderProto);
 749             }
 750 
 751             return new CHPreSharedKeySpec(identities, binders);
 752         }
 753     }
 754 
 755     private static byte[] computeBinder(
 756             HandshakeContext context, SecretKey binderKey,
 757             SSLSessionImpl session,
 758             HandshakeHash pskBinderHash) throws IOException {
 759 
 760         pskBinderHash.determine(
 761                 session.getProtocolVersion(), session.getSuite());
 762         pskBinderHash.update();
 763         byte[] digest = pskBinderHash.digest();
 764 
 765         return computeBinder(context, binderKey, session, digest);
 766     }
 767 
 768     private static byte[] computeBinder(
 769             HandshakeContext context, SecretKey binderKey,
 770             HandshakeHash hash, SSLSessionImpl session,
 771             HandshakeContext ctx, ClientHello.ClientHelloMessage hello,
 772             CHPreSharedKeySpec pskPrototype) throws IOException {
 773 
 774         PartialClientHelloMessage partialMsg =
 775                 new PartialClientHelloMessage(ctx, hello, pskPrototype);
 776 
 777         SSLEngineOutputRecord record = new SSLEngineOutputRecord(hash);
 778         HandshakeOutStream hos = new HandshakeOutStream(record);
 779         partialMsg.write(hos);
 780 
 781         hash.determine(session.getProtocolVersion(), session.getSuite());
 782         hash.update();
 783         byte[] digest = hash.digest();
 784 
 785         return computeBinder(context, binderKey, session, digest);
 786     }
 787 
 788     private static byte[] computeBinder(HandshakeContext context,
 789             SecretKey binderKey,
 790             SSLSessionImpl session, byte[] digest) throws IOException {
 791         try {
 792             CipherSuite.HashAlg hashAlg = session.getSuite().hashAlg;
 793             HKDF hkdf = new HKDF(hashAlg.name);
 794             byte[] label = ("tls13 finished").getBytes();
 795             byte[] hkdfInfo = SSLSecretDerivation.createHkdfInfo(
 796                     label, new byte[0], hashAlg.hashLength);
 797             SecretKey finishedKey = hkdf.expand(
 798                     binderKey, hkdfInfo, hashAlg.hashLength, "TlsBinderKey");
 799 
 800             String hmacAlg =
 801                 "Hmac" + hashAlg.name.replace("-", "");
 802             try {
 803                 Mac hmac = Mac.getInstance(hmacAlg);
 804                 hmac.init(finishedKey);
 805                 return hmac.doFinal(digest);
 806             } catch (NoSuchAlgorithmException | InvalidKeyException ex) {
 807                 throw context.conContext.fatal(Alert.INTERNAL_ERROR, ex);
 808             }
 809         } catch (GeneralSecurityException ex) {
 810             throw context.conContext.fatal(Alert.INTERNAL_ERROR, ex);
 811         }
 812     }
 813 
 814     private static SecretKey deriveBinderKey(HandshakeContext context,
 815             SecretKey psk, SSLSessionImpl session) throws IOException {
 816         try {
 817             CipherSuite.HashAlg hashAlg = session.getSuite().hashAlg;
 818             HKDF hkdf = new HKDF(hashAlg.name);
 819             byte[] zeros = new byte[hashAlg.hashLength];
 820             SecretKey earlySecret = hkdf.extract(zeros, psk, "TlsEarlySecret");
 821 
 822             byte[] label = ("tls13 res binder").getBytes();
 823             MessageDigest md = MessageDigest.getInstance(hashAlg.name);
 824             byte[] hkdfInfo = SSLSecretDerivation.createHkdfInfo(
 825                     label, md.digest(new byte[0]), hashAlg.hashLength);
 826             return hkdf.expand(earlySecret,
 827                     hkdfInfo, hashAlg.hashLength, "TlsBinderKey");
 828         } catch (GeneralSecurityException ex) {
 829             throw context.conContext.fatal(Alert.INTERNAL_ERROR, ex);
 830         }
 831     }
 832 
 833     private static final
 834             class CHPreSharedKeyAbsence implements HandshakeAbsence {
 835         @Override
 836         public void absent(ConnectionContext context,
 837                            HandshakeMessage message) throws IOException {
 838 
 839             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 840                 SSLLogger.fine(
 841                 "Handling pre_shared_key absence.");
 842             }
 843 
 844             ServerHandshakeContext shc = (ServerHandshakeContext)context;
 845 
 846             // Resumption is only determined by PSK, when enabled
 847             shc.resumingSession = null;
 848             shc.isResumption = false;
 849         }
 850     }
 851 
 852     private static final
 853             class SHPreSharedKeyConsumer implements ExtensionConsumer {
 854         // Prevent instantiation of this class.
 855         private SHPreSharedKeyConsumer() {
 856             // blank
 857         }
 858 
 859         @Override
 860         public void consume(ConnectionContext context,
 861             HandshakeMessage message, ByteBuffer buffer) throws IOException {
 862             // The consuming happens in client side only.
 863             ClientHandshakeContext chc = (ClientHandshakeContext)context;
 864 
 865             // Is it a response of the specific request?
 866             if (!chc.handshakeExtensions.containsKey(
 867                     SSLExtension.CH_PRE_SHARED_KEY)) {
 868                 throw chc.conContext.fatal(Alert.UNEXPECTED_MESSAGE,
 869                     "Server sent unexpected pre_shared_key extension");
 870             }
 871 
 872             SHPreSharedKeySpec shPsk = new SHPreSharedKeySpec(chc, buffer);
 873             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 874                 SSLLogger.fine(
 875                     "Received pre_shared_key extension: ", shPsk);
 876             }
 877 
 878             if (shPsk.selectedIdentity != 0) {
 879                 throw chc.conContext.fatal(Alert.ILLEGAL_PARAMETER,
 880                     "Selected identity index is not in correct range.");
 881             }
 882 
 883             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 884                 SSLLogger.fine(
 885                         "Resuming session: ", chc.resumingSession);
 886             }
 887         }
 888     }
 889 
 890     private static final
 891             class SHPreSharedKeyAbsence implements HandshakeAbsence {
 892         @Override
 893         public void absent(ConnectionContext context,
 894                 HandshakeMessage message) throws IOException {
 895             ClientHandshakeContext chc = (ClientHandshakeContext)context;
 896 
 897             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 898                 SSLLogger.fine("Handling pre_shared_key absence.");
 899             }
 900 
 901             // The server refused to resume, or the client did not
 902             // request 1.3 resumption.
 903             chc.resumingSession = null;
 904             chc.isResumption = false;
 905         }
 906     }
 907 
 908     private static final
 909             class SHPreSharedKeyProducer implements HandshakeProducer {
 910         // Prevent instantiation of this class.
 911         private SHPreSharedKeyProducer() {
 912             // blank
 913         }
 914 
 915         @Override
 916         public byte[] produce(ConnectionContext context,
 917                 HandshakeMessage message) throws IOException {
 918             ServerHandshakeContext shc = (ServerHandshakeContext)context;
 919             SHPreSharedKeySpec psk = (SHPreSharedKeySpec)
 920                     shc.handshakeExtensions.get(SH_PRE_SHARED_KEY);
 921             if (psk == null) {
 922                 return null;
 923             }
 924 
 925             return psk.getEncoded();
 926         }
 927     }
 928 }