Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aikido_zen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

# Re-export functions :
from aikido_zen.context.users import set_user
from aikido_zen.context.set_tenant_id import set_tenant_id
from aikido_zen.vulnerabilities.idor.enable_idor_protection import (
enable_idor_protection,
)
from aikido_zen.helpers.check_gevent import check_gevent
from aikido_zen.helpers.python_version_not_supported import python_version_not_supported
from aikido_zen.middleware import should_block_request
Expand Down
1 change: 1 addition & 0 deletions aikido_zen/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, context_obj=None, body=None, req=None, source=None):
self.cookies = {}
self.query = {}
self.protection_forced_off = None
self.tenant_id = None

# Parse WSGI/ASGI/... request :
self.method = self.remote_address = self.url = None
Expand Down
31 changes: 31 additions & 0 deletions aikido_zen/context/set_tenant_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Exports set_tenant_id for setting the tenant ID on the current context.
"""

from aikido_zen.helpers.logging import logger
from . import get_current_context


def set_tenant_id(tenant_id):
"""
Sets the tenant ID on the current request context.
Used for IDOR protection to verify SQL queries filter on the correct tenant.
"""
if not isinstance(tenant_id, (str, int)):
logger.info(
"set_tenant_id(...) expects a string or integer, found %s instead.",
type(tenant_id),
)
return

str_id = str(tenant_id)
if len(str_id) == 0:
logger.info("set_tenant_id(...) expects a non-empty value.")
return

context = get_current_context()
if not context:
logger.debug("No context set, cannot set tenant_id")
return

context.tenant_id = str_id
82 changes: 82 additions & 0 deletions aikido_zen/context/set_tenant_id_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest
from . import current_context, Context
from .set_tenant_id import set_tenant_id


@pytest.fixture(autouse=True)
def run_around_tests():
yield
current_context.set(None)


def _create_context():
wsgi_request = {
"REQUEST_METHOD": "GET",
"HTTP_HEADER_1": "header 1 value",
"wsgi.url_scheme": "http",
"HTTP_HOST": "localhost:8080",
"PATH_INFO": "/hello",
"QUERY_STRING": "",
"CONTENT_TYPE": "application/json",
"REMOTE_ADDR": "198.51.100.23",
}
context = Context(req=wsgi_request, body=None, source="flask")
context.set_as_current_context()
return context


def test_set_tenant_id_string():
ctx = _create_context()
set_tenant_id("tenant_123")
assert ctx.tenant_id == "tenant_123"


def test_set_tenant_id_integer():
ctx = _create_context()
set_tenant_id(42)
assert ctx.tenant_id == "42"


def test_set_tenant_id_empty_string(caplog):
ctx = _create_context()
set_tenant_id("")
assert ctx.tenant_id is None
assert "non-empty" in caplog.text


def test_set_tenant_id_invalid_type(caplog):
ctx = _create_context()
set_tenant_id(12.34)
assert ctx.tenant_id is None
assert "expects a string or integer" in caplog.text


def test_set_tenant_id_none(caplog):
ctx = _create_context()
set_tenant_id(None)
assert ctx.tenant_id is None
assert "expects a string or integer" in caplog.text


def test_set_tenant_id_dict(caplog):
ctx = _create_context()
set_tenant_id({"id": 1})
assert ctx.tenant_id is None
assert "expects a string or integer" in caplog.text


def test_set_tenant_id_no_context(caplog):
import logging

# No context set — should not raise, tenant_id is not applied anywhere
with caplog.at_level(logging.DEBUG, logger="Zen"):
set_tenant_id("tenant_123")
assert "No context set" in caplog.text


def test_set_tenant_id_overwrites():
ctx = _create_context()
set_tenant_id("first")
assert ctx.tenant_id == "first"
set_tenant_id("second")
assert ctx.tenant_id == "second"
8 changes: 8 additions & 0 deletions aikido_zen/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,11 @@ class AikidoSSRF(AikidoException):
"""Exception because of SSRF"""

kind = "ssrf"


class AikidoIDOR(AikidoException):
"""Exception because of an IDOR vulnerability (missing or wrong tenant filter)"""

def __init__(self, message):
super().__init__(message)
self.message = message
5 changes: 5 additions & 0 deletions aikido_zen/sinks/asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from aikido_zen.helpers.get_argument import get_argument
from aikido_zen.helpers.register_call import register_call
from aikido_zen.sinks import patch_function, before, on_import
from aikido_zen.vulnerabilities.idor.check_idor import run_idor_check


@before
Expand All @@ -17,6 +18,10 @@ def _execute(func, instance, args, kwargs):

vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "postgres"))

# asyncpg uses variadic positional args for params: execute(query, *args)
query_params = args[1:] if len(args) > 1 else None
run_idor_check(query, "postgres", query_params)


@on_import("asyncpg.connection", "asyncpg", version_requirement="0.27.0")
def patch(m):
Expand Down
4 changes: 4 additions & 0 deletions aikido_zen/sinks/clickhouse_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from aikido_zen.helpers.register_call import register_call
from aikido_zen.sinks import before, on_import, patch_function
from aikido_zen.vulnerabilities import run_vulnerability_scan
from aikido_zen.vulnerabilities.idor.check_idor import run_idor_check


@before
Expand All @@ -13,6 +14,9 @@ def _execute(func, instance, args, kwargs):

run_vulnerability_scan("sql_injection", op, args=(query, "clickhouse"))

query_params = get_argument(args, kwargs, 1, "params")
run_idor_check(query, "clickhouse", query_params)


@on_import("clickhouse_driver", package="clickhouse_driver")
def patch(m):
Expand Down
7 changes: 7 additions & 0 deletions aikido_zen/sinks/mysqlclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import aikido_zen.vulnerabilities as vulns
from aikido_zen.helpers.register_call import register_call
from aikido_zen.sinks import patch_function, on_import, before
from aikido_zen.vulnerabilities.idor.check_idor import run_idor_check


@before
Expand All @@ -20,6 +21,9 @@ def _execute(func, instance, args, kwargs):
kind="sql_injection", op="MySQLdb.Cursor.execute", args=(query, "mysql")
)

query_params = get_argument(args, kwargs, 1, "args")
run_idor_check(query, "mysql", query_params)


@before
def _executemany(func, instance, args, kwargs):
Expand All @@ -30,6 +34,9 @@ def _executemany(func, instance, args, kwargs):
kind="sql_injection", op="MySQLdb.Cursor.executemany", args=(query, "mysql")
)

query_params = get_argument(args, kwargs, 1, "args")
run_idor_check(query, "mysql", query_params)


@on_import("MySQLdb.cursors", "mysqlclient", version_requirement="1.5.0")
def patch(m):
Expand Down
4 changes: 4 additions & 0 deletions aikido_zen/sinks/psycopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from aikido_zen.helpers.get_argument import get_argument
from aikido_zen.helpers.register_call import register_call
from aikido_zen.sinks import patch_function, on_import, before
from aikido_zen.vulnerabilities.idor.check_idor import run_idor_check


@before
Expand All @@ -25,6 +26,9 @@ def _execute(func, instance, args, kwargs):
op = f"psycopg.{instance.__class__.__name__}.{func.__name__}"
vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "postgres"))

query_params = get_argument(args, kwargs, 1, "params")
run_idor_check(query, "postgres", query_params)


@on_import("psycopg.cursor", "psycopg", version_requirement="3.1.0")
def patch(m):
Expand Down
4 changes: 4 additions & 0 deletions aikido_zen/sinks/psycopg2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from aikido_zen.helpers.get_argument import get_argument
from aikido_zen.helpers.register_call import register_call
from aikido_zen.sinks import on_import, before, patch_function, after
from aikido_zen.vulnerabilities.idor.check_idor import run_idor_check


@after
Expand Down Expand Up @@ -42,6 +43,9 @@ def psycopg2_patch(func, instance, args, kwargs):

vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "postgres"))

query_params = get_argument(args, kwargs, 1, "vars")
run_idor_check(query, "postgres", query_params)


@on_import("psycopg2")
def patch(m):
Expand Down
7 changes: 7 additions & 0 deletions aikido_zen/sinks/pymysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from aikido_zen.helpers.get_argument import get_argument
from aikido_zen.helpers.register_call import register_call
from aikido_zen.sinks import patch_function, on_import, before
from aikido_zen.vulnerabilities.idor.check_idor import run_idor_check


@before
Expand All @@ -20,6 +21,9 @@ def _execute(func, instance, args, kwargs):
kind="sql_injection", op="pymysql.Cursor.execute", args=(query, "mysql")
)

query_params = get_argument(args, kwargs, 1, "args")
run_idor_check(query, "mysql", query_params)


@before
def _executemany(func, instance, args, kwargs):
Expand All @@ -30,6 +34,9 @@ def _executemany(func, instance, args, kwargs):
kind="sql_injection", op="pymysql.Cursor.executemany", args=(query, "mysql")
)

query_params = get_argument(args, kwargs, 1, "args")
run_idor_check(query, "mysql", query_params)


@on_import("pymysql.cursors", "pymysql", version_requirement="0.9.0")
def patch(m):
Expand Down
21 changes: 21 additions & 0 deletions aikido_zen/storage/idor_protection_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
class IdorProtectionConfig:
def __init__(self, tenant_column_name, excluded_tables):
self.tenant_column_name = tenant_column_name
self.excluded_tables = excluded_tables


class IdorProtectionStore:
def __init__(self):
self.config = None

def get(self):
return self.config

def set(self, config):
self.config = config

def clear(self):
self.config = None


idor_protection_store = IdorProtectionStore()
Empty file.
Loading
Loading