diff --git a/stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir b/stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir new file mode 100644 index 0000000000..feb4a88cc4 --- /dev/null +++ b/stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir @@ -0,0 +1,15 @@ +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=1.4.1' %s | FileCheck %s + +// AllToAll was in the initial StableHLO opset, but changed in v1.5.0 to have +// tuple arguments. Ensure that serializing for 1.4.1 is valid and targets the +// v1.4.0 opset. +// +// This will catch issues in op `isLegal` checks: +// op.minVersion() <= target <= op.maxVersion() + +// CHECK-LABEL: vhlo.func_v1 @all_to_all +func.func public @all_to_all(%arg0: tensor<8x8x1xui16>) -> tensor<1x8x8xui16> { + // CHECK: vhlo.all_to_all_v1 + %0 = "stablehlo.all_to_all"(%arg0) <{concat_dimension = 2 : i64, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, split_count = 8 : i64, split_dimension = 0 : i64}> : (tensor<8x8x1xui16>) -> tensor<1x8x8xui16> + return %0 : tensor<1x8x8xui16> +} diff --git a/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/transforms/VhloToVersion.cpp index 02b956617e..e1d2225587 100644 --- a/stablehlo/transforms/VhloToVersion.cpp +++ b/stablehlo/transforms/VhloToVersion.cpp @@ -92,6 +92,13 @@ FailureOr validateTargetVersion(llvm::StringRef versionRef, << " is greater than current version " << Version::getCurrentVersion(); + // Opset changes warrant a minor version bump, so this conversion assumes + // patch v0 since it is written against the opset at version `X.Y.0`. + if (targetVersion.getPatch() != 0) { + targetVersion = + vhlo::Version(targetVersion.getMajor(), targetVersion.getMinor(), 0); + } + return targetVersion; }