1 /*
  2  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  3  *
  4  * This code is free software; you can redistribute it and/or modify it
  5  * under the terms of the GNU General Public License version 2 only, as
  6  * published by the Free Software Foundation.  Oracle designates this
  7  * particular file as subject to the "Classpath" exception as provided
  8  * by Oracle in the LICENSE file that accompanied this code.
  9  *
 10  * This code is distributed in the hope that it will be useful, but WITHOUT
 11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 12  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 13  * version 2 for more details (a copy is included in the LICENSE file that
 14  * accompanied this code).
 15  *
 16  * You should have received a copy of the GNU General Public License version
 17  * 2 along with this work; if not, write to the Free Software Foundation,
 18  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 19  *
 20  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 21  * or visit www.oracle.com if you need additional information or have any
 22  * questions.
 23  */
 24 
 25 /* infback.c -- inflate using a call-back interface
 26  * Copyright (C) 1995-2022 Mark Adler
 27  * For conditions of distribution and use, see copyright notice in zlib.h
 28  */
 29 
 30 /*
 31    This code is largely copied from inflate.c.  Normally either infback.o or
 32    inflate.o would be linked into an application--not both.  The interface
 33    with inffast.c is retained so that optimized assembler-coded versions of
 34    inflate_fast() can be used with either inflate.c or infback.c.
 35  */
 36 
 37 #include "zutil.h"
 38 #include "inftrees.h"
 39 #include "inflate.h"
 40 #include "inffast.h"
 41 
 42 /*
 43    strm provides memory allocation functions in zalloc and zfree, or
 44    Z_NULL to use the library memory allocation functions.
 45 
 46    windowBits is in the range 8..15, and window is a user-supplied
 47    window and output buffer that is 2**windowBits bytes.
 48  */
 49 int ZEXPORT inflateBackInit_(z_streamp strm, int windowBits,
 50                              unsigned char FAR *window, const char *version,
 51                              int stream_size) {
 52     struct inflate_state FAR *state;
 53 
 54     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
 55         stream_size != (int)(sizeof(z_stream)))
 56         return Z_VERSION_ERROR;
 57     if (strm == Z_NULL || window == Z_NULL ||
 58         windowBits < 8 || windowBits > 15)
 59         return Z_STREAM_ERROR;
 60     strm->msg = Z_NULL;                 /* in case we return an error */
 61     if (strm->zalloc == (alloc_func)0) {
 62 #ifdef Z_SOLO
 63         return Z_STREAM_ERROR;
 64 #else
 65         strm->zalloc = zcalloc;
 66         strm->opaque = (voidpf)0;
 67 #endif
 68     }
 69     if (strm->zfree == (free_func)0)
 70 #ifdef Z_SOLO
 71         return Z_STREAM_ERROR;
 72 #else
 73     strm->zfree = zcfree;
 74 #endif
 75     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
 76                                                sizeof(struct inflate_state));
 77     if (state == Z_NULL) return Z_MEM_ERROR;
 78     Tracev((stderr, "inflate: allocated\n"));
 79     strm->state = (struct internal_state FAR *)state;
 80     state->dmax = 32768U;
 81     state->wbits = (uInt)windowBits;
 82     state->wsize = 1U << windowBits;
 83     state->window = window;
 84     state->wnext = 0;
 85     state->whave = 0;
 86     state->sane = 1;
 87     return Z_OK;
 88 }
 89 
 90 /*
 91    Return state with length and distance decoding tables and index sizes set to
 92    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
 93    If BUILDFIXED is defined, then instead this routine builds the tables the
 94    first time it's called, and returns those tables the first time and
 95    thereafter.  This reduces the size of the code by about 2K bytes, in
 96    exchange for a little execution time.  However, BUILDFIXED should not be
 97    used for threaded applications, since the rewriting of the tables and virgin
 98    may not be thread-safe.
 99  */
100 local void fixedtables(struct inflate_state FAR *state) {
101 #ifdef BUILDFIXED
102     static int virgin = 1;
103     static code *lenfix, *distfix;
104     static code fixed[544];
105 
106     /* build fixed huffman tables if first call (may not be thread safe) */
107     if (virgin) {
108         unsigned sym, bits;
109         static code *next;
110 
111         /* literal/length table */
112         sym = 0;
113         while (sym < 144) state->lens[sym++] = 8;
114         while (sym < 256) state->lens[sym++] = 9;
115         while (sym < 280) state->lens[sym++] = 7;
116         while (sym < 288) state->lens[sym++] = 8;
117         next = fixed;
118         lenfix = next;
119         bits = 9;
120         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
121 
122         /* distance table */
123         sym = 0;
124         while (sym < 32) state->lens[sym++] = 5;
125         distfix = next;
126         bits = 5;
127         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
128 
129         /* do this just once */
130         virgin = 0;
131     }
132 #else /* !BUILDFIXED */
133 #   include "inffixed.h"
134 #endif /* BUILDFIXED */
135     state->lencode = lenfix;
136     state->lenbits = 9;
137     state->distcode = distfix;
138     state->distbits = 5;
139 }
140 
141 /* Macros for inflateBack(): */
142 
143 /* Load returned state from inflate_fast() */
144 #define LOAD() \
145     do { \
146         put = strm->next_out; \
147         left = strm->avail_out; \
148         next = strm->next_in; \
149         have = strm->avail_in; \
150         hold = state->hold; \
151         bits = state->bits; \
152     } while (0)
153 
154 /* Set state from registers for inflate_fast() */
155 #define RESTORE() \
156     do { \
157         strm->next_out = put; \
158         strm->avail_out = left; \
159         strm->next_in = next; \
160         strm->avail_in = have; \
161         state->hold = hold; \
162         state->bits = bits; \
163     } while (0)
164 
165 /* Clear the input bit accumulator */
166 #define INITBITS() \
167     do { \
168         hold = 0; \
169         bits = 0; \
170     } while (0)
171 
172 /* Assure that some input is available.  If input is requested, but denied,
173    then return a Z_BUF_ERROR from inflateBack(). */
174 #define PULL() \
175     do { \
176         if (have == 0) { \
177             have = in(in_desc, &next); \
178             if (have == 0) { \
179                 next = Z_NULL; \
180                 ret = Z_BUF_ERROR; \
181                 goto inf_leave; \
182             } \
183         } \
184     } while (0)
185 
186 /* Get a byte of input into the bit accumulator, or return from inflateBack()
187    with an error if there is no input available. */
188 #define PULLBYTE() \
189     do { \
190         PULL(); \
191         have--; \
192         hold += (unsigned long)(*next++) << bits; \
193         bits += 8; \
194     } while (0)
195 
196 /* Assure that there are at least n bits in the bit accumulator.  If there is
197    not enough available input to do that, then return from inflateBack() with
198    an error. */
199 #define NEEDBITS(n) \
200     do { \
201         while (bits < (unsigned)(n)) \
202             PULLBYTE(); \
203     } while (0)
204 
205 /* Return the low n bits of the bit accumulator (n < 16) */
206 #define BITS(n) \
207     ((unsigned)hold & ((1U << (n)) - 1))
208 
209 /* Remove n bits from the bit accumulator */
210 #define DROPBITS(n) \
211     do { \
212         hold >>= (n); \
213         bits -= (unsigned)(n); \
214     } while (0)
215 
216 /* Remove zero to seven bits as needed to go to a byte boundary */
217 #define BYTEBITS() \
218     do { \
219         hold >>= bits & 7; \
220         bits -= bits & 7; \
221     } while (0)
222 
223 /* Assure that some output space is available, by writing out the window
224    if it's full.  If the write fails, return from inflateBack() with a
225    Z_BUF_ERROR. */
226 #define ROOM() \
227     do { \
228         if (left == 0) { \
229             put = state->window; \
230             left = state->wsize; \
231             state->whave = left; \
232             if (out(out_desc, put, left)) { \
233                 ret = Z_BUF_ERROR; \
234                 goto inf_leave; \
235             } \
236         } \
237     } while (0)
238 
239 /*
240    strm provides the memory allocation functions and window buffer on input,
241    and provides information on the unused input on return.  For Z_DATA_ERROR
242    returns, strm will also provide an error message.
243 
244    in() and out() are the call-back input and output functions.  When
245    inflateBack() needs more input, it calls in().  When inflateBack() has
246    filled the window with output, or when it completes with data in the
247    window, it calls out() to write out the data.  The application must not
248    change the provided input until in() is called again or inflateBack()
249    returns.  The application must not change the window/output buffer until
250    inflateBack() returns.
251 
252    in() and out() are called with a descriptor parameter provided in the
253    inflateBack() call.  This parameter can be a structure that provides the
254    information required to do the read or write, as well as accumulated
255    information on the input and output such as totals and check values.
256 
257    in() should return zero on failure.  out() should return non-zero on
258    failure.  If either in() or out() fails, than inflateBack() returns a
259    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
260    was in() or out() that caused in the error.  Otherwise,  inflateBack()
261    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
262    error, or Z_MEM_ERROR if it could not allocate memory for the state.
263    inflateBack() can also return Z_STREAM_ERROR if the input parameters
264    are not correct, i.e. strm is Z_NULL or the state was not initialized.
265  */
266 int ZEXPORT inflateBack(z_streamp strm, in_func in, void FAR *in_desc,
267                         out_func out, void FAR *out_desc) {
268     struct inflate_state FAR *state;
269     z_const unsigned char FAR *next;    /* next input */
270     unsigned char FAR *put;     /* next output */
271     unsigned have, left;        /* available input and output */
272     unsigned long hold;         /* bit buffer */
273     unsigned bits;              /* bits in bit buffer */
274     unsigned copy;              /* number of stored or match bytes to copy */
275     unsigned char FAR *from;    /* where to copy match bytes from */
276     code here;                  /* current decoding table entry */
277     code last;                  /* parent table entry */
278     unsigned len;               /* length to copy for repeats, bits to drop */
279     int ret;                    /* return code */
280     static const unsigned short order[19] = /* permutation of code lengths */
281         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
282 
283     /* Check that the strm exists and that the state was initialized */
284     if (strm == Z_NULL || strm->state == Z_NULL)
285         return Z_STREAM_ERROR;
286     state = (struct inflate_state FAR *)strm->state;
287 
288     /* Reset the state */
289     strm->msg = Z_NULL;
290     state->mode = TYPE;
291     state->last = 0;
292     state->whave = 0;
293     next = strm->next_in;
294     have = next != Z_NULL ? strm->avail_in : 0;
295     hold = 0;
296     bits = 0;
297     put = state->window;
298     left = state->wsize;
299 
300     /* Inflate until end of block marked as last */
301     for (;;)
302         switch (state->mode) {
303         case TYPE:
304             /* determine and dispatch block type */
305             if (state->last) {
306                 BYTEBITS();
307                 state->mode = DONE;
308                 break;
309             }
310             NEEDBITS(3);
311             state->last = BITS(1);
312             DROPBITS(1);
313             switch (BITS(2)) {
314             case 0:                             /* stored block */
315                 Tracev((stderr, "inflate:     stored block%s\n",
316                         state->last ? " (last)" : ""));
317                 state->mode = STORED;
318                 break;
319             case 1:                             /* fixed block */
320                 fixedtables(state);
321                 Tracev((stderr, "inflate:     fixed codes block%s\n",
322                         state->last ? " (last)" : ""));
323                 state->mode = LEN;              /* decode codes */
324                 break;
325             case 2:                             /* dynamic block */
326                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
327                         state->last ? " (last)" : ""));
328                 state->mode = TABLE;
329                 break;
330             case 3:
331                 strm->msg = (char *)"invalid block type";
332                 state->mode = BAD;
333             }
334             DROPBITS(2);
335             break;
336 
337         case STORED:
338             /* get and verify stored block length */
339             BYTEBITS();                         /* go to byte boundary */
340             NEEDBITS(32);
341             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
342                 strm->msg = (char *)"invalid stored block lengths";
343                 state->mode = BAD;
344                 break;
345             }
346             state->length = (unsigned)hold & 0xffff;
347             Tracev((stderr, "inflate:       stored length %u\n",
348                     state->length));
349             INITBITS();
350 
351             /* copy stored block from input to output */
352             while (state->length != 0) {
353                 copy = state->length;
354                 PULL();
355                 ROOM();
356                 if (copy > have) copy = have;
357                 if (copy > left) copy = left;
358                 zmemcpy(put, next, copy);
359                 have -= copy;
360                 next += copy;
361                 left -= copy;
362                 put += copy;
363                 state->length -= copy;
364             }
365             Tracev((stderr, "inflate:       stored end\n"));
366             state->mode = TYPE;
367             break;
368 
369         case TABLE:
370             /* get dynamic table entries descriptor */
371             NEEDBITS(14);
372             state->nlen = BITS(5) + 257;
373             DROPBITS(5);
374             state->ndist = BITS(5) + 1;
375             DROPBITS(5);
376             state->ncode = BITS(4) + 4;
377             DROPBITS(4);
378 #ifndef PKZIP_BUG_WORKAROUND
379             if (state->nlen > 286 || state->ndist > 30) {
380                 strm->msg = (char *)"too many length or distance symbols";
381                 state->mode = BAD;
382                 break;
383             }
384 #endif
385             Tracev((stderr, "inflate:       table sizes ok\n"));
386 
387             /* get code length code lengths (not a typo) */
388             state->have = 0;
389             while (state->have < state->ncode) {
390                 NEEDBITS(3);
391                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
392                 DROPBITS(3);
393             }
394             while (state->have < 19)
395                 state->lens[order[state->have++]] = 0;
396             state->next = state->codes;
397             state->lencode = (code const FAR *)(state->next);
398             state->lenbits = 7;
399             ret = inflate_table(CODES, state->lens, 19, &(state->next),
400                                 &(state->lenbits), state->work);
401             if (ret) {
402                 strm->msg = (char *)"invalid code lengths set";
403                 state->mode = BAD;
404                 break;
405             }
406             Tracev((stderr, "inflate:       code lengths ok\n"));
407 
408             /* get length and distance code code lengths */
409             state->have = 0;
410             while (state->have < state->nlen + state->ndist) {
411                 for (;;) {
412                     here = state->lencode[BITS(state->lenbits)];
413                     if ((unsigned)(here.bits) <= bits) break;
414                     PULLBYTE();
415                 }
416                 if (here.val < 16) {
417                     DROPBITS(here.bits);
418                     state->lens[state->have++] = here.val;
419                 }
420                 else {
421                     if (here.val == 16) {
422                         NEEDBITS(here.bits + 2);
423                         DROPBITS(here.bits);
424                         if (state->have == 0) {
425                             strm->msg = (char *)"invalid bit length repeat";
426                             state->mode = BAD;
427                             break;
428                         }
429                         len = (unsigned)(state->lens[state->have - 1]);
430                         copy = 3 + BITS(2);
431                         DROPBITS(2);
432                     }
433                     else if (here.val == 17) {
434                         NEEDBITS(here.bits + 3);
435                         DROPBITS(here.bits);
436                         len = 0;
437                         copy = 3 + BITS(3);
438                         DROPBITS(3);
439                     }
440                     else {
441                         NEEDBITS(here.bits + 7);
442                         DROPBITS(here.bits);
443                         len = 0;
444                         copy = 11 + BITS(7);
445                         DROPBITS(7);
446                     }
447                     if (state->have + copy > state->nlen + state->ndist) {
448                         strm->msg = (char *)"invalid bit length repeat";
449                         state->mode = BAD;
450                         break;
451                     }
452                     while (copy--)
453                         state->lens[state->have++] = (unsigned short)len;
454                 }
455             }
456 
457             /* handle error breaks in while */
458             if (state->mode == BAD) break;
459 
460             /* check for end-of-block code (better have one) */
461             if (state->lens[256] == 0) {
462                 strm->msg = (char *)"invalid code -- missing end-of-block";
463                 state->mode = BAD;
464                 break;
465             }
466 
467             /* build code tables -- note: do not change the lenbits or distbits
468                values here (9 and 6) without reading the comments in inftrees.h
469                concerning the ENOUGH constants, which depend on those values */
470             state->next = state->codes;
471             state->lencode = (code const FAR *)(state->next);
472             state->lenbits = 9;
473             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
474                                 &(state->lenbits), state->work);
475             if (ret) {
476                 strm->msg = (char *)"invalid literal/lengths set";
477                 state->mode = BAD;
478                 break;
479             }
480             state->distcode = (code const FAR *)(state->next);
481             state->distbits = 6;
482             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
483                             &(state->next), &(state->distbits), state->work);
484             if (ret) {
485                 strm->msg = (char *)"invalid distances set";
486                 state->mode = BAD;
487                 break;
488             }
489             Tracev((stderr, "inflate:       codes ok\n"));
490             state->mode = LEN;
491                 /* fallthrough */
492 
493         case LEN:
494             /* use inflate_fast() if we have enough input and output */
495             if (have >= 6 && left >= 258) {
496                 RESTORE();
497                 if (state->whave < state->wsize)
498                     state->whave = state->wsize - left;
499                 inflate_fast(strm, state->wsize);
500                 LOAD();
501                 break;
502             }
503 
504             /* get a literal, length, or end-of-block code */
505             for (;;) {
506                 here = state->lencode[BITS(state->lenbits)];
507                 if ((unsigned)(here.bits) <= bits) break;
508                 PULLBYTE();
509             }
510             if (here.op && (here.op & 0xf0) == 0) {
511                 last = here;
512                 for (;;) {
513                     here = state->lencode[last.val +
514                             (BITS(last.bits + last.op) >> last.bits)];
515                     if ((unsigned)(last.bits + here.bits) <= bits) break;
516                     PULLBYTE();
517                 }
518                 DROPBITS(last.bits);
519             }
520             DROPBITS(here.bits);
521             state->length = (unsigned)here.val;
522 
523             /* process literal */
524             if (here.op == 0) {
525                 Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
526                         "inflate:         literal '%c'\n" :
527                         "inflate:         literal 0x%02x\n", here.val));
528                 ROOM();
529                 *put++ = (unsigned char)(state->length);
530                 left--;
531                 state->mode = LEN;
532                 break;
533             }
534 
535             /* process end of block */
536             if (here.op & 32) {
537                 Tracevv((stderr, "inflate:         end of block\n"));
538                 state->mode = TYPE;
539                 break;
540             }
541 
542             /* invalid code */
543             if (here.op & 64) {
544                 strm->msg = (char *)"invalid literal/length code";
545                 state->mode = BAD;
546                 break;
547             }
548 
549             /* length code -- get extra bits, if any */
550             state->extra = (unsigned)(here.op) & 15;
551             if (state->extra != 0) {
552                 NEEDBITS(state->extra);
553                 state->length += BITS(state->extra);
554                 DROPBITS(state->extra);
555             }
556             Tracevv((stderr, "inflate:         length %u\n", state->length));
557 
558             /* get distance code */
559             for (;;) {
560                 here = state->distcode[BITS(state->distbits)];
561                 if ((unsigned)(here.bits) <= bits) break;
562                 PULLBYTE();
563             }
564             if ((here.op & 0xf0) == 0) {
565                 last = here;
566                 for (;;) {
567                     here = state->distcode[last.val +
568                             (BITS(last.bits + last.op) >> last.bits)];
569                     if ((unsigned)(last.bits + here.bits) <= bits) break;
570                     PULLBYTE();
571                 }
572                 DROPBITS(last.bits);
573             }
574             DROPBITS(here.bits);
575             if (here.op & 64) {
576                 strm->msg = (char *)"invalid distance code";
577                 state->mode = BAD;
578                 break;
579             }
580             state->offset = (unsigned)here.val;
581 
582             /* get distance extra bits, if any */
583             state->extra = (unsigned)(here.op) & 15;
584             if (state->extra != 0) {
585                 NEEDBITS(state->extra);
586                 state->offset += BITS(state->extra);
587                 DROPBITS(state->extra);
588             }
589             if (state->offset > state->wsize - (state->whave < state->wsize ?
590                                                 left : 0)) {
591                 strm->msg = (char *)"invalid distance too far back";
592                 state->mode = BAD;
593                 break;
594             }
595             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
596 
597             /* copy match from window to output */
598             do {
599                 ROOM();
600                 copy = state->wsize - state->offset;
601                 if (copy < left) {
602                     from = put + copy;
603                     copy = left - copy;
604                 }
605                 else {
606                     from = put - state->offset;
607                     copy = left;
608                 }
609                 if (copy > state->length) copy = state->length;
610                 state->length -= copy;
611                 left -= copy;
612                 do {
613                     *put++ = *from++;
614                 } while (--copy);
615             } while (state->length != 0);
616             break;
617 
618         case DONE:
619             /* inflate stream terminated properly */
620             ret = Z_STREAM_END;
621             goto inf_leave;
622 
623         case BAD:
624             ret = Z_DATA_ERROR;
625             goto inf_leave;
626 
627         default:
628             /* can't happen, but makes compilers happy */
629             ret = Z_STREAM_ERROR;
630             goto inf_leave;
631         }
632 
633     /* Write leftover output and return unused input */
634   inf_leave:
635     if (left < state->wsize) {
636         if (out(out_desc, state->window, state->wsize - left) &&
637             ret == Z_STREAM_END)
638             ret = Z_BUF_ERROR;
639     }
640     strm->next_in = next;
641     strm->avail_in = have;
642     return ret;
643 }
644 
645 int ZEXPORT inflateBackEnd(z_streamp strm) {
646     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
647         return Z_STREAM_ERROR;
648     ZFREE(strm, strm->state);
649     strm->state = Z_NULL;
650     Tracev((stderr, "inflate: end\n"));
651     return Z_OK;
652 }