Skip to content

Commit

Permalink
added filters support in threat detection APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
ag060 committed Jan 10, 2025
1 parent 3be068a commit 926bdf6
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,39 @@ public SuspectSampleDataAction() {
}

public String fetchSampleData() {
HttpPost post =
new HttpPost(
String.format("%s/api/dashboard/list_malicious_requests", this.getBackendUrl()));
HttpPost post = new HttpPost(
String.format("%s/api/dashboard/list_malicious_requests", this.getBackendUrl()));
post.addHeader("Authorization", "Bearer " + this.getApiToken());
post.addHeader("Content-Type", "application/json");

Map<String, Object> body =
new HashMap<String, Object>() {
{
put("skip", skip);
put("limit", LIMIT);
put("sort", sort);
}
};
Map<String, Object> filter = new HashMap<>();
if (this.ips != null && !this.ips.isEmpty()) {
filter.put("ips", this.ips);
}

if (this.urls != null && !this.urls.isEmpty()) {
filter.put("urls", this.urls);
}

Map<String, Integer> time_range = new HashMap<>();
if (this.startTimestamp > 0) {
time_range.put("start", this.startTimestamp);
}

if (this.endTimestamp > 0) {
time_range.put("end", this.endTimestamp);
}

filter.put("detected_at_time_range", time_range);

Map<String, Object> body = new HashMap<String, Object>() {
{
put("skip", skip);
put("limit", LIMIT);
put("sort", sort);
put("filter", filter);
}
};
String msg = objectMapper.valueToTree(body).toString();

StringEntity requestEntity = new StringEntity(msg, ContentType.APPLICATION_JSON);
Expand All @@ -65,24 +84,22 @@ public String fetchSampleData() {
String responseBody = EntityUtils.toString(resp.getEntity());

ProtoMessageUtils.<ListMaliciousRequestsResponse>toProtoMessage(
ListMaliciousRequestsResponse.class, responseBody)
ListMaliciousRequestsResponse.class, responseBody)
.ifPresent(
m -> {
this.maliciousEvents =
m.getMaliciousEventsList().stream()
.map(
smr ->
new DashboardMaliciousEvent(
smr.getId(),
smr.getActor(),
smr.getFilterId(),
smr.getEndpoint(),
URLMethods.Method.fromString(smr.getMethod()),
smr.getApiCollectionId(),
smr.getIp(),
smr.getCountry(),
smr.getDetectedAt()))
.collect(Collectors.toList());
this.maliciousEvents = m.getMaliciousEventsList().stream()
.map(
smr -> new DashboardMaliciousEvent(
smr.getId(),
smr.getActor(),
smr.getFilterId(),
smr.getEndpoint(),
URLMethods.Method.fromString(smr.getMethod()),
smr.getApiCollectionId(),
smr.getIp(),
smr.getCountry(),
smr.getDetectedAt()))
.collect(Collectors.toList());
this.total = m.getTotal();
});
} catch (Exception e) {
Expand All @@ -94,16 +111,15 @@ public String fetchSampleData() {
}

public String fetchFilters() {
HttpGet get =
new HttpGet(String.format("%s/api/dashboard/fetch_filters", this.getBackendUrl()));
HttpGet get = new HttpGet(String.format("%s/api/dashboard/fetch_filters", this.getBackendUrl()));
get.addHeader("Authorization", "Bearer " + this.getApiToken());
get.addHeader("Content-Type", "application/json");

try (CloseableHttpResponse resp = this.httpClient.execute(get)) {
String responseBody = EntityUtils.toString(resp.getEntity());

ProtoMessageUtils.<FetchAlertFiltersResponse>toProtoMessage(
FetchAlertFiltersResponse.class, responseBody)
FetchAlertFiltersResponse.class, responseBody)
.ifPresent(
msg -> {
this.ips = msg.getActorsList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,6 @@ function SusDataTable({ currDateRange, rowClicked }) {
});

filters = [
{
key: "apiCollectionId",
label: "Collection",
title: "Collection",
choices: apiCollectionFilterChoices,
},
{
key: "sourceIps",
label: "Source IP",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.FetchAlertFiltersResponse;
import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.ListMaliciousRequestsRequest;
import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.ListMaliciousRequestsResponse;
import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.TimeRangeFilter;
import com.akto.proto.generated.threat_detection.service.malicious_alert_service.v1.RecordMaliciousEventRequest;
import com.akto.threat.backend.client.IPLookupClient;
import com.akto.threat.backend.constants.KafkaTopic;
import com.akto.threat.backend.constants.MongoDBCollection;
import com.akto.threat.backend.db.AggregateSampleMaliciousEventModel;
import com.akto.threat.backend.db.MaliciousEventModel;
import com.akto.threat.backend.utils.KafkaUtils;
import com.mongodb.BasicDBObject;
import com.mongodb.client.DistinctIterable;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
Expand All @@ -28,6 +28,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.bson.Document;
import org.bson.conversions.Bson;

public class MaliciousEventService {
Expand Down Expand Up @@ -135,18 +136,39 @@ public ListMaliciousRequestsResponse listMaliciousRequests(
int limit = request.getLimit();
int skip = request.hasSkip() ? request.getSkip() : 0;
Map<String, Integer> sort = request.getSortMap();
ListMaliciousRequestsRequest.Filter filter = request.getFilter();

MongoCollection<MaliciousEventModel> coll =
this.mongoClient
.getDatabase(accountId)
.getCollection(
MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, MaliciousEventModel.class);

BasicDBObject query = new BasicDBObject();
Document query = new Document();
if (!filter.getActorsList().isEmpty()) {
query.append("actor", new Document("$in", filter.getActorsList()));
}

if (!filter.getUrlsList().isEmpty()) {
query.append("latestApiEndpoint", new Document("$in", filter.getUrlsList()));
}

if (!filter.getIpsList().isEmpty()) {
query.append("latestApiIp", new Document("$in", filter.getIpsList()));
}

if (filter.hasDetectedAtTimeRange()) {
TimeRangeFilter timeRange = filter.getDetectedAtTimeRange();
long start = timeRange.hasStart() ? timeRange.getStart() : 0;
long end = timeRange.hasEnd() ? timeRange.getEnd() : Long.MAX_VALUE;

query.append("detectedAt", new Document("$gte", start).append("$lte", end));
}

long total = coll.countDocuments(query);
try (MongoCursor<MaliciousEventModel> cursor =
coll.find(query)
.sort(new BasicDBObject("detectedAt", sort.getOrDefault("detectedAt", -1)))
.sort(new Document("detectedAt", sort.getOrDefault("detectedAt", -1)))
.skip(skip)
.limit(limit)
.cursor()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,30 @@ public ListThreatActorResponse listThreatActors(
.getDatabase(accountId)
.getCollection(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, Document.class);

ListThreatActorsRequest.Filter filter = request.getFilter();

List<Document> base = new ArrayList<>();

Document match = new Document();

if (!filter.getActorsList().isEmpty()) {
match.append("actor", new Document("$in", filter.getActorsList()));
}

if (!filter.getLatestIpsList().isEmpty()) {
match.append("latestApiIp", new Document("$in", filter.getLatestIpsList()));
}

if (filter.hasDetectedAtTimeRange()) {
long start = filter.getDetectedAtTimeRange().getStart();
long end = filter.getDetectedAtTimeRange().getEnd();
match.append("detectedAt", new Document("$gte", start).append("$lte", end));
}

if (!match.isEmpty()) {
base.add(new Document("$match", match));
}

base.add(new Document("$sort", new Document("detectedAt", -1)));
base.add(
new Document(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@ public ListThreatApiResponse listThreatApis(String accountId, ListThreatApiReque
.getCollection(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, Document.class);

List<Document> base = new ArrayList<>();
ListThreatApiRequest.Filter filter = request.getFilter();

Document match = new Document();
if (!filter.getMethodsList().isEmpty()) {
match.append("latestApiMethod", new Document("$in", filter.getMethodsList()));
}

if (!filter.getUrlsList().isEmpty()) {
match.append("latestApiEndpoint", new Document("$in", filter.getUrlsList()));
}

if (filter.hasDetectedAtTimeRange()) {
long start = filter.getDetectedAtTimeRange().getStart();
long end = filter.getDetectedAtTimeRange().getEnd();
match.append("detectedAt", new Document("$gte", start).append("$lte", end));
}

if (!match.isEmpty()) {
base.add(new Document("$match", match));
}

base.add(new Document("$sort", new Document("detectedAt", -1))); // sort
base.add(
new Document(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,24 @@ message ListMaliciousRequestsResponse {
uint64 total = 2;
}

message TimeRangeFilter {
optional uint64 start = 1;
optional uint64 end = 2;
}

message ListMaliciousRequestsRequest {
message Filter {
repeated string actors = 1;
repeated string urls = 2;
repeated string ips = 3;
optional TimeRangeFilter detected_at_time_range = 4;
}

// The number of alerts to return
optional uint32 skip = 1;
uint32 limit = 2;
map<string, int32> sort = 3;
Filter filter = 4;
}

message FetchAlertFiltersRequest {}
Expand All @@ -43,9 +56,16 @@ message FetchAlertFiltersResponse {
}

message ListThreatActorsRequest {
message Filter {
repeated string actors = 1;
repeated string latest_ips = 2;
optional TimeRangeFilter detected_at_time_range = 3;
}

optional uint32 skip = 1;
uint32 limit = 2;
map<string, int32> sort = 3;
Filter filter = 4;
}

message ListThreatActorResponse {
Expand All @@ -62,9 +82,16 @@ message ListThreatActorResponse {
}

message ListThreatApiRequest {
message Filter {
repeated string urls = 1;
repeated string methods = 2;
optional TimeRangeFilter detected_at_time_range = 3;
}

optional uint32 skip = 1;
uint32 limit = 2;
map<string, int32> sort = 3;
Filter filter = 4;
}

message ListThreatApiResponse {
Expand Down

0 comments on commit 926bdf6

Please sign in to comment.