Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions osfclient/models/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from contextlib import asynccontextmanager

import httpx

from ..exceptions import UnauthorizedException
Expand Down Expand Up @@ -57,8 +59,59 @@ async def put(self, url, *args, **kwargs):
return response

def stream(self, method, url, *args, **kwargs):
kwargs_ = self.modify_kwargs(kwargs)
return super(OSFSession, self).stream(method, url, *args, **kwargs_)
"""Stream with safe redirect handling.

When WaterButler returns a redirect to S3 (or other external storage),
we must not forward Content-Type and Accept headers because they would
break S3's presigned URL signature validation.
"""
return self._stream_with_safe_redirect(method, url, *args, **kwargs)

@asynccontextmanager
async def _stream_with_safe_redirect(self, method, url, *args, **kwargs):
"""Stream with redirect handling that strips API headers.

- 301, 302, 303: Follow as GET with minimal headers
- 307, 308: Raise error (WB providers never use these status codes)
"""
kwargs_no_redirect = kwargs.copy()
kwargs_no_redirect.update(dict(follow_redirects=False))

redirect_location = None
async with super(OSFSession, self).stream(
method, url, *args, **kwargs_no_redirect
) as response:
if response.status_code in (301, 302, 303):
if response.headers.get('location'):
redirect_location = response.headers.get('location')
else:
yield response
return
elif response.status_code in (307, 308):
raise RuntimeError(
f"HTTP {response.status_code} redirect is not supported."
)
else:
yield response
return

async with self._follow_redirect(redirect_location) as redirected_response:
yield redirected_response

@asynccontextmanager
async def _follow_redirect(self, url: str):
"""Follow a redirect with minimal headers.

Only sends headers that won't interfere with presigned URL signatures.
"""
clean_headers = {
'User-Agent': self.headers.get('User-Agent', 'osfclient v0.0.1'),
'Accept-Charset': self.headers.get('Accept-Charset', 'utf-8'),
}

async with httpx.AsyncClient(timeout=self._timeout) as clean_client:
async with clean_client.stream('GET', url, headers=clean_headers) as response:
yield response

async def get(self, url, *args, **kwargs):
kwargs_ = self.modify_kwargs(kwargs)
Expand Down
112 changes: 110 additions & 2 deletions osfclient/tests/test_session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
from mock import patch
from mock import MagicMock
from mock import patch, MagicMock, AsyncMock

import pytest

Expand Down Expand Up @@ -91,3 +90,112 @@ async def test_get(mock_get):

assert response == mock_response
mock_get.assert_called_once_with(url, follow_redirects=True)


# Tests for stream method with redirect handling

@pytest.mark.asyncio
@patch('osfclient.models.session.httpx.AsyncClient.stream')
async def test_stream_no_redirect(mock_stream):
"""When response is 200 OK, yield response directly without redirect handling."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {}

mock_stream.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream.return_value.__aexit__ = AsyncMock(return_value=None)

session = OSFSession()
async with session.stream('GET', 'http://localhost:7777/download') as response:
assert response.status_code == 200
assert response is mock_response


@pytest.mark.asyncio
@patch('osfclient.models.session.httpx.AsyncClient')
@patch('osfclient.models.session.httpx.AsyncClient.stream')
async def test_stream_redirect(mock_parent_stream, mock_client_class):
"""When response is a redirect, follow with clean headers."""
# Mock the initial response (302 redirect)
mock_initial_response = MagicMock()
mock_initial_response.status_code = 302
mock_initial_response.headers = {'location': 'https://s3.amazonaws.com/bucket/file?Signature=xxx'}

mock_parent_stream.return_value.__aenter__ = AsyncMock(return_value=mock_initial_response)
mock_parent_stream.return_value.__aexit__ = AsyncMock(return_value=None)

# Mock the redirect response (200 OK with content)
mock_redirect_response = MagicMock()
mock_redirect_response.status_code = 200

mock_redirect_stream = MagicMock()
mock_redirect_stream.return_value.__aenter__ = AsyncMock(return_value=mock_redirect_response)
mock_redirect_stream.return_value.__aexit__ = AsyncMock(return_value=None)

mock_client_instance = MagicMock()
mock_client_instance.stream = mock_redirect_stream
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
mock_client_class.return_value = mock_client_instance

session = OSFSession()
async with session.stream('GET', 'http://localhost:7777/download') as response:
assert response.status_code == 200
assert response is mock_redirect_response

# Verify clean headers are used (no Content-Type, no Accept)
call_args = mock_redirect_stream.call_args
headers = call_args.kwargs.get('headers', {})
assert 'User-Agent' in headers
assert 'Accept-Charset' in headers
assert 'Content-Type' not in headers
assert 'Accept' not in headers


@pytest.mark.asyncio
@patch('osfclient.models.session.httpx.AsyncClient')
@patch('osfclient.models.session.httpx.AsyncClient.stream')
async def test_stream_redirect_headers_not_forwarded(mock_parent_stream, mock_client_class):
"""Verify that API-specific headers (Content-Type, Accept, Authorization) are NOT forwarded on redirect."""
# Mock the initial response (302 redirect)
mock_initial_response = MagicMock()
mock_initial_response.status_code = 302
mock_initial_response.headers = {'location': 'https://s3.amazonaws.com/bucket/file'}

mock_parent_stream.return_value.__aenter__ = AsyncMock(return_value=mock_initial_response)
mock_parent_stream.return_value.__aexit__ = AsyncMock(return_value=None)

# Mock the redirect response (200 OK)
mock_redirect_response = MagicMock()
mock_redirect_response.status_code = 200

mock_redirect_stream = MagicMock()
mock_redirect_stream.return_value.__aenter__ = AsyncMock(return_value=mock_redirect_response)
mock_redirect_stream.return_value.__aexit__ = AsyncMock(return_value=None)

mock_client_instance = MagicMock()
mock_client_instance.stream = mock_redirect_stream
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
mock_client_class.return_value = mock_client_instance

session = OSFSession()
# Add Authorization header
session.token_auth('test-token')

async with session.stream('GET', 'http://localhost:7777/download') as response:
pass

# Verify headers passed to redirect request
call_args = mock_redirect_stream.call_args
headers = call_args.kwargs.get('headers', {})

# Verify headers that SHOULD be present
assert 'User-Agent' in headers
assert headers['User-Agent'] == 'osfclient v0.0.1'
assert 'Accept-Charset' in headers

# Verify headers that MUST NOT be present (API-specific headers)
assert 'Content-Type' not in headers
assert 'Accept' not in headers
assert 'Authorization' not in headers