Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ChainStorm"
uuid = "4b48e6e7-310e-4c37-88f1-40fb96fd3041"
authors = ["murrellb <murrellb@gmail.com> and contributors"]
version = "1.0.0-DEV"
authors = ["murrellb <murrellb@gmail.com> and contributors"]

[deps]
BatchedTransformations = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b"
Expand All @@ -13,6 +13,7 @@ HuggingFaceApi = "3cc741c3-0c9d-4fbe-84fa-cdec264173de"
InvariantPointAttention = "814a5788-1061-474f-a71f-873f087d5d12"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Onion = "fdebf6c2-71da-43a1-b539-c3bc3e09c5c6"
ProteinChains = "b8e8f2a5-48d3-44f1-ba0d-c71cb7726ff8"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -29,7 +30,8 @@ HuggingFaceApi = "0.1.0"
InvariantPointAttention = "0.1.3"
JLD2 = "0.5.13"
LinearAlgebra = "1"
Onion = "0.1.0"
OneHotArrays = "0.2.10"
Onion = "0.2.0"
ProteinChains = "0.7.2"
Random = "1"
RandomFeatureMaps = "0.2.2"
Expand Down
1 change: 1 addition & 0 deletions src/ChainStorm.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ChainStorm

using Flowfusion, ForwardBackward, Flux, RandomFeatureMaps, Onion, InvariantPointAttention, BatchedTransformations, ProteinChains, DLProteinFormats, HuggingFaceApi, JLD2
using OneHotArrays #Added

include("flow.jl")
include("model.jl")
Expand Down
63 changes: 50 additions & 13 deletions src/flow.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,54 @@
const rotM = Flowfusion.Rotations(3)

schedule_f(t) = 1-(1-t)^2
const P = (FProcess(BrownianMotion(0.2f0), schedule_f), FProcess(ManifoldProcess(0.2f0), schedule_f), NoisyInterpolatingDiscreteFlow(0.2f0, K = 2, dummy_token = 21))
#const P = (FProcess(BrownianMotion(0.2f0), schedule_f), FProcess(ManifoldProcess(0.2f0), schedule_f), NoisyInterpolatingDiscreteFlow(0.2f0, K = 2, dummy_token = 21))

P = (FProcess(BrownianMotion(0.2f0), schedule_f), FProcess(ManifoldProcess(0.2f0), schedule_f))

function compound_state(b)
L,B = size(b.aas)
cmask = b.aas .< 100
X1locs = MaskedState(ContinuousState(b.locs), cmask, b.padmask)
X1rots = MaskedState(ManifoldState(rotM,eachslice(b.rots, dims=(3,4))), cmask, b.padmask)
X1aas = MaskedState(DiscreteState(21, Flux.onehotbatch(b.aas, 1:21)), cmask, b.padmask)
return (X1locs, X1rots, X1aas)
#X1aas = MaskedState(DiscreteState(21, Flux.onehotbatch(b.aas, 1:21)), cmask, b.padmask)
#return (X1locs, X1rots, X1aas)
return (X1locs, X1rots)
end

#Added seq_to_masked_state
function seq_to_masked_state(seq_int; K=21)
N = length(seq_int)
B = 1 # batch size

# Step 1: Create 2D one-hot to get the indices
oh_2d = Flux.onehotbatch(seq_int, 1:K) # K×N OneHotMatrix

# Step 2: Reshape indices to 3D (N×B matrix) for batch dimension
indices_3d = reshape(oh_2d.indices, N, B) # N×B Matrix{UInt32}

# Step 3: Create 3D OneHotArray (K×N×B)
oh_3d = OneHotArrays.OneHotArray(indices_3d, K)

# Step 4: Create DiscreteState with 3D one-hot array
ds = ForwardBackward.DiscreteState(K, oh_3d)

# Step 5: Create masks (N×B)
cmask = trues(N, B) # conditioning mask
lmask = trues(N, B) # length/padding mask

# Step 6: Create MaskedState
return Flowfusion.MaskedState(ds, cmask, lmask)
end


function zero_state(b)
L,B = size(b.aas)
cmask = b.aas .< 100
X0locs = MaskedState(ContinuousState(randn(Float32, size(b.locs))), cmask, b.padmask)
X0rots = MaskedState(ManifoldState(rotM, reshape(Array{Float32}.(Flowfusion.rand(rotM, L*B)), L, B)), cmask, b.padmask)
X0aas = MaskedState(DiscreteState(21, Flux.onehotbatch(similar(b.aas) .= 21, 1:21)), cmask, b.padmask)
return (X0locs, X0rots, X0aas)
#X0aas = MaskedState(DiscreteState(21, Flux.onehotbatch(similar(b.aas) .= 21, 1:21)), cmask, b.padmask)
#return (X0locs, X0rots, X0aas)
return (X0locs, X0rots)
end

function training_sample(b)
Expand All @@ -27,29 +57,36 @@ function training_sample(b)
t = rand(Float32, 1, size(b.aas,2))
Xt = bridge(P, X0, X1, t)
rotξ = Guide(Xt[2], X1[2])
return (; t, Xt, X1, rotξ, chainids = b.chainids, resinds = b.resinds)
#return (; t, Xt, X1, rotξ, chainids = b.chainids, resinds = b.resinds)
return (; t, Xt, X1, rotξ, aas = b.aas, chainids = b.chainids, resinds = b.resinds)
end

function losses(hatframes, aalogits, ts)
#function losses(hatframes, aalogits, ts)
function losses(hatframes, ts)
rotangent = Flowfusion.so3_tangent_coordinates_stack(values(linear(hatframes)), tensor(ts.Xt[2]))
hatloc, hatrot, hataas = (values(translation(hatframes)), rotangent, aalogits)
#hatloc, hatrot, hataas = (values(translation(hatframes)), rotangent, aalogits)
hatloc, hatrot = (values(translation(hatframes)), rotangent)
l_loc = floss(P[1], hatloc, ts.X1[1], scalefloss(P[1], ts.t, 2, 0.2f0)) / 2
l_rot = floss(P[2], hatrot, ts.rotξ, scalefloss(P[2], ts.t, 2, 0.2f0)) / 10
l_aas = floss(P[3], hataas, ts.X1[3], scalefloss(P[3], ts.t, 1, 0.2f0)) / 100
return l_loc, l_rot, l_aas
#l_aas = floss(P[3], hataas, ts.X1[3], scalefloss(P[3], ts.t, 1, 0.2f0)) / 100
#return l_loc, l_rot, l_aas
return l_loc, l_rot
end

function flowX1predictor(X0, b, model; d = identity, smooth = 0)
batch_dim = size(tensor(X0[1]), 4)
f, _ = model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds))
#f, aalogtis = model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds))
f = model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.aas), d(b.chainids), d(b.resinds))
prev_trans = values(translation(f))
T = eltype(prev_trans)
function m(t, Xt)
print(".")
f, aalogits = model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = f)
#f, aalogits = model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = f)
f = model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.aas), d(b.chainids), d(b.resinds), sc_frames = f)
values(translation(f)) .= prev_trans .* T(smooth) .+ values(translation(f)) .* T(1-smooth)
prev_trans = values(translation(f))
return cpu(values(translation(f))), ManifoldState(rotM, eachslice(cpu(values(linear(f))), dims=(3,4))), cpu(softmax(aalogits))
#return cpu(values(translation(f))), ManifoldState(rotM, eachslice(cpu(values(linear(f))), dims=(3,4))), cpu(softmax(aalogits))
return cpu(values(translation(f))), ManifoldState(rotM, eachslice(cpu(values(linear(f))), dims=(3,4)))
end
return m
end
Expand Down
13 changes: 9 additions & 4 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,19 @@ function ChainStormV1(dim::Int = 384, depth::Int = 6, f_depth::Int = 6)
)
return ChainStormV1(layers)
end
function (fc::ChainStormV1)(t, Xt, chainids, resinds; sc_frames = nothing)

#function (fc::ChainStormV1)(t, Xt, chainids, resinds; sc_frames = nothing)
function (fc::ChainStormV1)(t, Xt, aas, chainids, resinds; sc_frames = nothing)
l = fc.layers
pmask = Flux.Zygote.@ignore self_att_padding_mask(Xt[1].lmask)
pre_z = Flux.Zygote.@ignore l.pair_rff(pair_encode(resinds, chainids))
pair_feats = l.pair_project(pre_z)
t_rff = Flux.Zygote.@ignore l.t_rff(t)
cond = reshape(l.cond_t_encoding(t_rff), :, 1, size(t,2))
frames = Translation(tensor(Xt[1])) ∘ Rotation(tensor(Xt[2]))
AA_one_hots = tensor(Xt[3])
AA_one_hots = tensor(Flux.onehotbatch(aas, 1:21))
#AA_one_hots = tensor(Xt[3])

x = l.AAencoder(AA_one_hots .+ 0)
for i in 1:l.depth
if sc_frames !== nothing
Expand All @@ -44,6 +48,7 @@ function (fc::ChainStormV1)(t, Xt, chainids, resinds; sc_frames = nothing)
frames = l.framemovers[i - l.depth + l.f_depth](frames, x, t = t)
end
end
aa_logits = l.AAdecoder(x .+ reshape(l.AApre_t_encoding(t_rff), :, 1, size(t,2)))
return frames, aa_logits
#aa_logits = l.AAdecoder(x .+ reshape(l.AApre_t_encoding(t_rff), :, 1, size(t,2)))
#return frames, aa_logits
return frames
end
36 changes: 36 additions & 0 deletions training_code/training.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#In addition to ChainStorm, also install these:
using Pkg
Pkg.add(["JLD2", "Flux", "CannotWaitForTheseOptimisers", "LearningSchedules", "DLProteinFormats"])
Pkg.add(["CUDA", "cuDNN"])

using ChainStorm, DLProteinFormats, Flux, CannotWaitForTheseOptimisers, LearningSchedules, JLD2
using DLProteinFormats: load, PDBSimpleFlat, batch_flatrecs, sample_batched_inds, length2batch
using CUDA
device = gpu

dat = load(PDBSimpleFlat);

model = ChainStormV1(384, 3, 3) |> device
sched = burnin_learning_schedule(0.000005f0, 0.001f0, 1.05f0, 0.99995f0)
opt_state = Flux.setup(Muon(eta = sched.lr), model)

for epoch in 1:100
batchinds = sample_batched_inds(dat,l2b = length2batch(1500, 1.9))
for (i, b) in enumerate(batchinds)
bat = batch_flatrecs(dat[b])
ts = training_sample(bat) |> device
sc_frames = nothing
if epoch > 1 && rand() < 0.5
sc_frames, _ = model(ts.t, ts.Xt, ts.chainids, ts.resinds)
end
l, grad = Flux.withgradient(model) do m
fr, aalogs = m(ts.t, ts.Xt, ts.chainids, ts.resinds, sc_frames = sc_frames)
l_loc, l_rot, l_aas = losses(fr, aalogs, ts)
l_loc + l_rot + l_aas
end
Flux.update!(opt_state, model, grad[1])
(mod(i, 10) == 0) && Flux.adjust!(opt_state, next_rate(sched))
println(l)
end
jldsave("model_epoch_$epoch.jld", model_state = Flux.state(cpu(model)), opt_state=cpu(opt_state))
end
58 changes: 58 additions & 0 deletions training_code/training_finetune.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using Dates
using ChainStorm, DLProteinFormats, Flux, CannotWaitForTheseOptimisers, LearningSchedules, JLD2
using DLProteinFormats: load, PDBSimpleFlat, batch_flatrecs, sample_batched_inds, length2batch
using CUDA, cuDNN

fallback_mask(x) = any(size(x) .== 21)

device = gpu
dat = load(PDBSimpleFlat);
sample = dat[9]
print(sample.len)
L = length(sample.chainids)
c = (;chainids = reshape(sample.chainids, :, 1),
resinds = view(sample.resinds, :, 1),
padmask = trues(L, 1),
aas = reshape(sample.AAs, :, 1),
locs = reshape(sample.locs, 3, 1, L, 1))

model = ChainStorm.load_model() |> device
model_cpu = nothing

sched = burnin_learning_schedule(0.00001f0, 0.000250f0, 1.05f0, 0.999995f0);
opt_state = Flux.setup(Muon(eta = sched.lr, fallback = fallback_mask), model);

start_time = Dates.format(Dates.now(), "yyyy-mm-dd_HHMM")

gen_sample_counter = 0
sample_counter = 10000
for epoch in 1:100
batchinds = sample_batched_inds(dat,l2b = length2batch(1500, 1.9))
@info "Epoch $epoch"
for (i, b) in enumerate(batchinds)
sample_counter += 1
bat = batch_flatrecs(dat[b])
ts = training_sample(bat) |> device
sc_frames = nothing
if epoch > 1 && rand() < 0.5
sc_frames = model(ts.t, ts.Xt, ts.aas, ts.chainids, ts.resinds)
end
l, grad = Flux.withgradient(model) do m
fr = m(ts.t, ts.Xt, ts.aas, ts.chainids, ts.resinds, sc_frames = sc_frames)
l_loc, l_rot = losses(fr, ts)
l_loc + l_rot
end
Flux.update!(opt_state, model, grad[1])
(mod(i, 10) == 0) && Flux.adjust!(opt_state, next_rate(sched))
println(l)
println("Sample counter: ", sample_counter, " Gen sample counter: ", gen_sample_counter)
if sample_counter >= 10000
sample_counter = 0
gen_sample_counter += 1
model_cpu = cpu(model)
g = flow_quickgen(c, model_cpu)
export_pdb("notrunk_toymodel/gens/gen_$(start_time)_foldfinetune_sample_$(gen_sample_counter).pdb", (g..., ChainStorm.seq_to_masked_state(sample.AAs)), c.chainids, c.resinds)
end
end
jldsave("model_epoch_$epoch.jld", model_state = Flux.state(cpu(model)), opt_state=cpu(opt_state))
end
Loading