dev
This commit is contained in:
261
main.go
261
main.go
@@ -1,13 +1,22 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"protoc-gen-slc/tpl"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/mod/modfile"
|
||||
"google.golang.org/protobuf/compiler/protogen"
|
||||
"google.golang.org/protobuf/types/pluginpb"
|
||||
)
|
||||
|
||||
var ServicesName []string
|
||||
|
||||
func main() {
|
||||
protogen.Options{}.Run(func(gen *protogen.Plugin) error {
|
||||
gen.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
|
||||
@@ -19,176 +28,115 @@ func main() {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
generateNewServerFile(ServicesName)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func generateFiles(gen *protogen.Plugin, file *protogen.File) error {
|
||||
for _, service := range file.Services {
|
||||
ServicesName = append(ServicesName, service.GoName)
|
||||
// Generate server file
|
||||
if err := generateServerFile(gen, file, service); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate client file
|
||||
if err := generateClientFile(gen, file, service); err != nil {
|
||||
return err
|
||||
}
|
||||
// // Generate client file
|
||||
// if err := generateClientFile(gen, file, service); err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// Generate logic file
|
||||
if err := generateLogicFile(gen, file, service); err != nil {
|
||||
return err
|
||||
}
|
||||
// // Generate logic file
|
||||
// if err := generateLogicFile(gen, file, service); err != nil {
|
||||
// return err
|
||||
// }
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateNewServerFile(services []string) error {
|
||||
moduleName := getModuleName()
|
||||
|
||||
//create new.go
|
||||
code := tpl.NewFile
|
||||
newImports := []string{
|
||||
"pb \"" + moduleName + "/pb\"",
|
||||
}
|
||||
code = strings.ReplaceAll(code, "{import}", strings.Join(newImports, "\n"))
|
||||
var register []string
|
||||
for _, service := range services {
|
||||
register = append(register, "pb.Register"+service+"Server(srv, New"+service+"Server())")
|
||||
}
|
||||
|
||||
code = strings.ReplaceAll(code, "{register}", strings.Join(register, "\n"))
|
||||
|
||||
// 格式化代码
|
||||
formattedCode, err := format.Source([]byte(code))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to format generated code: %w", err)
|
||||
}
|
||||
|
||||
StringToFile("./server/new.go", string(formattedCode))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateServerFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error {
|
||||
filename := fmt.Sprintf("%s_server.pb.go", strings.ToLower(service.GoName))
|
||||
g := gen.NewGeneratedFile(filename, file.GoImportPath)
|
||||
filename := fmt.Sprintf("./server/%s_server.go", strings.ToLower(service.GoName))
|
||||
moduleName := getModuleName()
|
||||
|
||||
// Package declaration
|
||||
g.P("// Code generated by protoc-gen-layered. DO NOT EDIT.")
|
||||
g.P()
|
||||
g.P("package ", file.GoPackageName)
|
||||
g.P()
|
||||
|
||||
// Imports
|
||||
g.P("import (")
|
||||
g.P("\t\"context\"")
|
||||
g.P("\t\"errors\"")
|
||||
g.P()
|
||||
g.P("\t\"google.golang.org/grpc\"")
|
||||
g.P("\t\"google.golang.org/grpc/codes\"")
|
||||
g.P("\t\"google.golang.org/grpc/status\"")
|
||||
g.P(")")
|
||||
g.P()
|
||||
|
||||
// Server struct
|
||||
g.P("type ", service.GoName, "Server struct {")
|
||||
g.P("\tUnimplemented", service.GoName, "Server")
|
||||
g.P("\tlogic *", service.GoName, "Logic")
|
||||
g.P("}")
|
||||
g.P()
|
||||
|
||||
// NewServer function
|
||||
g.P("func New", service.GoName, "Server(logic *", service.GoName, "Logic) *", service.GoName, "Server {")
|
||||
g.P("\treturn &", service.GoName, "Server{logic: logic}")
|
||||
g.P("}")
|
||||
g.P()
|
||||
|
||||
// Register function
|
||||
g.P("func Register", service.GoName, "Server(s *grpc.Server, logic *", service.GoName, "Logic) {")
|
||||
g.P("\tserver := New", service.GoName, "Server(logic)")
|
||||
g.P("\tRegister", service.GoName, "Server(s, server)")
|
||||
g.P("}")
|
||||
g.P()
|
||||
|
||||
// Service methods
|
||||
for _, method := range service.Methods {
|
||||
g.P("func (s *", service.GoName, "Server) ", methodSignature(g, method), " {")
|
||||
g.P("\t// Add your server-side logic here")
|
||||
g.P("\tresp, err := s.logic.", method.GoName, "(ctx, req)")
|
||||
g.P("\tif err != nil {")
|
||||
g.P("\t\treturn nil, status.Errorf(codes.Internal, \"%v\", err)")
|
||||
g.P("\t}")
|
||||
g.P("\treturn resp, nil")
|
||||
g.P("}")
|
||||
g.P()
|
||||
//create servers.
|
||||
code := tpl.Server
|
||||
imports := []string{
|
||||
"\"" + moduleName + "/internal/logic/" + strings.ToLower(service.GoName) + "\"",
|
||||
"pb \"" + moduleName + "/pb\"",
|
||||
}
|
||||
|
||||
code = strings.ReplaceAll(code, "{import}", strings.Join(imports, "\n"))
|
||||
code = strings.ReplaceAll(code, "{service}", service.GoName)
|
||||
|
||||
var codeMethods []string
|
||||
for _, method := range service.Methods {
|
||||
commit := strings.TrimSpace(method.Comments.Leading.String())
|
||||
methodCode := tpl.Method
|
||||
methodCode = strings.ReplaceAll(methodCode, "{service}", service.GoName)
|
||||
methodCode = strings.ReplaceAll(methodCode, "{serviceLower}", strings.ToLower(service.GoName))
|
||||
methodCode = strings.ReplaceAll(methodCode, "{func}", method.GoName)
|
||||
methodCode = strings.ReplaceAll(methodCode, "{comment}", commit)
|
||||
methodCode = strings.ReplaceAll(methodCode, "{input}", method.Input.GoIdent.GoName)
|
||||
methodCode = strings.ReplaceAll(methodCode, "{output}", method.Output.GoIdent.GoName)
|
||||
codeMethods = append(codeMethods, methodCode)
|
||||
}
|
||||
code = strings.ReplaceAll(code, "{method}", strings.Join(codeMethods, "\n"))
|
||||
|
||||
// 格式化代码
|
||||
formattedCode, err := format.Source([]byte(code))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to format generated code: %w", err)
|
||||
}
|
||||
|
||||
StringToFile(filename, string(formattedCode))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateClientFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error {
|
||||
filename := fmt.Sprintf("%s_client.pb.go", strings.ToLower(service.GoName))
|
||||
g := gen.NewGeneratedFile(filename, file.GoImportPath)
|
||||
|
||||
// Package declaration
|
||||
g.P("// Code generated by protoc-gen-layered. DO NOT EDIT.")
|
||||
g.P()
|
||||
g.P("package ", file.GoPackageName)
|
||||
g.P()
|
||||
|
||||
// Imports
|
||||
g.P("import (")
|
||||
g.P("\t\"context\"")
|
||||
g.P()
|
||||
g.P("\t\"google.golang.org/grpc\"")
|
||||
g.P(")")
|
||||
g.P()
|
||||
|
||||
// Client struct
|
||||
g.P("type ", service.GoName, "Client struct {")
|
||||
g.P("\tcc grpc.ClientConnInterface")
|
||||
g.P("}")
|
||||
g.P()
|
||||
|
||||
// NewClient function
|
||||
g.P("func New", service.GoName, "Client(cc grpc.ClientConnInterface) *", service.GoName, "Client {")
|
||||
g.P("\treturn &", service.GoName, "Client{cc}")
|
||||
g.P("}")
|
||||
g.P()
|
||||
|
||||
// Client methods
|
||||
for _, method := range service.Methods {
|
||||
g.P("func (c *", service.GoName, "Client) ", methodSignature(g, method), " {")
|
||||
g.P("\tout := new(", method.Output.GoIdent, ")")
|
||||
g.P("\terr := c.cc.Invoke(ctx, \"", fullMethodName(file, service, method), "\", req, out)")
|
||||
g.P("\tif err != nil {")
|
||||
g.P("\t\treturn nil, err")
|
||||
g.P("\t}")
|
||||
g.P("\treturn out, nil")
|
||||
g.P("}")
|
||||
g.P()
|
||||
}
|
||||
|
||||
fmt.Println(filename, file.GoImportPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateLogicFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error {
|
||||
filename := fmt.Sprintf("%s_logic.pb.go", strings.ToLower(service.GoName))
|
||||
g := gen.NewGeneratedFile(filename, file.GoImportPath)
|
||||
|
||||
// Package declaration
|
||||
g.P("// Code generated by protoc-gen-layered. DO NOT EDIT.")
|
||||
g.P()
|
||||
g.P("package ", file.GoPackageName)
|
||||
g.P()
|
||||
|
||||
// Imports
|
||||
g.P("import (")
|
||||
g.P("\t\"context\"")
|
||||
g.P("\t\"errors\"")
|
||||
g.P(")")
|
||||
g.P()
|
||||
|
||||
// Logic struct
|
||||
g.P("type ", service.GoName, "Logic struct {")
|
||||
g.P("\t// Add your dependencies here")
|
||||
g.P("}")
|
||||
g.P()
|
||||
|
||||
// NewLogic function
|
||||
g.P("func New", service.GoName, "Logic() *", service.GoName, "Logic {")
|
||||
g.P("\treturn &", service.GoName, "Logic{}")
|
||||
g.P("}")
|
||||
g.P()
|
||||
|
||||
// Logic methods
|
||||
for _, method := range service.Methods {
|
||||
g.P("func (l *", service.GoName, "Logic) ", method.GoName, "(ctx context.Context, req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error) {")
|
||||
g.P("\t// Implement your business logic here")
|
||||
g.P("\treturn nil, errors.New(\"not implemented\")")
|
||||
g.P("}")
|
||||
g.P()
|
||||
}
|
||||
|
||||
fmt.Println(filename, file.GoImportPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func methodSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
|
||||
return fmt.Sprintf("%s(ctx context.Context, req *%s) (*%s, error)",
|
||||
return fmt.Sprintf("%s(ctx context.Context, req pb%s) (*%s, error)",
|
||||
method.GoName,
|
||||
method.Input.GoIdent,
|
||||
method.Output.GoIdent)
|
||||
@@ -200,3 +148,44 @@ func fullMethodName(file *protogen.File, service *protogen.Service, method *prot
|
||||
service.GoName,
|
||||
method.GoName)
|
||||
}
|
||||
|
||||
func getModuleName() (modulePath string) {
|
||||
// 获取当前工作目录
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
fmt.Errorf("failed to get current working directory: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 读取 go.mod 文件
|
||||
modFilePath := filepath.Join(cwd, "go.mod")
|
||||
modFileBytes, err := os.ReadFile(modFilePath)
|
||||
if err != nil {
|
||||
fmt.Errorf("failed to read go.mod file: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 go.mod 文件
|
||||
modFile, err := modfile.Parse(modFilePath, modFileBytes, nil)
|
||||
if err != nil {
|
||||
fmt.Errorf("failed to parse go.mod file: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取模块路径
|
||||
return modFile.Module.Mod.Path
|
||||
}
|
||||
|
||||
// 将字符串写入文件
|
||||
func StringToFile(path, content string) error {
|
||||
startF, err := os.Create(path)
|
||||
if err != nil {
|
||||
return errors.New("os.Create create file " + path + " error:" + err.Error())
|
||||
}
|
||||
defer startF.Close()
|
||||
_, err = io.WriteString(startF, content)
|
||||
if err != nil {
|
||||
return errors.New("io.WriteString to " + path + " error:" + err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user