diff --git a/.vscode/settings.json b/.vscode/settings.json index e64211a..9a6d501 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -6,5 +6,10 @@ "source.organizeImports": "explicit" }, "editor.defaultFormatter": "charliermarsh.ruff" - } + }, + "python.testing.pytestArgs": [ + "dreadnode_cli" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } \ No newline at end of file diff --git a/CLI.md b/CLI.md index 1ab3889..b5d670c 100644 --- a/CLI.md +++ b/CLI.md @@ -88,7 +88,8 @@ $ dreadnode agent init [OPTIONS] STRIKE * `-d, --dir DIRECTORY`: The directory to initialize [default: .] * `-n, --name TEXT`: The project name (used for container naming) -* `-t, --template [rigging_basic|rigging_loop]`: The template to use for the agent [default: rigging_basic] +* `-t, --template [rigging_basic|rigging_loop|nerve_basic]`: The template to use for the agent [default: rigging_basic] +* `-s, --source TEXT`: Initialize the agent using a custom template from a github repository, ZIP archive URL or local folder * `--help`: Show this message and exit. ### `dreadnode agent latest` diff --git a/README.md b/README.md index 25f5c42..710e155 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,15 @@ dreadnode agent init -t # initialize a new agent in the specified directory dreadnode agent init -t --dir +# initialize a new agent using a custom template from a github repository +dreadnode agent init -s username/repository + +# initialize a new agent using a custom template from a github branch/tag +dreadnode agent init -s username/repository@custom-feature + +# initialize a new agent using a custom template from a ZIP archive URL +dreadnode agent init -s https://example.com/template-archive.zip + # push a new version of the agent dreadnode agent push diff --git a/dreadnode_cli/agent/cli.py b/dreadnode_cli/agent/cli.py index abbb15e..1cf4511 100644 --- a/dreadnode_cli/agent/cli.py +++ b/dreadnode_cli/agent/cli.py @@ -1,4 +1,5 @@ import pathlib +import shutil import time import typing as t @@ -21,9 +22,10 @@ format_strikes, format_templates, ) -from dreadnode_cli.agent.templates import Template, install_template +from dreadnode_cli.agent.templates import Template, install_template, install_template_from_dir from dreadnode_cli.config import UserConfig -from dreadnode_cli.utils import pretty_cli +from dreadnode_cli.types import GithubRepo +from dreadnode_cli.utils import download_and_unzip_archive, pretty_cli, repo_exists cli = typer.Typer(no_args_is_help=True) @@ -48,7 +50,17 @@ def init( template: t.Annotated[ Template, typer.Option("--template", "-t", help="The template to use for the agent") ] = Template.rigging_basic, + source: t.Annotated[ + str | None, + typer.Option( + "--source", + "-s", + help="Initialize the agent using a custom template from a github repository, ZIP archive URL or local folder", + ), + ] = None, ) -> None: + print(f":coffee: Fetching strike '{strike}' ...") + client = api.create_client() try: @@ -56,12 +68,11 @@ def init( except Exception as e: raise Exception(f"Failed to find strike '{strike}': {e}") from e - print() print(f":crossed_swords: Linking to strike '{strike_response.name}' ({strike_response.type})") - print() + project_name = Prompt.ask("Project name?", default=name or directory.name) - template = Template(Prompt.ask("Template?", choices=[t.value for t in Template], default=template.value)) + print() directory.mkdir(exist_ok=True) @@ -69,12 +80,62 @@ def init( AgentConfig.read(directory) if Prompt.ask(":axe: Agent config exists, overwrite?", choices=["y", "n"], default="n") == "n": return + print() except Exception: pass - AgentConfig(project_name=project_name, strike=strike).write(directory=directory) + context = {"project_name": project_name, "strike": strike_response} + + if source is None: + # initialize from builtin template + template = Template(Prompt.ask("Template?", choices=[t.value for t in Template], default=template.value)) + + install_template(template, directory, context) + else: + source_dir = pathlib.Path(source) + cleanup = False + + if not source_dir.exists(): + # source is not a local folder, so it can be: + # - full ZIP archive URL + # - github compatible reference + + try: + github_repo = GithubRepo(source) + + # Check if the repo is accessible + if repo_exists(github_repo): + source_dir = download_and_unzip_archive(github_repo.zip_url) + + # This could be a private repo that the user can access + # by getting an access token from our API + elif github_repo.namespace == "dreadnode" and ( + github_access_token := client.get_github_access_token([github_repo.repo]) + ): + print(":key: Accessed private repository") + source_dir = download_and_unzip_archive( + github_repo.api_zip_url, headers={"Authorization": f"Bearer {github_access_token.token}"} + ) - install_template(template, directory, {"project_name": project_name, "strike": strike_response}) + else: + raise Exception(f"Repository '{github_repo}' not found or inaccessible") + + except ValueError: + source_dir = download_and_unzip_archive(source) + + # make sure the temporary directory is cleaned up + cleanup = True + + try: + # initialize from local folder, validation performed inside install_template_from_dir + install_template_from_dir(source_dir, directory, context) + except Exception: + if cleanup and source_dir.exists(): + shutil.rmtree(source_dir) + raise + + # Wait to write this until after the template is installed + AgentConfig(project_name=project_name, strike=strike).write(directory=directory) print() print(f"Initialized [b]{directory}[/]") diff --git a/dreadnode_cli/agent/templates/__init__.py b/dreadnode_cli/agent/templates/__init__.py index d2b5c56..8d5934d 100644 --- a/dreadnode_cli/agent/templates/__init__.py +++ b/dreadnode_cli/agent/templates/__init__.py @@ -26,20 +26,59 @@ def template_description(template: Template) -> str: def install_template(template: Template, dest: pathlib.Path, context: dict[str, t.Any]) -> None: """Install a template into a directory.""" - src = TEMPLATES_DIR / template.value - env = Environment(loader=FileSystemLoader(src)) + install_template_from_dir(TEMPLATES_DIR / template.value, dest, context) + + +def install_template_from_dir(src: pathlib.Path, dest: pathlib.Path, context: dict[str, t.Any]) -> None: + """Install a template from a source directory into a destination directory.""" + + if not src.exists(): + raise Exception(f"Source directory '{src}' does not exist") + + elif not src.is_dir(): + raise Exception(f"Source '{src}' is not a directory") - for src_item in src.iterdir(): - dest_item = dest / src_item.name - content = src_item.read_text() + elif not (src / "Dockerfile").exists() and not (src / "Dockerfile.j2").exists(): + # if src has been downloaded from a ZIP archive, it may contain a single + # 'project-main' folder, that is the actual source we want to use. + # Check if src contains only one folder and update it if so. + subdirs = [d for d in src.iterdir() if d.is_dir()] + if len(subdirs) == 1: + src = subdirs[0] - if src_item.name.endswith(".j2"): - j2_template = env.get_template(src_item.name) - content = j2_template.render(context) - dest_item = dest / src_item.name.removesuffix(".j2") + # check again for Dockerfile in the subdirectory + if not (src / "Dockerfile").exists() and not (src / "Dockerfile.j2").exists(): + raise Exception("Source directory does not contain a Dockerfile") - if dest_item.exists(): - if Prompt.ask(f":axe: Overwrite {dest_item.name}?", choices=["y", "n"], default="n") == "n": + env = Environment(loader=FileSystemLoader(src)) + + # iterate over all items in the source directory + for src_item in src.glob("**/*"): + # get the relative path of the item + src_item_path = str(src_item.relative_to(src)) + # get the destination path + dest_item = dest / src_item_path + + # if the destination item is not the root directory and it exists, + # ask the user if they want to overwrite it + if dest_item != dest and dest_item.exists(): + if Prompt.ask(f":axe: Overwrite {dest_item}?", choices=["y", "n"], default="n") == "n": continue - dest_item.write_text(content) + # if the source item is a file + if src_item.is_file(): + # if the file has a .j2 extension, render it using Jinja2 + if src_item.name.endswith(".j2"): + # we can read as text + content = src_item.read_text() + j2_template = env.get_template(src_item_path) + content = j2_template.render(context) + dest_item = dest / src_item_path.removesuffix(".j2") + dest_item.write_text(content) + else: + # otherwise, copy the file as is + dest_item.write_bytes(src_item.read_bytes()) + + # if the source item is a directory, create it in the destination + elif src_item.is_dir(): + dest_item.mkdir(exist_ok=True) diff --git a/dreadnode_cli/agent/tests/test_templates.py b/dreadnode_cli/agent/tests/test_templates.py index 766094f..6a79830 100644 --- a/dreadnode_cli/agent/tests/test_templates.py +++ b/dreadnode_cli/agent/tests/test_templates.py @@ -1,6 +1,8 @@ import pathlib from unittest.mock import patch +import pytest + from dreadnode_cli.agent import templates @@ -11,3 +13,139 @@ def test_templates_install(tmp_path: pathlib.Path) -> None: assert (tmp_path / "requirements.txt").exists() assert (tmp_path / "Dockerfile").exists() assert (tmp_path / "agent.py").exists() + + +def test_templates_install_from_dir(tmp_path: pathlib.Path) -> None: + templates.install_template_from_dir(templates.TEMPLATES_DIR / "rigging_basic", tmp_path, {"name": "World"}) + + assert (tmp_path / "requirements.txt").exists() + assert (tmp_path / "Dockerfile").exists() + assert (tmp_path / "agent.py").exists() + + +def test_templates_install_from_dir_with_dockerfile_template(tmp_path: pathlib.Path) -> None: + # create source directory + source_dir = tmp_path / "source" + source_dir.mkdir() + + # create a Dockerfile.j2 template + dockerfile_content = """ +FROM python:3.9 +WORKDIR /app +ENV APP_NAME={{name}} +COPY . . +CMD ["python", "app.py"] +""" + (source_dir / "Dockerfile.j2").write_text(dockerfile_content) + + # create destination directory + dest_dir = tmp_path / "dest" + dest_dir.mkdir() + + # install template + templates.install_template_from_dir(source_dir, dest_dir, {"name": "TestContainer"}) + + # verify Dockerfile was rendered correctly + expected_dockerfile = """ +FROM python:3.9 +WORKDIR /app +ENV APP_NAME=TestContainer +COPY . . +CMD ["python", "app.py"] +""" + assert (dest_dir / "Dockerfile").exists() + assert (dest_dir / "Dockerfile").read_text().strip() == expected_dockerfile.strip() + + +def test_templates_install_from_dir_nested_structure(tmp_path: pathlib.Path) -> None: + # create source directory with nested structure + source_dir = tmp_path / "source" + source_dir.mkdir() + + # create some regular files + (source_dir / "Dockerfile").touch() + (source_dir / "README.md").write_text("# Test Project") + + # create nested folders with files + config_dir = source_dir / "config" + config_dir.mkdir() + (config_dir / "settings.json").write_text('{"debug": true}') + + templates_dir = source_dir / "templates" + templates_dir.mkdir() + (templates_dir / "base.html.j2").write_text("Hello {{name}}!") + + src_dir = source_dir / "src" + src_dir.mkdir() + (src_dir / "main.py").touch() + + # deeper nested folder + utils_dir = src_dir / "utils" + utils_dir.mkdir() + (utils_dir / "helpers.py").touch() + (utils_dir / "config.py.j2").write_text("APP_NAME = '{{name}}'") + + # create destination directory + dest_dir = tmp_path / "dest" + dest_dir.mkdir() + + # install template + templates.install_template_from_dir(source_dir, dest_dir, {"name": "TestApp"}) + + # verify regular files were copied + assert (dest_dir / "Dockerfile").exists() + assert (dest_dir / "README.md").read_text() == "# Test Project" + + # verify nested structure and files + assert (dest_dir / "config" / "settings.json").read_text() == '{"debug": true}' + assert (dest_dir / "src" / "main.py").exists() + assert (dest_dir / "src" / "utils" / "helpers.py").exists() + + # verify j2 templates were rendered correctly + assert (dest_dir / "templates" / "base.html").read_text() == "Hello TestApp!" + assert (dest_dir / "src" / "utils" / "config.py").read_text() == "APP_NAME = 'TestApp'" + + +def test_templates_install_from_dir_missing_source(tmp_path: pathlib.Path) -> None: + source_dir = tmp_path / "nonexistent" + with pytest.raises(Exception, match="Source directory '.*' does not exist"): + templates.install_template_from_dir(source_dir, tmp_path, {"name": "World"}) + + +def test_templates_install_from_dir_source_is_file(tmp_path: pathlib.Path) -> None: + source_file = tmp_path / "source.txt" + source_file.touch() + + with pytest.raises(Exception, match="Source '.*' is not a directory"): + templates.install_template_from_dir(source_file, tmp_path, {"name": "World"}) + + +def test_templates_install_from_dir_missing_dockerfile(tmp_path: pathlib.Path) -> None: + source_dir = tmp_path / "source" + source_dir.mkdir() + (source_dir / "agent.py").touch() + + with pytest.raises(Exception, match="Source directory does not contain a Dockerfile"): + templates.install_template_from_dir(source_dir, tmp_path, {"name": "World"}) + + +def test_templates_install_from_dir_single_inner_folder(tmp_path: pathlib.Path) -> None: + # create a source directory with a single inner folder to simulate a github zip archive + source_dir = tmp_path / "source" + source_dir.mkdir() + inner_dir = source_dir / "project-main" + inner_dir.mkdir() + + # create a Dockerfile in the inner directory + (inner_dir / "Dockerfile").touch() + (inner_dir / "agent.py").touch() + + dest_dir = tmp_path / "dest" + dest_dir.mkdir() + + # install from the outer directory - should detect and use inner directory + templates.install_template_from_dir(source_dir, dest_dir, {"name": "World"}) + + # assert files were copied from inner directory + assert (dest_dir / "Dockerfile").exists() + assert (dest_dir / "agent.py").exists() diff --git a/dreadnode_cli/api.py b/dreadnode_cli/api.py index 2142aaa..692b730 100644 --- a/dreadnode_cli/api.py +++ b/dreadnode_cli/api.py @@ -229,6 +229,19 @@ def submit_challenge_flag(self, challenge: str, flag: str) -> bool: response = self.request("POST", f"/api/challenges/{challenge}/submit-flag", json_data={"flag": flag}) return bool(response.json().get("correct", False)) + # Github + + class GithubTokenResponse(BaseModel): + token: str + expires_at: datetime + repos: list[str] + + def get_github_access_token(self, repos: list[str]) -> GithubTokenResponse | None: + """Try to get a GitHub access token for the given repositories.""" + + response = self.request("POST", "/api/github/token", json_data={"repos": repos}) + return self.GithubTokenResponse(**response.json()) if response.status_code == 200 else None + # Strikes StrikeRunStatus = t.Literal[ diff --git a/dreadnode_cli/tests/test_types.py b/dreadnode_cli/tests/test_types.py new file mode 100644 index 0000000..1d40da6 --- /dev/null +++ b/dreadnode_cli/tests/test_types.py @@ -0,0 +1,203 @@ +import pytest + +from dreadnode_cli.types import GithubRepo + + +def test_github_repo_simple_format() -> None: + repo = GithubRepo("owner/repo") + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == "main" + assert str(repo) == "owner/repo@main" + + +def test_github_repo_simple_format_with_ref() -> None: + repo = GithubRepo("owner/repo/tree/develop") + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == "develop" + assert str(repo) == "owner/repo@develop" + + +@pytest.mark.parametrize( + "case", + [ + "https://github.com/owner/repo", + "http://github.com/owner/repo", + "https://github.com/owner/repo.git", + ], +) +def test_github_repo_https_url(case: str) -> None: + repo = GithubRepo(case) + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == "main" + assert str(repo) == "owner/repo@main" + + +@pytest.mark.parametrize( + "case, expected_ref", + [ + ("https://github.com/owner/repo/tree/feature/custom-branch", "feature/custom-branch"), + ("https://github.com/owner/repo/blob/feature/custom-branch", "feature/custom-branch"), + ], +) +def test_github_repo_https_url_with_ref(case: str, expected_ref: str) -> None: + repo = GithubRepo(case) + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == expected_ref + assert str(repo) == f"owner/repo@{expected_ref}" + + +@pytest.mark.parametrize( + "case", + [ + "git@github.com:owner/repo", + "git@github.com:owner/repo.git", + ], +) +def test_github_repo_ssh_url(case: str) -> None: + repo = GithubRepo(case) + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == "main" + assert str(repo) == "owner/repo@main" + + +@pytest.mark.parametrize( + "case, expected_ref", + [ + ("https://raw.githubusercontent.com/owner/repo/main", "main"), + ("https://raw.githubusercontent.com/owner/repo/feature-branch", "feature-branch"), + ("https://raw.githubusercontent.com/owner/repo/feature/branch", "feature/branch"), + ], +) +def test_github_repo_raw_githubusercontent(case: str, expected_ref: str) -> None: + repo = GithubRepo(case) + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == expected_ref + assert str(repo) == f"owner/repo@{expected_ref}" + + +@pytest.mark.parametrize( + "input_str, expected_ref", + [ + ("owner/repo/tree/feature/custom", "feature/custom"), + ("owner/repo/releases/tag/v1.0.0", "v1.0.0"), + ], +) +def test_github_repo_ref_handling(input_str: str, expected_ref: str) -> None: + """Test handling of different reference formats""" + repo = GithubRepo(input_str) + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == expected_ref + assert repo.zip_url == f"https://github.com/owner/repo/zipball/{expected_ref}" + + +@pytest.mark.parametrize( + "case", + [ + "owner/repo.js", + "https://github.com/owner/repo.js", + "git@github.com:owner/repo.js.git", + ], +) +def test_github_repo_with_dots(case: str) -> None: + """Test repositories with dots in names""" + repo = GithubRepo(case) + assert repo.namespace == "owner" + assert repo.repo == "repo.js" + assert str(repo) == "owner/repo.js@main" + + +@pytest.mark.parametrize( + "case", + [ + "owner-name/repo-name", + "https://github.com/owner-name/repo-name", + "git@github.com:owner-name/repo-name.git", + ], +) +def test_github_repo_with_dashes(case: str) -> None: + """Test repositories with dashes in names""" + repo = GithubRepo(case) + assert repo.namespace == "owner-name" + assert repo.repo == "repo-name" + assert str(repo) == "owner-name/repo-name@main" + + +@pytest.mark.parametrize( + "case", + [ + " owner/repo ", + "\nowner/repo\n", + "\towner/repo\t", + ], +) +def test_github_repo_whitespace_handling(case: str) -> None: + """Test that whitespace is properly stripped""" + repo = GithubRepo(case) + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert str(repo) == "owner/repo@main" + + +@pytest.mark.parametrize( + "case", + [ + "", # Empty string + "owner", # Missing repo + "owner/", # Missing repo + "/repo", # Missing owner + "owner/repo/extra", # Too many parts + "http://gitlab.com/owner/repo", # Wrong domain + "git@gitlab.com:owner/repo.git", # Wrong domain + ], +) +def test_github_repo_invalid_formats(case: str) -> None: + """Test that invalid formats raise ValueError""" + with pytest.raises(ValueError, match="Invalid GitHub repository format"): + GithubRepo(case) + + +def test_github_repo_string_methods_inheritance() -> None: + """Test that string methods work as expected""" + repo = GithubRepo("owner/repo") + assert repo.upper() == "OWNER/REPO@MAIN" + assert repo.split("/") == ["owner", "repo@main"] + assert repo.split("@") == ["owner/repo", "main"] + assert repo.replace("owner", "newowner") == "newowner/repo@main" + assert len(repo) == len("owner/repo@main") + + +def test_github_repo_comparisons() -> None: + """Test comparison operations""" + repo1 = GithubRepo("owner/repo") + repo2 = GithubRepo("owner/repo") + repo3 = GithubRepo("different/repo") + + assert repo1 == repo2 + assert repo1 != repo3 + assert repo1 == "owner/repo@main" + assert "owner/repo@main" == repo1 + assert repo1 in ["owner/repo@main", "other/repo@main"] + + +def test_github_repo_self_format() -> None: + """Test that GithubRepo can handle its own string representations""" + # Test basic format + repo1 = GithubRepo("owner/repo@main") + assert repo1.namespace == "owner" + assert repo1.repo == "repo" + assert repo1.ref == "main" + assert str(repo1) == "owner/repo@main" + + # Test creating from existing repo string + repo2 = GithubRepo(str(repo1)) + assert repo2.namespace == "owner" + assert repo2.repo == "repo" + assert repo2.ref == "main" + assert str(repo2) == str(repo1) diff --git a/dreadnode_cli/tests/test_utils.py b/dreadnode_cli/tests/test_utils.py index ed74bf4..d115742 100644 --- a/dreadnode_cli/tests/test_utils.py +++ b/dreadnode_cli/tests/test_utils.py @@ -1,6 +1,38 @@ +import os +import pathlib +import typing as t +import zipfile +from collections.abc import Generator from datetime import datetime, timedelta -from dreadnode_cli.utils import parse_jwt_token_expiration, time_to +import httpx +import pytest + +from dreadnode_cli.utils import ( + download_and_unzip_archive, + parse_jwt_token_expiration, + time_to, +) + + +# Mock the httpx.stream context manager +class MockResponse: + def __init__(self, zip_path: pathlib.Path): + self.status_code = 200 + with open(zip_path, "rb") as f: + self.content = f.read() + + def raise_for_status(self) -> None: + pass + + def iter_bytes(self, chunk_size: int) -> Generator[bytes, None, None]: + yield self.content + + def __enter__(self) -> "MockResponse": + return self + + def __exit__(self, *args: t.Any, **kwargs: t.Any) -> None: + pass def test_time_to() -> None: @@ -29,3 +61,55 @@ def test_parse_jwt_token_expiration() -> None: exp_date = parse_jwt_token_expiration(token) assert isinstance(exp_date, datetime) assert exp_date == datetime.fromtimestamp(1708656000) + + +def test_download_and_unzip_archive_success(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch) -> None: + # create a mock zip file with test content + test_file_content = b"test content" + zip_path = tmp_path / "test.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("test.txt", test_file_content) + + monkeypatch.setattr(httpx, "stream", lambda *args, **kw: MockResponse(zip_path)) + + # test successful download and extraction + output_dir = download_and_unzip_archive("http://test.com/archive.zip") + extracted_file = pathlib.Path(output_dir) / "test.txt" + + assert os.path.exists(output_dir) + assert extracted_file.exists() + assert extracted_file.read_bytes() == test_file_content + + +def test_download_and_unzip_archive_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + def mock_stream(*args: t.Any, **kwargs: t.Any) -> MockResponse: + raise httpx.HTTPError("404 Not Found") + + monkeypatch.setattr(httpx, "stream", mock_stream) + + with pytest.raises(httpx.HTTPError, match="404 Not Found"): + download_and_unzip_archive("http://test.com/nonexistent.zip") + + +def test_download_and_unzip_archive_invalid_zip(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch) -> None: + # create a mock file that's not a valid zip + invalid_zip = tmp_path / "invalid.zip" + invalid_zip.write_bytes(b"not a zip file") + + monkeypatch.setattr(httpx, "stream", lambda *args, **kw: MockResponse(invalid_zip)) + + with pytest.raises(zipfile.BadZipFile): + download_and_unzip_archive("http://test.com/invalid.zip") + + +def test_download_and_unzip_archive_path_traversal(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch) -> None: + # create a mock zip file with path traversal attempt + test_file_content = b"test content" + zip_path = tmp_path / "traversal.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("../test.txt", test_file_content) + + monkeypatch.setattr(httpx, "stream", lambda *args, **kw: MockResponse(zip_path)) + + with pytest.raises(Exception, match="Attempted Path Traversal Attack Detected"): + download_and_unzip_archive("http://test.com/archive.zip") diff --git a/dreadnode_cli/types.py b/dreadnode_cli/types.py new file mode 100644 index 0000000..d394fd1 --- /dev/null +++ b/dreadnode_cli/types.py @@ -0,0 +1,136 @@ +import re +import typing as t + + +class GithubRepo(str): + """ + A string subclass that normalizes various GitHub repository string formats. + + Supported formats: + - Full URLs: https://github.com/owner/repo + - SSH URLs: git@github.com:owner/repo.git + - Simple format: owner/repo + - With ref: owner/repo/tree/main + - With complex ref: owner/repo/tree/feature/custom + - With ref (URL): https://github.com/owner/repo/tree/main + - With .git: owner/repo.git + - Raw URLs: https://raw.githubusercontent.com/owner/repo/main + - Release URLs: owner/repo/releases/tag/v1.0.0 + - ZIP URLs: https://github.com/owner/repo/zipball/main + - Own format: owner/repo@ref + """ + + # Instance properties + namespace: str + repo: str + ref: str + + # Regex patterns + SSH_PATTERN = re.compile(r"git@github\.com:([^/]+)/([^/]+?)(\.git)?$") + SIMPLE_PATTERN = re.compile(r"^([^/]+)/([^/]+?)(\.git)?$") + URL_PATTERN = re.compile(r"github\.com/([^/]+)/([^/]+?)(?:\.git|/(?:tree|blob)/(.+?))?$") + RAW_PATTERN = re.compile(r"raw\.githubusercontent\.com/([^/]+)/([^/]+)/(.+)") + RELEASE_PATTERN = re.compile(r"([^/]+)/([^/]+)/releases/tag/(.+)$") + OWN_FORMAT_PATTERN = re.compile(r"^([^/]+)/([^/@:]+)@(.+)$") + ZIPBALL_PATTERN = re.compile(r"github\.com/([^/]+)/([^/]+?)/zipball/(.+)$") + + def __new__(cls, value: t.Any, *args: t.Any, **kwargs: t.Any) -> "GithubRepo": + if not isinstance(value, str): + return super().__new__(cls, str(value)) + + namespace = None + repo = None + ref = "main" + + value = value.strip() + + # Try our own format first (owner/repo@ref) + match = cls.OWN_FORMAT_PATTERN.match(value) + if match: + namespace = match.group(1) + repo = match.group(2) + ref = match.group(3) + + # Try as an SSH URL + elif value.startswith("git@"): + match = cls.SSH_PATTERN.search(value) + if match: + namespace, repo = match.group(1), match.group(2) + + # Try as a full URL + elif value.startswith(("http://", "https://")): + url_parts = value.split("//", 1)[1] + + # Try zipball pattern first + match = cls.ZIPBALL_PATTERN.search(url_parts) + if match: + namespace = match.group(1) + repo = match.group(2) + ref = match.group(3) + + # Try raw githubusercontent pattern + elif url_parts.startswith("raw.githubusercontent.com"): + match = cls.RAW_PATTERN.search(url_parts) + if match: + namespace, repo, ref = match.group(1), match.group(2), match.group(3) + + # Try standard GitHub URL pattern + else: + match = cls.URL_PATTERN.search(url_parts) + if match: + namespace = match.group(1) + repo = match.group(2) + ref = match.group(3) or ref + + # Try release tag format + elif "/releases/tag/" in value: + match = cls.RELEASE_PATTERN.match(value) + if match: + namespace, repo, ref = match.group(1), match.group(2), match.group(3) + + # Try simple owner/repo format + else: + # First try to extract any ref + tree_parts = value.split("/tree/") + blob_parts = value.split("/blob/") + + if len(tree_parts) > 1: + value, ref = tree_parts[0], tree_parts[1] + elif len(blob_parts) > 1: + value, ref = blob_parts[0], blob_parts[1] + + # Now check for owner/repo pattern + match = cls.SIMPLE_PATTERN.match(value) + if match: + namespace, repo = match.group(1), match.group(2) + + if not namespace or not repo: + raise ValueError(f"Invalid GitHub repository format: {value}") + + repo = repo.removesuffix(".git") + + obj = super().__new__(cls, f"{namespace}/{repo}@{ref}") + + obj.namespace = namespace + obj.repo = repo + obj.ref = ref + + return obj + + @property + def zip_url(self) -> str: + """ZIP archive URL for the repository.""" + return f"https://github.com/{self.namespace}/{self.repo}/zipball/{self.ref}" + + @property + def api_zip_url(self) -> str: + """API ZIP archive URL for the repository.""" + return f"https://api.github.com/repos/{self.namespace}/{self.repo}/zipball/{self.ref}" + + @property + def tree_url(self) -> str: + """URL to view the tree at this reference.""" + return f"https://github.com/{self.namespace}/{self.repo}/tree/{self.ref}" + + def __repr__(self) -> str: + return f"GithubRepo(namespace='{self.namespace}', repo='{self.repo}', ref='{self.ref}')" diff --git a/dreadnode_cli/utils.py b/dreadnode_cli/utils.py index 4e92979..2a910dd 100644 --- a/dreadnode_cli/utils.py +++ b/dreadnode_cli/utils.py @@ -1,13 +1,19 @@ import base64 import functools import json +import os +import pathlib import sys +import tempfile import typing as t +import zipfile from datetime import datetime +import httpx from rich import print from dreadnode_cli.defaults import DEBUG +from dreadnode_cli.types import GithubRepo P = t.ParamSpec("P") R = t.TypeVar("R") @@ -60,3 +66,44 @@ def parse_jwt_token_expiration(token: str) -> datetime: _, b64payload, _ = token.split(".") payload = base64.urlsafe_b64decode(b64payload + "==").decode("utf-8") return datetime.fromtimestamp(json.loads(payload).get("exp")) + + +def repo_exists(repo: GithubRepo) -> bool: + """Check if a repo exists (or is private) on GitHub.""" + response = httpx.get(f"https://github.com/repos/{repo.namespace}/{repo.repo}") + return response.status_code == 200 + + +def download_and_unzip_archive(url: str, *, headers: dict[str, str] | None = None) -> pathlib.Path: + """ + Downloads a ZIP archive from the given URL and unzips it into a temporary directory. + """ + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + local_zip_path = pathlib.Path(os.path.join(temp_dir, "archive.zip")) + + print(f":arrow_double_down: Downloading {url} ...") + + # download to temporary file + with httpx.stream("GET", url, follow_redirects=True, verify=True, headers=headers) as response: + response.raise_for_status() + with open(local_zip_path, "wb") as zip_file: + for chunk in response.iter_bytes(chunk_size=8192): + zip_file.write(chunk) + + # unzip to temporary directory + try: + with zipfile.ZipFile(local_zip_path, "r") as zf: + for member in zf.infolist(): + file_path = os.path.realpath(os.path.join(temp_dir, member.filename)) + if file_path.startswith(os.path.realpath(temp_dir)): + zf.extract(member, temp_dir) + else: + raise Exception("Attempted Path Traversal Attack Detected") + + finally: + # always remove the zip file + if local_zip_path.exists(): + os.remove(local_zip_path) + + return temp_dir