1 /*
   2  * Copyright (c) 1996, 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package sun.security.ssl;
  27 
  28 import java.io.*;
  29 import java.nio.*;
  30 import java.util.*;
  31 import javax.net.ssl.*;
  32 import sun.security.ssl.SSLCipher.SSLWriteCipher;
  33 
  34 /**
  35  * DTLS {@code OutputRecord} implementation for {@code SSLEngine}.
  36  */
  37 final class DTLSOutputRecord extends OutputRecord implements DTLSRecord {
  38 
  39     private DTLSFragmenter fragmenter = null;
  40 
  41     int                 writeEpoch;
  42 
  43     int                 prevWriteEpoch;
  44     Authenticator       prevWriteAuthenticator;
  45     SSLWriteCipher      prevWriteCipher;
  46 
  47     private volatile boolean isCloseWaiting = false;
  48 
  49     DTLSOutputRecord(HandshakeHash handshakeHash) {
  50         super(handshakeHash, SSLWriteCipher.nullDTlsWriteCipher());
  51 
  52         this.writeEpoch = 0;
  53         this.prevWriteEpoch = 0;
  54         this.prevWriteCipher = SSLWriteCipher.nullDTlsWriteCipher();
  55 
  56         this.packetSize = DTLSRecord.maxRecordSize;
  57         this.protocolVersion = ProtocolVersion.NONE;
  58     }
  59 
  60     @Override
  61     public void close() throws IOException {
  62         recordLock.lock();
  63         try {
  64             if (!isClosed) {
  65                 if (fragmenter != null && fragmenter.hasAlert()) {
  66                     isCloseWaiting = true;
  67                 } else {
  68                     super.close();
  69                 }
  70             }
  71         } finally {
  72             recordLock.unlock();
  73         }
  74     }
  75 
  76     boolean isClosed() {
  77         return isClosed || isCloseWaiting;
  78     }
  79 
  80     @Override
  81     void initHandshaker() {
  82         // clean up
  83         fragmenter = null;
  84     }
  85 
  86     @Override
  87     void finishHandshake() {
  88         // Nothing to do here currently.
  89     }
  90 
  91     @Override
  92     void changeWriteCiphers(SSLWriteCipher writeCipher,
  93             boolean useChangeCipherSpec) throws IOException {
  94         if (isClosed()) {
  95             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
  96                 SSLLogger.warning("outbound has closed, ignore outbound " +
  97                     "change_cipher_spec message");
  98             }
  99             return;
 100         }
 101 
 102         if (useChangeCipherSpec) {
 103             encodeChangeCipherSpec();
 104         }
 105 
 106         prevWriteCipher.dispose();
 107 
 108         this.prevWriteCipher = this.writeCipher;
 109         this.prevWriteEpoch = this.writeEpoch;
 110 
 111         this.writeCipher = writeCipher;
 112         this.writeEpoch++;
 113 
 114         this.isFirstAppOutputRecord = true;
 115 
 116         // set the epoch number
 117         this.writeCipher.authenticator.setEpochNumber(this.writeEpoch);
 118     }
 119 
 120     @Override
 121     void encodeAlert(byte level, byte description) throws IOException {
 122         if (isClosed()) {
 123             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
 124                 SSLLogger.warning("outbound has closed, ignore outbound " +
 125                     "alert message: " + Alert.nameOf(description));
 126             }
 127             return;
 128         }
 129 
 130         if (fragmenter == null) {
 131            fragmenter = new DTLSFragmenter();
 132         }
 133 
 134         fragmenter.queueUpAlert(level, description);
 135     }
 136 
 137     @Override
 138     void encodeChangeCipherSpec() throws IOException {
 139         if (isClosed()) {
 140             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
 141                 SSLLogger.warning("outbound has closed, ignore outbound " +
 142                     "change_cipher_spec message");
 143             }
 144             return;
 145         }
 146 
 147         if (fragmenter == null) {
 148            fragmenter = new DTLSFragmenter();
 149         }
 150         fragmenter.queueUpChangeCipherSpec();
 151     }
 152 
 153     @Override
 154     void encodeHandshake(byte[] source,
 155             int offset, int length) throws IOException {
 156         if (isClosed()) {
 157             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
 158                 SSLLogger.warning("outbound has closed, ignore outbound " +
 159                         "handshake message",
 160                         ByteBuffer.wrap(source, offset, length));
 161             }
 162             return;
 163         }
 164 
 165         if (firstMessage) {
 166             firstMessage = false;
 167         }
 168 
 169         if (fragmenter == null) {
 170            fragmenter = new DTLSFragmenter();
 171         }
 172 
 173         fragmenter.queueUpHandshake(source, offset, length);
 174     }
 175 
 176     @Override
 177     Ciphertext encode(
 178         ByteBuffer[] srcs, int srcsOffset, int srcsLength,
 179         ByteBuffer[] dsts, int dstsOffset, int dstsLength) throws IOException {
 180 
 181         if (isClosed) {
 182             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
 183                 SSLLogger.warning("outbound has closed, ignore outbound " +
 184                     "application data or cached messages");
 185             }
 186 
 187             return null;
 188         } else if (isCloseWaiting) {
 189             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
 190                 SSLLogger.warning("outbound has closed, ignore outbound " +
 191                     "application data");
 192             }
 193 
 194             srcs = null;    // use no application data.
 195         }
 196 
 197         return encode(srcs, srcsOffset, srcsLength, dsts[0]);
 198     }
 199 
 200     private Ciphertext encode(ByteBuffer[] sources, int offset, int length,
 201             ByteBuffer destination) throws IOException {
 202 
 203         if (writeCipher.authenticator.seqNumOverflow()) {
 204             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
 205                 SSLLogger.fine(
 206                     "sequence number extremely close to overflow " +
 207                     "(2^64-1 packets). Closing connection.");
 208             }
 209 
 210             throw new SSLHandshakeException("sequence number overflow");
 211         }
 212 
 213         // Don't process the incoming record until all of the buffered records
 214         // get handled.  May need retransmission if no sources specified.
 215         if (!isEmpty() || sources == null || sources.length == 0) {
 216             Ciphertext ct = acquireCiphertext(destination);
 217             if (ct != null) {
 218                 return ct;
 219             }
 220         }
 221 
 222         if (sources == null || sources.length == 0) {
 223             return null;
 224         }
 225 
 226         int srcsRemains = 0;
 227         for (int i = offset; i < offset + length; i++) {
 228             srcsRemains += sources[i].remaining();
 229         }
 230 
 231         if (srcsRemains == 0) {
 232             return null;
 233         }
 234 
 235         // not apply to handshake message
 236         int fragLen;
 237         if (packetSize > 0) {
 238             fragLen = Math.min(maxRecordSize, packetSize);
 239             fragLen = writeCipher.calculateFragmentSize(
 240                     fragLen, headerSize);
 241 
 242             fragLen = Math.min(fragLen, Record.maxDataSize);
 243         } else {
 244             fragLen = Record.maxDataSize;
 245         }
 246 
 247         // Calculate more impact, for example TLS 1.3 padding.
 248         fragLen = calculateFragmentSize(fragLen);

 249 
 250         int dstPos = destination.position();
 251         int dstLim = destination.limit();
 252         int dstContent = dstPos + headerSize +
 253                                 writeCipher.getExplicitNonceSize();
 254         destination.position(dstContent);
 255 
 256         int remains = Math.min(fragLen, destination.remaining());
 257         fragLen = 0;
 258         int srcsLen = offset + length;
 259         for (int i = offset; (i < srcsLen) && (remains > 0); i++) {
 260             int amount = Math.min(sources[i].remaining(), remains);
 261             int srcLimit = sources[i].limit();
 262             sources[i].limit(sources[i].position() + amount);
 263             destination.put(sources[i]);
 264             sources[i].limit(srcLimit);         // restore the limit
 265             remains -= amount;
 266             fragLen += amount;
 267         }
 268 
 269         destination.limit(destination.position());
 270         destination.position(dstContent);
 271 
 272         if (SSLLogger.isOn && SSLLogger.isOn("record")) {
 273             SSLLogger.fine(
 274                     "WRITE: " + protocolVersion + " " +
 275                     ContentType.APPLICATION_DATA.name +
 276                     ", length = " + destination.remaining());
 277         }
 278 
 279         // Encrypt the fragment and wrap up a record.
 280         long recordSN = encrypt(writeCipher,
 281                 ContentType.APPLICATION_DATA.id, destination,
 282                 dstPos, dstLim, headerSize,
 283                 protocolVersion);
 284 
 285         if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
 286             ByteBuffer temporary = destination.duplicate();
 287             temporary.limit(temporary.position());
 288             temporary.position(dstPos);
 289             SSLLogger.fine("Raw write", temporary);
 290         }
 291 
 292         // remain the limit unchanged
 293         destination.limit(dstLim);
 294 
 295         return new Ciphertext(ContentType.APPLICATION_DATA.id,
 296                 SSLHandshake.NOT_APPLICABLE.id, recordSN);
 297     }
 298 
 299     private Ciphertext acquireCiphertext(
 300             ByteBuffer destination) throws IOException {
 301         if (fragmenter != null) {
 302             return fragmenter.acquireCiphertext(destination);
 303         }
 304 
 305         return null;
 306     }
 307 
 308     @Override
 309     boolean isEmpty() {
 310         return (fragmenter == null) || fragmenter.isEmpty();
 311     }
 312 
 313     @Override
 314     void launchRetransmission() {
 315         // Note: Please don't retransmit if there are handshake messages
 316         // or alerts waiting in the queue.
 317         if ((fragmenter != null) && fragmenter.isRetransmittable()) {
 318             fragmenter.setRetransmission();
 319         }
 320     }
 321 
 322     // buffered record fragment
 323     private static class RecordMemo {
 324         byte            contentType;
 325         byte            majorVersion;
 326         byte            minorVersion;
 327         int             encodeEpoch;
 328         SSLWriteCipher  encodeCipher;
 329 
 330         byte[]          fragment;
 331     }
 332 
 333     private static class HandshakeMemo extends RecordMemo {
 334         byte            handshakeType;
 335         int             messageSequence;
 336         int             acquireOffset;
 337     }
 338 
 339     private final class DTLSFragmenter {
 340         private final LinkedList<RecordMemo> handshakeMemos =
 341                 new LinkedList<>();
 342         private int acquireIndex = 0;
 343         private int messageSequence = 0;
 344         private boolean flightIsReady = false;
 345 
 346         // Per section 4.1.1, RFC 6347:
 347         //
 348         // If repeated retransmissions do not result in a response, and the
 349         // PMTU is unknown, subsequent retransmissions SHOULD back off to a
 350         // smaller record size, fragmenting the handshake message as
 351         // appropriate.
 352         //
 353         // In this implementation, two times of retransmits would be attempted
 354         // before backing off.  The back off is supported only if the packet
 355         // size is bigger than 256 bytes.
 356         private int retransmits = 2;            // attemps of retransmits
 357 
 358         void queueUpHandshake(byte[] buf,
 359                 int offset, int length) throws IOException {
 360 
 361             // Cleanup if a new flight starts.
 362             if (flightIsReady) {
 363                 handshakeMemos.clear();
 364                 acquireIndex = 0;
 365                 flightIsReady = false;
 366             }
 367 
 368             HandshakeMemo memo = new HandshakeMemo();
 369 
 370             memo.contentType = ContentType.HANDSHAKE.id;
 371             memo.majorVersion = protocolVersion.major;
 372             memo.minorVersion = protocolVersion.minor;
 373             memo.encodeEpoch = writeEpoch;
 374             memo.encodeCipher = writeCipher;
 375 
 376             memo.handshakeType = buf[offset];
 377             memo.messageSequence = messageSequence++;
 378             memo.acquireOffset = 0;
 379             memo.fragment = new byte[length - 4];       // 4: header size
 380                                                         //    1: HandshakeType
 381                                                         //    3: message length
 382             System.arraycopy(buf, offset + 4, memo.fragment, 0, length - 4);
 383 
 384             handshakeHashing(memo, memo.fragment);
 385             handshakeMemos.add(memo);
 386 
 387             if ((memo.handshakeType == SSLHandshake.CLIENT_HELLO.id) ||
 388                 (memo.handshakeType == SSLHandshake.HELLO_REQUEST.id) ||
 389                 (memo.handshakeType ==
 390                         SSLHandshake.HELLO_VERIFY_REQUEST.id) ||
 391                 (memo.handshakeType == SSLHandshake.SERVER_HELLO_DONE.id) ||
 392                 (memo.handshakeType == SSLHandshake.FINISHED.id)) {
 393 
 394                 flightIsReady = true;
 395             }
 396         }
 397 
 398         void queueUpChangeCipherSpec() {
 399 
 400             // Cleanup if a new flight starts.
 401             if (flightIsReady) {
 402                 handshakeMemos.clear();
 403                 acquireIndex = 0;
 404                 flightIsReady = false;
 405             }
 406 
 407             RecordMemo memo = new RecordMemo();
 408 
 409             memo.contentType = ContentType.CHANGE_CIPHER_SPEC.id;
 410             memo.majorVersion = protocolVersion.major;
 411             memo.minorVersion = protocolVersion.minor;
 412             memo.encodeEpoch = writeEpoch;
 413             memo.encodeCipher = writeCipher;
 414 
 415             memo.fragment = new byte[1];
 416             memo.fragment[0] = 1;
 417 
 418             handshakeMemos.add(memo);
 419         }
 420 
 421         void queueUpAlert(byte level, byte description) throws IOException {
 422             RecordMemo memo = new RecordMemo();
 423 
 424             memo.contentType = ContentType.ALERT.id;
 425             memo.majorVersion = protocolVersion.major;
 426             memo.minorVersion = protocolVersion.minor;
 427             memo.encodeEpoch = writeEpoch;
 428             memo.encodeCipher = writeCipher;
 429 
 430             memo.fragment = new byte[2];
 431             memo.fragment[0] = level;
 432             memo.fragment[1] = description;
 433 
 434             handshakeMemos.add(memo);
 435         }
 436 
 437         Ciphertext acquireCiphertext(ByteBuffer dstBuf) throws IOException {
 438             if (isEmpty()) {
 439                 if (isRetransmittable()) {
 440                     setRetransmission();    // configure for retransmission
 441                 } else {
 442                     return null;
 443                 }
 444             }
 445 
 446             RecordMemo memo = handshakeMemos.get(acquireIndex);
 447             HandshakeMemo hsMemo = null;
 448             if (memo.contentType == ContentType.HANDSHAKE.id) {
 449                 hsMemo = (HandshakeMemo)memo;
 450             }
 451 
 452             // ChangeCipherSpec message is pretty small.  Don't worry about
 453             // the fragmentation of ChangeCipherSpec record.
 454             int fragLen;
 455             if (packetSize > 0) {
 456                 fragLen = Math.min(maxRecordSize, packetSize);
 457                 fragLen = memo.encodeCipher.calculateFragmentSize(
 458                         fragLen, 25);   // 25: header size
 459                                                 //   13: DTLS record
 460                                                 //   12: DTLS handshake message
 461                 fragLen = Math.min(fragLen, Record.maxDataSize);
 462             } else {
 463                 fragLen = Record.maxDataSize;
 464             }
 465 
 466             // Calculate more impact, for example TLS 1.3 padding.
 467             fragLen = calculateFragmentSize(fragLen);

 468 
 469             int dstPos = dstBuf.position();
 470             int dstLim = dstBuf.limit();
 471             int dstContent = dstPos + headerSize +
 472                                     memo.encodeCipher.getExplicitNonceSize();
 473             dstBuf.position(dstContent);
 474 
 475             if (hsMemo != null) {
 476                 fragLen = Math.min(fragLen,
 477                         (hsMemo.fragment.length - hsMemo.acquireOffset));
 478 
 479                 dstBuf.put(hsMemo.handshakeType);
 480                 dstBuf.put((byte)((hsMemo.fragment.length >> 16) & 0xFF));
 481                 dstBuf.put((byte)((hsMemo.fragment.length >> 8) & 0xFF));
 482                 dstBuf.put((byte)(hsMemo.fragment.length & 0xFF));
 483                 dstBuf.put((byte)((hsMemo.messageSequence >> 8) & 0xFF));
 484                 dstBuf.put((byte)(hsMemo.messageSequence & 0xFF));
 485                 dstBuf.put((byte)((hsMemo.acquireOffset >> 16) & 0xFF));
 486                 dstBuf.put((byte)((hsMemo.acquireOffset >> 8) & 0xFF));
 487                 dstBuf.put((byte)(hsMemo.acquireOffset & 0xFF));
 488                 dstBuf.put((byte)((fragLen >> 16) & 0xFF));
 489                 dstBuf.put((byte)((fragLen >> 8) & 0xFF));
 490                 dstBuf.put((byte)(fragLen & 0xFF));
 491                 dstBuf.put(hsMemo.fragment, hsMemo.acquireOffset, fragLen);
 492             } else {
 493                 fragLen = Math.min(fragLen, memo.fragment.length);
 494                 dstBuf.put(memo.fragment, 0, fragLen);
 495             }
 496 
 497             dstBuf.limit(dstBuf.position());
 498             dstBuf.position(dstContent);
 499 
 500             if (SSLLogger.isOn && SSLLogger.isOn("record")) {
 501                 SSLLogger.fine(
 502                         "WRITE: " + protocolVersion + " " +
 503                         ContentType.nameOf(memo.contentType) +
 504                         ", length = " + dstBuf.remaining());
 505             }
 506 
 507             // Encrypt the fragment and wrap up a record.
 508             long recordSN = encrypt(memo.encodeCipher,
 509                     memo.contentType, dstBuf,
 510                     dstPos, dstLim, headerSize,
 511                     ProtocolVersion.valueOf(memo.majorVersion,
 512                             memo.minorVersion));
 513 
 514             if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
 515                 ByteBuffer temporary = dstBuf.duplicate();
 516                 temporary.limit(temporary.position());
 517                 temporary.position(dstPos);
 518                 SSLLogger.fine(
 519                         "Raw write (" + temporary.remaining() + ")", temporary);
 520             }
 521 
 522             // remain the limit unchanged
 523             dstBuf.limit(dstLim);
 524 
 525             // Reset the fragmentation offset.
 526             if (hsMemo != null) {
 527                 hsMemo.acquireOffset += fragLen;
 528                 if (hsMemo.acquireOffset == hsMemo.fragment.length) {
 529                     acquireIndex++;
 530                 }
 531 
 532                 return new Ciphertext(hsMemo.contentType,
 533                         hsMemo.handshakeType, recordSN);
 534             } else {
 535                 if (isCloseWaiting &&
 536                         memo.contentType == ContentType.ALERT.id) {
 537                     close();
 538                 }
 539 
 540                 acquireIndex++;
 541                 return new Ciphertext(memo.contentType,
 542                         SSLHandshake.NOT_APPLICABLE.id, recordSN);
 543             }
 544         }
 545 
 546         private void handshakeHashing(HandshakeMemo hsFrag, byte[] hsBody) {
 547 
 548             byte hsType = hsFrag.handshakeType;
 549             if (!handshakeHash.isHashable(hsType)) {
 550                 // omitted from handshake hash computation
 551                 return;
 552             }
 553 
 554             // calculate the DTLS header
 555             byte[] temporary = new byte[12];    // 12: handshake header size
 556 
 557             // Handshake.msg_type
 558             temporary[0] = hsFrag.handshakeType;
 559 
 560             // Handshake.length
 561             temporary[1] = (byte)((hsBody.length >> 16) & 0xFF);
 562             temporary[2] = (byte)((hsBody.length >> 8) & 0xFF);
 563             temporary[3] = (byte)(hsBody.length & 0xFF);
 564 
 565             // Handshake.message_seq
 566             temporary[4] = (byte)((hsFrag.messageSequence >> 8) & 0xFF);
 567             temporary[5] = (byte)(hsFrag.messageSequence & 0xFF);
 568 
 569             // Handshake.fragment_offset
 570             temporary[6] = 0;
 571             temporary[7] = 0;
 572             temporary[8] = 0;
 573 
 574             // Handshake.fragment_length
 575             temporary[9] = temporary[1];
 576             temporary[10] = temporary[2];
 577             temporary[11] = temporary[3];
 578 
 579             handshakeHash.deliver(temporary, 0, 12);
 580             handshakeHash.deliver(hsBody, 0, hsBody.length);
 581         }
 582 
 583         boolean isEmpty() {
 584             if (!flightIsReady || handshakeMemos.isEmpty() ||
 585                     acquireIndex >= handshakeMemos.size()) {
 586                 return true;
 587             }
 588 
 589             return false;
 590         }
 591 
 592         boolean hasAlert() {
 593             for (RecordMemo memo : handshakeMemos) {
 594                 if (memo.contentType == ContentType.ALERT.id) {
 595                     return true;
 596                 }
 597             }
 598 
 599             return false;
 600         }
 601 
 602         boolean isRetransmittable() {
 603             return (flightIsReady && !handshakeMemos.isEmpty() &&
 604                                 (acquireIndex >= handshakeMemos.size()));
 605         }
 606 
 607         private void setRetransmission() {
 608             acquireIndex = 0;
 609             for (RecordMemo memo : handshakeMemos) {
 610                 if (memo instanceof HandshakeMemo) {
 611                     HandshakeMemo hmemo = (HandshakeMemo)memo;
 612                     hmemo.acquireOffset = 0;
 613                 }
 614             }
 615 
 616             // Shrink packet size if:
 617             // 1. maximum fragment size is allowed, in which case the packet
 618             //    size is configured bigger than maxRecordSize;
 619             // 2. maximum packet is bigger than 256 bytes;
 620             // 3. two times of retransmits have been attempted.
 621             if ((packetSize <= maxRecordSize) &&
 622                     (packetSize > 256) && ((retransmits--) <= 0)) {
 623 
 624                 // shrink packet size
 625                 shrinkPacketSize();
 626                 retransmits = 2;        // attemps of retransmits
 627             }
 628         }
 629 
 630         private void shrinkPacketSize() {
 631             packetSize = Math.max(256, packetSize / 2);
 632         }
 633     }
 634 }
--- EOF ---