001// ***************************************************************************************************************************
002// * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements.  See the NOTICE file *
003// * distributed with this work for additional information regarding copyright ownership.  The ASF licenses this file        *
004// * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance            *
005// * with the License.  You may obtain a copy of the License at                                                              *
006// *                                                                                                                         *
007// *  http://www.apache.org/licenses/LICENSE-2.0                                                                             *
008// *                                                                                                                         *
009// * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an  *
010// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the License for the        *
011// * specific language governing permissions and limitations under the License.                                              *
012// ***************************************************************************************************************************
013package org.apache.juneau.rest.util;
014
015import java.io.*;
016
017import jakarta.servlet.*;
018
019/**
020 * ServletInputStream wrapper around a normal input stream with support for limiting input.
021 *
022 * <h5 class='section'>See Also:</h5><ul>
023 * </ul>
024 */
025public final class BoundedServletInputStream extends ServletInputStream {
026
027   private final InputStream is;
028   private final ServletInputStream sis;
029   private long remain;
030
031   /**
032    * Wraps the specified input stream.
033    *
034    * @param is The input stream to wrap.
035    * @param max The maximum number of bytes to read from the stream.
036    */
037   public BoundedServletInputStream(InputStream is, long max) {
038      this.is = is;
039      this.sis = null;
040      this.remain = max;
041   }
042
043   /**
044    * Wraps the specified input stream.
045    *
046    * @param sis The input stream to wrap.
047    * @param max The maximum number of bytes to read from the stream.
048    */
049   public BoundedServletInputStream(ServletInputStream sis, long max) {
050      this.sis = sis;
051      this.is = sis;
052      this.remain = max;
053   }
054
055   /**
056    * Wraps the specified byte array.
057    *
058    * @param b The byte contents of the stream.
059    */
060   public BoundedServletInputStream(byte[] b) {
061      this(new ByteArrayInputStream(b), Long.MAX_VALUE);
062   }
063
064   @Override /* InputStream */
065   public int read() throws IOException {
066      decrement();
067      return is.read();
068   }
069
070   @Override /* InputStream */
071   public int read(byte[] b) throws IOException {
072      return read(b, 0, b.length);
073   }
074
075   @Override /* InputStream */
076   public int read(final byte[] b, final int off, final int len) throws IOException {
077      long numBytes = Math.min(len, remain);
078      int r = is.read(b, off, (int) numBytes);
079      if (r == -1)
080         return -1;
081      decrement(numBytes);
082      return r;
083   }
084
085   @Override /* InputStream */
086   public long skip(final long n) throws IOException {
087      long toSkip = Math.min(n, remain);
088      long r = is.skip(toSkip);
089      decrement(r);
090      return r;
091   }
092
093   @Override /* InputStream */
094   public int available() throws IOException {
095      if (remain <= 0)
096         return 0;
097      return is.available();
098   }
099
100   @Override /* InputStream */
101   public synchronized void reset() throws IOException {
102      is.reset();
103   }
104
105   @Override /* InputStream */
106   public synchronized void mark(int limit) {
107      is.mark(limit);
108   }
109
110   @Override /* InputStream */
111   public boolean markSupported() {
112      return is.markSupported();
113   }
114
115   @Override /* InputStream */
116   public void close() throws IOException {
117      is.close();
118   }
119
120   @Override /* ServletInputStream */
121   public boolean isFinished() {
122      return sis == null ? false : sis.isFinished();
123   }
124
125   @Override /* ServletInputStream */
126   public boolean isReady() {
127      return sis == null ? true : sis.isReady();
128   }
129
130   @Override /* ServletInputStream */
131   public void setReadListener(ReadListener arg0) {
132      if (sis != null)
133         sis.setReadListener(arg0);
134   }
135
136   private void decrement() throws IOException {
137      remain--;
138      if (remain < 0)
139         throw new IOException("Input limit exceeded.  See @Rest(maxInput).");
140   }
141
142   private void decrement(long count) throws IOException {
143      remain -= count;
144      if (remain < 0)
145         throw new IOException("Input limit exceeded.  See @Rest(maxInput).");
146   }
147}