1 /*
   2  * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package sun.security.ssl;
  27 
  28 import sun.security.action.GetPropertyAction;
  29 import sun.security.ssl.SSLExtension.ExtensionConsumer;
  30 import sun.security.ssl.SSLExtension.SSLExtensionSpec;
  31 import sun.security.ssl.SSLHandshake.HandshakeMessage;
  32 import sun.security.ssl.SupportedGroupsExtension.SupportedGroups;
  33 import sun.security.util.HexDumpEncoder;
  34 
  35 import javax.crypto.Cipher;
  36 import javax.crypto.KeyGenerator;
  37 import javax.crypto.SecretKey;
  38 import javax.crypto.spec.GCMParameterSpec;
  39 import javax.net.ssl.SSLProtocolException;
  40 
  41 import static sun.security.ssl.SSLExtension.CH_SESSION_TICKET;
  42 import static sun.security.ssl.SSLExtension.SH_SESSION_TICKET;
  43 
  44 import java.io.IOException;
  45 import java.nio.ByteBuffer;
  46 import java.security.NoSuchAlgorithmException;
  47 import java.security.SecureRandom;
  48 import java.text.MessageFormat;
  49 import java.util.Collection;
  50 import java.util.Locale;
  51 
  52 /**
  53  * SessionTicketExtension is an implementation of RFC 5077 with some internals
  54  * that are used for stateless operation in TLS 1.3.
  55  *
  56  * {@systemProperty jdk.tls.server.statelessKeyTimeout} can override the default
  57  * amount of time, in seconds, for how long a randomly-generated key and
  58  * parameters can be used before being regenerated.  The key material is used
  59  * to encrypt the stateless session ticket that is sent to the client that will
  60  * be used during resumption.  Default is 3600 seconds (1 hour)
  61  *
  62  */
  63 
  64 final class SessionTicketExtension {
  65 
  66     static final HandshakeProducer chNetworkProducer =
  67             new T12CHSessionTicketProducer();
  68     static final ExtensionConsumer chOnLoadConsumer =
  69             new T12CHSessionTicketConsumer();
  70     static final HandshakeProducer shNetworkProducer =
  71             new T12SHSessionTicketProducer();
  72     static final ExtensionConsumer shOnLoadConsumer =
  73             new T12SHSessionTicketConsumer();
  74 
  75     static final SSLStringizer steStringizer = new SessionTicketStringizer();
  76 
  77     // Time in milliseconds until key is changed for encrypting session state
  78     private static final int TIMEOUT_DEFAULT = 3600 * 1000;
  79     private static final int keyTimeout;
  80     private static int currentKeyID = new SecureRandom().nextInt();
  81     private static final int KEYLEN = 256;
  82 
  83     static {
  84         String s = GetPropertyAction.privilegedGetProperty(
  85                 "jdk.tls.server.statelessKeyTimeout");
  86         if (s != null) {
  87             int kt;
  88             try {
  89                 kt = Integer.parseInt(s) * 1000;  // change to ms
  90                 if (kt < 0 ||
  91                         kt > NewSessionTicket.MAX_TICKET_LIFETIME) {
  92                     if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
  93                         SSLLogger.warning("Invalid timeout for " +
  94                                 "jdk.tls.server.statelessKeyTimeout: " +
  95                                 kt + ".  Set to default value " +
  96                                 TIMEOUT_DEFAULT + "sec");
  97                     }
  98                     kt = TIMEOUT_DEFAULT;
  99                 }
 100             } catch (NumberFormatException e) {
 101                 kt = TIMEOUT_DEFAULT;
 102                 if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
 103                     SSLLogger.warning("Invalid timeout for " +
 104                             "jdk.tls.server.statelessKeyTimeout: " + s +
 105                             ".  Set to default value " + TIMEOUT_DEFAULT +
 106                             "sec");
 107                 }
 108             }
 109             keyTimeout = kt;
 110         } else {
 111             keyTimeout = TIMEOUT_DEFAULT;
 112         }
 113     }
 114 
 115     // Crypto key context for session state.  Used with stateless operation.
 116     final static class StatelessKey {
 117         final long timeout;
 118         final SecretKey key;
 119         final int num;
 120 
 121         StatelessKey(HandshakeContext hc, int newNum) {
 122             SecretKey k = null;
 123             try {
 124                 KeyGenerator kg = KeyGenerator.getInstance("AES");
 125                 kg.init(KEYLEN, hc.sslContext.getSecureRandom());
 126                 k = kg.generateKey();
 127             } catch (NoSuchAlgorithmException e) {
 128                 // should not happen;
 129             }
 130             key = k;
 131             timeout = System.currentTimeMillis() + keyTimeout;
 132             num = newNum;
 133             hc.sslContext.keyHashMap.put(Integer.valueOf(num), this);
 134         }
 135 
 136         // Check if key needs to be changed
 137         boolean isExpired() {
 138             return ((System.currentTimeMillis()) > timeout);
 139         }
 140 
 141         // Check if this key is ready for deletion.
 142         boolean isInvalid(long sessionTimeout) {
 143             return ((System.currentTimeMillis()) > (timeout + sessionTimeout));
 144         }
 145     }
 146 
 147     private static final class KeyState {
 148 
 149         // Get a key with a specific key number
 150         static StatelessKey getKey(HandshakeContext hc, int num)  {
 151             StatelessKey ssk = hc.sslContext.keyHashMap.get(num);
 152 
 153             if (ssk == null || ssk.isInvalid(getSessionTimeout(hc))) {
 154                 return null;
 155             }
 156             return ssk;
 157         }
 158 
 159         // Get the current valid key, this will generate a new key if needed
 160         static StatelessKey getCurrentKey(HandshakeContext hc) {
 161             StatelessKey ssk = hc.sslContext.keyHashMap.get(currentKeyID);
 162 
 163             if (ssk != null && !ssk.isExpired()) {
 164                 return ssk;
 165             }
 166             return nextKey(hc);
 167         }
 168 
 169         // This method locks when the first getCurrentKey() finds it to be too
 170         // old and create a new key to replace the current key.  After the new
 171         // key established, the lock can be released so following
 172         // operations will start using the new key.
 173         // The first operation will take a longer code path by generating the
 174         // next key and cleaning up old keys.
 175         private static StatelessKey nextKey(HandshakeContext hc) {
 176             StatelessKey ssk;
 177 
 178             synchronized (hc.sslContext.keyHashMap) {
 179                 // If the current key is no longer expired, it was already
 180                 // updated by a previous operation and we can return.
 181                 ssk = hc.sslContext.keyHashMap.get(currentKeyID);
 182                 if (ssk != null && !ssk.isExpired()) {
 183                     return ssk;
 184                 }
 185                 int newNum;
 186                 if (currentKeyID == Integer.MAX_VALUE) {
 187                     newNum = 0;
 188                 } else {
 189                     newNum = currentKeyID + 1;
 190                 }
 191                 // Get new key
 192                 ssk = new StatelessKey(hc, newNum);
 193                 currentKeyID = newNum;
 194                 // Release lock since the new key is ready to be used.
 195             }
 196 
 197             // Clean up any old keys, then return the current key
 198             cleanup(hc);
 199             return ssk;
 200         }
 201 
 202         // Deletes any invalid SessionStateKeys.
 203         static void cleanup(HandshakeContext hc) {
 204             int sessionTimeout = getSessionTimeout(hc);
 205 
 206             StatelessKey ks;
 207             for (Object o : hc.sslContext.keyHashMap.keySet().toArray()) {
 208                 Integer i = (Integer)o;
 209                 ks = hc.sslContext.keyHashMap.get(i);
 210                 if (ks.isInvalid(sessionTimeout)) {
 211                     try {
 212                         ks.key.destroy();
 213                     } catch (Exception e) {
 214                         // Suppress
 215                     }
 216                     hc.sslContext.keyHashMap.remove(i);
 217                 }
 218             }
 219         }
 220 
 221         static int getSessionTimeout(HandshakeContext hc) {
 222             return hc.sslContext.engineGetServerSessionContext().
 223                     getSessionTimeout() * 1000;
 224         }
 225     }
 226 
 227     /**
 228      * This class contains the session state that is in the session ticket.
 229      * Using the key associated with the ticket, the class encrypts and
 230      * decrypts the data, but does not interpret the data.
 231      */
 232     static final class SessionTicketSpec implements SSLExtensionSpec {
 233         private static final int GCM_TAG_LEN = 128;
 234         ByteBuffer data;
 235         static final ByteBuffer zero = ByteBuffer.wrap(new byte[0]);
 236 
 237         SessionTicketSpec() {
 238             data = zero;
 239         }
 240 
 241         SessionTicketSpec(byte[] b) throws IOException {
 242             this(ByteBuffer.wrap(b));
 243         }
 244 
 245         SessionTicketSpec(ByteBuffer buf) throws IOException {
 246             if (buf == null) {
 247                 throw new SSLProtocolException(
 248                         "SessionTicket buffer too small");
 249             }
 250             if (buf.remaining() > 65536) {
 251                 throw new SSLProtocolException(
 252                         "SessionTicket buffer too large. " + buf.remaining());
 253             }
 254 
 255             data = buf;
 256         }
 257 
 258         public byte[] encrypt(HandshakeContext hc, SSLSessionImpl session)
 259                 throws IOException {
 260             byte[] encrypted;
 261             StatelessKey key = KeyState.getCurrentKey(hc);
 262             byte[] iv = new byte[16];
 263 
 264             try {
 265                 SecureRandom random = hc.sslContext.getSecureRandom();
 266                 random.nextBytes(iv);
 267                 Cipher c = Cipher.getInstance("AES/GCM/NoPadding");
 268                 c.init(Cipher.ENCRYPT_MODE, key.key,
 269                         new GCMParameterSpec(GCM_TAG_LEN, iv));
 270                 c.updateAAD(new byte[] {
 271                         (byte)(key.num >>> 24),
 272                         (byte)(key.num >>> 16),
 273                         (byte)(key.num >>> 8),
 274                         (byte)(key.num)}
 275                 );
 276                 encrypted = c.doFinal(session.write());
 277 
 278                 byte[] result = new byte[encrypted.length + Integer.BYTES +
 279                         iv.length];
 280                 result[0] = (byte)(key.num >>> 24);
 281                 result[1] = (byte)(key.num >>> 16);
 282                 result[2] = (byte)(key.num >>> 8);
 283                 result[3] = (byte)(key.num);
 284                 System.arraycopy(iv, 0, result, Integer.BYTES, iv.length);
 285                 System.arraycopy(encrypted, 0, result,
 286                         Integer.BYTES + iv.length, encrypted.length);
 287                 return result;
 288             } catch (Exception e) {
 289                 throw hc.conContext.fatal(Alert.UNEXPECTED_MESSAGE, e);
 290             }
 291         }
 292 
 293         ByteBuffer decrypt(HandshakeContext hc) {
 294             int keyID;
 295             byte[] iv;
 296             try {
 297                 keyID = data.getInt();
 298                 StatelessKey key = KeyState.getKey(hc, keyID);
 299                 if (key == null) {
 300                     return null;
 301                 }
 302 
 303                 iv = new byte[16];
 304                 data.get(iv);
 305                 Cipher c = Cipher.getInstance("AES/GCM/NoPadding");
 306                 c.init(Cipher.DECRYPT_MODE, key.key,
 307                         new GCMParameterSpec(GCM_TAG_LEN, iv));
 308                 c.updateAAD(new byte[] {
 309                         (byte)(keyID >>> 24),
 310                         (byte)(keyID >>> 16),
 311                         (byte)(keyID >>> 8),
 312                         (byte)(keyID)}
 313                 );
 314                 /*
 315                 return ByteBuffer.wrap(c.doFinal(data,
 316                         Integer.BYTES + iv.length,
 317                         data.length - (Integer.BYTES + iv.length)));
 318                  */
 319                 ByteBuffer out;
 320                 out = ByteBuffer.allocate(data.remaining() - GCM_TAG_LEN / 8);
 321                 c.doFinal(data, out);
 322                 out.flip();
 323                 return out;
 324             } catch (Exception e) {
 325                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 326                     SSLLogger.fine("Decryption failed." + e.getMessage());
 327                 }
 328             }
 329             return null;
 330         }
 331 
 332         byte[] getEncoded() {
 333             byte[] out = new byte[data.capacity()];
 334             data.duplicate().get(out);
 335             return out;
 336         }
 337 
 338         @Override
 339         public String toString() {
 340             if (data == null) {
 341                 return "<null>";
 342             }
 343             if (data.capacity() == 0) {
 344                 return "<empty>";
 345             }
 346 
 347             MessageFormat messageFormat = new MessageFormat(
 348                     "  \"ticket\" : '{'\n" +
 349                             "{0}\n" +
 350                             "  '}'",
 351                     Locale.ENGLISH);
 352             HexDumpEncoder hexEncoder = new HexDumpEncoder();
 353 
 354             Object[] messageFields = {
 355                     Utilities.indent(hexEncoder.encode(data.duplicate()),
 356                             "    "),
 357             };
 358 
 359             return messageFormat.format(messageFields);
 360         }
 361     }
 362 
 363     static final class SessionTicketStringizer implements SSLStringizer {
 364         SessionTicketStringizer() {}
 365 
 366         @Override
 367         public String toString(ByteBuffer buffer) {
 368             try {
 369                 return new SessionTicketSpec(buffer).toString();
 370             } catch (IOException e) {
 371                 return e.getMessage();
 372             }
 373         }
 374     }
 375 
 376     private static final class T12CHSessionTicketProducer
 377             extends SupportedGroups implements HandshakeProducer {
 378         T12CHSessionTicketProducer() {
 379         }
 380 
 381         @Override
 382         public byte[] produce(ConnectionContext context,
 383                 HandshakeMessage message) throws IOException {
 384 
 385             ClientHandshakeContext chc = (ClientHandshakeContext)context;
 386 
 387             // If the context does not allow stateless tickets, exit
 388             if (!((SSLSessionContextImpl)chc.sslContext.
 389                     engineGetClientSessionContext()).statelessEnabled()) {
 390                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 391                     SSLLogger.fine("Stateless resumption not supported");
 392                 }
 393                 return null;
 394             }
 395 
 396             chc.statelessResumption = true;
 397 
 398             // If resumption is not in progress, return an empty value
 399             if (!chc.isResumption || chc.resumingSession == null) {
 400                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 401                     SSLLogger.fine("Stateless resumption supported");
 402                 }
 403                 return new SessionTicketSpec().getEncoded();
 404             }
 405 
 406             if (chc.localSupportedSignAlgs == null) {
 407                 chc.localSupportedSignAlgs =
 408                         SignatureScheme.getSupportedAlgorithms(
 409                                 chc.algorithmConstraints, chc.activeProtocols);
 410             }
 411 
 412             return chc.resumingSession.getPskIdentity();
 413         }
 414 
 415     }
 416 
 417     private static final class T12CHSessionTicketConsumer
 418             implements ExtensionConsumer {
 419         T12CHSessionTicketConsumer() {
 420         }
 421 
 422         @Override
 423         public void consume(ConnectionContext context,
 424                 HandshakeMessage message, ByteBuffer buffer)
 425                 throws IOException {
 426             ServerHandshakeContext shc = (ServerHandshakeContext) context;
 427 
 428             // Skip if extension is not provided
 429             if (!shc.sslConfig.isAvailable(CH_SESSION_TICKET)) {
 430                 return;
 431             }
 432 
 433             // Skip consumption if we are already in stateless resumption
 434             if (shc.statelessResumption) {
 435                 return;
 436             }
 437             // If the context does not allow stateless tickets, exit
 438             SSLSessionContextImpl cache = (SSLSessionContextImpl)shc.sslContext
 439                     .engineGetServerSessionContext();
 440             if (!cache.statelessEnabled()) {
 441                 return;
 442             }
 443 
 444             if (buffer.remaining() == 0) {
 445                 shc.statelessResumption = true;
 446                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 447                     SSLLogger.fine("Client accepts session tickets.");
 448                 }
 449                 return;
 450             }
 451 
 452             // Parse the extension.
 453             SessionTicketSpec spec;
 454             try {
 455                  spec = new SessionTicketSpec(buffer);
 456             } catch (IOException | RuntimeException e) {
 457                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 458                     SSLLogger.fine("SessionTicket data invalid. Doing full " +
 459                             "handshake.");
 460                 }
 461                 return;
 462             }
 463             ByteBuffer b = spec.decrypt(shc);
 464             if (b != null) {
 465                 shc.resumingSession = new SSLSessionImpl(shc, b);
 466                 shc.isResumption = true;
 467                 shc.statelessResumption = true;
 468                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 469                     SSLLogger.fine("Valid stateless session ticket found");
 470                 }
 471             }
 472         }
 473     }
 474 
 475 
 476     private static final class T12SHSessionTicketProducer
 477             extends SupportedGroups implements HandshakeProducer {
 478         T12SHSessionTicketProducer() {
 479         }
 480 
 481         @Override
 482         public byte[] produce(ConnectionContext context,
 483                 HandshakeMessage message) {
 484 
 485             ServerHandshakeContext shc = (ServerHandshakeContext)context;
 486 
 487             // If boolean is false, the CH did not have this extension
 488             if (!shc.statelessResumption) {
 489                 return null;
 490             }
 491             // If the client has sent a SessionTicketExtension and stateless
 492             // is enabled on the server, return an empty message.
 493             // If the context does not allow stateless tickets, exit
 494             SSLSessionContextImpl cache = (SSLSessionContextImpl)shc.sslContext
 495                     .engineGetServerSessionContext();
 496             if (cache.statelessEnabled()) {
 497                 return new byte[0];
 498             }
 499 
 500             shc.statelessResumption = false;
 501             return null;
 502         }
 503     }
 504 
 505     private static final class T12SHSessionTicketConsumer
 506             implements ExtensionConsumer {
 507         T12SHSessionTicketConsumer() {
 508         }
 509 
 510         @Override
 511         public void consume(ConnectionContext context,
 512                 HandshakeMessage message, ByteBuffer buffer)
 513                 throws IOException {
 514             ClientHandshakeContext chc = (ClientHandshakeContext) context;
 515 
 516             // Skip if extension is not provided
 517             if (!chc.sslConfig.isAvailable(SH_SESSION_TICKET)) {
 518                 chc.statelessResumption = false;
 519                 return;
 520             }
 521 
 522             // If the context does not allow stateless tickets, exit
 523             if (!((SSLSessionContextImpl)chc.sslContext.
 524                     engineGetClientSessionContext()).statelessEnabled()) {
 525                 chc.statelessResumption = false;
 526                 return;
 527             }
 528 
 529             try {
 530                 if (new SessionTicketSpec(buffer) == null) {
 531                     return;
 532                 }
 533                 chc.statelessResumption = true;
 534             } catch (IOException e) {
 535                 throw chc.conContext.fatal(Alert.UNEXPECTED_MESSAGE, e);
 536             }
 537         }
 538     }
 539 }