Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 68 additions & 47 deletions tools/operations-gen/internal/families/evm/evm.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ const (
anyType = "any"
// emptyReturnType is the Go type used for read functions with no return values.
emptyReturnType = "struct{}"

abiTypeFunction = "function"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introduced some constants too

abiTypeConstructor = "constructor"
stateMutabilityView = "view"
stateMutabilityPure = "pure"
accessOwner = "owner"
accessPublic = "public"
accessControlAllCallers = "AllCallersAllowed"
accessControlOnlyOwner = "OnlyOwner"
)

// evmTypeMap maps Solidity types to their Go equivalents.
Expand All @@ -34,6 +43,9 @@ var evmTypeMap = map[string]string{
"uint8": "uint8",
"uint16": "uint16",
"uint32": "uint32",
"uint40": "uint64",
"uint48": "uint64",
"uint56": "uint64",
"uint64": "uint64",
"uint96": "*big.Int",
"uint128": "*big.Int",
Expand Down Expand Up @@ -104,13 +116,12 @@ type structDef struct {
}

type functionInfo struct {
Name string
StateMutability string
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This StateMutability was never used

Parameters []parameterInfo
ReturnParams []parameterInfo
IsWrite bool
CallMethod string // Method name, with numeric suffix for overloaded functions
HasOnlyOwner bool
Name string
Parameters []parameterInfo
ReturnParams []parameterInfo
IsWrite bool
CallMethod string // Method name, with numeric suffix for overloaded functions
HasOnlyOwner bool
}

type parameterInfo struct {
Expand Down Expand Up @@ -385,12 +396,8 @@ func prepareTemplateData(info *contractInfo) templateData {
func prepareParameters(params []parameterInfo) []parameterData {
result := make([]parameterData, 0, len(params))
for i, param := range params {
name := sanitizeFieldName(param.Name)
if name == "" {
name = fmt.Sprintf("Field%d", i)
}
result = append(result, parameterData{
GoName: name,
GoName: fieldNameOrIndex(param.Name, i),
GoType: param.GoType,
JSONTag: param.Name,
})
Expand All @@ -410,13 +417,9 @@ func buildCallArgs(fi *functionInfo) (argsType string, callArgs string) {
}

argsType = fi.Name + "Args"
var callArgsList []string
callArgsList := make([]string, 0, len(fi.Parameters))
for i, p := range fi.Parameters {
fieldName := sanitizeFieldName(p.Name)
if fieldName == "" {
fieldName = fmt.Sprintf("Field%d", i)
}
callArgsList = append(callArgsList, "args."+fieldName)
callArgsList = append(callArgsList, "args."+fieldNameOrIndex(p.Name, i))
}
callArgs = ", " + strings.Join(callArgsList, ", ")

Expand All @@ -436,9 +439,9 @@ func resolveReturnType(fi *functionInfo) string {
func prepareWriteOp(fi *functionInfo) operationData {
argsType, callArgs := buildCallArgs(fi)

accessControl := "AllCallersAllowed"
accessControl := accessControlAllCallers
if fi.HasOnlyOwner {
accessControl = "OnlyOwner"
accessControl = accessControlOnlyOwner
}

return operationData{
Expand Down Expand Up @@ -504,7 +507,7 @@ func prepareContractMethod(fi *functionInfo, isWrite bool) contractMethodData {

var methodBody string
if isWrite {
methodBody = buildWriteMethodBody(fi, methodArgs)
methodBody = buildWriteMethodBody(fi.CallMethod, methodArgs)
} else {
methodBody = buildReadMethodBody(fi, methodArgs, resolveReturnType(fi))
}
Expand All @@ -519,13 +522,13 @@ func prepareContractMethod(fi *functionInfo, isWrite bool) contractMethodData {
}

// buildWriteMethodBody generates the body of a write (transact) method.
func buildWriteMethodBody(fi *functionInfo, methodArgs []string) string {
func buildWriteMethodBody(callMethod string, methodArgs []string) string {
if len(methodArgs) > 0 {
return fmt.Sprintf("return c.contract.Transact(opts, \"%s\", %s)",
fi.CallMethod, strings.Join(methodArgs, ", "))
callMethod, strings.Join(methodArgs, ", "))
}

return fmt.Sprintf("return c.contract.Transact(opts, \"%s\")", fi.CallMethod)
return fmt.Sprintf("return c.contract.Transact(opts, \"%s\")", callMethod)
}

// buildReadMethodBody generates the body of a read (call) method.
Expand Down Expand Up @@ -568,12 +571,8 @@ func buildMultiReturnMethodBody(fi *functionInfo, callArgsStr, returnType string
fmt.Fprintf(&b, "\t\treturn *outstruct, err\n")
fmt.Fprintf(&b, "\t}\n\n")
for i, p := range fi.ReturnParams {
fieldName := sanitizeFieldName(p.Name)
if fieldName == "" {
fieldName = fmt.Sprintf("Field%d", i)
}
fmt.Fprintf(&b, "\toutstruct.%s = *abi.ConvertType(out[%d], new(%s)).(*%s)\n",
fieldName, i, p.GoType, p.GoType)
fieldNameOrIndex(p.Name, i), i, p.GoType, p.GoType)
}
fmt.Fprintf(&b, "\n\treturn *outstruct, nil")

Expand Down Expand Up @@ -641,7 +640,7 @@ func readABIAndBytecode(

func extractConstructor(info *contractInfo, abiEntries []ABIEntry, typeMap map[string]string) {
for _, entry := range abiEntries {
if entry.Type == "constructor" {
if entry.Type == abiTypeConstructor {
info.Constructor = parseABIFunction(entry, info.PackageName, typeMap)
break
}
Expand All @@ -657,9 +656,9 @@ func extractFunctions(info *contractInfo, funcConfigs []evmFunctionConfig, abiEn

for _, fi := range funcInfos {
switch funcCfg.Access {
case "owner":
case accessOwner:
fi.HasOnlyOwner = true
case "public", "":
case accessPublic, "":
fi.HasOnlyOwner = false
default:
return fmt.Errorf("unknown access control '%s' for function %s (use 'owner' or 'public')",
Expand All @@ -679,7 +678,7 @@ func extractFunctions(info *contractInfo, funcConfigs []evmFunctionConfig, abiEn
func findFunctionInABI(entries []ABIEntry, funcName string, packageName string, typeMap map[string]string) []*functionInfo {
var candidates []ABIEntry
for _, entry := range entries {
if entry.Type == "function" && strings.EqualFold(entry.Name, funcName) {
if entry.Type == abiTypeFunction && strings.EqualFold(entry.Name, funcName) {
candidates = append(candidates, entry)
}
}
Expand Down Expand Up @@ -710,10 +709,9 @@ func findFunctionInABI(entries []ABIEntry, funcName string, packageName string,
// IsWrite is determined by stateMutability: anything other than "view" or "pure" is a write.
func parseABIFunction(entry ABIEntry, packageName string, typeMap map[string]string) *functionInfo {
fi := &functionInfo{
Name: core.Capitalize(entry.Name),
StateMutability: entry.StateMutability,
CallMethod: entry.Name,
IsWrite: entry.StateMutability != "view" && entry.StateMutability != "pure",
Name: core.Capitalize(entry.Name),
CallMethod: entry.Name,
IsWrite: entry.StateMutability != stateMutabilityView && entry.StateMutability != stateMutabilityPure,
}

for i, input := range entry.Inputs {
Expand Down Expand Up @@ -768,17 +766,25 @@ func parseABIParam(param ABIParam, packageName string, typeMap map[string]string

// solidityToGoType maps a Solidity type string to its Go equivalent using typeMap.
func solidityToGoType(solidityType string, typeMap map[string]string) string {
baseType := strings.TrimSuffix(solidityType, "[]")
if goType, ok := typeMap[baseType]; ok {
if strings.HasSuffix(solidityType, "[]") {
return "[]" + goType
// Array: uint8[] → []uint8, uint8[32] → [32]uint8
if i := strings.LastIndexByte(solidityType, '['); i != -1 {
// Guard malformed type strings like "[" or "uint8[" to avoid slicing panics.
if !strings.HasSuffix(solidityType, "]") || i+1 > len(solidityType)-1 {
return anyType
}
sizeStr := solidityType[i+1 : len(solidityType)-1]
_, numErr := strconv.Atoi(sizeStr)
if sizeStr == "" || numErr == nil {
inner := solidityToGoType(solidityType[:i], typeMap)
Comment on lines +776 to +778
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check the i value here to prevent index out of range panics when doing solidityType[i+1 : len(solidityType)-1]

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good one, added

if inner != anyType {
return "[" + sizeStr + "]" + inner
}

return goType
return anyType
}
}

if strings.HasPrefix(baseType, "tuple") {
return anyType
if goType, ok := typeMap[solidityType]; ok {
return goType
}

return anyType
Expand Down Expand Up @@ -847,12 +853,17 @@ func checkNeedsBigInt(info *contractInfo) bool {

// ---- Naming utilities ----

// trimUnderscores strips all leading underscores from s.
func trimUnderscores(s string) string {
return strings.TrimLeft(s, "_")
}

// sanitizeFieldName strips leading underscores and capitalizes the result,
// producing a valid exported Go identifier for struct fields.
// Returns "" when the result would start with a digit (e.g. "_1" → ""); callers fall back to "Field%d".
// e.g. "_to" → "To", "_value" → "Value", "balance" → "Balance"
func sanitizeFieldName(name string) string {
trimmed := strings.TrimLeft(name, "_")
trimmed := trimUnderscores(name)
if len(trimmed) == 0 || (trimmed[0] >= '0' && trimmed[0] <= '9') {
return ""
}
Expand All @@ -865,14 +876,24 @@ func sanitizeFieldName(name string) string {
// Returns "" when the result would start with a digit (e.g. "_1" → ""); callers fall back to "arg%d".
// e.g. "_to" → "to", "_value" → "value"
func sanitizeParamName(name string) string {
name = strings.TrimLeft(name, "_")
name = trimUnderscores(name)
if len(name) == 0 || (name[0] >= '0' && name[0] <= '9') {
return ""
}

return strings.ToLower(name[:1]) + name[1:]
}

// fieldNameOrIndex returns the sanitized exported field name for a struct field,
// or "Field{i}" when the sanitized result would be empty (e.g. numeric-only names).
func fieldNameOrIndex(name string, i int) string {
if n := sanitizeFieldName(name); n != "" {
return n
}

return fmt.Sprintf("Field%d", i)
}

func toSnakeCase(s string) string {
var result []rune
runes := []rune(s)
Expand Down
130 changes: 130 additions & 0 deletions tools/operations-gen/internal/families/evm/evm_golden_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package evm

import (
"flag"
"os"
"path/filepath"
"testing"
"text/template"

"gopkg.in/yaml.v3"

"github.com/smartcontractkit/chainlink-deployments-framework/tools/operations-gen/internal/core"
)

var update = flag.Bool("update", false, "update golden files")

// TestGenerateLinkToken is an end-to-end test that runs the generator against the
// real LinkToken ABI/bytecode and verifies that the generated output matches golden.
func TestGenerateLinkToken(t *testing.T) {
t.Parallel()
runGoldenGenerationTest(t, "operations_gen_config.yaml", "link_token.golden.go")
}

// TestGenerateManyChainMultiSig verifies generation against an MCMS-like ABI fixture.
func TestGenerateManyChainMultiSig(t *testing.T) {
t.Parallel()
runGoldenGenerationTest(t, "operations_gen_mcms_config.yaml", "many_chain_multi_sig.golden.go")
}

func runGoldenGenerationTest(t *testing.T, configFileName string, goldenFileName string) {
t.Helper()

evmTestdataDir, err := filepath.Abs(filepath.Join("..", "..", "..", "testdata", "evm"))
if err != nil {
t.Fatal(err)
}

configData, err := os.ReadFile(filepath.Join(evmTestdataDir, configFileName))
if err != nil {
t.Fatalf("reading config: %v", err)
}

var cfg core.Config
if err = yaml.Unmarshal(configData, &cfg); err != nil {
t.Fatalf("parsing config: %v", err)
}

// Override paths: inputs point to fixture dirs, output to a temp dir.
cfg.Input = mustYAMLNode(t, evmInputConfig{
ABIBasePath: filepath.Join(evmTestdataDir, "abi"),
BytecodeBasePath: filepath.Join(evmTestdataDir, "bytecode"),
})
tmpDir := t.TempDir()
cfg.Output = mustYAMLNode(t, evmOutputConfig{BasePath: tmpDir})
cfg.ConfigDir = ""

handler := Handler{}
tmpl, err := loadTemplateForTest()
if err != nil {
t.Fatalf("loadTemplate: %v", err)
}

if err = handler.Generate(cfg, tmpl); err != nil {
t.Fatalf("Generate: %v", err)
}

// Derive the output path from the first contract in the config, mirroring extractContractInfo.
var contractCfgs []evmContractConfig
if err = cfg.Contracts.Decode(&contractCfgs); err != nil || len(contractCfgs) == 0 {
t.Fatalf("decoding contract configs: %v", err)
}
first := contractCfgs[0]
pkgName := first.PackageName
if pkgName == "" {
pkgName = toSnakeCase(first.Name)
}
vPath := core.VersionToPath(first.Version)
if first.VersionPath != "" {
vPath = first.VersionPath
}
outputPath := core.ContractOutputPath(tmpDir, vPath, pkgName)

got, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("reading generated file %s: %v", outputPath, err)
}

goldenPath := filepath.Join(evmTestdataDir, goldenFileName)

if *update {
if err = os.WriteFile(goldenPath, got, 0o600); err != nil {
t.Fatalf("writing golden file: %v", err)
}

return
}

want, err := os.ReadFile(goldenPath)
if err != nil {
t.Fatalf("reading golden file %s: %v (run with -update to create it)", goldenPath, err)
}

if string(got) != string(want) {
t.Errorf("generated output does not match golden file %s\n\nrun: go test ./... -run %s -update", goldenPath, t.Name())
}
}

func mustYAMLNode(t *testing.T, value any) yaml.Node {
t.Helper()
b, err := yaml.Marshal(value)
if err != nil {
t.Fatalf("marshal yaml node: %v", err)
}
var n yaml.Node
if err = yaml.Unmarshal(b, &n); err != nil {
t.Fatalf("unmarshal yaml node: %v", err)
}

return n
}

func loadTemplateForTest() (*template.Template, error) {
path := filepath.Join("..", "..", "..", "templates", "evm", "operations.tmpl")
content, err := os.ReadFile(path)
if err != nil {
return nil, err
}

return template.New("operations").Parse(string(content))
}
Loading
Loading