diff --git a/sagemaker-train/tests/unit/ai_registry/test_dataset_domain_id.py b/sagemaker-train/tests/unit/ai_registry/test_dataset_domain_id.py index 7e72064057..d1e73ee665 100644 --- a/sagemaker-train/tests/unit/ai_registry/test_dataset_domain_id.py +++ b/sagemaker-train/tests/unit/ai_registry/test_dataset_domain_id.py @@ -11,27 +11,58 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Unit tests for domain-id tagging in DataSet.""" +import json +import tempfile +import os import pytest from unittest.mock import Mock, patch, MagicMock from sagemaker.ai_registry.dataset import DataSet from sagemaker.ai_registry.dataset_utils import CustomizationTechnique +# Sample RLVR format dataset (GSM8K style) +SAMPLE_DATASET = { + "data_source": "openai/gsm8k", + "prompt": [{"content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\".", "role": "user"}], + "ability": "math", + "reward_model": {"ground_truth": "72", "style": "rule"}, + "extra_info": {"answer": "Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72", "index": 0, "question": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?", "split": "train"} +} + + +@pytest.fixture +def sample_dataset_file(): + """Create a temporary JSONL file with sample dataset.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f: + json.dump(SAMPLE_DATASET, f) + temp_path = f.name + + yield temp_path + + # Cleanup + if os.path.exists(temp_path): + os.remove(temp_path) + + class TestDataSetDomainId: """Test domain-id is added to SearchKeywords when available.""" @patch('sagemaker.core.helper.session_helper.Session') @patch('sagemaker.ai_registry.dataset._get_current_domain_id') @patch('sagemaker.ai_registry.dataset.AIRHub') - @patch('sagemaker.ai_registry.dataset.validate_dataset') + @patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.defaults.TrainDefaults.get_role') def test_domain_id_added_when_available( - self, mock_validate, mock_air_hub, mock_get_domain_id, mock_session + self, mock_get_role, mock_get_session, mock_air_hub, mock_get_domain_id, mock_session, sample_dataset_file ): """Test that domain-id is added to tags when available.""" # Setup mocks mock_domain_id = "d-test123456" mock_get_domain_id.return_value = mock_domain_id - mock_session.return_value = Mock() + mock_session_instance = Mock() + mock_session.return_value = mock_session_instance + mock_get_session.return_value = mock_session_instance + mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role" # Mock AIRHub methods mock_air_hub.upload_to_s3 = Mock() @@ -46,11 +77,11 @@ def test_domain_id_added_when_available( 'HubContentDocument': '{"DatasetS3Bucket": "bucket", "DatasetS3Prefix": "prefix"}' }) - # Create dataset + # Create dataset with real file with patch('sagemaker.ai_registry.dataset.DataSet.wait'): dataset = DataSet.create( name="test-dataset", - source="test-data.jsonl", + source=sample_dataset_file, customization_technique=CustomizationTechnique.SFT ) @@ -67,14 +98,18 @@ def test_domain_id_added_when_available( @patch('sagemaker.core.helper.session_helper.Session') @patch('sagemaker.ai_registry.dataset._get_current_domain_id') @patch('sagemaker.ai_registry.dataset.AIRHub') - @patch('sagemaker.ai_registry.dataset.validate_dataset') + @patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.defaults.TrainDefaults.get_role') def test_domain_id_not_added_when_unavailable( - self, mock_validate, mock_air_hub, mock_get_domain_id, mock_session + self, mock_get_role, mock_get_session, mock_air_hub, mock_get_domain_id, mock_session, sample_dataset_file ): """Test that domain-id is not added when unavailable (non-Studio).""" # Setup mocks - domain_id returns None mock_get_domain_id.return_value = None - mock_session.return_value = Mock() + mock_session_instance = Mock() + mock_session.return_value = mock_session_instance + mock_get_session.return_value = mock_session_instance + mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role" # Mock AIRHub methods mock_air_hub.upload_to_s3 = Mock() @@ -89,11 +124,11 @@ def test_domain_id_not_added_when_unavailable( 'HubContentDocument': '{"DatasetS3Bucket": "bucket", "DatasetS3Prefix": "prefix"}' }) - # Create dataset + # Create dataset with real file with patch('sagemaker.ai_registry.dataset.DataSet.wait'): dataset = DataSet.create( name="test-dataset", - source="test-data.jsonl", + source=sample_dataset_file, customization_technique=CustomizationTechnique.SFT ) @@ -110,14 +145,19 @@ def test_domain_id_not_added_when_unavailable( @patch('sagemaker.core.helper.session_helper.Session') @patch('sagemaker.ai_registry.dataset._get_current_domain_id') @patch('sagemaker.ai_registry.dataset.AIRHub') + @patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.defaults.TrainDefaults.get_role') def test_domain_id_added_without_customization_technique( - self, mock_air_hub, mock_get_domain_id, mock_session + self, mock_get_role, mock_get_session, mock_air_hub, mock_get_domain_id, mock_session, sample_dataset_file ): """Test that domain-id is added even without customization_technique.""" # Setup mocks mock_domain_id = "d-test789" mock_get_domain_id.return_value = mock_domain_id - mock_session.return_value = Mock() + mock_session_instance = Mock() + mock_session.return_value = mock_session_instance + mock_get_session.return_value = mock_session_instance + mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role" # Mock AIRHub methods mock_air_hub.upload_to_s3 = Mock() @@ -132,11 +172,11 @@ def test_domain_id_added_without_customization_technique( 'HubContentDocument': '{"DatasetS3Bucket": "bucket", "DatasetS3Prefix": "prefix"}' }) - # Create dataset WITHOUT customization_technique + # Create dataset WITHOUT customization_technique using real file with patch('sagemaker.ai_registry.dataset.DataSet.wait'): dataset = DataSet.create( name="test-dataset", - source="test-data.jsonl" + source=sample_dataset_file # No customization_technique )