Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 54 additions & 14 deletions sagemaker-train/tests/unit/ai_registry/test_dataset_domain_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,58 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Unit tests for domain-id tagging in DataSet."""
import json
import tempfile
import os
import pytest
from unittest.mock import Mock, patch, MagicMock
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.dataset_utils import CustomizationTechnique


# Sample RLVR format dataset (GSM8K style)
SAMPLE_DATASET = {
"data_source": "openai/gsm8k",
"prompt": [{"content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\".", "role": "user"}],
"ability": "math",
"reward_model": {"ground_truth": "72", "style": "rule"},
"extra_info": {"answer": "Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72", "index": 0, "question": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?", "split": "train"}
}


@pytest.fixture
def sample_dataset_file():
"""Create a temporary JSONL file with sample dataset."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
json.dump(SAMPLE_DATASET, f)
temp_path = f.name

yield temp_path

# Cleanup
if os.path.exists(temp_path):
os.remove(temp_path)


class TestDataSetDomainId:
"""Test domain-id is added to SearchKeywords when available."""

@patch('sagemaker.core.helper.session_helper.Session')
@patch('sagemaker.ai_registry.dataset._get_current_domain_id')
@patch('sagemaker.ai_registry.dataset.AIRHub')
@patch('sagemaker.ai_registry.dataset.validate_dataset')
@patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session')
@patch('sagemaker.train.defaults.TrainDefaults.get_role')
def test_domain_id_added_when_available(
self, mock_validate, mock_air_hub, mock_get_domain_id, mock_session
self, mock_get_role, mock_get_session, mock_air_hub, mock_get_domain_id, mock_session, sample_dataset_file
):
"""Test that domain-id is added to tags when available."""
# Setup mocks
mock_domain_id = "d-test123456"
mock_get_domain_id.return_value = mock_domain_id
mock_session.return_value = Mock()
mock_session_instance = Mock()
mock_session.return_value = mock_session_instance
mock_get_session.return_value = mock_session_instance
mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role"

# Mock AIRHub methods
mock_air_hub.upload_to_s3 = Mock()
Expand All @@ -46,11 +77,11 @@ def test_domain_id_added_when_available(
'HubContentDocument': '{"DatasetS3Bucket": "bucket", "DatasetS3Prefix": "prefix"}'
})

# Create dataset
# Create dataset with real file
with patch('sagemaker.ai_registry.dataset.DataSet.wait'):
dataset = DataSet.create(
name="test-dataset",
source="test-data.jsonl",
source=sample_dataset_file,
customization_technique=CustomizationTechnique.SFT
)

Expand All @@ -67,14 +98,18 @@ def test_domain_id_added_when_available(
@patch('sagemaker.core.helper.session_helper.Session')
@patch('sagemaker.ai_registry.dataset._get_current_domain_id')
@patch('sagemaker.ai_registry.dataset.AIRHub')
@patch('sagemaker.ai_registry.dataset.validate_dataset')
@patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session')
@patch('sagemaker.train.defaults.TrainDefaults.get_role')
def test_domain_id_not_added_when_unavailable(
self, mock_validate, mock_air_hub, mock_get_domain_id, mock_session
self, mock_get_role, mock_get_session, mock_air_hub, mock_get_domain_id, mock_session, sample_dataset_file
):
"""Test that domain-id is not added when unavailable (non-Studio)."""
# Setup mocks - domain_id returns None
mock_get_domain_id.return_value = None
mock_session.return_value = Mock()
mock_session_instance = Mock()
mock_session.return_value = mock_session_instance
mock_get_session.return_value = mock_session_instance
mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role"

# Mock AIRHub methods
mock_air_hub.upload_to_s3 = Mock()
Expand All @@ -89,11 +124,11 @@ def test_domain_id_not_added_when_unavailable(
'HubContentDocument': '{"DatasetS3Bucket": "bucket", "DatasetS3Prefix": "prefix"}'
})

# Create dataset
# Create dataset with real file
with patch('sagemaker.ai_registry.dataset.DataSet.wait'):
dataset = DataSet.create(
name="test-dataset",
source="test-data.jsonl",
source=sample_dataset_file,
customization_technique=CustomizationTechnique.SFT
)

Expand All @@ -110,14 +145,19 @@ def test_domain_id_not_added_when_unavailable(
@patch('sagemaker.core.helper.session_helper.Session')
@patch('sagemaker.ai_registry.dataset._get_current_domain_id')
@patch('sagemaker.ai_registry.dataset.AIRHub')
@patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session')
@patch('sagemaker.train.defaults.TrainDefaults.get_role')
def test_domain_id_added_without_customization_technique(
self, mock_air_hub, mock_get_domain_id, mock_session
self, mock_get_role, mock_get_session, mock_air_hub, mock_get_domain_id, mock_session, sample_dataset_file
):
"""Test that domain-id is added even without customization_technique."""
# Setup mocks
mock_domain_id = "d-test789"
mock_get_domain_id.return_value = mock_domain_id
mock_session.return_value = Mock()
mock_session_instance = Mock()
mock_session.return_value = mock_session_instance
mock_get_session.return_value = mock_session_instance
mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role"

# Mock AIRHub methods
mock_air_hub.upload_to_s3 = Mock()
Expand All @@ -132,11 +172,11 @@ def test_domain_id_added_without_customization_technique(
'HubContentDocument': '{"DatasetS3Bucket": "bucket", "DatasetS3Prefix": "prefix"}'
})

# Create dataset WITHOUT customization_technique
# Create dataset WITHOUT customization_technique using real file
with patch('sagemaker.ai_registry.dataset.DataSet.wait'):
dataset = DataSet.create(
name="test-dataset",
source="test-data.jsonl"
source=sample_dataset_file
# No customization_technique
)

Expand Down
Loading