diff --git a/internal/dms/pkg/constant/const.go b/internal/dms/pkg/constant/const.go index 3bf4b9056..5b7852cd8 100644 --- a/internal/dms/pkg/constant/const.go +++ b/internal/dms/pkg/constant/const.go @@ -252,6 +252,8 @@ func ParseDBType(s string) (DBType, error) { return DBTypeGaussDB, nil case "HANA": return DBTypeHANA, nil + case "PolarDB For MySQL": + return DBTypePolarDBMySQL, nil default: return "", fmt.Errorf("invalid db type: %s", s) @@ -273,6 +275,7 @@ const ( DBTypeDM DBType = "达梦(DM)" DBTypeGaussDB DBType = "GaussDB / openGauss" DBTypeHANA DBType = "HANA" + DBTypePolarDBMySQL DBType = "PolarDB For MySQL" ) var supportedDataExportDBTypes = map[DBType]struct{}{ @@ -291,6 +294,7 @@ var supportedDataExportDBTypes = map[DBType]struct{}{ DBTypeGaussDB: {}, DBTypeDB2: {}, DBTypeHANA: {}, + DBTypePolarDBMySQL: {}, } func CheckDBTypeIfDataExportSupported(dbtype string) bool { diff --git a/internal/dms/pkg/constant/const_test.go b/internal/dms/pkg/constant/const_test.go index 51b28b880..17baa17b1 100644 --- a/internal/dms/pkg/constant/const_test.go +++ b/internal/dms/pkg/constant/const_test.go @@ -14,6 +14,7 @@ func TestCheckDBTypeIfDataExportSupported_NewTypes(t *testing.T) { "GaussDB for MySQL": true, // GaussDB/openGauss: ParseDBType 的输入值是 "GaussDB for MySQL" "DB2": true, "HANA": true, + "PolarDB For MySQL": true, } for dbType, expectedSupported := range newTypes { t.Run(dbType, func(t *testing.T) { @@ -46,6 +47,42 @@ func TestCheckDBTypeIfDataExportSupported_ExistingTypes(t *testing.T) { } } +func TestParseDBType_PolarDB(t *testing.T) { + cases := map[string]struct { + input string + wantDBType DBType + wantErr bool + }{ + "valid PolarDB For MySQL": { + input: "PolarDB For MySQL", + wantDBType: DBTypePolarDBMySQL, + wantErr: false, + }, + "invalid lowercase polardb": { + input: "polardb for mysql", + wantDBType: "", + wantErr: true, + }, + "invalid partial match": { + input: "PolarDB", + wantDBType: "", + wantErr: true, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := ParseDBType(tc.input) + if (err != nil) != tc.wantErr { + t.Errorf("ParseDBType(%q) error = %v, wantErr %v", tc.input, err, tc.wantErr) + return + } + if got != tc.wantDBType { + t.Errorf("ParseDBType(%q) = %v, want %v", tc.input, got, tc.wantDBType) + } + }) + } +} + func TestCheckDBTypeIfDataExportSupported_UnsupportedTypes(t *testing.T) { // 验证未支持的类型返回 false unsupportedTypes := map[string]bool{ diff --git a/internal/sql_workbench/service/sql_workbench_service.go b/internal/sql_workbench/service/sql_workbench_service.go index e2ffff1b7..eb77409e2 100644 --- a/internal/sql_workbench/service/sql_workbench_service.go +++ b/internal/sql_workbench/service/sql_workbench_service.go @@ -937,6 +937,8 @@ func (sqlWorkbenchService *SqlWorkbenchService) convertDBType(dmsDBType string) return "MYSQL" case "GoldenDB": return "MYSQL" + case "PolarDB For MySQL": + return "MYSQL" default: return dmsDBType } diff --git a/internal/sql_workbench/service/sql_workbench_service_test.go b/internal/sql_workbench/service/sql_workbench_service_test.go index 60d5784b0..ce6673c02 100644 --- a/internal/sql_workbench/service/sql_workbench_service_test.go +++ b/internal/sql_workbench/service/sql_workbench_service_test.go @@ -22,6 +22,7 @@ func Test_convertDBType(t *testing.T) { "TiDB": {input: "TiDB", expected: "TIDB"}, "TDSQL For InnoDB": {input: "TDSQL For InnoDB", expected: "MYSQL"}, "GoldenDB": {input: "GoldenDB", expected: "MYSQL"}, + "PolarDB For MySQL": {input: "PolarDB For MySQL", expected: "MYSQL"}, "Unknown passthrough": {input: "UnknownDB", expected: "UnknownDB"}, } for name, tc := range cases { @@ -49,6 +50,7 @@ func Test_SupportDBType(t *testing.T) { "GoldenDB supported": {input: pkgConst.DBTypeGoldenDB, expected: true}, "PostgreSQL unsupported": {input: pkgConst.DBTypePostgreSQL, expected: false}, "SQL Server unsupported": {input: pkgConst.DBTypeSQLServer, expected: false}, + "PolarDB MySQL unsupported": {input: pkgConst.DBTypePolarDBMySQL, expected: false}, } for name, tc := range cases { t.Run(name, func(t *testing.T) {