diff --git a/osfclient/models/session.py b/osfclient/models/session.py index ab0a405..5153af7 100644 --- a/osfclient/models/session.py +++ b/osfclient/models/session.py @@ -1,4 +1,6 @@ import os +from contextlib import asynccontextmanager + import httpx from ..exceptions import UnauthorizedException @@ -57,8 +59,59 @@ async def put(self, url, *args, **kwargs): return response def stream(self, method, url, *args, **kwargs): - kwargs_ = self.modify_kwargs(kwargs) - return super(OSFSession, self).stream(method, url, *args, **kwargs_) + """Stream with safe redirect handling. + + When WaterButler returns a redirect to S3 (or other external storage), + we must not forward Content-Type and Accept headers because they would + break S3's presigned URL signature validation. + """ + return self._stream_with_safe_redirect(method, url, *args, **kwargs) + + @asynccontextmanager + async def _stream_with_safe_redirect(self, method, url, *args, **kwargs): + """Stream with redirect handling that strips API headers. + + - 301, 302, 303: Follow as GET with minimal headers + - 307, 308: Raise error (WB providers never use these status codes) + """ + kwargs_no_redirect = kwargs.copy() + kwargs_no_redirect.update(dict(follow_redirects=False)) + + redirect_location = None + async with super(OSFSession, self).stream( + method, url, *args, **kwargs_no_redirect + ) as response: + if response.status_code in (301, 302, 303): + if response.headers.get('location'): + redirect_location = response.headers.get('location') + else: + yield response + return + elif response.status_code in (307, 308): + raise RuntimeError( + f"HTTP {response.status_code} redirect is not supported." + ) + else: + yield response + return + + async with self._follow_redirect(redirect_location) as redirected_response: + yield redirected_response + + @asynccontextmanager + async def _follow_redirect(self, url: str): + """Follow a redirect with minimal headers. + + Only sends headers that won't interfere with presigned URL signatures. + """ + clean_headers = { + 'User-Agent': self.headers.get('User-Agent', 'osfclient v0.0.1'), + 'Accept-Charset': self.headers.get('Accept-Charset', 'utf-8'), + } + + async with httpx.AsyncClient(timeout=self._timeout) as clean_client: + async with clean_client.stream('GET', url, headers=clean_headers) as response: + yield response async def get(self, url, *args, **kwargs): kwargs_ = self.modify_kwargs(kwargs) diff --git a/osfclient/tests/test_session.py b/osfclient/tests/test_session.py index a4a70d9..da626f5 100644 --- a/osfclient/tests/test_session.py +++ b/osfclient/tests/test_session.py @@ -1,6 +1,5 @@ import asyncio -from mock import patch -from mock import MagicMock +from mock import patch, MagicMock, AsyncMock import pytest @@ -91,3 +90,112 @@ async def test_get(mock_get): assert response == mock_response mock_get.assert_called_once_with(url, follow_redirects=True) + + +# Tests for stream method with redirect handling + +@pytest.mark.asyncio +@patch('osfclient.models.session.httpx.AsyncClient.stream') +async def test_stream_no_redirect(mock_stream): + """When response is 200 OK, yield response directly without redirect handling.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {} + + mock_stream.return_value.__aenter__ = AsyncMock(return_value=mock_response) + mock_stream.return_value.__aexit__ = AsyncMock(return_value=None) + + session = OSFSession() + async with session.stream('GET', 'http://localhost:7777/download') as response: + assert response.status_code == 200 + assert response is mock_response + + +@pytest.mark.asyncio +@patch('osfclient.models.session.httpx.AsyncClient') +@patch('osfclient.models.session.httpx.AsyncClient.stream') +async def test_stream_redirect(mock_parent_stream, mock_client_class): + """When response is a redirect, follow with clean headers.""" + # Mock the initial response (302 redirect) + mock_initial_response = MagicMock() + mock_initial_response.status_code = 302 + mock_initial_response.headers = {'location': 'https://s3.amazonaws.com/bucket/file?Signature=xxx'} + + mock_parent_stream.return_value.__aenter__ = AsyncMock(return_value=mock_initial_response) + mock_parent_stream.return_value.__aexit__ = AsyncMock(return_value=None) + + # Mock the redirect response (200 OK with content) + mock_redirect_response = MagicMock() + mock_redirect_response.status_code = 200 + + mock_redirect_stream = MagicMock() + mock_redirect_stream.return_value.__aenter__ = AsyncMock(return_value=mock_redirect_response) + mock_redirect_stream.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_client_instance = MagicMock() + mock_client_instance.stream = mock_redirect_stream + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + session = OSFSession() + async with session.stream('GET', 'http://localhost:7777/download') as response: + assert response.status_code == 200 + assert response is mock_redirect_response + + # Verify clean headers are used (no Content-Type, no Accept) + call_args = mock_redirect_stream.call_args + headers = call_args.kwargs.get('headers', {}) + assert 'User-Agent' in headers + assert 'Accept-Charset' in headers + assert 'Content-Type' not in headers + assert 'Accept' not in headers + + +@pytest.mark.asyncio +@patch('osfclient.models.session.httpx.AsyncClient') +@patch('osfclient.models.session.httpx.AsyncClient.stream') +async def test_stream_redirect_headers_not_forwarded(mock_parent_stream, mock_client_class): + """Verify that API-specific headers (Content-Type, Accept, Authorization) are NOT forwarded on redirect.""" + # Mock the initial response (302 redirect) + mock_initial_response = MagicMock() + mock_initial_response.status_code = 302 + mock_initial_response.headers = {'location': 'https://s3.amazonaws.com/bucket/file'} + + mock_parent_stream.return_value.__aenter__ = AsyncMock(return_value=mock_initial_response) + mock_parent_stream.return_value.__aexit__ = AsyncMock(return_value=None) + + # Mock the redirect response (200 OK) + mock_redirect_response = MagicMock() + mock_redirect_response.status_code = 200 + + mock_redirect_stream = MagicMock() + mock_redirect_stream.return_value.__aenter__ = AsyncMock(return_value=mock_redirect_response) + mock_redirect_stream.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_client_instance = MagicMock() + mock_client_instance.stream = mock_redirect_stream + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + session = OSFSession() + # Add Authorization header + session.token_auth('test-token') + + async with session.stream('GET', 'http://localhost:7777/download') as response: + pass + + # Verify headers passed to redirect request + call_args = mock_redirect_stream.call_args + headers = call_args.kwargs.get('headers', {}) + + # Verify headers that SHOULD be present + assert 'User-Agent' in headers + assert headers['User-Agent'] == 'osfclient v0.0.1' + assert 'Accept-Charset' in headers + + # Verify headers that MUST NOT be present (API-specific headers) + assert 'Content-Type' not in headers + assert 'Accept' not in headers + assert 'Authorization' not in headers