From fe5f910af409ab7c693d976594fa3fd2bfe478a1 Mon Sep 17 00:00:00 2001 From: Theodor Date: Thu, 15 Jan 2026 13:19:08 +0100 Subject: [PATCH 1/4] edits for fold gen --- Project.toml | 4 ++- src/ChainStorm.jl | 1 + src/flow.jl | 57 +++++++++++++++++++++++++++++++-------- src/model.jl | 13 ++++++--- training_code/training.jl | 36 +++++++++++++++++++++++++ 5 files changed, 95 insertions(+), 16 deletions(-) create mode 100644 training_code/training.jl diff --git a/Project.toml b/Project.toml index 304f560..60b86b1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ChainStorm" uuid = "4b48e6e7-310e-4c37-88f1-40fb96fd3041" -authors = ["murrellb and contributors"] version = "1.0.0-DEV" +authors = ["murrellb and contributors"] [deps] BatchedTransformations = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b" @@ -13,6 +13,7 @@ HuggingFaceApi = "3cc741c3-0c9d-4fbe-84fa-cdec264173de" InvariantPointAttention = "814a5788-1061-474f-a71f-873f087d5d12" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Onion = "fdebf6c2-71da-43a1-b539-c3bc3e09c5c6" ProteinChains = "b8e8f2a5-48d3-44f1-ba0d-c71cb7726ff8" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -29,6 +30,7 @@ HuggingFaceApi = "0.1.0" InvariantPointAttention = "0.1.3" JLD2 = "0.5.13" LinearAlgebra = "1" +OneHotArrays = "0.2.10" Onion = "0.1.0" ProteinChains = "0.7.2" Random = "1" diff --git a/src/ChainStorm.jl b/src/ChainStorm.jl index 76c8e6f..fa3021c 100644 --- a/src/ChainStorm.jl +++ b/src/ChainStorm.jl @@ -1,6 +1,7 @@ module ChainStorm using Flowfusion, ForwardBackward, Flux, RandomFeatureMaps, Onion, InvariantPointAttention, BatchedTransformations, ProteinChains, DLProteinFormats, HuggingFaceApi, JLD2 +using OneHotArrays #Added include("flow.jl") include("model.jl") diff --git a/src/flow.jl b/src/flow.jl index 4c4a315..b4c8323 100644 --- a/src/flow.jl +++ b/src/flow.jl @@ -1,24 +1,54 @@ const rotM = Flowfusion.Rotations(3) schedule_f(t) = 1-(1-t)^2 -const P = (FProcess(BrownianMotion(0.2f0), schedule_f), FProcess(ManifoldProcess(0.2f0), schedule_f), NoisyInterpolatingDiscreteFlow(0.2f0, K = 2, dummy_token = 21)) +#const P = (FProcess(BrownianMotion(0.2f0), schedule_f), FProcess(ManifoldProcess(0.2f0), schedule_f), NoisyInterpolatingDiscreteFlow(0.2f0, K = 2, dummy_token = 21)) + +P = (FProcess(BrownianMotion(0.2f0), schedule_f), FProcess(ManifoldProcess(0.2f0), schedule_f)) function compound_state(b) L,B = size(b.aas) cmask = b.aas .< 100 X1locs = MaskedState(ContinuousState(b.locs), cmask, b.padmask) X1rots = MaskedState(ManifoldState(rotM,eachslice(b.rots, dims=(3,4))), cmask, b.padmask) - X1aas = MaskedState(DiscreteState(21, Flux.onehotbatch(b.aas, 1:21)), cmask, b.padmask) - return (X1locs, X1rots, X1aas) + #X1aas = MaskedState(DiscreteState(21, Flux.onehotbatch(b.aas, 1:21)), cmask, b.padmask) + #return (X1locs, X1rots, X1aas) + return (X1locs, X1rots) +end + +#Added seq_to_masked_state +function seq_to_masked_state(seq_int; K=21) + N = length(seq_int) + B = 1 # batch size + + # Step 1: Create 2D one-hot to get the indices + oh_2d = Flux.onehotbatch(seq_int, 1:K) # K×N OneHotMatrix + + # Step 2: Reshape indices to 3D (N×B matrix) for batch dimension + indices_3d = reshape(oh_2d.indices, N, B) # N×B Matrix{UInt32} + + # Step 3: Create 3D OneHotArray (K×N×B) + oh_3d = OneHotArrays.OneHotArray(indices_3d, K) + + # Step 4: Create DiscreteState with 3D one-hot array + ds = ForwardBackward.DiscreteState(K, oh_3d) + + # Step 5: Create masks (N×B) + cmask = trues(N, B) # conditioning mask + lmask = trues(N, B) # length/padding mask + + # Step 6: Create MaskedState + return Flowfusion.MaskedState(ds, cmask, lmask) end + function zero_state(b) L,B = size(b.aas) cmask = b.aas .< 100 X0locs = MaskedState(ContinuousState(randn(Float32, size(b.locs))), cmask, b.padmask) X0rots = MaskedState(ManifoldState(rotM, reshape(Array{Float32}.(Flowfusion.rand(rotM, L*B)), L, B)), cmask, b.padmask) - X0aas = MaskedState(DiscreteState(21, Flux.onehotbatch(similar(b.aas) .= 21, 1:21)), cmask, b.padmask) - return (X0locs, X0rots, X0aas) + #X0aas = MaskedState(DiscreteState(21, Flux.onehotbatch(similar(b.aas) .= 21, 1:21)), cmask, b.padmask) + #return (X0locs, X0rots, X0aas) + return (X0locs, X0rots) end function training_sample(b) @@ -27,7 +57,8 @@ function training_sample(b) t = rand(Float32, 1, size(b.aas,2)) Xt = bridge(P, X0, X1, t) rotξ = Guide(Xt[2], X1[2]) - return (; t, Xt, X1, rotξ, chainids = b.chainids, resinds = b.resinds) + #return (; t, Xt, X1, rotξ, chainids = b.chainids, resinds = b.resinds) + return (; t, Xt, X1, rotξ, aas = b.aas, chainids = b.chainids, resinds = b.resinds) end function losses(hatframes, aalogits, ts) @@ -35,21 +66,25 @@ function losses(hatframes, aalogits, ts) hatloc, hatrot, hataas = (values(translation(hatframes)), rotangent, aalogits) l_loc = floss(P[1], hatloc, ts.X1[1], scalefloss(P[1], ts.t, 2, 0.2f0)) / 2 l_rot = floss(P[2], hatrot, ts.rotξ, scalefloss(P[2], ts.t, 2, 0.2f0)) / 10 - l_aas = floss(P[3], hataas, ts.X1[3], scalefloss(P[3], ts.t, 1, 0.2f0)) / 100 - return l_loc, l_rot, l_aas + #l_aas = floss(P[3], hataas, ts.X1[3], scalefloss(P[3], ts.t, 1, 0.2f0)) / 100 + #return l_loc, l_rot, l_aas + return l_loc, l_rot end function flowX1predictor(X0, b, model; d = identity, smooth = 0) batch_dim = size(tensor(X0[1]), 4) - f, _ = model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds)) + #f, aalogtis = model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds)) + f = model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.aas), d(b.chainids), d(b.resinds)) prev_trans = values(translation(f)) T = eltype(prev_trans) function m(t, Xt) print(".") - f, aalogits = model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = f) + #f, aalogits = model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = f) + f = model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.aas), d(b.chainids), d(b.resinds), sc_frames = f) values(translation(f)) .= prev_trans .* T(smooth) .+ values(translation(f)) .* T(1-smooth) prev_trans = values(translation(f)) - return cpu(values(translation(f))), ManifoldState(rotM, eachslice(cpu(values(linear(f))), dims=(3,4))), cpu(softmax(aalogits)) + #return cpu(values(translation(f))), ManifoldState(rotM, eachslice(cpu(values(linear(f))), dims=(3,4))), cpu(softmax(aalogits)) + return cpu(values(translation(f))), ManifoldState(rotM, eachslice(cpu(values(linear(f))), dims=(3,4))) end return m end diff --git a/src/model.jl b/src/model.jl index fab914b..6a71f8e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -23,7 +23,9 @@ function ChainStormV1(dim::Int = 384, depth::Int = 6, f_depth::Int = 6) ) return ChainStormV1(layers) end -function (fc::ChainStormV1)(t, Xt, chainids, resinds; sc_frames = nothing) + +#function (fc::ChainStormV1)(t, Xt, chainids, resinds; sc_frames = nothing) +function (fc::ChainStormV1)(t, Xt, aas, chainids, resinds; sc_frames = nothing) l = fc.layers pmask = Flux.Zygote.@ignore self_att_padding_mask(Xt[1].lmask) pre_z = Flux.Zygote.@ignore l.pair_rff(pair_encode(resinds, chainids)) @@ -31,7 +33,9 @@ function (fc::ChainStormV1)(t, Xt, chainids, resinds; sc_frames = nothing) t_rff = Flux.Zygote.@ignore l.t_rff(t) cond = reshape(l.cond_t_encoding(t_rff), :, 1, size(t,2)) frames = Translation(tensor(Xt[1])) ∘ Rotation(tensor(Xt[2])) - AA_one_hots = tensor(Xt[3]) + AA_one_hots = tensor(Flux.onehotbatch(aas, 1:21)) + #AA_one_hots = tensor(Xt[3]) + x = l.AAencoder(AA_one_hots .+ 0) for i in 1:l.depth if sc_frames !== nothing @@ -44,6 +48,7 @@ function (fc::ChainStormV1)(t, Xt, chainids, resinds; sc_frames = nothing) frames = l.framemovers[i - l.depth + l.f_depth](frames, x, t = t) end end - aa_logits = l.AAdecoder(x .+ reshape(l.AApre_t_encoding(t_rff), :, 1, size(t,2))) - return frames, aa_logits + #aa_logits = l.AAdecoder(x .+ reshape(l.AApre_t_encoding(t_rff), :, 1, size(t,2))) + #return frames, aa_logits + return frames end \ No newline at end of file diff --git a/training_code/training.jl b/training_code/training.jl new file mode 100644 index 0000000..a75ea61 --- /dev/null +++ b/training_code/training.jl @@ -0,0 +1,36 @@ +#In addition to ChainStorm, also install these: +using Pkg +Pkg.add(["JLD2", "Flux", "CannotWaitForTheseOptimisers", "LearningSchedules", "DLProteinFormats"]) +Pkg.add(["CUDA", "cuDNN"]) + +using ChainStorm, DLProteinFormats, Flux, CannotWaitForTheseOptimisers, LearningSchedules, JLD2 +using DLProteinFormats: load, PDBSimpleFlat, batch_flatrecs, sample_batched_inds, length2batch +using CUDA +device = gpu + +dat = load(PDBSimpleFlat); + +model = ChainStormV1(384, 3, 3) |> device +sched = burnin_learning_schedule(0.000005f0, 0.001f0, 1.05f0, 0.99995f0) +opt_state = Flux.setup(Muon(eta = sched.lr), model) + +for epoch in 1:100 + batchinds = sample_batched_inds(dat,l2b = length2batch(1500, 1.9)) + for (i, b) in enumerate(batchinds) + bat = batch_flatrecs(dat[b]) + ts = training_sample(bat) |> device + sc_frames = nothing + if epoch > 1 && rand() < 0.5 + sc_frames, _ = model(ts.t, ts.Xt, ts.chainids, ts.resinds) + end + l, grad = Flux.withgradient(model) do m + fr, aalogs = m(ts.t, ts.Xt, ts.chainids, ts.resinds, sc_frames = sc_frames) + l_loc, l_rot, l_aas = losses(fr, aalogs, ts) + l_loc + l_rot + l_aas + end + Flux.update!(opt_state, model, grad[1]) + (mod(i, 10) == 0) && Flux.adjust!(opt_state, next_rate(sched)) + println(l) + end + jldsave("model_epoch_$epoch.jld", model_state = Flux.state(cpu(model)), opt_state=cpu(opt_state)) +end \ No newline at end of file From 386016272464aff7ab4de5cbdf7dd6fd0601605b Mon Sep 17 00:00:00 2001 From: Theodor Date: Tue, 20 Jan 2026 17:32:15 +0100 Subject: [PATCH 2/4] saving --- src/flow.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/flow.jl b/src/flow.jl index b4c8323..15dcbeb 100644 --- a/src/flow.jl +++ b/src/flow.jl @@ -61,9 +61,11 @@ function training_sample(b) return (; t, Xt, X1, rotξ, aas = b.aas, chainids = b.chainids, resinds = b.resinds) end -function losses(hatframes, aalogits, ts) +#function losses(hatframes, aalogits, ts) +function losses(hatframes, ts) rotangent = Flowfusion.so3_tangent_coordinates_stack(values(linear(hatframes)), tensor(ts.Xt[2])) - hatloc, hatrot, hataas = (values(translation(hatframes)), rotangent, aalogits) + #hatloc, hatrot, hataas = (values(translation(hatframes)), rotangent, aalogits) + hatloc, hatrot = (values(translation(hatframes)), rotangent) l_loc = floss(P[1], hatloc, ts.X1[1], scalefloss(P[1], ts.t, 2, 0.2f0)) / 2 l_rot = floss(P[2], hatrot, ts.rotξ, scalefloss(P[2], ts.t, 2, 0.2f0)) / 10 #l_aas = floss(P[3], hataas, ts.X1[3], scalefloss(P[3], ts.t, 1, 0.2f0)) / 100 From b0e66d1c6c69e83b0f6ca41e8579882cd6c8de69 Mon Sep 17 00:00:00 2001 From: Theodor Date: Fri, 23 Jan 2026 14:23:07 +0100 Subject: [PATCH 3/4] training finetune code --- training_code/training_finetune.jl | 58 ++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 training_code/training_finetune.jl diff --git a/training_code/training_finetune.jl b/training_code/training_finetune.jl new file mode 100644 index 0000000..97329b0 --- /dev/null +++ b/training_code/training_finetune.jl @@ -0,0 +1,58 @@ +using Dates +using ChainStorm, DLProteinFormats, Flux, CannotWaitForTheseOptimisers, LearningSchedules, JLD2 +using DLProteinFormats: load, PDBSimpleFlat, batch_flatrecs, sample_batched_inds, length2batch +using CUDA, cuDNN + +fallback_mask(x) = any(size(x) .== 21) + +device = gpu +dat = load(PDBSimpleFlat); +sample = dat[9] +print(sample.len) +L = length(sample.chainids) +c = (;chainids = reshape(sample.chainids, :, 1), + resinds = view(sample.resinds, :, 1), + padmask = trues(L, 1), + aas = reshape(sample.AAs, :, 1), + locs = reshape(sample.locs, 3, 1, L, 1)) + +model = ChainStorm.load_model() |> device +model_cpu = nothing + +sched = burnin_learning_schedule(0.00001f0, 0.000250f0, 1.05f0, 0.999995f0); +opt_state = Flux.setup(Muon(eta = sched.lr, fallback = fallback_mask), model); + +start_time = Dates.format(Dates.now(), "yyyy-mm-dd_HHMM") + +gen_sample_counter = 0 +sample_counter = 10000 +for epoch in 1:100 + batchinds = sample_batched_inds(dat,l2b = length2batch(1500, 1.9)) + @info "Epoch $epoch" + for (i, b) in enumerate(batchinds) + sample_counter += 1 + bat = batch_flatrecs(dat[b]) + ts = training_sample(bat) |> device + sc_frames = nothing + if epoch > 1 && rand() < 0.5 + sc_frames = model(ts.t, ts.Xt, ts.aas, ts.chainids, ts.resinds) + end + l, grad = Flux.withgradient(model) do m + fr = m(ts.t, ts.Xt, ts.aas, ts.chainids, ts.resinds, sc_frames = sc_frames) + l_loc, l_rot = losses(fr, ts) + l_loc + l_rot + end + Flux.update!(opt_state, model, grad[1]) + (mod(i, 10) == 0) && Flux.adjust!(opt_state, next_rate(sched)) + println(l) + println("Sample counter: ", sample_counter, " Gen sample counter: ", gen_sample_counter) + if sample_counter >= 10000 + sample_counter = 0 + gen_sample_counter += 1 + model_cpu = cpu(model) + g = flow_quickgen(c, model_cpu) + export_pdb("notrunk_toymodel/gens/gen_$(start_time)_foldfinetune_sample_$(gen_sample_counter).pdb", (g..., ChainStorm.seq_to_masked_state(sample.AAs)), c.chainids, c.resinds) + end + end + jldsave("model_epoch_$epoch.jld", model_state = Flux.state(cpu(model)), opt_state=cpu(opt_state)) +end From 4952118a4337cbf939b0adfdc2c2df4b9c807905 Mon Sep 17 00:00:00 2001 From: Theodor Date: Fri, 6 Feb 2026 18:06:52 +0100 Subject: [PATCH 4/4] changed onion req to 0.2 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 60b86b1..244a07e 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ InvariantPointAttention = "0.1.3" JLD2 = "0.5.13" LinearAlgebra = "1" OneHotArrays = "0.2.10" -Onion = "0.1.0" +Onion = "0.2.0" ProteinChains = "0.7.2" Random = "1" RandomFeatureMaps = "0.2.2"