1 /*
   2  * Copyright (c) 1996, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package sun.security.ssl;
  27 
  28 import java.io.*;
  29 import java.nio.*;
  30 import java.util.*;
  31 import javax.net.ssl.*;
  32 import sun.security.ssl.SSLCipher.SSLWriteCipher;
  33 
  34 /**
  35  * DTLS {@code OutputRecord} implementation for {@code SSLEngine}.
  36  */
  37 final class DTLSOutputRecord extends OutputRecord implements DTLSRecord {
  38 
  39     private DTLSFragmenter fragmenter = null;
  40 
  41     int                 writeEpoch;
  42 
  43     int                 prevWriteEpoch;
  44     Authenticator       prevWriteAuthenticator;
  45     SSLWriteCipher      prevWriteCipher;
  46 
  47     private volatile boolean isCloseWaiting = false;
  48 
  49     DTLSOutputRecord(HandshakeHash handshakeHash) {
  50         super(handshakeHash, SSLWriteCipher.nullDTlsWriteCipher());
  51 
  52         this.writeEpoch = 0;
  53         this.prevWriteEpoch = 0;
  54         this.prevWriteCipher = SSLWriteCipher.nullDTlsWriteCipher();
  55 
  56         this.packetSize = DTLSRecord.maxRecordSize;
  57         this.protocolVersion = ProtocolVersion.NONE;
  58     }
  59 
  60     @Override
  61     public synchronized void close() throws IOException {
  62         if (!isClosed) {
  63             if (fragmenter != null && fragmenter.hasAlert()) {
  64                 isCloseWaiting = true;
  65             } else {
  66                 super.close();
  67             }
  68         }
  69     }
  70 
  71     boolean isClosed() {
  72         return isClosed || isCloseWaiting;
  73     }
  74 
  75     @Override
  76     void initHandshaker() {
  77         // clean up
  78         fragmenter = null;
  79     }
  80 
  81     @Override
  82     void finishHandshake() {
  83         // Nothing to do here currently.
  84     }
  85 
  86     @Override
  87     void changeWriteCiphers(SSLWriteCipher writeCipher,
  88             boolean useChangeCipherSpec) throws IOException {
  89         if (isClosed()) {
  90             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
  91                 SSLLogger.warning("outbound has closed, ignore outbound " +
  92                     "change_cipher_spec message");
  93             }
  94             return;
  95         }
  96 
  97         if (useChangeCipherSpec) {
  98             encodeChangeCipherSpec();
  99         }
 100 
 101         prevWriteCipher.dispose();
 102 
 103         this.prevWriteCipher = this.writeCipher;
 104         this.prevWriteEpoch = this.writeEpoch;
 105 
 106         this.writeCipher = writeCipher;
 107         this.writeEpoch++;
 108 
 109         this.isFirstAppOutputRecord = true;
 110 
 111         // set the epoch number
 112         this.writeCipher.authenticator.setEpochNumber(this.writeEpoch);
 113     }
 114 
 115     @Override
 116     void encodeAlert(byte level, byte description) throws IOException {
 117         if (isClosed()) {
 118             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
 119                 SSLLogger.warning("outbound has closed, ignore outbound " +
 120                     "alert message: " + Alert.nameOf(description));
 121             }
 122             return;
 123         }
 124 
 125         if (fragmenter == null) {
 126            fragmenter = new DTLSFragmenter();
 127         }
 128 
 129         fragmenter.queueUpAlert(level, description);
 130     }
 131 
 132     @Override
 133     void encodeChangeCipherSpec() throws IOException {
 134         if (isClosed()) {
 135             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
 136                 SSLLogger.warning("outbound has closed, ignore outbound " +
 137                     "change_cipher_spec message");
 138             }
 139             return;
 140         }
 141 
 142         if (fragmenter == null) {
 143            fragmenter = new DTLSFragmenter();
 144         }
 145         fragmenter.queueUpChangeCipherSpec();
 146     }
 147 
 148     @Override
 149     void encodeHandshake(byte[] source,
 150             int offset, int length) throws IOException {
 151         if (isClosed()) {
 152             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
 153                 SSLLogger.warning("outbound has closed, ignore outbound " +
 154                         "handshake message",
 155                         ByteBuffer.wrap(source, offset, length));
 156             }
 157             return;
 158         }
 159 
 160         if (firstMessage) {
 161             firstMessage = false;
 162         }
 163 
 164         if (fragmenter == null) {
 165            fragmenter = new DTLSFragmenter();
 166         }
 167 
 168         fragmenter.queueUpHandshake(source, offset, length);
 169     }
 170 
 171     @Override
 172     Ciphertext encode(
 173         ByteBuffer[] srcs, int srcsOffset, int srcsLength,
 174         ByteBuffer[] dsts, int dstsOffset, int dstsLength) throws IOException {
 175 
 176         if (isClosed) {
 177             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
 178                 SSLLogger.warning("outbound has closed, ignore outbound " +
 179                     "application data or cached messages");
 180             }
 181 
 182             return null;
 183         } else if (isCloseWaiting) {
 184             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
 185                 SSLLogger.warning("outbound has closed, ignore outbound " +
 186                     "application data");
 187             }
 188 
 189             srcs = null;    // use no application data.
 190         }
 191 
 192         return encode(srcs, srcsOffset, srcsLength, dsts[0]);
 193     }
 194 
 195     private Ciphertext encode(ByteBuffer[] sources, int offset, int length,
 196             ByteBuffer destination) throws IOException {
 197 
 198         if (writeCipher.authenticator.seqNumOverflow()) {
 199             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
 200                 SSLLogger.fine(
 201                     "sequence number extremely close to overflow " +
 202                     "(2^64-1 packets). Closing connection.");
 203             }
 204 
 205             throw new SSLHandshakeException("sequence number overflow");
 206         }
 207 
 208         // Don't process the incoming record until all of the buffered records
 209         // get handled.  May need retransmission if no sources specified.
 210         if (!isEmpty() || sources == null || sources.length == 0) {
 211             Ciphertext ct = acquireCiphertext(destination);
 212             if (ct != null) {
 213                 return ct;
 214             }
 215         }
 216 
 217         if (sources == null || sources.length == 0) {
 218             return null;
 219         }
 220 
 221         int srcsRemains = 0;
 222         for (int i = offset; i < offset + length; i++) {
 223             srcsRemains += sources[i].remaining();
 224         }
 225 
 226         if (srcsRemains == 0) {
 227             return null;
 228         }
 229 
 230         // not apply to handshake message
 231         int fragLen;
 232         if (packetSize > 0) {
 233             fragLen = Math.min(maxRecordSize, packetSize);
 234             fragLen = writeCipher.calculateFragmentSize(
 235                     fragLen, headerSize);
 236 
 237             fragLen = Math.min(fragLen, Record.maxDataSize);
 238         } else {
 239             fragLen = Record.maxDataSize;
 240         }
 241 
 242         if (fragmentSize > 0) {
 243             fragLen = Math.min(fragLen, fragmentSize);
 244         }
 245 
 246         int dstPos = destination.position();
 247         int dstLim = destination.limit();
 248         int dstContent = dstPos + headerSize +
 249                                 writeCipher.getExplicitNonceSize();
 250         destination.position(dstContent);
 251 
 252         int remains = Math.min(fragLen, destination.remaining());
 253         fragLen = 0;
 254         int srcsLen = offset + length;
 255         for (int i = offset; (i < srcsLen) && (remains > 0); i++) {
 256             int amount = Math.min(sources[i].remaining(), remains);
 257             int srcLimit = sources[i].limit();
 258             sources[i].limit(sources[i].position() + amount);
 259             destination.put(sources[i]);
 260             sources[i].limit(srcLimit);         // restore the limit
 261             remains -= amount;
 262             fragLen += amount;
 263         }
 264 
 265         destination.limit(destination.position());
 266         destination.position(dstContent);
 267 
 268         if (SSLLogger.isOn && SSLLogger.isOn("record")) {
 269             SSLLogger.fine(
 270                     "WRITE: " + protocolVersion + " " +
 271                     ContentType.APPLICATION_DATA.name +
 272                     ", length = " + destination.remaining());
 273         }
 274 
 275         // Encrypt the fragment and wrap up a record.
 276         long recordSN = encrypt(writeCipher,
 277                 ContentType.APPLICATION_DATA.id, destination,
 278                 dstPos, dstLim, headerSize,
 279                 protocolVersion);
 280 
 281         if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
 282             ByteBuffer temporary = destination.duplicate();
 283             temporary.limit(temporary.position());
 284             temporary.position(dstPos);
 285             SSLLogger.fine("Raw write", temporary);
 286         }
 287 
 288         // remain the limit unchanged
 289         destination.limit(dstLim);
 290 
 291         return new Ciphertext(ContentType.APPLICATION_DATA.id,
 292                 SSLHandshake.NOT_APPLICABLE.id, recordSN);
 293     }
 294 
 295     private Ciphertext acquireCiphertext(
 296             ByteBuffer destination) throws IOException {
 297         if (fragmenter != null) {
 298             return fragmenter.acquireCiphertext(destination);
 299         }
 300 
 301         return null;
 302     }
 303 
 304     @Override
 305     boolean isEmpty() {
 306         return (fragmenter == null) || fragmenter.isEmpty();
 307     }
 308 
 309     @Override
 310     void launchRetransmission() {
 311         // Note: Please don't retransmit if there are handshake messages
 312         // or alerts waiting in the queue.
 313         if ((fragmenter != null) && fragmenter.isRetransmittable()) {
 314             fragmenter.setRetransmission();
 315         }
 316     }
 317 
 318     // buffered record fragment
 319     private static class RecordMemo {
 320         byte            contentType;
 321         byte            majorVersion;
 322         byte            minorVersion;
 323         int             encodeEpoch;
 324         SSLWriteCipher  encodeCipher;
 325 
 326         byte[]          fragment;
 327     }
 328 
 329     private static class HandshakeMemo extends RecordMemo {
 330         byte            handshakeType;
 331         int             messageSequence;
 332         int             acquireOffset;
 333     }
 334 
 335     private final class DTLSFragmenter {
 336         private final LinkedList<RecordMemo> handshakeMemos =
 337                 new LinkedList<>();
 338         private int acquireIndex = 0;
 339         private int messageSequence = 0;
 340         private boolean flightIsReady = false;
 341 
 342         // Per section 4.1.1, RFC 6347:
 343         //
 344         // If repeated retransmissions do not result in a response, and the
 345         // PMTU is unknown, subsequent retransmissions SHOULD back off to a
 346         // smaller record size, fragmenting the handshake message as
 347         // appropriate.
 348         //
 349         // In this implementation, two times of retransmits would be attempted
 350         // before backing off.  The back off is supported only if the packet
 351         // size is bigger than 256 bytes.
 352         private int retransmits = 2;            // attemps of retransmits
 353 
 354         void queueUpHandshake(byte[] buf,
 355                 int offset, int length) throws IOException {
 356 
 357             // Cleanup if a new flight starts.
 358             if (flightIsReady) {
 359                 handshakeMemos.clear();
 360                 acquireIndex = 0;
 361                 flightIsReady = false;
 362             }
 363 
 364             HandshakeMemo memo = new HandshakeMemo();
 365 
 366             memo.contentType = ContentType.HANDSHAKE.id;
 367             memo.majorVersion = protocolVersion.major;
 368             memo.minorVersion = protocolVersion.minor;
 369             memo.encodeEpoch = writeEpoch;
 370             memo.encodeCipher = writeCipher;
 371 
 372             memo.handshakeType = buf[offset];
 373             memo.messageSequence = messageSequence++;
 374             memo.acquireOffset = 0;
 375             memo.fragment = new byte[length - 4];       // 4: header size
 376                                                         //    1: HandshakeType
 377                                                         //    3: message length
 378             System.arraycopy(buf, offset + 4, memo.fragment, 0, length - 4);
 379 
 380             handshakeHashing(memo, memo.fragment);
 381             handshakeMemos.add(memo);
 382 
 383             if ((memo.handshakeType == SSLHandshake.CLIENT_HELLO.id) ||
 384                 (memo.handshakeType == SSLHandshake.HELLO_REQUEST.id) ||
 385                 (memo.handshakeType ==
 386                         SSLHandshake.HELLO_VERIFY_REQUEST.id) ||
 387                 (memo.handshakeType == SSLHandshake.SERVER_HELLO_DONE.id) ||
 388                 (memo.handshakeType == SSLHandshake.FINISHED.id)) {
 389 
 390                 flightIsReady = true;
 391             }
 392         }
 393 
 394         void queueUpChangeCipherSpec() {
 395 
 396             // Cleanup if a new flight starts.
 397             if (flightIsReady) {
 398                 handshakeMemos.clear();
 399                 acquireIndex = 0;
 400                 flightIsReady = false;
 401             }
 402 
 403             RecordMemo memo = new RecordMemo();
 404 
 405             memo.contentType = ContentType.CHANGE_CIPHER_SPEC.id;
 406             memo.majorVersion = protocolVersion.major;
 407             memo.minorVersion = protocolVersion.minor;
 408             memo.encodeEpoch = writeEpoch;
 409             memo.encodeCipher = writeCipher;
 410 
 411             memo.fragment = new byte[1];
 412             memo.fragment[0] = 1;
 413 
 414             handshakeMemos.add(memo);
 415         }
 416 
 417         void queueUpAlert(byte level, byte description) throws IOException {
 418             RecordMemo memo = new RecordMemo();
 419 
 420             memo.contentType = ContentType.ALERT.id;
 421             memo.majorVersion = protocolVersion.major;
 422             memo.minorVersion = protocolVersion.minor;
 423             memo.encodeEpoch = writeEpoch;
 424             memo.encodeCipher = writeCipher;
 425 
 426             memo.fragment = new byte[2];
 427             memo.fragment[0] = level;
 428             memo.fragment[1] = description;
 429 
 430             handshakeMemos.add(memo);
 431         }
 432 
 433         Ciphertext acquireCiphertext(ByteBuffer dstBuf) throws IOException {
 434             if (isEmpty()) {
 435                 if (isRetransmittable()) {
 436                     setRetransmission();    // configure for retransmission
 437                 } else {
 438                     return null;
 439                 }
 440             }
 441 
 442             RecordMemo memo = handshakeMemos.get(acquireIndex);
 443             HandshakeMemo hsMemo = null;
 444             if (memo.contentType == ContentType.HANDSHAKE.id) {
 445                 hsMemo = (HandshakeMemo)memo;
 446             }
 447 
 448             // ChangeCipherSpec message is pretty small.  Don't worry about
 449             // the fragmentation of ChangeCipherSpec record.
 450             int fragLen;
 451             if (packetSize > 0) {
 452                 fragLen = Math.min(maxRecordSize, packetSize);
 453                 fragLen = memo.encodeCipher.calculateFragmentSize(
 454                         fragLen, 25);   // 25: header size
 455                                                 //   13: DTLS record
 456                                                 //   12: DTLS handshake message
 457                 fragLen = Math.min(fragLen, Record.maxDataSize);
 458             } else {
 459                 fragLen = Record.maxDataSize;
 460             }
 461 
 462             if (fragmentSize > 0) {
 463                 fragLen = Math.min(fragLen, fragmentSize);
 464             }
 465 
 466             int dstPos = dstBuf.position();
 467             int dstLim = dstBuf.limit();
 468             int dstContent = dstPos + headerSize +
 469                                     memo.encodeCipher.getExplicitNonceSize();
 470             dstBuf.position(dstContent);
 471 
 472             if (hsMemo != null) {
 473                 fragLen = Math.min(fragLen,
 474                         (hsMemo.fragment.length - hsMemo.acquireOffset));
 475 
 476                 dstBuf.put(hsMemo.handshakeType);
 477                 dstBuf.put((byte)((hsMemo.fragment.length >> 16) & 0xFF));
 478                 dstBuf.put((byte)((hsMemo.fragment.length >> 8) & 0xFF));
 479                 dstBuf.put((byte)(hsMemo.fragment.length & 0xFF));
 480                 dstBuf.put((byte)((hsMemo.messageSequence >> 8) & 0xFF));
 481                 dstBuf.put((byte)(hsMemo.messageSequence & 0xFF));
 482                 dstBuf.put((byte)((hsMemo.acquireOffset >> 16) & 0xFF));
 483                 dstBuf.put((byte)((hsMemo.acquireOffset >> 8) & 0xFF));
 484                 dstBuf.put((byte)(hsMemo.acquireOffset & 0xFF));
 485                 dstBuf.put((byte)((fragLen >> 16) & 0xFF));
 486                 dstBuf.put((byte)((fragLen >> 8) & 0xFF));
 487                 dstBuf.put((byte)(fragLen & 0xFF));
 488                 dstBuf.put(hsMemo.fragment, hsMemo.acquireOffset, fragLen);
 489             } else {
 490                 fragLen = Math.min(fragLen, memo.fragment.length);
 491                 dstBuf.put(memo.fragment, 0, fragLen);
 492             }
 493 
 494             dstBuf.limit(dstBuf.position());
 495             dstBuf.position(dstContent);
 496 
 497             if (SSLLogger.isOn && SSLLogger.isOn("record")) {
 498                 SSLLogger.fine(
 499                         "WRITE: " + protocolVersion + " " +
 500                         ContentType.nameOf(memo.contentType) +
 501                         ", length = " + dstBuf.remaining());
 502             }
 503 
 504             // Encrypt the fragment and wrap up a record.
 505             long recordSN = encrypt(memo.encodeCipher,
 506                     memo.contentType, dstBuf,
 507                     dstPos, dstLim, headerSize,
 508                     ProtocolVersion.valueOf(memo.majorVersion,
 509                             memo.minorVersion));
 510 
 511             if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
 512                 ByteBuffer temporary = dstBuf.duplicate();
 513                 temporary.limit(temporary.position());
 514                 temporary.position(dstPos);
 515                 SSLLogger.fine(
 516                         "Raw write (" + temporary.remaining() + ")", temporary);
 517             }
 518 
 519             // remain the limit unchanged
 520             dstBuf.limit(dstLim);
 521 
 522             // Reset the fragmentation offset.
 523             if (hsMemo != null) {
 524                 hsMemo.acquireOffset += fragLen;
 525                 if (hsMemo.acquireOffset == hsMemo.fragment.length) {
 526                     acquireIndex++;
 527                 }
 528 
 529                 return new Ciphertext(hsMemo.contentType,
 530                         hsMemo.handshakeType, recordSN);
 531             } else {
 532                 if (isCloseWaiting &&
 533                         memo.contentType == ContentType.ALERT.id) {
 534                     close();
 535                 }
 536 
 537                 acquireIndex++;
 538                 return new Ciphertext(memo.contentType,
 539                         SSLHandshake.NOT_APPLICABLE.id, recordSN);
 540             }
 541         }
 542 
 543         private void handshakeHashing(HandshakeMemo hsFrag, byte[] hsBody) {
 544 
 545             byte hsType = hsFrag.handshakeType;
 546             if (!handshakeHash.isHashable(hsType)) {
 547                 // omitted from handshake hash computation
 548                 return;
 549             }
 550 
 551             // calculate the DTLS header
 552             byte[] temporary = new byte[12];    // 12: handshake header size
 553 
 554             // Handshake.msg_type
 555             temporary[0] = hsFrag.handshakeType;
 556 
 557             // Handshake.length
 558             temporary[1] = (byte)((hsBody.length >> 16) & 0xFF);
 559             temporary[2] = (byte)((hsBody.length >> 8) & 0xFF);
 560             temporary[3] = (byte)(hsBody.length & 0xFF);
 561 
 562             // Handshake.message_seq
 563             temporary[4] = (byte)((hsFrag.messageSequence >> 8) & 0xFF);
 564             temporary[5] = (byte)(hsFrag.messageSequence & 0xFF);
 565 
 566             // Handshake.fragment_offset
 567             temporary[6] = 0;
 568             temporary[7] = 0;
 569             temporary[8] = 0;
 570 
 571             // Handshake.fragment_length
 572             temporary[9] = temporary[1];
 573             temporary[10] = temporary[2];
 574             temporary[11] = temporary[3];
 575 
 576             handshakeHash.deliver(temporary, 0, 12);
 577             handshakeHash.deliver(hsBody, 0, hsBody.length);
 578         }
 579 
 580         boolean isEmpty() {
 581             if (!flightIsReady || handshakeMemos.isEmpty() ||
 582                     acquireIndex >= handshakeMemos.size()) {
 583                 return true;
 584             }
 585 
 586             return false;
 587         }
 588 
 589         boolean hasAlert() {
 590             for (RecordMemo memo : handshakeMemos) {
 591                 if (memo.contentType == ContentType.ALERT.id) {
 592                     return true;
 593                 }
 594             }
 595 
 596             return false;
 597         }
 598 
 599         boolean isRetransmittable() {
 600             return (flightIsReady && !handshakeMemos.isEmpty() &&
 601                                 (acquireIndex >= handshakeMemos.size()));
 602         }
 603 
 604         private void setRetransmission() {
 605             acquireIndex = 0;
 606             for (RecordMemo memo : handshakeMemos) {
 607                 if (memo instanceof HandshakeMemo) {
 608                     HandshakeMemo hmemo = (HandshakeMemo)memo;
 609                     hmemo.acquireOffset = 0;
 610                 }
 611             }
 612 
 613             // Shrink packet size if:
 614             // 1. maximum fragment size is allowed, in which case the packet
 615             //    size is configured bigger than maxRecordSize;
 616             // 2. maximum packet is bigger than 256 bytes;
 617             // 3. two times of retransmits have been attempted.
 618             if ((packetSize <= maxRecordSize) &&
 619                     (packetSize > 256) && ((retransmits--) <= 0)) {
 620 
 621                 // shrink packet size
 622                 shrinkPacketSize();
 623                 retransmits = 2;        // attemps of retransmits
 624             }
 625         }
 626 
 627         private void shrinkPacketSize() {
 628             packetSize = Math.max(256, packetSize / 2);
 629         }
 630     }
 631 }