From a51f93ee312830f76d89e20aaa652bd6d2bc1056 Mon Sep 17 00:00:00 2001 From: Cody Date: Wed, 25 Jun 2025 19:13:20 -0400 Subject: [PATCH] Further work to support ESMFold mining --- go.sum | 1 + models/esmfold.go | 187 +++++++++++++++++++++++++++------------------- 2 files changed, 111 insertions(+), 77 deletions(-) diff --git a/go.sum b/go.sum index a4418ce..cefb4ba 100644 --- a/go.sum +++ b/go.sum @@ -1243,6 +1243,7 @@ golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/models/esmfold.go b/models/esmfold.go index 7576984..221d4e0 100644 --- a/models/esmfold.go +++ b/models/esmfold.go @@ -10,19 +10,25 @@ import ( "gobius/config" "gobius/ipfs" "gobius/utils" - "io" "net/http" "path/filepath" "strings" "time" - + "os" + "os/exec" + "crypto/tls" + "github.com/google/uuid" "github.com/mr-tron/base58" "github.com/rs/zerolog" ) -type ESMFoldV1Input struct { - Sequence string `json:"sequence"` +type ESMFoldV1Inner struct { + Prompt string `json:"prompt"` +} + +type ESMFoldV1Prompt struct { + Input ESMFoldV1Inner `json:"input"` } type ESMFoldV1Model struct { @@ -51,7 +57,7 @@ var ESMFoldV1ModelTemplate = Model{ "version": 1, "input": []map[string]any{ { - "variable": "sequence", + "variable": "prompt", "type": "string", "required": true, "default": "", @@ -79,40 +85,46 @@ func NewESMFoldV1Model(client ipfs.IPFSClient, appConfig *config.AppConfig, logg return nil } - http := &http.Client{ - Transport: &http.Transport{MaxIdleConnsPerHost: 10}, - } + httpClient := &http.Client{ + Transport: &http.Transport{ + MaxIdleConnsPerHost: 10, + ForceAttemptHTTP2: true, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return fmt.Errorf("too many redirects") + } + logger.Debug().Str("redirect_url", req.URL.String()).Msg("Following redirect") + return nil + }, + Timeout: 600 * time.Second, + } timeout := 600 * time.Second ipfsTimeout := 30 * time.Second - // Use model.ID (the hex string CID) as the key for the Cog map cogConfig, ok := appConfig.ML.Cog[model.ID] if ok { - // Parse inference timeout only if the string is not empty if cogConfig.HttpTimeout != "" { parsedTimeout, err := time.ParseDuration(cogConfig.HttpTimeout) if err != nil { logger.Warn().Err(err).Str("model", model.ID).Str("config_timeout", cogConfig.HttpTimeout).Msg("failed to parse model timeout from cog config, using default 120s") - // Keep default timeout } else { timeout = parsedTimeout } - } // Else: HttpTimeout is empty, silently use the default - - // Parse IPFS timeout only if the string is not empty + } + if cogConfig.IpfsTimeout != "" { parsedIpfsTimeout, err := time.ParseDuration(cogConfig.IpfsTimeout) if err != nil { logger.Warn().Err(err).Str("model", model.ID).Str("config_ipfs_timeout", cogConfig.IpfsTimeout).Msg("failed to parse IPFS timeout from cog config, using default 30s") - // Keep default ipfsTimeout } else { ipfsTimeout = parsedIpfsTimeout } - } // Else: IpfsTimeout is empty, silently use the default + } } - // perform validation on the template templateMeta, ok := ESMFoldV1ModelTemplate.Template.(map[string]any) if !ok { logger.Error().Str("model", model.ID).Msg("invalid template format") @@ -144,7 +156,7 @@ func NewESMFoldV1Model(client ipfs.IPFSClient, appConfig *config.AppConfig, logg }, }, ipfs: client, - client: http, + client: httpClient, logger: logger, inputFields: inputFields, } @@ -175,8 +187,8 @@ func (m *ESMFoldV1Model) HydrateInput(preprocessedInput map[string]any, seed uin } } - var inner ESMFoldV1Input - inner.Sequence, _ = input["sequence"].(string) + var inner ESMFoldV1Prompt + inner.Input.Prompt, _ = input["prompt"].(string) return inner, nil } @@ -184,63 +196,84 @@ func (m *ESMFoldV1Model) GetID() string { return m.Model.ID } -func (m *ESMFoldV1Model) GetFiles(ctx context.Context, gpu *common.GPU, taskid string, input any) ([]ipfs.IPFSFile, error) { - if err := ctx.Err(); err != nil { - m.logger.Warn().Err(err).Str("task", taskid).Msg("Context canceled before GetFiles execution") - return nil, err - } - - inner, ok := input.(ESMFoldV1Input) - if !ok { - return nil, fmt.Errorf("invalid input type: expected ESMFoldV1Input") - } - - payload := map[string]string{"sequence": inner.Sequence} - marshaledInput, err := json.Marshal(payload) - if err != nil { - return nil, fmt.Errorf("failed to marshal input: %w", err) - } - - endpoint := fmt.Sprintf("%s:8080/predict", strings.TrimSuffix(gpu.Url, "/")) - req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(marshaledInput)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := m.client.Do(req) - if err != nil { - if errors.Is(err, context.DeadlineExceeded) { - m.logger.Error().Err(err).Str("task", taskid).Str("gpu", endpoint).Msg("model inference request timed out") - return nil, fmt.Errorf("model inference timed out: %w", err) - } - return nil, fmt.Errorf("failed to POST to GPU: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - if resp.StatusCode == http.StatusConflict { - m.logger.Warn().Str("task", taskid).Str("gpu", endpoint).Int("status", resp.StatusCode).Str("body", string(bodyBytes)).Msg("resource busy") - return nil, ErrResourceBusy - } - return nil, fmt.Errorf("server returned non-200 status: %d - %s", resp.StatusCode, string(bodyBytes)) - } - - pdbData, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read model response body: %w", err) - } - - if len(pdbData) == 0 { - return nil, errors.New("model returned empty PDB data") - } - - fileName := fmt.Sprintf("%d.%s.pdb", gpu.ID, uuid.New().String()) - path := filepath.Join(m.config.CachePath, fileName) - buffer := bytes.NewBuffer(pdbData) - - return []ipfs.IPFSFile{{Name: "result.pdb", Path: path, Buffer: buffer}}, nil +func (m *ESMFoldV1Model) GetFiles(ctx context.Context, gpu *common.GPU, taskid string, input interface{}) ([]ipfs.IPFSFile, error) { + if err := ctx.Err(); err != nil { + m.logger.Warn().Err(err).Str("task", taskid).Msg("Context canceled before GetFiles execution") + return nil, err + } + + inner, ok := input.(ESMFoldV1Prompt) + if !ok { + return nil, fmt.Errorf("invalid input type: expected ESMFoldV1Prompt") + } + + payload := map[string]string{"sequence": inner.Input.Prompt} + marshaledInput, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal input: %w", err) + } + + m.logger.Debug().Str("marshaledInput", string(marshaledInput)).Msg("Prepared input payload") + m.logger.Debug().Str("gpu_url", gpu.Url).Msg("GPU URL") + + endpoint := fmt.Sprintf("%s/predict", strings.TrimSuffix(gpu.Url, "/")) + m.logger.Debug().Str("endpoint", endpoint).Msg("Constructed endpoint URL") + + tmpFile, err := os.CreateTemp("", "curl_output_*.pdb") + if err != nil { + m.logger.Error().Err(err).Str("task", taskid).Msg("Failed to create temporary file") + return nil, fmt.Errorf("failed to create temporary file: %w", err) + } + tmpFileName := tmpFile.Name() + tmpFile.Close() + + curlArgs := []string{ + "-v", + "-k", + "-X", "POST", + "-H", "Content-Type: application/json", + "-H", "Accept: chemical/x-pdb", + "-d", string(marshaledInput), + "-o", tmpFileName, + endpoint, + } + + m.logger.Debug().Str("curl_command", fmt.Sprintf("curl %s", strings.Join(curlArgs, " "))).Msg("Executing curl") + + cmd := exec.CommandContext(ctx, "curl", curlArgs...) + var stderr bytes.Buffer + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + m.logger.Error(). + Err(err). + Str("task", taskid). + Str("gpu", endpoint). + Str("stderr", stderr.String()). + Msg("Failed to execute curl") + os.Remove(tmpFileName) + return nil, fmt.Errorf("failed to execute curl: %w, stderr: %s", err, stderr.String()) + } + + pdbData, err := os.ReadFile(tmpFileName) + if err != nil { + m.logger.Error().Err(err).Str("task", taskid).Msg("Failed to read curl output file") + os.Remove(tmpFileName) + return nil, fmt.Errorf("failed to read curl output file: %w", err) + } + os.Remove(tmpFileName) + + if len(pdbData) == 0 { + m.logger.Error().Str("task", taskid).Msg("curl returned empty PDB data") + return nil, errors.New("curl returned empty PDB data") + } + + m.logger.Debug().Int("data_length", len(pdbData)).Msg("Received curl response") + + fileName := fmt.Sprintf("%d.%s.pdb", gpu.ID, uuid.New().String()) + path := filepath.Join(m.config.CachePath, fileName) + buffer := bytes.NewBuffer(pdbData) + + return []ipfs.IPFSFile{{Name: "result.pdb", Path: path, Buffer: buffer}}, nil } func (m *ESMFoldV1Model) GetCID(ctx context.Context, gpu *common.GPU, taskid string, input any) ([]byte, error) {