diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 88b40f3..2779fd5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ jobs: name: test runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-go@v4 with: go-version: '1.21.0' diff --git a/go.mod b/go.mod index 99cab88..9d96bbc 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,9 @@ module github.com/zztkm/sqlc-gen-python-orm go 1.19 require ( - buf.build/gen/go/sqlc/sqlc/protocolbuffers/go v1.30.0-20230621221448-196413f69ab3.1 + buf.build/gen/go/sqlc/sqlc/protocolbuffers/go v1.31.0-20231002190240-3f2d312ab6fd.1 github.com/google/go-cmp v0.5.9 github.com/jinzhu/inflection v1.0.0 - github.com/sqlc-dev/sqlc-go v1.18.1 - google.golang.org/protobuf v1.30.0 + github.com/sqlc-dev/sqlc-go v1.22.0 + google.golang.org/protobuf v1.31.0 ) - -require github.com/tabbed/sqlc-go v1.18.0 // indirect diff --git a/go.sum b/go.sum index 0ae299e..23b946b 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,14 @@ -buf.build/gen/go/sqlc/sqlc/protocolbuffers/go v1.30.0-20230621221448-196413f69ab3.1 h1:ze0HODAjPRXSkiqSpDTYq2baS4IVtRtDLSZY2p1ZCX4= -buf.build/gen/go/sqlc/sqlc/protocolbuffers/go v1.30.0-20230621221448-196413f69ab3.1/go.mod h1:DSpReHp8PwHOeCfGymiiY4HSx2iVL358X7JRMciL7T0= +buf.build/gen/go/sqlc/sqlc/protocolbuffers/go v1.31.0-20231002190240-3f2d312ab6fd.1 h1:94JzirpGhebc3++MqmvWY0fi9TJxlle5M52NO4pTEZY= +buf.build/gen/go/sqlc/sqlc/protocolbuffers/go v1.31.0-20231002190240-3f2d312ab6fd.1/go.mod h1:x7kMRcmAiYQXko+NDqLP2agondNlbHKUNGPXqU1nrOU= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/sqlc-dev/sqlc-go v1.18.1 h1:mmudfN9G938piXnZGvrMEHp9RF4dD+InQIY1BaoQOvU= -github.com/sqlc-dev/sqlc-go v1.18.1/go.mod h1:v6c+FMh0YrbT9RU9+S5Sh62VXmVhdpTsQoXn1QxH294= -github.com/tabbed/sqlc-go v1.18.0 h1:GNE8b8xue8fKVptQnr3Z6DV8FqdokyDYML7O0kYtbe4= -github.com/tabbed/sqlc-go v1.18.0/go.mod h1:qx8ocsmviBDyRfLNuJQtdu0f5oqa8XBjKxMldl+Wm24= +github.com/sqlc-dev/sqlc-go v1.22.0 h1:ivUplxHRkw1WZ++rs80OfoJLYbpXMXYGtc79e7z/0HA= +github.com/sqlc-dev/sqlc-go v1.22.0/go.mod h1:/4snw3ucbglJfyLRxp8X2weM4pwT8w1NlEKm4PzxAuQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= diff --git a/internal/gen.go b/internal/gen.go index 69a2d13..86d75b0 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -207,6 +207,8 @@ func pyInnerType(req *plugin.CodeGenRequest, col *plugin.Column) string { switch req.Settings.Engine { case "postgresql": return postgresType(req, col) + case "mysql": + return mysqlType(req, col) default: log.Println("unsupported engine type") return "Any" @@ -376,6 +378,15 @@ func sqlalchemySQL(s, engine string) string { s = strings.ReplaceAll(s, ":", `\\:`) if engine == "postgresql" { return postgresPlaceholderRegexp.ReplaceAllString(s, ":p$1") + } else if engine == "mysql" { + // All "?" in string s in string s are replaced with ":p1", ":p2", ... in that order + parts := strings.Split(s, "?") + for i := range parts { + if i != 0 { + parts[i] = fmt.Sprintf(":p%d%s", i, parts[i]) + } + } + return strings.Join(parts, "") } return s } diff --git a/internal/mysql_type.go b/internal/mysql_type.go new file mode 100644 index 0000000..7bad4fa --- /dev/null +++ b/internal/mysql_type.go @@ -0,0 +1,72 @@ +package python + +import ( + "log" + + "buf.build/gen/go/sqlc/sqlc/protocolbuffers/go/protos/plugin" + "github.com/sqlc-dev/sqlc-go/sdk" +) + +func mysqlType(req *plugin.CodeGenRequest, col *plugin.Column) string { + columnType := sdk.DataType(col.Type) + + switch columnType { + + case "varchar", "text", "char", "tinytext", "mediumtext", "longtext": + return "str" + + case "tinyint": + if col.Length == 1 { + return "bool" + } else { + return "int" + } + + case "int", "integer", "smallint", "mediumint", "year": + return "int" + + case "bigint": + return "int" + + case "blob", "binary", "varbinary", "tinyblob", "mediumblob", "longblob": + // TODO: Proper blob support + return "Any" + + case "double", "double precision", "real", "float": + return "float" + + case "decimal", "dec", "fixed": + return "string" + + case "enum": + // TODO: Proper Enum support + return "string" + + case "date", "timestamp", "datetime", "time": + return "datetime.date" + + case "boolean", "bool": + return "bool" + + case "json": + return "Any" + + case "any": + return "Any" + + default: + for _, schema := range req.Catalog.Schemas { + for _, enum := range schema.Enums { + if columnType == enum.Name { + if schema.Name == req.Catalog.DefaultSchema { + return "models." + modelName(enum.Name, req.Settings) + } + return "models." + modelName(schema.Name+"_"+enum.Name, req.Settings) + } + } + } + log.Printf("Unknown MySQL type: %s\n", columnType) + return "Any" + + } +}