Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for endpoint discovery in spring mvc #8352

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dd-java-agent/appsec/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies {
testImplementation project(':utils:test-utils')
testImplementation group: 'org.hamcrest', name: 'hamcrest', version: '2.2'
testImplementation group: 'com.flipkart.zjsonpatch', name: 'zjsonpatch', version: '0.4.11'
testImplementation(group: 'org.skyscreamer', name: 'jsonassert', version: '1.5.1')

testFixturesApi project(':dd-java-agent:testing')
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package com.datadog.appsec.api.security.json;

import com.squareup.moshi.JsonWriter;
import com.squareup.moshi.ToJson;
import datadog.trace.api.appsec.api.security.model.Endpoint;
import java.io.IOException;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

public class EndpointAdapter {

@ToJson
public void toJson(@Nonnull final JsonWriter jsonWriter, @Nullable final Endpoint endpoint)
throws IOException {
if (endpoint == null) {
jsonWriter.nullValue();
} else {
jsonWriter.beginObject();
jsonWriter.name("type");
jsonWriter.value(endpoint.getType().name());
jsonWriter.name("method");
jsonWriter.value(endpoint.getMethod().getName());
jsonWriter.name("path");
jsonWriter.value(endpoint.getPath());
jsonWriter.name("operation-name");
jsonWriter.value(endpoint.getOperation().getName());
jsonWriter.name("request-body-type");
jsonWriter.jsonValue(endpoint.getRequestBodyType());
jsonWriter.name("response-body-type");
jsonWriter.jsonValue(endpoint.getResponseBodyType());
jsonWriter.name("response-code");
jsonWriter.jsonValue(endpoint.getResponseCode());
jsonWriter.name("authentication");
jsonWriter.jsonValue(endpoint.getAuthentication());
jsonWriter.name("metadata");
jsonWriter.jsonValue(endpoint.getMetadata());
jsonWriter.endObject();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.datadog.appsec.api.security.json;

import com.squareup.moshi.JsonAdapter;
import com.squareup.moshi.Moshi;
import datadog.trace.api.appsec.api.security.model.Endpoint;
import java.util.List;

public class EndpointsEncoding {

private static final JsonAdapter<Endpoints> JSON_ADAPTER =
new Moshi.Builder().add(new EndpointAdapter()).build().adapter(Endpoints.class);

public static String toJson(final List<Endpoint> endpoints) {
final Endpoints target = new Endpoints();
target.setEndpoints(endpoints);
return JSON_ADAPTER.toJson(target);
}

public static class Endpoints {

private List<Endpoint> endpoints;

public List<Endpoint> getEndpoints() {
return endpoints;
}

public void setEndpoints(final List<Endpoint> endpoints) {
this.endpoints = endpoints;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.datadog.appsec.report.AppSecEventWrapper;
import datadog.trace.api.Config;
import datadog.trace.api.UserIdCollectionMode;
import datadog.trace.api.appsec.api.security.model.Endpoint;
import datadog.trace.api.gateway.Events;
import datadog.trace.api.gateway.Flow;
import datadog.trace.api.gateway.IGSpanInfo;
Expand Down Expand Up @@ -55,13 +56,17 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -171,6 +176,10 @@ public void init() {
subscriptionService.registerCallback(
EVENTS.requestBodyProcessed(), this::onRequestBodyProcessed);
}

if (Config.get().isApiSecurityEndpointCollectionEnabled()) {
subscriptionService.registerCallback(EVENTS.endpoints(), this::onEndpoints);
}
}

/**
Expand All @@ -197,6 +206,13 @@ public void reset() {
shellCmdSubInfo = null;
}

private void onEndpoints(final Iterator<Endpoint> endpoints) {
// TODO: do something with the endpoints
StreamSupport.stream(Spliterators.spliteratorUnknownSize(endpoints, Spliterator.ORDERED), false)
.limit(Config.get().getApiSecurityEndpointCollectionMessageLimit())
.forEach(System.out::println);
}

private Flow<Void> onUser(
final RequestContext ctx_, final UserIdCollectionMode mode, final String originalUser) {
if (mode == DISABLED) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package com.datadog.appsec.api.security.model

import com.datadog.appsec.api.security.json.EndpointsEncoding
import datadog.trace.api.appsec.api.security.model.Endpoint
import datadog.trace.test.util.DDSpecification
import org.skyscreamer.jsonassert.JSONAssert
import org.skyscreamer.jsonassert.JSONCompareMode

import static datadog.trace.api.appsec.api.security.model.Endpoint.Method.POST
import static datadog.trace.api.appsec.api.security.model.Endpoint.Operation.HTTP_REQUEST
import static datadog.trace.api.appsec.api.security.model.Endpoint.Type.REST

class EndpointsEncodingTest extends DDSpecification {

void 'test json encoding of endpoints'() {
when:
final json = EndpointsEncoding.toJson(test.v1)

then:
JSONAssert.assertEquals("Endpoints payload should match", test.v2, json, JSONCompareMode.NON_EXTENSIBLE)

where:
test << buildEndpoints()
}

static List<Tuple2<List<Endpoint>, String>> buildEndpoints() {
return [
Tuple.tuple([
new Endpoint(type: REST,
method: POST,
path: '/analytics/requests',
operation: HTTP_REQUEST,
requestBodyType: ['application/json'],
responseBodyType: ['application/json'],
responseCode: [200, 201],
authentication: ['JWT'],
metadata: ['dotnet-ignore-anti-forgery': true, 'deprecated': true])
],
"""
{
"endpoints": [
{
"type": "REST",
"method": "POST",
"path": "/analytics/requests",
"operation-name": "http.request",
"request-body-type": ["application/json"],
"response-body-type": ["application/json"],
"response-code": [200, 201],
"authentication": ["JWT"],
"metadata": {
"dotnet-ignore-anti-forgery": true,
"deprecated": true
}
}
]
}
""")
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package datadog.trace.instrumentation.springweb;

import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named;
import static net.bytebuddy.matcher.ElementMatchers.isMethod;
import static net.bytebuddy.matcher.ElementMatchers.isProtected;
import static net.bytebuddy.matcher.ElementMatchers.takesArgument;
import static net.bytebuddy.matcher.ElementMatchers.takesArguments;

import com.google.auto.service.AutoService;
import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.agent.tooling.InstrumenterModule;
import datadog.trace.api.appsec.api.security.model.Endpoint;
import datadog.trace.api.gateway.CallbackProvider;
import datadog.trace.api.gateway.Events;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
import java.util.Iterator;
import java.util.Map;
import java.util.function.Consumer;
import net.bytebuddy.asm.Advice;
import org.springframework.context.ApplicationContext;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.mvc.method.RequestMappingInfo;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;

@AutoService(InstrumenterModule.class)
public class AppSecDispatcherServletInstrumentation extends InstrumenterModule.AppSec
implements Instrumenter.ForSingleType, Instrumenter.HasMethodAdvice {

public AppSecDispatcherServletInstrumentation() {
super("spring-web");
}

@Override
public String instrumentedType() {
return "org.springframework.web.servlet.DispatcherServlet";
}

@Override
public String[] helperClassNames() {
return new String[] {packageName + ".RequestMappingInfoInterator"};
}

@Override
public void methodAdvice(MethodTransformer transformer) {
transformer.applyAdvice(
isMethod()
.and(isProtected())
.and(named("onRefresh"))
.and(takesArgument(0, named("org.springframework.context.ApplicationContext")))
.and(takesArguments(1)),
AppSecDispatcherServletInstrumentation.class.getName() + "$AppSecHandlerMappingAdvice");
}

public static class AppSecHandlerMappingAdvice {

@Advice.OnMethodExit(suppress = Throwable.class)
public static void afterRefresh(@Advice.Argument(0) final ApplicationContext springCtx) {

final CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC);
if (cbp == null) {
return;
}
final Consumer<Iterator<Endpoint>> callback = cbp.getCallback(Events.get().endpoints());
if (callback == null) {
return;
}
final RequestMappingHandlerMapping handler =
springCtx.getBean(RequestMappingHandlerMapping.class);
if (handler == null) {
return;
}
final Map<RequestMappingInfo, HandlerMethod> mappings = handler.getHandlerMethods();
if (mappings == null || mappings.isEmpty()) {
return;
}
callback.accept(new RequestMappingInfoInterator(mappings));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package datadog.trace.instrumentation.springweb;

import static datadog.trace.api.appsec.api.security.model.Endpoint.Operation.HTTP_REQUEST;
import static datadog.trace.api.appsec.api.security.model.Endpoint.Type.REST;

import datadog.trace.api.appsec.api.security.model.Endpoint;
import datadog.trace.api.appsec.api.security.model.Endpoint.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Queue;
import java.util.Set;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.mvc.condition.MediaTypeExpression;
import org.springframework.web.servlet.mvc.method.RequestMappingInfo;

public class RequestMappingInfoInterator implements Iterator<Endpoint> {

private final Iterator<Map.Entry<RequestMappingInfo, HandlerMethod>> delegate;
private final Queue<Endpoint> queue = new LinkedList<>();

public RequestMappingInfoInterator(final Map<RequestMappingInfo, HandlerMethod> mappings) {
delegate = mappings.entrySet().iterator();
fetchNext();
}

@Override
public boolean hasNext() {
return !queue.isEmpty();
}

@Override
public Endpoint next() {
Endpoint result = queue.poll();
if (result == null) {
throw new NoSuchElementException();
}
if (queue.isEmpty()) {
fetchNext();
}
return result;
}

private void fetchNext() {
if (!delegate.hasNext()) {
return;
}
final Map.Entry<RequestMappingInfo, HandlerMethod> nextEntry = delegate.next();
final RequestMappingInfo nextInfo = nextEntry.getKey();
final HandlerMethod nextHandler = nextEntry.getValue();
for (final String path : nextInfo.getPatternsCondition().getPatterns()) {
final List<Method> methods = new LinkedList<>();
if (nextInfo.getMethodsCondition().getMethods().isEmpty()) {
methods.add(Method.ALL);
} else {
for (final RequestMethod method : nextInfo.getMethodsCondition().getMethods()) {
methods.add(Method.parseMethod(method.name()));
}
}
for (final Method method : methods) {
final Endpoint endpoint = new Endpoint();
endpoint.setType(REST);
endpoint.setOperation(HTTP_REQUEST);
endpoint.setPath(path);
endpoint.setMethod(method);
endpoint.setRequestBodyType(
parseMediaTypes(nextInfo.getConsumesCondition().getExpressions()));
endpoint.setResponseBodyType(
parseMediaTypes(nextInfo.getProducesCondition().getExpressions()));
final Map<String, Object> metadata = new HashMap<>();
metadata.put("handler", nextHandler.toString());
endpoint.setMetadata(metadata);
queue.add(endpoint);
}
}
}

private List<String> parseMediaTypes(final Set<MediaTypeExpression> expressions) {
if (expressions.isEmpty()) {
return null;
}
final List<String> result = new ArrayList<>(expressions.size());
for (final MediaTypeExpression expression : expressions) {
result.add(expression.toString());
}
return result;
}
}
Loading
Loading