Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1785,8 +1785,8 @@ abstract class HttpServerTest<SERVER> extends WithHttpServer<SERVER> {
TEST_WRITER.get(0).any {
span ->
def tag = span.getTag('request.body.files_content') as String
tag?.contains("content_of_file_$maxFilesToInspect") &&
!tag.contains("content_of_file_${maxFilesToInspect + 1}")
// Exactly maxFilesToInspect files inspected; which file is excluded depends on iteration order
tag != null && tag.count('content_of_file_') == maxFilesToInspect
}

cleanup:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,36 @@
import static datadog.trace.api.gateway.Events.EVENTS;

import datadog.appsec.api.blocking.BlockingException;
import datadog.trace.api.Config;
import datadog.trace.api.gateway.BlockResponseFunction;
import datadog.trace.api.gateway.CallbackProvider;
import datadog.trace.api.gateway.Flow;
import datadog.trace.api.gateway.RequestContext;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.http.MultipartContentDecoder;
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
import jakarta.servlet.http.Part;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.function.BiFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultipartHelper {

public static final int MAX_CONTENT_BYTES = Config.get().getAppSecMaxFileContentBytes();
public static final int MAX_FILES_TO_INSPECT = Config.get().getAppSecMaxFileContentCount();

private static final Logger log = LoggerFactory.getLogger(MultipartHelper.class);

private MultipartHelper() {}

/**
* Extracts non-null, non-empty filenames from a collection of multipart {@link Part}s using
* {@link Part#getSubmittedFileName()} (Servlet 5.0+, Jetty 11.x).
* {@link Part#getSubmittedFileName()} (Servlet 3.1+, Jetty 11.0.x).
*
* @return list of filenames; never {@code null}, may be empty
*/
Expand All @@ -39,11 +49,80 @@ public static List<String> extractFilenames(Collection<Part> parts) {
}
} catch (Exception ignored) {
// malformed or inaccessible part — skip and continue with remaining parts
log.debug("extractFilenames: skipping malformed part", ignored);
}
}
return filenames;
}

/**
* Extracts file content from a collection of multipart {@link Part}s. Form fields (those with a
* {@code null} submitted filename) are skipped. Reads up to {@link #MAX_CONTENT_BYTES} bytes per
* part, up to {@link #MAX_FILES_TO_INSPECT} parts total.
*
* @return list of decoded content strings; never {@code null}, may be empty
*/
public static List<String> extractContents(Collection<Part> parts) {
if (parts == null || parts.isEmpty()) {
return Collections.emptyList();
}
List<String> contents = new ArrayList<>(Math.min(parts.size(), MAX_FILES_TO_INSPECT));
for (Part part : parts) {
if (contents.size() >= MAX_FILES_TO_INSPECT) {
break;
}
try {
if (part.getSubmittedFileName() == null) {
continue; // form field — skip
}
contents.add(readFileContent(part));
} catch (Exception ignored) {
log.debug("extractContents: skipping malformed part", ignored);
}
}
return contents;
}

private static String readFileContent(Part part) {
try (InputStream is = part.getInputStream()) {
return MultipartContentDecoder.readInputStream(is, MAX_CONTENT_BYTES, part.getContentType());
} catch (Exception e) {
log.debug("readFileContent: stream read failed", e);
return "";
}
}

/**
* Fires the {@code requestFilesContent} IG event and returns a {@link BlockingException} if the
* WAF requests blocking, or {@code null} otherwise.
*/
public static BlockingException fireFilesContentEvent(
Collection<Part> parts, RequestContext reqCtx) {
CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC);
BiFunction<RequestContext, List<String>, Flow<Void>> callback =
cbp.getCallback(EVENTS.requestFilesContent());
if (callback == null) {
return null;
}
List<String> contents = extractContents(parts);
if (contents.isEmpty()) {
return null;
}
Flow<Void> flow = callback.apply(reqCtx, contents);
Flow.Action action = flow.getAction();
if (action instanceof Flow.Action.RequestBlockingAction) {
Flow.Action.RequestBlockingAction rba = (Flow.Action.RequestBlockingAction) action;
BlockResponseFunction brf = reqCtx.getBlockResponseFunction();
if (brf != null) {
if (brf.tryCommitBlockingResponse(reqCtx.getTraceSegment(), rba)) {
reqCtx.getTraceSegment().effectivelyBlocked();
return new BlockingException("Blocked request (multipart file content)");
}
}
}
return null;
}

/**
* Fires the {@code requestFilesFilenames} IG event and returns a {@link BlockingException} if the
* WAF requests blocking, or {@code null} otherwise.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ static void after(
return;
}
t = MultipartHelper.fireFilenamesEvent(parts, reqCtx);
if (t == null) {
t = MultipartHelper.fireFilesContentEvent(parts, reqCtx);
}
}
}

Expand All @@ -188,6 +191,9 @@ static void after(
return;
}
t = MultipartHelper.fireFilenamesEvent(parts, reqCtx);
if (t == null) {
t = MultipartHelper.fireFilesContentEvent(parts, reqCtx);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import static org.mockito.Mockito.when;

import jakarta.servlet.http.Part;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -64,4 +69,79 @@ private Part part(String submittedFileName) {
when(p.getSubmittedFileName()).thenReturn(submittedFileName);
return p;
}

// ── extractContents ─────────────────────────────────────────────────────────

@Test
void extractContentsReturnsEmptyListForNull() {
assertEquals(emptyList(), MultipartHelper.extractContents(null));
}

@Test
void extractContentsReturnsEmptyListForEmpty() {
assertEquals(emptyList(), MultipartHelper.extractContents(emptyList()));
}

@Test
void extractContentsSkipsFormFieldParts() {
List<Part> parts = asList(part(null), part(null));
assertEquals(emptyList(), MultipartHelper.extractContents(parts));
}

@Test
void extractContentsIncludesFileWithEmptyFilename() throws IOException {
Part p = mock(Part.class);
when(p.getSubmittedFileName()).thenReturn("");
when(p.getInputStream())
.thenReturn(new ByteArrayInputStream("data".getBytes(StandardCharsets.UTF_8)));
when(p.getContentType()).thenReturn("text/plain; charset=UTF-8");
assertEquals(singletonList("data"), MultipartHelper.extractContents(singletonList(p)));
}

@Test
void extractContentsReadsFileContent() throws IOException {
Part p = mock(Part.class);
when(p.getSubmittedFileName()).thenReturn("photo.jpg");
when(p.getInputStream())
.thenReturn(new ByteArrayInputStream("file-content".getBytes(StandardCharsets.UTF_8)));
when(p.getContentType()).thenReturn("text/plain; charset=UTF-8");
assertEquals(singletonList("file-content"), MultipartHelper.extractContents(singletonList(p)));
}

@Test
void extractContentsTruncatesAtMaxContentBytes() throws IOException {
byte[] large = new byte[MultipartHelper.MAX_CONTENT_BYTES + 1];
Arrays.fill(large, (byte) 'A');
Part p = mock(Part.class);
when(p.getSubmittedFileName()).thenReturn("big.bin");
when(p.getInputStream()).thenReturn(new ByteArrayInputStream(large));
when(p.getContentType()).thenReturn(null);
List<String> contents = MultipartHelper.extractContents(singletonList(p));
assertEquals(1, contents.size());
assertEquals(MultipartHelper.MAX_CONTENT_BYTES, contents.get(0).length());
}

@Test
void extractContentsReturnsEmptyStringOnIOException() throws IOException {
Part p = mock(Part.class);
when(p.getSubmittedFileName()).thenReturn("file.txt");
when(p.getInputStream()).thenThrow(new IOException("simulated"));
assertEquals(singletonList(""), MultipartHelper.extractContents(singletonList(p)));
}

@Test
void extractContentsCappsAtMaxFilesToInspect() throws IOException {
int count = MultipartHelper.MAX_FILES_TO_INSPECT + 1;
List<Part> parts = new ArrayList<>(count);
for (int i = 0; i < count; i++) {
Part p = mock(Part.class);
when(p.getSubmittedFileName()).thenReturn("file" + i + ".txt");
when(p.getInputStream())
.thenReturn(new ByteArrayInputStream("c".getBytes(StandardCharsets.UTF_8)));
when(p.getContentType()).thenReturn(null);
parts.add(p);
}
List<String> contents = MultipartHelper.extractContents(parts);
assertEquals(MultipartHelper.MAX_FILES_TO_INSPECT, contents.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import static datadog.trace.api.gateway.Events.EVENTS;

import datadog.appsec.api.blocking.BlockingException;
import datadog.trace.api.Config;
import datadog.trace.api.gateway.BlockResponseFunction;
import datadog.trace.api.gateway.CallbackProvider;
import datadog.trace.api.gateway.Flow;
import datadog.trace.api.gateway.RequestContext;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.http.MultipartContentDecoder;
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
Expand Down Expand Up @@ -38,6 +40,9 @@ public class PartHelper {

private static final Logger log = LoggerFactory.getLogger(PartHelper.class);

public static final int MAX_CONTENT_BYTES = Config.get().getAppSecMaxFileContentBytes();
public static final int MAX_FILES_TO_INSPECT = Config.get().getAppSecMaxFileContentCount();

private PartHelper() {}

// Lazily resolves MultiPartInputStream.getParts() as a MethodHandle on first class access.
Expand Down Expand Up @@ -257,6 +262,75 @@ public static BlockingException fireFilenamesEvent(Collection<?> parts, RequestC
return null;
}

/**
* Extracts file content from a collection of multipart {@link Part}s. Form fields (those without
* a {@code filename} parameter in the {@code Content-Disposition} header) are skipped. Reads up
* to {@link #MAX_CONTENT_BYTES} bytes per part, up to {@link #MAX_FILES_TO_INSPECT} parts total.
*
* @return list of decoded content strings; never {@code null}, may be empty
*/
public static List<String> extractContents(Collection<?> parts) {
if (parts == null || parts.isEmpty()) {
return Collections.emptyList();
}
List<String> contents = new ArrayList<>(Math.min(parts.size(), MAX_FILES_TO_INSPECT));
for (Object obj : parts) {
if (contents.size() >= MAX_FILES_TO_INSPECT) {
break;
}
try {
Part part = (Part) obj;
if (filenameFromPart(part) == null) {
continue; // form field — skip
}
contents.add(readFileContent(part));
} catch (Exception e) {
log.debug("extractContents: skipping malformed part", e);
}
}
return contents;
}

private static String readFileContent(Part part) {
try (InputStream is = part.getInputStream()) {
return MultipartContentDecoder.readInputStream(is, MAX_CONTENT_BYTES, part.getContentType());
} catch (Exception e) {
log.debug("readFileContent: stream read failed", e);
return "";
}
}

/**
* Fires the {@code requestFilesContent} IG event for file-upload parts in {@code parts} and
* returns a {@link BlockingException} if the WAF requests blocking, or {@code null} otherwise.
*/
public static BlockingException fireFilesContentEvent(
Collection<?> parts, RequestContext reqCtx) {
CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC);
BiFunction<RequestContext, List<String>, Flow<Void>> callback =
cbp.getCallback(EVENTS.requestFilesContent());
if (callback == null) {
return null;
}
List<String> contents = extractContents(parts);
if (contents.isEmpty()) {
return null;
}
Flow<Void> flow = callback.apply(reqCtx, contents);
Flow.Action action = flow.getAction();
if (action instanceof Flow.Action.RequestBlockingAction) {
Flow.Action.RequestBlockingAction rba = (Flow.Action.RequestBlockingAction) action;
BlockResponseFunction brf = reqCtx.getBlockResponseFunction();
if (brf != null) {
if (brf.tryCommitBlockingResponse(reqCtx.getTraceSegment(), rba)) {
reqCtx.getTraceSegment().effectivelyBlocked();
return new BlockingException("Blocked request (multipart file content)");
}
}
}
return null;
}

private static String readPartContent(Part part) {
Charset charset = charsetFromContentType(part.getContentType());
try (InputStream is = part.getInputStream()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,11 @@ static void after(
}
BlockingException bodyBlock = PartHelper.fireBodyProcessedEvent(parts, reqCtx);
BlockingException filenamesBlock = PartHelper.fireFilenamesEvent(parts, reqCtx);
t = bodyBlock != null ? bodyBlock : filenamesBlock;
BlockingException contentBlock =
bodyBlock == null && filenamesBlock == null
? PartHelper.fireFilesContentEvent(parts, reqCtx)
: null;
t = bodyBlock != null ? bodyBlock : (filenamesBlock != null ? filenamesBlock : contentBlock);
}
}

Expand Down Expand Up @@ -192,7 +196,11 @@ static void after(
}
BlockingException bodyBlock = PartHelper.fireBodyProcessedEvent(parts, reqCtx);
BlockingException filenamesBlock = PartHelper.fireFilenamesEvent(parts, reqCtx);
t = bodyBlock != null ? bodyBlock : filenamesBlock;
BlockingException contentBlock =
bodyBlock == null && filenamesBlock == null
? PartHelper.fireFilesContentEvent(parts, reqCtx)
: null;
t = bodyBlock != null ? bodyBlock : (filenamesBlock != null ? filenamesBlock : contentBlock);
}
}
}
Loading